1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=unidiomatic-typecheck 16"""Defun decorator for defining graph-mode functions.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import collections 23import functools 24import threading 25import types as types_lib 26import weakref 27 28import numpy as np 29import six 30 31from tensorflow.core.framework import attr_value_pb2 32from tensorflow.core.framework import function_pb2 33from tensorflow.python import pywrap_tensorflow 34from tensorflow.python.eager import context 35from tensorflow.python.eager import execute 36from tensorflow.python.eager import tape 37from tensorflow.python.eager.graph_only_ops import graph_placeholder 38from tensorflow.python.framework import c_api_util 39from tensorflow.python.framework import constant_op 40from tensorflow.python.framework import device as pydev 41from tensorflow.python.framework import error_interpolation 42from tensorflow.python.framework import errors 43from tensorflow.python.framework import func_graph as func_graph_module 44from tensorflow.python.framework import ops 45from tensorflow.python.framework import tensor_shape 46from tensorflow.python.framework import tensor_spec 47from tensorflow.python.ops import custom_gradient 48from tensorflow.python.ops import functional_ops 49from tensorflow.python.ops import gradients_util 50from tensorflow.python.ops import resource_variable_ops 51from tensorflow.python.platform import tf_logging as logging 52from tensorflow.python.util import compat 53from tensorflow.python.util import function_utils 54from tensorflow.python.util import memory 55from tensorflow.python.util import nest 56from tensorflow.python.util import tf_decorator 57from tensorflow.python.util import tf_inspect 58 59 60FORWARD_FUNCTION_ATTRIBUTE_NAME = "forward_function_name" 61BACKWARD_FUNCTION_ATTRIBUTE_NAME = "backward_function_name" 62 63 64CacheKey = collections.namedtuple("CacheKey", [ 65 "input_signature", "parent_graph", "device_functions", 66 "colocation_stack"]) 67 68CacheKey.replace = CacheKey._replace # pylint: disable=protected-access 69 70 71def _flat_shape_list(*params): 72 """Return a flat list of TensorShapes, one for each tensor[spec] in `*params`. 73 74 Args: 75 *params: Set of nested entries containing Tensors, TensorSpec, and 76 non-tensors. 77 78 Returns: 79 A list of entries containing either `None` or `TensorShape`. 80 """ 81 return [tensor_shape.TensorShape(x.shape) 82 if isinstance(x, (ops.Tensor, tensor_spec.TensorSpec)) else None 83 for x in nest.flatten(params)] 84 85 86def _shape_less_specific_than(relaxed, to_check): 87 """Checks if `relaxed` is less specific than `to_check`. 88 89 This is an asymmetric check, unlike `TensorShape.is_compatible_with`. If 90 `to_check` has a dimension with an undefined shape, `relaxed` must also have 91 an undefined shape for that dimension. 92 93 Args: 94 relaxed: A `TensorShape` to check against. 95 to_check: A second `TensorShape`. 96 97 Returns: 98 True if `to_check` represents a set of shapes which is a subset of 99 `relaxed`'s shapes and False otherwise. 100 """ 101 if to_check.dims is not None and relaxed.dims is not None: 102 if to_check.rank != relaxed.rank: 103 return False 104 for check_dim, relaxed_dim in zip(to_check.dims, relaxed.dims): 105 if check_dim.value is None and relaxed_dim.value is not None: 106 return False 107 if not relaxed_dim.is_compatible_with(check_dim): 108 return False 109 return True 110 111 112def _compatible_shapes(flat_relaxed, flat_to_check): 113 """Check if lists of TensorShapes contain compatible shapes. 114 115 Checks that each `flat_relaxed` shape covers a superset of the shapes of the 116 corresponding `flat_to_check` shape. 117 118 Args: 119 flat_relaxed: List of TensorShape or None. 120 flat_to_check: List of TensorShape or None. 121 122 Returns: 123 A python bool. 124 125 Raises: 126 RuntimeError: 127 if `len(flat_relaxed) != len(flat_to_check)`. 128 RuntimeError: 129 if `flat_relaxed[i] is None != flat_to_check[i] is None` for any `i`. 130 """ 131 132 if len(flat_relaxed) != len(flat_to_check): 133 raise RuntimeError("Expected shape lists of identical lengths, but saw: " 134 "%s and %s" % (flat_relaxed, flat_to_check)) 135 def is_compatible(relaxed, to_check): 136 """Internal help function. 137 138 Args: 139 relaxed: TensorShape or None. 140 to_check: TensorShape or None. 141 142 Returns: 143 Python bool. 144 145 Raises: 146 RuntimeError: If `relaxed is None != to_check is None`. 147 """ 148 # If both x and y are None, there is no shape to compare. Otherwise check 149 # if they are compatible with each other. Either way, both input signatures 150 # must have have Tensors in the same entries. If not, raise an assertion 151 # error. 152 if relaxed is None != to_check is None: 153 raise RuntimeError( 154 "Expected signature type matches between flattened input shapes " 155 "%s and %s; but saw that (%s is None) != (%s is None)" 156 % (flat_relaxed, flat_to_check, relaxed, to_check)) 157 return relaxed is None or _shape_less_specific_than(relaxed, to_check) 158 return all(is_compatible(relaxed, to_check) 159 for relaxed, to_check in zip(flat_relaxed, flat_to_check)) 160 161 162def _common_shape(x, y): 163 """Find a `TensorShape` that is compatible with both `x` and `y`.""" 164 if x is None != y is None: 165 raise RuntimeError( 166 "Cannot find a common shape when LHS shape is None but RHS shape " 167 "is not (or vice versa): %s vs. %s" % (x, y)) 168 if x is None: 169 return None # The associated input was not a Tensor, no shape generated. 170 if not isinstance(x, tensor_shape.TensorShape): 171 raise TypeError("Expected x to be a TensorShape but saw %s" % (x,)) 172 if not isinstance(y, tensor_shape.TensorShape): 173 raise TypeError("Expected y to be a TensorShape but saw %s" % (y,)) 174 if x.rank != y.rank or x.rank is None: 175 return tensor_shape.TensorShape(None) 176 dims = [] 177 for dim_x, dim_y in zip(x.dims, y.dims): 178 if (dim_x != dim_y 179 or tensor_shape.dimension_value(dim_x) is None 180 or tensor_shape.dimension_value(dim_y) is None): 181 dims.append(None) 182 else: 183 dims.append(tensor_shape.dimension_value(dim_x)) 184 return tensor_shape.TensorShape(dims) 185 186 187def is_same_structure(structure1, 188 structure2, 189 check_values=False): 190 """Check two structures for equality, optionally of types and of values.""" 191 try: 192 nest.assert_same_structure(structure1, structure2) 193 except (ValueError, TypeError): 194 return False 195 if check_values: 196 flattened1 = nest.flatten(structure1) 197 flattened2 = nest.flatten(structure2) 198 # First check the types to avoid AttributeErrors. 199 if any(type(f1) != type(f2) for f1, f2 in zip(flattened1, flattened2)): 200 return False 201 return flattened1 == flattened2 202 return True 203 204 205def _parse_func_attrs(attributes): 206 """Convert the keyword arguments into function_def attributes. 207 208 Currently only support primitive types: bool, int, float and string. 209 210 Args: 211 attributes: the dictionary of attributes. 212 Returns: 213 A dict of attributes where the key is the name of attribute and the value 214 is the AttrValue proto. 215 Raises: 216 ValueError: If the kwargs contains unwhitelisted name or unsupported value 217 types. 218 """ 219 attrs = {} 220 for key, value in attributes.items(): 221 if isinstance(value, attr_value_pb2.AttrValue): 222 attrs[key] = value 223 # bool type check has to happen before int since bool is a subclass of int. 224 elif isinstance(value, bool): 225 attrs[key] = attr_value_pb2.AttrValue(b=value) 226 elif isinstance(value, int): 227 attrs[key] = attr_value_pb2.AttrValue(i=value) 228 elif isinstance(value, float): 229 attrs[key] = attr_value_pb2.AttrValue(f=value) 230 elif isinstance(value, (str, bytes, six.text_type)): 231 attrs[key] = attr_value_pb2.AttrValue(s=compat.as_bytes(value)) 232 else: 233 raise ValueError("Unsupported attribute type for %s with type %s" % 234 (key, type(value))) 235 return attrs 236 237 238class _InterpolateFunctionError(object): 239 """Context Manager that interpolates the exception from 'top_level_func'.""" 240 241 def __init__(self, top_level_func): 242 self._func = top_level_func 243 244 def __enter__(self): 245 pass 246 247 def __exit__(self, typ, exc, tb): 248 if not exc or not isinstance(exc, errors.OpError): 249 return False 250 message = compat.as_text(exc.message) 251 _, tags = error_interpolation.parse_message(message) 252 g = None 253 func_stack = [] 254 # pylint: disable=protected-access 255 for t in tags: 256 if t.type == "function_node": 257 if t.name == compat.as_str(self._func.name): 258 g = self._func._graph 259 elif g: 260 next_func = g._get_function(t.name) 261 if next_func is not None and isinstance(next_func, 262 _EagerDefinedFunction): 263 g = next_func._graph 264 if g: 265 func_stack.append(g.name) 266 else: 267 func_stack.append("<unknown>") 268 # pylint: enable=protected-access 269 if g: 270 message = error_interpolation.interpolate(message, g) 271 message += "\n\nFunction call stack:\n" 272 message += " -> ".join(func_stack) 273 message += "\n" 274 exc._message = message # pylint: disable=protected-access 275 return False 276 277 278def _forward_name(n): 279 """The name of a generated forward defun named n.""" 280 return "__forward_%s_%s" % (n, ops.uid()) 281 282 283def _backward_name(n): 284 """The name of a generated backward defun named n.""" 285 return "__backward_%s_%s" % (n, ops.uid()) 286 287 288def _inference_name(n): 289 """The name of a forward-but-no-gradient defun named n.""" 290 return "__inference_%s_%s" % (n, ops.uid()) 291 292 293# TODO(apassos) get rid of this by splitting framework.function._DefinedFunction 294# so it doesn't have the definition-generating logic and is just a container for 295# an already-defined function. 296class _EagerDefinedFunction(object): 297 """Callable with the interface of `framework.function._DefinedFunction`. 298 299 `_EagerDefinedFunction` encapsulates a function definition and its properties, 300 and it provides a method for calling the encapsulated function. Some Ops 301 take functions as attributes, which have type `func`; an instance of this 302 class may be provided as the value of these `func` attributes. 303 """ 304 305 def __init__(self, name, graph, inputs, outputs, attrs): 306 """Initializes an eager defined function. 307 308 Args: 309 name: str, the name for the created function. 310 graph: Graph, the graph containing the operations in the function 311 inputs: the tensors in the graph to be used as inputs to the function 312 outputs: the tensors in the graph which will be outputs to the function 313 attrs: dict mapping names of attributes to their AttrValue values 314 """ 315 input_ops = set(arg.op for arg in inputs) 316 operations = [op for op in graph.get_operations() if op not in input_ops] 317 318 fn = pywrap_tensorflow.TF_GraphToFunction_wrapper( 319 graph._c_graph, # pylint: disable=protected-access 320 compat.as_str(name), 321 False, 322 [o._c_op for o in operations], # pylint: disable=protected-access 323 [t._as_tf_output() for t in inputs], # pylint: disable=protected-access 324 [t._as_tf_output() for t in outputs], # pylint: disable=protected-access 325 [], 326 [o._c_op for o in graph.control_outputs], # pylint: disable=protected-access 327 [], # control_output_names 328 None, 329 compat.as_str("")) 330 331 for name, attr_value in attrs.items(): 332 serialized = attr_value.SerializeToString() 333 # TODO(iga): this creates and deletes a new TF_Status for every attr. 334 # It might be worth creating a convenient way to re-use status. 335 pywrap_tensorflow.TF_FunctionSetAttrValueProto( 336 fn, compat.as_str(name), serialized) 337 338 # TODO(apassos) avoid creating a FunctionDef (specially to grab the 339 # signature, but also in general it's nice not to depend on it. 340 with c_api_util.tf_buffer() as buffer_: 341 pywrap_tensorflow.TF_FunctionToFunctionDef(fn, buffer_) 342 proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_) 343 function_def = function_pb2.FunctionDef() 344 function_def.ParseFromString(compat.as_bytes(proto_data)) 345 with ops.init_scope(): 346 if context.executing_eagerly(): 347 context.add_function(fn) 348 self.definition = function_def 349 self.name = compat.as_bytes(function_def.signature.name) 350 self.signature = function_def.signature 351 self._num_outputs = len(self.signature.output_arg) 352 self._output_types = [o.type for o in self.signature.output_arg] 353 self._output_shapes = [o.shape for o in outputs] 354 self._control_captures = graph.control_captures 355 self._func_graph_outputs = outputs 356 self.grad_func_name = None 357 self.python_grad_func = None 358 self._c_func = c_api_util.ScopedTFFunction(fn) 359 self._grad_func = None 360 self._graph = graph 361 self._stateful_ops = tuple(op for op in operations if op.op_def.is_stateful) 362 363 def add_to_graph(self, g=None): 364 # pylint: disable=protected-access 365 if not g and context.executing_eagerly(): 366 context.context().add_function_def(self.definition) 367 else: 368 if self.name not in g._functions: 369 g._add_function(self) 370 for f in self._graph._functions.values(): 371 if f.name not in g._functions: 372 g._add_function(f) 373 # pylint: enable=protected-access 374 375 @property 376 def stateful_ops(self): 377 return self._stateful_ops 378 379 def call(self, ctx, args): 380 """Calls this function with `args` as inputs. 381 382 `ConcreteFunction` execution respects device annotations only if the 383 function won't be compiled with xla. 384 385 Args: 386 ctx: a Context object 387 args: a list of arguments to supply this function with. 388 389 Returns: 390 The outputs of the function call. 391 392 Raises: 393 ValueError: if the number of arguments is incorrect. 394 """ 395 if len(args) != len(self.signature.input_arg): 396 raise ValueError( 397 "Arguments and signature arguments do not match: %s %s " % 398 (len(args), len(list(self.signature.input_arg)))) 399 400 function_call_options = ctx.function_call_options 401 if function_call_options.config_proto_serialized is None: 402 config = function_utils.get_disabled_rewriter_config() 403 else: 404 config = function_call_options.config_proto_serialized 405 executor_type = function_call_options.executor_type or "" 406 407 executing_eagerly = ctx.executing_eagerly() 408 if executing_eagerly: 409 with _InterpolateFunctionError(self): 410 outputs = execute.execute( 411 str(self.signature.name), 412 num_outputs=self._num_outputs, 413 inputs=args, 414 attrs=("executor_type", executor_type, 415 "config_proto", config), 416 ctx=ctx) 417 # Replace empty list with None 418 outputs = outputs or None 419 else: 420 # TODO(akshayka): Either remove this if the FunctionLibraryRuntime 421 # creates `PartitionedCallOp` kernels by default, or remove the previous 422 # branch if a TPU kernel is registered for `PartitionedCall`. 423 with _InterpolateFunctionError(self): 424 with ops.control_dependencies(self._control_captures): 425 outputs = functional_ops.partitioned_call( 426 args=args, 427 f=self, 428 tout=self._output_types, 429 executing_eagerly=executing_eagerly, 430 config=config, 431 executor_type=executor_type) 432 433 if executing_eagerly: 434 return outputs 435 else: 436 for i, shape in enumerate(self._output_shapes): 437 outputs[i].set_shape(shape) 438 for i, func_graph_output in enumerate(self._func_graph_outputs): 439 custom_gradient.copy_handle_data(func_graph_output, outputs[i]) 440 return outputs 441 442 443class ConcreteFunction(object): 444 """Callable object encapsulating a function definition and its gradient. 445 446 `ConcreteFunction` is a callable that encapsulates a function definition and 447 is differentiable under `tf.GradientTape` objects. 448 """ 449 450 def __init__(self, func_graph, attrs=None, signature=None): 451 """Initialize a `ConcreteFunction`. 452 453 Args: 454 func_graph: An instance of FuncGraph: the function body to wrap. 455 attrs: (optional) dict mapping names of attributes to their AttrValue 456 values. Attributes in `attrs` will be included in this function's 457 definition. 458 signature: a nested sequence of `TensorSpec` objects specifying the input 459 signature of this function. 460 461 Raises: 462 ValueError: If number of input_placeholders is not equal to the number 463 of function inputs. 464 """ 465 self._arg_keywords = None 466 self._num_positional_args = None 467 self._func_graph = func_graph 468 self._captured_inputs = list(self._func_graph.captures.keys()) 469 self._num_outputs = len(self._func_graph.outputs) 470 self._output_shapes = tuple( 471 output.shape for output in self._func_graph.outputs) 472 self._attrs = _parse_func_attrs(attrs or {}) 473 474 self._inference_function = _EagerDefinedFunction( 475 _inference_name(self._func_graph.name), self._func_graph, 476 self._func_graph.inputs, self._func_graph.outputs, self._attrs) 477 self._backward_graph_function = None 478 self._signature = signature 479 self._gradient_name = None 480 481 def __call__(self, *args, **kwargs): 482 """Executes the wrapped function. 483 484 Args: 485 *args: Tensors or Variables. Positional arguments are only accepted when 486 they correspond one-to-one with arguments of the traced Python function. 487 **kwargs: Tensors or Variables specified by name. When 488 `get_concrete_function` was called to create this `ConcreteFunction`, 489 each Tensor input was given a name, defaulting to the name of the Python 490 function's argument but possibly overridden by the `name=` argument to 491 `tf.TensorSpec`. These names become the argument names for the concrete 492 function. 493 494 Returns: 495 The result of applying the TF function on the given Tensors. 496 497 Raises: 498 AssertionError: If this `ConcreteFunction` was not created through 499 `get_concrete_function`. 500 ValueError: If arguments contains anything other than Tensors or 501 Variables. 502 TypeError: For invalid positional/keyword argument combinations. 503 """ 504 if self._arg_keywords is None or self._num_positional_args is None: 505 if self._signature is not None: 506 if kwargs: 507 raise NotImplementedError( 508 "Keyword arguments not supported when calling a " 509 "wrap_function-decorated function.") 510 return self._call_flat(args) 511 raise AssertionError( 512 "Tried to call a concrete function obtained from an internal API " 513 "through the public interface. Use get_concrete_function instead.") 514 if len(args) > self._num_positional_args: 515 raise TypeError( 516 ("Expected at most {} positional arguments (and the rest keywords, " 517 "of {}), got {}. When calling a concrete function, positional " 518 "arguments may not be bound to Tensors within nested structures." 519 ).format(self._num_positional_args, self._arg_keywords, args)) 520 args = list(args) 521 for keyword in self._arg_keywords[len(args):]: 522 try: 523 args.append(kwargs.pop(compat.as_str(keyword))) 524 except KeyError: 525 specified_keywords = (list(self._arg_keywords[:len(args)]) 526 + list(kwargs.keys())) 527 raise TypeError( 528 "Expected argument names {} but got values for {}. Missing: {}." 529 .format( 530 list(self._arg_keywords), 531 specified_keywords, 532 list(set(self._arg_keywords) - set(specified_keywords)))) 533 if kwargs: 534 positional_arg_keywords = set(self._arg_keywords[:len(args)]) 535 for unused_key in kwargs: 536 if unused_key in positional_arg_keywords: 537 raise TypeError("Got two values for keyword '{}'.".format(unused_key)) 538 raise TypeError("Keyword arguments {} unknown. Expected {}.".format( 539 list(kwargs.keys()), list(self._arg_keywords))) 540 return self._call_flat(args) 541 542 def _filtered_call(self, args, kwargs): 543 """Executes the function, filtering arguments from the Python function. 544 545 Objects aside from Tensors and Variables are ignored. 546 547 Args: 548 args: Canonicalized positional arguments of the Python function. 549 kwargs: Canonicalized keyword arguments of the Python function. 550 551 Returns: 552 The result of applying the function on the Tensors/Variables contained in 553 `args` and `kwargs`. 554 """ 555 return self._call_flat( 556 (t for t in nest.flatten((args, kwargs)) 557 if isinstance(t, (ops.Tensor, 558 resource_variable_ops.ResourceVariable)))) 559 560 def _call_flat(self, args): 561 """Executes the wrapped function. 562 563 Args: 564 args: a list of Tensors or Variables. 565 566 Returns: 567 The result of applying the TF function to `args`. 568 569 Raises: 570 ValueError: If `args` contains anything other than Tensors or Variables. 571 """ 572 ctx = context.context() 573 574 tape.variables_accessed(self._func_graph.variables) 575 576 tensor_inputs = [] 577 variables_used = set([]) 578 for i, arg in enumerate(args): 579 if isinstance(arg, resource_variable_ops.ResourceVariable): 580 # We can pass a variable more than once, and in this case we need to 581 # pass its handle only once. 582 if arg.handle in variables_used: 583 continue 584 if arg.trainable: 585 tape.variable_accessed(arg) 586 tensor_inputs.append(arg.handle) 587 variables_used.add(arg.handle) 588 elif isinstance(arg, ops.Tensor): 589 tensor_inputs.append(arg) 590 elif (self._signature is not None and 591 isinstance(self._signature[i], tensor_spec.TensorSpec)): 592 tensor_inputs.append( 593 ops.convert_to_tensor(arg, self._signature[i].dtype)) 594 else: 595 raise ValueError("All inputs to `ConcreteFunction`s must be Tensors; " 596 "on invocation of %s, the %d-th input (%s) was not a " 597 "Tensor." % (self._func_graph.name, i, str(arg))) 598 args = tensor_inputs + self._captured_inputs 599 600 if (tape.should_record(tensor_inputs) or 601 tape.should_record(self._captured_inputs)): 602 if context.executing_eagerly(): 603 return self._eager_backprop_call(args) 604 else: 605 return self._backprop_call_with_delayed_rewrite(args) 606 607 # Only need to override the gradient in graph mode and when we have outputs. 608 if context.executing_eagerly() or not self.outputs: 609 outputs = self._inference_function.call(ctx, args) 610 else: 611 self._register_gradient() 612 with ops.get_default_graph().gradient_override_map( 613 {"PartitionedCall": self._gradient_name, 614 "StatefulPartitionedCall": self._gradient_name}): 615 outputs = self._inference_function.call(ctx, args) 616 return self._build_call_outputs(outputs) 617 618 def _register_gradient(self): 619 """Registers the gradient for this `ConcreteFunction`. 620 621 The gradient rewrites an inference call op to a forward call op, but does 622 not modify a pre-existing forward call op. It then computes the gradient 623 from the output's gradients and the side outputs of the forward op. 624 """ 625 if self._gradient_name: 626 return 627 self._gradient_name = "PartitionedCall-%s" % ops.uid() 628 629 @ops.RegisterGradient(self._gradient_name) 630 def _registered_grad_fn(op, *doutputs): # pylint: disable=unused-variable 631 return self._grad_fn(op, *doutputs) 632 633 def _grad_fn(self, op, *doutputs): 634 """Gradients of this function.""" 635 if self._backward_graph_function is None: 636 self._construct_backprop_function() 637 638 # pylint: disable=protected-access 639 self._forward_function.add_to_graph(op.graph) 640 num_inference_outputs = self._inference_function._num_outputs 641 642 # Rewrite an inference call op to be a forward call op 643 if op.get_attr("f").name.encode() == self._inference_function.name: 644 op._set_func_attr("f", self._forward_function.name) 645 op._set_type_list_attr("Tout", self._forward_function._output_types) 646 op._add_outputs( 647 self._forward_function._output_types[num_inference_outputs:], 648 self._forward_function._output_shapes[num_inference_outputs:]) 649 for i in range(num_inference_outputs, len(op.outputs)): 650 func_graph_output = self._forward_function._func_graph_outputs[i] 651 custom_gradient.copy_handle_data(func_graph_output, op.outputs[i]) 652 # pylint: enable=protected-access 653 # Compute the gradients using the side outputs 654 side_outputs = op.outputs[num_inference_outputs:] 655 args = list(doutputs[:num_inference_outputs]) + list(side_outputs) 656 return self._backward_graph_function._call_flat( # pylint: disable=protected-access 657 (a for a in args if a is not None)) 658 659 @property 660 def name(self): 661 """`ConcreteFunction` name.""" 662 return self._inference_function.name 663 664 @property 665 def graph(self): 666 """Returns the graph from which this function was constructed.""" 667 return self._func_graph 668 669 @property 670 def inputs(self): 671 """Returns tensors in `self.graph` corresponding to arguments.""" 672 return self._func_graph.inputs 673 674 @property 675 def structured_input_signature(self): 676 """Returns structured signature of the original function.""" 677 return self._func_graph.structured_input_signature 678 679 @property 680 def outputs(self): 681 """Returns tensors in `self.graph` corresponding to returned tensors.""" 682 return self._func_graph.outputs 683 684 @property 685 def structured_outputs(self): 686 """Returns outputs in `self.graph` as returned by the original function.""" 687 return self._func_graph.structured_outputs 688 689 @property 690 def captured_inputs(self): 691 """Returns external Tensors captured by this function. 692 693 self.__call__(*args) passes `args + self.captured_inputs` to the function. 694 """ 695 return self._captured_inputs 696 697 @property 698 def function_def(self): 699 """Returns a `FunctionDef` object representing this function.""" 700 return self._inference_function.definition 701 702 @property 703 def output_shapes(self): 704 """The function's output shapes.""" 705 # TODO(ebrevdo): Should we only keep the output shapes associated 706 # with len(self._python_returns) outputs? 707 # TODO(akshayka): Consider removing this. 708 outputs_list = nest.flatten(self._func_graph.structured_outputs) 709 j = 0 710 for i, o in enumerate(outputs_list): 711 if o is not None: 712 if isinstance(o, ops.IndexedSlices): 713 # Extract the shape of the `IndexedSlices` object's `values` field. 714 outputs_list[i] = self._output_shapes[j] # the `values` shape 715 if o.dense_shape is not None: 716 j += 3 # skip over shapes for `values`, `indices`, `dense_shape` 717 else: 718 j += 2 # skip over shapes for `values`, `indices` 719 else: 720 outputs_list[i] = self._output_shapes[j] 721 j += 1 722 return nest.pack_sequence_as(self._func_graph.structured_outputs, 723 outputs_list) 724 725 @property 726 def output_dtypes(self): 727 # TODO(akshayka): Consider removing this. 728 return nest.map_structure(lambda x: x.dtype if x is not None else None, 729 self._func_graph.structured_outputs) 730 731 def add_to_graph(self, g=None, register_gradient_functions=False): 732 """Registers the function, adds it to the graph g or default graph.""" 733 # If we are not executing eagerly, adds the function to default graph if no 734 # graph is specified. 735 # In case of eager execution, function definition gets added to context 736 # during construction itself. 737 738 # TODO(allenl/shivaniagrawal): rename this to register to reflect the 739 # method's functionality better. Remove register_gradient_functions argument 740 # and figure out if these needs to be registered. 741 742 if not context.executing_eagerly() and not g: 743 g = ops.get_default_graph() 744 self._inference_function.add_to_graph(g) # pylint: disable=protected-access 745 746 # pylint: disable=protected-access 747 if register_gradient_functions: 748 # There are two situations for the actual call of a defun: 749 # 1. If none of the input args are resource variables or watch by any 750 # tape, and it will run the _inference_function of concrete_func for 751 # forward pass, the gradient will be generated by standard mechanism. 752 # 2. Otherwise, defun will create two functions, one for forward pass, 753 # and the backward pass will be created via tape. 754 # When registering the function, we register both cases. 755 if self._backward_graph_function is None: 756 self._construct_backprop_function() 757 forward_function = self._forward_function 758 backward_function = self._backward_graph_function._inference_function 759 # pylint: enable=protected-access 760 forward_function.add_to_graph(g) 761 backward_function.add_to_graph(g) 762 763 def _construct_backprop_function(self): 764 """Constructs the backprop function object for this function.""" 765 backwards_graph = func_graph_module.FuncGraph( 766 _backward_name(self._func_graph.name)) 767 forward_function_name = _forward_name(self._func_graph.name) 768 outputs = [x for x in self._func_graph.outputs 769 if gradients_util.IsTrainable(x)] 770 with backwards_graph.as_default(): 771 gradients_wrt_outputs = [ 772 graph_placeholder(x.dtype, x.shape) for x in outputs 773 ] 774 gradients_wrt_inputs = gradients_util._GradientsHelper( # pylint: disable=protected-access 775 outputs, 776 self._func_graph.inputs, 777 grad_ys=gradients_wrt_outputs, 778 src_graph=self._func_graph) 779 780 backwards_graph_captures = list(backwards_graph.captures.keys()) 781 782 backward_function_attr = _parse_func_attrs( 783 {FORWARD_FUNCTION_ATTRIBUTE_NAME: forward_function_name}) 784 backward_function_attr.update(self._attrs) 785 786 # The ordering of `backwards_graph.inputs` is important: inputs of 787 # `self._backward_graph_function` correspond to outputs of 788 # `self._forward_function`. 789 backwards_graph.inputs = gradients_wrt_outputs + list( 790 backwards_graph.captures.values()) 791 # Clear captures, since we pass them in as inputs. 792 backwards_graph.captures = {} 793 backwards_graph.outputs.extend( 794 grad 795 for grad in nest.flatten(gradients_wrt_inputs, expand_composites=True) 796 if grad is not None) 797 backwards_graph.structured_outputs = gradients_wrt_inputs 798 self._backward_graph_function = ConcreteFunction( 799 backwards_graph, attrs=backward_function_attr) 800 801 forward_function_attr = _parse_func_attrs({ 802 BACKWARD_FUNCTION_ATTRIBUTE_NAME: 803 self._backward_graph_function._inference_function.name}) # pylint: disable=protected-access 804 forward_function_attr.update(self._attrs) 805 self._forward_function = _EagerDefinedFunction( 806 forward_function_name, self._func_graph, self._func_graph.inputs, 807 self._func_graph.outputs + backwards_graph_captures, 808 forward_function_attr) 809 810 def _eager_backprop_call(self, args): 811 """Calls the forward function and records the result on a tape. 812 813 This method fully constructs the forward and backward functions before 814 calling the function and recording them on the tape. 815 816 (Only records results on a tape if the function has outputs). 817 818 Args: 819 args: All inputs to the function, including resolved captured inputs 820 821 Returns: 822 The call output. 823 """ 824 if self._backward_graph_function is None: 825 self._construct_backprop_function() 826 827 ctx = context.context() 828 829 self._register_gradient() 830 with ops.get_default_graph().gradient_override_map( 831 {"PartitionedCall": self._gradient_name, 832 "StatefulPartitionedCall": self._gradient_name}): 833 outputs = self._forward_function.call(ctx, args) 834 835 if isinstance(outputs, ops.Operation) or outputs is None: 836 return outputs 837 838 # `real_outputs` are the actual outputs of the inference graph function; 839 # `side_outputs` are the intermediate Tensors that were added as outputs to 840 # the forward graph function so that we can compute its gradient. 841 real_outputs = outputs[:self._num_outputs] 842 skip_positions = [i for i, t in enumerate(real_outputs) 843 if not gradients_util.IsTrainable(t)] 844 side_outputs = outputs[self._num_outputs:] 845 846 def backward_function(*args): 847 args = [a for i, a in enumerate(args) 848 if a is not None and i not in skip_positions] 849 return self._backward_graph_function._call_flat( # pylint: disable=protected-access 850 list(args) + side_outputs) 851 852 tape.record_operation(self._forward_function.signature.name, real_outputs, 853 args, backward_function) 854 return self._build_call_outputs(real_outputs) 855 856 def _backprop_call_with_delayed_rewrite(self, args): 857 """Calls the inference function and records the result on a tape. 858 859 The recorded backwards function will construct the backwards graph and 860 rewrite the inference function to the forward function. This only happens 861 if the recorded backwards function ends up being used to compute gradients. 862 863 This approach avoids constructing unnecessary graphs, but it only works if 864 we are calling this function when not executing eagerly. 865 866 (Only records results on a tape if the function has outputs) 867 868 Args: 869 args: All inputs to the function, including resolved captured inputs 870 871 Returns: 872 The call output. 873 """ 874 ctx = context.context() 875 876 self._register_gradient() 877 with ops.get_default_graph().gradient_override_map( 878 {"PartitionedCall": self._gradient_name, 879 "StatefulPartitionedCall": self._gradient_name}): 880 outputs = self._inference_function.call(ctx, args) 881 882 if isinstance(outputs, ops.Operation) or outputs is None: 883 return outputs 884 885 call_op = outputs[0].op 886 887 def backward_function(*args): 888 return self._grad_fn(call_op, *args) 889 890 tape.record_operation(self._inference_function.signature.name, outputs, 891 args, backward_function) 892 return self._build_call_outputs(outputs) 893 894 def _build_call_outputs(self, result): 895 """Maps the fdef output list to actual output structure. 896 897 Args: 898 result: Output lists defined by FunctionDef. 899 Returns: 900 The actual call output. 901 """ 902 if self._func_graph.structured_outputs is None: 903 return result 904 905 # Use `nest.flatten` instead of `func_graph_module.flatten` in order to 906 # preserve any IndexedSlices in `self._func_graph.structured_outputs`. 907 outputs_list = nest.flatten(self._func_graph.structured_outputs) 908 j = 0 909 for i, o in enumerate(outputs_list): 910 if o is not None: 911 if isinstance(o, ops.IndexedSlices): 912 # Repack Tensors for IndexedSlices. 913 if o.dense_shape is not None: 914 outputs_list[i] = ops.IndexedSlices( 915 values=result[j], 916 indices=result[j + 1], 917 dense_shape=result[j + 2]) 918 j += 3 919 else: 920 outputs_list[i] = ops.IndexedSlices( 921 values=result[j], indices=result[j + 1]) 922 j += 2 923 else: 924 outputs_list[i] = result[j] 925 j += 1 926 ret = nest.pack_sequence_as(self._func_graph.structured_outputs, 927 outputs_list) 928 return ret 929 930 931pywrap_tensorflow.RegisterType("Tensor", ops.Tensor) 932pywrap_tensorflow.RegisterType("IndexedSlices", ops.IndexedSlices) 933 934 935def _deterministic_dict_values(dictionary): 936 return tuple(dictionary[key] for key in sorted(dictionary)) 937 938 939class FunctionSpec(object): 940 """Specification of how to bind arguments to a function.""" 941 942 @staticmethod 943 def from_function_and_signature(python_function, input_signature): 944 """Create a FunctionSpec instance given a python function and signature.""" 945 if isinstance(python_function, functools.partial): 946 python_function_to_inspect = python_function.func 947 args_to_prepend = python_function.args or tuple() 948 kwargs_to_include = python_function.keywords or {} 949 if input_signature is not None: 950 # TODO(b/124441704): Add support for input_signature + partial. 951 raise NotImplementedError( 952 "Missing support for input_signature when using partial functions.") 953 else: 954 python_function_to_inspect = python_function 955 args_to_prepend = tuple() 956 kwargs_to_include = {} 957 958 fullargspec = tf_inspect.getfullargspec(python_function_to_inspect) 959 is_method = tf_inspect.ismethod(python_function_to_inspect) 960 961 return FunctionSpec(fullargspec, is_method, args_to_prepend, 962 kwargs_to_include, input_signature) 963 964 def __init__(self, fullargspec, is_method, args_to_prepend, kwargs_to_include, 965 input_signature): 966 self._fullargspec = fullargspec 967 self._is_method = is_method 968 self._args_to_prepend = args_to_prepend 969 self._kwargs_to_include = kwargs_to_include 970 self._default_values = fullargspec.defaults 971 972 if self._is_method: 973 # Remove `self`: default arguments shouldn't be matched to it. 974 args = fullargspec.args[1:] 975 else: 976 args = fullargspec.args 977 978 # A cache mapping from argument name to index, for canonicalizing 979 # arguments that are called in a keyword-like fashion. 980 self._args_to_indices = {arg: i for i, arg in enumerate(args)} 981 self.arg_names = args 982 self.vararg_name = fullargspec.varargs 983 984 # A cache mapping from arg index to default value, for canonicalization. 985 offset = len(args) - len(fullargspec.defaults or []) 986 self._arg_indices_to_default_values = { 987 offset + index: default 988 for index, default in enumerate(fullargspec.defaults or []) 989 } 990 self._default_values_start_index = offset 991 if input_signature is None: 992 self._input_signature = None 993 else: 994 if fullargspec.varkw is not None or fullargspec.kwonlyargs: 995 raise ValueError("Cannot define a TensorFlow function from a Python " 996 "function with keyword arguments when " 997 "input_signature is provided.") 998 999 if not isinstance(input_signature, (tuple, list)): 1000 raise TypeError("input_signature must be either a tuple or a " 1001 "list, received " + str(type(input_signature))) 1002 1003 self._input_signature = tuple(input_signature) 1004 self._flat_input_signature = tuple(nest.flatten(input_signature)) 1005 1006 @property 1007 def fullargspec(self): 1008 return self._fullargspec 1009 1010 @property 1011 def is_method(self): 1012 return self._is_method 1013 1014 @property 1015 def args_to_prepend(self): 1016 return self._args_to_prepend 1017 1018 @property 1019 def kwargs_to_include(self): 1020 return self._kwargs_to_include 1021 1022 @property 1023 def input_signature(self): 1024 return self._input_signature 1025 1026 @property 1027 def flat_input_signature(self): 1028 return self._flat_input_signature 1029 1030 def canonicalize_function_inputs(self, *args, **kwargs): 1031 """Canonicalizes `args` and `kwargs`. 1032 1033 Canonicalize the inputs to the Python function using a `FunctionSpec` 1034 instance. In particular, we parse the varags and kwargs that the 1035 original function was called with into a tuple corresponding to the 1036 Python function's positional (named) arguments and a dictionary 1037 corresponding to its kwargs. 1038 1039 Args: 1040 *args: The varargs this object was called with. 1041 **kwargs: The keyword args this function was called with. 1042 1043 Returns: 1044 A canonicalized ordering of the inputs representened by a tuple in the 1045 form (args, kwargs). Here: `args` is a full list of bound arguments, and 1046 `kwargs` contains only true keyword arguments, as opposed to named 1047 arguments called in a keyword-like fashion. 1048 1049 Raises: 1050 ValueError: If a keyword in `kwargs` cannot be matched with a positional 1051 argument when an input signature is specified, or when the inputs 1052 do not conform to the input signature. 1053 """ 1054 if self._input_signature is not None: 1055 if len(args) > len(self._input_signature): 1056 raise TypeError( 1057 "When input_signature is provided, only pass arguments " 1058 "covered by it. Received %d argument(s)." % len(args)) 1059 for arg in six.iterkeys(kwargs): 1060 index = self._args_to_indices.get(arg, None) 1061 if index is None: 1062 raise TypeError( 1063 "Function got an unexpected keyword argument %s" % arg) 1064 if index >= len(self._input_signature): 1065 raise TypeError( 1066 "When input_signature is provided, only pass arguments " 1067 "covered by it. Received argument %s." % arg) 1068 1069 args = self._args_to_prepend + args 1070 kwargs = dict(kwargs, **self._kwargs_to_include) 1071 if not kwargs: 1072 if self._default_values: 1073 inputs = args + self._default_values[ 1074 len(args) - self._default_values_start_index:] 1075 else: 1076 inputs = args 1077 else: 1078 # Maps from index of arg to its corresponding value, according to `args` 1079 # and `kwargs`; seeded with the default values for the named args that 1080 # aren't in `args`. 1081 arg_indices_to_values = { 1082 index: default for index, default in six.iteritems( 1083 self._arg_indices_to_default_values) if index >= len(args) 1084 } 1085 consumed_args = [] 1086 for arg, value in six.iteritems(kwargs): 1087 index = self._args_to_indices.get(arg, None) 1088 if index is not None: 1089 arg_indices_to_values[index] = value 1090 consumed_args.append(arg) 1091 elif self._input_signature is not None: 1092 raise ValueError("Cannot define a TensorFlow function from a Python " 1093 "function with keyword arguments when " 1094 "input_signature is provided.") 1095 for arg in consumed_args: 1096 # After this loop, `kwargs` will only contain true keyword arguments, as 1097 # opposed to named arguments called in a keyword-like fashion. 1098 kwargs.pop(arg) 1099 inputs = args + _deterministic_dict_values(arg_indices_to_values) 1100 1101 if self._input_signature is None: 1102 inputs = _convert_numpy_inputs(inputs) 1103 return inputs, kwargs 1104 else: 1105 assert not kwargs 1106 inputs = _convert_inputs_to_signature( 1107 inputs, 1108 self._input_signature, 1109 self._flat_input_signature) 1110 return inputs, {} 1111 1112 1113def _convert_numpy_inputs(inputs): 1114 """Convert numpy array inputs to tensors.""" 1115 flat_inputs = nest.flatten(inputs) 1116 1117 # Check for NumPy arrays in arguments and convert them to Tensors. 1118 # TODO(nareshmodi): Skip ndarray conversion to tensor altogether, perhaps 1119 # finding a way to store them directly in the cache key (currently not 1120 # possible since ndarrays are not hashable). 1121 need_packing = False 1122 for index, value in enumerate(flat_inputs): 1123 if type(value) == np.ndarray: 1124 flat_inputs[index] = constant_op.constant(value) 1125 need_packing = True 1126 if need_packing: 1127 return nest.pack_sequence_as( 1128 structure=inputs, flat_sequence=flat_inputs) 1129 else: 1130 return inputs 1131 1132 1133def _convert_inputs_to_signature(inputs, input_signature, flat_input_signature): 1134 """Convert inputs to pass into a function with an explicit signature.""" 1135 try: 1136 # TODO(b/124370185): Use all elements as inputs to throw an error if there 1137 # are ignored arguments. Calling with arguments that are not part of the 1138 # signature should throw an error. 1139 flatten_inputs = nest.flatten_up_to( 1140 input_signature, 1141 inputs[:len(input_signature)]) 1142 except ValueError: 1143 raise ValueError("Structure of Python function inputs does not match " 1144 "input_signature. Inputs (%s), input_signature(%s)." % 1145 (str(inputs), str(input_signature))) 1146 1147 need_packing = False 1148 for index, (value, spec) in enumerate(zip(flatten_inputs, 1149 flat_input_signature)): 1150 if not pywrap_tensorflow.IsTensor(value): 1151 try: 1152 flatten_inputs[index] = ops.convert_to_tensor( 1153 value, dtype_hint=spec.dtype) 1154 need_packing = True 1155 except ValueError: 1156 raise ValueError("When input_signature is provided, all inputs to " 1157 "the Python function must be convertible to tensors." 1158 "Inputs (%s), input_signature(%s)." % 1159 (str(inputs), str(input_signature))) 1160 1161 if any(not spec.is_compatible_with(other) for spec, other in zip( 1162 flat_input_signature, 1163 flatten_inputs)): 1164 raise ValueError("Python inputs incompatible with input_signature: " 1165 "inputs (%s), input_signature (%s)" % 1166 (str(inputs), str(input_signature))) 1167 1168 if need_packing: 1169 inputs = nest.pack_sequence_as( 1170 structure=input_signature, 1171 flat_sequence=flatten_inputs) 1172 1173 return inputs 1174 1175 1176class FunctionCache(object): 1177 """A lightweight container for cached functions. 1178 """ 1179 1180 def __init__(self): 1181 # The set of functions that have been missed; entries are CacheKey with 1182 # input_signature `None` (e.g. a "call context key") 1183 self.missed = set() 1184 # The primary cache, mapping a fully shaped CacheKey to a function. 1185 self.primary = collections.OrderedDict() 1186 # A cache key lookup, mapping a CacheKey generated without shape info to a 1187 # flat list of relaxed shapes (one for each argument). Arguments that are 1188 # not Tensors contain a `None` for the corresponding relaxed shape. 1189 self.arg_relaxed_shapes = collections.OrderedDict() 1190 # The secondary cache, mapping a CacheKey generated without shape info to a 1191 # function. 1192 self.arg_relaxed = collections.OrderedDict() 1193 # All OrderedDicts require manual garbage collection. 1194 self._garbage_collectors = [ 1195 _FunctionGarbageCollector(self.primary), 1196 _FunctionGarbageCollector(self.arg_relaxed), 1197 _FunctionGarbageCollector(self.arg_relaxed_shapes)] 1198 1199 def all_values(self): 1200 """A set of all `ConcreteFunction` instances held by this cache.""" 1201 return set(self.primary.values()) | set(self.arg_relaxed.values()) 1202 1203 1204class Function(object): 1205 """Wrapper class for the graph functions defined for a Python function. 1206 1207 See the documentation for `defun` for more information on the semantics of 1208 defined functions. 1209 1210 `Function` class is thread-compatible meaning that minimal usage of defuns 1211 (defining and calling) is thread-safe, but if users call other methods or 1212 invoke the base `python_function` themselves, external synchronization is 1213 necessary. 1214 """ 1215 1216 def __init__(self, 1217 python_function, 1218 name, 1219 input_signature=None, 1220 attributes=None, 1221 autograph=True, 1222 autograph_options=None, 1223 capture_by_value=None): 1224 """Initializes a `Function`. 1225 1226 Args: 1227 python_function: the function to be wrapped. 1228 name: the name given to it. 1229 input_signature: a possibly nested sequence of `TensorSpec` objects 1230 specifying the input signature of this function. If `None`, a separate 1231 function is instantiated for each inferred input signature. 1232 attributes: dict, extra keyword arguments that will be added as attribute 1233 of the function. 1234 autograph: whether to use autograph to compile 1235 `python_function`. See https://www.tensorflow.org/guide/autograph for 1236 more information. 1237 autograph_options: Experimental knobs to control behavior 1238 `when autograph=True`. See https://www.tensorflow.org/guide/autograph 1239 for more information. 1240 capture_by_value: Experimental. Whether to capture resource variables by 1241 value or reference. If None, will inherit from a parent context or 1242 default to False. 1243 1244 Raises: 1245 ValueError: if `input_signature` is not None and the `python_function`'s 1246 argspec has keyword arguments. 1247 """ 1248 if isinstance(python_function, functools.partial): 1249 self._python_function = python_function.func 1250 else: 1251 self._python_function = python_function 1252 self._function_spec = FunctionSpec.from_function_and_signature( 1253 python_function, input_signature) 1254 self._name = name 1255 self._autograph = autograph 1256 self._autograph_options = autograph_options 1257 self._function_cache = FunctionCache() 1258 self._function_attributes = attributes or {} 1259 self._capture_by_value = capture_by_value 1260 1261 self._lock = threading.Lock() 1262 # _descriptor_cache is a of instance of a class to an instance-specific 1263 # `Function`, used to make sure defun-decorated methods create different 1264 # functions for each instance. 1265 self._descriptor_cache = weakref.WeakKeyDictionary() 1266 1267 def __call__(self, *args, **kwargs): 1268 """Calls a graph function specialized to the inputs.""" 1269 graph_function, args, kwargs = self._maybe_define_function(args, kwargs) 1270 return graph_function._filtered_call(args, kwargs) # pylint: disable=protected-access 1271 1272 @property 1273 def python_function(self): 1274 """Returns the wrapped Python function.""" 1275 return self._python_function # pylint: disable=protected-access 1276 1277 @property 1278 def function_spec(self): 1279 return self._function_spec 1280 1281 @property 1282 def input_signature(self): 1283 """Returns the input signature.""" 1284 return self._function_spec.input_signature 1285 1286 @property 1287 def flat_input_signature(self): 1288 """Returns the flattened input signature.""" 1289 return self._function_spec.flat_input_signature 1290 1291 def _get_concrete_function_internal_garbage_collected(self, *args, **kwargs): 1292 """Returns a concrete function which cleans up its graph function.""" 1293 if self.input_signature: 1294 args, kwargs = None, None 1295 graph_function, _, _ = self._maybe_define_function(args, kwargs) 1296 return graph_function 1297 1298 def _get_concrete_function_internal(self, *args, **kwargs): 1299 """Bypasses error checking when getting a graph function.""" 1300 graph_function = self._get_concrete_function_internal_garbage_collected( 1301 *args, **kwargs) 1302 # We're returning this concrete function to someone, and they may keep a 1303 # reference to the FuncGraph without keeping a reference to the 1304 # ConcreteFunction object. So we won't clean up the reference cycles 1305 # manually and instead will leave them to Python's garbage collector. 1306 graph_function._garbage_collector.release() # pylint: disable=protected-access 1307 return graph_function 1308 1309 def get_concrete_function(self, *args, **kwargs): 1310 """Returns a `ConcreteFunction` specialized to inputs and execution context. 1311 1312 Args: 1313 *args: inputs to specialize on. 1314 **kwargs: inputs to specialize on. 1315 """ 1316 if self.input_signature: 1317 if kwargs: 1318 raise ValueError("Cannot define a TensorFlow function from a Python " 1319 "function with keyword arguments when " 1320 "input_signature is provided.") 1321 if args: 1322 # If args are provided, they must match the input signature. 1323 if not is_same_structure(self.input_signature, args): 1324 raise ValueError("Structure of Python function inputs does not match " 1325 "input_signature.") 1326 flat_inputs = nest.flatten(args) 1327 if any(not isinstance(arg, (ops.Tensor, tensor_spec.TensorSpec)) 1328 for arg in flat_inputs): 1329 raise ValueError("When input_signature is provided, all inputs to " 1330 "the Python function must be Tensors or " 1331 "tf.TensorSpec objects.") 1332 if any(not spec.is_compatible_with(other) 1333 for spec, other in zip(self.flat_input_signature, flat_inputs)): 1334 raise ValueError("Python inputs incompatible with input_signature: " 1335 "inputs (%s), input_signature (%s)" % 1336 (str(args), str(self.input_signature))) 1337 args, kwargs = None, None 1338 graph_function, args, kwargs = self._maybe_define_function(args, kwargs) 1339 if self.input_signature: 1340 args = self.input_signature 1341 kwargs = {} 1342 seen_names = set() 1343 captured = frozenset(graph_function.graph.internal_captures) 1344 allowed_positional = 0 1345 if args: 1346 for outer_arg in args: 1347 # TODO(allenl): Consider allowing arguments with defaults in the Python 1348 # function's signature to be passed as positional arguments to the 1349 # concrete function. 1350 if not isinstance( 1351 outer_arg, 1352 (ops.Tensor, resource_variable_ops.ResourceVariable, 1353 tensor_spec.TensorSpec)): 1354 break 1355 allowed_positional += 1 1356 # pylint: disable=protected-access 1357 graph_function._num_positional_args = allowed_positional 1358 graph_function._arg_keywords = [] 1359 # pylint: enable=protected-access 1360 for arg in graph_function.graph.inputs: 1361 if arg in captured: 1362 break 1363 user_arg_name = arg.op.get_attr("_user_specified_name") 1364 if user_arg_name in seen_names: 1365 raise ValueError( 1366 ("Unable to construct a concrete function for {} since some " 1367 "arguments do not have unique names. Got two arguments named " 1368 "'{}'. When constructing a concrete TensorFlow function from a " 1369 "Python function which takes nested structures or variadic " 1370 "positional arguments, pass unique names to tf.TensorSpec objects " 1371 "used to identify these Tensor inputs. These names may then be " 1372 "used as keyword arguments to the concrete function.") 1373 .format( 1374 self._python_function, 1375 compat.as_str(arg.op.get_attr("_user_specified_name")))) 1376 seen_names.add(user_arg_name) 1377 graph_function._arg_keywords.append(user_arg_name) # pylint: disable=protected-access 1378 return graph_function 1379 1380 def __get__(self, instance, owner): 1381 """Makes it possible to defun instance methods.""" 1382 del owner 1383 # `instance` here is the instance that this `Function` was accessed through 1384 # e.g., for 1385 # 1386 # class Foo(object): 1387 # 1388 # @function.defun 1389 # def bar(self): 1390 # ... 1391 # 1392 # foo = Foo() 1393 # foo.bar() # `foo.bar` is a `Function` instance 1394 # 1395 # then `instance` will be `foo` (and `owner` will be `Foo`). We create a 1396 # new instance of `Function` here to allow different instances each 1397 # to create variables once, thereby allowing methods to be decorated with 1398 # defun. Keeps a cache to avoid retracing the function every time the 1399 # descriptor is accessed. 1400 if instance not in self._descriptor_cache: 1401 if instance is None: 1402 return self 1403 # If there is no instance-specific `Function` in the cache, we construct 1404 # an instance-specific `Function` that uses a weak reference to the 1405 # instance (so that the instance will be correctly gc'd). 1406 1407 # And finally add the wrapped function to the description cache 1408 self._descriptor_cache[instance] = class_method_to_instance_method( 1409 self, instance) 1410 1411 # Return the cached `Function` for the instance 1412 return self._descriptor_cache[instance] 1413 1414 def _cache_key(self, args, kwargs, include_tensor_ranks_only=False): 1415 """Computes the cache key given inputs and execution context.""" 1416 if self.input_signature is None: 1417 inputs = (args, kwargs) if kwargs else args 1418 input_signature = pywrap_tensorflow.TFE_Py_EncodeArg( 1419 inputs, include_tensor_ranks_only) 1420 else: 1421 del args, kwargs 1422 assert not include_tensor_ranks_only 1423 input_signature = self.flat_input_signature 1424 1425 ctx = context.context() 1426 1427 # Don't need to open an init_scope if the _cache_key call is in eager mode 1428 # already. 1429 executing_eagerly = ctx.executing_eagerly() 1430 parent_graph = None 1431 if not executing_eagerly: 1432 with ops.init_scope(): 1433 # The graph, or whether we're executing eagerly, should be a part of the 1434 # cache key so we don't improperly capture tensors such as variables. 1435 executing_eagerly = ctx.executing_eagerly() 1436 parent_graph = None if executing_eagerly else ops.get_default_graph() 1437 1438 # pylint: disable=protected-access 1439 default_graph = ops.get_default_graph() 1440 # TODO(b/117617952): The current distribution strategy will affect graph 1441 # building (e.g. accessing different variables from different devices) and 1442 # so requires retracing for each device. 1443 uses_distribution_strategy = bool( 1444 default_graph._distribution_strategy_stack) 1445 if executing_eagerly: 1446 colocation_stack = () 1447 if uses_distribution_strategy: 1448 device_functions = (pydev.merge_device(ctx.device_name),) 1449 else: 1450 device_functions = () 1451 else: 1452 colocation_stack = tuple(default_graph._colocation_stack.peek_objs()) 1453 if (uses_distribution_strategy 1454 or func_graph_module.device_stack_has_callable( 1455 default_graph._device_function_stack)): 1456 # Putting the device in the cache key ensures that call-site device 1457 # annotations are respected. 1458 device_functions = tuple(default_graph._device_functions_outer_to_inner) 1459 else: 1460 device_functions = () 1461 # pylint: enable=protected-access 1462 return CacheKey(input_signature, parent_graph, device_functions, 1463 colocation_stack) 1464 1465 def _create_graph_function(self, args, kwargs, override_flat_arg_shapes=None): 1466 """Create a `ConcreteFunction` from `args` and `kwargs`.""" 1467 if self.input_signature is None: 1468 arglen = len(args) 1469 else: 1470 arglen = len(self.input_signature) 1471 base_arg_names = self._function_spec.arg_names[:arglen] 1472 num_missing_args = arglen - len(self._function_spec.arg_names) 1473 missing_arg_names = [self._function_spec.vararg_name] * num_missing_args 1474 # Produce a list of missing args of the form ["arg_0", "arg_1", ...], 1475 # where arg is based on the self._function_spec.vararg_name. 1476 missing_arg_names = [ 1477 "%s_%d" % (arg, i) for i, arg in enumerate(missing_arg_names) 1478 ] 1479 arg_names = base_arg_names + missing_arg_names 1480 graph_function = ConcreteFunction( 1481 func_graph_module.func_graph_from_py_func( 1482 self._name, 1483 self._python_function, 1484 args, 1485 kwargs, 1486 self.input_signature, 1487 autograph=self._autograph, 1488 autograph_options=self._autograph_options, 1489 arg_names=arg_names, 1490 override_flat_arg_shapes=override_flat_arg_shapes, 1491 capture_by_value=self._capture_by_value), 1492 self._function_attributes) 1493 1494 # pylint: disable=protected-access 1495 # Tell the ConcreteFunction to clean up its graph once it goes out of 1496 # scope. ConcreteFunction does not do this in its constructor since it 1497 # gets used in some places (like Keras) where the FuncGraph lives 1498 # longer than the ConcreteFunction. 1499 graph_function._garbage_collector = ConcreteFunctionGarbageCollector( 1500 graph_function.graph) 1501 # pylint: enable=protected-access 1502 1503 return graph_function 1504 1505 def _maybe_define_function(self, args, kwargs): 1506 """Gets a function for these inputs, defining it if necessary. 1507 1508 `args` and `kwargs` can be None if this `Function` was created with an 1509 `input_signature`. 1510 1511 Args: 1512 args: The varargs for the Python function. 1513 kwargs: The keyword args for the Python function. 1514 1515 Returns: 1516 A graph function corresponding to the input signature implied by args and 1517 kwargs, as well as the inputs that the object should be called with. 1518 1519 Raises: 1520 ValueError: If inputs are incompatible with the input signature. 1521 TypeError: If the function inputs include non-hashable objects 1522 RuntimeError: If there's an internal bug (inconsistency) in handling 1523 shape relaxation retracing. 1524 """ 1525 if self.input_signature is None or args is not None or kwargs is not None: 1526 args, kwargs = self._function_spec.canonicalize_function_inputs( 1527 *args, **kwargs) 1528 cache_key = self._cache_key(args, kwargs) 1529 1530 try: 1531 hash(cache_key) 1532 except TypeError as e: 1533 raise TypeError( 1534 "Arguments supplied to `defun`-generated functions must be" 1535 " hashable. Original error: %s" % e) 1536 1537 with self._lock: 1538 graph_function = self._function_cache.primary.get(cache_key, None) 1539 if graph_function is not None: 1540 return graph_function, args, kwargs 1541 1542 logging.vlog(1, 1543 "Creating new FuncGraph for Python function %r (key: %r)", 1544 self._python_function, cache_key) 1545 logging.vlog(2, 1546 "Python function signature [args: %s] [kwargs: %s]", 1547 args, 1548 kwargs) 1549 1550 call_context_key = cache_key.replace(input_signature=None) 1551 1552 # If there's a provided input signature, or 1553 # there's no cache miss for this calling context so far, go ahead and 1554 # build the function and bypass shape relaxation retracing. 1555 if (self.input_signature is not None 1556 or call_context_key not in self._function_cache.missed): 1557 self._function_cache.missed.add(call_context_key) 1558 graph_function = self._create_graph_function(args, kwargs) 1559 self._function_cache.primary[cache_key] = graph_function 1560 return graph_function, args, kwargs 1561 1562 rank_only_cache_key = self._cache_key( 1563 args, kwargs, include_tensor_ranks_only=True) 1564 1565 arg_shapes = _flat_shape_list(args, kwargs) 1566 relaxed_arg_shapes = self._function_cache.arg_relaxed_shapes.get( 1567 rank_only_cache_key, None) 1568 relaxed_arg_function = self._function_cache.arg_relaxed.get( 1569 rank_only_cache_key, None) 1570 1571 if (relaxed_arg_function is not None 1572 and _compatible_shapes(flat_relaxed=relaxed_arg_shapes, 1573 flat_to_check=arg_shapes)): 1574 return relaxed_arg_function, args, kwargs 1575 1576 if relaxed_arg_shapes is None: 1577 relaxed_arg_shapes = arg_shapes 1578 else: 1579 if len(arg_shapes) != len(relaxed_arg_shapes): 1580 raise RuntimeError("Expected arg_shapes len to match " 1581 "relaxed_arg_shapes len: %d vs. %d" 1582 % (len(arg_shapes), len(relaxed_arg_shapes))) 1583 relaxed_arg_shapes = [ 1584 _common_shape(x, y) for (x, y) in zip( 1585 arg_shapes, relaxed_arg_shapes)] 1586 self._function_cache.arg_relaxed_shapes[rank_only_cache_key] = ( 1587 relaxed_arg_shapes) 1588 graph_function = self._create_graph_function( 1589 args, kwargs, override_flat_arg_shapes=relaxed_arg_shapes) 1590 self._function_cache.arg_relaxed[rank_only_cache_key] = graph_function 1591 1592 return graph_function, args, kwargs 1593 1594 1595def register(func, *args, **kwargs): 1596 """Register a specialization of a `Function` into the graph. 1597 1598 This won't actually call the function with the inputs, and only put the 1599 function definition into graph. Register function with different input param 1600 will result into multiple version of functions registered in graph. 1601 1602 Args: 1603 func: the `Function` instance that generated by a @defun 1604 *args: input arguments for the Python function. 1605 **kwargs: input keyword arguments for the Python function. 1606 1607 Returns: 1608 a `ConcreteFunction` object specialized to inputs and execution context. 1609 1610 Raises: 1611 ValueError: When the input function is not a defun wrapped python function. 1612 """ 1613 if not isinstance(func, Function): 1614 raise ValueError("Only defun function is allowed to be registered. " 1615 "Got type: %s" % type(func)) 1616 concrete_func = func.get_concrete_function(*args, **kwargs) 1617 concrete_func.add_to_graph(register_gradient_functions=True) 1618 return concrete_func 1619 1620 1621def validate_signature(signature): 1622 if any(not isinstance(arg, tensor_spec.TensorSpec) 1623 for arg in nest.flatten(signature)): 1624 raise TypeError("Invalid input_signature %s; input_signature must be " 1625 "a possibly nested sequence of TensorSpec objects.") 1626 1627 1628def defun(func=None, 1629 input_signature=None, 1630 autograph=True, 1631 experimental_autograph_options=None): 1632 """Compiles a Python function into a callable TensorFlow graph. 1633 1634 `defun` (short for "define function") compiles a Python function 1635 composed of TensorFlow operations into a callable that executes a `tf.Graph` 1636 containing those operations. The callable produced by `defun` contains only 1637 the subgraph of TensorFlow operations that were executed when the Python 1638 function was called with a particular input signature, defined as a list 1639 of the shapes and dtypes of the Python function's Tensor-valued arguments and 1640 the values of its non-Tensor Python objects. 1641 1642 When eager execution is enabled, the ability to create graphs from Python 1643 functions makes it possible to incrementally trade off debugability and 1644 interactivity for performance. Functions compiled with `defun` cannot be 1645 inspected with `pdb`; however, executing a graph 1646 generated by `defun` sometimes takes less time and memory than eagerly 1647 executing the corresponding Python function, since specifying computations as 1648 graphs allows for optimizations like automatic buffer reuse and 1649 parallelization among ops. Note that executing a `defun`-compiled function 1650 incurs a small constant overhead, so eagerly executing sufficiently small 1651 Python functions might take less time than executing their corresponding 1652 `defun`-generated graphs. 1653 1654 For a Python function to be compatible with `defun`, all of its arguments must 1655 be hashable Python objects or lists thereof. The function itself may not 1656 modify the list/map structure of its arguments. Additionally, it must return 1657 zero or more `tf.Tensor` objects. If the Python function returns 1658 a `tf.Variable`, its compiled version will return the value of that variable 1659 as a `tf.Tensor`. 1660 1661 Executing a graph generated by `defun` respects device annotations (i.e., 1662 all `with tf.device` directives present in a Python function will also be 1663 present in its corresponding graph), but it is not yet possible to execute the 1664 generated graphs across multiple machines. 1665 1666 _Example Usage_ 1667 1668 ```python 1669 import tensorflow as tf 1670 1671 tf.enable_eager_execution() 1672 1673 # A simple example. 1674 def f(x, y): 1675 return tf.reduce_mean(tf.multiply(x ** 2, 3) + y) 1676 1677 g = tf.contrib.eager.defun(f) 1678 1679 x = tf.constant([[2.0, 3.0]]) 1680 y = tf.constant([[3.0, -2.0]]) 1681 1682 # `f` and `g` will return the same value, but `g` will be executed as a 1683 # TensorFlow graph. 1684 assert f(x, y).numpy() == g(x, y).numpy() 1685 1686 # `defun` is capable of compiling Python functions that close over Python 1687 # objects, including Tensors and Variables. 1688 @tf.contrib.eager.defun 1689 def h(): 1690 return f(x, y) 1691 1692 assert (h().numpy() == f(x, y).numpy()).all() 1693 1694 # `defun` automatically lifts variables out of the graphs it creates, 1695 # allowing you to compile the `call` methods of `tf.keras.layers.Layer` and 1696 # `tf.keras.Model` objects. 1697 class MyModel(tf.keras.Model): 1698 1699 def __init__(self, keep_probability=0.2): 1700 super(MyModel, self).__init__() 1701 self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu) 1702 self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax) 1703 self.keep_probability = keep_probability 1704 1705 @tf.contrib.eager.defun 1706 def call(self, inputs, training=True): 1707 x = self.dense2(self.dense1(inputs)) 1708 if training: 1709 return tf.nn.dropout(x, self.keep_probability) 1710 else: 1711 return x 1712 1713 model = MyModel() 1714 model(x, training=True) # executes a graph, with dropout 1715 model(x, training=False) # executes a graph, without dropout 1716 1717 # `defun`-compiled functions are differentiable. 1718 optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01) 1719 with tf.GradientTape() as tape: 1720 outputs = model(x) 1721 gradient = tape.gradient(outputs, model.trainable_variables) 1722 optimizer.apply_gradients((grad, var) for grad, var in zip(gradient, 1723 model.trainable_variables)) 1724 ``` 1725 1726 When using `defun`, there are subtleties regarding inputs, Python control 1727 flow, and variable creation that one should be aware of. For concreteness, let 1728 `f` be a Python function that returns zero or more `tf.Tensor` objects and 1729 let `F = defun(f)`. `F` builds a graph for each unique input signature it 1730 sees, Python control flow is baked into graphs, and operations related to 1731 variable initialization are automatically lifted out of the graphs that `F` 1732 generates and placed in the eager context if executing eagerly or into an 1733 outer graph otherwise. 1734 1735 _Input Signatures_ 1736 1737 By default, `F = tf.contrib.eager.defun(f)` instantiates a separate graph 1738 for every unique sequence of the shapes and dtypes of Tensor arguments and 1739 the values of Python objects it is invoked with. For example, calling 1740 `F(tf.random_uniform([2])` will execute a different graph than 1741 `F(tf.random_uniform([3])` because the two inputs have different shapes. 1742 The first time that `F(*args, **kwargs)` is called with a particular sequence 1743 of Tensor shapes and dtypes and Python values, it constructs a graph by 1744 tracing the execution of `f(*args, **kwargs)`; this graph is bound to an 1745 input signature inferred from `(*args, **kwargs)` and cached for future reuse. 1746 1747 NumPy arrays passed as inputs to `F` are converted to `tf.Tensor` objects 1748 before being passed to `f`, and are treated as Tensors for caching. This 1749 allows a function to be called multiple times with NumPy arrays having 1750 different values but the same shape and dtype without re-tracing each time. 1751 1752 `tf.contrib.eager.defun` caches graphs for your convenience, letting you 1753 define TensorFlow functions without explicitly specifying their signatures. 1754 However, this policy is conservative and potentially expensive; for example, 1755 when different invocations of your function have differently-shaped Tensor 1756 inputs, this policy might generate more graph functions than necessary. To 1757 eliminate such costs, `tf.contrib.eager.defun` allows you to supply an 1758 optional `input_signature` argument specifying the shapes and dtypes of the 1759 inputs. In particular, the shapes may be partially unspecified, with `None`s 1760 in the unknown dimensions. When an input signature is provided, 1761 `tf.contrib.eager.defun` will only instantiate a single graph for the 1762 decorated Python function. The following is an example: 1763 1764 ```python 1765 import tensorflow as tf 1766 1767 # The first `TensorSpec` below describes the shape and dtype of `words`, 1768 # and the second describes the shape and dtype of `another_tensor`. Note that 1769 # the last dimension of the `words` `TensorSpec` is left unspecified. 1770 @tf.contrib.eager.defun(input_signature=[ 1771 tf.contrib.eager.TensorSpec(shape=[50, 300, None], dtype=tf.float32), 1772 tf.contrib.eager.TensorSpec(shape=[300, 100], dtype=tf.float32) 1773 ]) 1774 def my_sequence_model(words, another_tensor): 1775 ... 1776 1777 # Note how the third dimension of the first input can vary freely. 1778 words = tf.random_uniform(([50, 300, 10]) 1779 second_input = tf.random_uniform([300, 100]) 1780 my_sequence_model(words, second_input) 1781 1782 words = tf.random_uniform(([50, 300, 20]) 1783 my_sequence_model(words, second_input) 1784 1785 # Passing an input with an incompatible shape will raise an error. 1786 words = tf.random_uniform(([50, 100, 20]) 1787 my_sequence_model(words, second_input) # <---- This will raise an error. 1788 1789 ``` 1790 1791 Python functions that are compiled with an `input_signature` must only accept 1792 Tensors as arguments and must not take unnamed keyword arguments (**kwargs). 1793 1794 _Tracing_ 1795 1796 Be aware that because `F` only logs TensorFlow operations, all the other 1797 Python code that `f` executes will only shape the _construction_ of the graphs 1798 that `F` executes: the Python code won't be executed when the graphs 1799 themselves are executed, though it will be executed every time the Python 1800 function is traced (and a given Python function might be traced multiple 1801 times, once for each input signature it is invoked with). For example, whereas 1802 the Python function 1803 1804 ```python 1805 import tensorflow as tf 1806 import numpy as np 1807 1808 tf.enable_eager_execution() 1809 1810 def add_noise(): 1811 return tf.eye(5) + np.random.randn(5, 5) 1812 ``` 1813 1814 will return a different output everytime it is invoked, the compiled function 1815 `compiled = tf.contrib.eager.defun(add_noise)` will return the same value 1816 every time it is called, since a particular random offset generated by NumPy 1817 will be inserted into the graph as a TensorFlow constant. The solution is to 1818 replace the call to `np.random.randn` with `tf.random_normal((5, 5))`. 1819 1820 _Python Side-Effects_ 1821 1822 A corollary of the previous discussion on tracing is the following: If a 1823 Python function `f` has Python side-effects, then executing `f` multiple times 1824 will not necessarily be semantically equivalent to executing `F = 1825 tf.contrib.eager.defun(f)` multiple times; this difference is due to the fact 1826 that `defun` only captures the subgraph of TensorFlow operations that is 1827 constructed when `f` is called in a graph-building context. 1828 1829 _Python Control Flow_ 1830 1831 The structure of many machine learning computations depend upon whether one is 1832 training or validating, and it is common to nest specialized logic under `if 1833 training:` blocks. By mapping each input signature to a unique graph, `defun` 1834 lets users transparently compile such code, as the following code snippet 1835 demonstrates: 1836 1837 ```python 1838 import tensorflow as tf 1839 1840 tf.enable_eager_execution() 1841 1842 @tf.contrib.eager.defun 1843 def lossy_matmul(W, x, training=True): 1844 outputs = tf.matmul(W, x) 1845 if training: 1846 outputs = tf.nn.dropout(outputs, keep_probability=0.2) 1847 return outputs 1848 1849 W = tf.random_normal((3, 5)) 1850 x = tf.random_normal((5, 1)) 1851 1852 # Executes a graph that applies dropout. 1853 lossy_outputs = lossy_matmul(W, x, training=True) 1854 1855 # Executes a graph that does not apply dropout. 1856 exact_outputs = lossy_matmul(W, x, training=False) 1857 ``` 1858 1859 _TensorFlow Control Flow_ 1860 1861 When `autograph` is `True`, data-dependent control flow is allowed as well. 1862 Control flow statements that depend on `Tensor` values are staged into 1863 corresponding TensorFlow ops. For example, the following code will work as 1864 expected: 1865 1866 ```python 1867 @tf.contrib.eager.defun 1868 def dynamic_rnn_loop(cell, seq): 1869 state, output = cell.zero_state() 1870 for input in seq: 1871 state, output = cell(input, state) 1872 return output 1873 ``` 1874 1875 For more information see `tf.autograph`. 1876 1877 _Variables_ 1878 1879 TensorFlow operations related to variable creation and initialization are 1880 automatically lifted out of the graphs generated by `defun`. In practice, this 1881 implies that variable creation and initialization only happen the first time 1882 `F` is called, and that variables are reused every time thereafter. Many 1883 TensorFlow APIs, like `tf.keras.layers.Layer` objects, create variables the 1884 first time they are called and reuse them thereafter. Automatic variable 1885 lifting makes it possible to compile these APIs without extra effort, at the 1886 cost of introducing a discrepancy between the semantics of executing Python 1887 functions and their corresponding compiled functions. For example: 1888 1889 ```python 1890 import tensorflow as tf 1891 1892 tf.enable_eager_execution() 1893 1894 def fn(): 1895 x = tf.Variable(0.0) 1896 x.assign_add(1.0) 1897 return x.read_value() 1898 1899 # `fn` is a Python function, so x is created, initialized, and destroyed upon 1900 # every invocation 1901 assert fn().numpy() == fn().numpy() == 1.0 1902 1903 compiled = tf.contrib.eager.defun(fn) 1904 1905 # Compiling `fn` with `defun` hoists all variables outside of the generated 1906 # graph, so initialization happens exactly once. 1907 assert compiled().numpy() == 1.0 1908 assert compiled().numpy() == 2.0 1909 ``` 1910 1911 Finally, because each input signature is bound to a unique graph, if your 1912 Python function constructs `tf.Variable` objects, then each graph constructed 1913 for that Python function will reference a unique set of variables. To 1914 circumvent this problem, we recommend against compiling Python functions that 1915 create `tf.Variable` objects. Instead, Python functions should either 1916 lexically close over `tf.Variable` objects or accept them as arguments, 1917 preferably encapsulated in an object-oriented container. If you must create 1918 variables inside your Python function and you want each graph generated for it 1919 to reference the same set of variables, add logic to your Python function that 1920 ensures that variables are only created the first time it is called and are 1921 reused for every subsequent invocation; note that this is precisely what 1922 `tf.keras.layers.Layer` objects do, so we recommend using them to represent 1923 variable-bearing computations whenever possible. 1924 1925 Args: 1926 func: function to be compiled. If `func` is None, returns a 1927 decorator that can be invoked with a single argument - `func`. The 1928 end result is equivalent to providing all the arguments up front. 1929 In other words, defun(input_signature=...)(func) is equivalent to 1930 defun(func, input_signature=...). The former allows 1931 the following use case: 1932 @tf.contrib.eager.defun(input_signature=...) 1933 def foo(...): 1934 ... 1935 1936 input_signature: A possibly nested sequence of 1937 `tf.contrib.eager.TensorSpec` objects specifying the shapes and dtypes of 1938 the Tensors that will be supplied to this function. If `None`, a separate 1939 function is instantiated for each inferred input signature. If a 1940 signature is specified, every input to `func` must be a `Tensor`, and 1941 `func` cannot accept `**kwargs`. 1942 autograph: Whether `func` should be compiled before 1943 constructing the graph. See https://www.tensorflow.org/guide/autograph 1944 for more information. 1945 experimental_autograph_options: Experimental knobs (in the form of a tuple 1946 of tensorflow.autograph.Feature values) to control behavior when 1947 autograph=True. 1948 1949 1950 Returns: 1951 If `func` is not None, returns a callable that will execute the compiled 1952 function (and return zero or more `tf.Tensor` objects). 1953 If `func` is None, returns a decorator that, when invoked with a single 1954 `func` argument, returns a callable equivalent to the case above. 1955 1956 Raises: 1957 TypeError: If `input_signature` is neither `None` nor a sequence of 1958 `tf.contrib.eager.TensorSpec` objects. 1959 """ 1960 return defun_with_attributes( 1961 func=func, 1962 input_signature=input_signature, 1963 autograph=autograph, 1964 experimental_autograph_options=experimental_autograph_options) 1965 1966 1967def defun_with_attributes(func=None, 1968 input_signature=None, 1969 attributes=None, 1970 autograph=True, 1971 experimental_autograph_options=None): 1972 """Compiles a Python function into a callable TensorFlow graph. 1973 1974 This function supports adding extra function attributes. See detailed 1975 documentation in defun(). Currently this is not exposed in public API since we 1976 don't expect user to directly use attributes, and attribute won't work by 1977 itself. This assumption might change in future. 1978 1979 Args: 1980 func: function to be compiled. 1981 input_signature: same as defun()'s input_signature. 1982 attributes: A dictionary of arguments which will be added to function def as 1983 attributes. Currently only support primitive types as value, and only 1984 whitelisted attribute name is allowed. Unwhitelisted attribute name or 1985 unsupported value will result into ValueError. `func_name` is also one of 1986 the whitelisted argument which is a python string, and sets the name for 1987 this `ConcreteFunction` in the graph. 1988 autograph: same as defun()'s autograph. 1989 experimental_autograph_options: same as defun()'s 1990 experimental_autograph_options. 1991 1992 Returns: 1993 Same as the return value of defun, with attributes added to the function in 1994 graph. 1995 """ 1996 if input_signature is not None: 1997 validate_signature(input_signature) 1998 1999 # TODO(apassos): deal with captured global state. Deal with control flow. 2000 def decorated(function): 2001 try: 2002 if attributes: 2003 name = attributes.pop("func_name", function.__name__) 2004 else: 2005 name = function.__name__ 2006 except AttributeError: 2007 name = "function" 2008 return tf_decorator.make_decorator( 2009 function, 2010 Function( 2011 function, 2012 name, 2013 input_signature=input_signature, 2014 attributes=attributes, 2015 autograph=autograph, 2016 autograph_options=experimental_autograph_options)) 2017 2018 # This code path is for the `foo = tfe.defun(foo, ...)` use case 2019 if func is not None: 2020 return decorated(func) 2021 2022 # This code path is for the 2023 # 2024 # @tfe.defun(...) 2025 # def foo(...): 2026 # ... 2027 # 2028 # use case, which is equivalent to `foo = tfe.defun(...)(foo)` 2029 return decorated 2030 2031 2032# When a method is bound to objects of this type, it allows AutoGraph to 2033# recover a weak reference the original method's self pointer, so that it can 2034# execute it consistent with class_method_to_instance_method's 2035# bound_method_wrapper. 2036# TODO(b/119246461): This is not pretty. Use a descriptor instead? 2037class TfMethodTarget(object): 2038 """Binding target for methods replaced by function and defun.""" 2039 2040 def __init__(self, target, original_python_function): 2041 self.weakrefself_target__ = target 2042 self.weakrefself_func__ = weakref.ref(original_python_function) 2043 2044 @property 2045 def target(self): 2046 return self.weakrefself_target__() 2047 2048 def call(self, args, kwargs): 2049 wrapped_fn = self.weakrefself_func__() 2050 if tf_inspect.ismethod(wrapped_fn): 2051 wrapped_fn = six.get_unbound_function(wrapped_fn) 2052 return wrapped_fn(self.weakrefself_target__(), *args, **kwargs) 2053 2054 2055def class_method_to_instance_method(original_function, instance): 2056 """Constructs a new `Function` with `self` bound.""" 2057 weak_instance = weakref.ref(instance) 2058 2059 # Note: while we could bind to a weakref proxy instead, that causes the 2060 # bound method to be unhashable. 2061 bound_method = types_lib.MethodType( 2062 original_function.python_function, 2063 TfMethodTarget(weak_instance, original_function.python_function)) 2064 2065 # original_function is expected to be of one of the two `Function` types 2066 # (defined either in function.py or def_function.py). 2067 assert hasattr(original_function, "_name") 2068 assert hasattr(original_function, "_autograph") 2069 assert hasattr(original_function, "_function_spec") 2070 assert hasattr(original_function, "python_function") 2071 2072 weak_bound_method_wrapper = None 2073 def bound_method_wrapper(*args, **kwargs): 2074 """Wraps either a dummy MethodType or a converted AutoGraph function.""" 2075 # __wrapped__ allows AutoGraph to swap in a converted function. 2076 strong_bound_method_wrapper = weak_bound_method_wrapper() 2077 wrapped_fn = strong_bound_method_wrapper.__wrapped__ 2078 2079 if wrapped_fn is strong_bound_method_wrapper.__original_wrapped__: 2080 # If __wrapped__ was not replaced, then call original_function. 2081 # TODO(mdan): For better consistency, use the wrapper's call(). 2082 wrapped_fn = original_function.python_function 2083 if tf_inspect.ismethod(wrapped_fn): 2084 wrapped_fn = six.get_unbound_function(wrapped_fn) 2085 return wrapped_fn(weak_instance(), *args, **kwargs) 2086 2087 # If __wrapped__ was replaced, then it is always an unbound function. 2088 # However, the replacer is still responsible for attaching self properly. 2089 # TODO(mdan): Is it possible to do it here instead? 2090 return wrapped_fn(*args, **kwargs) 2091 weak_bound_method_wrapper = weakref.ref(bound_method_wrapper) 2092 2093 # pylint: disable=protected-access 2094 # We make a dummy MethodType object to generate the correct bound method 2095 # signature. The actual call is to a function with a weak reference to 2096 # `instance`. 2097 instance_func = type(original_function)( 2098 tf_decorator.make_decorator(bound_method, bound_method_wrapper), 2099 name=original_function._name, 2100 autograph=original_function._autograph, 2101 input_signature=original_function.input_signature) 2102 # pylint: enable=protected-access 2103 2104 # And we wrap the function with tf_decorator so inspection works correctly 2105 wrapped_instance_func = tf_decorator.make_decorator( 2106 original_function.python_function, instance_func) 2107 return wrapped_instance_func 2108 2109 2110class _FunctionGarbageCollector(object): 2111 """Cleans up cycles when a defun goes out of scope.""" 2112 2113 def __init__(self, cache): 2114 self._cache = cache 2115 2116 def __del__(self): 2117 if func_graph_module is None or memory is None: 2118 return 2119 try: 2120 while self._cache: 2121 self._cache.popitem() 2122 memory.dismantle_ordered_dict(self._cache) 2123 except: # pylint: disable=bare-except 2124 pass 2125 2126 2127class ConcreteFunctionGarbageCollector(object): 2128 """Cleans up reference cycles when a `ConcreteFunction` goes out of scope.""" 2129 2130 def __init__(self, func_graph): 2131 self._func_graph = func_graph 2132 2133 def release(self): 2134 """Call off the FuncGraph deletion.""" 2135 self._func_graph = None 2136 2137 def __del__(self): 2138 if func_graph_module is None or memory is None or self._func_graph is None: 2139 return 2140 try: 2141 func_graph_module.dismantle_func_graph(self._func_graph) 2142 except: # pylint: disable=bare-except 2143 pass 2144