1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""AutomaticControlDependencies and related functionality.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import enum 23 24from tensorflow.core.framework import attr_value_pb2 25from tensorflow.python.eager import context 26from tensorflow.python.framework import auto_control_deps_utils as utils 27from tensorflow.python.framework import dtypes as dtypes_module 28from tensorflow.python.framework import op_def_registry 29from tensorflow.python.framework import ops 30from tensorflow.python.framework import registry 31from tensorflow.python.framework import sparse_tensor 32from tensorflow.python.ops import array_ops 33from tensorflow.python.ops import control_flow_ops 34from tensorflow.python.ops import control_flow_util 35from tensorflow.python.ops import tensor_array_ops 36from tensorflow.python.util import nest 37from tensorflow.python.util import object_identity 38from tensorflow.python.util import tf_decorator 39 40# LINT.IfChange 41# Op types that should not run in program order, e.g. because they need to run 42# asynchronously to avoid deadlock. 43ASYNC_STATEFUL_OPS = [ 44 "CollectiveGather", 45 "CollectiveGatherV2", 46 "CollectiveReduce", 47 "CollectiveReduceV2", 48 "CollectiveBcastSend", 49 "CollectiveBcastSendV2", 50 "CollectiveBcastRecv", 51 "CollectiveBcastRecvV2", 52 "NcclAllReduce", 53 # We do not add "Send" here since we want it to be added as a control output 54 # in order to avoid being pruned. 55 "Recv", 56] 57 58LEGACY_RANDOM_OPS = [ 59 # These may be used in variable initializers -- thus their execution should 60 # not be dependent on other stateful operations. This is because although 61 # according to program order, tf.Variables may be created in sequence, 62 # their initialization happens outside of the program order (specifically, 63 # in graph mode their initialization happens by calling a grouped 64 # initializer operation or in eager mode, where initialization is lifted 65 # out of the tf.function and executed the first time the function is 66 # executed). 67 # 68 # Unless there is a specific dependency between the initializers 69 # themselves (e.g. one initializer depends on a Variable whose value depends 70 # on another initializer), the initialization can happen in any order so 71 # long as it's before the associated Variable read operations. 72 # 73 # Note that in general the randomness of legacy random operations is only 74 # guaranteed by providing a graph-level and op-level seed (and ordering of 75 # the same op across multiple iterations of a while_loop is specifically not 76 # guaranteed; see the discussion below). 77 # 78 # There is a possible race condition inside while_loop where the same 79 # random OpKernel instantiation is reused across multiple steps 80 # of the loop. Since legacy Random OpKernels have an internal rng state, 81 # automatic dependency tracking across loop steps would likely 82 # fix this race; and for that case this denylist is problematic. 83 # However, since automatic dependency tracking inside while loops is not 84 # currently supported, and there are no other examples of OpKernel reuse 85 # (each OpKernel is associated with a unique op in graph mode), 86 # this denylist has no effect on the aforementioned behavior. 87 # 88 # TODO(ebrevdo,skyewm): Modify the check against this denylist to 89 # only occur when the op is inside a "variable initialization scope"; and 90 # add proper autodeps inside while_loops that respects this updated check. 91 "RandomUniform", 92 "RandomUniformInt", 93 "RandomStandardNormal", 94 "ParameterizedTruncatedNormal", 95 "TruncatedNormal", 96 "RandomShuffle", 97 "Multinomial", 98 "RandomGamma", 99 "RandomGammaGrad", 100 "RandomPoisson", 101 "RandomPoissonV2", 102] 103 104_ORDER_INSENSITIVE_STATEFUL_OPS = [ 105 "CudnnRNN", "CudnnRNNBackprop", "CudnnRNNV2", "CudnnRNNV3", 106 "CudnnRNNBackpropV2", "CudnnRNNBackpropV3", 107 "EnqueueTPUEmbeddingSparseBatch", "EnqueueTPUEmbeddingIntegerBatch", 108 "EnqueueTPUEmbeddingSparseTensorBatch", 109 "EnqueueTPUEmbeddingRaggedTensorBatch", "RestoreV2", "SaveV2" 110] 111# LINT.ThenChange(//tensorflow/core/grappler/optimizers/function_optimizer.cc) 112 113_ALL_DENYLISTED_OPS = ( 114 set(ASYNC_STATEFUL_OPS) | set(LEGACY_RANDOM_OPS) 115 | set(_ORDER_INSENSITIVE_STATEFUL_OPS)) 116 117# Op types that are marked as stateless, but should be allowlisted to add auto 118# control dependencies. 119_ALLOWLIST_STATELESS_OPS = [ 120 # As TPU collective ops are blocking, if there are more than one collective 121 # op in the function, we need to make sure different collectives ops are 122 # scheduled in certain orders. Otherwise if at the same time all the 123 # replicas are launching different collective ops/programs, it may cause 124 # deadlock. 125 "AllToAll", 126 "CrossReplicaSum", 127 "CollectivePermute", 128] 129 130 131def op_is_stateful(op): 132 # pylint: disable=protected-access 133 return (op._is_stateful and op.type not in _ALL_DENYLISTED_OPS) or ( 134 op.type in _ALLOWLIST_STATELESS_OPS) 135 136 137class ResourceType(enum.Enum): 138 READ_ONLY = "read-only" 139 READ_WRITE = "read-write" 140 141 142def collective_manager_ids_from_op(op): 143 """Returns CollectiveManager ID from the op if one exists, else None. 144 145 CollectiveManager adds collective and no_op operations tagged with an ID, 146 unique to the manager object. This function extracts that ID, or None, if the 147 node was not generated by a CollectiveManager. 148 149 Args: 150 op: `Operation` to get the collective manager ID from. 151 152 Returns: 153 List of CollectiveManager IDs used by the op. 154 """ 155 if op.type == "CollectiveReduce": 156 try: 157 return [op.get_attr("_collective_manager_id")] 158 except ValueError: 159 pass 160 elif op.type == "StatefulPartitionedCall": 161 try: 162 return op.get_attr(utils.COLLECTIVE_MANAGER_IDS) 163 except ValueError: 164 pass 165 return [] 166 167 168class AutomaticControlDependencies(object): 169 """Context manager to automatically add control dependencies. 170 171 Code under this context manager will act as if a sensible set of control 172 dependencies were present. More specifically: 173 1. All stateful ops in the scope will execute (with the exception of ops in 174 ASYNC_STATEFUL_OPS and LEGACY_RANDOM_OPS) 175 2. Stateful ops which modify the same resource will execute in program order 176 177 Note: creating variables in an automatic control dependencies context is not 178 supported (the value of the variables will never change as they will keep 179 getting reinitialized). 180 181 NOT THREAD SAFE 182 """ 183 184 def __init__(self, 185 record_initial_resource_uses=False, 186 record_uses_of_resource_ids=None): 187 self._returned_tensors = object_identity.ObjectIdentitySet() 188 self.ops_which_must_run = set() 189 self.record_initial_resource_uses = record_initial_resource_uses 190 self.record_uses_of_resource_ids = record_uses_of_resource_ids 191 192 def mark_as_return(self, tensor): 193 """Acts like identity but marks the `Tensor` as a return value. 194 195 This will possibly return a copy of the `Tensor`. Usage: 196 197 ``` 198 with AutomaticControlDependencies() as a: 199 ... 200 t = a.mark_as_return(t) 201 _ = ...(t...) # i.e. it's safe to use t here 202 ``` 203 204 Args: 205 tensor: the `Tensor` to be marked 206 207 Returns: 208 a copy of the `Tensor`. 209 """ 210 if isinstance(tensor, ops.IndexedSlices): 211 values = array_ops.identity(tensor.values) 212 indices = array_ops.identity(tensor.indices) 213 self._returned_tensors.add(indices) 214 self._returned_tensors.add(values) 215 return ops.IndexedSlices(values, indices, dense_shape=tensor.dense_shape) 216 elif isinstance(tensor, sparse_tensor.SparseTensor): 217 values = array_ops.identity(tensor.values) 218 indices = array_ops.identity(tensor.indices) 219 self._returned_tensors.add(indices) 220 self._returned_tensors.add(values) 221 return sparse_tensor.SparseTensor( 222 indices, values, dense_shape=tensor.dense_shape) 223 elif isinstance(tensor, tensor_array_ops.TensorArray): 224 flow = array_ops.identity(tensor.flow) 225 self._returned_tensors.add(flow) 226 return tensor_array_ops.build_ta_with_new_flow(tensor, flow) 227 # We want to make the return values depend on the stateful operations, but 228 # we don't want to introduce a cycle, so we make the return value the result 229 # of a new identity operation that the stateful operations definitely don't 230 # depend on. 231 tensor = array_ops.identity(tensor) 232 self._returned_tensors.add(tensor) 233 return tensor 234 235 def __enter__(self): 236 if context.executing_eagerly(): 237 return self 238 # This code assumes no other thread is adding ops to the graph while 239 # we're adding ops to the graph. 240 # TODO(apassos): Fix this by locking the graph or using a temporary 241 # graph (but that would mess up devices and collections at least, 242 # probably other things as well). 243 self._graph = ops.get_default_graph() 244 self._graph._add_control_dependencies = True # pylint: disable=protected-access 245 self._n_operations = len(self._graph.get_operations()) 246 return self 247 248 def _process_switch(self, switch_op, ops_which_must_run, 249 last_write_to_resource, merge_for_resource): 250 """Processes a switch node for a resource input. 251 252 When tensorflow creates a cond, it creates a control flow context for each 253 branch of the cond. Each external tensor accessed by that branch is routed 254 through a switch op, which gets created in the graph _after_ the op which 255 uses that tensor get created. 256 257 If the resource comes from another switch op we process that one first. 258 259 _process_switch creates a corresponding merge node for the switch node. This 260 merge node is added to the outer control flow context of the switch 261 node. We also ensure that: 262 263 1. The switch node executes after the previous op which used the resource 264 tensor 265 266 2. Any op which uses a resource output of the switch node executes before 267 the merge for the switch node. 268 269 3. The next op which uses the input resource to the switch node (which 270 might be another switch node for the other branch of the conditional) 271 will execute after the merge node is done. 272 273 4. The merge node is marked as must_run so it will run even if no 274 subsequent operation uses the resource. 275 276 Args: 277 switch_op: the switch op to be processed 278 ops_which_must_run: the set of ops which must run 279 last_write_to_resource: map from resource tensor to last op updating 280 it 281 merge_for_resource: map from resource tensor to merge which must follow 282 all usages of it. 283 """ 284 # pylint: disable=protected-access 285 inp = switch_op.inputs[0] 286 input_id = ops.tensor_id(inp) 287 if inp.dtype == dtypes_module.resource and inp.op.type == "Switch": 288 self._process_switch(inp.op, ops_which_must_run, last_write_to_resource, 289 merge_for_resource) 290 output = switch_op.outputs[0] 291 output_id = ops.tensor_id(output) 292 if output_id in merge_for_resource: 293 return 294 new_merge = control_flow_ops.merge( 295 switch_op.outputs, name="artificial_merge") 296 new_merge[0].op._control_flow_context = ( 297 switch_op._control_flow_context.outer_context) 298 # Ensures the merge always runs 299 ops_which_must_run.add(new_merge[0].op) 300 if input_id in last_write_to_resource: 301 # Ensures the switch executes after the previous op using the resource. 302 switch_op._add_control_input(last_write_to_resource[input_id]) 303 # Ensure the next op outside the cond happens after the merge. 304 last_write_to_resource[input_id] = new_merge[0].op 305 if input_id in merge_for_resource: 306 merge_for_resource[input_id]._add_control_input(new_merge[0].op) 307 for o in switch_op.outputs: 308 # Ensures the merge will execute after all ops inside the cond 309 merge_for_resource[ops.tensor_id(o)] = new_merge[0].op 310 311 def __exit__(self, unused_type, unused_value, unused_traceback): 312 # pylint: disable=protected-access 313 if context.executing_eagerly(): 314 return 315 316 if self._graph is not ops.get_default_graph(): 317 raise RuntimeError( 318 "Graph changed while trying to add control dependencies.") 319 320 if hasattr(self._graph, "outer_graph"): 321 outer_val = self._graph.outer_graph._add_control_dependencies 322 self._graph._add_control_dependencies = outer_val 323 else: 324 self._graph._add_control_dependencies = False 325 326 # map from resource tensor to the last op which wrote to it 327 last_write_to_resource = {} 328 # map from resource tensor to the list of reads from it since the last 329 # write or since the beginning of the function. 330 reads_since_last_write_to_resource = collections.defaultdict(list) 331 # CollectiveManager manager_ids within a particular function call should not 332 # be needed outside of that function call. So we keep them separate (though 333 # the general idea of the maps is the same, in the future, we'll need to 334 # correctly thread the control output outside). 335 # Map from collective manager scope to the last op which used it 336 collective_manager_scopes_opened = {} 337 collective_manager_scopes_used = {} 338 # set of conditional and loop exits 339 ops_which_must_run = set() 340 # merge which must depend on ops which use this resource 341 merge_for_resource = {} 342 343 new_operations = self._graph.get_operations()[self._n_operations:] 344 first_use_for_res = {} 345 resources_by_op = {} 346 347 # Ensures that uses of resource tensors get serialized properly and all 348 # execute. This is done by keeping a map from resource tensor to the last op 349 # in graph-construction order which used it (last_write_to_resource). 350 # 351 # Conditionals are written in TensorFlow such that every external tensor 352 # accessed in the conditional goes through a switch op and every return 353 # tensor (it's guaranteed that there will be at least one) goes through a 354 # merge op. 355 # 356 # To handle conditionals, switches are handled in a special way (see 357 # comments for _process_switch). Merge nodes created by TF's conditional 358 # logic (as opposed to by _process_switch) are forced to run and also get a 359 # control dependency added to them to ensure all stateful ops inside their 360 # control flow context run. 361 # 362 # We also ensure that if an op is using a resource output by a switch node 363 # (that is, a resource tensor for which there's a value in 364 # merge_for_resource) this op will run before the merge for that resource. 365 # 366 # We try to add control inputs to nodes respecting their control flow 367 # contexts to avoid dead nodes propagating everywhere and leading to 368 # "retval[0] doesn't have value" errors. If a node gets a control dependency 369 # on a dead node (i.e. a note from an untaken control flow branch) that node 370 # will be marked as dead unless it's a merge node. 371 # 372 # TODO(apassos): serialize non-resource-taking stateful ops as well, and 373 # test that it works. Support while loops. Support init_scope escaping from 374 # this. 375 for op in new_operations: 376 # TODO(apassos) make this code safely support while loops. 377 if control_flow_util.IsInWhileLoop(op): 378 continue 379 control_inputs = set() 380 # Ensure stateful ops run. 381 # Read-only ops are added to control outputs if the read value is 382 # consumed. This covers the case when the read value is returned from 383 # the function since that goes through a tf.identity in mark_as_return. 384 if (op_def_registry.get(op.type) is None or 385 (op_is_stateful(op) and 386 (op.type not in utils.RESOURCE_READ_OPS or 387 any(output.consumers() for output in op.outputs)))): 388 ops_which_must_run.add(op) 389 # Make a note of all opened manager_ids. 390 if op.type == "NoOp": 391 try: 392 collective_manager_scopes_opened[op.get_attr( 393 "_collective_manager_id")] = op 394 except ValueError: 395 pass 396 # Ignore switches (they're handled separately) 397 if op.type == "Switch" and op.inputs[0].dtype == dtypes_module.resource: 398 continue 399 # Make merges trigger all other computation which must run 400 # TODO(mdan): Don't do this. Write a transform to chains instead. 401 # See core/common_runtime/control_flow_deps_to_chains.cc. 402 if op.type == "Merge": 403 for o in ops_which_must_run: 404 op._add_control_input(o) 405 for inp in o.inputs: 406 input_id = ops.tensor_id(inp) 407 if input_id in last_write_to_resource: 408 last_write_to_resource[input_id] = op 409 ops_which_must_run = set([op]) 410 continue 411 412 resource_inputs = set() 413 # Check for any resource inputs. If we find any, we update control_inputs 414 # and last_write_to_resource. 415 for inp, resource_type in _get_resource_inputs(op): 416 is_read = resource_type == ResourceType.READ_ONLY 417 input_id = ops.tensor_id(inp) 418 419 # If the op receives the same resource tensor twice as an input, we skip 420 # to avoid the op getting a control dependency on itself. 421 if input_id in resource_inputs: 422 continue 423 424 resource_inputs.add(input_id) 425 # Deal with switches, finally. 426 if inp.op.type == "Switch": 427 self._process_switch(inp.op, ops_which_must_run, 428 last_write_to_resource, merge_for_resource) 429 is_building_function = op.graph.building_function 430 # Ensure uses of resources are serialized 431 if input_id in last_write_to_resource: 432 if is_building_function or ( 433 last_write_to_resource[input_id]._control_flow_context 434 is op._control_flow_context): 435 control_inputs.add(last_write_to_resource[input_id]) 436 # Ensure merges happen after the closing of a cond block 437 if input_id in merge_for_resource: 438 merge_for_resource[input_id]._add_control_input(op) 439 440 do_record = ( 441 self.record_initial_resource_uses and 442 input_id not in first_use_for_res) 443 444 if is_read: 445 reads_list = reads_since_last_write_to_resource[input_id] 446 reads_list.append(op) 447 448 if do_record: 449 # Note: this will track the entire list that 450 # reads_since_last_write_to_resource maintains. Updates to it will 451 # and should be tracked, until the first write is encountered. At 452 # that point, reads_since_last_write_to_resource will contain a new 453 # empty list. This logic relies on that behavior. 454 first_use_for_res[input_id] = reads_list 455 456 else: 457 control_inputs.update(reads_since_last_write_to_resource[input_id]) 458 reads_since_last_write_to_resource[input_id] = [] 459 last_write_to_resource[input_id] = op 460 461 if do_record: 462 first_use_for_res[input_id] = [op] 463 464 if self.record_initial_resource_uses and op_is_stateful(op): 465 if resource_inputs: 466 resources_by_op[op] = tuple(resource_inputs) 467 else: 468 if None not in first_use_for_res: 469 first_use_for_res[None] = [op] 470 resources_by_op[op] = (None,) 471 472 if (op_is_stateful(op) and not resource_inputs 473 and op._control_flow_context is None): 474 if None in last_write_to_resource: 475 op._add_control_input(last_write_to_resource[None]) 476 last_write_to_resource[None] = op 477 478 # Ensure ordering of collective ops 479 manager_ids = collective_manager_ids_from_op(op) 480 for manager_id in manager_ids: 481 if manager_id in collective_manager_scopes_opened: 482 # Chain this function call if the scope was opened. 483 op._add_control_input(collective_manager_scopes_opened[manager_id]) 484 collective_manager_scopes_opened[manager_id] = op 485 else: 486 # If this op is in a scope not created here, create a chain starting 487 # at this op. 488 if manager_id in collective_manager_scopes_used: 489 op._add_control_input(collective_manager_scopes_used[manager_id]) 490 collective_manager_scopes_used[manager_id] = op 491 492 if control_inputs and not is_building_function: 493 control_inputs = [ 494 c for c in control_inputs 495 if c._control_flow_context is op._control_flow_context 496 ] 497 498 op._add_control_inputs(control_inputs) 499 500 # Record the ops which first use resources touched by "ops which must run". 501 if self.record_initial_resource_uses: 502 first_uses_by_output_ops = {} 503 for op in ops_which_must_run: 504 if op not in resources_by_op: 505 # This may happen with Merge/Switch nodes which are special cased 506 # above. 507 continue 508 for r in resources_by_op[op]: 509 if op not in first_uses_by_output_ops: 510 first_uses_by_output_ops[op] = set() 511 first_uses_by_output_ops[op].update(first_use_for_res[r]) 512 # For each "op which must run", set a private attr indicating the ops that 513 # used the same resources it did. 514 for op in first_uses_by_output_ops: 515 others = [ 516 other.name.encode() for other in first_uses_by_output_ops[op] 517 ] 518 l = attr_value_pb2.AttrValue.ListValue(s=others) 519 # TODO(mdan): Is there a way which doesn't use anonymous attrs? 520 op._set_attr("_res_first_used_by", attr_value_pb2.AttrValue(list=l)) 521 522 # Ensure all ops which must run do run 523 self.ops_which_must_run.update(ops_which_must_run) 524 control_output_op = None 525 for idx, r in enumerate( 526 nest.flatten(list(self._returned_tensors), expand_composites=True)): 527 if self.ops_which_must_run: 528 updated_ops_which_must_run = [] 529 if r.graph.building_function: 530 # There may be many stateful ops in the graph. Adding them as 531 # control inputs to each function output could create excessive 532 # control edges in the graph. Thus we create an intermediate No-op 533 # to chain the control dependencies between stateful ops and 534 # function outputs. 535 if idx == 0: 536 control_output_op = control_flow_ops.no_op() 537 control_output_op._set_attr("_acd_function_control_output", 538 attr_value_pb2.AttrValue(b=True)) 539 control_output_op._add_control_inputs(self.ops_which_must_run) 540 updated_ops_which_must_run = [control_output_op] 541 else: 542 updated_ops_which_must_run = [ 543 o for o in self.ops_which_must_run 544 if o._control_flow_context is r.op._control_flow_context 545 ] 546 r.op._add_control_inputs(updated_ops_which_must_run) 547 548 self.collective_manager_ids_used = collective_manager_scopes_used 549 550 551_acd_resource_resolvers_registry = registry.Registry("acd_resource_resolvers") 552 553 554def register_acd_resource_resolver(f): 555 """Register a function for resolving resources touched by an op. 556 557 `f` is called for every Operation added in the ACD context with the op's 558 original resource reads and writes. `f` is expected to update the sets of 559 resource reads and writes in-place and return True if it updated either of the 560 sets, False otherwise. 561 562 Example: 563 @register_acd_resource_resolver 564 def ResolveIdentity(op, resource_reads, resource_writes): 565 # op: The `Operation` being processed by ACD currently. 566 # resource_reads: An `ObjectIdentitySet` of read-only resources. 567 # resource_writes: An `ObjectIdentitySet` of read-write resources. 568 if not resource_reads or resource_writes: 569 return False 570 def update(resource_inputs): 571 to_add = [] 572 to_remove = [] 573 for t in resource_inputs: 574 if t.op.type == "Identity": 575 to_remove.append(t) 576 to_add.append(t.op.inputs[0]) 577 if not to_add and not to_remove: 578 return False 579 for t in to_remove: 580 resource_inputs.discard(t) 581 resource_inputs.update(to_add) 582 return True 583 return update(resource_reads) or update(resource_writes) 584 585 Args: 586 f: Python function with signature 587 (Operation, ObjectIdentitySet, ObjectIdentitySet) -> bool 588 589 Returns: 590 The function `f` after adding it to the registry. 591 """ 592 _acd_resource_resolvers_registry.register(f) 593 return f 594 595 596def _get_resource_inputs(op): 597 """Returns an iterable of resources touched by this `op`.""" 598 reads, writes = utils.get_read_write_resource_inputs(op) 599 saturated = False 600 while not saturated: 601 saturated = True 602 for key in _acd_resource_resolvers_registry.list(): 603 # Resolvers should return true if they are updating the list of 604 # resource_inputs. 605 # TODO(srbs): An alternate would be to just compare the old and new set 606 # but that may not be as fast. 607 updated = _acd_resource_resolvers_registry.lookup(key)(op, reads, writes) 608 if updated: 609 # Conservatively remove any resources from `reads` that are also writes. 610 reads = reads.difference(writes) 611 saturated = saturated and not updated 612 613 # Note: A resource handle that is not written to is treated as read-only. We 614 # don't have a special way of denoting an unused resource. 615 for t in reads: 616 yield (t, ResourceType.READ_ONLY) 617 for t in writes: 618 yield (t, ResourceType.READ_WRITE) 619 620 621def automatic_control_dependencies(f): 622 """Wraps f to automatically insert control dependencies. 623 624 The inserted dependencies ensure that: 625 1. All stateful ops in f run when the result of f runs 626 2. Updates to the same resources happen in order. 627 628 Args: 629 f: the function to be wrapped. 630 631 Returns: 632 The wrapped function. 633 """ 634 635 def wrapper(*args, **kwargs): 636 with AutomaticControlDependencies() as a: 637 result = f(*args, **kwargs) 638 result_flat = [a.mark_as_return(t) for t in nest.flatten(result)] 639 return nest.pack_sequence_as(result, result_flat) 640 641 return tf_decorator.make_decorator(f, wrapper) 642