1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=unidiomatic-typecheck 16"""Defun decorator for defining graph-mode functions.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import collections 23import functools 24import itertools 25import pprint 26import threading 27import types as types_lib 28import weakref 29 30import numpy as np 31import six 32from six.moves import map 33 34from tensorflow.core.framework import attr_value_pb2 35from tensorflow.core.framework import function_pb2 36from tensorflow.python import pywrap_tfe 37from tensorflow.python.client import pywrap_tf_session 38from tensorflow.python.eager import backprop 39from tensorflow.python.eager import backprop_util 40from tensorflow.python.eager import context 41from tensorflow.python.eager import execute 42from tensorflow.python.eager import forwardprop_util 43from tensorflow.python.eager import monitoring 44from tensorflow.python.eager import tape 45from tensorflow.python.eager.graph_only_ops import graph_placeholder 46from tensorflow.python.framework import c_api_util 47from tensorflow.python.framework import composite_tensor 48from tensorflow.python.framework import constant_op 49from tensorflow.python.framework import device as pydev 50from tensorflow.python.framework import dtypes 51from tensorflow.python.framework import error_interpolation 52from tensorflow.python.framework import errors 53from tensorflow.python.framework import func_graph as func_graph_module 54from tensorflow.python.framework import ops 55from tensorflow.python.framework import tensor_shape 56from tensorflow.python.framework import tensor_spec 57from tensorflow.python.framework import type_spec 58from tensorflow.python.ops import array_ops 59from tensorflow.python.ops import control_flow_ops 60from tensorflow.python.ops import custom_gradient 61from tensorflow.python.ops import default_gradient 62from tensorflow.python.ops import functional_ops 63from tensorflow.python.ops import gradients_util 64from tensorflow.python.ops import resource_variable_ops 65 66from tensorflow.python.platform import tf_logging as logging 67from tensorflow.python.profiler import trace 68from tensorflow.python.saved_model import save_context 69from tensorflow.python.util import _pywrap_utils 70from tensorflow.python.util import compat 71from tensorflow.python.util import function_utils 72from tensorflow.python.util import lazy_loader 73from tensorflow.python.util import memory 74from tensorflow.python.util import nest 75from tensorflow.python.util import object_identity 76from tensorflow.python.util import tf_decorator 77from tensorflow.python.util import tf_inspect 78from tensorflow.python.util.tf_export import tf_export 79 80# Loaded lazily due to a circular dependency (roughly 81# tf.function->autograph->->dataset->tf.function). 82# TODO(b/133251390): Use a regular import. 83ag_ctx = lazy_loader.LazyLoader( 84 "ag_ctx", globals(), 85 "tensorflow.python.autograph.core.ag_ctx") 86np_arrays = lazy_loader.LazyLoader( 87 "np_arrays", globals(), 88 "tensorflow.python.ops.numpy_ops.np_arrays") 89 90 91FORWARD_FUNCTION_ATTRIBUTE_NAME = "forward_function_name" 92BACKWARD_FUNCTION_ATTRIBUTE_NAME = "backward_function_name" 93IMPLEMENTS_ATTRIBUTE_NAME = "_implements" 94SHARED_RENDEZVOUS_ATTRIBUTE_NAME = "shared_rendezvous" 95 96_graph_building_time_counter = monitoring.Counter( 97 "/tensorflow/core/tf_function/graph_building_time_usecs", 98 "Time for tf.function to build a graph (us).") 99 100 101def _make_input_signature_hashable(elem): 102 """Rewrite input signature to be hashable. 103 104 We replace nested variables in the input signature with TensorSpec in order to 105 be hashable. 106 107 Args: 108 elem: Input signature element 109 110 Returns: 111 A hashable object for the requested input signature 112 """ 113 try: 114 hash(elem) 115 except TypeError: 116 # TODO(slebedev): consider using nest. 117 if isinstance(elem, tuple): 118 return tuple(map(_make_input_signature_hashable, elem)) 119 120 # TFE_Py_EncodeArg weakrefs arguments it does not recognize, and we expect 121 # all recognized types to be hashable. 122 assert isinstance(elem, weakref.ReferenceType) 123 v = elem() 124 125 if resource_variable_ops.is_resource_variable(v): 126 # We special case variables here to use unique_id as the cache key. This 127 # ensures we have to retrace whenever a different variable is passed in. 128 # This is needed to support cases where the user may use the id of a 129 # variable in the function perhaps as a lookup in a dictionary. 130 # 131 # This choice leads to more retracing when we could have possibly used the 132 # shape and dtype instead. However, we expect the number of variables in a 133 # program to be bounded, and correspondingly the number of retraces. 134 # 135 # Note we also include the class name to avoid collisions with strings. 136 return v.__class__, v._unique_id # pylint: disable=protected-access 137 138 if _is_ndarray(v): 139 # Numpy arrays are not hashable, but when calling functions we treat them 140 # in the same way as tf.Tensors. 141 if not hasattr(v, "shape") or not hasattr(v, "dtype"): 142 # TODO(tomhennigan) De-dup with _as_ndarray in _convert_numpy_inputs. 143 v = _as_ndarray(v) 144 return tensor_spec.TensorSpec(v.shape, v.dtype) 145 146 raise ValueError("Arguments to a tf.function must be Tensors, Variables, " 147 "or hashable Python objects (or nested structures of " 148 "these types).\nGot type: %s" % type(v).__name__) 149 150 return elem 151 152 153CacheKey = collections.namedtuple("CacheKey", [ 154 "input_signature", 155 "parent_graph", 156 "device_functions", 157 "colocation_stack", 158 "in_cross_replica_context", 159 "variable_policy", 160 "xla_context_id", 161]) 162 163 164def _type_spec_for(x): 165 """Returns a TypeSpec for `x`, or `None` if `x` doesn't have a TensorSpec.""" 166 if isinstance(x, ops.Tensor): 167 return tensor_spec.TensorSpec.from_tensor(x) 168 elif isinstance(x, type_spec.TypeSpec): 169 return x 170 elif isinstance(x, composite_tensor.CompositeTensor): 171 return x._type_spec # pylint: disable=protected-access 172 else: 173 return None 174 175 176def _is_type_subset(a, b): 177 """Returns true if TypeSpec `b` is a subset of type `a` (or if a is None.)""" 178 if a is None: 179 return True 180 else: 181 return a.most_specific_compatible_type(b) == a 182 183 184def _shape_relaxed_type_for_composite_tensor(x): 185 """Returns a shape-relaxed TypeSpec for x (if composite) or x (if not).""" 186 if isinstance(x, composite_tensor.CompositeTensor): 187 # pylint: disable=protected-access 188 return x._type_spec._with_tensor_ranks_only() 189 else: 190 return x 191 192 193def common_shape(x, y): 194 """Find a `TensorShape` that is compatible with both `x` and `y`.""" 195 if x is None != y is None: 196 raise RuntimeError( 197 "Cannot find a common shape when LHS shape is None but RHS shape " 198 "is not (or vice versa): %s vs. %s" % (x, y)) 199 if x is None: 200 return None # The associated input was not a Tensor, no shape generated. 201 if not isinstance(x, tensor_shape.TensorShape): 202 raise TypeError("Expected x to be a TensorShape but saw %s" % (x,)) 203 if not isinstance(y, tensor_shape.TensorShape): 204 raise TypeError("Expected y to be a TensorShape but saw %s" % (y,)) 205 if x.rank != y.rank or x.rank is None: 206 return tensor_shape.TensorShape(None) 207 dims = [] 208 for dim_x, dim_y in zip(x.dims, y.dims): 209 if (dim_x != dim_y 210 or tensor_shape.dimension_value(dim_x) is None 211 or tensor_shape.dimension_value(dim_y) is None): 212 dims.append(None) 213 else: 214 dims.append(tensor_shape.dimension_value(dim_x)) 215 return tensor_shape.TensorShape(dims) 216 217 218def is_same_structure(structure1, 219 structure2, 220 check_values=False): 221 """Check two structures for equality, optionally of types and of values.""" 222 try: 223 nest.assert_same_structure(structure1, structure2, expand_composites=True) 224 except (ValueError, TypeError): 225 return False 226 if check_values: 227 flattened1 = nest.flatten(structure1, expand_composites=True) 228 flattened2 = nest.flatten(structure2, expand_composites=True) 229 # First check the types to avoid AttributeErrors. 230 if any(type(f1) != type(f2) for f1, f2 in zip(flattened1, flattened2)): 231 return False 232 return flattened1 == flattened2 233 return True 234 235 236def _parse_func_attrs(attributes): 237 """Convert the keyword arguments into function_def attributes. 238 239 Currently only support primitive types: bool, int, float and string. 240 241 Args: 242 attributes: the dictionary of attributes. 243 Returns: 244 A dict of attributes where the key is the name of attribute and the value 245 is the AttrValue proto. 246 Raises: 247 ValueError: If the kwargs contains unallowlisted name or unsupported value 248 types. 249 """ 250 attrs = {} 251 for key, value in attributes.items(): 252 if isinstance(value, attr_value_pb2.AttrValue): 253 attrs[key] = value 254 # bool type check has to happen before int since bool is a subclass of int. 255 elif isinstance(value, bool): 256 attrs[key] = attr_value_pb2.AttrValue(b=value) 257 elif isinstance(value, int): 258 attrs[key] = attr_value_pb2.AttrValue(i=value) 259 elif isinstance(value, float): 260 attrs[key] = attr_value_pb2.AttrValue(f=value) 261 elif isinstance(value, (str, bytes, six.text_type)): 262 attrs[key] = attr_value_pb2.AttrValue(s=compat.as_bytes(value)) 263 else: 264 raise ValueError("Unsupported attribute type for %s with type %s" % 265 (key, type(value))) 266 return attrs 267 268 269class _InterpolateFunctionError(object): 270 """Context Manager that interpolates the exception from 'top_level_func'.""" 271 272 __slots__ = ["_func"] 273 274 def __init__(self, top_level_func): 275 self._func = top_level_func 276 277 def __enter__(self): 278 pass 279 280 def __exit__(self, typ, exc, tb): 281 if not exc or not isinstance(exc, errors.OpError): 282 return False 283 message = compat.as_text(exc.message) 284 _, tags = error_interpolation.parse_message(message) 285 g = None 286 func_stack = [] 287 for t in tags: 288 if t.type == "function_node": 289 # TODO(mdan): Tests should cover this. 290 if t.name == compat.as_str(self._func.name): 291 g = self._func.graph 292 elif g: 293 next_func = g._get_function(t.name) # pylint: disable=protected-access 294 if next_func is not None and isinstance(next_func, 295 _EagerDefinedFunction): 296 g = next_func.graph 297 if g: 298 func_stack.append(g.name) 299 else: 300 func_stack.append("<unknown>") 301 if g: 302 message = error_interpolation.interpolate(message, g) 303 message += "\n\nFunction call stack:\n" 304 message += " -> ".join(func_stack) 305 message += "\n" 306 exc._message = message # pylint: disable=protected-access 307 return False 308 309 310_function_callbacks = set() 311 312 313def add_function_callback(function_callback): 314 """Add a callback function for Function creation. 315 316 The callback function has the signature: 317 318 `def function_callback(function, name, graph, inputs, outputs):` 319 320 where: 321 - `function`: _EagerDefinedFunction being created before finalizing the graph. 322 Do not modify the function directly but instead modify the graph. 323 - `name`: name of the function. 324 - `graph`: Graph of the function. 325 - `inputs`: `tuple` of tensors used as inputs to the function. 326 - `outputs`: `tuple` of tensors used as outputs from the function. 327 328 The callback is at the top of the `_EagerDefinedFunction` construction, giving 329 callback an opportunity to make the last edits to the graph. Do not make 330 changes to `graph, inputs`, and `outputs` manually, but, instead, set the 331 `graph` as the default then define ops. 332 333 Repeated registration of the same callback function is idempotent. 334 After a callback is added, it can be removed with the 335 `remove_function_callback()` method. 336 337 Args: 338 function_callback: The callback to add. 339 """ 340 _function_callbacks.add(function_callback) 341 342 343def remove_function_callback(function_callback): 344 """Remove an already-added function callback. 345 346 See the doc string of `add_function_callback()` for more information. 347 348 Args: 349 function_callback: The callback to remove. 350 """ 351 _function_callbacks.remove(function_callback) 352 353 354def clear_function_callbacks(): 355 """Clear all function callbacks, if any have been regisered.""" 356 _function_callbacks.clear() 357 358 359_FORWARD_PREFIX = "__forward_" 360_BACKWARD_PREFIX = "__backward_" 361_INFERENCE_PREFIX = "__inference_" 362 363 364def _forward_name(n): 365 """The name of a generated forward defun named n.""" 366 return "%s%s_%s" % (_FORWARD_PREFIX, n, ops.uid()) 367 368 369def _backward_name(n): 370 """The name of a generated backward defun named n.""" 371 return "%s%s_%s" % (_BACKWARD_PREFIX, n, ops.uid()) 372 373 374def _inference_name(n): 375 """The name of a forward-but-no-gradient defun named n.""" 376 return "%s%s_%s" % (_INFERENCE_PREFIX, n, ops.uid()) 377 378 379def _enclosing_xla_context(): 380 """Returns the XLAControlFlowContext, which exists inside a tpu.rewrite().""" 381 graph = ops.get_default_graph() 382 while graph is not None: 383 # pylint: disable=protected-access 384 context_ = graph._get_control_flow_context() 385 # pylint: enable=protected-access 386 while context_ is not None: 387 if isinstance(context_, control_flow_ops.XLAControlFlowContext): 388 return context_ 389 context_ = context_.outer_context 390 # This may be a FuncGraph due to defuns or v2 control flow. We need to 391 # find the original graph with the XLAControlFlowContext. 392 graph = getattr(graph, "outer_graph", None) 393 return None 394 395 396class _EagerDefinedFunctionDeleter(object): 397 """Unregister function from eager context.""" 398 399 __slots__ = ["name"] 400 401 def __init__(self, name): 402 self.name = name 403 404 def __del__(self): 405 try: 406 context.remove_function(self.name) 407 except TypeError: 408 # Suppress some exceptions, mainly for the case when we're running on 409 # module deletion. Things that can go wrong include the context module 410 # already being unloaded, self._handle._handle_data no longer being 411 # valid, and so on. Printing warnings in these cases is silly 412 # (exceptions raised from __del__ are printed as warnings to stderr). 413 pass # 'NoneType' object is not callable when the handle has been 414 # partially unloaded. 415 except AttributeError: 416 pass # 'NoneType' object has no attribute 'eager_mode' when context has 417 # been unloaded. Will catch other module unloads as well. 418 419 420# TODO(apassos) get rid of this by splitting framework.function._DefinedFunction 421# so it doesn't have the definition-generating logic and is just a container for 422# an already-defined function. 423class _EagerDefinedFunction(object): 424 """Callable with the interface of `framework.function._DefinedFunction`. 425 426 `_EagerDefinedFunction` encapsulates a function definition and its properties, 427 and it provides a method for calling the encapsulated function. Some Ops 428 take functions as attributes, which have type `func`; an instance of this 429 class may be provided as the value of these `func` attributes. 430 """ 431 432 def __init__(self, name, graph, inputs, outputs, attrs): 433 """Initializes an eager defined function. 434 435 Args: 436 name: str, the name for the created function. 437 graph: Graph, the graph containing the operations in the function 438 inputs: the tensors in the graph to be used as inputs to the function 439 outputs: the tensors in the graph which will be outputs from the function 440 attrs: dict mapping names of attributes to their AttrValue values 441 """ 442 for function_callback in _function_callbacks: 443 function_callback(self, name, graph, tuple(inputs), tuple(outputs)) 444 445 input_ops = set(arg.op for arg in inputs) 446 operations = [op for op in graph.get_operations() if op not in input_ops] 447 448 graph_output_names = graph._output_names # pylint: disable=protected-access 449 if (graph_output_names is not None and 450 all(ops.tensor_id(t) in graph_output_names for t in outputs)): 451 output_names = [ 452 compat.as_bytes(graph_output_names[ops.tensor_id(t)]) for t in outputs 453 ] 454 if len(set(output_names)) != len(output_names): 455 # There are duplicate names for some reason, probably an invalid 456 # signature. Revert to auto-naming. 457 output_names = [] 458 else: 459 output_names = [] 460 fn = pywrap_tf_session.TF_GraphToFunction_wrapper( 461 graph._c_graph, # pylint: disable=protected-access 462 compat.as_str(name), 463 False, 464 [o._c_op for o in operations], # pylint: disable=protected-access 465 [t._as_tf_output() for t in inputs], # pylint: disable=protected-access 466 [t._as_tf_output() for t in outputs], # pylint: disable=protected-access 467 output_names, 468 [o._c_op for o in graph.control_outputs], # pylint: disable=protected-access 469 [], # control_output_names 470 None, 471 compat.as_str("")) 472 473 for name, attr_value in attrs.items(): 474 serialized = attr_value.SerializeToString() 475 # TODO(iga): this creates and deletes a new TF_Status for every attr. 476 # It might be worth creating a convenient way to re-use status. 477 pywrap_tf_session.TF_FunctionSetAttrValueProto(fn, compat.as_str(name), 478 serialized) 479 480 # TODO(apassos) avoid creating a FunctionDef (specially to grab the 481 # signature, but also in general it's nice not to depend on it. 482 with c_api_util.tf_buffer() as buffer_: 483 pywrap_tf_session.TF_FunctionToFunctionDef(fn, buffer_) 484 proto_data = pywrap_tf_session.TF_GetBuffer(buffer_) 485 function_def = function_pb2.FunctionDef() 486 function_def.ParseFromString(compat.as_bytes(proto_data)) 487 self._name = compat.as_bytes(function_def.signature.name) 488 with ops.init_scope(): 489 if context.executing_eagerly(): 490 context.ensure_initialized() 491 context.add_function(fn) 492 self._function_deleter = _EagerDefinedFunctionDeleter(self.name) 493 self._registered_on_context = True 494 self.definition = function_def 495 self.signature = function_def.signature 496 self._num_outputs = len(self.signature.output_arg) 497 self._output_types = [o.type for o in self.signature.output_arg] 498 self._output_shapes = [o.shape for o in outputs] 499 self._control_captures = graph.control_captures 500 # Shallow copy outputs since ConcreteFunction may mutate it. 501 self._func_graph_outputs = list(outputs) 502 self.grad_func_name = None 503 self.python_grad_func = None 504 self._c_func = c_api_util.ScopedTFFunction(fn) 505 self._grad_func = None 506 self.graph = graph 507 self._stateful_ops = tuple(op for op in operations if op._is_stateful) # pylint: disable=protected-access 508 509 def add_to_graph(self, g=None): 510 """Add the function to the current context or a graph, if supplied. 511 512 Args: 513 g: the graph to add the function to. If not supplied, the function will 514 be added to the current context. 515 """ 516 # pylint: disable=protected-access 517 if not g and context.executing_eagerly(): 518 ctx = context.context() 519 if not ctx.has_function(self.name): 520 ctx.add_function_def(self.definition) 521 else: 522 if not g._is_function(self.name): 523 g._add_function(self) 524 for f in self.graph._functions.values(): 525 if not g._is_function(f.name): 526 g._add_function(f) 527 # pylint: enable=protected-access 528 529 @property 530 def name(self): 531 return self._name 532 533 @property 534 def stateful_ops(self): 535 return self._stateful_ops 536 537 def call(self, ctx, args, cancellation_manager=None): 538 """Calls this function with `args` as inputs. 539 540 `ConcreteFunction` execution respects device annotations only if the 541 function won't be compiled with xla. 542 543 Args: 544 ctx: a Context object 545 args: a list of arguments to supply this function with. 546 cancellation_manager: a `CancellationManager` object that can be used to 547 cancel function execution. 548 549 Returns: 550 The outputs of the function call. 551 552 Raises: 553 ValueError: if the number of arguments is incorrect. 554 """ 555 if len(args) != len(self.signature.input_arg): 556 raise ValueError( 557 "Arguments and signature arguments do not match. " 558 "got: %s, expected: %s " % 559 (len(args), len(list(self.signature.input_arg)))) 560 561 function_call_options = ctx.function_call_options 562 if function_call_options.config_proto_serialized is None: 563 config = function_utils.get_disabled_rewriter_config() 564 else: 565 config = function_call_options.config_proto_serialized 566 executor_type = function_call_options.executor_type or "" 567 568 executing_eagerly = ctx.executing_eagerly() 569 attrs = ("executor_type", executor_type, "config_proto", config) 570 if executing_eagerly: 571 with _InterpolateFunctionError(self): 572 if cancellation_manager is None: 573 outputs = execute.execute( 574 str(self.signature.name), 575 num_outputs=self._num_outputs, 576 inputs=args, 577 attrs=attrs, 578 ctx=ctx) 579 else: 580 outputs = execute.execute_with_cancellation( 581 str(self.signature.name), 582 num_outputs=self._num_outputs, 583 inputs=args, 584 attrs=attrs, 585 ctx=ctx, 586 cancellation_manager=cancellation_manager) 587 # Replace empty list with None 588 outputs = outputs or None 589 else: 590 # TODO(akshayka): Either remove this if the FunctionLibraryRuntime 591 # creates `PartitionedCallOp` kernels by default, or remove the previous 592 # branch if a TPU kernel is registered for `PartitionedCall`. 593 with _InterpolateFunctionError(self): 594 with ops.control_dependencies(self._control_captures): 595 # The caller must use record_operation to record this operation in the 596 # eager case, so we enforce the same requirement for the non-eager 597 # case by explicitly pausing recording. We don't have a gradient 598 # registered for PartitionedCall, so recording this operation confuses 599 # forwardprop code (GradientTape manages to ignore it). 600 with tape.stop_recording(): 601 outputs = functional_ops.partitioned_call( 602 args=args, 603 f=self, 604 tout=self._output_types, 605 executing_eagerly=executing_eagerly, 606 config=config, 607 executor_type=executor_type) 608 609 for i, func_graph_output in enumerate(self._func_graph_outputs): 610 custom_gradient.copy_handle_data(func_graph_output, outputs[i]) 611 if executing_eagerly: 612 return outputs 613 else: 614 # TODO(b/128924522): This additional set_shape should not be 615 # necessary. ShapeRefiner likely needs to inspect handle_data. Remove this 616 # once that's done. 617 for i, shape in enumerate(self._output_shapes): 618 outputs[i].set_shape(shape) 619 return outputs 620 621 622def _create_forward_backward_with_graph(attrs, forward_graph, backwards_graph): 623 """Creates forward and backward functions from the function graphs.""" 624 forward_function_name = _forward_name(forward_graph.name) 625 common_attributes = dict(attrs) 626 # NB: forward and backward function need to drop "_implements". 627 # attribute, because their signature contains all the intermediate tensors 628 # that they compute. Thus they don't have a stable signature which can 629 # be directly optimized downstream. 630 # See for more details: 631 # https://github.com/tensorflow/community/blob/master/rfcs/20190610-standardizing-composite_ops.md#appendix-future-support-for-optimizing-gradient-functions 632 common_attributes.pop(IMPLEMENTS_ATTRIBUTE_NAME, None) 633 backward_function_attr = _parse_func_attrs( 634 {FORWARD_FUNCTION_ATTRIBUTE_NAME: forward_function_name}) 635 backward_function_attr.update(common_attributes) 636 backward_function = ConcreteFunction( 637 backwards_graph, attrs=backward_function_attr) 638 forward_function_attr = _parse_func_attrs({ 639 BACKWARD_FUNCTION_ATTRIBUTE_NAME: 640 backward_function.name}) 641 forward_function_attr.update(common_attributes) 642 forward_function = _EagerDefinedFunction( 643 forward_function_name, forward_graph, forward_graph.inputs, 644 forward_graph.outputs, forward_function_attr) 645 return forward_function, backward_function 646 647 648class _DelayedRewriteGradientFunctions(object): 649 """Caches forward/backward functions with a delayed forward rewrite.""" 650 651 def __init__(self, func_graph, attrs, func_graph_deleter): 652 """Construct an inference function and initialize caches.""" 653 # A map from the number of forward function outputs with accepted gradients 654 # to forward and backward functions, used to cache non-tape backward 655 # function generation. 656 self._cached_function_pairs = {} 657 self._func_graph = func_graph 658 self._inference_function = _EagerDefinedFunction( 659 _inference_name(self._func_graph.name), self._func_graph, 660 self._func_graph.inputs, self._func_graph.outputs, attrs) 661 self._attrs = attrs 662 self._gradient_name = None 663 # Note that the FuncGraph is mutated later, so we need to inspect it now to 664 # figure out the user-specified outputs of the inference function. 665 self._num_inference_outputs = len(self._func_graph.outputs) 666 self._func_graph_deleter = func_graph_deleter 667 668 def forward_backward(self, num_doutputs=None): 669 """A possibly-cached pair of forward and backward functions.""" 670 if num_doutputs is None: 671 num_doutputs = self._num_inference_outputs 672 forward_backward = self._cached_function_pairs.get(num_doutputs) 673 if forward_backward is not None: 674 return forward_backward 675 forward, backward = self._construct_forward_backward(num_doutputs) 676 self._cached_function_pairs[num_doutputs] = (forward, backward) 677 return forward, backward 678 679 def _construct_forward_backward(self, num_doutputs): 680 """Constructs a pair of forward and backward functions. 681 682 Args: 683 num_doutputs: The constructed backprop function will take output gradients 684 for the first `num_doutputs` outputs of the forward function. Defaults 685 to the number of outputs for the inference function, but when 686 higher-order gradients are computed this will increase to include side 687 outputs. 688 689 Returns: 690 A pair of (forward_function, backward_function): 691 forward_function: A re-generated inference function (an 692 _EagerDefinedFunction) to account for new side outputs, if any extra 693 were required when building the backward pass. 694 backward_function: A ConcreteFunction that Takes `num_doutputs` 695 arguments and returns gradients with respect to inputs of the forward 696 function. 697 """ 698 trainable_outputs = [ 699 output for output in self._func_graph.outputs[:num_doutputs] 700 if backprop_util.IsTrainable(output)] 701 702 signature = [] 703 for t in trainable_outputs: 704 signature.append( 705 tensor_spec.TensorSpec(*default_gradient.shape_and_dtype(t))) 706 707 def _backprop_function(*grad_ys): 708 with ops.device(None): 709 return gradients_util._GradientsHelper( # pylint: disable=protected-access 710 trainable_outputs, 711 self._func_graph.inputs, 712 grad_ys=grad_ys, 713 src_graph=self._func_graph) 714 715 with self._func_graph.as_default(): 716 backwards_graph = func_graph_module.FuncGraph( 717 _backward_name(self._func_graph.name)) 718 func_graph_module.func_graph_from_py_func( 719 name=backwards_graph.name, 720 python_func=_backprop_function, 721 args=[], kwargs={}, 722 signature=signature, 723 func_graph=backwards_graph) 724 backwards_graph_captures = backwards_graph.external_captures 725 captures_from_forward = [ 726 c for c in backwards_graph_captures if 727 not isinstance(c, ops.EagerTensor) and c.graph is self._func_graph] 728 729 existing_outputs = object_identity.ObjectIdentitySet( 730 self._func_graph.outputs) 731 for capture in captures_from_forward: 732 if capture not in existing_outputs: 733 existing_outputs.add(capture) 734 self._func_graph.outputs.append(capture) 735 736 forward_function, backward_function = _create_forward_backward_with_graph( 737 self._attrs, self._func_graph, backwards_graph) 738 return forward_function, backward_function 739 740 def _rewrite_forward_and_call_backward(self, op, *doutputs): 741 """Add outputs to the forward call and feed them to the grad function.""" 742 forward_function, backwards_function = self.forward_backward(len(doutputs)) 743 if not backwards_function.outputs: 744 return backwards_function.structured_outputs 745 forward_function.add_to_graph(op.graph) 746 747 # pylint: disable=protected-access 748 # Rewrite an inference call op to be a forward call op 749 op._set_func_attr("f", forward_function.name) 750 op._set_type_list_attr("Tout", forward_function._output_types) 751 op._add_outputs( 752 forward_function._output_types[len(op.outputs):], 753 forward_function._output_shapes[len(op.outputs):]) 754 for i in range(len(op.outputs)): 755 func_graph_output = forward_function._func_graph_outputs[i] 756 custom_gradient.copy_handle_data(func_graph_output, op.outputs[i]) 757 # pylint: enable=protected-access 758 759 capture_mapping = dict( 760 zip((ops.tensor_id(t) for t in self._func_graph.outputs), op.outputs)) 761 remapped_captures = [ 762 capture_mapping.get(ops.tensor_id(capture), capture) 763 for capture in backwards_function.captured_inputs 764 ] 765 766 # Replace Nones with zeros since we're calling a graph function which 767 # expects numeric inputs. 768 cleaned_doutputs = [] 769 for doutput, placeholder in zip(doutputs, self._func_graph.outputs): 770 if backprop_util.IsTrainable(placeholder): 771 if isinstance(doutput, ops.IndexedSlices): 772 # Gradient passed to a backward ConcreteFunction must be tf.Tensor, 773 # so we convert tf.IndexedSlices to tf.Tensor. 774 cleaned_doutputs.append(ops.convert_to_tensor(doutput)) 775 elif doutput is not None: 776 cleaned_doutputs.append(doutput) 777 else: 778 cleaned_doutputs.append(default_gradient.zeros_like(placeholder)) 779 780 # Compute the gradients using the side outputs 781 return backwards_function._call_flat( # pylint: disable=protected-access 782 cleaned_doutputs, remapped_captures) 783 784 def get_gradient_function(self): 785 """Returns gradient function. 786 787 The gradient rewrites an inference call op to a forward call op, but does 788 not modify a pre-existing forward call op. It then computes the gradient 789 from the output's gradients and the side outputs of the forward op. 790 """ 791 return self._rewrite_forward_and_call_backward 792 793 def forward(self, inference_args=None, input_tangents=None): 794 """A forward function with only user-specified outputs. 795 796 The call operation for the returned inference function can be rewritten into 797 a forward function. This only happens if the backward function (from the 798 `backward` method) ends up being used to compute gradients. 799 800 This approach avoids constructing unnecessary graphs, but it only works if 801 we are calling this function when not executing eagerly. 802 803 Args: 804 inference_args: A flat list of Tensors, arguments to the inference 805 function. Unused, but taken for compatibility with 806 _TapeGradientFunctions. 807 input_tangents: A flat list of Tensors, jvps associated with 808 `inference_args`. Unused; if required, tape functions must be used 809 instead. 810 811 Returns: 812 An _EagerDefinedFunction. 813 """ 814 del inference_args # unused 815 if input_tangents: 816 # This class does not support special-cased forwardprop. The arguments are 817 # here for compatibility with _TapeGradientFunctions. 818 raise AssertionError( 819 "Internal error: unexpectedly got forwardprop information in a class " 820 "that does not support forwardprop.") 821 return self._inference_function 822 823 def _backward(self, outputs): 824 """Fetch a backward function for `outputs` from the forward function.""" 825 def _backward_function(*args): 826 call_op = outputs[0].op 827 return self._rewrite_forward_and_call_backward(call_op, *args) 828 return _backward_function, outputs 829 830 def record(self, flat_outputs, inference_args, input_tangents): 831 """Record the function call operation. 832 833 _DelayedRewriteGradientFunctions supports only first-order backprop tape 834 gradients (and then only when graph building). It does not work with 835 higher-order tape gradients or forward autodiff, but does work with 836 higher-order symbolic gradients (tf.gradients). 837 838 Args: 839 flat_outputs: The result of running `forward`. 840 inference_args: A flat list of Tensors with inference inputs to the 841 operation. 842 input_tangents: A flat list of Tensors with input tangents consumed by the 843 operation. 844 """ 845 backward_function, to_record = self._backward(flat_outputs) 846 tape.record_operation(self._inference_function.signature.name, 847 to_record, inference_args + input_tangents, 848 backward_function) 849 850 851# Contains information about a forward function wrapped to compute jvps. 852_ForwardWrapper = collections.namedtuple( 853 "_ForwardWrapper", ( 854 # The wrapper Graph. 855 "graph", 856 # A flat list of non-tangent Tensor outputs from the wrapped forward 857 # function. 858 "outputs", 859 # Indices for output tangents, same format as 860 # forwardprop_util.pack_tangents. 861 "output_indices", 862 # A flat list of tangents for `outputs`. 863 "output_tangents")) 864 865 866class _TapeGradientFunctions(object): 867 """Caches forward and backward functions compatible with eager gradients. 868 869 In contrast to the delayed-rewrite approach in 870 `_DelayedRewriteGradientFunctions` which only works with delayed execution, 871 the forward function generated by this class has a fixed set of outputs which 872 may be preserved by a tape in order to compute gradients later. 873 874 This class is abstract; its child classes differ in how many side outputs of 875 the forward function their backward function accepts gradients for, which 876 determines whether higher-order tape gradients are possible. 877 """ 878 879 def __init__(self, func_graph, attrs, func_graph_deleter, 880 forwardprop_input_indices, delayed_rewrite_functions, 881 need_gradients_for_jvps): 882 self._func_graph = func_graph 883 self._forward_graph = None 884 self._attrs = attrs 885 self._forward = None 886 self._backward = None 887 self._num_outputs = len(func_graph.outputs) 888 self._func_graph_deleter = func_graph_deleter 889 self._forwardprop_input_indices = forwardprop_input_indices 890 self._forwardprop_output_indices = None 891 self._num_forwardprop_outputs = 0 892 self._num_inference_outputs = len(func_graph.outputs) 893 self._num_trainable_inference_outputs = len( 894 [t for t in func_graph.outputs if backprop_util.IsTrainable(t)]) 895 self._delayed_rewrite_functions = delayed_rewrite_functions 896 self._need_gradients_for_jvps = need_gradients_for_jvps 897 898 def _build_functions_for_outputs( 899 self, outputs, inference_args, input_tangents): 900 """Forward+backward functions where the backward function sees `outputs`.""" 901 # First figure out which of `outputs` are trainable. We'll accept gradients 902 # for each of these in the backward function. 903 handles_to_variables = self._func_graph.variable_captures 904 trainable_outputs = [] 905 trainable_indices = [] 906 for index, output in enumerate(outputs): 907 908 if backprop_util.IsTrainable(output): 909 # Swap in the Variable object for resource handles if we can so 910 # sparse gradients work. 911 output = handles_to_variables.get(id(output), output) 912 trainable_outputs.append(output) 913 trainable_indices.append(index) 914 915 backwards_graph = func_graph_module.FuncGraph( 916 _backward_name(self._func_graph.name)) 917 with backwards_graph.as_default(): 918 gradients_wrt_outputs = [] 919 for output in trainable_outputs: 920 gradient_shape, gradient_dtype = default_gradient.shape_and_dtype( 921 output) 922 gradient_placeholder = graph_placeholder(gradient_dtype, gradient_shape) 923 custom_gradient.copy_handle_data(output, gradient_placeholder) 924 gradients_wrt_outputs.append(gradient_placeholder) 925 with ops.device(None): 926 gradients_wrt_inputs = gradients_util._GradientsHelper( # pylint: disable=protected-access 927 trainable_outputs, 928 self._func_graph.inputs, 929 grad_ys=gradients_wrt_outputs, 930 src_graph=self._func_graph) 931 932 if input_tangents: 933 # Convert IndexedSlices to dense tensors (as we do elsewhere for 934 # function gradients). Our C++ bindings don't know how to handle them 935 # currently. 936 gradients_wrt_inputs = nest.map_structure( 937 lambda x: ops.convert_to_tensor(x) if x is not None else None, 938 gradients_wrt_inputs) 939 captures_from_forward = [ 940 c for c in backwards_graph.external_captures 941 if not isinstance(c, ops.EagerTensor) and c.graph is self._func_graph 942 ] 943 existing_outputs = object_identity.ObjectIdentitySet( 944 self._func_graph.outputs) 945 for capture in captures_from_forward: 946 if capture not in existing_outputs: 947 existing_outputs.add(capture) 948 self._func_graph.outputs.append(capture) 949 950 # The ordering of `backwards_graph.inputs` is important: inputs of 951 # `backward_function` correspond to outputs (including 952 # side outputs) of `self._tape_forward_function`. 953 backwards_graph.inputs = ( 954 gradients_wrt_outputs + backwards_graph.internal_captures) 955 backwards_graph.outputs.extend( 956 grad 957 for grad in nest.flatten(gradients_wrt_inputs, expand_composites=True) 958 if grad is not None) 959 backwards_graph.structured_outputs = gradients_wrt_inputs 960 961 forward_function, backward_function = _create_forward_backward_with_graph( 962 self._attrs, self._func_graph, backwards_graph) 963 964 if not input_tangents: 965 # There is no need to special-case forwardprop, so we can return the 966 # forward+backward pair we've created without further wrapping. 967 return (forward_function, self._func_graph, backward_function, 968 # No forwardprop outputs. 969 None, 0) 970 forward_wrapper = self._wrap_forward_function_with_jvps( 971 forward_function, backward_function, inference_args, input_tangents) 972 (wrapped_backwards_graph, 973 forward_wrapper) = self._wrap_backward_function_with_jvp_backprop( 974 backward_function, gradients_wrt_outputs, forward_wrapper) 975 # Now that we've added new captures, we need to make sure forward outputs 976 # are in the same order the backward function expects them to be in: 977 # [inference outputs] + [jvps] + [side outputs] + [captures]. 978 forward_wrapper = self._shuffle_forward_outputs(forward_wrapper) 979 (wrapped_forward_function, 980 wrapped_backward_function) = _create_forward_backward_with_graph( 981 self._attrs, forward_wrapper.graph, wrapped_backwards_graph) 982 if (len(inference_args) + len(input_tangents) 983 != len(forward_wrapper.graph.inputs)): 984 raise AssertionError( 985 ("Internal error: the forward graph had {} inputs, but we expected" 986 " {} ({} inference inputs and {} input tangents)") 987 .format(len(len(forward_wrapper.graph.inputs)), 988 len(inference_args) + len(input_tangents), 989 len(inference_args), len(input_tangents))) 990 return (wrapped_forward_function, forward_wrapper.graph, 991 wrapped_backward_function, forward_wrapper.output_indices, 992 len(forward_wrapper.output_tangents)) 993 994 def _wrap_forward_function_with_jvps( 995 self, forward_function, backward_function, 996 inference_args, input_tangents): 997 """Adds inline JVP computation to a forward function.""" 998 forward_wrapper_graph = func_graph_module.FuncGraph( 999 _forward_name(self._func_graph.name)) 1000 with forward_wrapper_graph.as_default(): 1001 # Tell forward accumulators to free up space for new JVP computations, 1002 # since one may be in the process of computing a JVP (if that computation 1003 # triggered this function building). 1004 # 1005 # We'll make symbolic versions of input JVPs, run the forward function 1006 # under forward accumulators to get symbolic output JVPs, then set those 1007 # as outputs of the new wrapped forward function. 1008 with forwardprop_util.push_forwardprop_state(): 1009 forward_captures = { 1010 ops.tensor_id(internal): external 1011 for external, internal in self._func_graph.captures} 1012 for input_index, real_input in enumerate(self._func_graph.inputs): 1013 # This loop is more or less equivalent to running tf.identity on each 1014 # of self._func_graph.inputs. However, doing that also captures jvps 1015 # for resource handles, which confuses the jvp capturing code below 1016 # (since primal inputs are interwoven with jvp inputs). 1017 input_placeholder = array_ops.placeholder( 1018 dtype=real_input.dtype, 1019 shape=real_input.shape) 1020 capture = forward_captures.get(ops.tensor_id(real_input)) 1021 if capture is not None: 1022 forward_wrapper_graph.add_capture(capture, input_placeholder) 1023 if capture.dtype == dtypes.resource: 1024 custom_gradient.copy_handle_data(capture, input_placeholder) 1025 else: 1026 forward_wrapper_graph.inputs.append(input_placeholder) 1027 for inp, arg in zip(forward_wrapper_graph.inputs, inference_args): 1028 tape.record_operation( 1029 "captured_value", [inp], [arg], 1030 backward_function=lambda x: [x], 1031 forward_function=lambda x: [x]) 1032 num_inference_inputs = len(inference_args) 1033 for tape_indices in self._forwardprop_input_indices: 1034 for input_index, jvp_index in tape_indices: 1035 input_placeholder = forward_wrapper_graph.inputs[input_index] 1036 if len(forward_wrapper_graph.inputs) != jvp_index: 1037 raise AssertionError( 1038 ("Internal error: expected {} forward graph inputs, but " 1039 "found {}.") 1040 .format(jvp_index, len(forward_wrapper_graph.inputs))) 1041 gradient_shape, gradient_dtype = default_gradient.shape_and_dtype( 1042 input_placeholder) 1043 jvp_placeholder = graph_placeholder(gradient_dtype, gradient_shape) 1044 external_jvp = input_tangents[jvp_index - num_inference_inputs] 1045 forward_wrapper_graph.add_capture(external_jvp, jvp_placeholder) 1046 tensor_shape.TensorShape( 1047 external_jvp.shape).assert_is_compatible_with( 1048 jvp_placeholder.shape) 1049 tape.record_operation( 1050 "captured_value", 1051 [jvp_placeholder], 1052 [external_jvp], 1053 backward_function=lambda x: [x], 1054 forward_function=lambda x: [x]) 1055 forward_inputs = forward_wrapper_graph.inputs[:num_inference_inputs] 1056 gradient_function = ( 1057 self._delayed_rewrite_functions._rewrite_forward_and_call_backward) # pylint: disable=protected-access 1058 with ops.get_default_graph()._override_gradient_function( # pylint: disable=protected-access 1059 {"PartitionedCall": gradient_function, 1060 "StatefulPartitionedCall": gradient_function}): 1061 forward_outputs = forward_function.call(context.context(), 1062 forward_inputs) 1063 if isinstance(forward_outputs, ops.Operation): 1064 # _wrapped_backward_function expects a list, but if the function has 1065 # no outputs its call() returns an Operation. We need to undo that 1066 # so we don't cause problems later. 1067 forward_outputs = [] 1068 py_backward, _ = self._wrap_backward_function( 1069 self._func_graph, backward_function, forward_outputs) 1070 # We will never request backward tape gradients for this operation 1071 # directly since we're wrapping the call; forwardprop will call the 1072 # backward function (and nested forward accumulators may build 1073 # higher-order gradients), but any watching GradientTapes should ignore 1074 # it. 1075 # 1076 # TODO(allenl): It might be better to explicitly stop backward recording 1077 # so we don't use the second-order tape cases unnecessarily. 1078 tape.record_operation_forwardprop_only( 1079 forward_function.signature.name, 1080 forward_outputs, forward_inputs, py_backward, None) 1081 output_indices, output_tangents = ( 1082 pywrap_tfe.TFE_Py_PackJVPs(forward_outputs)) 1083 output_tangents = [forward_wrapper_graph.capture(t) 1084 for t in output_tangents] 1085 return _ForwardWrapper( 1086 graph=forward_wrapper_graph, outputs=forward_outputs, 1087 output_indices=output_indices, output_tangents=output_tangents) 1088 1089 def _wrap_backward_function_with_jvp_backprop( 1090 self, backward_function, gradients_wrt_outputs, forward_wrapper): 1091 """Wraps `backward_function` to include gradients for JVPs.""" 1092 wrapped_backwards_graph = func_graph_module.FuncGraph( 1093 _backward_name(self._func_graph.name)) 1094 with wrapped_backwards_graph.as_default(): 1095 py_backward, recorded_outputs = self._wrap_backward_function( 1096 self._func_graph, backward_function, forward_wrapper.outputs) 1097 trainable_index = 0 1098 forward_doutputs = [] 1099 doutput_args = [] 1100 for output in recorded_outputs: 1101 if backprop_util.IsTrainable(output): 1102 doutput = gradients_wrt_outputs[trainable_index] 1103 doutput_placeholder = graph_placeholder(doutput.dtype, doutput.shape) 1104 doutput_args.append(doutput_placeholder) 1105 forward_doutputs.append(doutput_placeholder) 1106 trainable_index += 1 1107 else: 1108 doutput_args.append(None) 1109 1110 dinputs = py_backward(*doutput_args) 1111 existing_outputs = object_identity.ObjectIdentitySet( 1112 forward_wrapper.outputs + forward_wrapper.output_tangents) 1113 num_processed_output_tangents = 0 1114 gradients_wrt_output_tangents = [] 1115 tangent_doutputs = [] 1116 output_tangents = forward_wrapper.output_tangents 1117 output_indices = forward_wrapper.output_indices 1118 if self._need_gradients_for_jvps: 1119 # TODO(allenl): Consider using a throwaway graph to avoid extra gradient 1120 # evaluations; gradients for jvps may have common subgraphs. 1121 while num_processed_output_tangents != len(output_tangents): 1122 for output in output_tangents[num_processed_output_tangents:]: 1123 gradient_shape, gradient_dtype = default_gradient.shape_and_dtype( 1124 output) 1125 placeholder = graph_placeholder(gradient_dtype, gradient_shape) 1126 gradients_wrt_output_tangents.append(placeholder) 1127 tangent_doutputs.append(placeholder) 1128 num_processed_output_tangents = len(output_tangents) 1129 with ops.device(None): 1130 gradients_wrt_inputs = gradients_util._GradientsHelper( # pylint: disable=protected-access 1131 output_tangents, 1132 forward_wrapper.graph.inputs, 1133 grad_ys=gradients_wrt_output_tangents, 1134 src_graph=forward_wrapper.graph) 1135 dinputs = [ 1136 backprop.aggregate_indexed_slices_gradients((existing, new)) 1137 for existing, new in zip(dinputs, gradients_wrt_inputs) 1138 if existing is not None or new is not None] 1139 dinputs.extend(gradients_wrt_inputs[len(dinputs):]) 1140 captures_from_forward = [ 1141 c for c in wrapped_backwards_graph.external_captures 1142 if (not isinstance(c, ops.EagerTensor) 1143 and c.graph is forward_wrapper.graph)] 1144 for capture in captures_from_forward: 1145 if capture not in existing_outputs: 1146 existing_outputs.add(capture) 1147 forward_wrapper.outputs.append(capture) 1148 output_indices, output_tangents = ( 1149 forwardprop_util.pack_tangents(forward_wrapper.outputs)) 1150 output_tangents = [forward_wrapper.graph.capture(t) 1151 for t in output_tangents] 1152 for t in output_tangents: 1153 existing_outputs.add(t) 1154 wrapped_backwards_graph.inputs = ( 1155 forward_doutputs[:self._num_trainable_inference_outputs] 1156 + tangent_doutputs 1157 + forward_doutputs[self._num_trainable_inference_outputs:] 1158 + wrapped_backwards_graph.internal_captures) 1159 wrapped_backwards_graph.structured_outputs = dinputs 1160 wrapped_backwards_graph.outputs = [t for t in dinputs if t is not None] 1161 return (wrapped_backwards_graph, 1162 forward_wrapper._replace(output_indices=output_indices, 1163 output_tangents=output_tangents)) 1164 1165 def _shuffle_forward_outputs(self, forward_wrapper): 1166 """Reorders function outputs so captures are last.""" 1167 def _index_map(original): 1168 if original < self._num_inference_outputs: 1169 return original 1170 if original >= len(forward_wrapper.outputs): 1171 return (original - len(forward_wrapper.outputs) 1172 + self._num_inference_outputs) 1173 return original + len(forward_wrapper.output_tangents) 1174 output_indices = nest.map_structure( 1175 _index_map, forward_wrapper.output_indices) 1176 forward_wrapper.graph.outputs = ( 1177 forward_wrapper.outputs[:self._num_inference_outputs] 1178 + forward_wrapper.output_tangents 1179 + forward_wrapper.outputs[self._num_inference_outputs:]) 1180 return forward_wrapper._replace(output_indices=output_indices) 1181 1182 def forward(self, inference_args, input_tangents): 1183 """Construct or fetch a forward function with side-outputs. 1184 1185 When graph building without a tape active, symbolic gradients rely on 1186 regenerating the backward function for higher-order gradients (to account 1187 for new side outputs of the rewritten forward function call). Thus there is 1188 no fixed backward function for this case. However, when a tape is active 1189 (eager or graph building), we generate fixed backward and forward functions 1190 at forward function call time. 1191 1192 This difference between the tape and non-tape cases is to avoid building 1193 unneeded backward functions while graph building (where we may or may not 1194 eventually need gradients). 1195 1196 Args: 1197 inference_args: A flat list of Tensors, arguments to the inference 1198 function. 1199 input_tangents: A flat list of Tensors, jvps associated with 1200 `inference_args`. 1201 1202 Returns: 1203 A forward _EagerDefinedFunction. 1204 """ 1205 if self._forward is None: 1206 (self._forward, self._forward_graph, self._backward, 1207 self._forwardprop_output_indices, self._num_forwardprop_outputs) = ( 1208 self._forward_and_backward_functions(inference_args, input_tangents)) 1209 return self._forward 1210 1211 def _wrap_backward_function(self, forward_graph, backward, outputs): 1212 """Create a backward function given `outputs` from the forward function.""" 1213 capture_mapping = dict( 1214 zip((ops.tensor_id(t) for t in forward_graph.outputs), outputs)) 1215 captured_inputs = backward.captured_inputs 1216 remapped_captures = [ 1217 capture_mapping.get(ops.tensor_id(capture), capture) 1218 for capture in captured_inputs 1219 ] 1220 if any(t.graph is forward_graph for t in remapped_captures 1221 if not isinstance(t, ops.EagerTensor)): 1222 raise AssertionError( 1223 "Internal error: failed to map all backward graph captures to the " 1224 "forward graph. Incorrectly mapped: {}".format( 1225 [t for t in remapped_captures 1226 if (not isinstance(t, ops.EagerTensor) 1227 and t.graph is not forward_graph)])) 1228 # We may need to use zeros_like to get a zero for variant Tensors with 1229 # unconnected gradients. We do that in advance so we don't have to hold on 1230 # to the outputs themselves, which may not be needed otherwise. 1231 variant_zeros_like = {} 1232 backward_function_inputs = (len(backward.inputs) - len(captured_inputs)) 1233 recorded_outputs = [] 1234 trainable_recorded_outputs = 0 1235 skip_positions = [] 1236 if self._num_forwardprop_outputs and not self._need_gradients_for_jvps: 1237 relevant_outputs = ( 1238 outputs[:self._num_inference_outputs] 1239 + outputs[self._num_inference_outputs 1240 + self._num_forwardprop_outputs:]) 1241 else: 1242 relevant_outputs = outputs 1243 for output_index, output in enumerate(relevant_outputs): 1244 if trainable_recorded_outputs < backward_function_inputs: 1245 recorded_outputs.append(output) 1246 if backprop_util.IsTrainable(output): 1247 trainable_recorded_outputs += 1 1248 else: 1249 skip_positions.append(output_index) 1250 if output.dtype == dtypes.variant: 1251 variant_zeros_like[output_index] = default_gradient.zeros_like(output) 1252 1253 def _backward_function_wrapper(*args): 1254 """Process output gradients and call the backward function.""" 1255 if not backward.outputs: 1256 return backward.structured_outputs 1257 1258 processed_args = [] 1259 input_index = 0 1260 for output_index, arg in enumerate(args): 1261 # Convert IndexedSlices to dense tensors. The IndexedSlices optimization 1262 # is only really effective when doing tf.gather(variable) as the 1263 # adjoint functions for most operations are unlikely to preserve the 1264 # sparsity in IndexedSlices. 1265 if isinstance(arg, ops.IndexedSlices): 1266 arg = ops.convert_to_tensor(arg) 1267 if output_index in skip_positions: 1268 continue 1269 if arg is None: 1270 # We're calling a (non-polymorphic) ConcreteFunction, so we need to 1271 # have a Tensor value for each Tensor we thought would be trainable 1272 # based on its dtype, even if it ended up being unconnected. 1273 input_placeholder = backward.inputs[ 1274 input_index] 1275 if input_placeholder.dtype == dtypes.variant: 1276 arg = variant_zeros_like[output_index] 1277 else: 1278 arg = array_ops.zeros( 1279 *default_gradient.shape_and_dtype(input_placeholder)) 1280 processed_args.append(arg) 1281 input_index += 1 1282 if input_index >= backward_function_inputs: 1283 break 1284 return backward._call_flat( # pylint: disable=protected-access 1285 processed_args, remapped_captures) 1286 1287 return _backward_function_wrapper, recorded_outputs 1288 1289 def record(self, flat_outputs, inference_args, input_tangents): 1290 """Record the function call operation. 1291 1292 For backprop, indicates the backward function to use and which new Tensors 1293 must be watched. For forwardprop from eager, the function call itself will 1294 have produced tangents which need to be recorded. 1295 1296 Args: 1297 flat_outputs: The result of running `forward`. 1298 inference_args: A flat list of Tensors with inference inputs to the 1299 operation. 1300 input_tangents: A flat list of Tensors with input tangents consumed by the 1301 operation. 1302 """ 1303 backward_function, to_record = self._wrap_backward_function( 1304 self._forward_graph, self._backward, flat_outputs) 1305 if self._forwardprop_output_indices: 1306 tape.record_operation_backprop_only( 1307 self._forward.signature.name, 1308 to_record, inference_args, 1309 backward_function) 1310 tape.record_operation_forwardprop_only( 1311 self._forward.signature.name, 1312 flat_outputs, inference_args + input_tangents, 1313 backward_function, 1314 self._forwardprop_output_indices) 1315 else: 1316 tape.record_operation(self._forward.signature.name, 1317 to_record, inference_args + input_tangents, 1318 backward_function) 1319 1320 1321class _FirstOrderTapeGradientFunctions(_TapeGradientFunctions): 1322 """Caches tape-friendly functions for first-order gradients.""" 1323 1324 def __init__(self, func_graph, attrs, func_graph_deleter, 1325 forwardprop_input_indices, delayed_rewrite_functions, 1326 need_gradients_for_jvps): 1327 super(_FirstOrderTapeGradientFunctions, self).__init__( 1328 func_graph, attrs, func_graph_deleter, forwardprop_input_indices, 1329 delayed_rewrite_functions, need_gradients_for_jvps) 1330 self._func_graph_deleter = func_graph_deleter 1331 self._forwardprop_input_indices = forwardprop_input_indices 1332 1333 def _forward_and_backward_functions(self, inference_args, input_tangents): 1334 """Shortcut for when only first-order gradients are required. 1335 1336 The returned backward function does not accept gradients with respect to 1337 side output of forward_function. This is fine as long as the user can't 1338 possibly request second order tape gradients, as when they've used a single 1339 non-persistent GradientTape. Since we don't need the backward function to 1340 take gradients with respect to side outputs, we can skip some potentially 1341 slow graph building. 1342 1343 Args: 1344 inference_args: A flat list of Tensors, arguments to the inference 1345 function. 1346 input_tangents: A flat list of Tensors, jvps associated with 1347 `inference_args`. 1348 1349 Returns: 1350 A tuple of (forward_function, backward_function): 1351 forward_function: Takes the same inputs as the inference function, but 1352 returns side outputs used by backward_function in addition to the 1353 inference function's outputs. 1354 backward_function: Takes side outputs from forward_function and 1355 gradients with respect to the "real" outputs of forward_function and 1356 returns gradients with respect to the inputs. 1357 """ 1358 outputs = self._func_graph.outputs[:self._num_inference_outputs] 1359 return self._build_functions_for_outputs( 1360 outputs, inference_args, input_tangents) 1361 1362 1363class _HigherOrderTapeGradientFunctions(_TapeGradientFunctions): 1364 """Caches tape-friendly functions for higher-order gradients.""" 1365 1366 # TODO(b/136189779): Cond/while under a tape may need similar logic. Consider 1367 # generalizing if so. 1368 def _forward_and_backward_functions(self, inference_args, input_tangents): 1369 """Forward and backward functions suitable for higher-order gradients. 1370 1371 Unlike in `_FirstOrderTapeGradientFunctions`, the backward function built by 1372 this method accepts gradients for all of the outputs of the returned forward 1373 function, including side outputs. 1374 1375 Args: 1376 inference_args: A flat list of Tensors, arguments to the inference 1377 function. 1378 input_tangents: A flat list of Tensors, jvps associated with 1379 `inference_args`. 1380 1381 Returns: 1382 A tuple of (forward_function, backward_function): 1383 forward_function: Takes the same inputs as the inference function, but 1384 returns side outputs used by backward_function in addition to the 1385 inference function's outputs. 1386 backward_function: Takes side outputs from forward_function and 1387 gradients with respect to all of its outputs, real and side. Returns 1388 gradients with respect to the inputs. 1389 """ 1390 outputs = [] 1391 iteration_count = 0 1392 # First we need to figure out how many side outputs from the forward pass 1393 # will be required. We do this in a temporary graph to avoid actually 1394 # running multiple copies of the backward pass (one per _GradientsHelper 1395 # call). 1396 # 1397 # While computing gradients, the backward function captures Tensors from 1398 # the forward function. We add these as side outputs of the original 1399 # function. However, we then need to accept output gradients with respect 1400 # to these side outputs for higher order gradients to work. Thus we loop 1401 # until the number of outputs of the function stabilizes. Note that this 1402 # is only required for tape gradients, where we need to declare in advance 1403 # all of the forward op's outputs: symbolic gradients with tf.gradients 1404 # instead rely on regenerating backward functions when higher-order 1405 # gradients are requested. 1406 while (len(outputs) < len(self._func_graph.outputs) 1407 # It's possible for gradient generation to add new ops to the forward 1408 # pass. If all of the new outputs are non-trainable, there's no 1409 # reason to continue. 1410 and any(backprop_util.IsTrainable(output) 1411 for output in self._func_graph.outputs[len(outputs):])): 1412 iteration_count += 1 1413 if iteration_count >= 20 and iteration_count % 5 == 0: 1414 new_op_with_trainable_output = None 1415 num_new_trainable_outputs = 0 1416 for output in self._func_graph.outputs[len(outputs):]: 1417 if backprop_util.IsTrainable(output): 1418 num_new_trainable_outputs += 1 1419 new_op_with_trainable_output = output.op 1420 logging.warning( 1421 ("Determining side outputs for the function '{}' is taking longer " 1422 "than expected ({} iterations, typically this converges in 5 or " 1423 "so). This could indicate that a gradient registration is adding " 1424 "new ops to the forward pass every time gradients are generated. " 1425 "{} new trainable output(s) were added this iteration, one from " 1426 "the following op:\n {}\nThis may indicate a TensorFlow bug, or " 1427 "an issue in a tf.custom_gradient.") 1428 .format( 1429 self._func_graph.name, iteration_count, 1430 num_new_trainable_outputs, new_op_with_trainable_output)) 1431 outputs = list(self._func_graph.outputs) 1432 self._build_functions_for_outputs( 1433 outputs, inference_args, input_tangents) 1434 1435 (forward_function, forward_graph, 1436 backward_function, output_indices, num_output_tangents) = ( 1437 self._build_functions_for_outputs( 1438 outputs, inference_args, input_tangents)) 1439 if (len(self._func_graph.outputs) > len(outputs) 1440 and any(backprop_util.IsTrainable(output) 1441 for output in self._func_graph.outputs[len(outputs):])): 1442 raise AssertionError( 1443 ("Unexpectedly added new outputs to the forward function when " 1444 "building the backward function: {}").format( 1445 self._func_graph.outputs[len(outputs):])) 1446 return (forward_function, forward_graph, backward_function, output_indices, 1447 num_output_tangents) 1448 1449 1450class _ForwardBackwardCall(object): 1451 """Holds the state of a function call between execution and recording.""" 1452 1453 __slots__ = [ 1454 "_functions", "_inference_args", "_input_tangents", "_tape_watching" 1455 ] 1456 1457 def __init__(self, functions, inference_args, input_tangents, tape_watching): 1458 """Collects information about the function call. 1459 1460 Args: 1461 functions: An object which produces forward and backward functions, either 1462 a _DelayedRewriteGradientFunctions or a _TapeGradientFunctions object. 1463 inference_args: A flat list of Tensors, arguments to the inference 1464 function. 1465 input_tangents: A flat list of Tensors, jvps associated with 1466 `inference_args`. 1467 tape_watching: Boolean, with True indicating that recording is necessary. 1468 """ 1469 self._functions = functions 1470 self._inference_args = inference_args 1471 self._input_tangents = input_tangents 1472 self._tape_watching = tape_watching 1473 1474 def forward(self): 1475 """Builds or retrieves a forward function for this call.""" 1476 forward_function = self._functions.forward( 1477 self._inference_args, self._input_tangents) 1478 return forward_function, self._inference_args + self._input_tangents 1479 1480 def record(self, flat_outputs): 1481 """Given outputs from the execution of `forward`, records the operation.""" 1482 if (self._tape_watching 1483 and not isinstance(flat_outputs, ops.Operation) 1484 and flat_outputs is not None): 1485 # We only record function calls which have outputs, and then only when a 1486 # tape is watching. 1487 self._functions.record( 1488 flat_outputs, self._inference_args, self._input_tangents) 1489 1490 1491# Sentinel value used by with ConcreteFunction's structured signature to 1492# indicate that a non-tensor parameter should use the value that was 1493# specified when the concrete function was created. 1494_BOUND_VALUE = object() 1495 1496 1497class ConcreteFunction(object): 1498 """Callable object encapsulating a function definition and its gradient. 1499 1500 `ConcreteFunction` is a callable that encapsulates a function definition and 1501 is differentiable under `tf.GradientTape` objects. 1502 """ 1503 1504 def __init__(self, 1505 func_graph, 1506 attrs=None, 1507 shared_func_graph=True, 1508 function_spec=None): 1509 """Initialize a `ConcreteFunction`. 1510 1511 Args: 1512 func_graph: An instance of FuncGraph: the function body to wrap. 1513 attrs: (optional) dict mapping names of attributes to their AttrValue 1514 values. Attributes in `attrs` will be included in this function's 1515 definition. 1516 shared_func_graph: If False, the ConcreteFunction takes ownership of 1517 `func_graph` and will break reference cycles when it is deleted. This 1518 makes the FuncGraph inoperable. 1519 function_spec: FunctionSpec for the original function. If not specified, 1520 then this ConcreteFunction may only be called using the flat signature. 1521 1522 Raises: 1523 ValueError: If number of input_placeholders is not equal to the number 1524 of function inputs. 1525 """ 1526 # _arg_keywords and _num_positional_args define the flat signature. They 1527 # are assigned after construction. 1528 self._arg_keywords = None 1529 self._num_positional_args = None 1530 1531 self._func_graph = func_graph 1532 self._captured_inputs = self._func_graph.external_captures 1533 self._captured_closures = self._func_graph.deferred_external_captures 1534 1535 # function_spec defines the structured signature. 1536 self._set_function_spec(function_spec) 1537 1538 if attrs and IMPLEMENTS_ATTRIBUTE_NAME in attrs: 1539 # The alternative is to silently drop "implements" tag 1540 # but it seems likely it would lead to hard to catch bugs. 1541 # Another alternative is to make func_body to preserve the order 1542 # of arguments if variables are present. Yet another option 1543 # is to automatically replace variables as arguments to functions 1544 # to v.read_value() whenever "implements" tag is present 1545 # Anytime we annotate existing function we probably want to wrap 1546 # it with safe read_value for backward compatibility. 1547 has_resource_vars = any(inp.dtype == dtypes.resource 1548 for inp in self.inputs) 1549 1550 assert not any( 1551 (has_resource_vars, self._captured_inputs, self._captured_closures) 1552 ), ('Function {name} has "{attr}={value}" attribute and thus can not ' 1553 "depend on any tensors outside of its signature or modify variables. " 1554 "\n\nNote: variables are always captured and cause function " 1555 "re-tracing for every variable called.\n" 1556 " inputs: {inputs}\n captures: {captured}\n" 1557 " closures: {closures}.\n\n" 1558 "To pass a variable to such function use " 1559 "use variable.read_value().".format( 1560 name=func_graph.name, 1561 attr=IMPLEMENTS_ATTRIBUTE_NAME, 1562 value=attrs[IMPLEMENTS_ATTRIBUTE_NAME], 1563 inputs=self.inputs, 1564 captured=self._captured_inputs, 1565 closures=self._captured_closures)) 1566 self._output_shapes = tuple( 1567 output.shape for output in self._func_graph.outputs) 1568 self._attrs = _parse_func_attrs(attrs or {}) 1569 1570 if shared_func_graph: 1571 self._garbage_collector = None 1572 else: 1573 self._garbage_collector = ConcreteFunctionGarbageCollector(func_graph) 1574 1575 # Pairs of forward and backward functions used for computing gradients. 1576 # 1577 # These each get a reference to the FuncGraph deleter since they use the 1578 # FuncGraph directly. 1579 self._delayed_rewrite_functions = _DelayedRewriteGradientFunctions( 1580 func_graph, self._attrs, self._garbage_collector) 1581 self._first_order_tape_functions = {} 1582 self._higher_order_tape_functions = {} 1583 # Cache the inference function to avoid a (Python) function call when not 1584 # building gradients. 1585 self._inference_function = self._delayed_rewrite_functions.forward() 1586 1587 def _set_function_spec(self, function_spec): 1588 """Enables the structured signature by supplying a function_spec.""" 1589 self._function_spec = None 1590 self._pre_initialized_function_spec = function_spec 1591 1592 # Note: when ConcreteFunctions are built by recreate_function() in 1593 # function_deserialization.py, they don't have a structured_input_signature 1594 # yet. In that case, _initialize_function_spec() gets called by 1595 # _setup_functions_structures() in load.py. 1596 if (function_spec is not None and 1597 self.structured_input_signature is not None): 1598 self._initialize_function_spec() 1599 1600 def _initialize_function_spec(self): 1601 """Updates `self._function_spec` to include varargs and bound variables. 1602 1603 Adds new positional arguments for any varargs (i.e., for args that are 1604 in `structured_input_signature`, but not in the original fullargspec.args). 1605 1606 Replaces `defaults` and `kwonlydefaults` with the `_BOUND_VALUE`, for 1607 all args and kwargs in `structured_input_signature`. 1608 1609 Sets `varkw` and `varargs` to None. 1610 """ 1611 if self._pre_initialized_function_spec is None: 1612 return # e.g., SavedBareConcreteFunction doesn't have function_spec yet. 1613 assert not self._function_spec, "already initialized" 1614 function_spec = self._pre_initialized_function_spec 1615 args = function_spec.fullargspec.args 1616 arg_specs, kwarg_specs = self.structured_input_signature 1617 vararg_indices = range(len(function_spec.arg_names), len(arg_specs)) 1618 fullargspec = tf_inspect.FullArgSpec( 1619 args=list(args) + ["<arg{}>".format(i + 1) for i in vararg_indices], 1620 varargs=None, 1621 varkw=None, 1622 defaults=[_BOUND_VALUE] * len(arg_specs), 1623 kwonlyargs=list(sorted(kwarg_specs)), 1624 kwonlydefaults=dict((k, _BOUND_VALUE) for k in kwarg_specs), 1625 annotations=function_spec.fullargspec.annotations) 1626 self._function_spec = FunctionSpec( 1627 fullargspec, 1628 function_spec.is_method, 1629 function_spec.input_signature, 1630 function_spec.is_pure, 1631 name=self._func_graph.name) 1632 1633 @property 1634 def variables(self): 1635 """Sequence of variables for this function.""" 1636 return tuple(self._func_graph.variables) 1637 1638 @property 1639 def trainable_variables(self): 1640 """Sequence of trainable variables for this function.""" 1641 return tuple(self._func_graph.trainable_variables) 1642 1643 def __call__(self, *args, **kwargs): 1644 """Executes the wrapped function. 1645 1646 ConcreteFunctions have two signatures: 1647 1648 * The signature of the original function wrapped by this ConcreteFunction. 1649 * A flat signature, where each argument accepts a single Tensor. 1650 1651 The original function signature is generally preferred, but the flat input 1652 signature is supported for backward compatibility. 1653 1654 ### Original Function Signature 1655 1656 When calling a ConcreteFunction with the signature of the original function, 1657 each argument must match the type or value that was used when the 1658 ConcreteFunction's graph was traced. In particular: 1659 1660 * Tensor arguments (including CompositeTensors, such as RaggedTensor) must 1661 have matching `TypeSpec`s. 1662 * Non-Tensor arguments (such as booleans or ints) must have equal values. 1663 * Nested arguments (such as lists, tuples, or dictionaries) must have the 1664 same nesting structure; and each nested value must have a matching type 1665 or value. 1666 1667 The default value for any arguments that were traced with non-Tensor values 1668 is the value that was used in the trace. Arguments that were traced with 1669 tensor arguments do not have a default value (even if the original function 1670 had a default value for that argument). 1671 1672 ### Flat Signature 1673 1674 When calling a ConcreteFunction with the flat signature, the arguments 1675 correspond to the flattened component tensors of the arguments that were 1676 used to construct the ConcreteFunction. Parameter names are assigned based 1677 on `TensorSpec.name` (when specified) or the original argument names (with 1678 suffixes automatically added for nested arguments or composite tensors with 1679 multiple components). 1680 1681 Args: 1682 *args: Positional arguments to the concrete function. 1683 **kwargs: Keyword arguments to the concrete function. 1684 1685 Returns: 1686 The result of applying the TF function on the given Tensors. 1687 1688 Raises: 1689 AssertionError: If this `ConcreteFunction` was not created through 1690 `get_concrete_function`. 1691 TypeError: If the arguments do not match the function's signature. 1692 """ 1693 return self._call_impl(args, kwargs) 1694 1695 def _call_impl(self, args, kwargs, cancellation_manager=None): 1696 """See `__call__` for details.""" 1697 with trace.Trace(self._func_graph.name, tf_function_call="concrete"): 1698 # Construct the list of input tensors: check if the structured signature 1699 # applies first; and if not, then use the flat signature. 1700 if self._function_spec is not None: 1701 try: 1702 return self._call_with_structured_signature(args, kwargs, 1703 cancellation_manager) 1704 except TypeError as structured_err: 1705 try: 1706 return self._call_with_flat_signature(args, kwargs, 1707 cancellation_manager) 1708 except TypeError: 1709 raise structured_err 1710 1711 return self._call_with_flat_signature(args, kwargs, cancellation_manager) 1712 1713 def _call_with_flat_signature(self, args, kwargs, cancellation_manager): 1714 """Executes the wrapped function with the flat signature. 1715 1716 Args: 1717 args: Positional arguments to the concrete function. 1718 kwargs: Keyword arguments to the concrete function. 1719 cancellation_manager: A `CancellationManager` that can be used to cancel 1720 function invocation. 1721 1722 Returns: 1723 The result of applying the function on the Tensors/Variables contained in 1724 `args` and `kwargs`. 1725 Raises: 1726 TypeError: if `args` and `kwargs` do not match the flat signature of this 1727 `ConcreteFunction`. 1728 """ 1729 if len(args) > self._num_positional_args: 1730 raise TypeError( 1731 "{} takes {} positional arguments but {} were given".format( 1732 self._flat_signature_summary(), self._num_positional_args, 1733 len(args))) 1734 args = list(args) 1735 kwargs = dict(kwargs) 1736 for keyword in self._arg_keywords[len(args):]: 1737 try: 1738 args.append(kwargs.pop(compat.as_str(keyword))) 1739 except KeyError: 1740 specified_keywords = ( 1741 list(self._arg_keywords[:len(args)]) + list(kwargs.keys())) 1742 raise TypeError("{} missing required arguments: {}".format( 1743 self._flat_signature_summary(), ", ".join( 1744 sorted(set(self._arg_keywords) - set(specified_keywords))))) 1745 if kwargs: 1746 positional_arg_keywords = set(self._arg_keywords[:len(args)]) 1747 for unused_key in kwargs: 1748 if unused_key in positional_arg_keywords: 1749 raise TypeError("{} got two values for argument '{}'".format( 1750 self._flat_signature_summary(), unused_key)) 1751 raise TypeError("{} got unexpected keyword arguments: {}.".format( 1752 self._flat_signature_summary(), ", ".join(sorted(kwargs)))) 1753 1754 for i, arg in enumerate(args): 1755 if not isinstance( 1756 arg, (ops.Tensor, resource_variable_ops.BaseResourceVariable)): 1757 raise TypeError("{}: expected argument #{}(zero-based) to be a Tensor; " 1758 "got {} ({})".format(self._flat_signature_summary(), i, 1759 type(arg).__name__, str(arg))) 1760 return self._call_flat(args, self.captured_inputs, cancellation_manager) 1761 1762 def _call_with_structured_signature(self, args, kwargs, cancellation_manager): 1763 """Executes the wrapped function with the structured signature. 1764 1765 Args: 1766 args: Positional arguments to the concrete function. 1767 kwargs: Keyword arguments to the concrete function. 1768 cancellation_manager: A `CancellationManager` that can be used to cancel 1769 function invocation. 1770 1771 Returns: 1772 The result of applying the function on the Tensors/Variables contained in 1773 `args` and `kwargs`. 1774 Raises: 1775 TypeError: if `args` and `kwargs` do not match the structured signature 1776 of this `ConcreteFunction`. 1777 """ 1778 args, kwargs, _, filtered_flat_args = \ 1779 self._function_spec.canonicalize_function_inputs(*args, **kwargs) 1780 self._structured_signature_check_missing_args(args, kwargs) 1781 self._structured_signature_check_unexpected_args(args, kwargs) 1782 self._structured_signature_check_arg_types(args, kwargs) 1783 return self._call_flat( 1784 filtered_flat_args, 1785 captured_inputs=self.captured_inputs, 1786 cancellation_manager=cancellation_manager) 1787 1788 def _structured_signature_check_missing_args(self, args, kwargs): 1789 """Raises a TypeError if any args are missing.""" 1790 arg_specs, kwarg_specs = self.structured_input_signature 1791 missing_arguments = [] 1792 for i, (arg, spec) in enumerate(zip(args, arg_specs)): 1793 if arg is _BOUND_VALUE and _contains_type_spec(spec): 1794 missing_arguments.append(self._function_spec.arg_names[i]) 1795 for (name, arg) in kwargs.items(): 1796 if arg is _BOUND_VALUE and _contains_type_spec(kwarg_specs[name]): 1797 missing_arguments.append(name) 1798 if missing_arguments: 1799 raise TypeError("{} missing required arguments: {}".format( 1800 self._structured_signature_summary(), 1801 ", ".join(sorted(missing_arguments)))) 1802 1803 def _structured_signature_check_unexpected_args(self, args, kwargs): 1804 """Raises a TypeError if there are any extra args.""" 1805 arg_specs, kwarg_specs = self.structured_input_signature 1806 if len(args) > len(arg_specs): 1807 raise TypeError( 1808 "{} takes {} positional arguments but {} were given".format( 1809 self._structured_signature_summary(), 1810 len(self._function_spec.arg_names), len(args))) 1811 if len(kwargs) > len(kwarg_specs): 1812 extra_args = set(kwargs) - set(kwarg_specs) 1813 raise TypeError("{} got unexpected keyword arguments: {}".format( 1814 self._structured_signature_summary(), ", ".join(extra_args))) 1815 1816 def _structured_signature_check_arg_types(self, args, kwargs): 1817 """Raises a TypeError if any args have the wrong type.""" 1818 # Check argument types 1819 arg_specs, kwarg_specs = self.structured_input_signature 1820 for i, (arg, spec) in enumerate(zip(args, arg_specs)): 1821 name = self._function_spec.arg_names[i] 1822 self._structured_signature_check_arg_type(arg, spec, name) 1823 for (name, arg) in kwargs.items(): 1824 self._structured_signature_check_arg_type(arg, kwarg_specs[name], name) 1825 1826 def _structured_signature_check_arg_type(self, arg, spec, name): 1827 """Raise TypeError if `arg`'s type doesn't match `spec`.""" 1828 if arg is _BOUND_VALUE: 1829 return 1830 1831 # Check the overall nested structure of the argument. 1832 try: 1833 nest.assert_same_structure(arg, spec, expand_composites=True) 1834 except (ValueError, TypeError): 1835 try: 1836 nest.assert_same_structure(arg, spec, expand_composites=False) 1837 expected, got = spec, arg 1838 except (ValueError, TypeError): 1839 expected, got = _structure_summary(spec), _structure_summary(arg) 1840 raise TypeError("{}: argument {} had incorrect type\n" 1841 " expected: {}\n got: {}".format( 1842 self._structured_signature_summary(), name, expected, 1843 got)) 1844 1845 # Check the type for each leaf in the nested structure. 1846 arg_pieces = nest.flatten(arg, expand_composites=True) 1847 spec_pieces = nest.flatten(spec, expand_composites=True) 1848 for (arg_piece, spec_piece) in zip(arg_pieces, spec_pieces): 1849 if isinstance(spec_piece, tensor_spec.DenseSpec): 1850 # TODO(edloper): Consider calling convert_to_tensor on non-tensor 1851 # values here. That would match the behavior of 1852 # _call_concrete_function() in function_deserialization.py. If 1853 # we do, then we need to change the nest assert_same_structure and 1854 # flatten calls above to use shallow variants. 1855 tensor_types = (ops.Tensor, resource_variable_ops.BaseResourceVariable) 1856 if not isinstance(arg_piece, tensor_types): 1857 raise TypeError( 1858 "{} expected a Tensor in {}, but got {} value {}".format( 1859 self._structured_signature_summary(), name, 1860 type(arg_piece).__name__, arg_piece)) 1861 elif arg_piece is not _BOUND_VALUE and arg_piece != spec_piece: 1862 raise TypeError("ConcreteFunction {} was constructed with {} value " 1863 "{} in {}, but was called with {} value {}".format( 1864 self._structured_signature_summary(), 1865 type(spec_piece).__name__, spec_piece, name, 1866 type(arg_piece).__name__, arg_piece)) 1867 1868 def _call_flat(self, args, captured_inputs, cancellation_manager=None): 1869 """Executes the wrapped function. 1870 1871 Args: 1872 args: a list of Tensors or Variables. Arguments from the Python function 1873 should be filtered before calling this method: objects aside from 1874 Tensors, CompositeTensors, and Variables are ignored. Any 1875 CompositeTensors should be expanded before calling this method. 1876 captured_inputs: the captured inputs that are also part of the input args 1877 to the actual execution. By default, it should be self._captured_inputs. 1878 cancellation_manager: (Optional.) A `CancellationManager` that can be 1879 used to cancel function invocation. 1880 1881 Returns: 1882 The result of applying the TF function to `args`. 1883 1884 Raises: 1885 ValueError: If `args` contains anything other than Tensors or Variables. 1886 """ 1887 ctx = context.context() 1888 executing_eagerly = ctx.executing_eagerly() 1889 1890 # Copy saveable status of function's graph to current FuncGraph. 1891 default_graph = ops.get_default_graph() 1892 if default_graph.building_function and not self._func_graph.saveable: 1893 default_graph.mark_as_unsaveable(self._func_graph.saving_errors) 1894 1895 if (tape.could_possibly_record() or 1896 hasattr(default_graph, "watch_variable")): 1897 for v in self._func_graph.variables: 1898 resource_variable_ops.variable_accessed(v) 1899 1900 tensor_inputs = [] 1901 variables_used = set([]) 1902 for i, arg in enumerate(args): 1903 if isinstance(arg, resource_variable_ops.BaseResourceVariable): 1904 # We can pass a variable more than once, and in this case we need to 1905 # pass its handle only once. 1906 if id(arg.handle) in variables_used: 1907 continue 1908 resource_variable_ops.variable_accessed(arg) 1909 tensor_inputs.append(arg.handle) 1910 variables_used.add(id(arg.handle)) 1911 elif isinstance(arg, ops.Tensor): 1912 tensor_inputs.append(arg) 1913 if not executing_eagerly: 1914 # If we're graph building, shape inference is on. We check for input 1915 # compatibility up front to avoid hard to debug incompatibilities 1916 # later. 1917 graph_input_shape = tensor_shape.TensorShape( 1918 self._func_graph.inputs[i].shape) 1919 if not graph_input_shape.is_compatible_with(arg.shape): 1920 if self._arg_keywords: 1921 arg_name = "'{}'".format(self._arg_keywords[i]) 1922 else: 1923 arg_name = "with index {}".format(i) 1924 raise ValueError( 1925 ("The argument {} (value {}) is not compatible with the shape " 1926 "this function was traced with. Expected shape {}, but got " 1927 "shape {}.\n\nIf you called get_concrete_function, you may " 1928 "need to pass a tf.TensorSpec(..., shape=...) with a less " 1929 "specific shape, having None on axes which can vary.").format( 1930 arg_name, arg, 1931 self._func_graph.inputs[i].shape, 1932 arg.shape)) 1933 else: 1934 raise ValueError("All inputs to `ConcreteFunction`s must be Tensors; " 1935 "on invocation of %s, the %d-th input (%s) was not a " 1936 "Tensor." % (self._func_graph.name, i, str(arg))) 1937 args = tensor_inputs + captured_inputs 1938 possible_gradient_type = gradients_util.PossibleTapeGradientTypes(args) 1939 if (possible_gradient_type == gradients_util.POSSIBLE_GRADIENT_TYPES_NONE 1940 and executing_eagerly): 1941 # No tape is watching; skip to running the function. 1942 return self._build_call_outputs(self._inference_function.call( 1943 ctx, args, cancellation_manager=cancellation_manager)) 1944 forward_backward = self._select_forward_and_backward_functions( 1945 args, 1946 possible_gradient_type, 1947 executing_eagerly) 1948 forward_function, args_with_tangents = forward_backward.forward() 1949 if executing_eagerly: 1950 flat_outputs = forward_function.call( 1951 ctx, args_with_tangents, cancellation_manager=cancellation_manager) 1952 else: 1953 with default_graph._override_gradient_function( # pylint: disable=protected-access 1954 {"PartitionedCall": self._get_gradient_function(), 1955 "StatefulPartitionedCall": self._get_gradient_function()}): 1956 flat_outputs = forward_function.call(ctx, args_with_tangents) 1957 forward_backward.record(flat_outputs) 1958 return self._build_call_outputs(flat_outputs) 1959 1960 def _experimental_with_cancellation_manager(self, cancellation_manager): 1961 """Returns a callable that invokes a cancellable version of this function. 1962 1963 Args: 1964 cancellation_manager: A `CancellationManager` object that can be used to 1965 cancel function invocation. 1966 1967 Returns: 1968 A callable with the same signature as this concrete function. 1969 """ 1970 1971 def cancellable_call(*args, **kwargs): 1972 return self._call_impl( 1973 args, kwargs, cancellation_manager=cancellation_manager) 1974 1975 return cancellable_call 1976 1977 @property 1978 def name(self): 1979 """`ConcreteFunction` name.""" 1980 return self._delayed_rewrite_functions.forward().name 1981 1982 @property 1983 def graph(self): 1984 """Returns the graph from which this function was constructed.""" 1985 return self._func_graph 1986 1987 @property 1988 def inputs(self): 1989 """Returns tensors in `self.graph` corresponding to arguments.""" 1990 return self._func_graph.inputs 1991 1992 @property 1993 def structured_input_signature(self): 1994 """Returns structured signature for this concrete function. 1995 1996 Returns: 1997 A tuple `(args, kwargs)`, where: 1998 1999 * `args` is a tuple that specifies the expected type or value each for 2000 positional argument. 2001 * `kwargs` is a dictionary that specifies the expected type or value 2002 for each keyword-only argument. 2003 2004 The type or value for each argument is specified using one of the 2005 following: 2006 2007 * A `tf.TypeSpec`, indicating that a Tensor or other TensorFlow-native 2008 value is expected. 2009 * A Python value, such as an integer, indicating that an equal value 2010 is expected. 2011 * A nested structure of `tf.TypeSpec`s and Python values, indicating 2012 that a corresponding nested structure is expected. 2013 """ 2014 return self._func_graph.structured_input_signature 2015 2016 @property 2017 def outputs(self): 2018 """Returns tensors in `self.graph` corresponding to returned tensors.""" 2019 return self._func_graph.outputs 2020 2021 @property 2022 def structured_outputs(self): 2023 """Returns outputs in `self.graph` as returned by the original function.""" 2024 return self._func_graph.structured_outputs 2025 2026 @property 2027 def captured_inputs(self): 2028 """Returns external Tensors captured by this function. 2029 2030 self.__call__(*args) passes `args + self.captured_inputs` to the function. 2031 """ 2032 from_closures = nest.flatten([x() for x in self._captured_closures], 2033 expand_composites=True) 2034 return self._captured_inputs + from_closures 2035 2036 @property 2037 def function_def(self): 2038 """Returns a `FunctionDef` object representing this function.""" 2039 return self._delayed_rewrite_functions.forward().definition 2040 2041 @property 2042 def output_shapes(self): 2043 """The function's output shapes.""" 2044 return nest.map_structure( 2045 lambda x: getattr(x, "shape", tensor_shape.TensorShape(None)), 2046 composite_tensor.replace_composites_with_components( 2047 self._func_graph.structured_outputs), 2048 expand_composites=False) 2049 2050 @property 2051 def output_dtypes(self): 2052 # TODO(akshayka): Consider removing this. 2053 return nest.map_structure( 2054 lambda x: x.dtype if x is not None else None, 2055 composite_tensor.replace_composites_with_components( 2056 self._func_graph.structured_outputs), 2057 expand_composites=False) 2058 2059 def add_to_graph(self, g=None): 2060 """Registers the function, adds it to the graph g or default graph. 2061 2062 Args: 2063 g: If specified, registers the function with this graph. Defaults to the 2064 current context (either the default graph or the eager context). 2065 """ 2066 # If we are not executing eagerly, adds the function to default graph if no 2067 # graph is specified. 2068 # In case of eager execution, function definition gets added to context 2069 # during construction itself. 2070 2071 if not context.executing_eagerly() and not g: 2072 g = ops.get_default_graph() 2073 self._delayed_rewrite_functions.forward().add_to_graph(g) 2074 2075 def add_gradient_functions_to_graph(self, g=None): 2076 """Add forward/backward functions to graph `g` or the current context.""" 2077 if not context.executing_eagerly() and not g: 2078 g = ops.get_default_graph() 2079 self._delayed_rewrite_functions.forward().add_to_graph(g) 2080 forward_function, backward_function = ( 2081 self._delayed_rewrite_functions.forward_backward()) 2082 forward_function.add_to_graph(g) 2083 backward_function.add_to_graph(g) 2084 2085 def _get_gradient_function(self): 2086 """Returns gradient function. It will be lazily created at first call.""" 2087 return self._delayed_rewrite_functions._rewrite_forward_and_call_backward # pylint: disable=protected-access 2088 2089 def _select_forward_and_backward_functions( 2090 self, args, possible_gradient_type, executing_eagerly): 2091 """Selects forward and backward functions based on the calling context. 2092 2093 The forward function computes the "real" function outputs, `self._outputs`, 2094 and any extra values needed by the corresponding backward function. 2095 2096 Args: 2097 args: A flat list of Tensors with all of the inputs to the forward 2098 function (including user-specified and captured inputs). 2099 possible_gradient_type: One of gradients_util.POSSIBLE_GRADIENT_TYPES_*. 2100 executing_eagerly: Boolean, the value of context.executing_eagerly(). 2101 2102 Returns: 2103 An object with a `forward` method returning a tuple of (forward_function : 2104 _EagerDefinedFunction, augmented_arguments : List), and a corresponding 2105 `record` method which takes outputs from the forward function and records 2106 the operation. forward_function should be called with augmented_arguments. 2107 """ 2108 if executing_eagerly: 2109 input_tangents = forwardprop_util.pack_tangents(args) 2110 else: 2111 input_tangents = forwardprop_util.TangentInfo() 2112 need_gradients_for_jvps = tape.should_record_backprop( 2113 input_tangents.tangents) 2114 # Allows re-use of forward and backward function pairs depending on the 2115 # tapes and forward accumulators watching its inputs. 2116 cache_key = (need_gradients_for_jvps, input_tangents.indices) 2117 if (possible_gradient_type 2118 == gradients_util.POSSIBLE_GRADIENT_TYPES_FIRST_ORDER): 2119 if input_tangents.indices or executing_eagerly: 2120 # There is a single non-persistent tape active, so the user can only 2121 # request first-order gradients from a tape. We can spend less time 2122 # graph building since we know this. 2123 # 2124 # We may still end up computing higher-order gradients, but that'd be 2125 # through `tf.gradients`, which can re-write the forward pass and so 2126 # needs no preparation here. 2127 functions = self._first_order_tape_functions.get(cache_key, None) 2128 if functions is None: 2129 functions = _FirstOrderTapeGradientFunctions( 2130 self._func_graph, self._attrs, self._garbage_collector, 2131 forwardprop_input_indices=input_tangents.indices, 2132 delayed_rewrite_functions=self._delayed_rewrite_functions, 2133 need_gradients_for_jvps=need_gradients_for_jvps) 2134 self._first_order_tape_functions[cache_key] = functions 2135 return _ForwardBackwardCall( 2136 functions, args, input_tangents.tangents, tape_watching=True) 2137 else: 2138 # We can avoid computing second-order gradients in some cases by doing a 2139 # delayed rewrite when graph building. Since we know we'll only compute 2140 # first-order tape gradients, the delayed rewrite is safe: we won't need 2141 # to tell the tape about side outputs. 2142 # 2143 # TODO(allenl): This case is really dirty. It would be better if we 2144 # could temporarily pop all of the current tapes to avoid 2145 # accidentally taking second-order gradients. 2146 return _ForwardBackwardCall( 2147 self._delayed_rewrite_functions, args, input_tangents.tangents, 2148 tape_watching=True) 2149 elif (possible_gradient_type 2150 == gradients_util.POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER): 2151 # Either there's a persistent tape watching, or there are multiple nested 2152 # tapes. Either way, the user may request higher-order gradients. We'll 2153 # spend a bit more time and make sure higher-order gradients are correct. 2154 functions = self._higher_order_tape_functions.get( 2155 cache_key, None) 2156 if functions is None: 2157 functions = _HigherOrderTapeGradientFunctions( 2158 self._func_graph, self._attrs, self._garbage_collector, 2159 forwardprop_input_indices=input_tangents.indices, 2160 delayed_rewrite_functions=self._delayed_rewrite_functions, 2161 need_gradients_for_jvps=need_gradients_for_jvps) 2162 self._higher_order_tape_functions[cache_key] = functions 2163 return _ForwardBackwardCall(functions, args, input_tangents.tangents, 2164 tape_watching=True) 2165 # else possible_gradient_type == POSSIBLE_GRADIENT_TYPES_NONE, meaning no 2166 # tape is recording. 2167 return _ForwardBackwardCall( 2168 self._delayed_rewrite_functions, args, input_tangents.tangents, 2169 tape_watching=False) 2170 2171 def _build_call_outputs(self, result): 2172 """Maps the fdef output list to actual output structure. 2173 2174 Args: 2175 result: Output lists defined by FunctionDef. 2176 Returns: 2177 The actual call output. 2178 """ 2179 # TODO(jlchu): call C++ version in function.cc when speed is improved 2180 if self._func_graph.structured_outputs is None: 2181 return result 2182 2183 # Replace outputs with results, skipping over any 'None' values. 2184 outputs_list = nest.flatten( 2185 self._func_graph.structured_outputs, expand_composites=True) 2186 j = 0 2187 for i, o in enumerate(outputs_list): 2188 if o is not None: 2189 custom_gradient.copy_handle_data(self.outputs[j], result[j]) 2190 outputs_list[i] = result[j] 2191 j += 1 2192 ret = nest.pack_sequence_as(self._func_graph.structured_outputs, 2193 outputs_list, expand_composites=True) 2194 return ret 2195 2196 @property 2197 def _as_name_attr_list(self): 2198 """Returns a `NameAttrList` representing this function.""" 2199 ret = attr_value_pb2.NameAttrList(name=self.name) 2200 for name, value in self._attrs.items(): 2201 ret.attr[name].CopyFrom(value) 2202 return ret 2203 2204 def _structured_signature_summary(self, default_values=False): 2205 """Returns a string summarizing this function's structured signature. 2206 2207 Args: 2208 default_values: If true, then include default values in the signature. 2209 2210 Returns: 2211 A `string`. 2212 """ 2213 # Note: we can't just use self._funcion_spec.signature_summary(), because 2214 # that would show "_BOUND_VALUE" as the default value for all arguments. 2215 assert self._function_spec is not None 2216 arg_specs, kwarg_specs = self.structured_input_signature 2217 arg_names = list(self._function_spec.arg_names) 2218 2219 # If an explicit input_signature is provided to @tf.function, then any 2220 # arguments with defaults that are not covered by that explicit signature 2221 # are simply dropped from the signature. 2222 # TODO(b/159639913) Look into whether dropping arguments with default values 2223 # from the signature is the right thing to do. 2224 arg_names = arg_names[:len(arg_specs)] 2225 2226 if default_values: 2227 for i in range(len(arg_names)): 2228 if not _contains_type_spec(arg_specs[i]): 2229 arg_names[i] += "={}".format(arg_specs[i]) 2230 if kwarg_specs: 2231 arg_names.append("*") 2232 for name, spec in kwarg_specs.items(): 2233 arg_names.append(name) 2234 if default_values and not _contains_type_spec(spec): 2235 arg_names[-1] += "={}".format(spec) 2236 signature = "{}({})".format(self._func_graph.name, ", ".join(arg_names)) 2237 2238 return signature 2239 2240 def _flat_signature_summary(self): 2241 """Returns a string summarizing this function's flat signature.""" 2242 assert self._arg_keywords is not None 2243 assert self._num_positional_args is not None 2244 arg_names = self._arg_keywords 2245 if self._num_positional_args > len(arg_names): 2246 arg_names.extend( 2247 "<arg{}>".format(i + 1) 2248 for i in range(len(arg_names), self._num_positional_args)) 2249 return "{}({})".format(self._func_graph.name, ", ".join(arg_names)) 2250 2251 def pretty_printed_signature(self, verbose=True): 2252 """Returns a string summarizing the signature of this concrete function.""" 2253 if not verbose: 2254 return self._structured_signature_summary(default_values=True) 2255 2256 def pretty_print_spec(spec): 2257 """Returns a string describing the spec for a single argument.""" 2258 if isinstance(spec, tensor_spec.TensorSpec): 2259 return "{} Tensor, shape={}".format(spec.dtype.name, spec.shape) 2260 elif nest.is_sequence(spec): 2261 pieces = nest.flatten(spec, expand_composites=False) 2262 markers = [_Marker("<{}>".format(i + 1)) for i in range(len(pieces))] 2263 structure = nest.pack_sequence_as(spec, markers) 2264 # Ensure dictionaries are sorted by key (for determinism) 2265 result = pprint.pformat(structure, width=10000) 2266 for (marker, piece) in zip(markers, pieces): 2267 result += "\n {}: {}".format(marker, pretty_print_spec(piece)) 2268 return result 2269 else: 2270 return repr(spec) 2271 2272 lines = [self._structured_signature_summary(default_values=True)] 2273 arg_specs, kwarg_specs = self.structured_input_signature 2274 names = list(self._function_spec.arg_names) 2275 2276 # If an explicit input_signature is provided to @tf.function, then any 2277 # arguments with defaults that are not covered by that explicit signature 2278 # are simply dropped from the signature. 2279 # TODO(b/159639913) Look into whether dropping arguments with default values 2280 # from the signature is the right thing to do. 2281 names = names[:len(arg_specs)] 2282 2283 names.extend(sorted(kwarg_specs)) 2284 specs = list(arg_specs) + list(kwarg_specs.values()) 2285 # note: we can skip bound args, since we already displayed thier bound 2286 # value in the signature summary. 2287 arg_details = [] 2288 for (name, spec) in zip(names, specs): 2289 if _contains_type_spec(spec): 2290 arg_details.append(" {}: {}".format(name, pretty_print_spec(spec))) 2291 if arg_details: 2292 lines.append(" Args:") 2293 lines.extend(arg_details) 2294 lines.append(" Returns:") 2295 2296 def spec_from_value(value): 2297 # For loaded function, structured_outputs are already specs. 2298 if isinstance(value, type_spec.TypeSpec): 2299 return value 2300 return type_spec.type_spec_from_value(value) 2301 2302 lines.append(" {}".format( 2303 pretty_print_spec( 2304 nest.map_structure(spec_from_value, self.structured_outputs)))) 2305 2306 return "\n".join(lines) 2307 2308 def __repr__(self): 2309 if self._function_spec is not None: 2310 return "<ConcreteFunction {} at 0x{:X}>".format( 2311 self.pretty_printed_signature(verbose=False), id(self)) 2312 elif not (self._num_positional_args is None or self._arg_keywords is None): 2313 return "<ConcreteFunction {} at 0x{:X}>".format( 2314 self._flat_signature_summary(), id(self)) 2315 else: 2316 return object.__repr__(self) 2317 2318 def __str__(self): 2319 if self._function_spec is not None: 2320 return "ConcreteFunction {}".format(self.pretty_printed_signature()) 2321 else: 2322 return self.__repr__() 2323 2324 2325_pywrap_utils.RegisterType("Tensor", ops.Tensor) 2326_pywrap_utils.RegisterType("EagerTensor", ops.EagerTensor) 2327_pywrap_utils.RegisterType("IndexedSlices", ops.IndexedSlices) 2328 2329 2330def _deterministic_dict_values(dictionary): 2331 return tuple(dictionary[key] for key in sorted(dictionary)) 2332 2333 2334class FunctionSpec(object): 2335 """Specification of how to bind arguments to a function.""" 2336 2337 @staticmethod 2338 def from_function_and_signature(python_function, 2339 input_signature, 2340 is_pure=False, 2341 experimental_follow_type_hints=False, 2342 jit_compile=None): 2343 """Create a FunctionSpec instance given a python function and signature. 2344 2345 Args: 2346 python_function: a function to inspect 2347 input_signature: a signature of the function (None, if variable) 2348 is_pure: if True all input arguments (including variables and constants) 2349 will be converted to tensors and no variable changes allowed. 2350 experimental_follow_type_hints: see `tf.function` 2351 jit_compile: see `tf.function` 2352 2353 Returns: 2354 instance of FunctionSpec 2355 """ 2356 fullargspec = tf_inspect.getfullargspec(python_function) 2357 # Treat a wrapped partial function as a special case. For all arguments that 2358 # were overridden with keywords in the partial: 2359 # - remove the corresponding arguments, 2360 # - remove the corresponding keywords. 2361 _, unwrapped = tf_decorator.unwrap(python_function) 2362 # TODO(b/131153379): Consider Python3's fullargspec.kwonlyargs and 2363 # fullargspec.kwonlydefaults. 2364 if isinstance(unwrapped, functools.partial): 2365 # Also consider the Python3 case with kwonlydefaults. 2366 if fullargspec.defaults or fullargspec.kwonlydefaults: 2367 new_defaults = fullargspec.defaults 2368 new_args = fullargspec.args 2369 if fullargspec.defaults: 2370 # To be able to canonicalize the function properly, we want to ignore 2371 # default values that are overridden via a partial kwarg. For example: 2372 # 2373 # def func(a, b, c, d=5, e=7): 2374 # return a, b, c, d, e 2375 # p_func = functools.partial(tf.function(func, 10, e=9)) 2376 # 2377 # Here we want to drop from the defaults the parameter `e`. If we 2378 # forwarded the call to the partial function with a default for `e` 2379 # we would get an error for passing two values for one parameter. 2380 # 2381 # Note that this has a limitation: we can only override parameters at 2382 # the end of the parameter list. 2383 # 2384 # In this case we want to end up with 3 arguments (b, c, d) and 1 2385 # default value (5). We do this by constructing a mask where 0 stands 2386 # for a value that was overridden by a partial kwarg. The seemingly 2387 # complicated logic below does just that - for arguments (b, c, d, e) 2388 # we would get a mask (1, 1, 1, 0). 2389 old_args = fullargspec.args 2390 old_defaults = fullargspec.defaults 2391 2392 no_default = object() 2393 num_args_without_defaults = len(old_args) - len(old_defaults) 2394 left_padding = tuple([no_default] * num_args_without_defaults) 2395 2396 args_with_defaults = zip(old_args, left_padding + old_defaults) 2397 2398 # Create a mask where 0 stands for args that had a partial kwarg 2399 # defined. 2400 non_keyword_defaults_mask = [ 2401 0 if key in unwrapped.keywords else 1 for key in old_args 2402 ] 2403 # Keep only arguments and defaults that were not kwargs of partial. 2404 new_args_with_defaults = list( 2405 itertools.compress(args_with_defaults, non_keyword_defaults_mask)) 2406 # Keep all args. 2407 new_args = [arg for arg, _ in new_args_with_defaults] 2408 # Keep only real default values. 2409 new_defaults = [ 2410 default for _, default in new_args_with_defaults 2411 if default is not no_default 2412 ] 2413 fullargspec = tf_inspect.FullArgSpec( 2414 args=new_args, 2415 varargs=fullargspec.varargs, 2416 varkw=fullargspec.varkw, 2417 defaults=new_defaults, 2418 kwonlyargs=[], 2419 kwonlydefaults={}, 2420 annotations=fullargspec.annotations) 2421 2422 # inspect.ismethod() and inspect.isfunction() both return False on a 2423 # functools.partial-wrapped function. We set it to False to 2424 # maintain consistency with prior versions. 2425 is_method = False 2426 2427 else: 2428 # Instead of using tf_inspect.ismethod() which only checks the 2429 # final unwrapped target, we check if any decorated target along the chain 2430 # is a method. 2431 is_method = tf_inspect.isanytargetmethod(python_function) 2432 2433 # In the following scenario, 'python_function' is a callable object. 2434 # python_function(...) is equal to python_function.__call__(self, ...) 2435 if not is_method and not tf_inspect.isfunction( 2436 python_function) and hasattr( 2437 python_function, "__class__") and hasattr( 2438 python_function.__class__, "__call__"): 2439 is_method = True 2440 2441 # Get the function's name. Remove functools.partial wrappers if necessary. 2442 while isinstance(python_function, functools.partial): 2443 python_function = python_function.func 2444 name = getattr(python_function, "__name__", "f") 2445 2446 return FunctionSpec( 2447 fullargspec, 2448 is_method, 2449 input_signature, 2450 is_pure=is_pure, 2451 jit_compile=jit_compile, 2452 experimental_follow_type_hints=experimental_follow_type_hints, 2453 name=name) 2454 2455 def __init__(self, 2456 fullargspec, 2457 is_method, 2458 input_signature, 2459 is_pure=False, 2460 experimental_follow_type_hints=False, 2461 name=None, 2462 jit_compile=None): 2463 """Constructs a FunctionSpec describing a python function. 2464 2465 Args: 2466 fullargspec: `tf_inspect.FullArgSpec` object describing the function. 2467 is_method: True if the function is a method. 2468 input_signature: a signature of the function (None, if variable) 2469 is_pure: if True all input arguments (including variables and constants) 2470 will be converted to tensors and no variable changes allowed. 2471 experimental_follow_type_hints: see `tf.function`. 2472 name: Name of the function 2473 jit_compile: see `tf.function`. 2474 """ 2475 self._fullargspec = fullargspec 2476 self._is_method = is_method 2477 self._is_pure = is_pure 2478 self._jit_compile = jit_compile 2479 self._experimental_follow_type_hints = experimental_follow_type_hints 2480 2481 # TODO(edloper): Include name when serializing for SavedModel? 2482 self._name = name or "f" 2483 2484 if self._is_method: 2485 # Remove `self`: default arguments shouldn't be matched to it. 2486 # TODO(b/127938157): Should this error out if there is no arg to 2487 # be removed? 2488 args = fullargspec.args[1:] 2489 else: 2490 args = fullargspec.args 2491 2492 # A cache mapping from argument name to index, for canonicalizing 2493 # arguments that are called in a keyword-like fashion. 2494 self._args_to_indices = {arg: i for i, arg in enumerate(args)} 2495 self._arg_names = args 2496 2497 # A cache mapping from arg index to default value, for canonicalization. 2498 default_values = fullargspec.defaults 2499 offset = len(args) - len(default_values or []) 2500 self._arg_indices_to_default_values = { 2501 offset + index: default 2502 for index, default in enumerate(default_values or []) 2503 } 2504 self._arg_indices_no_default_values = set(range(len(args))) - set( 2505 self._arg_indices_to_default_values) 2506 if input_signature is None: 2507 self._input_signature = None 2508 else: 2509 if set(fullargspec.kwonlyargs) - set(fullargspec.kwonlydefaults or ()): 2510 raise ValueError("Cannot define a TensorFlow function from a Python " 2511 "function with keyword-only arguments when " 2512 "input_signature is provided.") 2513 2514 if not isinstance(input_signature, (tuple, list)): 2515 raise TypeError("input_signature must be either a tuple or a " 2516 "list, received " + str(type(input_signature))) 2517 2518 self._input_signature = tuple(input_signature) 2519 self._flat_input_signature = tuple(nest.flatten(input_signature, 2520 expand_composites=True)) 2521 2522 @property 2523 def fullargspec(self): 2524 return self._fullargspec 2525 2526 @property 2527 def is_method(self): 2528 return self._is_method 2529 2530 @property 2531 def args_to_indices(self): 2532 return self._args_to_indices 2533 2534 @property 2535 def kwargs_to_include(self): 2536 return self._kwargs_to_include 2537 2538 @property 2539 def input_signature(self): 2540 return self._input_signature 2541 2542 @property 2543 def flat_input_signature(self): 2544 return self._flat_input_signature 2545 2546 @property 2547 def is_pure(self): 2548 return self._is_pure 2549 2550 @property 2551 def jit_compile(self): 2552 return self._jit_compile 2553 2554 @property 2555 def arg_names(self): 2556 return self._arg_names 2557 2558 @property 2559 def vararg_name(self): 2560 return self._fullargspec.varargs 2561 2562 @property 2563 def varkw_name(self): 2564 return self._fullargspec.varkw 2565 2566 def signature_summary(self, default_values=False): 2567 """Returns a string summarizing this function's signature. 2568 2569 Args: 2570 default_values: If true, then include default values in the signature. 2571 2572 Returns: 2573 A `string`. 2574 """ 2575 args = list(self._arg_names) 2576 if default_values: 2577 for (i, default) in self._arg_indices_to_default_values.items(): 2578 args[i] += "={}".format(default) 2579 if self._fullargspec.kwonlyargs: 2580 args.append("*") 2581 for arg_name in self._fullargspec.kwonlyargs: 2582 args.append(arg_name) 2583 if default_values and arg_name in self._fullargspec.kwonlydefaults: 2584 args[-1] += "={}".format(self._fullargspec.kwonlydefaults[arg_name]) 2585 return "{}({})".format(self._name, ", ".join(args)) 2586 2587 def _to_tensor_or_tensor_spec(self, x): 2588 return (x if isinstance(x, (ops.Tensor, tensor_spec.TensorSpec)) 2589 else ops.convert_to_tensor(x)) 2590 2591 def _convert_variables_to_tensors(self, args, kwargs): 2592 args = [self._to_tensor_or_tensor_spec(x) for x in args] 2593 kwargs = {kw: self._to_tensor_or_tensor_spec(x) 2594 for kw, x in kwargs.items()} 2595 return tuple(args), kwargs 2596 2597 def _convert_annotated_args_to_tensors(self, args, kwargs): 2598 """Attempts to autobox arguments annotated as tf.Tensor.""" 2599 if self.input_signature is not None: 2600 return 2601 2602 args = list(args) 2603 for i, arg in enumerate(args): 2604 # See 2605 # https://docs.python.org/3/library/inspect.html#inspect.getfullargspec 2606 if i < len(self._fullargspec.args): 2607 annotation_key = self._fullargspec.args[i] 2608 else: 2609 annotation_key = self._fullargspec.varargs 2610 arg_annotation = self._fullargspec.annotations.get(annotation_key, None) 2611 2612 # TODO(rahulkamat): Change to TensorLike (here ans below) 2613 if arg_annotation == ops.Tensor: 2614 args[i] = self._to_tensor_or_tensor_spec(arg) 2615 2616 for kw, v in kwargs.items(): 2617 if kw in self._fullargspec.kwonlyargs or kw in self._fullargspec.args: 2618 annotation_key = kw 2619 else: 2620 annotation_key = self._fullargspec.varkw 2621 kwarg_annotation = self._fullargspec.annotations.get(annotation_key, None) 2622 if kwarg_annotation == ops.Tensor: 2623 kwargs[kw] = self._to_tensor_or_tensor_spec(v) 2624 return tuple(args), kwargs 2625 2626 def canonicalize_function_inputs(self, *args, **kwargs): 2627 """Canonicalizes `args` and `kwargs`. 2628 2629 Canonicalize the inputs to the Python function using a `FunctionSpec` 2630 instance. In particular, we parse the varargs and kwargs that the 2631 original function was called with into a tuple corresponding to the 2632 Python function's positional (named) arguments and a dictionary 2633 corresponding to its kwargs. Missing default arguments are added. 2634 2635 If this `FunctionSpec` has an input signature, then it is used to convert 2636 arguments to tensors; otherwise, any inputs containing numpy arrays are 2637 converted to tensors. 2638 2639 Additionally, any inputs containing numpy arrays are converted to Tensors. 2640 2641 Args: 2642 *args: The varargs this object was called with. 2643 **kwargs: The keyword args this function was called with. 2644 2645 Returns: 2646 A canonicalized ordering of the inputs, as well as full and filtered 2647 (Tensors and Variables only) versions of their concatenated flattened 2648 representations, represented by a tuple in the form (args, kwargs, 2649 flat_args, filtered_flat_args). Here: `args` is a full list of bound 2650 arguments, and `kwargs` contains only true keyword arguments, as opposed 2651 to named arguments called in a keyword-like fashion. 2652 2653 Raises: 2654 ValueError: If a keyword in `kwargs` cannot be matched with a positional 2655 argument when an input signature is specified, or when the inputs 2656 do not conform to the input signature. 2657 """ 2658 if self._is_pure: 2659 args, kwargs = self._convert_variables_to_tensors(args, kwargs) 2660 if self._experimental_follow_type_hints: 2661 args, kwargs = self._convert_annotated_args_to_tensors(args, kwargs) 2662 # Pre-calculate to reduce overhead 2663 arglen = len(args) 2664 if self._input_signature is not None: 2665 if arglen > len(self._input_signature): 2666 raise TypeError("{} takes {} positional arguments (as specified by the " 2667 "input_signature) but {} were given".format( 2668 self.signature_summary(), 2669 len(self._input_signature), arglen)) 2670 for arg in six.iterkeys(kwargs): 2671 index = self._args_to_indices.get(arg, None) 2672 if index is None: 2673 raise TypeError("{} got unexpected keyword argument `{}`".format( 2674 self.signature_summary(), arg)) 2675 if index >= len(self._input_signature): 2676 raise TypeError( 2677 "{} got keyword argument `{}` that was not included in " 2678 "input_signature".format(self.signature_summary(), arg)) 2679 2680 if not kwargs: 2681 inputs = args 2682 if self._arg_indices_to_default_values: 2683 try: 2684 inputs += tuple(self._arg_indices_to_default_values[i] 2685 for i in range(arglen, len(self._arg_names))) 2686 except KeyError: 2687 missing_args = [ 2688 self._arg_names[i] 2689 for i in range(arglen, len(self._arg_names)) 2690 if i not in self._arg_indices_to_default_values 2691 ] 2692 raise TypeError("{} missing required arguments: {}".format( 2693 self.signature_summary(), ", ".join(missing_args))) 2694 2695 if self._fullargspec.kwonlydefaults: 2696 kwargs.update(self._fullargspec.kwonlydefaults) 2697 else: 2698 # Maps from index of arg to its corresponding value, according to `args` 2699 # and `kwargs`; seeded with the default values for the named args that 2700 # aren't in `args`. 2701 arg_indices_to_values = { 2702 index: default for index, default in six.iteritems( 2703 self._arg_indices_to_default_values) if index >= arglen 2704 } 2705 consumed_args = [] 2706 missing_arg_indices = self._arg_indices_no_default_values - set( 2707 range(arglen)) 2708 for arg, value in six.iteritems(kwargs): 2709 index = self._args_to_indices.get(arg, None) 2710 if index is not None: 2711 if index < arglen: 2712 raise TypeError("{} got two values for argument '{}'".format( 2713 self.signature_summary(), arg)) 2714 arg_indices_to_values[index] = value 2715 # These arguments in 'kwargs' might also belong to 2716 # positional arguments 2717 missing_arg_indices.discard(index) 2718 consumed_args.append(arg) 2719 for arg in consumed_args: 2720 # After this loop, `kwargs` will only contain keyword_only arguments, 2721 # and all positional_or_keyword arguments have been moved to `inputs`. 2722 kwargs.pop(arg) 2723 inputs = args + _deterministic_dict_values(arg_indices_to_values) 2724 # Exclude positional args with values 2725 if missing_arg_indices: 2726 missing_args = [self._arg_names[i] for i in sorted(missing_arg_indices)] 2727 if len(missing_args) == 1: 2728 raise TypeError("{} missing 1 required argument: {}".format( 2729 self.signature_summary(), missing_args[0])) 2730 else: 2731 raise TypeError("{} missing required arguments: {}".format( 2732 self.signature_summary(), ", ".join(missing_args))) 2733 2734 if kwargs and self._input_signature is not None: 2735 raise TypeError( 2736 "{} got unexpected keyword arguments: {}\n(Cannot define a " 2737 "TensorFlow function from a Python function with keyword arguments " 2738 "when input_signature is provided.)".format( 2739 self.signature_summary(), ", ".join(kwargs))) 2740 2741 if self._fullargspec.kwonlydefaults: 2742 for (kwarg, default) in self._fullargspec.kwonlydefaults.items(): 2743 kwargs.setdefault(kwarg, default) 2744 2745 if self._input_signature is None: 2746 inputs, flat_inputs, filtered_flat_inputs = _convert_numpy_inputs(inputs) 2747 kwargs, flat_kwargs, filtered_flat_kwargs = _convert_numpy_inputs(kwargs) 2748 return (inputs, kwargs, flat_inputs + flat_kwargs, 2749 filtered_flat_inputs + filtered_flat_kwargs) 2750 else: 2751 assert not kwargs 2752 inputs, flat_inputs, filtered_flat_inputs = _convert_inputs_to_signature( 2753 inputs, self._input_signature, self._flat_input_signature) 2754 return inputs, {}, flat_inputs, filtered_flat_inputs 2755 2756 2757def _as_ndarray(value): 2758 """Converts value to an ndarray, assumes _is_ndarray(value).""" 2759 # TODO(tomhennigan) Support __array_interface__ too. 2760 return value.__array__() 2761 2762 2763def _is_ndarray(value): 2764 """Tests whether the given value is an ndarray (and not a TF tensor/var).""" 2765 # TODO(tomhennigan) Support __array_interface__ too. 2766 return hasattr(value, "__array__") and not ( 2767 isinstance(value, ops.Tensor) 2768 or isinstance(value, resource_variable_ops.BaseResourceVariable) 2769 or hasattr(value, "_should_act_as_resource_variable") 2770 2771 # For legacy reasons we do not automatically promote Numpy strings. 2772 or isinstance(value, np.str_) 2773 # NumPy dtypes have __array__ as unbound methods. 2774 or isinstance(value, type) 2775 # CompositeTensors should be flattened instead. 2776 or isinstance(value, composite_tensor.CompositeTensor)) 2777 2778 2779def _convert_numpy_inputs(inputs): 2780 """Convert numpy array inputs to tensors.""" 2781 # We assume that any CompositeTensors have already converted their components 2782 # from numpy arrays to Tensors, so we don't need to expand composites here for 2783 # the numpy array conversion. Instead, we do so because the flattened inputs 2784 # are eventually passed to ConcreteFunction()._call_flat, which requires 2785 # expanded composites. 2786 flat_inputs = nest.flatten(inputs, expand_composites=True) 2787 2788 # Check for NumPy arrays in arguments and convert them to Tensors. 2789 # TODO(nareshmodi): Skip ndarray conversion to tensor altogether, perhaps 2790 # finding a way to store them directly in the cache key (currently not 2791 # possible since ndarrays are not hashable). 2792 need_packing = False 2793 filtered_flat_inputs = [] 2794 for index, value in enumerate(flat_inputs): 2795 if isinstance(value, 2796 (ops.Tensor, resource_variable_ops.BaseResourceVariable)): 2797 filtered_flat_inputs.append(value) 2798 elif hasattr(value, "__array__") and not ( 2799 hasattr(value, "_should_act_as_resource_variable") or 2800 isinstance(value, (np.str_, type, composite_tensor.CompositeTensor))): 2801 # This case is equivalent to _is_ndarray(value) == True 2802 a = _as_ndarray(value) 2803 if not isinstance(a, np.ndarray): 2804 raise TypeError("The output of __array__ must be an np.ndarray " 2805 "(got {} from {}).".format(type(a), type(value))) 2806 flat_inputs[index] = constant_op.constant(a) 2807 filtered_flat_inputs.append(flat_inputs[index]) 2808 need_packing = True 2809 if need_packing: 2810 return (nest.pack_sequence_as( 2811 structure=inputs, flat_sequence=flat_inputs, 2812 expand_composites=True), flat_inputs, filtered_flat_inputs) 2813 else: 2814 return inputs, flat_inputs, filtered_flat_inputs 2815 2816 2817def _convert_inputs_to_signature(inputs, input_signature, flat_input_signature): 2818 """Convert inputs to pass into a function with an explicit signature.""" 2819 2820 def format_error_message(inputs, input_signature): 2821 return (" inputs: (\n" + " " + ",\n ".join(str(i) for i in inputs) + 2822 ")\n" + " input_signature: (\n" + " " + 2823 ",\n ".join(str(i) for i in input_signature) + ")") 2824 2825 try: 2826 flatten_inputs = nest.flatten_up_to( 2827 input_signature, 2828 inputs[:len(input_signature)], 2829 expand_composites=True, 2830 check_types=False) # lists are convert to tuples for `tf.data`. 2831 except ValueError: 2832 raise ValueError("Structure of Python function inputs does not match " 2833 "input_signature:\n%s" % 2834 format_error_message(inputs, input_signature)) 2835 2836 need_packing = False 2837 for index, (value, spec) in enumerate(zip(flatten_inputs, 2838 flat_input_signature)): 2839 if (isinstance(spec, tensor_spec.TensorSpec) and 2840 not _pywrap_utils.IsTensor(value)): 2841 try: 2842 flatten_inputs[index] = ops.convert_to_tensor( 2843 value, dtype_hint=spec.dtype) 2844 need_packing = True 2845 except ValueError: 2846 raise ValueError("When input_signature is provided, all inputs to " 2847 "the Python function must be convertible to " 2848 "tensors:\n%s" % 2849 format_error_message(inputs, input_signature)) 2850 2851 if any(not spec.is_compatible_with(other) for spec, other in zip( 2852 flat_input_signature, 2853 flatten_inputs)): 2854 raise ValueError("Python inputs incompatible with input_signature:\n%s" % 2855 format_error_message(inputs, input_signature)) 2856 2857 if need_packing: 2858 inputs = nest.pack_sequence_as( 2859 structure=input_signature, 2860 flat_sequence=flatten_inputs, 2861 expand_composites=True) 2862 2863 flat_inputs = nest.flatten(inputs, expand_composites=True) 2864 2865 return (inputs, flat_inputs, [ 2866 t for t in flat_inputs 2867 if isinstance(t, (ops.Tensor, resource_variable_ops.BaseResourceVariable)) 2868 ]) 2869 2870 2871class FunctionCache(object): 2872 """A lightweight container for cached functions. 2873 """ 2874 2875 __slots__ = [ 2876 "missed", "primary", "arg_relaxed_specs", "arg_relaxed", 2877 "_garbage_collectors" 2878 ] 2879 2880 def __init__(self): 2881 # The set of functions that have been missed; entries are CacheKey with 2882 # input_signature `None` (e.g. a "call context key") 2883 self.missed = set() 2884 # The primary cache, mapping a fully shaped CacheKey to a function. 2885 self.primary = collections.OrderedDict() 2886 # A cache key lookup, mapping a CacheKey generated without shape info to a 2887 # flat list of `TypeSpec`s with relaxed shapes (one for each flattened 2888 # argument). Arguments that are not Tensors or `CompositeTensor`s contain a 2889 # `None` for the corresponding relaxed spec. 2890 self.arg_relaxed_specs = collections.OrderedDict() 2891 # The secondary cache, mapping a CacheKey generated without shape info to a 2892 # function. 2893 self.arg_relaxed = collections.OrderedDict() 2894 # All OrderedDicts require manual garbage collection. 2895 self._garbage_collectors = [ 2896 _FunctionGarbageCollector(self.primary), 2897 _FunctionGarbageCollector(self.arg_relaxed), 2898 _FunctionGarbageCollector(self.arg_relaxed_specs)] 2899 2900 def all_values(self): 2901 """A list of all `ConcreteFunction` instances held by this cache.""" 2902 # We need to simultaneously make sure our returned concrete functions are 2903 # unique *and* make sure they are returned in a deterministic order for 2904 # serialization. 2905 # 2906 # TODO(b/174215821): It's likely that we ultimately would just prefer to 2907 # choose the most specific concrete function shape given a set of 2908 # arguments. If and when that is implemented, this logic can be revisited. 2909 primary_functions = set(self.primary.values()) 2910 return list(self.primary.values()) + [ 2911 v for v in self.arg_relaxed.values() if v not in primary_functions] 2912 2913 2914class Function(object): 2915 """Wrapper class for the graph functions defined for a Python function. 2916 2917 See the documentation for `defun` for more information on the semantics of 2918 defined functions. 2919 2920 `Function` class is thread-compatible meaning that minimal usage of defuns 2921 (defining and calling) is thread-safe, but if users call other methods or 2922 invoke the base `python_function` themselves, external synchronization is 2923 necessary. 2924 In addition, Function is not reentrant, so recursive functions need to call 2925 the wrapped function, not the wrapper. 2926 """ 2927 2928 def __init__(self, 2929 python_function, 2930 name, 2931 input_signature=None, 2932 attributes=None, 2933 autograph=True, 2934 autograph_options=None, 2935 experimental_relax_shapes=False, 2936 capture_by_value=None, 2937 jit_compile=None, 2938 experimental_follow_type_hints=False): 2939 """Initializes a `Function`. 2940 2941 Args: 2942 python_function: the function to be wrapped. 2943 name: the name given to it. 2944 input_signature: a possibly nested sequence of `TensorSpec` objects 2945 specifying the input signature of this function. If `None`, a separate 2946 function is instantiated for each inferred input signature. 2947 attributes: dict, extra keyword arguments that will be added as attribute 2948 of the function. 2949 autograph: whether to use autograph to compile 2950 `python_function`. See https://www.tensorflow.org/guide/autograph for 2951 more information. 2952 autograph_options: Experimental knobs to control behavior 2953 `when autograph=True`. See https://www.tensorflow.org/guide/autograph 2954 for more information. 2955 experimental_relax_shapes: When true, argument shapes may be relaxed to 2956 avoid unnecessary retracing. 2957 capture_by_value: Experimental. Whether to capture resource variables by 2958 value or reference. If None, will inherit from a parent context or 2959 default to False. 2960 jit_compile: Force-compile the function with XLA, cf. 2961 def_function.Function doc on jit_compile. 2962 experimental_follow_type_hints: See the documentation for `tf.function`. 2963 2964 Raises: 2965 ValueError: if `input_signature` is not None and the `python_function`'s 2966 argspec has keyword arguments. 2967 """ 2968 self._python_function = python_function 2969 pure_function = attributes and IMPLEMENTS_ATTRIBUTE_NAME in attributes 2970 self._function_spec = FunctionSpec.from_function_and_signature( 2971 python_function, 2972 input_signature, 2973 is_pure=pure_function, 2974 experimental_follow_type_hints=experimental_follow_type_hints) 2975 self._name = name 2976 self._autograph = autograph 2977 self._autograph_options = autograph_options 2978 self._experimental_relax_shapes = experimental_relax_shapes 2979 self._function_cache = FunctionCache() 2980 self._function_attributes = attributes or {} 2981 self._capture_by_value = capture_by_value 2982 self.tracing_count = 0 2983 if self.input_signature is not None: 2984 self._hashable_input_signature = _make_input_signature_hashable( 2985 self.flat_input_signature) 2986 2987 self._lock = threading.Lock() 2988 # _descriptor_cache is a of instance of a class to an instance-specific 2989 # `Function`, used to make sure defun-decorated methods create different 2990 # functions for each instance. 2991 self._descriptor_cache = weakref.WeakKeyDictionary() 2992 self._jit_compile = jit_compile 2993 self._experimental_follow_type_hints = experimental_follow_type_hints 2994 2995 def __call__(self, *args, **kwargs): 2996 """Calls a graph function specialized to the inputs.""" 2997 with self._lock: 2998 (graph_function, 2999 filtered_flat_args) = self._maybe_define_function(args, kwargs) 3000 return graph_function._call_flat( 3001 filtered_flat_args, captured_inputs=graph_function.captured_inputs) # pylint: disable=protected-access 3002 3003 @property 3004 def python_function(self): 3005 """Returns the wrapped Python function.""" 3006 return self._python_function # pylint: disable=protected-access 3007 3008 @property 3009 def function_spec(self): 3010 return self._function_spec 3011 3012 @property 3013 def input_signature(self): 3014 """Returns the input signature.""" 3015 return self._function_spec.input_signature 3016 3017 @property 3018 def flat_input_signature(self): 3019 """Returns the flattened input signature.""" 3020 return self._function_spec.flat_input_signature 3021 3022 def _get_concrete_function_internal_garbage_collected(self, *args, **kwargs): 3023 """Returns a concrete function which cleans up its graph function.""" 3024 if self.input_signature: 3025 args, kwargs = None, None 3026 with self._lock: 3027 graph_function, _ = self._maybe_define_function(args, kwargs) 3028 return graph_function 3029 3030 def _get_concrete_function_internal(self, *args, **kwargs): 3031 """Bypasses error checking when getting a graph function.""" 3032 graph_function = self._get_concrete_function_internal_garbage_collected( 3033 *args, **kwargs) 3034 # We're returning this concrete function to someone, and they may keep a 3035 # reference to the FuncGraph without keeping a reference to the 3036 # ConcreteFunction object. So we won't clean up the reference cycles 3037 # manually and instead will leave them to Python's garbage collector. 3038 graph_function._garbage_collector.release() # pylint: disable=protected-access 3039 return graph_function 3040 3041 def _get_concrete_function_garbage_collected(self, *args, **kwargs): 3042 """Returns a `ConcreteFunction` specialized to inputs and execution context. 3043 3044 Unlike `get_concrete_function(...)`, the graph will be deleted when the 3045 returned function is deleted. It's useful to avoid creating a reference 3046 cycle when you know for sure that the graph will be no longer used without 3047 the returned function. 3048 3049 Args: 3050 *args: inputs to specialize on. 3051 **kwargs: inputs to specialize on. 3052 """ 3053 if self.input_signature: 3054 if kwargs: 3055 raise ValueError("Cannot define a TensorFlow function from a Python " 3056 "function with keyword arguments when " 3057 "input_signature is provided.") 3058 if args: 3059 # If args are provided, they must match the input signature. 3060 if not is_same_structure(self.input_signature, args): 3061 raise ValueError("Structure of Python function inputs does not match " 3062 "input_signature.") 3063 flat_inputs = nest.flatten(args, expand_composites=True) 3064 if any(not isinstance(arg, (ops.Tensor, tensor_spec.DenseSpec, 3065 resource_variable_ops.BaseResourceVariable)) 3066 for arg in flat_inputs): 3067 raise ValueError("When input_signature is provided, all inputs to " 3068 "the Python function must be Tensors, Variables, " 3069 "tf.TensorSpec or tf.VariableSpec objects.") 3070 if any(not spec.is_compatible_with(other) 3071 for spec, other in zip(self.flat_input_signature, flat_inputs)): 3072 raise ValueError("Python inputs incompatible with input_signature: " 3073 "inputs (%s), input_signature (%s)" % 3074 (str(args), str(self.input_signature))) 3075 args, kwargs = None, None 3076 with self._lock: 3077 graph_function, _ = self._maybe_define_function(args, kwargs) 3078 seen_names = set() 3079 captured = object_identity.ObjectIdentitySet( 3080 graph_function.graph.internal_captures) 3081 # pylint: disable=protected-access 3082 graph_function._arg_keywords = [] 3083 prefix_counts = {} 3084 # pylint: enable=protected-access 3085 num_positional = 0 3086 for arg in graph_function.graph.inputs: 3087 if arg in captured: 3088 break 3089 num_positional += 1 3090 user_arg_name = compat.as_str(arg.op.get_attr("_user_specified_name")) 3091 proposal = user_arg_name 3092 while proposal in seen_names: 3093 index = prefix_counts.get(user_arg_name, 1) 3094 proposal = "{}_{}".format(user_arg_name, index) 3095 prefix_counts[user_arg_name] = index + 1 3096 seen_names.add(proposal) 3097 graph_function._arg_keywords.append(proposal) # pylint: disable=protected-access 3098 # Anything can be a positional argument, in the same order as .inputs 3099 graph_function._num_positional_args = num_positional # pylint: disable=protected-access 3100 return graph_function 3101 3102 def get_concrete_function(self, *args, **kwargs): 3103 """Returns a `ConcreteFunction` specialized to inputs and execution context. 3104 3105 Args: 3106 *args: inputs to specialize on. Can be concrete values (e.g. 1) 3107 or `tf.Tensor` or `tf.TensorSpec`. 3108 **kwargs: keyword inputs to specialize on. Concrete values (e.g. 1) 3109 or `tf.Tensor` or `tf.TensorSpec`. 3110 """ 3111 graph_function = self._get_concrete_function_garbage_collected( 3112 *args, **kwargs) 3113 graph_function._garbage_collector.release() # pylint: disable=protected-access 3114 return graph_function 3115 3116 def __get__(self, instance, owner): 3117 """Makes it possible to defun instance methods.""" 3118 del owner 3119 # `instance` here is the instance that this `Function` was accessed through 3120 # e.g., for 3121 # 3122 # class Foo(object): 3123 # 3124 # @function.defun 3125 # def bar(self): 3126 # ... 3127 # 3128 # foo = Foo() 3129 # foo.bar() # `foo.bar` is a `Function` instance 3130 # 3131 # then `instance` will be `foo` (and `owner` will be `Foo`). We create a 3132 # new instance of `Function` here to allow different instances each 3133 # to create variables once, thereby allowing methods to be decorated with 3134 # defun. Keeps a cache to avoid retracing the function every time the 3135 # descriptor is accessed. 3136 if instance not in self._descriptor_cache: 3137 if instance is None: 3138 return self 3139 # If there is no instance-specific `Function` in the cache, we construct 3140 # an instance-specific `Function` that uses a weak reference to the 3141 # instance (so that the instance will be correctly gc'd). 3142 3143 # And finally add the wrapped function to the description cache 3144 self._descriptor_cache[instance] = class_method_to_instance_method( 3145 self, instance) 3146 3147 # Return the cached `Function` for the instance 3148 return self._descriptor_cache[instance] 3149 3150 def _cache_key(self, 3151 args, 3152 kwargs, 3153 cache_key_context, 3154 include_tensor_ranks_only=False): 3155 """Computes the cache key given inputs and execution context.""" 3156 if self.input_signature is None: 3157 inputs = (args, kwargs) if kwargs else args 3158 input_signature = pywrap_tfe.TFE_Py_EncodeArg(inputs, 3159 include_tensor_ranks_only) 3160 hashable_input_signature = _make_input_signature_hashable(input_signature) 3161 else: 3162 del args, kwargs 3163 assert not include_tensor_ranks_only 3164 hashable_input_signature = self._hashable_input_signature 3165 3166 (parent_graph, device_functions, colocation_stack, in_cross_replica_context, 3167 variable_policy, xla_context_id) = cache_key_context 3168 3169 return CacheKey(hashable_input_signature, parent_graph, device_functions, 3170 colocation_stack, in_cross_replica_context, variable_policy, 3171 xla_context_id) 3172 3173 def _cache_key_context(self): 3174 """Returns execution context.""" 3175 ctx = context.context() 3176 3177 # Don't need to open an init_scope if the _cache_key call is in eager mode 3178 # already. 3179 executing_eagerly = ctx.executing_eagerly() 3180 parent_graph = None 3181 xla_context_id = 0 3182 if not executing_eagerly: 3183 # We want to force function retracing for each different 3184 # XLAControlFlowContext, so add `xla_context_id` to the cache key. 3185 xla_context = _enclosing_xla_context() 3186 if xla_context is not None and \ 3187 xla_context.RequiresUniqueFunctionRetracing(): 3188 xla_context_id = id(xla_context) 3189 3190 with ops.init_scope(): 3191 # The graph, or whether we're executing eagerly, should be a part of the 3192 # cache key so we don't improperly capture tensors such as variables. 3193 executing_eagerly = ctx.executing_eagerly() 3194 parent_graph = None if executing_eagerly else ops.get_default_graph() 3195 3196 # pylint: disable=protected-access 3197 default_graph = ops.get_default_graph() 3198 # TODO(b/117617952): The current distribution strategy will affect graph 3199 # building (e.g. accessing different variables from different devices) and 3200 # so requires retracing for each device. 3201 strategy_stack = default_graph._distribution_strategy_stack 3202 uses_distribution_strategy = ( 3203 strategy_stack and 3204 strategy_stack[-1].strategy.extended._retrace_functions_for_each_device 3205 ) 3206 if executing_eagerly: 3207 colocation_stack = () 3208 if uses_distribution_strategy: 3209 device_functions = (pydev.merge_device(ctx.device_name),) 3210 else: 3211 device_functions = () 3212 else: 3213 colocation_stack = tuple(default_graph._colocation_stack.peek_objs()) 3214 if (uses_distribution_strategy 3215 or func_graph_module.device_stack_has_callable( 3216 default_graph._device_function_stack)): 3217 # Putting the device in the cache key ensures that call-site device 3218 # annotations are respected. 3219 device_functions = tuple(default_graph._device_functions_outer_to_inner) 3220 else: 3221 device_functions = () 3222 3223 in_cross_replica_context = False 3224 try: 3225 in_cross_replica_context = (strategy_stack[-1].replica_context is None) # pylint: disable=protected-access 3226 except (AttributeError, IndexError): 3227 pass 3228 3229 if save_context.in_save_context(): 3230 variable_policy = ( 3231 save_context.get_save_options().experimental_variable_policy) 3232 else: 3233 variable_policy = None 3234 3235 return (parent_graph, device_functions, colocation_stack, 3236 in_cross_replica_context, variable_policy, xla_context_id) 3237 3238 def _create_graph_function(self, args, kwargs, override_flat_arg_shapes=None): 3239 """Create a `ConcreteFunction` from `args` and `kwargs`.""" 3240 self.tracing_count += 1 3241 3242 if self.input_signature is None: 3243 arglen = len(args) 3244 else: 3245 arglen = len(self.input_signature) 3246 base_arg_names = self._function_spec.arg_names[:arglen] 3247 num_missing_args = arglen - len(self._function_spec.arg_names) 3248 missing_arg_names = [self._function_spec.vararg_name] * num_missing_args 3249 # Produce a list of missing args of the form ["arg_0", "arg_1", ...], 3250 # where arg is based on the self._function_spec.vararg_name. 3251 missing_arg_names = [ 3252 "%s_%d" % (arg, i) for i, arg in enumerate(missing_arg_names) 3253 ] 3254 arg_names = base_arg_names + missing_arg_names 3255 graph_function = ConcreteFunction( 3256 func_graph_module.func_graph_from_py_func( 3257 self._name, 3258 self._python_function, 3259 args, 3260 kwargs, 3261 self.input_signature, 3262 autograph=self._autograph, 3263 autograph_options=self._autograph_options, 3264 arg_names=arg_names, 3265 override_flat_arg_shapes=override_flat_arg_shapes, 3266 capture_by_value=self._capture_by_value), 3267 self._function_attributes, 3268 function_spec=self.function_spec, 3269 # Tell the ConcreteFunction to clean up its graph once it goes out of 3270 # scope. This is not the default behavior since it gets used in some 3271 # places (like Keras) where the FuncGraph lives longer than the 3272 # ConcreteFunction. 3273 shared_func_graph=False) 3274 return graph_function 3275 3276 def _define_function_with_shape_relaxation(self, args, kwargs, flat_args, 3277 filtered_flat_args, 3278 cache_key_context): 3279 """Define a function, relaxing arg shapes to avoid unnecessary retracing.""" 3280 flat_no_comp = nest.flatten((args, kwargs), expand_composites=False) 3281 3282 any_composite_args = any( 3283 isinstance(x, composite_tensor.CompositeTensor) for x in flat_no_comp) 3284 3285 # Build a cache key where TensorShapes include only rank information (and 3286 # not information about the size of each dimension). 3287 if not any_composite_args: 3288 rank_only_cache_key = self._cache_key( 3289 args, kwargs, cache_key_context, include_tensor_ranks_only=True) 3290 else: 3291 # For the rank-only cache key, replace any composite tensors with 3292 # shape-relaxed TypeSpecs. 3293 (cache_key_args, cache_key_kwargs) = nest.map_structure( 3294 _shape_relaxed_type_for_composite_tensor, (args, kwargs)) 3295 rank_only_cache_key = self._cache_key( 3296 cache_key_args, 3297 cache_key_kwargs, 3298 cache_key_context, 3299 include_tensor_ranks_only=True) 3300 3301 arg_specs = [_type_spec_for(x) for x in flat_no_comp] 3302 relaxed_arg_specs = self._function_cache.arg_relaxed_specs.get( 3303 rank_only_cache_key, None) 3304 relaxed_arg_function = self._function_cache.arg_relaxed.get( 3305 rank_only_cache_key, None) 3306 3307 if (relaxed_arg_function is not None 3308 and all(_is_type_subset(x, y) for (x, y) in 3309 zip(relaxed_arg_specs, arg_specs))): 3310 return relaxed_arg_function, filtered_flat_args 3311 3312 if relaxed_arg_specs is None: 3313 relaxed_arg_specs = arg_specs 3314 else: 3315 if len(arg_specs) != len(relaxed_arg_specs): 3316 raise RuntimeError("Expected arg_specs len to match " 3317 "relaxed_arg_specs len: %d vs. %d" 3318 % (len(arg_specs), len(relaxed_arg_specs))) 3319 relaxed_arg_specs = [ 3320 x if x is None else x.most_specific_compatible_type(y) 3321 for (x, y) in zip(arg_specs, relaxed_arg_specs)] 3322 self._function_cache.arg_relaxed_specs[rank_only_cache_key] = ( 3323 relaxed_arg_specs) 3324 relaxed_arg_shapes = [ 3325 x if x is None else x.shape 3326 for x in nest.flatten(relaxed_arg_specs, expand_composites=True)] 3327 3328 if any_composite_args: 3329 # Rebuild composite tensors with the relaxed TypeSpecs. For example, 3330 # if a tf.data iterator is passed as an argument, then we need to relax 3331 # the TensorShapes in its element_spec. 3332 (relaxed_arg_specs, relaxed_kwarg_specs) = nest.pack_sequence_as( 3333 (args, kwargs), relaxed_arg_specs, expand_composites=False) 3334 (args, kwargs) = nest.pack_sequence_as( 3335 (relaxed_arg_specs, relaxed_kwarg_specs), 3336 flat_args, 3337 expand_composites=True) 3338 3339 graph_function = self._create_graph_function( 3340 args, kwargs, override_flat_arg_shapes=relaxed_arg_shapes) 3341 self._function_cache.arg_relaxed[rank_only_cache_key] = graph_function 3342 3343 return (graph_function, [ 3344 t for t in nest.flatten((args, kwargs), expand_composites=True) 3345 if isinstance(t, (ops.Tensor, 3346 resource_variable_ops.BaseResourceVariable)) 3347 ]) 3348 3349 def _maybe_define_function(self, args, kwargs): 3350 """Gets a function for these inputs, defining it if necessary. 3351 3352 `args` and `kwargs` can be None if this `Function` was created with an 3353 `input_signature`. 3354 3355 Caller must hold self._lock. 3356 3357 Args: 3358 args: The varargs for the Python function. 3359 kwargs: The keyword args for the Python function. 3360 3361 Returns: 3362 A graph function corresponding to the input signature implied by args and 3363 kwargs, as well as filtered flattened inputs (only Tensors and Variables) 3364 that the object should be called with. 3365 3366 Raises: 3367 ValueError: If inputs are incompatible with the input signature. 3368 TypeError: If the function inputs include non-hashable objects 3369 RuntimeError: If there's an internal bug (inconsistency) in handling 3370 shape relaxation retracing. 3371 """ 3372 if self.input_signature is None or args is not None or kwargs is not None: 3373 args, kwargs, flat_args, filtered_flat_args = \ 3374 self._function_spec.canonicalize_function_inputs(*args, **kwargs) 3375 else: 3376 flat_args, filtered_flat_args = [None], [] 3377 3378 cache_key_context = self._cache_key_context() 3379 cache_key = self._cache_key(args, kwargs, cache_key_context) 3380 3381 try: 3382 hash(cache_key) 3383 except TypeError as e: 3384 raise TypeError( 3385 "Arguments supplied to `defun`-generated functions must be" 3386 " hashable. Original error: %s" % e) 3387 3388 graph_function = self._function_cache.primary.get(cache_key, None) 3389 if graph_function is not None: 3390 return graph_function, filtered_flat_args 3391 3392 with monitoring.MonitoredTimer(_graph_building_time_counter.get_cell()): 3393 with trace.Trace("tf.function-graph_building"): 3394 logging.vlog(1, 3395 "Creating new FuncGraph for Python function %r (key: %r)", 3396 self._python_function, cache_key) 3397 logging.vlog(2, "Python function signature [args: %s] [kwargs: %s]", 3398 args, kwargs) 3399 3400 # pylint: disable=protected-access 3401 call_context_key = cache_key._replace(input_signature=None) 3402 # pylint: disable=protected-access 3403 3404 ag_status = ( 3405 ag_ctx.Status.ENABLED 3406 if self._autograph else ag_ctx.Status.DISABLED) 3407 with ag_ctx.ControlStatusCtx( 3408 status=ag_status, options=self._autograph_options): 3409 3410 # Build a function with shape relaxation retracing if: 3411 # 1. shape relaxation is explicitly enabled 3412 # and 2. there's no provided input signature 3413 # and 3. there's been a cache miss for this calling context 3414 if (self._experimental_relax_shapes and 3415 self.input_signature is None and 3416 call_context_key in self._function_cache.missed): 3417 return self._define_function_with_shape_relaxation( 3418 args, kwargs, flat_args, filtered_flat_args, cache_key_context) 3419 3420 self._function_cache.missed.add(call_context_key) 3421 graph_function = self._create_graph_function(args, kwargs) 3422 self._function_cache.primary[cache_key] = graph_function 3423 3424 return graph_function, filtered_flat_args 3425 3426 3427def register(func, *args, **kwargs): 3428 """Register a specialization of a `Function` into the graph. 3429 3430 This won't actually call the function with the inputs, and only put the 3431 function definition into graph. Register function with different input param 3432 will result into multiple version of functions registered in graph. 3433 3434 Args: 3435 func: the `Function` instance that generated by a @defun 3436 *args: input arguments for the Python function. 3437 **kwargs: input keyword arguments for the Python function. 3438 3439 Returns: 3440 a `ConcreteFunction` object specialized to inputs and execution context. 3441 3442 Raises: 3443 ValueError: When the input function is not a defun wrapped python function. 3444 """ 3445 if not isinstance(func, Function): 3446 raise ValueError("Only defun function is allowed to be registered. " 3447 "Got type: %s" % type(func)) 3448 concrete_func = func.get_concrete_function(*args, **kwargs) 3449 concrete_func.add_to_graph() 3450 concrete_func.add_gradient_functions_to_graph() 3451 return concrete_func 3452 3453 3454def validate_signature(signature): 3455 if any(not isinstance(arg, tensor_spec.DenseSpec) 3456 for arg in nest.flatten(signature, expand_composites=True)): 3457 raise TypeError("Invalid input_signature {}; input_signature must be " 3458 "a possibly nested sequence of TensorSpec objects." 3459 .format(signature)) 3460 3461 3462def defun(func=None, 3463 input_signature=None, 3464 autograph=True, 3465 experimental_autograph_options=None, 3466 experimental_relax_shapes=False): 3467 """Compiles a Python function into a callable TensorFlow graph. 3468 3469 `defun` (short for "define function") compiles a Python function 3470 composed of TensorFlow operations into a callable that executes a `tf.Graph` 3471 containing those operations. The callable produced by `defun` contains only 3472 the subgraph of TensorFlow operations that were executed when the Python 3473 function was called with a particular input signature, defined as a list 3474 of the shapes and dtypes of the Python function's Tensor-valued arguments and 3475 the values of its non-Tensor Python objects. 3476 3477 When eager execution is enabled, the ability to create graphs from Python 3478 functions makes it possible to incrementally trade off debuggability and 3479 interactivity for performance. Functions compiled with `defun` cannot be 3480 inspected with `pdb`; however, executing a graph 3481 generated by `defun` sometimes takes less time and memory than eagerly 3482 executing the corresponding Python function, since specifying computations as 3483 graphs allows for optimizations like automatic buffer reuse and 3484 parallelization among ops. Note that executing a `defun`-compiled function 3485 incurs a small constant overhead, so eagerly executing sufficiently small 3486 Python functions might take less time than executing their corresponding 3487 `defun`-generated graphs. 3488 3489 For a Python function to be compatible with `defun`, all of its arguments must 3490 be hashable Python objects or lists thereof. The function itself may not 3491 modify the list/map structure of its arguments. Additionally, it must return 3492 zero or more `tf.Tensor` objects. If the Python function returns 3493 a `tf.Variable`, its compiled version will return the value of that variable 3494 as a `tf.Tensor`. 3495 3496 Executing a graph generated by `defun` respects device annotations (i.e., 3497 all `with tf.device` directives present in a Python function will also be 3498 present in its corresponding graph), but it is not yet possible to execute the 3499 generated graphs across multiple machines. 3500 3501 _Example Usage_ 3502 3503 ```python 3504 import tensorflow as tf 3505 3506 tf.compat.v1.enable_eager_execution() 3507 3508 # A simple example. 3509 def f(x, y): 3510 return tf.reduce_mean(tf.multiply(x ** 2, 3) + y) 3511 3512 g = tf.contrib.eager.defun(f) 3513 3514 x = tf.constant([[2.0, 3.0]]) 3515 y = tf.constant([[3.0, -2.0]]) 3516 3517 # `f` and `g` will return the same value, but `g` will be executed as a 3518 # TensorFlow graph. 3519 assert f(x, y).numpy() == g(x, y).numpy() 3520 3521 # `defun` is capable of compiling Python functions that close over Python 3522 # objects, including Tensors and Variables. 3523 @tf.contrib.eager.defun 3524 def h(): 3525 return f(x, y) 3526 3527 assert (h().numpy() == f(x, y).numpy()).all() 3528 3529 # `defun` automatically lifts variables out of the graphs it creates, 3530 # allowing you to compile the `call` methods of `tf.keras.layers.Layer` and 3531 # `tf.keras.Model` objects. 3532 class MyModel(tf.keras.Model): 3533 3534 def __init__(self, keep_probability=0.2): 3535 super(MyModel, self).__init__() 3536 self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu) 3537 self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax) 3538 self.keep_probability = keep_probability 3539 3540 @tf.contrib.eager.defun 3541 def call(self, inputs, training=True): 3542 x = self.dense2(self.dense1(inputs)) 3543 if training: 3544 return tf.nn.dropout(x, self.keep_probability) 3545 else: 3546 return x 3547 3548 model = MyModel() 3549 model(x, training=True) # executes a graph, with dropout 3550 model(x, training=False) # executes a graph, without dropout 3551 3552 # `defun`-compiled functions are differentiable. 3553 optimizer = tf.compat.v1.train.GradientDescentOptimizer(learning_rate=0.01) 3554 with tf.GradientTape() as tape: 3555 outputs = model(x) 3556 gradient = tape.gradient(outputs, model.trainable_variables) 3557 optimizer.apply_gradients((grad, var) for grad, var in zip(gradient, 3558 model.trainable_variables)) 3559 ``` 3560 3561 When using `defun`, there are subtleties regarding inputs, Python control 3562 flow, and variable creation that one should be aware of. For concreteness, let 3563 `f` be a Python function that returns zero or more `tf.Tensor` objects and 3564 let `F = defun(f)`. `F` builds a graph for each unique input signature it 3565 sees, Python control flow is baked into graphs, and operations related to 3566 variable initialization are automatically lifted out of the graphs that `F` 3567 generates and placed in the eager context if executing eagerly or into an 3568 outer graph otherwise. 3569 3570 _Input Signatures_ 3571 3572 By default, `F = tf.contrib.eager.defun(f)` instantiates a separate graph 3573 for every unique sequence of the shapes and dtypes of Tensor arguments and 3574 the values of Python objects it is invoked with. For example, calling 3575 `F(tf.random.uniform([2])` will execute a different graph than 3576 `F(tf.random.uniform([3])` because the two inputs have different shapes. 3577 The first time that `F(*args, **kwargs)` is called with a particular sequence 3578 of Tensor shapes and dtypes and Python values, it constructs a graph by 3579 tracing the execution of `f(*args, **kwargs)`; this graph is bound to an 3580 input signature inferred from `(*args, **kwargs)` and cached for future reuse. 3581 3582 NumPy arrays passed as inputs to `F` are converted to `tf.Tensor` objects 3583 before being passed to `f`, and are treated as Tensors for caching. This 3584 allows a function to be called multiple times with NumPy arrays having 3585 different values but the same shape and dtype without re-tracing each time. 3586 3587 `tf.contrib.eager.defun` caches graphs for your convenience, letting you 3588 define TensorFlow functions without explicitly specifying their signatures. 3589 However, this policy is conservative and potentially expensive; for example, 3590 when different invocations of your function have differently-shaped Tensor 3591 inputs, this policy might generate more graph functions than necessary. To 3592 eliminate such costs, `tf.contrib.eager.defun` allows you to supply an 3593 optional `input_signature` argument specifying the shapes and dtypes of the 3594 inputs. In particular, the shapes may be partially unspecified, with `None`s 3595 in the unknown dimensions. When an input signature is provided, 3596 `tf.contrib.eager.defun` will only instantiate a single graph for the 3597 decorated Python function. The following is an example: 3598 3599 ```python 3600 import tensorflow as tf 3601 3602 # The first `TensorSpec` below describes the shape and dtype of `words`, 3603 # and the second describes the shape and dtype of `another_tensor`. Note that 3604 # the last dimension of the `words` `TensorSpec` is left unspecified. 3605 @tf.contrib.eager.defun(input_signature=[ 3606 tf.contrib.eager.TensorSpec(shape=[50, 300, None], dtype=tf.float32), 3607 tf.contrib.eager.TensorSpec(shape=[300, 100], dtype=tf.float32) 3608 ]) 3609 def my_sequence_model(words, another_tensor): 3610 ... 3611 3612 # Note how the third dimension of the first input can vary freely. 3613 words = tf.random.uniform(([50, 300, 10]) 3614 second_input = tf.random.uniform([300, 100]) 3615 my_sequence_model(words, second_input) 3616 3617 words = tf.random.uniform(([50, 300, 20]) 3618 my_sequence_model(words, second_input) 3619 3620 # Passing an input with an incompatible shape will raise an error. 3621 words = tf.random.uniform(([50, 100, 20]) 3622 my_sequence_model(words, second_input) # <---- This will raise an error. 3623 3624 ``` 3625 3626 Python functions that are compiled with an `input_signature` must only accept 3627 Tensors as arguments and must not take unnamed keyword arguments (**kwargs). 3628 3629 _Tracing_ 3630 3631 Be aware that because `F` only logs TensorFlow operations, all the other 3632 Python code that `f` executes will only shape the _construction_ of the graphs 3633 that `F` executes: the Python code won't be executed when the graphs 3634 themselves are executed, though it will be executed every time the Python 3635 function is traced (and a given Python function might be traced multiple 3636 times, once for each input signature it is invoked with). For example, whereas 3637 the Python function 3638 3639 ```python 3640 import tensorflow as tf 3641 import numpy as np 3642 3643 tf.compat.v1.enable_eager_execution() 3644 3645 def add_noise(): 3646 return tf.eye(5) + np.random.randn(5, 5) 3647 ``` 3648 3649 will return a different output everytime it is invoked, the compiled function 3650 `compiled = tf.contrib.eager.defun(add_noise)` will return the same value 3651 every time it is called, since a particular random offset generated by NumPy 3652 will be inserted into the graph as a TensorFlow constant. The solution is to 3653 replace the call to `np.random.randn` with `tf.random.normal((5, 5))`. 3654 3655 _Python Side-Effects_ 3656 3657 A corollary of the previous discussion on tracing is the following: If a 3658 Python function `f` has Python side-effects, then executing `f` multiple times 3659 will not necessarily be semantically equivalent to executing `F = 3660 tf.contrib.eager.defun(f)` multiple times; this difference is due to the fact 3661 that `defun` only captures the subgraph of TensorFlow operations that is 3662 constructed when `f` is called in a graph-building context. 3663 3664 _Python Control Flow_ 3665 3666 The structure of many machine learning computations depend upon whether one is 3667 training or validating, and it is common to nest specialized logic under `if 3668 training:` blocks. By mapping each input signature to a unique graph, `defun` 3669 lets users transparently compile such code, as the following code snippet 3670 demonstrates: 3671 3672 ```python 3673 import tensorflow as tf 3674 3675 tf.compat.v1.enable_eager_execution() 3676 3677 @tf.contrib.eager.defun 3678 def lossy_matmul(W, x, training=True): 3679 outputs = tf.matmul(W, x) 3680 if training: 3681 outputs = tf.nn.dropout(outputs, keep_probability=0.2) 3682 return outputs 3683 3684 W = tf.random.normal((3, 5)) 3685 x = tf.random.normal((5, 1)) 3686 3687 # Executes a graph that applies dropout. 3688 lossy_outputs = lossy_matmul(W, x, training=True) 3689 3690 # Executes a graph that does not apply dropout. 3691 exact_outputs = lossy_matmul(W, x, training=False) 3692 ``` 3693 3694 _TensorFlow Control Flow_ 3695 3696 When `autograph` is `True`, data-dependent control flow is allowed as well. 3697 Control flow statements that depend on `Tensor` values are staged into 3698 corresponding TensorFlow ops. For example, the following code will work as 3699 expected: 3700 3701 ```python 3702 @tf.contrib.eager.defun 3703 def dynamic_rnn_loop(cell, seq): 3704 state, output = cell.zero_state() 3705 for input in seq: 3706 state, output = cell(input, state) 3707 return output 3708 ``` 3709 3710 For more information see `tf.autograph`. 3711 3712 _Variables_ 3713 3714 TensorFlow operations related to variable creation and initialization are 3715 automatically lifted out of the graphs generated by `defun`. In practice, this 3716 implies that variable creation and initialization only happen the first time 3717 `F` is called, and that variables are reused every time thereafter. Many 3718 TensorFlow APIs, like `tf.keras.layers.Layer` objects, create variables the 3719 first time they are called and reuse them thereafter. Automatic variable 3720 lifting makes it possible to compile these APIs without extra effort, at the 3721 cost of introducing a discrepancy between the semantics of executing Python 3722 functions and their corresponding compiled functions. For example: 3723 3724 ```python 3725 import tensorflow as tf 3726 3727 tf.compat.v1.enable_eager_execution() 3728 3729 def fn(): 3730 x = tf.Variable(0.0) 3731 x.assign_add(1.0) 3732 return x.read_value() 3733 3734 # `fn` is a Python function, so x is created, initialized, and destroyed upon 3735 # every invocation 3736 assert fn().numpy() == fn().numpy() == 1.0 3737 3738 compiled = tf.contrib.eager.defun(fn) 3739 3740 # Compiling `fn` with `defun` hoists all variables outside of the generated 3741 # graph, so initialization happens exactly once. 3742 assert compiled().numpy() == 1.0 3743 assert compiled().numpy() == 2.0 3744 ``` 3745 3746 Finally, because each input signature is bound to a unique graph, if your 3747 Python function constructs `tf.Variable` objects, then each graph constructed 3748 for that Python function will reference a unique set of variables. To 3749 circumvent this problem, we recommend against compiling Python functions that 3750 create `tf.Variable` objects. Instead, Python functions should either 3751 lexically close over `tf.Variable` objects or accept them as arguments, 3752 preferably encapsulated in an object-oriented container. If you must create 3753 variables inside your Python function and you want each graph generated for it 3754 to reference the same set of variables, add logic to your Python function that 3755 ensures that variables are only created the first time it is called and are 3756 reused for every subsequent invocation; note that this is precisely what 3757 `tf.keras.layers.Layer` objects do, so we recommend using them to represent 3758 variable-bearing computations whenever possible. 3759 3760 Args: 3761 func: function to be compiled. If `func` is None, returns a 3762 decorator that can be invoked with a single argument - `func`. The 3763 end result is equivalent to providing all the arguments up front. 3764 In other words, defun(input_signature=...)(func) is equivalent to 3765 defun(func, input_signature=...). The former allows 3766 the following use case: 3767 @tf.contrib.eager.defun(input_signature=...) 3768 def foo(...): 3769 ... 3770 3771 input_signature: A possibly nested sequence of 3772 `tf.contrib.eager.TensorSpec` objects specifying the shapes and dtypes of 3773 the Tensors that will be supplied to this function. If `None`, a separate 3774 function is instantiated for each inferred input signature. If a 3775 signature is specified, every input to `func` must be a `Tensor`, and 3776 `func` cannot accept `**kwargs`. 3777 autograph: Whether `func` should be compiled before 3778 constructing the graph. See https://www.tensorflow.org/guide/autograph 3779 for more information. 3780 experimental_autograph_options: Experimental knobs (in the form of a tuple 3781 of tensorflow.autograph.Feature values) to control behavior when 3782 autograph=True. 3783 experimental_relax_shapes: When true, argument shapes may be relaxed to 3784 avoid unnecessary retracing. 3785 3786 Returns: 3787 If `func` is not None, returns a callable that will execute the compiled 3788 function (and return zero or more `tf.Tensor` objects). 3789 If `func` is None, returns a decorator that, when invoked with a single 3790 `func` argument, returns a callable equivalent to the case above. 3791 3792 Raises: 3793 TypeError: If `input_signature` is neither `None` nor a sequence of 3794 `tf.contrib.eager.TensorSpec` objects. 3795 """ 3796 return defun_with_attributes( 3797 func=func, 3798 input_signature=input_signature, 3799 autograph=autograph, 3800 experimental_autograph_options=experimental_autograph_options, 3801 experimental_relax_shapes=experimental_relax_shapes) 3802 3803 3804@tf_export("__internal__.function.defun_with_attributes", v1=[]) 3805def defun_with_attributes(func=None, 3806 input_signature=None, 3807 attributes=None, 3808 autograph=True, 3809 experimental_autograph_options=None, 3810 jit_compile=None, 3811 experimental_relax_shapes=False, 3812 experimental_follow_type_hints=False): 3813 """Compiles a Python function into a callable TensorFlow graph. 3814 3815 This function supports adding extra function attributes. See detailed 3816 documentation in defun(). Currently this is not exposed in public API since we 3817 don't expect user to directly use attributes, and attribute won't work by 3818 itself. This assumption might change in future. 3819 3820 Args: 3821 func: function to be compiled. 3822 input_signature: same as defun()'s input_signature. 3823 attributes: A dictionary of arguments which will be added to function def as 3824 attributes. Currently only support primitive types as value, and only 3825 allowlisted attribute name is allowed. Unallowlisted attribute name or 3826 unsupported value will result into ValueError. `func_name` is also one of 3827 the allowlisted argument which is a python string, and sets the name for 3828 this `ConcreteFunction` in the graph. 3829 autograph: same as defun()'s autograph. 3830 experimental_autograph_options: same as defun()'s 3831 experimental_autograph_options. 3832 jit_compile: same as defun()'s jit_compile. 3833 experimental_relax_shapes: same as defun()'s experimental_relax_shapes 3834 experimental_follow_type_hints: see `tf.function`. 3835 3836 Returns: 3837 Same as the return value of defun, with attributes added to the function in 3838 graph. 3839 """ 3840 if input_signature is not None: 3841 validate_signature(input_signature) 3842 3843 # TODO(apassos): deal with captured global state. Deal with control flow. 3844 def decorated(function): 3845 try: 3846 if attributes: 3847 name = attributes.pop("func_name", function.__name__) 3848 else: 3849 name = function.__name__ 3850 except AttributeError: 3851 name = "function" 3852 return tf_decorator.make_decorator( 3853 function, 3854 Function( 3855 function, 3856 name, 3857 input_signature=input_signature, 3858 attributes=attributes, 3859 autograph=autograph, 3860 autograph_options=experimental_autograph_options, 3861 jit_compile=jit_compile, 3862 experimental_relax_shapes=experimental_relax_shapes, 3863 experimental_follow_type_hints=experimental_follow_type_hints)) 3864 3865 # This code path is for the `foo = tfe.defun(foo, ...)` use case 3866 if func is not None: 3867 return decorated(func) 3868 3869 # This code path is for the 3870 # 3871 # @tfe.defun(...) 3872 # def foo(...): 3873 # ... 3874 # 3875 # use case, which is equivalent to `foo = tfe.defun(...)(foo)` 3876 return decorated 3877 3878 3879# When a method is bound to objects of this type, it allows AutoGraph to 3880# recover a weak reference the original method's self pointer, so that it can 3881# execute it consistent with class_method_to_instance_method's 3882# bound_method_wrapper. 3883# TODO(b/119246461): This is not pretty. Use a descriptor instead? 3884class TfMethodTarget(object): 3885 """Binding target for methods replaced by function and defun.""" 3886 3887 __slots__ = ("weakrefself_target__", "weakrefself_func__") 3888 3889 def __init__(self, target, original_python_function): 3890 self.weakrefself_target__ = target 3891 self.weakrefself_func__ = weakref.ref(original_python_function) 3892 3893 @property 3894 def target(self): 3895 return self.weakrefself_target__() 3896 3897 @property 3898 def target_class(self): 3899 true_self = self.weakrefself_target__() 3900 if tf_inspect.isclass(true_self): 3901 # Class method 3902 return true_self 3903 else: 3904 return true_self.__class__ 3905 3906 def call(self, args, kwargs): 3907 wrapped_fn = self.weakrefself_func__() 3908 if tf_inspect.ismethod(wrapped_fn): 3909 wrapped_fn = six.get_unbound_function(wrapped_fn) 3910 return wrapped_fn(self.weakrefself_target__(), *args, **kwargs) 3911 3912 3913def class_method_to_instance_method(original_function, instance): 3914 """Constructs a new `Function` with `self` bound.""" 3915 weak_instance = weakref.ref(instance) 3916 3917 # Note: while we could bind to a weakref proxy instead, that causes the 3918 # bound method to be unhashable. 3919 bound_method = types_lib.MethodType( 3920 original_function.python_function, 3921 TfMethodTarget(weak_instance, original_function.python_function)) 3922 3923 # original_function is expected to be of one of the two `Function` types 3924 # (defined either in function.py or def_function.py). 3925 assert hasattr(original_function, "_name") 3926 assert hasattr(original_function, "_autograph") 3927 assert hasattr(original_function, "_function_spec") 3928 assert hasattr(original_function, "python_function") 3929 3930 weak_bound_method_wrapper = None 3931 def bound_method_wrapper(*args, **kwargs): 3932 """Wraps either a dummy MethodType or a converted AutoGraph function.""" 3933 # __wrapped__ allows AutoGraph to swap in a converted function. 3934 strong_bound_method_wrapper = weak_bound_method_wrapper() 3935 wrapped_fn = strong_bound_method_wrapper.__wrapped__ 3936 3937 if wrapped_fn is strong_bound_method_wrapper.__original_wrapped__: 3938 # If __wrapped__ was not replaced, then call original_function. 3939 # TODO(mdan): For better consistency, use the wrapper's call(). 3940 wrapped_fn = original_function.python_function 3941 if tf_inspect.ismethod(wrapped_fn): 3942 wrapped_fn = six.get_unbound_function(wrapped_fn) 3943 return wrapped_fn(weak_instance(), *args, **kwargs) 3944 3945 # If __wrapped__ was replaced, then it is always an unbound function. 3946 # However, the replacer is still responsible for attaching self properly. 3947 # TODO(mdan): Is it possible to do it here instead? 3948 return wrapped_fn(*args, **kwargs) 3949 weak_bound_method_wrapper = weakref.ref(bound_method_wrapper) 3950 3951 # pylint: disable=protected-access 3952 # We make a dummy MethodType object to generate the correct bound method 3953 # signature. The actual call is to a function with a weak reference to 3954 # `instance`. 3955 instance_func = type(original_function)( 3956 tf_decorator.make_decorator(bound_method, bound_method_wrapper), 3957 name=original_function._name, 3958 autograph=original_function._autograph, 3959 input_signature=original_function.input_signature, 3960 experimental_relax_shapes=original_function._experimental_relax_shapes, 3961 jit_compile=original_function._jit_compile) 3962 # pylint: enable=protected-access 3963 3964 # We wrap the the bound method with tf_decorator so inspection works correctly 3965 wrapped_instance_func = tf_decorator.make_decorator(bound_method, 3966 instance_func) 3967 return wrapped_instance_func 3968 3969 3970class _FunctionGarbageCollector(object): 3971 """Cleans up cycles when a defun goes out of scope.""" 3972 3973 __slots__ = ["_cache"] 3974 3975 def __init__(self, cache): 3976 self._cache = cache 3977 3978 def __del__(self): 3979 if func_graph_module is None or memory is None: 3980 return 3981 try: 3982 while self._cache: 3983 self._cache.popitem() 3984 memory.dismantle_ordered_dict(self._cache) 3985 except: # pylint: disable=bare-except 3986 pass 3987 3988 3989class ConcreteFunctionGarbageCollector(object): 3990 """Cleans up reference cycles when a `ConcreteFunction` goes out of scope.""" 3991 3992 __slots__ = ["_func_graph"] 3993 3994 def __init__(self, func_graph): 3995 self._func_graph = func_graph 3996 3997 def release(self): 3998 """Call off the FuncGraph deletion.""" 3999 self._func_graph = None 4000 4001 def __del__(self): 4002 if func_graph_module is None or memory is None or self._func_graph is None: 4003 return 4004 try: 4005 func_graph_module.dismantle_func_graph(self._func_graph) 4006 except: # pylint: disable=bare-except 4007 pass 4008 4009 4010class _Marker(object): 4011 """Markers used to pretty-print nested args in function signatures.""" 4012 4013 __slots__ = ["_s"] 4014 4015 def __init__(self, s): 4016 self._s = s 4017 4018 def __repr__(self): 4019 return str(self._s) 4020 4021 4022def _structure_summary(structure): 4023 """Displays a summary of the nesting structure of the given value.""" 4024 4025 def type_name(x): 4026 if isinstance(x, type_spec.TypeSpec): 4027 return x.value_type.__name__ 4028 else: 4029 return type(x).__name__ 4030 4031 markers = [_Marker(type_name(v)) for v in nest.flatten(structure)] 4032 return str(nest.pack_sequence_as(structure, markers)) 4033 4034 4035def _contains_type_spec(value): 4036 return any(isinstance(x, type_spec.TypeSpec) for x in nest.flatten(value)) 4037