• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""AutomaticControlDependencies and related functionality."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import enum
23
24from tensorflow.core.framework import attr_value_pb2
25from tensorflow.python.eager import context
26from tensorflow.python.framework import auto_control_deps_utils as utils
27from tensorflow.python.framework import dtypes as dtypes_module
28from tensorflow.python.framework import op_def_registry
29from tensorflow.python.framework import ops
30from tensorflow.python.framework import registry
31from tensorflow.python.framework import sparse_tensor
32from tensorflow.python.ops import array_ops
33from tensorflow.python.ops import control_flow_ops
34from tensorflow.python.ops import control_flow_util
35from tensorflow.python.ops import tensor_array_ops
36from tensorflow.python.util import nest
37from tensorflow.python.util import object_identity
38from tensorflow.python.util import tf_decorator
39
40# LINT.IfChange
41# Op types that should not run in program order, e.g. because they need to run
42# asynchronously to avoid deadlock.
43ASYNC_STATEFUL_OPS = [
44    "CollectiveGather",
45    "CollectiveGatherV2",
46    "CollectiveReduce",
47    "CollectiveReduceV2",
48    "CollectiveBcastSend",
49    "CollectiveBcastSendV2",
50    "CollectiveBcastRecv",
51    "CollectiveBcastRecvV2",
52    "NcclAllReduce",
53    # We do not add "Send" here since we want it to be added as a control output
54    # in order to avoid being pruned.
55    "Recv",
56]
57
58LEGACY_RANDOM_OPS = [
59    # These may be used in variable initializers -- thus their execution should
60    # not be dependent on other stateful operations.  This is because although
61    # according to program order, tf.Variables may be created in sequence,
62    # their initialization happens outside of the program order (specifically,
63    # in graph mode their initialization happens by calling a grouped
64    # initializer operation or in eager mode, where initialization is lifted
65    # out of the tf.function and executed the first time the function is
66    # executed).
67    #
68    # Unless there is a specific dependency between the initializers
69    # themselves (e.g. one initializer depends on a Variable whose value depends
70    # on another initializer), the initialization can happen in any order so
71    # long as it's before the associated Variable read operations.
72    #
73    # Note that in general the randomness of legacy random operations is only
74    # guaranteed by providing a graph-level and op-level seed (and ordering of
75    # the same op across multiple iterations of a while_loop is specifically not
76    # guaranteed; see the discussion below).
77    #
78    # There is a possible race condition inside while_loop where the same
79    # random OpKernel instantiation is reused across multiple steps
80    # of the loop.  Since legacy Random OpKernels have an internal rng state,
81    # automatic dependency tracking across loop steps would likely
82    # fix this race; and for that case this denylist is problematic.
83    # However, since automatic dependency tracking inside while loops is not
84    # currently supported, and there are no other examples of OpKernel reuse
85    # (each OpKernel is associated with a unique op in graph mode),
86    # this denylist has no effect on the aforementioned behavior.
87    #
88    # TODO(ebrevdo,skyewm): Modify the check against this denylist to
89    # only occur when the op is inside a "variable initialization scope"; and
90    # add proper autodeps inside while_loops that respects this updated check.
91    "RandomUniform",
92    "RandomUniformInt",
93    "RandomStandardNormal",
94    "ParameterizedTruncatedNormal",
95    "TruncatedNormal",
96    "RandomShuffle",
97    "Multinomial",
98    "RandomGamma",
99    "RandomGammaGrad",
100    "RandomPoisson",
101    "RandomPoissonV2",
102]
103
104_ORDER_INSENSITIVE_STATEFUL_OPS = [
105    "CudnnRNN", "CudnnRNNBackprop", "CudnnRNNV2", "CudnnRNNV3",
106    "CudnnRNNBackpropV2", "CudnnRNNBackpropV3",
107    "EnqueueTPUEmbeddingSparseBatch", "EnqueueTPUEmbeddingIntegerBatch",
108    "EnqueueTPUEmbeddingSparseTensorBatch",
109    "EnqueueTPUEmbeddingRaggedTensorBatch", "RestoreV2", "SaveV2"
110]
111# LINT.ThenChange(//tensorflow/core/grappler/optimizers/function_optimizer.cc)
112
113_ALL_DENYLISTED_OPS = (
114    set(ASYNC_STATEFUL_OPS) | set(LEGACY_RANDOM_OPS)
115    | set(_ORDER_INSENSITIVE_STATEFUL_OPS))
116
117# Op types that are marked as stateless, but should be allowlisted to add auto
118# control dependencies.
119_ALLOWLIST_STATELESS_OPS = [
120    # As TPU collective ops are blocking, if there are more than one collective
121    # op in the function, we need to make sure different collectives ops are
122    # scheduled in certain orders. Otherwise if at the same time all the
123    # replicas are launching different collective ops/programs, it may cause
124    # deadlock.
125    "AllToAll",
126    "CrossReplicaSum",
127    "CollectivePermute",
128]
129
130
131def op_is_stateful(op):
132  # pylint: disable=protected-access
133  return (op._is_stateful and op.type not in _ALL_DENYLISTED_OPS) or (
134      op.type in _ALLOWLIST_STATELESS_OPS)
135
136
137class ResourceType(enum.Enum):
138  READ_ONLY = "read-only"
139  READ_WRITE = "read-write"
140
141
142def collective_manager_ids_from_op(op):
143  """Returns CollectiveManager ID from the op if one exists, else None.
144
145  CollectiveManager adds collective and no_op operations tagged with an ID,
146  unique to the manager object. This function extracts that ID, or None, if the
147  node was not generated by a CollectiveManager.
148
149  Args:
150    op: `Operation` to get the collective manager ID from.
151
152  Returns:
153    List of CollectiveManager IDs used by the op.
154  """
155  if op.type == "CollectiveReduce":
156    try:
157      return [op.get_attr("_collective_manager_id")]
158    except ValueError:
159      pass
160  elif op.type == "StatefulPartitionedCall":
161    try:
162      return op.get_attr(utils.COLLECTIVE_MANAGER_IDS)
163    except ValueError:
164      pass
165  return []
166
167
168class AutomaticControlDependencies(object):
169  """Context manager to automatically add control dependencies.
170
171  Code under this context manager will act as if a sensible set of control
172  dependencies were present. More specifically:
173    1. All stateful ops in the scope will execute (with the exception of ops in
174       ASYNC_STATEFUL_OPS and LEGACY_RANDOM_OPS)
175    2. Stateful ops which modify the same resource will execute in program order
176
177  Note: creating variables in an automatic control dependencies context is not
178  supported (the value of the variables will never change as they will keep
179  getting reinitialized).
180
181  NOT THREAD SAFE
182  """
183
184  def __init__(self,
185               record_initial_resource_uses=False,
186               record_uses_of_resource_ids=None):
187    self._returned_tensors = object_identity.ObjectIdentitySet()
188    self.ops_which_must_run = set()
189    self.record_initial_resource_uses = record_initial_resource_uses
190    self.record_uses_of_resource_ids = record_uses_of_resource_ids
191
192  def mark_as_return(self, tensor):
193    """Acts like identity but marks the `Tensor` as a return value.
194
195    This will possibly return a copy of the `Tensor`. Usage:
196
197    ```
198      with AutomaticControlDependencies() as a:
199       ...
200       t = a.mark_as_return(t)
201      _ = ...(t...)  # i.e. it's safe to use t here
202    ```
203
204    Args:
205      tensor: the `Tensor` to be marked
206
207    Returns:
208      a copy of the `Tensor`.
209    """
210    if isinstance(tensor, ops.IndexedSlices):
211      values = array_ops.identity(tensor.values)
212      indices = array_ops.identity(tensor.indices)
213      self._returned_tensors.add(indices)
214      self._returned_tensors.add(values)
215      return ops.IndexedSlices(values, indices, dense_shape=tensor.dense_shape)
216    elif isinstance(tensor, sparse_tensor.SparseTensor):
217      values = array_ops.identity(tensor.values)
218      indices = array_ops.identity(tensor.indices)
219      self._returned_tensors.add(indices)
220      self._returned_tensors.add(values)
221      return sparse_tensor.SparseTensor(
222          indices, values, dense_shape=tensor.dense_shape)
223    elif isinstance(tensor, tensor_array_ops.TensorArray):
224      flow = array_ops.identity(tensor.flow)
225      self._returned_tensors.add(flow)
226      return tensor_array_ops.build_ta_with_new_flow(tensor, flow)
227    # We want to make the return values depend on the stateful operations, but
228    # we don't want to introduce a cycle, so we make the return value the result
229    # of a new identity operation that the stateful operations definitely don't
230    # depend on.
231    tensor = array_ops.identity(tensor)
232    self._returned_tensors.add(tensor)
233    return tensor
234
235  def __enter__(self):
236    if context.executing_eagerly():
237      return self
238    # This code assumes no other thread is adding ops to the graph while
239    # we're adding ops to the graph.
240    # TODO(apassos): Fix this by locking the graph or using a temporary
241    # graph (but that would mess up devices and collections at least,
242    # probably other things as well).
243    self._graph = ops.get_default_graph()
244    self._graph._add_control_dependencies = True  # pylint: disable=protected-access
245    self._n_operations = len(self._graph.get_operations())
246    return self
247
248  def _process_switch(self, switch_op, ops_which_must_run,
249                      last_write_to_resource, merge_for_resource):
250    """Processes a switch node for a resource input.
251
252    When tensorflow creates a cond, it creates a control flow context for each
253    branch of the cond. Each external tensor accessed by that branch is routed
254    through a switch op, which gets created in the graph _after_ the op which
255    uses that tensor get created.
256
257    If the resource comes from another switch op we process that one first.
258
259    _process_switch creates a corresponding merge node for the switch node. This
260    merge node is added to the outer control flow context of the switch
261    node. We also ensure that:
262
263      1. The switch node executes after the previous op which used the resource
264         tensor
265
266      2. Any op which uses a resource output of the switch node executes before
267         the merge for the switch node.
268
269      3. The next op which uses the input resource to the switch node (which
270         might be another switch node for the other branch of the conditional)
271         will execute after the merge node is done.
272
273      4. The merge node is marked as must_run so it will run even if no
274         subsequent operation uses the resource.
275
276    Args:
277      switch_op: the switch op to be processed
278      ops_which_must_run: the set of ops which must run
279      last_write_to_resource: map from resource tensor to last op updating
280        it
281      merge_for_resource: map from resource tensor to merge which must follow
282        all usages of it.
283    """
284    # pylint: disable=protected-access
285    inp = switch_op.inputs[0]
286    input_id = ops.tensor_id(inp)
287    if inp.dtype == dtypes_module.resource and inp.op.type == "Switch":
288      self._process_switch(inp.op, ops_which_must_run, last_write_to_resource,
289                           merge_for_resource)
290    output = switch_op.outputs[0]
291    output_id = ops.tensor_id(output)
292    if output_id in merge_for_resource:
293      return
294    new_merge = control_flow_ops.merge(
295        switch_op.outputs, name="artificial_merge")
296    new_merge[0].op._control_flow_context = (
297        switch_op._control_flow_context.outer_context)
298    # Ensures the merge always runs
299    ops_which_must_run.add(new_merge[0].op)
300    if input_id in last_write_to_resource:
301      # Ensures the switch executes after the previous op using the resource.
302      switch_op._add_control_input(last_write_to_resource[input_id])
303    # Ensure the next op outside the cond happens after the merge.
304    last_write_to_resource[input_id] = new_merge[0].op
305    if input_id in merge_for_resource:
306      merge_for_resource[input_id]._add_control_input(new_merge[0].op)
307    for o in switch_op.outputs:
308      # Ensures the merge will execute after all ops inside the cond
309      merge_for_resource[ops.tensor_id(o)] = new_merge[0].op
310
311  def __exit__(self, unused_type, unused_value, unused_traceback):
312    # pylint: disable=protected-access
313    if context.executing_eagerly():
314      return
315
316    if self._graph is not ops.get_default_graph():
317      raise RuntimeError(
318          "Graph changed while trying to add control dependencies.")
319
320    if hasattr(self._graph, "outer_graph"):
321      outer_val = self._graph.outer_graph._add_control_dependencies
322      self._graph._add_control_dependencies = outer_val
323    else:
324      self._graph._add_control_dependencies = False
325
326    # map from resource tensor to the last op which wrote to it
327    last_write_to_resource = {}
328    # map from resource tensor to the list of reads from it since the last
329    # write or since the beginning of the function.
330    reads_since_last_write_to_resource = collections.defaultdict(list)
331    # CollectiveManager manager_ids within a particular function call should not
332    # be needed outside of that function call. So we keep them separate (though
333    # the general idea of the maps is the same, in the future, we'll need to
334    # correctly thread the control output outside).
335    # Map from collective manager scope to the last op which used it
336    collective_manager_scopes_opened = {}
337    collective_manager_scopes_used = {}
338    # set of conditional and loop exits
339    ops_which_must_run = set()
340    # merge which must depend on ops which use this resource
341    merge_for_resource = {}
342
343    new_operations = self._graph.get_operations()[self._n_operations:]
344    first_use_for_res = {}
345    resources_by_op = {}
346
347    # Ensures that uses of resource tensors get serialized properly and all
348    # execute. This is done by keeping a map from resource tensor to the last op
349    # in graph-construction order which used it (last_write_to_resource).
350    #
351    # Conditionals are written in TensorFlow such that every external tensor
352    # accessed in the conditional goes through a switch op and every return
353    # tensor (it's guaranteed that there will be at least one) goes through a
354    # merge op.
355    #
356    # To handle conditionals, switches are handled in a special way (see
357    # comments for _process_switch). Merge nodes created by TF's conditional
358    # logic (as opposed to by _process_switch) are forced to run and also get a
359    # control dependency added to them to ensure all stateful ops inside their
360    # control flow context run.
361    #
362    # We also ensure that if an op is using a resource output by a switch node
363    # (that is, a resource tensor for which there's a value in
364    # merge_for_resource) this op will run before the merge for that resource.
365    #
366    # We try to add control inputs to nodes respecting their control flow
367    # contexts to avoid dead nodes propagating everywhere and leading to
368    # "retval[0] doesn't have value" errors. If a node gets a control dependency
369    # on a dead node (i.e. a note from an untaken control flow branch) that node
370    # will be marked as dead unless it's a merge node.
371    #
372    # TODO(apassos): serialize non-resource-taking stateful ops as well, and
373    # test that it works. Support while loops. Support init_scope escaping from
374    # this.
375    for op in new_operations:
376      # TODO(apassos) make this code safely support while loops.
377      if control_flow_util.IsInWhileLoop(op):
378        continue
379      control_inputs = set()
380      # Ensure stateful ops run.
381      # Read-only ops are added to control outputs if the read value is
382      # consumed. This covers the case when the read value is returned from
383      # the function since that goes through a tf.identity in mark_as_return.
384      if (op_def_registry.get(op.type) is None or
385          (op_is_stateful(op) and
386           (op.type not in utils.RESOURCE_READ_OPS or
387            any(output.consumers() for output in op.outputs)))):
388        ops_which_must_run.add(op)
389      # Make a note of all opened manager_ids.
390      if op.type == "NoOp":
391        try:
392          collective_manager_scopes_opened[op.get_attr(
393              "_collective_manager_id")] = op
394        except ValueError:
395          pass
396      # Ignore switches (they're handled separately)
397      if op.type == "Switch" and op.inputs[0].dtype == dtypes_module.resource:
398        continue
399      # Make merges trigger all other computation which must run
400      # TODO(mdan): Don't do this. Write a transform to chains instead.
401      # See core/common_runtime/control_flow_deps_to_chains.cc.
402      if op.type == "Merge":
403        for o in ops_which_must_run:
404          op._add_control_input(o)
405          for inp in o.inputs:
406            input_id = ops.tensor_id(inp)
407            if input_id in last_write_to_resource:
408              last_write_to_resource[input_id] = op
409        ops_which_must_run = set([op])
410        continue
411
412      resource_inputs = set()
413      # Check for any resource inputs. If we find any, we update control_inputs
414      # and last_write_to_resource.
415      for inp, resource_type in _get_resource_inputs(op):
416        is_read = resource_type == ResourceType.READ_ONLY
417        input_id = ops.tensor_id(inp)
418
419        # If the op receives the same resource tensor twice as an input, we skip
420        # to avoid the op getting a control dependency on itself.
421        if input_id in resource_inputs:
422          continue
423
424        resource_inputs.add(input_id)
425        # Deal with switches, finally.
426        if inp.op.type == "Switch":
427          self._process_switch(inp.op, ops_which_must_run,
428                               last_write_to_resource, merge_for_resource)
429        is_building_function = op.graph.building_function
430        # Ensure uses of resources are serialized
431        if input_id in last_write_to_resource:
432          if is_building_function or (
433              last_write_to_resource[input_id]._control_flow_context
434              is op._control_flow_context):
435            control_inputs.add(last_write_to_resource[input_id])
436        # Ensure merges happen after the closing of a cond block
437        if input_id in merge_for_resource:
438          merge_for_resource[input_id]._add_control_input(op)
439
440        do_record = (
441            self.record_initial_resource_uses and
442            input_id not in first_use_for_res)
443
444        if is_read:
445          reads_list = reads_since_last_write_to_resource[input_id]
446          reads_list.append(op)
447
448          if do_record:
449            # Note: this will track the entire list that
450            # reads_since_last_write_to_resource maintains. Updates to it will
451            # and should be tracked, until the first write is encountered. At
452            # that point, reads_since_last_write_to_resource will contain a new
453            # empty list. This logic relies on that behavior.
454            first_use_for_res[input_id] = reads_list
455
456        else:
457          control_inputs.update(reads_since_last_write_to_resource[input_id])
458          reads_since_last_write_to_resource[input_id] = []
459          last_write_to_resource[input_id] = op
460
461          if do_record:
462            first_use_for_res[input_id] = [op]
463
464      if self.record_initial_resource_uses and op_is_stateful(op):
465        if resource_inputs:
466          resources_by_op[op] = tuple(resource_inputs)
467        else:
468          if None not in first_use_for_res:
469            first_use_for_res[None] = [op]
470          resources_by_op[op] = (None,)
471
472      if (op_is_stateful(op) and not resource_inputs
473          and op._control_flow_context is None):
474        if None in last_write_to_resource:
475          op._add_control_input(last_write_to_resource[None])
476        last_write_to_resource[None] = op
477
478      # Ensure ordering of collective ops
479      manager_ids = collective_manager_ids_from_op(op)
480      for manager_id in manager_ids:
481        if manager_id in collective_manager_scopes_opened:
482          # Chain this function call if the scope was opened.
483          op._add_control_input(collective_manager_scopes_opened[manager_id])
484          collective_manager_scopes_opened[manager_id] = op
485        else:
486          # If this op is in a scope not created here, create a chain starting
487          # at this op.
488          if manager_id in collective_manager_scopes_used:
489            op._add_control_input(collective_manager_scopes_used[manager_id])
490          collective_manager_scopes_used[manager_id] = op
491
492      if control_inputs and not is_building_function:
493        control_inputs = [
494            c for c in control_inputs
495            if c._control_flow_context is op._control_flow_context
496        ]
497
498      op._add_control_inputs(control_inputs)
499
500    # Record the ops which first use resources touched by "ops which must run".
501    if self.record_initial_resource_uses:
502      first_uses_by_output_ops = {}
503      for op in ops_which_must_run:
504        if op not in resources_by_op:
505          # This may happen with Merge/Switch nodes which are special cased
506          # above.
507          continue
508        for r in resources_by_op[op]:
509          if op not in first_uses_by_output_ops:
510            first_uses_by_output_ops[op] = set()
511          first_uses_by_output_ops[op].update(first_use_for_res[r])
512      # For each "op which must run", set a private attr indicating the ops that
513      # used the same resources it did.
514      for op in first_uses_by_output_ops:
515        others = [
516            other.name.encode() for other in first_uses_by_output_ops[op]
517        ]
518        l = attr_value_pb2.AttrValue.ListValue(s=others)
519        # TODO(mdan): Is there a way which doesn't use anonymous attrs?
520        op._set_attr("_res_first_used_by", attr_value_pb2.AttrValue(list=l))
521
522    # Ensure all ops which must run do run
523    self.ops_which_must_run.update(ops_which_must_run)
524    control_output_op = None
525    for idx, r in enumerate(
526        nest.flatten(list(self._returned_tensors), expand_composites=True)):
527      if self.ops_which_must_run:
528        updated_ops_which_must_run = []
529        if r.graph.building_function:
530          # There may be many stateful ops in the graph. Adding them as
531          # control inputs to each function output could create excessive
532          # control edges in the graph. Thus we create an intermediate No-op
533          # to chain the control dependencies between stateful ops and
534          # function outputs.
535          if idx == 0:
536            control_output_op = control_flow_ops.no_op()
537            control_output_op._set_attr("_acd_function_control_output",
538                                        attr_value_pb2.AttrValue(b=True))
539            control_output_op._add_control_inputs(self.ops_which_must_run)
540          updated_ops_which_must_run = [control_output_op]
541        else:
542          updated_ops_which_must_run = [
543              o for o in self.ops_which_must_run
544              if o._control_flow_context is r.op._control_flow_context
545          ]
546        r.op._add_control_inputs(updated_ops_which_must_run)
547
548    self.collective_manager_ids_used = collective_manager_scopes_used
549
550
551_acd_resource_resolvers_registry = registry.Registry("acd_resource_resolvers")
552
553
554def register_acd_resource_resolver(f):
555  """Register a function for resolving resources touched by an op.
556
557  `f` is called for every Operation added in the ACD context with the op's
558  original resource reads and writes. `f` is expected to update the sets of
559  resource reads and writes in-place and return True if it updated either of the
560  sets, False otherwise.
561
562  Example:
563  @register_acd_resource_resolver
564  def ResolveIdentity(op, resource_reads, resource_writes):
565    # op: The `Operation` being processed by ACD currently.
566    # resource_reads: An `ObjectIdentitySet` of read-only resources.
567    # resource_writes: An `ObjectIdentitySet` of read-write resources.
568    if not resource_reads or resource_writes:
569      return False
570    def update(resource_inputs):
571      to_add = []
572      to_remove = []
573      for t in resource_inputs:
574        if t.op.type == "Identity":
575          to_remove.append(t)
576          to_add.append(t.op.inputs[0])
577      if not to_add and not to_remove:
578        return False
579      for t in to_remove:
580        resource_inputs.discard(t)
581      resource_inputs.update(to_add)
582      return True
583    return update(resource_reads) or update(resource_writes)
584
585  Args:
586    f: Python function with signature
587    (Operation, ObjectIdentitySet, ObjectIdentitySet) -> bool
588
589  Returns:
590    The function `f` after adding it to the registry.
591  """
592  _acd_resource_resolvers_registry.register(f)
593  return f
594
595
596def _get_resource_inputs(op):
597  """Returns an iterable of resources touched by this `op`."""
598  reads, writes = utils.get_read_write_resource_inputs(op)
599  saturated = False
600  while not saturated:
601    saturated = True
602    for key in _acd_resource_resolvers_registry.list():
603      # Resolvers should return true if they are updating the list of
604      # resource_inputs.
605      # TODO(srbs): An alternate would be to just compare the old and new set
606      # but that may not be as fast.
607      updated = _acd_resource_resolvers_registry.lookup(key)(op, reads, writes)
608      if updated:
609        # Conservatively remove any resources from `reads` that are also writes.
610        reads = reads.difference(writes)
611      saturated = saturated and not updated
612
613  # Note: A resource handle that is not written to is treated as read-only. We
614  # don't have a special way of denoting an unused resource.
615  for t in reads:
616    yield (t, ResourceType.READ_ONLY)
617  for t in writes:
618    yield (t, ResourceType.READ_WRITE)
619
620
621def automatic_control_dependencies(f):
622  """Wraps f to automatically insert control dependencies.
623
624  The inserted dependencies ensure that:
625    1. All stateful ops in f run when the result of f runs
626    2. Updates to the same resources happen in order.
627
628  Args:
629    f: the function to be wrapped.
630
631  Returns:
632    The wrapped function.
633  """
634
635  def wrapper(*args, **kwargs):
636    with AutomaticControlDependencies() as a:
637      result = f(*args, **kwargs)
638      result_flat = [a.mark_as_return(t) for t in nest.flatten(result)]
639      return nest.pack_sequence_as(result, result_flat)
640
641  return tf_decorator.make_decorator(f, wrapper)
642