• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""while_v2 and gradient.
16
17This is a version of while_loop that emits a single While op, as well as the
18gradient function for While ops produced by while_loop. This will eventually
19replace the current tf.while_loop implementation once it reaches feature and
20performance parity.
21"""
22from __future__ import absolute_import
23from __future__ import division
24from __future__ import print_function
25
26import collections
27
28from tensorflow.core.framework import attr_value_pb2
29from tensorflow.python.client import pywrap_tf_session as c_api
30from tensorflow.python.eager import backprop_util
31from tensorflow.python.framework import auto_control_deps_utils as acd
32from tensorflow.python.framework import constant_op
33from tensorflow.python.framework import dtypes
34from tensorflow.python.framework import func_graph as func_graph_module
35from tensorflow.python.framework import ops
36from tensorflow.python.framework import tensor_shape
37from tensorflow.python.framework import tensor_spec
38from tensorflow.python.framework import tensor_util
39from tensorflow.python.framework import type_spec
40from tensorflow.python.ops import array_ops
41from tensorflow.python.ops import control_flow_ops
42from tensorflow.python.ops import control_flow_util as util_v1
43from tensorflow.python.ops import control_flow_util_v2 as util
44from tensorflow.python.ops import custom_gradient
45from tensorflow.python.ops import default_gradient
46from tensorflow.python.ops import gen_functional_ops
47from tensorflow.python.ops import gen_resource_variable_ops
48from tensorflow.python.ops import gradients_util
49from tensorflow.python.ops import list_ops
50from tensorflow.python.ops import math_ops
51from tensorflow.python.ops import tensor_array_ops
52from tensorflow.python.ops import while_v2_indexed_slices_rewriter
53from tensorflow.python.util import compat
54from tensorflow.python.util import nest
55from tensorflow.python.util import object_identity
56
57# pylint: disable=protected-access
58
59
60def while_loop(cond,
61               body,
62               loop_vars,
63               shape_invariants=None,
64               parallel_iterations=10,
65               maximum_iterations=None,
66               name=None,
67               return_same_structure=True,
68               back_prop=True):
69  """Like tf.while_loop, except emits a single While op."""
70  # Keep the original loop_vars around to know which args were TensorArrays.
71  orig_loop_vars = loop_vars
72  # Cache its length since we use it at multiple places below.
73  len_orig_loop_vars = len(orig_loop_vars)
74
75  # Convert TensorArrays to their flow variables. These get converted back to
76  # TensorArrays before calling `cond` and `body`. See `wrapped_cond` and
77  # `wrapped_body` below.
78  loop_vars = list(_tensor_array_to_flow(orig_loop_vars))
79  loop_vars = nest.map_structure(
80      ops.internal_convert_to_tensor_or_indexed_slices, loop_vars,
81      expand_composites=True)
82  if shape_invariants is not None:
83    nest.assert_same_structure(orig_loop_vars, shape_invariants,
84                               expand_composites=False)
85    signature = nest.map_structure(
86        control_flow_ops._shape_invariant_to_type_spec, loop_vars,
87        list(shape_invariants), expand_composites=False)
88    shape_invariants = nest.map_structure(
89        control_flow_ops._get_shape_invariant, loop_vars,
90        list(shape_invariants), expand_composites=False)
91
92  else:
93    signature = nest.map_structure(
94        type_spec.type_spec_from_value, loop_vars, expand_composites=False)
95    shape_invariants = nest.map_structure(
96        control_flow_ops._get_shape_invariant, loop_vars,
97        expand_composites=False)
98  if not name:
99    name = "while"
100
101  with ops.name_scope(name) as scope:
102    with ops.name_scope(None):
103      cond_name = util.unique_fn_name(scope, "cond")
104      body_name = util.unique_fn_name(scope, "body")
105    maximum_iterations_loop_var = _build_maximum_iterations_loop_var(
106        maximum_iterations)
107    loop_counter = constant_op.constant(
108        0,
109        dtype=maximum_iterations_loop_var.dtype
110        if maximum_iterations is not None else None,
111        name="loop_counter")
112    # Add loop counter needed for computing gradients.
113    loop_vars = [loop_counter, maximum_iterations_loop_var] + loop_vars
114
115    shape_invariants = [tensor_shape.TensorShape([])] * 2 + shape_invariants
116    signature = (
117        [tensor_spec.TensorSpec.from_tensor(loop_counter),
118         tensor_spec.TensorSpec.from_tensor(maximum_iterations_loop_var)] +
119        signature)
120
121    # Automatic control dependencies are added in defuns, but not in v1
122    # graphs. Propagate that behavior here.
123    add_control_dependencies = ops.get_default_graph()._add_control_dependencies
124
125    def wrapped_cond(loop_counter, maximum_iterations_arg, *args):
126      """Extra `cond` wrapper that can handle the extra counter loop_var."""
127      # Convert the flow variables in `args` to TensorArrays. `args` should
128      # already have the same structure as `orig_loop_vars` but currently there
129      # is no nest.zip so we call `_pack_sequence_as` which flattens both
130      # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays
131      # and packs it into the structure of `orig_loop_vars`.
132      pred = cond(*_pack_sequence_as(orig_loop_vars, args))
133      if (tensor_util.is_tf_type(pred) and
134          (pred.shape.dims is None or pred.shape.dims)):
135        pred = array_ops.squeeze_v2(pred)
136
137      if maximum_iterations is None:
138        return pred
139      else:
140        return math_ops.logical_and(
141            loop_counter < maximum_iterations_arg, pred)
142
143    # NOTE(skyewm): we set collections to the outer graph's collections for
144    # compatibility with TPUEstimator.
145    cond_graph = func_graph_module.func_graph_from_py_func(
146        cond_name,
147        wrapped_cond,
148        [],  # We provide signature instead of args.
149        {},
150        signature=signature,
151        func_graph=util.WhileCondFuncGraph(
152            cond_name, collections=ops.get_default_graph()._collections),  # pylint: disable=protected-access
153        add_control_dependencies=add_control_dependencies)
154
155    def wrapped_body(loop_counter, maximum_iterations_arg, *args):
156      """Loop body augmented with counter update.
157
158      Args:
159        loop_counter: Loop counter which needs to be incremented in the body.
160        maximum_iterations_arg: Maximum iterations of the loop.
161        *args: List of args
162
163      Returns:
164        A list of tensors the same length as args.
165      """
166      # The function was created with a signature rather than tensors, so
167      # internal placeholders were created without handle data.
168      _copy_handle_data(nest.flatten(loop_vars[2:], expand_composites=True),
169                        nest.flatten(args, expand_composites=True))
170      # Capture the tensors already captured in cond_graph so that they appear
171      # in the same order in body_graph.external_captures.
172      for t in cond_graph.external_captures:
173        ops.get_default_graph().capture(t)
174
175      # Convert the flow variables in `args` to TensorArrays. `args` should
176      # already have the same structure as `orig_loop_vars` but currently there
177      # is no nest.zip so we call `_pack_sequence_as` which flattens both
178      # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays
179      # and packs it into the structure of `orig_loop_vars`.
180      outputs = body(*_pack_sequence_as(orig_loop_vars, args))
181      if not nest.is_sequence_or_composite(outputs):
182        outputs = [outputs]
183      # Compare the structure of input and output of body converting the
184      # top-level tuples to list to be compatible with legacy while_loop.
185      nest.assert_same_structure(list(outputs), list(orig_loop_vars),
186                                 expand_composites=True)
187
188      outputs = _tensor_array_to_flow(outputs)
189
190      # TODO(srbs): Update lowering code to create _Enter nodes with
191      # is_constant=True for inputs that are directly passed to outputs.
192      return [loop_counter + 1, maximum_iterations_arg] + list(outputs)
193
194    body_graph = func_graph_module.func_graph_from_py_func(
195        body_name,
196        wrapped_body,
197        [],  # We provide signature instead of args.
198        {},
199        signature=signature,
200        func_graph=util.WhileBodyFuncGraph(
201            body_name, collections=ops.get_default_graph()._collections),  # pylint: disable=protected-access
202        add_control_dependencies=add_control_dependencies)
203    # Add external captures of body to the list of loop vars.
204    # Note that external tensors will be treated as loop invariants, i.e.,
205    # the value of that tensor in each iteration is the same as it was at the
206    # beginning of the loop execution.
207    loop_vars = loop_vars + body_graph.external_captures
208    # TODO(srbs): Update lowering code to create _Enter nodes with
209    # is_constant=True for inputs that are directly passed to outputs.
210    body_graph.outputs.extend(body_graph.internal_captures)
211
212    # Capture the extra `external_captures` of `body_graph` in `cond_graph` so
213    # that it expects to receive those as arguments.
214    with cond_graph.as_default():
215      num_cond_captures = len(cond_graph.external_captures)
216      assert (cond_graph.external_captures ==
217              body_graph.external_captures[:num_cond_captures])
218      _duplicate_body_captures_in_cond(
219          cond_graph, body_graph.external_captures[num_cond_captures:])
220
221    # Make sure that the shapes of the loop outputs are compatible with the
222    # shape invariants, or the shapes of the loop vars if the invariants are not
223    # specified.
224    num_flattened_outputs = len(nest.flatten(orig_loop_vars,
225                                             expand_composites=True))
226    # First var is loop counter and second var is maximum_iterations.
227    first_loop_var_index = 2
228    _check_shapes_compat(
229        body_graph.outputs[first_loop_var_index:first_loop_var_index +
230                           num_flattened_outputs],
231        nest.flatten(
232            shape_invariants[first_loop_var_index:first_loop_var_index +
233                             len_orig_loop_vars], expand_composites=True),
234        nest.flatten(loop_vars[first_loop_var_index:first_loop_var_index +
235                               len_orig_loop_vars], expand_composites=True))
236
237    num_original_outputs = len(body_graph.outputs)
238    if back_prop and util.output_all_intermediates():
239      # Export all tensors in the loop body that may be needed for gradient
240      # computation. We do this by accumulating the intermediate values in
241      # TensorLists.
242      intermediate_tensors = _get_intermediates(body_graph)
243
244      for intermediate_tensor in intermediate_tensors:
245        tensor_list = list_ops.empty_tensor_list(
246            element_dtype=intermediate_tensor.dtype,
247            element_shape=intermediate_tensor.shape,
248            max_num_elements=maximum_iterations)
249        loop_vars.append(tensor_list)
250        with cond_graph.as_default():
251          # Add a placeholder to cond_graph's inputs corresponding to the
252          # tensor_list.
253          cond_graph.capture(tensor_list)
254        with body_graph.as_default():
255          # Push the intermediate tensor to the tensor list. This captures the
256          # `tensor_list` as well.
257          appended_tensor_list = list_ops.tensor_list_push_back(
258              tensor_list, intermediate_tensor)
259          # Add this modified tensor list to the list of outputs.
260          body_graph.outputs.append(appended_tensor_list)
261
262    flattened_loop_vars = nest.flatten(loop_vars, expand_composites=True)
263    _check_num_inputs_outputs(cond_graph, body_graph,
264                              len(flattened_loop_vars))
265    _check_inputs_outputs_types_match(body_graph, flattened_loop_vars)
266
267    with ops.control_dependencies(
268        list(cond_graph.control_captures) + list(body_graph.control_captures)):
269      output_shapes = [t.shape for t in body_graph.outputs]
270      orig_loop_vars_range = slice(first_loop_var_index,
271                                   first_loop_var_index + num_flattened_outputs)
272      output_shapes[orig_loop_vars_range] = nest.flatten(
273          shape_invariants, expand_composites=True)[orig_loop_vars_range]
274
275      outputs = _build_while_op(
276          flattened_loop_vars,
277          cond_graph,
278          body_graph,
279          output_shapes=output_shapes,
280          parallel_iterations=parallel_iterations,
281          name=scope,
282          num_original_outputs=num_original_outputs)
283    if not ops.get_default_graph().building_function:
284      # In V1 graph mode, return identities for each output of the While op,
285      # rather than the output of the While op directly. This makes pruning work
286      # if the output of while_loop() is fetched: the lowering pass converts the
287      # While outputs into IdentityN outputs, which if fetched will cause all
288      # ops in the body to be run (since it takes all exit ops as input). After
289      # lowering, each output identity op will end up with only the appropriate
290      # exit op as input.
291      outputs = tuple(array_ops.identity(t) for t in outputs)
292
293  output_loop_vars = outputs[first_loop_var_index:first_loop_var_index +
294                             num_flattened_outputs]
295  if not back_prop:
296    output_loop_vars = [array_ops.stop_gradient(t) for t in output_loop_vars]
297  outputs = _pack_sequence_as(orig_loop_vars, output_loop_vars)
298
299  if return_same_structure:
300    return outputs
301
302  flattened_outputs = nest.flatten(outputs, expand_composites=True)
303  if len(flattened_outputs) == 1:
304    return flattened_outputs[0]
305  else:
306    return outputs
307
308
309@ops.RegisterGradient("StatelessWhile")
310@ops.RegisterGradient("While")
311def _WhileGrad(op, *grads):  # pylint: disable=invalid-name
312  """The gradient of a While op produced by while_loop."""
313  # Note that op is not always the same as while_op because the gradient tape,
314  # for eager mode compatibility, forgets information about the proper op. Since
315  # the loop cannot run in eager mode, however, we can safely introspect into
316  # the graph here.
317  while_op = op.outputs[0].op
318  cond_graph = _get_graph(while_op, "cond", "_cond_graph")
319  body_graph = _get_graph(while_op, "body", "_body_graph")
320  orig_num_params = len(body_graph.outputs)
321
322  maximum_iterations = op.inputs[1]
323  parallel_iterations = op.get_attr("parallel_iterations")
324
325  try:
326    num_original_outputs = while_op.get_attr("_num_original_outputs")
327  except:  # pylint: disable=bare-except
328    num_original_outputs = len(while_op.outputs)
329
330  num_intermediates = len(while_op.outputs) - num_original_outputs
331  grads = [
332      _preprocess_grad(grad, body_out, while_in, while_out)  # pylint: disable=g-complex-comprehension
333      for grad, body_out, while_in, while_out in zip(
334          grads[:num_original_outputs],
335          body_graph.outputs[:num_original_outputs],
336          while_op.inputs[:num_original_outputs],
337          while_op.outputs[:num_original_outputs])
338  ] + [None] * num_intermediates
339
340  # Skip gradients with respect to the captures whenever possible.
341  if "skip_input_indices" in op.__dict__ and op.skip_input_indices is not None:
342    captures_start_index = (
343        len(body_graph.inputs) - len(body_graph.internal_captures))
344    for i in op.skip_input_indices:
345      if i >= captures_start_index:
346        grads[i] = None
347
348  # We compute the gradient for the sub-graph between trainable ys and xs
349  # with non-None incoming gradients. We later pad the None's to the list of
350  # outputs.
351  ys, xs, non_none_grads = zip(*[(y, x, grad) for (y, x, grad) in zip(
352      body_graph.outputs, body_graph.inputs, grads) if grad is not None])
353
354  body_grad_graph, args = _create_grad_func(
355      ys, xs, non_none_grads, cond_graph, body_graph,
356      util.unique_grad_fn_name(body_graph.name), op, maximum_iterations)
357
358  if body_grad_graph.while_op_needs_rewrite:
359    # Modify 'op' to output the intermediate accumulators needed by the grad
360    # function.
361    # NOTE(skyewm): if there are any active sessions, this modification to `op`
362    # may make them unrunnable!
363
364    cond_graph.name += "_rewritten"
365    body_graph.name += "_rewritten"
366
367    # `body_grad_graph.extra_inputs` here is equivalent to skimming off the new
368    # `body_graph.external_captures` added during `_create_grad_func`.
369    new_inputs = body_grad_graph.extra_inputs
370    new_outputs = body_graph.outputs[orig_num_params:]
371
372    while_op._set_func_attr("cond", util.create_new_tf_function(cond_graph))
373    while_op._set_func_attr("body", util.create_new_tf_function(body_graph))
374    if len(body_graph.output_types) != len(while_op.inputs) + len(new_inputs):
375      # Continuing leads to an invalid graph with disconnected inputs.
376      raise AssertionError(
377          "Inputs and outputs constructed for the forward op of a While "
378          "gradient don't match. This doesn't make sense, please file a bug.")
379    while_op._set_type_list_attr("T", body_graph.output_types)
380    while_op._set_shape_list_attr("output_shapes", body_graph.output_shapes)
381    while_op._add_while_inputs(new_inputs)
382    while_op._add_outputs([t.dtype for t in new_outputs],
383                          [t.shape for t in new_outputs])
384    _copy_handle_data(new_outputs, while_op.outputs[orig_num_params:])
385
386  # Do not ignore grads wrt extra outputs when computing higher order
387  # derivatives.
388  while_op._set_attr("_num_original_outputs",
389                     attr_value_pb2.AttrValue(i=len(while_op.outputs)))
390
391  captured_inputs = _resolve_grad_captures(body_graph, body_grad_graph,
392                                           while_op)
393  loop_vars = args + captured_inputs
394
395  # This modifies body_grad_graph.
396  loop_vars = while_v2_indexed_slices_rewriter.rewrite_grad_indexed_slices(
397      grads, body_grad_graph, loop_vars, while_op.inputs)
398
399  def grad_cond(counter, unused_maximum_iterations_arg, forward_loop_iters,
400                *unused_args):
401    return counter < forward_loop_iters
402
403  grad_cond_name = util.unique_grad_fn_name(op.get_attr("cond").name)
404  cond_grad_graph = func_graph_module.func_graph_from_py_func(
405      grad_cond_name, grad_cond, loop_vars, {},
406      func_graph=util.WhileCondFuncGraph(grad_cond_name))
407
408  _check_num_inputs_outputs(cond_grad_graph, body_grad_graph, len(loop_vars))
409
410  outputs = _build_while_op(
411      loop_vars,
412      cond_grad_graph,
413      body_grad_graph,
414      output_shapes=[t.shape for t in body_grad_graph.outputs],
415      parallel_iterations=parallel_iterations,
416      name="%s_grad" % while_op.name,
417      num_original_outputs=len(body_grad_graph.outputs))
418
419  # See comment in while_loop.
420  outputs = [array_ops.identity(t) for t in outputs]
421  return _get_structured_grad_output(outputs, grads, body_grad_graph)
422
423
424def _build_while_op(loop_vars, cond_graph, body_graph, output_shapes,
425                    parallel_iterations, name, num_original_outputs):
426  """Builds the functional StatelessWhile/While op."""
427  cond_stateful_ops = [
428      op for op in cond_graph.get_operations() if op._is_stateful
429  ]
430  body_stateful_ops = [
431      op for op in body_graph.get_operations() if op._is_stateful
432  ]
433  if (cond_stateful_ops or body_stateful_ops):
434    op_fn = gen_functional_ops._while
435  else:
436    op_fn = gen_functional_ops.stateless_while
437
438  def _make_op(inputs):
439    while_op, tensors = util.get_op_and_outputs(op_fn(
440        inputs,
441        util.create_new_tf_function(cond_graph),
442        util.create_new_tf_function(body_graph),
443        output_shapes=output_shapes,
444        parallel_iterations=parallel_iterations,
445        name=name))
446    _copy_handle_data(body_graph.outputs, tensors)
447    util.maybe_set_lowering_attr(while_op)
448    util.maybe_propagate_compile_time_consts_in_xla(while_op)
449    _set_read_only_resource_inputs_attr(while_op, [cond_graph, body_graph])
450    # This is needed so we do not compute derivative wrt these extra outputs.
451    while_op._set_attr("_num_original_outputs",
452                       attr_value_pb2.AttrValue(i=num_original_outputs))
453    # The while op may be created inside a tf.function, in which case ops
454    # needs to capture "through" it when taking gradients; outer_graph is used
455    # as a sanity check that capturing only happens from parent to child.
456    cond_graph.outer_graph = ops.get_default_graph()
457    body_graph.outer_graph = ops.get_default_graph()
458    while_op._cond_graph = cond_graph
459    while_op._body_graph = body_graph
460    return tensors
461  return util.run_as_function_for_tape_gradients(_make_op, loop_vars)
462
463
464def _get_intermediates(func_graph):
465  """Returns all tensors in `func_graph` that should be accumulated."""
466  # We currently accumulate output tensors of most ops in the function and rely
467  # on the pruning pass to get rid of the unused accumulators at runtime.
468  # However, this can bloat the GraphDef and make debugging harder so we perform
469  # some optimizations.
470  #
471  # Optimization we currently perform:
472  # 1. We do not accumulate tensors which already have an accumulator
473  #    in the loop body.
474  # 2. We do not accumulate outputs of Identity nodes. When building the
475  #    FuncGraph, we add an Identity node for each output (see
476  #    `AutomaticControlDependencies.mark_as_return`). Accumulating outputs
477  #    of all these nodes bloats the GraphDef quite a bit so we remove those.
478  #    Since the gradient of an Identity node does not rely on its forward op's
479  #    input this is safe to do.
480  #
481  # Other possible optimizations:
482  # 1. Only accumulate tensors that will be required by the backward pass.
483  #    This will require running the gradient pass and hence would increase the
484  #    graph building time for the forward pass.
485  # 2. Do not accumulate Const nodes created inside the loop body.
486  # 3. Do not accumulate loop vars that are returned as-is just like captured
487  #    tensors.
488  intermediates = []
489  reverse_captures = dict((v.ref(), k) for k, v in func_graph.captures)
490
491  for op in func_graph.get_operations():
492    if op.type == "Identity":
493      continue
494    # Accumulating mutexes can cause deadlock.
495    if op.type == "MutexLock":
496      continue
497    for o in op.outputs:
498      if (o is not func_graph.inputs[0] and  # Loop counter.
499          o.dtype != dtypes.resource and  # Do not accumulate resource tensors.
500          _get_accumulator(o) is None and  # Has existing accumulator.
501          o.ref() not in reverse_captures
502         ):  # Captured value, hence loop invariant.
503        intermediates.append(o)
504  return intermediates
505
506
507def _preprocess_grad(grad, body_graph_output, while_op_input, while_op_output):
508  """Returns the initial gradient to be used for a given output tensor.
509
510  Args:
511    grad: the original gradient Tensor passed to the gradient function.
512    body_graph_output: the corresponding Tensor in the body graph.
513    while_op_input: the corresponding Tensor input of the While op.
514    while_op_output: the corresponding Tensor output of the While op.
515
516  Returns:
517    A Tensor or None.
518  """
519  # Set the incoming gradient of non-trainable inputs to None. It is possible
520  # that we receive non-None gradients for non-trainable types in nested while
521  # loops because we accumulate outputs of the inner while as variant tensors
522  # which are trainable and hence receive zeros_like tensors in the gradient
523  # pass. The non-trainable tensors then receive the popped zeros tensor from
524  # this zeros variant. The gradient for the loop vars corresponding to these
525  # tensors is None or zeros (this happens only if the loop var is accumulated
526  # as well) in _grad_fn so we reset these.
527  # TODO(b/118712257): Remove once we can handle None output grads in _grad_fn.
528  if not _is_trainable(body_graph_output):
529    return None
530
531  # GradientTape initializes resource and variant grads as None instead of
532  # zeros. Set to zeros so _GradientsHelper computes the gradients instead of
533  # returning None.
534  # TODO(b/143286622): The supports_default_grad check is needed
535  # because While op emits non-differentiable resource tensors
536  # as outputs. Remove this check when that is not the case.
537  # Note: We use `while_op_input` instead of `while_op_output` for the call
538  # to `supports_default_grad` because `while_op_output` may be missing
539  # handle_data if the While is in a restored saved model.
540  if (while_op_output.dtype in (dtypes.resource, dtypes.variant) and
541      default_gradient.supports_default_grad(while_op_input) and grad is None):
542    return _zeros_like(while_op_input, while_op_output)
543
544  # Convert IndexedSlices to dense tensors since it is unlikely that downstream
545  # gradient functions with properly handle indexed slices. This is similar to
546  # what we do in tf.function gradients.
547  if isinstance(grad, ops.IndexedSlices):
548    return ops.convert_to_tensor(grad)
549
550  return grad
551
552
553# TODO(skyewm): make this return constants if op_output's shape is fully
554# defined (this can be done by checking the "shape" attr of resource vars).
555def _zeros_like(op_input, op_output):
556  """Like array_ops.zeros_like() but also accepts resource var handles."""
557  if op_output.dtype == dtypes.resource:
558    # Note: We use `op_input` instead of `op_output` to get the zeros dtype
559    # because `op_output` may be missing handle_data if the While is in a
560    # restored saved model.
561    return array_ops.zeros(
562        gen_resource_variable_ops.variable_shape(op_output),
563        dtype=default_gradient.get_zeros_dtype(op_input))
564  return array_ops.zeros_like(op_output)
565
566
567def _is_trainable(tensor):
568  """Returns whether the given tensor is trainable."""
569  if not backprop_util.IsTrainable(tensor):
570    return False
571
572  # Special case: untrainable accumulator output. The gradients algorithm
573  # doesn't know about tensor lists of untrainable elements. In theory the
574  # tensor list gradient functions should return None as appropriate, but
575  # because we can't return None from the gradient function we filter out
576  # untrainable accumulator output here to avoid computing the gradient at all.
577  if tensor.op.type == "TensorListPopBack" and tensor.value_index == 0:
578    assert tensor.dtype == dtypes.variant
579    element_type = tensor.op.get_attr("element_dtype")
580    return backprop_util.IsTrainable(element_type)
581
582  return True
583
584
585def _get_graph(while_op, func_attr_name, attr_graph_name):
586  """Returns `FuncGraph` for the given function attribute.
587
588  Args:
589    while_op: The While Operation.
590    func_attr_name: string
591    attr_graph_name: cached forward graph name
592
593  Returns:
594    `FuncGraph`
595  """
596  func_graph = getattr(while_op, attr_graph_name, None)
597  if func_graph is None:
598    # TODO(srbs): Handle TensorShapeProto in function_def_to_graph.input_shapes.
599    input_shapes = [
600        tensor_shape.TensorShape(s) for s in while_op.get_attr("output_shapes")
601    ]
602    func_name = while_op.get_attr(func_attr_name).name
603    func_graph = util.get_func_graph(while_op, input_shapes, func_name)
604  func_graph._while = while_op
605  return func_graph
606
607
608def _create_grad_func(ys, xs, grads, cond_graph, body_graph, name, while_op,
609                      maximum_iterations):
610  """Builds and returns the gradient FuncGraph of `func_graph` and its args.
611
612  The returned grad_func_graph must be called with the returned
613  args + grad_func_graph.captures.
614
615  Args:
616    ys: A `Tensor` or list of tensors to be differentiated.
617    xs: A `Tensor` or list of tensors to be used for differentiation.
618    grads: The incoming grads for `ys`.
619    cond_graph: FuncGraph for the forward cond function.
620    body_graph: FuncGraph for the forward body function.
621    name: Name of the returned gradient function.
622    while_op: The forward While op.
623    maximum_iterations: Tensor. The maximum number of iterations.
624
625  Returns:
626    2-tuple of (grad_func_graph, args).
627  """
628  assert len(ys) == len(grads)
629
630  total_iters = while_op.outputs[0]
631  counter = constant_op.constant(
632      0, dtype=total_iters.dtype, name="grad_counter")
633
634  # Build frozen sets so that we do not have linear time lookups in
635  # `_is_loop_invariant`. Note: `body_graph.inputs` and `body_graph.outputs`
636  # may get updated during gradient computation because we add accumulators to
637  # the forward op. However, those are not loop invariants so wouldn't affect
638  # the output of `_is_loop_invariant`. Also we would never attempt to capture
639  # those accumulators so `_is_loop_invariant` should never receive those new
640  # tensors as args.
641  body_graph_inputs = object_identity.ObjectIdentitySet(body_graph.inputs)
642  body_graph_outputs = object_identity.ObjectIdentitySet(body_graph.outputs)
643
644  args = [counter, maximum_iterations, total_iters] + list(grads)
645  # Note: The returned function does not have `args` in the list of
646  # `external_captures`.
647  grad_func_graph = func_graph_module.func_graph_from_py_func(
648      name,
649      lambda *args: _grad_fn(ys, xs, args, body_graph),
650      args, {},
651      func_graph=_WhileBodyGradFuncGraph(name, cond_graph, body_graph,
652                                         maximum_iterations, while_op,
653                                         body_graph_inputs, body_graph_outputs))
654
655  # Update the list of outputs with tensors corresponding to the captured
656  # tensors. We capture 3 types of tensors when building the grad fn:
657  # 1. Accumulators for forward graph intermediates which are not loop
658  #    invariants. The outputs corresponding to these are populated in
659  #    `internal_capture_to_output` by `_WhileBodyGradFuncGraph`.
660  # 2. Resources, which are output as is.
661  # 3. Forward graph loop invariants, which are output as is.
662  for external_capture, internal_capture in grad_func_graph.captures:
663    if (ops.tensor_id(internal_capture)
664        in grad_func_graph.internal_capture_to_output):
665      new_output = grad_func_graph.internal_capture_to_output[ops.tensor_id(
666          internal_capture)]
667    else:
668      raise ValueError(
669          "Tensor %s which captures %s is in list of "
670          "internal_captures but not in internal_capture_to_output." %
671          (str(internal_capture), str(external_capture)))
672    grad_func_graph.outputs.append(new_output)
673    grad_func_graph.structured_outputs.append(new_output)
674
675  return grad_func_graph, args
676
677
678def _grad_fn(ys, xs, args, func_graph):
679  """Computes the gradient of `func_graph` in the current graph.
680
681  This function builds the gradient graph of the corresponding forward-pass
682  `func_graph` by differentiating `func_graph`'s outputs w.r.t. its inputs.
683
684  Args:
685    ys: A `Tensor` or list of tensors to be differentiated.
686    xs: A `Tensor` or list of tensors to be used for differentiation.
687    args: The input arguments.
688      args[0] - Loop counter
689      args[1] - Total number of iterations.
690      args[2] - maximum_iterations.
691      args[3:] - Incoming gradients for `ys`.
692    func_graph: function.FuncGraph. The corresponding forward-pass function.
693
694  Returns:
695    The output gradient Tensors.
696  """
697  grad_ys = args[3:]
698
699  # Build the gradient graph. Note that this builds the gradient computation of
700  # func_graph in the current graph, which requires capturing tensors from
701  # func_graph. The captured func_graph tensors are resolved to external tensors
702  # after the forward While op has been rewritten in _resolve_grad_captures.
703  # TODO(srbs): Mark GradientsHelper as public?
704  grad_outs = gradients_util._GradientsHelper(
705      ys, xs, grad_ys=grad_ys, src_graph=func_graph,
706      unconnected_gradients="zero")
707
708  # TODO(b/118712257): Handle the case when grad_outs has None's e.g. when there
709  # is a tf.StopGradient in the loop body.
710  assert all(g is not None for g in grad_outs)
711  counter = args[0]
712  maximum_iterations = args[1]
713  total_iters = args[2]
714  return [counter + 1, maximum_iterations, total_iters] + grad_outs
715
716
717def _resolve_grad_captures(body_graph, body_grad_graph, while_op):
718  """Returns the tensors to pass as captured inputs to `body_grad_graph`.
719
720  `body_grad_graph` may have external references to:
721  1. Its outer graph containing the input gradients. These are left as-is.
722  2. Accumulators captured from the forward-pass graph. These should have been
723     added as `while_op` outputs after the gradient graph was built. We replace
724     these with the corresponding output of `while_op`, i.e. a tensor in
725     `body_graph.outer_graph`. In the case of nested control flow or functions,
726     the gradient logic handling `body_grad_graph.outer_graph` will make sure
727     the tensor from `body_graph.outer_graph` is also correctly captured.
728
729  Args:
730    body_graph: FuncGraph. The forward-pass body function.
731    body_grad_graph: FuncGraph. The body gradients function.
732    while_op: The forward-pass While Operation calling `body_graph`.
733
734  Returns:
735    A list of input tensors to be passed as the captured inputs to
736    `body_grad_graph`.
737  """
738  new_capture_inputs = []
739  for t in body_grad_graph.external_captures:
740    # Resolve tensors captured from the forward graph to the outputs of the
741    # forward while_op.
742    if t.graph == body_graph:
743      # Captured accumulator or loop invariant.
744      for i, output in enumerate(t.graph.outputs):
745        if output is t:
746          t = while_op.outputs[i]
747          break
748
749      # Note: We rely on the capturing logic of the gradient While op graph to
750      # correctly capture the tensors in `body_graph.outer_graph`. Both cond_v2
751      # and while_v2 handle this while building their gradient functions.
752      assert t.graph == body_graph.outer_graph
753
754    new_capture_inputs.append(t)
755  return new_capture_inputs
756
757
758def _get_structured_grad_output(outputs, grads, body_grad_graph):
759  """Returns the values that should be returned from the while grad function.
760
761  Args:
762    outputs: the raw Tensor outputs of the grad While op.
763    grads: the input gradients to the gradient function.
764    body_grad_graph: _WhileBodyGradFuncGraph.
765
766  Returns:
767    A list of gradient values. May include Nones.
768  """
769  result = []
770  # outputs[0] is the loop counter.
771  # outputs[1] is maximum_iterations.
772  # outputs[2] is the total number of loop iterations.
773  outputs_idx = 3
774  structured_outputs_idx = 3
775  for g in grads:
776    # Set None as the output gradient for tensors with None input gradient.
777    if g is None:
778      result.append(None)
779      continue
780    output = body_grad_graph.structured_outputs[structured_outputs_idx]
781    structured_outputs_idx += 1
782    if isinstance(output, ops.IndexedSlices):
783      # TODO(skyewm): is there a more robust way to determine the order of
784      # flattened IndexedSlices components?
785      result.append(ops.IndexedSlices(
786          values=outputs[outputs_idx],
787          indices=outputs[outputs_idx + 1],
788          dense_shape=outputs[outputs_idx + 2]))
789      outputs_idx += 3
790    else:
791      assert isinstance(output, ops.Tensor)
792      result.append(outputs[outputs_idx])
793      outputs_idx += 1
794
795  return result
796
797
798def _get_accumulator(tensor):
799  r"""Returns TensorList if any containing accumulated values of tensor.
800
801  We try to find a pattern of the form:
802
803     input_tl   tensor
804        \        /
805    (TensorListPushBack)
806            |
807        output_tl
808
809  which satisfies the following conditions:
810
811  1. input_tl must be in tensor.graph.inputs.
812  2. output_tl or Identity(output_tl) must be in tensor.graph.outputs.
813  3. tensor.graph.input_index(input_tl) == tensor.graph.output_index(output_t).
814
815  output_tl or Identity(output_tl) (whichever is in tensor.graph.outputs) is
816  returned if such a pattern is found else None is returned.
817
818  Args:
819    tensor: The Tensor to be accumulated.
820
821  Returns:
822    A variant tensor in the same graph as `tensor` or None if no accumulator is
823    found.
824  """
825  assert isinstance(tensor.graph, func_graph_module.FuncGraph)
826
827  def get_func_graph_output(t):
828    """Returns t or Identity(t) whichever exists in graph outputs else None."""
829    for output in tensor.graph.outputs:
830      if output is t:
831        return t
832    # tf.defun adds an Identity for each output, check whether that is the case.
833    identity_op = t.consumers()[0]
834    if (identity_op.type == "Identity" and
835        any(identity_op.outputs[0] is t for t in tensor.graph.outputs)):
836      return identity_op.outputs[0]
837    return None
838
839  for consumer in tensor.consumers():
840    # Find the consumer that is a TensorListPushBack node whose TensorList input
841    # is in the list of function inputs.
842    if consumer.type != "TensorListPushBack":
843      continue
844
845    accum_input_idx = -1
846    for accum_input_idx, inp in enumerate(tensor.graph.inputs):
847      if inp is consumer.inputs[0]:
848        break
849    else:
850      continue
851
852    output = get_func_graph_output(consumer.outputs[0])
853    if output is None:
854      # The TensorList output of `consumer` is not in the list of function
855      # outputs.
856      continue
857
858    for accum_output_idx, out in enumerate(tensor.graph.outputs):
859      if out is output:
860        if accum_input_idx == accum_output_idx:
861          return output
862        break
863
864  return None
865
866
867OptimizedReductionOpsCacheKey = collections.namedtuple(
868    "OptimizedReductionOpsCacheKey", [
869        "op_type",
870        "inputs",
871        "dtypes",
872        "input_types",
873        "name",
874        "attrs",
875        "op_def",
876        "compute_device",
877    ])
878
879
880class _WhileBodyGradFuncGraph(util.WhileBodyFuncGraph):
881  """FuncGraph for the gradient function of the body of a While op.
882
883  Contains the logic for capturing the tensors from the body of the forward
884  While op which is as follows:
885  1. If the tensor is of resource type (these are not accumulated):
886     a. Ensure that the tensor is a loop invariant, i.e., it exists in both loop
887        inputs and outputs at the same index.
888     b. Lookup the corresponding resource tensor in the forward outer graph and
889        try to capture that.
890  2. If the tensor is not of resource type:
891     a. Create an accumulator for that tensor and output it from the forward
892        pass. Note this also requires adding it as an input to the forward pass.
893     b. Capture the accumulator from the forward pass in this FuncGraph. This
894        will later be resolved to the correct output of the forward While op.
895     c. Pop a value from the captured placeholder and use it as the captured
896        value for the forward pass tensor.
897
898  This only allows capturing tensors in the forward graph. A ValueError is
899  raised if an attempt is made to capture a tensor not in the forward graph.
900  To manually capture a tensor that is not in the forward graph, call `capture`
901  with `allowlisted=True`.
902
903  Note: The `captures` dict does not contain the forward tensor since it is not
904  directly captured. It contains the accumulator corresponding to this forward
905  tensor.
906
907  Attributes:
908    while_op_needs_rewrite: True if any non-resource intermediates were
909      captured, meaning the forward While op needs to be rewritten to output the
910      corresponding accumulators.
911    extra_inputs: list of EmptyTensorList tensors to be used as initial input to
912    the new accumulators in the forward graph. It may also contain external
913    captures of the custom gradient function.
914    internal_capture_to_output: dict from a tensor_id(captured placeholder) to
915      the corresponding tensor that needs to be added to the list of outputs.
916      For instance, when capturing an accumulator TensorList this contains the
917      TensorList obtained after popping a tensor from the list. Other entries
918      in this dict are expected, though not enforced, to be identities.
919      This dict is needed because these output tensors need to be added to
920      FuncGraph.outputs "after" the tensors returned from the gradient function.
921  """
922
923  def __init__(self, name, forward_cond_graph, forward_body_graph,
924               maximum_iterations, forward_while_op, body_graph_inputs,
925               body_graph_outputs):
926    super(_WhileBodyGradFuncGraph, self).__init__(name)
927    self.extra_inputs = []
928    self.internal_capture_to_output = {}
929    # FuncGraph for the body of the forward While op.
930    self._forward_graph = forward_body_graph
931    # FuncGraph for the cond of the forward While op.
932    self._forward_cond_graph = forward_cond_graph
933    self._maximum_iterations = maximum_iterations
934    self._forward_while_op = forward_while_op
935    # Dict from forward intermediate tensor to its indirectly captured tensor
936    # in this graph. Indirect capturing happens in two ways:
937    # 1. For non-resource tensors we capture their accumulators from the forward
938    #    outer graph and pop values from that accumulator inside this graph
939    #    using TensorListPopBack.
940    # 2. For resource tensors we directly capture their corresponding tensor
941    #    in the forward outer graph.
942    self._indirect_captures = {}
943
944  @property
945  def while_op_needs_rewrite(self):
946    return self.extra_inputs
947
948  def _create_op_internal(
949      self,
950      op_type,
951      inputs,
952      dtypes=None,  # pylint: disable=redefined-outer-name
953      input_types=None,
954      name=None,
955      attrs=None,
956      op_def=None,
957      compute_device=True):
958    # For a reduction op, if op is in the gradient body graph and its input is
959    # from the forward graph, moving op to the forward graph means we would
960    # store the tensor after the reduction as opposed to the tensor before
961    # reduction, and therefore could significantly reduce memory consumption.
962    # For now, we do this only for a few ops.
963    #
964    # We don't do this if any input tensor has already been accumulated. This
965    # can happen if we output all intermediates in the forward pass.
966    #
967    # If in XLA context, do not move constant ops to forward pass as pushing to
968    # and popping from a TensorList removes the constant property of an op and
969    # breaks XLA compilation, which requires certain inputs to be compile-time
970    # constant for certain ops.
971    #
972    # This optimization is currently also disabled when under a persistent tape,
973    # since it leads to an unbounded number of side outputs. With caching it may
974    # be possible to re-enable it.
975    optimized_reduction_ops = {
976        "Shape", "Size", "Rank", "TensorListElementShape", "TensorListLength"
977    }
978    if (op_type in optimized_reduction_ops and
979        not util.output_all_intermediates() and
980        all(input.graph is self._forward_graph for input in inputs) and
981        all(_get_accumulator(input) is None for input in inputs) and
982        not util_v1.GraphOrParentsInXlaContext(self._forward_graph) and
983        not util.graph_wrapped_for_higher_order_tape_gradients(
984            self._forward_graph)):
985      return self._move_op_to_forward_graph(
986          op_type,
987          inputs,
988          dtypes=dtypes,
989          input_types=input_types,
990          name=name,
991          attrs=attrs,
992          op_def=op_def,
993          compute_device=compute_device)
994
995    return super(_WhileBodyGradFuncGraph, self)._create_op_internal(
996        op_type,
997        inputs,
998        dtypes=dtypes,
999        input_types=input_types,
1000        name=name,
1001        attrs=attrs,
1002        op_def=op_def,
1003        compute_device=compute_device)
1004
1005  def _move_op_to_forward_graph(
1006      self,
1007      op_type,
1008      inputs,
1009      dtypes=None,  # pylint: disable=redefined-outer-name
1010      input_types=None,
1011      name=None,
1012      attrs=None,
1013      op_def=None,
1014      compute_device=True):
1015    # We have a cache of reduction ops that have already been moved to the
1016    # forward graph, and we will check it first to avoid moving an op twice.
1017    if not hasattr(self._forward_graph, "_optimized_reduction_ops_cache"):
1018      self._forward_graph._optimized_reduction_ops_cache = {}
1019    cache_key = self._get_optimized_reduction_ops_cache_key(
1020        op_type, inputs, dtypes, input_types, name, attrs, op_def,
1021        compute_device)
1022    cached_op = self._forward_graph._optimized_reduction_ops_cache.get(
1023        cache_key)
1024    if cached_op is not None:
1025      # This op has already been moved to the forward graph and we have it in
1026      # the cache.
1027      return cached_op
1028
1029    with self._forward_graph.as_default():
1030      # `name` was built using name_scope stack of gradient graph and may not
1031      # be unique in the forward graph. `Graph.create_op` does not uniquify
1032      # names which are name scopes i.e. end in `/`. To ensure that the op
1033      # created gets a unique name in the forward graph we get rid of the
1034      # trailing slash.
1035      name = ops.name_from_scope_name(name)
1036      result = self._forward_graph._create_op_internal(
1037          op_type,
1038          inputs,
1039          dtypes=dtypes,
1040          input_types=input_types,
1041          name=name,
1042          attrs=attrs,
1043          op_def=op_def,
1044          compute_device=compute_device)
1045
1046      # Store the op we just moved to the forward graph so that it does
1047      # not need to be added there again.
1048      self._forward_graph._optimized_reduction_ops_cache[cache_key] = result
1049      return result
1050
1051  def _get_optimized_reduction_ops_cache_key(
1052      self,
1053      op_type,
1054      inputs,
1055      dtypes=None,  # pylint: disable=redefined-outer-name
1056      input_types=None,
1057      name=None,
1058      attrs=None,
1059      op_def=None,
1060      compute_device=True):
1061    # We need all elements of CacheKey to be hashable.
1062    inputs = tuple(map(lambda t: t.ref(), inputs))
1063
1064    if dtypes is not None:
1065      dtypes = tuple(dtypes)
1066
1067    if input_types is not None:
1068      input_types = tuple(input_types)
1069
1070    if attrs is not None:
1071      hashable_attrs = []
1072      for attr_name, attr_value in sorted(attrs.items()):
1073        hashable_attrs.append((attr_name, attr_value.SerializeToString()))
1074      attrs = tuple(hashable_attrs)
1075
1076    if op_def is not None:
1077      op_def = op_def.SerializeToString()
1078
1079    return OptimizedReductionOpsCacheKey(op_type, inputs, dtypes, input_types,
1080                                         name, attrs, op_def, compute_device)
1081
1082  def _capture_helper(self, tensor, name):
1083    """Implements the capturing described in the class docstring."""
1084    captured_tensor = self._indirect_captures.get(ops.tensor_id(tensor))
1085    if captured_tensor is not None:
1086      return captured_tensor
1087
1088    if tensor.graph is not self._forward_graph:
1089      already_captured = self.captured(tensor)
1090      captured_tensor = super(_WhileBodyGradFuncGraph, self)._capture_helper(
1091          tensor, name)
1092      if not already_captured:
1093        # Adds the captured tensor to the list of outputs so that the input
1094        # and output signatures match.
1095        self.internal_capture_to_output[ops.tensor_id(
1096            captured_tensor)] = captured_tensor
1097        self._indirect_captures[ops.tensor_id(tensor)] = captured_tensor
1098      return captured_tensor
1099
1100    while tensor.op.type == "Identity":
1101      # We do not accumulate the output of identity nodes so we try to capture
1102      # the input of the Identity node instead.
1103      tensor = tensor.op.inputs[0]
1104
1105    captured_tensor = self._indirect_captures.get(ops.tensor_id(tensor))
1106    if captured_tensor is not None:
1107      return captured_tensor
1108
1109    # No need to accumulate loop invariants. Capture them directly.
1110    # The captured tensor gets resolved to the corresponding while output in
1111    # `_resolve_grad_captures`.
1112    if _is_loop_invariant(tensor, self._forward_graph.inputs,
1113                          self._forward_graph.outputs):
1114      captured_tensor = super(_WhileBodyGradFuncGraph,
1115                              self)._capture_helper(tensor, name)
1116      # Add to `internal_capture_to_output` so that this gets added to the list
1117      # of outputs.
1118      self.internal_capture_to_output[ops.tensor_id(
1119          captured_tensor)] = captured_tensor
1120      self._indirect_captures[ops.tensor_id(tensor)] = captured_tensor
1121      return captured_tensor
1122
1123    # Do not accumulate Const nodes. Instead copy them directly in the backward
1124    # graph.
1125    # TODO(srbs): This just checks for `Const` nodes. Consider checking for
1126    # graph compile time consts in general.
1127    # TODO(srbs): Consider making this a loop input.
1128    if constant_op.is_constant(tensor):
1129      real_value = constant_op.constant(
1130          tensor_util.constant_value(tensor), dtype=tensor.dtype)
1131      self._indirect_captures[ops.tensor_id(tensor)] = real_value
1132      return real_value
1133
1134    # Resource tensors are not accumulated and handled specially.
1135    if tensor.dtype == dtypes.resource:
1136      return self._resource_capture_helper(tensor)
1137
1138    # Create or find an existing accumulator output for `tensor` in the forward
1139    # graph, and fetch from this accumulator in the gradient graph to get the
1140    # raw intermediate value.
1141    accumulator = _get_accumulator(tensor)
1142    if accumulator is None:
1143      # Create the initial empty tensor list.
1144      #
1145      # Note: We clear the control dependencies to avoid a cycle in case a
1146      # control tensor has an input path to an output of the  forward While.
1147      #
1148      # E.g.:
1149      # x = tf.while_loop(...)
1150      # y = f(x)
1151      # with tf.control_dependencies([y]):
1152      #   tf.gradients(y, x)
1153      #
1154      # Since the EmptyTensorList is fed back into the forward While, not
1155      # removing the control edge would cause a cycle.
1156      with self._forward_graph.outer_graph.as_default():
1157        with util.clear_control_inputs():
1158          tensor_list = list_ops.empty_tensor_list(
1159              element_dtype=tensor.dtype,
1160              element_shape=tensor.shape,
1161              max_num_elements=self._maximum_iterations,
1162              name=_build_accumulator_name(tensor))
1163      self.extra_inputs.append(tensor_list)
1164
1165      # Push the intermediate tensor to the tensor list. This captures
1166      # `tensor_list`.
1167      with self._forward_graph.as_default():
1168        accumulator = list_ops.tensor_list_push_back(tensor_list, tensor)
1169      # Add the modified tensor list to the list of outputs. This output will be
1170      # all the accumulated values.
1171      self._forward_graph.outputs.append(accumulator)
1172
1173      # Capture in the cond graph as well so the forward cond and body inputs
1174      # match.
1175      with self._forward_cond_graph.as_default():
1176        self._forward_cond_graph.capture(tensor_list)
1177
1178    # Capture the accumulator tensor list in the gradient graph directly from
1179    # the forward graph -- we'll later modify this to capture the final list
1180    # output by the forward While op instead.
1181    captured_accumulator = super(_WhileBodyGradFuncGraph, self)._capture_helper(
1182        accumulator, name)
1183
1184    # Pop the intermediate value from the tensor list in the gradient graph.
1185    new_tensor_list, captured_tensor = list_ops.tensor_list_pop_back(
1186        captured_accumulator, element_dtype=tensor.dtype)
1187
1188    self._indirect_captures[ops.tensor_id(tensor)] = captured_tensor
1189    self.internal_capture_to_output[ops.tensor_id(
1190        captured_accumulator)] = new_tensor_list
1191    return captured_tensor
1192
1193  def _resource_capture_helper(self, tensor):
1194    """Returns the captured resource tensor.
1195
1196    Resource-type tensors are not accumulated. If a resource tensor exists in
1197    the loop body it must either be a loop input or an output of a nested While
1198    op inside the loop body which had captured the external resource.
1199
1200    Args:
1201      tensor: the external resource Tensor to be captured.
1202
1203    Returns:
1204      Tensor in this graph.
1205    """
1206    assert tensor.dtype == dtypes.resource
1207
1208    index = util.resource_input_index(
1209        tensor.name, [t.name for t in self._forward_graph.inputs],
1210        {op.name: op.node_def for op in self._forward_graph.get_operations()},
1211        self._forward_graph._functions)
1212
1213    input_placeholder = self._forward_graph.inputs[index]
1214    tensor_in_outer_graph = self._forward_graph._while.inputs[index]
1215
1216    assert input_placeholder.dtype == dtypes.resource
1217    assert tensor_in_outer_graph.dtype == dtypes.resource
1218    # This must be a loop invariant.
1219    assert input_placeholder is self._forward_graph.outputs[index], (
1220        "Resource tensors must be loop invariants %s." % tensor_in_outer_graph)
1221
1222    self._indirect_captures[ops.tensor_id(tensor)] = self.capture(
1223        tensor_in_outer_graph)
1224    return self._indirect_captures[ops.tensor_id(tensor)]
1225
1226
1227def _check_shapes_compat(output_tensors, shape_invariants, input_tensors):
1228  for (t, shape, input_t) in zip(output_tensors, shape_invariants,
1229                                 input_tensors):
1230    if not control_flow_ops._ShapeLessThanOrEqual(t.shape, shape):
1231      raise ValueError(
1232          "Input tensor '%s' enters the loop with shape %s, but has "
1233          "shape %s after one iteration. To allow the shape to vary across "
1234          "iterations, use the `shape_invariants` argument of tf.while_loop to "
1235          "specify a less-specific shape." % (input_t.name, shape, t.shape))
1236
1237
1238def _check_num_inputs_outputs(cond_graph, body_graph, num_flattened_loop_vars):
1239  """Checks the number of inputs/outputs of `cond_graph` and `body_graph`."""
1240  assert len(cond_graph.inputs) == num_flattened_loop_vars, (
1241      "cond_graph takes %d inputs; Expected: %d" % (len(cond_graph.inputs),
1242                                                    num_flattened_loop_vars))
1243  assert len(cond_graph.outputs) == 1, (
1244      "cond_graph has %d outputs; Expected: 1" % len(cond_graph.outputs))
1245  assert len(body_graph.inputs) == num_flattened_loop_vars, (
1246      "body_graph takes %d inputs; Expected: %d" % (len(body_graph.inputs),
1247                                                    num_flattened_loop_vars))
1248  assert len(body_graph.outputs) == num_flattened_loop_vars, (
1249      "body_graph has %d outputs; Expected: %d" % (len(body_graph.outputs),
1250                                                   num_flattened_loop_vars))
1251
1252
1253def _check_inputs_outputs_types_match(body_graph, flattened_loop_vars):
1254  for inp, out, loop_var in zip(body_graph.inputs, body_graph.outputs,
1255                                flattened_loop_vars):
1256    if inp.dtype != out.dtype:
1257      raise TypeError("Loop var {} enters the loop with type {} "
1258                      "but has type {} after 1 iteration.".format(
1259                          loop_var.name, inp.dtype, out.dtype))
1260
1261
1262def _build_cond_placeholders_name_prefix(cond_graph):
1263  return cond_graph.unique_name(cond_graph.name + "___redundant_placeholder")
1264
1265
1266def _duplicate_body_captures_in_cond(cond_graph, body_graph_captures):
1267  """Creates placeholders for body captures in cond_graph.
1268
1269  This is needed to match signatures of cond and body graphs.
1270
1271  Args:
1272    cond_graph: cond branch graph
1273    body_graph_captures: Tensors which were captured when building the
1274      `body_graph`.
1275  """
1276  types = [t.dtype.as_datatype_enum for t in body_graph_captures]
1277  # TODO(srbs): Providing a unique prefix does not ensure that there is no
1278  # conflict between the placeholder names and existing nodes in the graph.
1279  # However passing a list of strings may not be performant.
1280  # Ideally we should move `Graph.unique_name` to C++ or make
1281  # `Graph._names_in_use` a trie so that we can find a unique prefix.
1282  # TODO(b/143286622): This should not be required once captures are separated
1283  # from regular loop vars.
1284  placeholders = c_api.TF_CreatePlaceholders(
1285      cond_graph._c_graph, types,
1286      compat.as_str(_build_cond_placeholders_name_prefix(cond_graph)))
1287  placeholder_ops = [
1288      _OperationWithOutputs(ph.oper, cond_graph)
1289      for ph in placeholders
1290  ]
1291
1292  tensors = []
1293  for op, ph, dtype in zip(placeholder_ops, placeholders, types):
1294    tensor = ops.Tensor._create_with_tf_output(op, 0, dtype, ph)
1295    op._outputs = [tensor]
1296    tensors.append(tensor)
1297
1298  # Update `cond_graph._captures` and `cond_graph.inputs` to contain the
1299  # newly created placeholders.
1300  tuples = zip(body_graph_captures, tensors)
1301  keys = [id(t) for t in body_graph_captures]
1302  cond_graph._captures.update(zip(keys, tuples))
1303  cond_graph.inputs.extend(tensors)
1304
1305
1306def _copy_handle_data(src_tensors, tgt_tensors):
1307  for src_t, tgt_t in zip(src_tensors, tgt_tensors):
1308    custom_gradient.copy_handle_data(src_t, tgt_t)
1309
1310
1311def _graph_name(graph):
1312  if isinstance(graph, func_graph_module.FuncGraph):
1313    return graph.name
1314  return "Base"
1315
1316
1317def _pack_sequence_as(structure_with_tas, loop_vars):
1318  """Like `nest.pack_sequence_as` but also replaces flows with TensorArrays."""
1319
1320  def flow_to_tensor_array(flow, ta):  # pylint: disable=missing-docstring
1321    return (tensor_array_ops.build_ta_with_new_flow(ta, flow) if isinstance(  # pylint: disable=g-long-ternary
1322        ta, tensor_array_ops.TensorArray) else flow)
1323
1324  flattened_loop_vars = [
1325      flow_to_tensor_array(*z)
1326      for z in zip(nest.flatten(loop_vars, expand_composites=True),
1327                   nest.flatten(structure_with_tas, expand_composites=True))
1328  ]
1329  return nest.pack_sequence_as(structure_with_tas, flattened_loop_vars,
1330                               expand_composites=True)
1331
1332
1333def _tensor_array_to_flow(loop_vars):
1334
1335  def f(maybe_ta):
1336    if isinstance(maybe_ta, tensor_array_ops.TensorArray):
1337      return maybe_ta.flow
1338    return maybe_ta
1339
1340  return nest.map_structure(f, loop_vars, expand_composites=True)
1341
1342
1343def _build_maximum_iterations_loop_var(maximum_iterations):
1344  if maximum_iterations is None:
1345    # Default value for max_num_elements to EmptyTensorList meaning that the
1346    # list size is unbounded.
1347    maximum_iterations = -1
1348  # EmptyTensorList expects `max_num_elements` to be of type int32.
1349  return ops.convert_to_tensor(
1350      maximum_iterations, dtype=dtypes.int32, name="maximum_iterations")
1351
1352
1353def _build_accumulator_name(tensor):
1354  # Tensor name may be of the form "pow/y:0". Name scope does not allow ":".
1355  return "{}/accumulator".format(tensor.name).replace(":", "_")
1356
1357
1358def _is_loop_invariant(tensor, inputs, outputs):
1359  return (any(tensor is t for t in inputs) and
1360          any(tensor is t for t in outputs))
1361
1362
1363class _OperationWithOutputs(ops.Operation):
1364  """Operation with pre-built `TF_Output`s.
1365
1366  The C API for creating the extra placeholders for the cond graph returns
1367  SWIG wrapped TF_Output* pointers which we can use directly for
1368  `Operation.outputs`. The default constructor for `Operation` does not provide
1369  a way of specifying pre-built output tensors and always creates them. This is
1370  a performance overhead. It is not clear if adding that feature to the
1371  `Operation` API would be generally useful so for now we just have our own
1372  lightweight `Operation` implementation. Note that this does not extract a
1373  stacktrace as well since we don't expect this operation to be used.
1374
1375  TODO(b/143286622): This should not be required once captures are separated
1376  from regular loop vars.
1377  """
1378
1379  def __init__(self, c_op, g):
1380    self._c_op = c_op
1381    self._graph = g
1382    self._outputs = None  # Initialized by _duplicate_body_captures_in_cond().
1383    self._id_value = g._add_op(self, self.name)
1384    self._is_stateful = False
1385
1386
1387def _set_read_only_resource_inputs_attr(op, branch_graphs):
1388  """Sets the list of resource inputs which are read-only.
1389
1390  This is used by AutomaticControlDependencies.
1391
1392  Args:
1393    op: While Operation.
1394    branch_graphs: List of branch FuncGraphs.
1395  """
1396  read_only_indices = set(range(len(op.inputs)))
1397  for branch_graph in branch_graphs:
1398    if not read_only_indices:
1399      break
1400    branch_read_only_indices = acd.get_read_only_resource_input_indices_graph(
1401        branch_graph)
1402    read_only_indices = read_only_indices.intersection(branch_read_only_indices)
1403
1404  ops.set_int_list_attr(op, acd.READ_ONLY_RESOURCE_INPUTS_ATTR,
1405                        sorted(read_only_indices))
1406
1407# pylint: enable=protected-access
1408