• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""while_v2 and gradient.
16
17This is a version of while_loop that emits a single While op, as well as the
18gradient function for While ops produced by while_loop. This will eventually
19replace the current tf.while_loop implementation once it reaches feature and
20performance parity.
21"""
22from __future__ import absolute_import
23from __future__ import division
24from __future__ import print_function
25
26import collections
27
28from tensorflow.core.framework import attr_value_pb2
29from tensorflow.python.client import pywrap_tf_session as c_api
30from tensorflow.python.eager import backprop_util
31from tensorflow.python.framework import auto_control_deps_utils as acd
32from tensorflow.python.framework import constant_op
33from tensorflow.python.framework import dtypes
34from tensorflow.python.framework import func_graph as func_graph_module
35from tensorflow.python.framework import ops
36from tensorflow.python.framework import tensor_shape
37from tensorflow.python.framework import tensor_spec
38from tensorflow.python.framework import tensor_util
39from tensorflow.python.framework import type_spec
40from tensorflow.python.ops import array_ops
41from tensorflow.python.ops import control_flow_ops
42from tensorflow.python.ops import control_flow_util as util_v1
43from tensorflow.python.ops import control_flow_util_v2 as util
44from tensorflow.python.ops import default_gradient
45from tensorflow.python.ops import gen_functional_ops
46from tensorflow.python.ops import gen_resource_variable_ops
47from tensorflow.python.ops import gradients_util
48from tensorflow.python.ops import handle_data_util
49from tensorflow.python.ops import list_ops
50from tensorflow.python.ops import math_ops
51from tensorflow.python.ops import tensor_array_ops
52from tensorflow.python.ops import while_v2_indexed_slices_rewriter
53from tensorflow.python.util import compat
54from tensorflow.python.util import nest
55from tensorflow.python.util import object_identity
56
57# pylint: disable=protected-access
58
59# Controls parallelism in the presence of side-effecting ops like variable
60# operations, print, py_function, etc. Can be set to True, False, or
61# "stateless_cond" (default).
62# Note that loops without side-effecting operations always execute with maximum
63# parallelism, ignoring this setting. When False, loops with side-effecting ops
64# execute sequentially, one iteration at a time.
65# When True, loops with side-effecting ops may execute parts of different
66# iterations in parallel; caution: if the loop condition contains
67# side-effecting ops, this mode produces unspecified results.
68# Setting it to "stateless_cond" automatically sets this mode to True when
69# the loop condition is free of side-effecting ops.
70# TODO(b/152548567): Change this to "stateless_cond".
71glob_stateful_parallelism = False
72
73
74def while_loop(cond,
75               body,
76               loop_vars,
77               shape_invariants=None,
78               parallel_iterations=10,
79               maximum_iterations=None,
80               name=None,
81               return_same_structure=True,
82               back_prop=True):
83  """Like tf.while_loop, except emits a single While op."""
84  # Keep the original loop_vars around to know which args were TensorArrays.
85  orig_loop_vars = loop_vars
86  # Cache its length since we use it at multiple places below.
87  len_orig_loop_vars = len(orig_loop_vars)
88
89  # Convert TensorArrays to their flow variables. These get converted back to
90  # TensorArrays before calling `cond` and `body`. See `wrapped_cond` and
91  # `wrapped_body` below.
92  loop_vars = list(_tensor_array_to_flow(orig_loop_vars))
93  loop_vars = nest.map_structure(
94      ops.internal_convert_to_tensor_or_indexed_slices, loop_vars,
95      expand_composites=True)
96  if shape_invariants is not None:
97    nest.assert_same_structure(orig_loop_vars, shape_invariants,
98                               expand_composites=False)
99    signature = nest.map_structure(
100        control_flow_ops._shape_invariant_to_type_spec, loop_vars,
101        list(shape_invariants), expand_composites=False)
102    shape_invariants = nest.map_structure(
103        control_flow_ops._get_shape_invariant, loop_vars,
104        list(shape_invariants), expand_composites=False)
105
106  else:
107    signature = nest.map_structure(
108        type_spec.type_spec_from_value, loop_vars, expand_composites=False)
109    shape_invariants = nest.map_structure(
110        control_flow_ops._get_shape_invariant, loop_vars,
111        expand_composites=False)
112  if not name:
113    name = "while"
114
115  with ops.name_scope(name) as scope:
116    with ops.name_scope(None):
117      cond_name = util.unique_fn_name(scope, "cond")
118      body_name = util.unique_fn_name(scope, "body")
119    maximum_iterations_loop_var = _build_maximum_iterations_loop_var(
120        maximum_iterations)
121    loop_counter = constant_op.constant(
122        0,
123        dtype=maximum_iterations_loop_var.dtype
124        if maximum_iterations is not None else None,
125        name="loop_counter")
126    # Add loop counter needed for computing gradients.
127    loop_vars = [loop_counter, maximum_iterations_loop_var] + loop_vars
128
129    shape_invariants = [tensor_shape.TensorShape([])] * 2 + shape_invariants
130    signature = (
131        [tensor_spec.TensorSpec.from_tensor(loop_counter),
132         tensor_spec.TensorSpec.from_tensor(maximum_iterations_loop_var)] +
133        signature)
134
135    # Automatic control dependencies are added in defuns, but not in v1
136    # graphs. Propagate that behavior here.
137    add_control_dependencies = ops.get_default_graph()._add_control_dependencies
138
139    def wrapped_cond(loop_counter, maximum_iterations_arg, *args):
140      """Extra `cond` wrapper that can handle the extra counter loop_var."""
141      # Convert the flow variables in `args` to TensorArrays. `args` should
142      # already have the same structure as `orig_loop_vars` but currently there
143      # is no nest.zip so we call `_pack_sequence_as` which flattens both
144      # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays
145      # and packs it into the structure of `orig_loop_vars`.
146      pred = cond(*_pack_sequence_as(orig_loop_vars, args))
147      if (tensor_util.is_tf_type(pred) and
148          (pred.shape.dims is None or pred.shape.dims)):
149        pred = array_ops.squeeze_v2(pred)
150
151      if maximum_iterations is None:
152        return pred
153      else:
154        return math_ops.logical_and(
155            loop_counter < maximum_iterations_arg, pred)
156
157    # NOTE(skyewm): we set collections to the outer graph's collections for
158    # compatibility with TPUEstimator.
159    cond_graph = func_graph_module.func_graph_from_py_func(
160        cond_name,
161        wrapped_cond,
162        [],  # We provide signature instead of args.
163        {},
164        signature=signature,
165        func_graph=util.WhileCondFuncGraph(
166            cond_name, collections=ops.get_default_graph()._collections),  # pylint: disable=protected-access
167        add_control_dependencies=add_control_dependencies)
168
169    if glob_stateful_parallelism == "stateless_cond":
170      stateful_parallelism = (not any(
171          op._is_stateful for op in cond_graph.get_operations()))
172    else:
173      stateful_parallelism = glob_stateful_parallelism
174
175    def wrapped_body(loop_counter, maximum_iterations_arg, *args):
176      """Loop body augmented with counter update.
177
178      Args:
179        loop_counter: Loop counter which needs to be incremented in the body.
180        maximum_iterations_arg: Maximum iterations of the loop.
181        *args: List of args
182
183      Returns:
184        A list of tensors the same length as args.
185      """
186      # The function was created with a signature rather than tensors, so
187      # internal placeholders were created without handle data.
188      _copy_handle_data(nest.flatten(loop_vars[2:], expand_composites=True),
189                        nest.flatten(args, expand_composites=True))
190      # Capture the tensors already captured in cond_graph so that they appear
191      # in the same order in body_graph.external_captures.
192      for t in cond_graph.external_captures:
193        ops.get_default_graph().capture(t)
194
195      # Convert the flow variables in `args` to TensorArrays. `args` should
196      # already have the same structure as `orig_loop_vars` but currently there
197      # is no nest.zip so we call `_pack_sequence_as` which flattens both
198      # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays
199      # and packs it into the structure of `orig_loop_vars`.
200      outputs = body(*_pack_sequence_as(orig_loop_vars, args))
201      if not nest.is_sequence_or_composite(outputs):
202        outputs = [outputs]
203      # Compare the structure of input and output of body converting the
204      # top-level tuples to list to be compatible with legacy while_loop.
205      nest.assert_same_structure(list(outputs), list(orig_loop_vars),
206                                 expand_composites=True)
207
208      outputs = _tensor_array_to_flow(outputs)
209
210      # TODO(srbs): Update lowering code to create _Enter nodes with
211      # is_constant=True for inputs that are directly passed to outputs.
212      return [loop_counter + 1, maximum_iterations_arg] + list(outputs)
213
214    body_graph = func_graph_module.func_graph_from_py_func(
215        body_name,
216        wrapped_body,
217        [],  # We provide signature instead of args.
218        {},
219        signature=signature,
220        func_graph=util.WhileBodyFuncGraph(
221            body_name, collections=ops.get_default_graph()._collections),  # pylint: disable=protected-access
222        add_control_dependencies=add_control_dependencies,
223        acd_record_initial_resource_uses=stateful_parallelism)
224    # Add external captures of body to the list of loop vars.
225    # Note that external tensors will be treated as loop invariants, i.e.,
226    # the value of that tensor in each iteration is the same as it was at the
227    # beginning of the loop execution.
228    loop_vars = loop_vars + body_graph.external_captures
229    # TODO(srbs): Update lowering code to create _Enter nodes with
230    # is_constant=True for inputs that are directly passed to outputs.
231    body_graph.outputs.extend(body_graph.internal_captures)
232
233    # Capture the extra `external_captures` of `body_graph` in `cond_graph` so
234    # that it expects to receive those as arguments.
235    with cond_graph.as_default():
236      num_cond_captures = len(cond_graph.external_captures)
237      assert (cond_graph.external_captures ==
238              body_graph.external_captures[:num_cond_captures])
239      _duplicate_body_captures_in_cond(
240          cond_graph, body_graph.external_captures[num_cond_captures:])
241
242    # Make sure that the shapes of the loop outputs are compatible with the
243    # shape invariants, or the shapes of the loop vars if the invariants are not
244    # specified.
245    num_flattened_outputs = len(nest.flatten(orig_loop_vars,
246                                             expand_composites=True))
247    # First var is loop counter and second var is maximum_iterations.
248    first_loop_var_index = 2
249    _check_shapes_compat(
250        body_graph.outputs[first_loop_var_index:first_loop_var_index +
251                           num_flattened_outputs],
252        nest.flatten(
253            shape_invariants[first_loop_var_index:first_loop_var_index +
254                             len_orig_loop_vars], expand_composites=True),
255        nest.flatten(loop_vars[first_loop_var_index:first_loop_var_index +
256                               len_orig_loop_vars], expand_composites=True))
257
258    num_original_outputs = len(body_graph.outputs)
259    if back_prop and util.output_all_intermediates():
260      # Export all tensors in the loop body that may be needed for gradient
261      # computation. We do this by accumulating the intermediate values in
262      # TensorLists.
263      intermediate_tensors = _get_intermediates(body_graph)
264
265      for intermediate_tensor in intermediate_tensors:
266        tensor_list = list_ops.empty_tensor_list(
267            element_dtype=intermediate_tensor.dtype,
268            element_shape=intermediate_tensor.shape,
269            max_num_elements=maximum_iterations)
270        loop_vars.append(tensor_list)
271        with cond_graph.as_default():
272          # Add a placeholder to cond_graph's inputs corresponding to the
273          # tensor_list.
274          cond_graph.capture(tensor_list)
275        with body_graph.as_default():
276          # Push the intermediate tensor to the tensor list. This captures the
277          # `tensor_list` as well.
278          appended_tensor_list = list_ops.tensor_list_push_back(
279              tensor_list, intermediate_tensor)
280          # Add this modified tensor list to the list of outputs.
281          body_graph.outputs.append(appended_tensor_list)
282
283    flattened_loop_vars = nest.flatten(loop_vars, expand_composites=True)
284    _check_num_inputs_outputs(cond_graph, body_graph,
285                              len(flattened_loop_vars))
286    _check_inputs_outputs_types_match(body_graph, flattened_loop_vars)
287
288    with ops.control_dependencies(
289        list(cond_graph.control_captures) + list(body_graph.control_captures)):
290      output_shapes = [t.shape for t in body_graph.outputs]
291      orig_loop_vars_range = slice(first_loop_var_index,
292                                   first_loop_var_index + num_flattened_outputs)
293      output_shapes[orig_loop_vars_range] = nest.flatten(
294          shape_invariants, expand_composites=True)[orig_loop_vars_range]
295
296      outputs = _build_while_op(
297          flattened_loop_vars,
298          cond_graph,
299          body_graph,
300          output_shapes=output_shapes,
301          parallel_iterations=parallel_iterations,
302          name=scope,
303          num_original_outputs=num_original_outputs,
304          stateful_parallelism=stateful_parallelism)
305    if not ops.get_default_graph().building_function:
306      # In V1 graph mode, return identities for each output of the While op,
307      # rather than the output of the While op directly. This makes pruning work
308      # if the output of while_loop() is fetched: the lowering pass converts the
309      # While outputs into IdentityN outputs, which if fetched will cause all
310      # ops in the body to be run (since it takes all exit ops as input). After
311      # lowering, each output identity op will end up with only the appropriate
312      # exit op as input.
313      outputs = tuple(array_ops.identity(t) for t in outputs)
314
315  output_loop_vars = outputs[first_loop_var_index:first_loop_var_index +
316                             num_flattened_outputs]
317  if not back_prop:
318    output_loop_vars = [array_ops.stop_gradient(t) for t in output_loop_vars]
319  outputs = _pack_sequence_as(orig_loop_vars, output_loop_vars)
320
321  if return_same_structure:
322    return outputs
323
324  flattened_outputs = nest.flatten(outputs, expand_composites=True)
325  if len(flattened_outputs) == 1:
326    return flattened_outputs[0]
327  else:
328    return outputs
329
330
331@ops.RegisterGradient("StatelessWhile")
332@ops.RegisterGradient("While")
333def _WhileGrad(op, *grads):  # pylint: disable=invalid-name
334  """The gradient of a While op produced by while_loop."""
335  # Note that op is not always the same as while_op because the gradient tape,
336  # for eager mode compatibility, forgets information about the proper op. Since
337  # the loop cannot run in eager mode, however, we can safely introspect into
338  # the graph here.
339  while_op = op.outputs[0].op
340  cond_graph = _get_graph(while_op, "cond", "_cond_graph")
341  body_graph = _get_graph(while_op, "body", "_body_graph")
342  orig_num_params = len(body_graph.outputs)
343
344  maximum_iterations = op.inputs[1]
345  parallel_iterations = op.get_attr("parallel_iterations")
346
347  try:
348    num_original_outputs = while_op.get_attr("_num_original_outputs")
349  except:  # pylint: disable=bare-except
350    num_original_outputs = len(while_op.outputs)
351
352  try:
353    stateful_parallelism = while_op.get_attr("_stateful_parallelism")
354  except:  # pylint: disable=bare-except
355    stateful_parallelism = False
356
357  num_intermediates = len(while_op.outputs) - num_original_outputs
358  grads = [
359      _preprocess_grad(grad, body_out, while_in, while_out)  # pylint: disable=g-complex-comprehension
360      for grad, body_out, while_in, while_out in zip(
361          grads[:num_original_outputs],
362          body_graph.outputs[:num_original_outputs],
363          while_op.inputs[:num_original_outputs],
364          while_op.outputs[:num_original_outputs])
365  ] + [None] * num_intermediates
366
367  # Skip gradients with respect to the captures whenever possible.
368  if "skip_input_indices" in op.__dict__ and op.skip_input_indices is not None:
369    captures_start_index = (
370        len(body_graph.inputs) - len(body_graph.internal_captures))
371    for i in op.skip_input_indices:
372      if i >= captures_start_index:
373        grads[i] = None
374
375  # We compute the gradient for the sub-graph between trainable ys and xs
376  # with non-None incoming gradients. We later pad the None's to the list of
377  # outputs.
378  ys, xs, non_none_grads = zip(*[(y, x, grad) for (y, x, grad) in zip(
379      body_graph.outputs, body_graph.inputs, grads) if grad is not None])
380
381  body_grad_graph, args = _create_grad_func(
382      ys, xs, non_none_grads, cond_graph, body_graph,
383      util.unique_grad_fn_name(body_graph.name), op, maximum_iterations,
384      stateful_parallelism)
385
386  if body_grad_graph.while_op_needs_rewrite:
387    # Modify 'op' to output the intermediate accumulators needed by the grad
388    # function.
389    # NOTE(skyewm): if there are any active sessions, this modification to `op`
390    # may make them unrunnable!
391
392    cond_graph.name += "_rewritten"
393    body_graph.name += "_rewritten"
394
395    # `body_grad_graph.extra_inputs` here is equivalent to skimming off the new
396    # `body_graph.external_captures` added during `_create_grad_func`.
397    new_inputs = body_grad_graph.extra_inputs
398    new_outputs = body_graph.outputs[orig_num_params:]
399
400    while_op._set_func_attr("cond", util.create_new_tf_function(cond_graph))
401    while_op._set_func_attr("body", util.create_new_tf_function(body_graph))
402    if len(body_graph.output_types) != len(while_op.inputs) + len(new_inputs):
403      # Continuing leads to an invalid graph with disconnected inputs.
404      raise AssertionError(
405          "Inputs and outputs constructed for the forward op of a While "
406          "gradient don't match with 'output_types' at  "
407          f"{len(body_graph.output_types)},'inputs' at length "
408          f"{len(while_op.inputs)}, and 'new_inputs' at length "
409          f"{len(new_inputs)}. This doesn't make sense, please file a bug.")
410    while_op._set_type_list_attr("T", body_graph.output_types)
411    while_op._set_shape_list_attr("output_shapes", body_graph.output_shapes)
412    while_op._add_while_inputs(new_inputs)
413    while_op._add_outputs([t.dtype for t in new_outputs],
414                          [t.shape for t in new_outputs])
415    _copy_handle_data(new_outputs, while_op.outputs[orig_num_params:])
416
417  # Do not ignore grads wrt extra outputs when computing higher order
418  # derivatives.
419  while_op._set_attr("_num_original_outputs",
420                     attr_value_pb2.AttrValue(i=len(while_op.outputs)))
421
422  captured_inputs = _resolve_grad_captures(body_graph, body_grad_graph,
423                                           while_op)
424  loop_vars = args + captured_inputs
425
426  # This modifies body_grad_graph.
427  loop_vars = while_v2_indexed_slices_rewriter.rewrite_grad_indexed_slices(
428      grads, body_grad_graph, loop_vars, while_op.inputs)
429
430  def grad_cond(counter, unused_maximum_iterations_arg, forward_loop_iters,
431                *unused_args):
432    return counter < forward_loop_iters
433
434  grad_cond_name = util.unique_grad_fn_name(op.get_attr("cond").name)
435  cond_grad_graph = func_graph_module.func_graph_from_py_func(
436      grad_cond_name, grad_cond, loop_vars, {},
437      func_graph=util.WhileCondFuncGraph(grad_cond_name))
438
439  _check_num_inputs_outputs(cond_grad_graph, body_grad_graph, len(loop_vars))
440
441  outputs = _build_while_op(
442      loop_vars,
443      cond_grad_graph,
444      body_grad_graph,
445      output_shapes=[t.shape for t in body_grad_graph.outputs],
446      parallel_iterations=parallel_iterations,
447      name="%s_grad" % while_op.name,
448      num_original_outputs=len(body_grad_graph.outputs),
449      stateful_parallelism=stateful_parallelism)
450
451  # See comment in while_loop.
452  outputs = [array_ops.identity(t) for t in outputs]
453  return _get_structured_grad_output(outputs, grads, body_grad_graph)
454
455
456def _build_while_op(loop_vars, cond_graph, body_graph, output_shapes,
457                    parallel_iterations, name, num_original_outputs,
458                    stateful_parallelism):
459  """Builds the functional StatelessWhile/While op."""
460  cond_stateful_ops = [
461      op for op in cond_graph.get_operations() if op._is_stateful
462  ]
463  body_stateful_ops = [
464      op for op in body_graph.get_operations() if op._is_stateful
465  ]
466  if (cond_stateful_ops or body_stateful_ops):
467    op_fn = gen_functional_ops._while
468  else:
469    op_fn = gen_functional_ops.stateless_while
470
471  def _make_op(inputs):
472    while_op, tensors = util.get_op_and_outputs(op_fn(
473        inputs,
474        util.create_new_tf_function(cond_graph),
475        util.create_new_tf_function(body_graph),
476        output_shapes=output_shapes,
477        parallel_iterations=parallel_iterations,
478        name=name))
479    _copy_handle_data(body_graph.outputs, tensors)
480    util.maybe_set_lowering_attr(while_op)
481    util.maybe_propagate_compile_time_consts_in_xla(while_op)
482    _set_read_only_resource_inputs_attr(while_op, [cond_graph, body_graph])
483    # This is needed so we do not compute derivative wrt these extra outputs.
484    while_op._set_attr("_num_original_outputs",
485                       attr_value_pb2.AttrValue(i=num_original_outputs))
486    while_op._set_attr("_stateful_parallelism",
487                       attr_value_pb2.AttrValue(b=stateful_parallelism))
488    # The while op may be created inside a tf.function, in which case ops
489    # needs to capture "through" it when taking gradients; outer_graph is used
490    # as a sanity check that capturing only happens from parent to child.
491    cond_graph.outer_graph = ops.get_default_graph()
492    body_graph.outer_graph = ops.get_default_graph()
493    while_op._cond_graph = cond_graph
494    while_op._body_graph = body_graph
495    return tensors
496  return util.run_as_function_for_tape_gradients(_make_op, loop_vars)
497
498
499def _get_intermediates(func_graph):
500  """Returns all tensors in `func_graph` that should be accumulated."""
501  # We currently accumulate output tensors of most ops in the function and rely
502  # on the pruning pass to get rid of the unused accumulators at runtime.
503  # However, this can bloat the GraphDef and make debugging harder so we perform
504  # some optimizations.
505  #
506  # Optimization we currently perform:
507  # 1. We do not accumulate tensors which already have an accumulator
508  #    in the loop body.
509  # 2. We do not accumulate outputs of Identity nodes. When building the
510  #    FuncGraph, we add an Identity node for each output (see
511  #    `AutomaticControlDependencies.mark_as_return`). Accumulating outputs
512  #    of all these nodes bloats the GraphDef quite a bit so we remove those.
513  #    Since the gradient of an Identity node does not rely on its forward op's
514  #    input this is safe to do.
515  #
516  # Other possible optimizations:
517  # 1. Only accumulate tensors that will be required by the backward pass.
518  #    This will require running the gradient pass and hence would increase the
519  #    graph building time for the forward pass.
520  # 2. Do not accumulate Const nodes created inside the loop body.
521  # 3. Do not accumulate loop vars that are returned as-is just like captured
522  #    tensors.
523  intermediates = []
524  reverse_captures = dict((v.ref(), k) for k, v in func_graph.captures)
525
526  for op in func_graph.get_operations():
527    if op.type == "Identity":
528      continue
529    # Accumulating mutexes can cause deadlock.
530    if op.type == "MutexLock":
531      continue
532    for o in op.outputs:
533      if (o is not func_graph.inputs[0] and  # Loop counter.
534          o.dtype != dtypes.resource and  # Do not accumulate resource tensors.
535          _get_accumulator(o) is None and  # Has existing accumulator.
536          o.ref() not in reverse_captures
537         ):  # Captured value, hence loop invariant.
538        intermediates.append(o)
539  return intermediates
540
541
542def _preprocess_grad(grad, body_graph_output, while_op_input, while_op_output):
543  """Returns the initial gradient to be used for a given output tensor.
544
545  Args:
546    grad: the original gradient Tensor passed to the gradient function.
547    body_graph_output: the corresponding Tensor in the body graph.
548    while_op_input: the corresponding Tensor input of the While op.
549    while_op_output: the corresponding Tensor output of the While op.
550
551  Returns:
552    A Tensor or None.
553  """
554  # Set the incoming gradient of non-trainable inputs to None. It is possible
555  # that we receive non-None gradients for non-trainable types in nested while
556  # loops because we accumulate outputs of the inner while as variant tensors
557  # which are trainable and hence receive zeros_like tensors in the gradient
558  # pass. The non-trainable tensors then receive the popped zeros tensor from
559  # this zeros variant. The gradient for the loop vars corresponding to these
560  # tensors is None or zeros (this happens only if the loop var is accumulated
561  # as well) in _grad_fn so we reset these.
562  # TODO(b/118712257): Remove once we can handle None output grads in _grad_fn.
563  if not _is_trainable(body_graph_output):
564    return None
565
566  # GradientTape initializes resource and variant grads as None instead of
567  # zeros. Set to zeros so _GradientsHelper computes the gradients instead of
568  # returning None.
569  # TODO(b/143286622): The supports_default_grad check is needed
570  # because While op emits non-differentiable resource tensors
571  # as outputs. Remove this check when that is not the case.
572  # Note: We use `while_op_input` instead of `while_op_output` for the call
573  # to `supports_default_grad` because `while_op_output` may be missing
574  # handle_data if the While is in a restored saved model.
575  if (while_op_output.dtype in (dtypes.resource, dtypes.variant) and
576      default_gradient.supports_default_grad(while_op_input) and grad is None):
577    return _zeros_like(while_op_input, while_op_output)
578
579  # Convert IndexedSlices to dense tensors since it is unlikely that downstream
580  # gradient functions with properly handle indexed slices. This is similar to
581  # what we do in tf.function gradients.
582  if isinstance(grad, ops.IndexedSlices):
583    return ops.convert_to_tensor(grad)
584
585  return grad
586
587
588# TODO(skyewm): make this return constants if op_output's shape is fully
589# defined (this can be done by checking the "shape" attr of resource vars).
590def _zeros_like(op_input, op_output):
591  """Like array_ops.zeros_like() but also accepts resource var handles."""
592  if op_output.dtype == dtypes.resource:
593    # Note: We use `op_input` instead of `op_output` to get the zeros dtype
594    # because `op_output` may be missing handle_data if the While is in a
595    # restored saved model.
596    return array_ops.zeros(
597        gen_resource_variable_ops.variable_shape(op_output),
598        dtype=default_gradient.get_zeros_dtype(op_input))
599  return array_ops.zeros_like(op_output)
600
601
602def _is_trainable(tensor):
603  """Returns whether the given tensor is trainable."""
604  if not backprop_util.IsTrainable(tensor):
605    return False
606
607  # Special case: untrainable accumulator output. The gradients algorithm
608  # doesn't know about tensor lists of untrainable elements. In theory the
609  # tensor list gradient functions should return None as appropriate, but
610  # because we can't return None from the gradient function we filter out
611  # untrainable accumulator output here to avoid computing the gradient at all.
612  if tensor.op.type == "TensorListPopBack" and tensor.value_index == 0:
613    assert tensor.dtype == dtypes.variant
614    element_type = tensor.op.get_attr("element_dtype")
615    return backprop_util.IsTrainable(element_type)
616
617  return True
618
619
620def _get_graph(while_op, func_attr_name, attr_graph_name):
621  """Returns `FuncGraph` for the given function attribute.
622
623  Args:
624    while_op: The While Operation.
625    func_attr_name: string
626    attr_graph_name: cached forward graph name
627
628  Returns:
629    `FuncGraph`
630  """
631  func_graph = getattr(while_op, attr_graph_name, None)
632  if func_graph is None:
633    # TODO(srbs): Handle TensorShapeProto in function_def_to_graph.input_shapes.
634    input_shapes = [
635        tensor_shape.TensorShape(s) for s in while_op.get_attr("output_shapes")
636    ]
637    func_name = while_op.get_attr(func_attr_name).name
638    func_graph = util.get_func_graph(while_op, input_shapes, func_name)
639  func_graph._while = while_op
640  return func_graph
641
642
643def _create_grad_func(ys, xs, grads, cond_graph, body_graph, name, while_op,
644                      maximum_iterations, stateful_parallelism):
645  """Builds and returns the gradient FuncGraph of `func_graph` and its args.
646
647  The returned grad_func_graph must be called with the returned
648  args + grad_func_graph.captures.
649
650  Args:
651    ys: A `Tensor` or list of tensors to be differentiated.
652    xs: A `Tensor` or list of tensors to be used for differentiation.
653    grads: The incoming grads for `ys`.
654    cond_graph: FuncGraph for the forward cond function.
655    body_graph: FuncGraph for the forward body function.
656    name: Name of the returned gradient function.
657    while_op: The forward While op.
658    maximum_iterations: Tensor. The maximum number of iterations.
659    stateful_parallelism: Bool, see tf.while_loop.
660
661  Returns:
662    2-tuple of (grad_func_graph, args).
663  """
664  assert len(ys) == len(grads)
665
666  total_iters = while_op.outputs[0]
667  counter = constant_op.constant(
668      0, dtype=total_iters.dtype, name="grad_counter")
669
670  # Build frozen sets so that we do not have linear time lookups in
671  # `_is_loop_invariant`. Note: `body_graph.inputs` and `body_graph.outputs`
672  # may get updated during gradient computation because we add accumulators to
673  # the forward op. However, those are not loop invariants so wouldn't affect
674  # the output of `_is_loop_invariant`. Also we would never attempt to capture
675  # those accumulators so `_is_loop_invariant` should never receive those new
676  # tensors as args.
677  body_graph_inputs = object_identity.ObjectIdentitySet(body_graph.inputs)
678  body_graph_outputs = object_identity.ObjectIdentitySet(body_graph.outputs)
679
680  args = [counter, maximum_iterations, total_iters] + list(grads)
681  # Note: The returned function does not have `args` in the list of
682  # `external_captures`.
683  grad_func_graph = func_graph_module.func_graph_from_py_func(
684      name,
685      lambda *args: _grad_fn(ys, xs, args, body_graph),
686      args, {},
687      func_graph=_WhileBodyGradFuncGraph(name, cond_graph, body_graph,
688                                         maximum_iterations, while_op,
689                                         body_graph_inputs, body_graph_outputs),
690      acd_record_initial_resource_uses=stateful_parallelism)
691
692  # Update the list of outputs with tensors corresponding to the captured
693  # tensors. We capture 3 types of tensors when building the grad fn:
694  # 1. Accumulators for forward graph intermediates which are not loop
695  #    invariants. The outputs corresponding to these are populated in
696  #    `internal_capture_to_output` by `_WhileBodyGradFuncGraph`.
697  # 2. Resources, which are output as is.
698  # 3. Forward graph loop invariants, which are output as is.
699  for external_capture, internal_capture in grad_func_graph.captures:
700    if (ops.tensor_id(internal_capture)
701        in grad_func_graph.internal_capture_to_output):
702      new_output = grad_func_graph.internal_capture_to_output[ops.tensor_id(
703          internal_capture)]
704    else:
705      raise ValueError(
706          f"Tensor {str(internal_capture)} which captures "
707          f"{str(external_capture)} is in list of "
708          f"internal_captures but not in internal_capture_to_output.")
709    grad_func_graph.outputs.append(new_output)
710    grad_func_graph.structured_outputs.append(new_output)
711
712  return grad_func_graph, args
713
714
715def _grad_fn(ys, xs, args, func_graph):
716  """Computes the gradient of `func_graph` in the current graph.
717
718  This function builds the gradient graph of the corresponding forward-pass
719  `func_graph` by differentiating `func_graph`'s outputs w.r.t. its inputs.
720
721  Args:
722    ys: A `Tensor` or list of tensors to be differentiated.
723    xs: A `Tensor` or list of tensors to be used for differentiation.
724    args: The input arguments.
725      args[0] - Loop counter
726      args[1] - Total number of iterations.
727      args[2] - maximum_iterations.
728      args[3:] - Incoming gradients for `ys`.
729    func_graph: function.FuncGraph. The corresponding forward-pass function.
730
731  Returns:
732    The output gradient Tensors.
733  """
734  grad_ys = args[3:]
735
736  # Build the gradient graph. Note that this builds the gradient computation of
737  # func_graph in the current graph, which requires capturing tensors from
738  # func_graph. The captured func_graph tensors are resolved to external tensors
739  # after the forward While op has been rewritten in _resolve_grad_captures.
740  # TODO(srbs): Mark GradientsHelper as public?
741  grad_outs = gradients_util._GradientsHelper(
742      ys, xs, grad_ys=grad_ys, src_graph=func_graph,
743      unconnected_gradients="zero")
744
745  # TODO(b/118712257): Handle the case when grad_outs has None's e.g. when there
746  # is a tf.StopGradient in the loop body.
747  assert all(g is not None for g in grad_outs)
748  counter = args[0]
749  maximum_iterations = args[1]
750  total_iters = args[2]
751  return [counter + 1, maximum_iterations, total_iters] + grad_outs
752
753
754def _resolve_grad_captures(body_graph, body_grad_graph, while_op):
755  """Returns the tensors to pass as captured inputs to `body_grad_graph`.
756
757  `body_grad_graph` may have external references to:
758  1. Its outer graph containing the input gradients. These are left as-is.
759  2. Accumulators captured from the forward-pass graph. These should have been
760     added as `while_op` outputs after the gradient graph was built. We replace
761     these with the corresponding output of `while_op`, i.e. a tensor in
762     `body_graph.outer_graph`. In the case of nested control flow or functions,
763     the gradient logic handling `body_grad_graph.outer_graph` will make sure
764     the tensor from `body_graph.outer_graph` is also correctly captured.
765
766  Args:
767    body_graph: FuncGraph. The forward-pass body function.
768    body_grad_graph: FuncGraph. The body gradients function.
769    while_op: The forward-pass While Operation calling `body_graph`.
770
771  Returns:
772    A list of input tensors to be passed as the captured inputs to
773    `body_grad_graph`.
774  """
775  new_capture_inputs = []
776  for t in body_grad_graph.external_captures:
777    # Resolve tensors captured from the forward graph to the outputs of the
778    # forward while_op.
779    if t.graph == body_graph:
780      # Captured accumulator or loop invariant.
781      for i, output in enumerate(t.graph.outputs):
782        if output is t:
783          t = while_op.outputs[i]
784          break
785
786      # Note: We rely on the capturing logic of the gradient While op graph to
787      # correctly capture the tensors in `body_graph.outer_graph`. Both cond_v2
788      # and while_v2 handle this while building their gradient functions.
789      assert t.graph == body_graph.outer_graph
790
791    new_capture_inputs.append(t)
792  return new_capture_inputs
793
794
795def _get_structured_grad_output(outputs, grads, body_grad_graph):
796  """Returns the values that should be returned from the while grad function.
797
798  Args:
799    outputs: the raw Tensor outputs of the grad While op.
800    grads: the input gradients to the gradient function.
801    body_grad_graph: _WhileBodyGradFuncGraph.
802
803  Returns:
804    A list of gradient values. May include Nones.
805  """
806  result = []
807  # outputs[0] is the loop counter.
808  # outputs[1] is maximum_iterations.
809  # outputs[2] is the total number of loop iterations.
810  outputs_idx = 3
811  structured_outputs_idx = 3
812  for g in grads:
813    # Set None as the output gradient for tensors with None input gradient.
814    if g is None:
815      result.append(None)
816      continue
817    output = body_grad_graph.structured_outputs[structured_outputs_idx]
818    structured_outputs_idx += 1
819    if isinstance(output, ops.IndexedSlices):
820      # TODO(skyewm): is there a more robust way to determine the order of
821      # flattened IndexedSlices components?
822      result.append(ops.IndexedSlices(
823          values=outputs[outputs_idx],
824          indices=outputs[outputs_idx + 1],
825          dense_shape=outputs[outputs_idx + 2]))
826      outputs_idx += 3
827    else:
828      assert isinstance(output, ops.Tensor)
829      result.append(outputs[outputs_idx])
830      outputs_idx += 1
831
832  return result
833
834
835def _get_accumulator(tensor):
836  r"""Returns TensorList if any containing accumulated values of tensor.
837
838  We try to find a pattern of the form:
839
840     input_tl   tensor
841        \        /
842    (TensorListPushBack)
843            |
844        output_tl
845
846  which satisfies the following conditions:
847
848  1. input_tl must be in tensor.graph.inputs.
849  2. output_tl or Identity(output_tl) must be in tensor.graph.outputs.
850  3. tensor.graph.input_index(input_tl) == tensor.graph.output_index(output_t).
851
852  output_tl or Identity(output_tl) (whichever is in tensor.graph.outputs) is
853  returned if such a pattern is found else None is returned.
854
855  Args:
856    tensor: The Tensor to be accumulated.
857
858  Returns:
859    A variant tensor in the same graph as `tensor` or None if no accumulator is
860    found.
861  """
862  assert isinstance(tensor.graph, func_graph_module.FuncGraph)
863
864  def get_func_graph_output(t):
865    """Returns t or Identity(t) whichever exists in graph outputs else None."""
866    for output in tensor.graph.outputs:
867      if output is t:
868        return t
869    # tf.defun adds an Identity for each output, check whether that is the case.
870    identity_op = t.consumers()[0]
871    if (identity_op.type == "Identity" and
872        any(identity_op.outputs[0] is t for t in tensor.graph.outputs)):
873      return identity_op.outputs[0]
874    return None
875
876  for consumer in tensor.consumers():
877    # Find the consumer that is a TensorListPushBack node whose TensorList input
878    # is in the list of function inputs.
879    if consumer.type != "TensorListPushBack":
880      continue
881
882    accum_input_idx = -1
883    for accum_input_idx, inp in enumerate(tensor.graph.inputs):
884      if inp is consumer.inputs[0]:
885        break
886    else:
887      continue
888
889    output = get_func_graph_output(consumer.outputs[0])
890    if output is None:
891      # The TensorList output of `consumer` is not in the list of function
892      # outputs.
893      continue
894
895    for accum_output_idx, out in enumerate(tensor.graph.outputs):
896      if out is output:
897        if accum_input_idx == accum_output_idx:
898          return output
899        break
900
901  return None
902
903
904OptimizedReductionOpsCacheKey = collections.namedtuple(
905    "OptimizedReductionOpsCacheKey", [
906        "op_type",
907        "inputs",
908        "dtypes",
909        "input_types",
910        "name",
911        "attrs",
912        "op_def",
913        "compute_device",
914    ])
915
916
917class _WhileBodyGradFuncGraph(util.WhileBodyFuncGraph):
918  """FuncGraph for the gradient function of the body of a While op.
919
920  Contains the logic for capturing the tensors from the body of the forward
921  While op which is as follows:
922  1. If the tensor is of resource type (these are not accumulated):
923     a. Ensure that the tensor is a loop invariant, i.e., it exists in both loop
924        inputs and outputs at the same index.
925     b. Lookup the corresponding resource tensor in the forward outer graph and
926        try to capture that.
927  2. If the tensor is not of resource type:
928     a. Create an accumulator for that tensor and output it from the forward
929        pass. Note this also requires adding it as an input to the forward pass.
930     b. Capture the accumulator from the forward pass in this FuncGraph. This
931        will later be resolved to the correct output of the forward While op.
932     c. Pop a value from the captured placeholder and use it as the captured
933        value for the forward pass tensor.
934
935  This only allows capturing tensors in the forward graph. A ValueError is
936  raised if an attempt is made to capture a tensor not in the forward graph.
937  To manually capture a tensor that is not in the forward graph, call `capture`
938  with `allowlisted=True`.
939
940  Note: The `captures` dict does not contain the forward tensor since it is not
941  directly captured. It contains the accumulator corresponding to this forward
942  tensor.
943
944  Attributes:
945    while_op_needs_rewrite: True if any non-resource intermediates were
946      captured, meaning the forward While op needs to be rewritten to output the
947      corresponding accumulators.
948    extra_inputs: list of EmptyTensorList tensors to be used as initial input to
949    the new accumulators in the forward graph. It may also contain external
950    captures of the custom gradient function.
951    internal_capture_to_output: dict from a tensor_id(captured placeholder) to
952      the corresponding tensor that needs to be added to the list of outputs.
953      For instance, when capturing an accumulator TensorList this contains the
954      TensorList obtained after popping a tensor from the list. Other entries
955      in this dict are expected, though not enforced, to be identities.
956      This dict is needed because these output tensors need to be added to
957      FuncGraph.outputs "after" the tensors returned from the gradient function.
958  """
959
960  def __init__(self, name, forward_cond_graph, forward_body_graph,
961               maximum_iterations, forward_while_op, body_graph_inputs,
962               body_graph_outputs):
963    super(_WhileBodyGradFuncGraph, self).__init__(name)
964    self.extra_inputs = []
965    self.internal_capture_to_output = {}
966    # FuncGraph for the body of the forward While op.
967    self._forward_graph = forward_body_graph
968    # FuncGraph for the cond of the forward While op.
969    self._forward_cond_graph = forward_cond_graph
970    self._maximum_iterations = maximum_iterations
971    self._forward_while_op = forward_while_op
972    # Dict from forward intermediate tensor to its indirectly captured tensor
973    # in this graph. Indirect capturing happens in two ways:
974    # 1. For non-resource tensors we capture their accumulators from the forward
975    #    outer graph and pop values from that accumulator inside this graph
976    #    using TensorListPopBack.
977    # 2. For resource tensors we directly capture their corresponding tensor
978    #    in the forward outer graph.
979    self._indirect_captures = {}
980
981  @property
982  def while_op_needs_rewrite(self):
983    return self.extra_inputs
984
985  def _create_op_internal(
986      self,
987      op_type,
988      inputs,
989      dtypes=None,  # pylint: disable=redefined-outer-name
990      input_types=None,
991      name=None,
992      attrs=None,
993      op_def=None,
994      compute_device=True):
995    # For a reduction op, if op is in the gradient body graph and its input is
996    # from the forward graph, moving op to the forward graph means we would
997    # store the tensor after the reduction as opposed to the tensor before
998    # reduction, and therefore could significantly reduce memory consumption.
999    # For now, we do this only for a few ops.
1000    #
1001    # We don't do this if any input tensor has already been accumulated. This
1002    # can happen if we output all intermediates in the forward pass.
1003    #
1004    # If in XLA context, do not move constant ops to forward pass as pushing to
1005    # and popping from a TensorList removes the constant property of an op and
1006    # breaks XLA compilation, which requires certain inputs to be compile-time
1007    # constant for certain ops.
1008    #
1009    # This optimization is currently also disabled when under a persistent tape,
1010    # since it leads to an unbounded number of side outputs. With caching it may
1011    # be possible to re-enable it.
1012    optimized_reduction_ops = {
1013        "Shape", "Size", "Rank", "TensorListElementShape", "TensorListLength"
1014    }
1015    if (op_type in optimized_reduction_ops and
1016        not util.output_all_intermediates() and
1017        all(input.graph is self._forward_graph for input in inputs) and
1018        all(_get_accumulator(input) is None for input in inputs) and
1019        not util_v1.GraphOrParentsInXlaContext(self._forward_graph) and
1020        not util.graph_wrapped_for_higher_order_tape_gradients(
1021            self._forward_graph)):
1022      return self._move_op_to_forward_graph(
1023          op_type,
1024          inputs,
1025          dtypes=dtypes,
1026          input_types=input_types,
1027          name=name,
1028          attrs=attrs,
1029          op_def=op_def,
1030          compute_device=compute_device)
1031
1032    return super(_WhileBodyGradFuncGraph, self)._create_op_internal(
1033        op_type,
1034        inputs,
1035        dtypes=dtypes,
1036        input_types=input_types,
1037        name=name,
1038        attrs=attrs,
1039        op_def=op_def,
1040        compute_device=compute_device)
1041
1042  def _move_op_to_forward_graph(
1043      self,
1044      op_type,
1045      inputs,
1046      dtypes=None,  # pylint: disable=redefined-outer-name
1047      input_types=None,
1048      name=None,
1049      attrs=None,
1050      op_def=None,
1051      compute_device=True):
1052    # We have a cache of reduction ops that have already been moved to the
1053    # forward graph, and we will check it first to avoid moving an op twice.
1054    if not hasattr(self._forward_graph, "_optimized_reduction_ops_cache"):
1055      self._forward_graph._optimized_reduction_ops_cache = {}
1056    cache_key = self._get_optimized_reduction_ops_cache_key(
1057        op_type, inputs, dtypes, input_types, name, attrs, op_def,
1058        compute_device)
1059    cached_op = self._forward_graph._optimized_reduction_ops_cache.get(
1060        cache_key)
1061    if cached_op is not None:
1062      # This op has already been moved to the forward graph and we have it in
1063      # the cache.
1064      return cached_op
1065
1066    with self._forward_graph.as_default():
1067      # `name` was built using name_scope stack of gradient graph and may not
1068      # be unique in the forward graph. `Graph.create_op` does not uniquify
1069      # names which are name scopes i.e. end in `/`. To ensure that the op
1070      # created gets a unique name in the forward graph we get rid of the
1071      # trailing slash.
1072      name = ops.name_from_scope_name(name)
1073      result = self._forward_graph._create_op_internal(
1074          op_type,
1075          inputs,
1076          dtypes=dtypes,
1077          input_types=input_types,
1078          name=name,
1079          attrs=attrs,
1080          op_def=op_def,
1081          compute_device=compute_device)
1082
1083      # Store the op we just moved to the forward graph so that it does
1084      # not need to be added there again.
1085      self._forward_graph._optimized_reduction_ops_cache[cache_key] = result
1086      return result
1087
1088  def _get_optimized_reduction_ops_cache_key(
1089      self,
1090      op_type,
1091      inputs,
1092      dtypes=None,  # pylint: disable=redefined-outer-name
1093      input_types=None,
1094      name=None,
1095      attrs=None,
1096      op_def=None,
1097      compute_device=True):
1098    # We need all elements of CacheKey to be hashable.
1099    inputs = tuple(map(lambda t: t.ref(), inputs))
1100
1101    if dtypes is not None:
1102      dtypes = tuple(dtypes)
1103
1104    if input_types is not None:
1105      input_types = tuple(input_types)
1106
1107    if attrs is not None:
1108      hashable_attrs = []
1109      for attr_name, attr_value in sorted(attrs.items()):
1110        hashable_attrs.append((attr_name, attr_value.SerializeToString()))
1111      attrs = tuple(hashable_attrs)
1112
1113    if op_def is not None:
1114      op_def = op_def.SerializeToString()
1115
1116    return OptimizedReductionOpsCacheKey(op_type, inputs, dtypes, input_types,
1117                                         name, attrs, op_def, compute_device)
1118
1119  def _capture_helper(self, tensor, name):
1120    """Implements the capturing described in the class docstring."""
1121    captured_tensor = self._indirect_captures.get(ops.tensor_id(tensor))
1122    if captured_tensor is not None:
1123      return captured_tensor
1124
1125    if tensor.graph is not self._forward_graph:
1126      already_captured = self.captured(tensor)
1127      captured_tensor = super(_WhileBodyGradFuncGraph, self)._capture_helper(
1128          tensor, name)
1129      if not already_captured:
1130        # Adds the captured tensor to the list of outputs so that the input
1131        # and output signatures match.
1132        self.internal_capture_to_output[ops.tensor_id(
1133            captured_tensor)] = captured_tensor
1134        self._indirect_captures[ops.tensor_id(tensor)] = captured_tensor
1135      return captured_tensor
1136
1137    while tensor.op.type == "Identity":
1138      # We do not accumulate the output of identity nodes so we try to capture
1139      # the input of the Identity node instead.
1140      tensor = tensor.op.inputs[0]
1141
1142    captured_tensor = self._indirect_captures.get(ops.tensor_id(tensor))
1143    if captured_tensor is not None:
1144      return captured_tensor
1145
1146    # No need to accumulate loop invariants. Capture them directly.
1147    # The captured tensor gets resolved to the corresponding while output in
1148    # `_resolve_grad_captures`.
1149    if _is_loop_invariant(tensor, self._forward_graph.inputs,
1150                          self._forward_graph.outputs):
1151      captured_tensor = super(_WhileBodyGradFuncGraph,
1152                              self)._capture_helper(tensor, name)
1153      # Add to `internal_capture_to_output` so that this gets added to the list
1154      # of outputs.
1155      self.internal_capture_to_output[ops.tensor_id(
1156          captured_tensor)] = captured_tensor
1157      self._indirect_captures[ops.tensor_id(tensor)] = captured_tensor
1158      return captured_tensor
1159
1160    # Do not accumulate Const nodes. Instead copy them directly in the backward
1161    # graph.
1162    # TODO(srbs): This just checks for `Const` nodes. Consider checking for
1163    # graph compile time consts in general.
1164    # TODO(srbs): Consider making this a loop input.
1165    if constant_op.is_constant(tensor):
1166      real_value = constant_op.constant(
1167          tensor_util.constant_value(tensor), dtype=tensor.dtype)
1168      self._indirect_captures[ops.tensor_id(tensor)] = real_value
1169      return real_value
1170
1171    # Resource tensors are not accumulated and handled specially.
1172    if tensor.dtype == dtypes.resource:
1173      return self._resource_capture_helper(tensor)
1174
1175    # Create or find an existing accumulator output for `tensor` in the forward
1176    # graph, and fetch from this accumulator in the gradient graph to get the
1177    # raw intermediate value.
1178    accumulator = _get_accumulator(tensor)
1179    if accumulator is None:
1180      # Create the initial empty tensor list.
1181      #
1182      # Note: We clear the control dependencies to avoid a cycle in case a
1183      # control tensor has an input path to an output of the  forward While.
1184      #
1185      # E.g.:
1186      # x = tf.while_loop(...)
1187      # y = f(x)
1188      # with tf.control_dependencies([y]):
1189      #   tf.gradients(y, x)
1190      #
1191      # Since the EmptyTensorList is fed back into the forward While, not
1192      # removing the control edge would cause a cycle.
1193      with self._forward_graph.outer_graph.as_default():
1194        with util.clear_control_inputs():
1195          tensor_list = list_ops.empty_tensor_list(
1196              element_dtype=tensor.dtype,
1197              element_shape=tensor.shape,
1198              max_num_elements=self._maximum_iterations,
1199              name=_build_accumulator_name(tensor))
1200      self.extra_inputs.append(tensor_list)
1201
1202      # Push the intermediate tensor to the tensor list. This captures
1203      # `tensor_list`.
1204      with self._forward_graph.as_default():
1205        accumulator = list_ops.tensor_list_push_back(tensor_list, tensor)
1206      # Add the modified tensor list to the list of outputs. This output will be
1207      # all the accumulated values.
1208      self._forward_graph.outputs.append(accumulator)
1209
1210      # Capture in the cond graph as well so the forward cond and body inputs
1211      # match.
1212      with self._forward_cond_graph.as_default():
1213        self._forward_cond_graph.capture(tensor_list)
1214
1215    # Capture the accumulator tensor list in the gradient graph directly from
1216    # the forward graph -- we'll later modify this to capture the final list
1217    # output by the forward While op instead.
1218    captured_accumulator = super(_WhileBodyGradFuncGraph, self)._capture_helper(
1219        accumulator, name)
1220
1221    # Pop the intermediate value from the tensor list in the gradient graph.
1222    new_tensor_list, captured_tensor = list_ops.tensor_list_pop_back(
1223        captured_accumulator, element_dtype=tensor.dtype)
1224
1225    self._indirect_captures[ops.tensor_id(tensor)] = captured_tensor
1226    self.internal_capture_to_output[ops.tensor_id(
1227        captured_accumulator)] = new_tensor_list
1228    return captured_tensor
1229
1230  def _resource_capture_helper(self, tensor):
1231    """Returns the captured resource tensor.
1232
1233    Resource-type tensors are not accumulated. If a resource tensor exists in
1234    the loop body it must either be a loop input or an output of a nested While
1235    op inside the loop body which had captured the external resource.
1236
1237    Args:
1238      tensor: the external resource Tensor to be captured.
1239
1240    Returns:
1241      Tensor in this graph.
1242    """
1243    assert tensor.dtype == dtypes.resource
1244
1245    forward_graph_input_names = [t.name for t in self._forward_graph.inputs]
1246    forward_graph_name_to_opdef = {
1247        op.name: op.node_def for op in self._forward_graph.get_operations()}
1248    index = util.resource_input_index(
1249        tensor.name, forward_graph_input_names,
1250        forward_graph_name_to_opdef,
1251        self._forward_graph._functions)
1252
1253    input_placeholder = self._forward_graph.inputs[index]
1254    tensor_in_outer_graph = self._forward_graph._while.inputs[index]
1255
1256    assert input_placeholder.dtype == dtypes.resource
1257    assert tensor_in_outer_graph.dtype == dtypes.resource
1258    # This must be a loop invariant. However, infrastructure
1259    # (e.g. tf.vectorized_map) may insert identity nodes, function calls, conds,
1260    # etc. which take and return the resource tensor unmodified; this means that
1261    # the Python objects may differ.
1262    if index != util.resource_input_index(
1263        self._forward_graph.outputs[index].name, forward_graph_input_names,
1264        forward_graph_name_to_opdef,
1265        self._forward_graph._functions):
1266      raise AssertionError(
1267          f"Resource tensors must be loop invariants {tensor_in_outer_graph}")
1268
1269    self._indirect_captures[ops.tensor_id(tensor)] = self.capture(
1270        tensor_in_outer_graph)
1271    return self._indirect_captures[ops.tensor_id(tensor)]
1272
1273
1274def _check_shapes_compat(output_tensors, shape_invariants, input_tensors):
1275  for (t, shape, input_t) in zip(output_tensors, shape_invariants,
1276                                 input_tensors):
1277    if not control_flow_ops._ShapeLessThanOrEqual(t.shape, shape):
1278      raise ValueError(
1279          f"Input tensor `{input_t.name}` enters the loop with shape {shape}, "
1280          f"but has shape {t.shape} after one iteration. To allow the shape to "
1281          "vary across iterations, use the `shape_invariants` argument of "
1282          "tf.while_loop to specify a less-specific shape.")
1283
1284
1285def _check_num_inputs_outputs(cond_graph, body_graph, num_flattened_loop_vars):
1286  """Checks the number of inputs/outputs of `cond_graph` and `body_graph`."""
1287  assert len(cond_graph.inputs) == num_flattened_loop_vars, (
1288      "cond_graph takes %d inputs; Expected: %d" % (len(cond_graph.inputs),
1289                                                    num_flattened_loop_vars))
1290  assert len(cond_graph.outputs) == 1, (
1291      "cond_graph has %d outputs; Expected: 1" % len(cond_graph.outputs))
1292  assert len(body_graph.inputs) == num_flattened_loop_vars, (
1293      "body_graph takes %d inputs; Expected: %d" % (len(body_graph.inputs),
1294                                                    num_flattened_loop_vars))
1295  assert len(body_graph.outputs) == num_flattened_loop_vars, (
1296      "body_graph has %d outputs; Expected: %d" % (len(body_graph.outputs),
1297                                                   num_flattened_loop_vars))
1298
1299
1300def _check_inputs_outputs_types_match(body_graph, flattened_loop_vars):
1301  for inp, out, loop_var in zip(body_graph.inputs, body_graph.outputs,
1302                                flattened_loop_vars):
1303    if inp.dtype != out.dtype:
1304      raise TypeError(
1305          f"Loop var {loop_var.name} enters the loop with type {inp.dtype} "
1306          f"but has type {out.dtype} after 1 iteration. {loop_var.name} type "
1307          "should remain constant.")
1308
1309
1310def _build_cond_placeholders_name_prefix(cond_graph):
1311  return cond_graph.unique_name(cond_graph.name + "___redundant_placeholder")
1312
1313
1314def _duplicate_body_captures_in_cond(cond_graph, body_graph_captures):
1315  """Creates placeholders for body captures in cond_graph.
1316
1317  This is needed to match signatures of cond and body graphs.
1318
1319  Args:
1320    cond_graph: cond branch graph
1321    body_graph_captures: Tensors which were captured when building the
1322      `body_graph`.
1323  """
1324  types = [t.dtype.as_datatype_enum for t in body_graph_captures]
1325  # TODO(srbs): Providing a unique prefix does not ensure that there is no
1326  # conflict between the placeholder names and existing nodes in the graph.
1327  # However passing a list of strings may not be performant.
1328  # Ideally we should move `Graph.unique_name` to C++ or make
1329  # `Graph._names_in_use` a trie so that we can find a unique prefix.
1330  # TODO(b/143286622): This should not be required once captures are separated
1331  # from regular loop vars.
1332  placeholders = c_api.TF_CreatePlaceholders(
1333      cond_graph._c_graph, types,
1334      compat.as_str(_build_cond_placeholders_name_prefix(cond_graph)))
1335  placeholder_ops = [
1336      _OperationWithOutputs(ph.oper, cond_graph)
1337      for ph in placeholders
1338  ]
1339
1340  tensors = []
1341  for op, ph, dtype in zip(placeholder_ops, placeholders, types):
1342    tensor = ops.Tensor._create_with_tf_output(op, 0, dtype, ph)
1343    op._outputs = [tensor]
1344    tensors.append(tensor)
1345
1346  # Update `cond_graph._captures` and `cond_graph.inputs` to contain the
1347  # newly created placeholders.
1348  tuples = zip(body_graph_captures, tensors)
1349  keys = [id(t) for t in body_graph_captures]
1350  cond_graph._captures.update(zip(keys, tuples))
1351  cond_graph.inputs.extend(tensors)
1352
1353
1354def _copy_handle_data(src_tensors, tgt_tensors):
1355  for src_t, tgt_t in zip(src_tensors, tgt_tensors):
1356    handle_data_util.copy_handle_data(src_t, tgt_t)
1357
1358
1359def _graph_name(graph):
1360  if isinstance(graph, func_graph_module.FuncGraph):
1361    return graph.name
1362  return "Base"
1363
1364
1365def _pack_sequence_as(structure_with_tas, loop_vars):
1366  """Like `nest.pack_sequence_as` but also replaces flows with TensorArrays."""
1367
1368  def flow_to_tensor_array(flow, ta):  # pylint: disable=missing-docstring
1369    return (tensor_array_ops.build_ta_with_new_flow(ta, flow) if isinstance(  # pylint: disable=g-long-ternary
1370        ta, tensor_array_ops.TensorArray) else flow)
1371
1372  flattened_loop_vars = [
1373      flow_to_tensor_array(*z)
1374      for z in zip(nest.flatten(loop_vars, expand_composites=True),
1375                   nest.flatten(structure_with_tas, expand_composites=True))
1376  ]
1377  return nest.pack_sequence_as(structure_with_tas, flattened_loop_vars,
1378                               expand_composites=True)
1379
1380
1381def _tensor_array_to_flow(loop_vars):
1382
1383  def f(maybe_ta):
1384    if isinstance(maybe_ta, tensor_array_ops.TensorArray):
1385      return maybe_ta.flow
1386    return maybe_ta
1387
1388  return nest.map_structure(f, loop_vars, expand_composites=True)
1389
1390
1391def _build_maximum_iterations_loop_var(maximum_iterations):
1392  if maximum_iterations is None:
1393    # Default value for max_num_elements to EmptyTensorList meaning that the
1394    # list size is unbounded.
1395    maximum_iterations = -1
1396  # EmptyTensorList expects `max_num_elements` to be of type int32.
1397  return ops.convert_to_tensor(
1398      maximum_iterations, dtype=dtypes.int32, name="maximum_iterations")
1399
1400
1401def _build_accumulator_name(tensor):
1402  # Tensor name may be of the form "pow/y:0". Name scope does not allow ":".
1403  return "{}/accumulator".format(tensor.name).replace(":", "_")
1404
1405
1406def _is_loop_invariant(tensor, inputs, outputs):
1407  return (any(tensor is t for t in inputs) and
1408          any(tensor is t for t in outputs))
1409
1410
1411class _OperationWithOutputs(ops.Operation):
1412  """Operation with pre-built `TF_Output`s.
1413
1414  The C API for creating the extra placeholders for the cond graph returns
1415  SWIG wrapped TF_Output* pointers which we can use directly for
1416  `Operation.outputs`. The default constructor for `Operation` does not provide
1417  a way of specifying pre-built output tensors and always creates them. This is
1418  a performance overhead. It is not clear if adding that feature to the
1419  `Operation` API would be generally useful so for now we just have our own
1420  lightweight `Operation` implementation. Note that this does not extract a
1421  stacktrace as well since we don't expect this operation to be used.
1422
1423  TODO(b/143286622): This should not be required once captures are separated
1424  from regular loop vars.
1425  """
1426
1427  def __init__(self, c_op, g):
1428    self._c_op = c_op
1429    self._graph = g
1430    self._outputs = None  # Initialized by _duplicate_body_captures_in_cond().
1431    self._id_value = g._add_op(self, self.name)
1432    self._is_stateful = False
1433
1434
1435def _set_read_only_resource_inputs_attr(op, branch_graphs):
1436  """Sets the list of resource inputs which are read-only.
1437
1438  This is used by AutomaticControlDependencies.
1439
1440  Args:
1441    op: While Operation.
1442    branch_graphs: List of branch FuncGraphs.
1443  """
1444  read_only_indices = set(range(len(op.inputs)))
1445  for branch_graph in branch_graphs:
1446    if not read_only_indices:
1447      break
1448    branch_read_only_indices = acd.get_read_only_resource_input_indices_graph(
1449        branch_graph)
1450    read_only_indices = read_only_indices.intersection(branch_read_only_indices)
1451
1452  ops.set_int_list_attr(op, acd.READ_ONLY_RESOURCE_INPUTS_ATTR,
1453                        sorted(read_only_indices))
1454
1455# pylint: enable=protected-access
1456