1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""FuncGraph and related functionality.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections as py_collections 22import itertools 23import weakref 24 25from tensorflow.core.framework import attr_value_pb2 26from tensorflow.python.eager import context 27from tensorflow.python.eager import execute 28from tensorflow.python.eager import tape 29from tensorflow.python.eager.graph_only_ops import graph_placeholder 30from tensorflow.python.framework import dtypes 31from tensorflow.python.framework import ops 32from tensorflow.python.framework import tensor_spec 33from tensorflow.python.framework.auto_control_deps import AutomaticControlDependencies 34from tensorflow.python.ops import array_ops 35from tensorflow.python.ops import custom_gradient 36from tensorflow.python.ops import resource_variable_ops 37from tensorflow.python.ops import tensor_array_ops 38from tensorflow.python.ops import variable_scope 39from tensorflow.python.util import compat 40from tensorflow.python.util import memory 41from tensorflow.python.util import nest 42from tensorflow.python.util import tf_contextlib 43from tensorflow.python.util import tf_decorator 44from tensorflow.python.util.lazy_loader import LazyLoader 45 46# This is to avoid a circular dependency: 47# function -> func_graph 48function = LazyLoader("function", globals(), 49 "tensorflow.python.eager.function") 50def_function = LazyLoader( 51 "def_function", globals(), 52 "tensorflow.python.eager.def_function") 53 54WHITELIST_COLLECTIONS = [ 55 ops.GraphKeys.GLOBAL_VARIABLES, 56 ops.GraphKeys.LOCAL_VARIABLES, 57 ops.GraphKeys.TRAINABLE_VARIABLES, 58 variable_scope._VARSTORE_KEY, # pylint: disable=protected-access 59 variable_scope._VARSCOPESTORE_KEY # pylint: disable=protected-access 60] 61 62 63class UnknownArgument(object): 64 """Signifies an argument which is not currently handled.""" 65 pass 66 67 68def convert_structure_to_signature(structure, arg_names=None): 69 """Convert a potentially nested structure to a signature. 70 71 Args: 72 structure: Structure to convert, where top level collection is a list or a 73 tuple. 74 arg_names: Optional list of arguments that has equal number of elements as 75 `structure` and is used for naming corresponding TensorSpecs. 76 77 Returns: 78 Identical structure that has TensorSpec objects instead of Tensors and 79 UknownArgument instead of any unsupported types. 80 """ 81 def encode_arg(arg, path): 82 """A representation for this argument, for converting into signatures.""" 83 if isinstance(arg, ops.Tensor): 84 user_specified_name = None 85 try: 86 user_specified_name = compat.as_str( 87 arg.op.get_attr("_user_specified_name")) 88 except ValueError: 89 pass 90 91 if path and user_specified_name and user_specified_name != path[0]: 92 # The user has explicitly named the argument differently than the name 93 # of the function argument. 94 name = user_specified_name 95 else: 96 name = "/".join([str(p) for p in path]) 97 return tensor_spec.TensorSpec(arg.shape, arg.dtype, name) 98 if isinstance(arg, ( 99 int, 100 float, 101 bool, 102 type(None), 103 dtypes.DType, 104 tensor_spec.TensorSpec, 105 )): 106 return arg 107 return UnknownArgument() 108 109 # We are using the flattened paths to name the TensorSpecs. We need an 110 # explicit name for them downstream. 111 flattened = nest.flatten_with_tuple_paths(structure) 112 if arg_names: 113 if len(arg_names) != len(structure): 114 raise ValueError( 115 "Passed in arg_names don't match actual signature (%s)." % arg_names) 116 # Replace all top-level names with their actual arg_names. If a path before 117 # was "(2,'a',1)", it will become "(arg_names[2],'a',1)". 118 flattened = [ 119 ((arg_names[path[0]],) + path[1:], arg) for path, arg in flattened 120 ] 121 122 mapped = [encode_arg(arg, path) for path, arg in flattened] 123 return nest.pack_sequence_as(structure, mapped) 124 125 126class FuncGraph(ops.Graph): 127 """Graph representing a function body. 128 129 Attributes: 130 name: The name of the function. 131 inputs: Placeholder tensors representing the inputs to this function. The 132 tensors are in this FuncGraph. This represents "regular" inputs as well as 133 captured inputs (i.e. the values of self.captures), with the regular 134 inputs coming first. 135 outputs: Tensors that will be returned by this function. The tensors are in 136 this FuncGraph. 137 control_outputs: Operations that must be executed before the function 138 represented by this graph can be said to have been executed. 139 structured_input_signature: A tuple of (args, kwargs), which are both 140 possibly-nested python objects that were received by this function. Note 141 that these structures might contain Python `None`s. 142 structured_outputs: A possibly-nested python object which will be returned 143 by this function. The Tensors in this structure are the same as those of 144 self.outputs. Note that this structure might contain Python `None`s. 145 variables: Variables that should be watched during function execution. 146 outer_graph: The graph this function is defined in. May be another FuncGraph 147 or the global default Graph. 148 captures: Maps external tensor -> internal tensor (i.e. input placeholder). 149 The entries are in the order they were captured. 150 control_captures: Set of external ops on which this graph has a control 151 dependency. 152 seed: The graph-level random seed. 153 capture_by_value: If True, the func graph will capture Variables by value 154 instead of reference. 155 """ 156 157 def __init__(self, name, collections=None, capture_by_value=None): 158 """Construct a new FuncGraph. 159 160 The graph will inherit its graph key, collections, seed, and distribution 161 strategy stack from the current context or graph. 162 163 Args: 164 name: the name of the function. 165 collections: a dictionary of collections this FuncGraph should start 166 with. If not specified (None), the FuncGraph will read (but not write 167 to) the outer graph's collections that are not whitelisted, and both 168 read and write to the outer graph's collections that are whitelisted. 169 The current whitelisted collections are the global variables, the 170 local variables, and the trainable variables. 171 Defaults to None. 172 capture_by_value: An optional boolean. If True, the func graph will 173 capture Variables by value instead of reference. By default inherit 174 from outer graphs, and failing that will default to False. 175 """ 176 super(FuncGraph, self).__init__() 177 178 self.name = name 179 self.inputs = [] 180 self.outputs = [] 181 self.control_outputs = [] 182 self.control_captures = set() 183 self.structured_input_signature = None 184 self.structured_outputs = None 185 self._weak_variables = [] 186 self.outer_graph = ops.get_default_graph() 187 self.captures = py_collections.OrderedDict() 188 # Inherit capture-by-value from outer graph. 189 if capture_by_value is not None: 190 self.capture_by_value = capture_by_value 191 elif self.outer_graph is not None and isinstance( 192 self.outer_graph, FuncGraph): 193 self.capture_by_value = self.outer_graph.capture_by_value 194 else: 195 self.capture_by_value = False 196 197 self._building_function = True 198 # Map from resource tensor name to last op (in program order) which uses 199 # this tensor. Used to enforce that execution order matches program order 200 # for resource tensors. 201 self._last_op_using_resource_tensor = {} 202 203 graph = self.outer_graph 204 205 if context.executing_eagerly(): 206 self.seed = context.global_seed() 207 # [for tf-data user migration from TF1.0 to 2.0] seed_used keep track of 208 # any None op_seed for random_op in the function, in which case we end up 209 # using function seed, which could be unintended behavior for the op. 210 self._seed_used = False 211 else: 212 self.seed = graph.seed 213 self._seed_used = False 214 # TODO(allenl): Figure out if we can remove colocation stack 215 # specialization (currently used in cond_v2), here and in the cache key. 216 self._colocation_stack = graph._colocation_stack.copy() # pylint: disable=protected-access 217 218 if collections is None: 219 for collection_name in graph.get_all_collection_keys(): 220 if collection_name not in WHITELIST_COLLECTIONS: 221 self._collections[collection_name] = graph.get_collection( 222 collection_name) 223 for collection_name in WHITELIST_COLLECTIONS: 224 self._collections[collection_name] = graph.get_collection_ref( 225 collection_name) 226 else: 227 self._collections = collections 228 229 def __str__(self): 230 return "FuncGraph(name=%s, id=%s)" % (self.name, id(self)) 231 232 def control_dependencies(self, control_inputs): 233 """Handles control dependencies. 234 235 FuncGraph wraps Graph's control_dependencies logic by first filtering out 236 any external tensors / operations and storing them in the graph's 237 control_captures member. Any consumers of this function graph must then 238 decide how to handle the control captures. 239 240 Args: 241 control_inputs: A list of `Operation` or `Tensor` objects which 242 must be executed or computed before running the operations 243 defined in the context. Can also be `None` to clear the control 244 dependencies. 245 246 Returns: 247 A context manager that specifies control dependencies for all 248 operations constructed within the context. 249 250 Raises: 251 TypeError: If `control_inputs` is not a list of `Operation` or 252 `Tensor` objects. 253 """ 254 if control_inputs is None: 255 return super(FuncGraph, self).control_dependencies(control_inputs) 256 257 filtered_control_inputs = [] 258 for c in control_inputs: 259 # Check for _UnreadVariable 260 if (isinstance(c, ops.IndexedSlices) or 261 (hasattr(c, "_handle") and hasattr(c, "op"))): 262 c = c.op 263 graph_element = ops._as_graph_element(c) # pylint: disable=protected-access 264 if graph_element is None: 265 graph_element = c 266 if graph_element is not None and getattr( 267 graph_element, "graph", None) is not self: 268 self.control_captures.add(graph_element) 269 else: 270 filtered_control_inputs.append(graph_element) 271 return super(FuncGraph, self).control_dependencies(filtered_control_inputs) 272 273 def as_default(self): 274 outer_cm = super(FuncGraph, self).as_default() 275 276 @tf_contextlib.contextmanager 277 def inner_cm(): 278 """Context manager for copying distribute.Strategy scope information.""" 279 graph = ops.get_default_graph() 280 # pylint: disable=protected-access 281 # TODO(b/112906995, nareshmodi): distribution strategy depends on 282 # inheriting this stack from the default graph even in eager mode. Maybe 283 # it should be part of the eager context? This would also allow us to 284 # remove a get_default_graph() call from the function cache lookup. 285 old_strategy_stack = self._distribution_strategy_stack 286 self._distribution_strategy_stack = list( 287 graph._distribution_strategy_stack) 288 # We ignore device placements from any outer scopes while tracing the 289 # function when possible, to avoid hard-coding them in the function 290 # graph. "Default" placements come from the PartitionedCallOp's placement, 291 # so that the same trace of the Python function may be placed on several 292 # different devices and saved functions may be placed on new devices when 293 # restored. 294 old_device_stack = self._device_function_stack 295 if context.executing_eagerly(): 296 if self._distribution_strategy_stack: 297 self._add_device_to_stack(context.context().device_name) 298 else: 299 if (self._distribution_strategy_stack 300 or device_stack_has_callable(graph._device_function_stack)): 301 # Hard-code devices from device functions in the function body 302 self._device_function_stack = graph._device_function_stack.copy() 303 304 old_creator_stack = self._variable_creator_stack 305 self._variable_creator_stack = graph._variable_creator_stack 306 # Inherit the graph key, since this is used for matching variables in 307 # optimizers. 308 old_graph_key = self._graph_key 309 self._graph_key = graph._graph_key 310 # pylint: enable=protected-access 311 312 with outer_cm as g: 313 try: 314 yield g 315 finally: 316 self._distribution_strategy_stack = old_strategy_stack 317 self._device_function_stack = old_device_stack 318 self._variable_creator_stack = old_creator_stack 319 self._graph_key = old_graph_key 320 return inner_cm() 321 322 @property 323 def output_types(self): 324 return [t.dtype for t in self.outputs] 325 326 @property 327 def output_shapes(self): 328 return [t.shape for t in self.outputs] 329 330 @property 331 def variables(self): 332 """A list of variables accessed by this FuncGraph. 333 334 Note that functions keep only weak references to variables. Calling the 335 function after a variable it accesses has been deleted is an error. 336 337 Yields: 338 Strong references to variables accessed by this FuncGraph. 339 """ 340 for weak_v in self._weak_variables: 341 v = weak_v() 342 if v is None: 343 raise AssertionError( 344 "Called a function referencing variables which have been deleted. " 345 "This likely means that function-local variables were created and " 346 "not referenced elsewhere in the program. This is generally a " 347 "mistake; consider storing variables in an object attribute on " 348 "first call.") 349 yield v 350 351 @variables.setter 352 def variables(self, var_list): 353 self._weak_variables = [weakref.ref(v) for v in var_list] 354 355 def _capture_by_value( 356 self, 357 op_type, 358 inputs, 359 dtypes, # pylint: disable=redefined-outer-name 360 input_types=None, 361 name=None, 362 attrs=None, 363 op_def=None, 364 compute_shapes=True, 365 compute_device=True): 366 # When capturing by value, do the read outside 367 reverse_captures = dict((v, k) for k, v in self.captures.items()) 368 uncaptured_inputs = [reverse_captures.get(t, t) for t in inputs] 369 with ops.init_scope(): 370 if context.executing_eagerly(): 371 attr_list = ("dtype", int(attrs["dtype"].type)) 372 value, = execute.execute( 373 compat.as_bytes(op_type), 1, uncaptured_inputs, attr_list, 374 context.context()) 375 else: 376 op = ops.get_default_graph().create_op( 377 op_type, uncaptured_inputs, dtypes, input_types, name, attrs, 378 op_def, compute_shapes, compute_device) 379 value = op.outputs[0] 380 captured_value = self.capture(value) 381 return captured_value.op 382 383 def create_op( 384 self, 385 op_type, 386 inputs, 387 dtypes=None, # pylint: disable=redefined-outer-name 388 input_types=None, 389 name=None, 390 attrs=None, 391 op_def=None, 392 compute_shapes=True, 393 compute_device=True): 394 """Like Graph.create_op, except handles external input tensors. 395 396 This overload adds functionality to create_op to "capture" any external 397 input tensors, i.e. tensors from the eager context or outer function graphs 398 if this is a nested function. See `capture` for more information. 399 400 Args: 401 op_type: The `Operation` type to create. This corresponds to the 402 `OpDef.name` field for the proto that defines the operation. 403 inputs: A list of `Tensor` objects that will be inputs to the `Operation`. 404 dtypes: (Optional) A list of `DType` objects that will be the types of the 405 tensors that the operation produces. 406 input_types: (Optional.) A list of `DType`s that will be the types of 407 the tensors that the operation consumes. By default, uses the base 408 `DType` of each input in `inputs`. Operations that expect 409 reference-typed inputs must specify `input_types` explicitly. 410 name: (Optional.) A string name for the operation. If not specified, a 411 name is generated based on `op_type`. 412 attrs: (Optional.) A dictionary where the key is the attribute name (a 413 string) and the value is the respective `attr` attribute of the 414 `NodeDef` proto that will represent the operation (an `AttrValue` 415 proto). 416 op_def: (Optional.) The `OpDef` proto that describes the `op_type` that 417 the operation will have. 418 compute_shapes: (Optional.) Deprecated. Has no effect (shapes are always 419 computed). 420 compute_device: (Optional.) If True, device functions will be executed 421 to compute the device property of the Operation. 422 423 Returns: 424 An `Operation` object. 425 """ 426 if self.capture_by_value and op_type in ["ReadVariableOp", 427 "ResourceGather"]: 428 return self._capture_by_value( 429 op_type, inputs, dtypes, input_types, name, attrs, op_def, 430 compute_shapes, compute_device) 431 432 # This capturing logic interacts poorly with control flow contexts which 433 # want to replace inputs of ops far too late in the process. This can lead 434 # the context to get confused and try to create an Enter for an Enter. We 435 # can detect this here and skip the additional Enter which can confuse loop 436 # validation logic. 437 if op_type == "Enter" and inputs[0].op.type == "Enter": 438 if inputs[0].op.get_attr("frame_name") == attrs["frame_name"].s: 439 return inputs[0].op 440 # Calling AddValue on the control flow contexts to force creation of the 441 # backward accumulators in the original graph before we create placeholders 442 # to capture the inputs. 443 ctxt = ops.get_default_graph()._control_flow_context # pylint: disable=protected-access 444 for i, inp in enumerate(inputs): 445 # TPU Estimator defines a control flow context with no AddValue method. 446 if ctxt is not None and hasattr(ctxt, "AddValue"): 447 inp = ctxt.AddValue(inp) 448 inp = self.capture(inp) 449 inputs[i] = inp 450 return super(FuncGraph, self).create_op( 451 op_type, inputs, dtypes, input_types, name, attrs, op_def, 452 compute_device=compute_device) 453 454 def capture(self, tensor, name=None): 455 """Captures `tensor` if it's external to this graph. 456 457 If `tensor` is from a different graph, returns a placeholder for it. 458 `tensor` and the placeholder will appear in self.captures, and the 459 placeholder will appear in self.inputs. Multiple calls to this method with 460 the same `tensor` argument will return the same placeholder. If `tensor` is 461 from this graph, returns `tensor`. 462 463 Args: 464 tensor: Tensor. May be from this FuncGraph or a different graph. 465 name: Optional name if a placeholder is created. 466 467 Returns: 468 Tensor from this FuncGraph. 469 """ 470 if isinstance(tensor, ops.EagerTensor): 471 if name is None: 472 name = str(ops.uid()) 473 return self._capture_helper(tensor, name) 474 if tensor.graph is not self: 475 if name is None: 476 name = tensor.op.name 477 inner_graph = tensor.graph 478 while inner_graph is not None and isinstance(inner_graph, FuncGraph): 479 if inner_graph is self: 480 raise ValueError( 481 "Trying to capture a tensor from an inner function. This can be " 482 "caused by accessing a tensor defined inside a loop or " 483 "conditional body, or a subfunction, from a calling function, " 484 "without going through the proper return value mechanism. " 485 "Consider using TensorFlow mechanisms such as TensorArrays " 486 "to return tensors from inner functions or loop / conditional " 487 "bodies. Tensor: %s; tensor graph: %s; this graph: %s" 488 % (tensor, tensor.graph, self)) 489 inner_graph = inner_graph.outer_graph 490 return self._capture_helper(tensor, name) 491 return tensor 492 493 def _capture_helper(self, tensor, name): 494 captured_tensor = self.captures.get(tensor, None) 495 if captured_tensor is None: 496 captured_tensor = _create_substitute_placeholder(tensor, name=name, 497 dtype=tensor.dtype) 498 self.captures[tensor] = captured_tensor 499 self.inputs.append(captured_tensor) 500 tape.record_operation("captured_value", [captured_tensor], [tensor], 501 lambda x: [x]) 502 return captured_tensor 503 504 @property 505 def external_captures(self): 506 """External tensors captured by this function.""" 507 return list(self.captures.keys()) 508 509 @property 510 def internal_captures(self): 511 """Placeholders in this function corresponding captured tensors.""" 512 return list(self.captures.values()) 513 514 515def func_graph_from_py_func(name, 516 python_func, 517 args, 518 kwargs, 519 signature=None, 520 func_graph=None, 521 autograph=False, 522 autograph_options=None, 523 add_control_dependencies=True, 524 arg_names=None, 525 op_return_value=None, 526 collections=None, 527 capture_by_value=None, 528 override_flat_arg_shapes=None): 529 """Returns a `FuncGraph` generated from `python_func`. 530 531 Args: 532 name: an identifier for the function. 533 python_func: the Python function to trace. 534 args: the positional args with which the Python function should be called; 535 ignored if a signature is provided. 536 kwargs: the keyword args with which the Python function should be called; 537 ignored if a signature is provided. 538 signature: a possibly nested sequence of `TensorSpecs` specifying the shapes 539 and dtypes of the arguments. When a signature is provided, `args` and 540 `kwargs` are ignored, and `python_func` is traced with Tensors conforming 541 to `signature`. If `None`, the shapes and dtypes are inferred from the 542 inputs. 543 func_graph: Optional. An instance of FuncGraph. If provided, we will use 544 this graph else a new one is built and returned. 545 autograph: whether to use autograph to compile `python_func`. 546 See https://www.tensorflow.org/guide/autograph for more information. 547 autograph_options: additional knobs to control when `autograph=True`. 548 See https://www.tensorflow.org/guide/autograph for more information. 549 add_control_dependencies: If True, automatically adds control dependencies 550 to ensure program order matches execution order and stateful ops always 551 execute. 552 arg_names: Optional list of argument names, used to give input placeholders 553 recognizable names. 554 op_return_value: Optional. A Tensor. If set and `python_func` returns 555 Operations, those return values will be replaced with this value. If not 556 set, returning an Operation triggers an error. 557 collections: a dictionary of collections this FuncGraph should start 558 with. If not specified (None), the FuncGraph will read (but not write to) 559 the outer graph's collections that are not whitelisted, and both 560 read and write to the outer graph's collections that are whitelisted. 561 The current whitelisted collections are the global variables, the 562 local variables, and the trainable variables. 563 Defaults to None. 564 capture_by_value: An optional boolean. If True, the func graph will capture 565 Variables by value instead of reference. By default inherit from outer 566 graphs, and failing that will default to False. 567 override_flat_arg_shapes: An optional list of instances that are either 568 `None` or `TensorShape`. The length must match that of 569 `nest.flatten((args, kwargs))`. The entries containing value `None` 570 must match entries in flattened arguments containing non-tensors, while 571 entries containing a `TensorShape` must match entries in the flattened 572 arguments containing tensors. 573 574 Returns: 575 A FuncGraph. 576 577 Raises: 578 TypeError: If any of `python_func`'s return values is neither `None` nor a 579 `Tensor`. 580 ValueError: If both `signature` and `override_flat_arg_shapes` are 581 passed in. 582 """ 583 if op_return_value is not None: 584 assert isinstance(op_return_value, ops.Tensor), op_return_value 585 if func_graph is None: 586 func_graph = FuncGraph(name, collections=collections, 587 capture_by_value=capture_by_value) 588 assert isinstance(func_graph, FuncGraph) 589 if add_control_dependencies: 590 control_manager = AutomaticControlDependencies() 591 else: 592 control_manager = ops.NullContextmanager() 593 with func_graph.as_default(), control_manager as a: 594 current_scope = variable_scope.get_variable_scope() 595 default_use_recource = current_scope.use_resource 596 current_scope.set_use_resource(True) 597 598 if signature is not None and override_flat_arg_shapes is not None: 599 raise ValueError( 600 "Passed both signature and override_flat_arg_shapes: %s and %s." 601 % (signature, override_flat_arg_shapes)) 602 603 if signature is not None: 604 args = signature 605 kwargs = {} 606 607 # Creates and names placeholders for all arguments. 608 if override_flat_arg_shapes is not None: 609 flat_args = nest.flatten(args) 610 arg_shapes = override_flat_arg_shapes[:len(flat_args)] 611 kwarg_shapes = override_flat_arg_shapes[len(flat_args):] 612 else: 613 arg_shapes = None 614 kwarg_shapes = None 615 func_args = _get_defun_inputs_from_args( 616 args, arg_names, flat_shapes=arg_shapes) 617 func_kwargs = _get_defun_inputs_from_kwargs( 618 kwargs, flat_shapes=kwarg_shapes) 619 620 # Convert all Tensors into TensorSpecs before saving the structured inputs. 621 # If storing pure concrete functions that are not called through polymorphic 622 # functions, we don't have access to FunctionSpec, so we need to call the 623 # TensorSpecs by their `arg_names` for later binding. 624 func_graph.structured_input_signature = ( 625 convert_structure_to_signature(func_args, arg_names), 626 convert_structure_to_signature(func_kwargs)) 627 628 flat_func_args = nest.flatten(func_args) 629 flat_func_kwargs = nest.flatten(func_kwargs) 630 # Temporarily set inputs to allow graph building code to inspect 631 # them. Reassigned below. 632 func_graph.inputs = [arg for arg in flat_func_args + flat_func_kwargs 633 if isinstance(arg, ops.Tensor)] 634 635 # Note: `nest.flatten` sorts by keys, as does `_deterministic_dict_values`. 636 # Variables to help check whether mutation happens in calling the function 637 # Copy the recursive list, tuple and map structure, but not base objects 638 func_args_before = nest.pack_sequence_as(func_args, flat_func_args) 639 func_kwargs_before = nest.pack_sequence_as( 640 func_kwargs, flat_func_kwargs) 641 642 def convert(x): 643 """Converts a function output to a Tensor.""" 644 if x is None: 645 return None 646 if op_return_value is not None and isinstance(x, ops.Operation): 647 # TODO(b/79881896): we currently can't capture external control deps, so 648 # this won't work if x needs to be captured (i.e. if python_func returns 649 # captured Operations). 650 with ops.control_dependencies([x]): 651 x = array_ops.identity(op_return_value) 652 elif not isinstance(x, tensor_array_ops.TensorArray): 653 try: 654 x = ops.convert_to_tensor_or_composite(x) 655 except (ValueError, TypeError): 656 raise TypeError( 657 "To be compatible with tf.contrib.eager.defun, Python functions " 658 "must return zero or more Tensors; in compilation of %s, found " 659 "return value of type %s, which is not a Tensor." % 660 (str(python_func), type(x))) 661 if add_control_dependencies: 662 x = a.mark_as_return(x) 663 return x 664 665 this_tape = tape.push_new_tape() 666 try: 667 if autograph: 668 from tensorflow.python import autograph # pylint: disable=g-import-not-at-top 669 _, original_func = tf_decorator.unwrap(python_func) 670 671 def wrapper(*args, **kwargs): 672 # Note: functions annotated with @tf.function should always be 673 # converted even though they would meet autograph's whitelisting 674 # criteria. 675 # If this assumption is ever broken, converted_call will need to 676 # handle the possibility of original_func still being a shim, e.g. 677 # bound to WeakrefSelf. 678 return autograph.converted_call( 679 original_func, None, 680 autograph.ConversionOptions( 681 recursive=True, 682 optional_features=autograph_options, 683 force_conversion=True, 684 ), args, kwargs) 685 686 # Wrapping around a decorator allows checks like tf_inspect.getargspec 687 # to be accurate. 688 converted_func = tf_decorator.make_decorator(original_func, wrapper) 689 python_func = tf_decorator.rewrap(python_func, original_func, 690 converted_func) 691 692 func_outputs = python_func(*func_args, **func_kwargs) 693 694 # invariant: `func_outputs` contains only Tensors, IndexedSlices, 695 # SparseTensors, TensorArrays and `None`s. 696 func_outputs = nest.map_structure(convert, func_outputs) 697 698 check_mutation(func_args_before, func_args) 699 check_mutation(func_kwargs_before, func_kwargs) 700 finally: 701 tape.pop_tape(this_tape) 702 current_scope.set_use_resource(default_use_recource) 703 704 # Variables in `func_args`, `func_kwargs` should be explicit inputs 705 # to the function, not captured inputs. 706 tape_variables = this_tape.watched_variables() 707 arg_variables = set() 708 inputs = [] 709 for arg in nest.flatten(func_args) + nest.flatten(func_kwargs): 710 if isinstance(arg, resource_variable_ops.ResourceVariable): 711 # Even if an argument variable was not used in the function, we've 712 # already manually captured the resource Tensor when creating argument 713 # placeholders. 714 resource_placeholder = func_graph.captures.pop(arg.handle, None) 715 if resource_placeholder is None: 716 continue 717 arg_variables.add(arg) 718 inputs.append(resource_placeholder) 719 elif isinstance(arg, ops.Tensor): 720 inputs.append(arg) 721 variables = [v for v in tape_variables if v not in arg_variables] 722 func_graph.inputs = inputs + list(func_graph.captures.values()) 723 724 func_graph.structured_outputs = func_outputs 725 # Returning a closed-over tensor does not trigger convert_to_tensor. 726 func_graph.outputs.extend( 727 func_graph.capture(x) 728 for x in flatten(func_graph.structured_outputs) 729 if x is not None) 730 731 func_graph.variables = variables 732 733 if add_control_dependencies: 734 func_graph.control_outputs.extend(control_manager.ops_which_must_run) 735 736# Register any other functions defined in the graph. 737 with ops.init_scope(): 738 if context.executing_eagerly(): 739 for f in func_graph._functions.values(): # pylint: disable=protected-access 740 # TODO(ashankar): What about the gradient registry? 741 context.add_function(f._c_func.func) # pylint: disable=protected-access 742 743 return func_graph 744 745 746def maybe_captured(tensor): 747 """If t is a captured value placeholder, returns the original captured value. 748 749 Args: 750 tensor: Tensor. 751 752 Returns: 753 A tensor, potentially from a different Graph/FuncGraph. 754 """ 755 if (not isinstance(tensor, ops.EagerTensor) and 756 tensor.op.graph.building_function and tensor.op.type == "Placeholder"): 757 for input_t, placeholder_t in tensor.op.graph.captures.items(): 758 if tensor == placeholder_t: 759 return maybe_captured(input_t) 760 # pylint: enable=protected-access 761 return tensor 762 763 764def device_stack_has_callable(device_stack): 765 """Checks whether a device stack contains a callable.""" 766 return any(callable(spec._device_name_or_function) # pylint: disable=protected-access 767 for spec in device_stack.peek_objs()) 768 769 770def check_mutation(n1, n2): 771 """Check if two list of arguments are exactly the same.""" 772 errmsg = ("Function to be traced should not modify structure of input " 773 "arguments. Check if your function has list and dictionary " 774 "operations that alter input arguments, " 775 "such as `list.pop`, `list.append`") 776 try: 777 nest.assert_same_structure(n1, n2) 778 except ValueError: 779 raise ValueError(errmsg) 780 781 for arg1, arg2 in zip(nest.flatten(n1), nest.flatten(n2)): 782 if arg1 is not arg2: 783 raise ValueError(errmsg) 784 785 786def flatten(sequence): 787 """Like `nest.flatten` but also unpacks other Tensor-like objects. 788 789 Flattens non-tensor objects into their constituent tensors. 790 791 Args: 792 sequence: A nested structure of Tensors, CompositeTensors, and 793 TensorArrays. 794 795 Returns: 796 A list of tensors. 797 """ 798 # TODO(akshayka): Support `SparseTensor` in a similar fashion. 799 flat_sequence = nest.flatten(sequence, expand_composites=True) 800 return [ 801 item.flow if isinstance(item, tensor_array_ops.TensorArray) else item 802 for item in flat_sequence] 803 804 805def pack_sequence_as(structure, flat_sequence): 806 """Like `nest.pack_sequence_as` but also packs other Tensor-like objects. 807 808 Args: 809 structure: The structure to pack into. May contain Tensors, 810 CompositeTensors, or TensorArrays. 811 flat_sequence: An iterable containing tensors. 812 813 Returns: 814 A nested structure. 815 816 Raises: 817 AssertionError if `structure` and `flat_sequence` are not compatible. 818 """ 819 flat_sequence = list(flat_sequence) 820 flattened_structure = nest.flatten(structure, expand_composites=True) 821 if len(flattened_structure) != len(flat_sequence): 822 raise ValueError("Mismatch in element count") 823 for i in range(len(flat_sequence)): 824 if isinstance(flattened_structure[i], tensor_array_ops.TensorArray): 825 flat_sequence[i] = tensor_array_ops.build_ta_with_new_flow( 826 old_ta=flattened_structure[i], flow=flat_sequence[i]) 827 return nest.pack_sequence_as(structure, flat_sequence, expand_composites=True) 828 829 830 831def _create_substitute_placeholder(value, name=None, dtype=None): 832 """Creates a placeholder for `value` and propagates shape info to it.""" 833 # Note: setting ops.control_dependencies(None) ensures we always put 834 # capturing placeholders outside of any control flow context. 835 with ops.control_dependencies(None): 836 placeholder = graph_placeholder( 837 dtype=dtype or value.dtype, shape=value.shape, name=name) 838 custom_gradient.copy_handle_data(value, placeholder) 839 return placeholder 840 841 842def _get_defun_inputs_from_args(args, names, flat_shapes=None): 843 """Maps Python function positional args to graph-construction inputs.""" 844 return _get_defun_inputs( 845 args, names, structure=args, flat_shapes=flat_shapes) 846 847 848def _get_defun_inputs(args, names, structure, flat_shapes=None): 849 """Maps python function args to graph-construction inputs. 850 851 Args: 852 args: A flat list of user-specified arguments. 853 names: A list of strings with user-specified argument names, same length as 854 `args`. May be `None`, in which case a generic name is used. 855 structure: The original argument list or dictionary. 856 flat_shapes: A flat list of values that are either `None` or 857 instances of `TensorShape`. If provided, then length must match 858 that of `nest.flatten(args)`; and locations where `args` are 859 instances of `Tensor` must have a corresponding `TensorShape` in 860 `flat_shapes`. May be `None`, in which case exact shapes are read 861 directly from the args. 862 863 Returns: 864 Placeholders with the same structure as `structure`. 865 866 Raises: 867 RuntimeError: if `flat_shapes` is provided, but 868 `len(flat_shapes) != len(nest.flatten(args))`. 869 RuntimeError: if a shape from `flat_shapes` is not None 870 for an argument that is not a `Tensor`, `TensorSpec`, 871 or `ResourceVariable`. 872 """ 873 func_graph = ops.get_default_graph() 874 function_inputs = [] 875 if names is None: 876 names = [None] * len(args) 877 if flat_shapes is None: 878 shapes_iter = itertools.repeat(None) 879 else: 880 len_flat_args = len(nest.flatten(args)) 881 if len_flat_args != len(flat_shapes): 882 raise RuntimeError( 883 "Length of fully flat shapes (%d) must match that of " 884 "flatten(args) (%d). args: %s, flat_shapes: %s" 885 % (len(flat_shapes), 886 len_flat_args, 887 args, 888 flat_shapes)) 889 shapes_iter = iter(flat_shapes) 890 for arg_value, name in zip(args, names): 891 flattened = nest.flatten(arg_value) 892 tensor_specs = [ 893 arg for arg in flattened if isinstance(arg, tensor_spec.TensorSpec) 894 ] 895 specified_names = [arg.name for arg in tensor_specs if arg.name] 896 if specified_names and len(specified_names) < len(tensor_specs): 897 raise ValueError("If specifying TensorSpec names for nested structures, " 898 "either zero or all names have to be specified.") 899 900 for arg in flattened: 901 # We have a shape entry for each arg, regadless of whether it's a real 902 # Tensor or not. For non-tensor entries it should be None. 903 shape = next(shapes_iter) 904 if isinstance(arg, (ops.Tensor, tensor_spec.TensorSpec)): 905 if isinstance(arg, tensor_spec.TensorSpec) and arg.name: 906 requested_name = arg.name 907 else: 908 requested_name = name 909 placeholder_shape = shape if shape is not None else arg.shape 910 try: 911 placeholder = graph_placeholder( 912 arg.dtype, placeholder_shape, 913 name=requested_name) 914 except ValueError: 915 # Sometimes parameter names are not valid op names, so fall back to 916 # unnamed placeholders. 917 placeholder = graph_placeholder(arg.dtype, placeholder_shape) 918 if name is not None: 919 # Record the requested/user-specified name in case it's different than 920 # the uniquified name, for validation when exporting signatures. 921 placeholder.op._set_attr( # pylint: disable=protected-access 922 "_user_specified_name", 923 attr_value_pb2.AttrValue(s=compat.as_bytes(requested_name))) 924 function_inputs.append(placeholder) 925 elif isinstance(arg, resource_variable_ops.ResourceVariable): 926 # Capture arg variables to create placeholders for them. These will be 927 # removed as captures after the function is traced (since otherwise we'd 928 # just add it back with a new placeholder when the variable was 929 # referenced). 930 placeholder = func_graph.capture(arg.handle, name=name) 931 placeholder.op._set_attr( # pylint: disable=protected-access 932 "_user_specified_name", 933 attr_value_pb2.AttrValue(s=compat.as_bytes(name))) 934 function_inputs.append(arg) 935 else: 936 if shape is not None: 937 raise RuntimeError( 938 "Expected provided shape override to be None for arg that isn't " 939 "a Tensor, but saw arg: '%s', shape: '%s'. args: %s" 940 % (arg, shape, args)) 941 function_inputs.append(arg) 942 return nest.pack_sequence_as(structure, function_inputs) 943 944 945def _get_defun_inputs_from_kwargs(kwargs, flat_shapes): 946 """Maps Python function keyword args to graph-construction inputs.""" 947 if kwargs: 948 names, args = zip(*sorted(kwargs.items())) 949 else: 950 names = [] 951 args = [] 952 return _get_defun_inputs( 953 args, names, structure=kwargs, flat_shapes=flat_shapes) 954 955 956def dismantle_func_graph(func_graph): 957 """Removes reference cycles in `func_graph` FuncGraph. 958 959 Helpful for making sure the garbage collector doesn't need to run when 960 the FuncGraph goes out of scope, e.g. in tests using defun with 961 @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True). 962 963 Args: 964 func_graph: A `FuncGraph` object to destroy. `func_graph` is unusable 965 after this function. 966 """ 967 # TODO(b/115366440): Delete this method when a custom OrderedDict is added. 968 # Clearing captures using clear() leaves some cycles around. 969 while func_graph.captures: 970 func_graph.captures.popitem() 971 memory.dismantle_ordered_dict(func_graph.captures) 972 ops.dismantle_graph(func_graph) 973