• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""while_v2 and gradient.
16
17This is a version of while_loop that emits a single While op, as well as the
18gradient function for While ops produced by while_loop. This will eventually
19replace the current tf.while_loop implementation once it reaches feature and
20performance parity.
21"""
22from __future__ import absolute_import
23from __future__ import division
24from __future__ import print_function
25
26from tensorflow.python.framework import constant_op
27from tensorflow.python.framework import dtypes
28from tensorflow.python.framework import func_graph as func_graph_module
29from tensorflow.python.framework import function_def_to_graph
30from tensorflow.python.framework import ops
31from tensorflow.python.framework import tensor_shape
32from tensorflow.python.framework import tensor_spec
33from tensorflow.python.ops import array_ops
34from tensorflow.python.ops import control_flow_ops
35from tensorflow.python.ops import control_flow_util
36from tensorflow.python.ops import control_flow_util_v2 as util
37from tensorflow.python.ops import custom_gradient
38from tensorflow.python.ops import gen_functional_ops
39from tensorflow.python.ops import gen_resource_variable_ops
40from tensorflow.python.ops import gradients_util
41from tensorflow.python.ops import list_ops
42from tensorflow.python.ops import math_ops
43from tensorflow.python.ops import tensor_array_ops
44from tensorflow.python.ops import while_v2_indexed_slices_rewriter
45from tensorflow.python.util import nest
46
47# pylint: disable=protected-access
48
49# TODO(b/79881896): Handle external control dependencies. tf.while_loop allows
50# control dependencies on external nodes with at least 1 output.
51# Another idea is to create const nodes outside the loop and add control edges
52# to them and then pass those in as data inputs. This should probably be
53# handled in the CapturingGraph itself.
54
55
56def while_loop(cond,
57               body,
58               loop_vars,
59               shape_invariants=None,
60               parallel_iterations=10,
61               maximum_iterations=None,
62               name=None,
63               return_same_structure=True):
64  """Like tf.while_loop, except emits a single While op."""
65  # Keep the original loop_vars around to know which args were TensorArrays.
66  orig_loop_vars = loop_vars
67  # Cache its length since we use it at multiple places below.
68  len_orig_loop_vars = len(orig_loop_vars)
69
70  # Convert TensorArrays to their flow variables. These get converted back to
71  # TensorArrays before calling `cond` and `body`. See `wrapped_cond` and
72  # `wrapped_body` below.
73  loop_vars = list(_tensor_array_to_flow(orig_loop_vars))
74  loop_vars = nest.map_structure(
75      ops.internal_convert_to_tensor_or_indexed_slices, loop_vars)
76  if shape_invariants is not None:
77    nest.assert_same_structure(orig_loop_vars, shape_invariants)
78  else:
79    shape_invariants = nest.map_structure(lambda t: t.shape, loop_vars)
80
81  if not name:
82    name = "while"
83
84  with ops.name_scope(name) as scope:
85    with ops.name_scope(None):
86      cond_name = util.unique_fn_name(scope, "cond")
87      body_name = util.unique_fn_name(scope, "body")
88    maximum_iterations_loop_var = _build_maximum_iterations_loop_var(
89        maximum_iterations)
90    loop_counter = constant_op.constant(
91        0,
92        dtype=maximum_iterations_loop_var.dtype
93        if maximum_iterations is not None else None,
94        name="loop_counter")
95    # Add loop counter needed for computing gradients.
96    loop_vars = [loop_counter, maximum_iterations_loop_var] + loop_vars
97
98    shape_invariants = type(shape_invariants)(
99        [tensor_shape.scalar(), tensor_shape.scalar()]) + shape_invariants
100
101    # Automatic control dependencies are added in defuns, but not in v1
102    # graphs. Propagate that behavior here.
103    add_control_dependencies = ops.get_default_graph()._add_control_dependencies
104
105    # Build a `cond` wrapper that can handle the extra counter loop_var.
106    def wrapped_cond(loop_counter, maximum_iterations_arg, *args):
107      # Convert the flow variables in `args` to TensorArrays. `args` should
108      # already have the same structure as `orig_loop_vars` but currently there
109      # is no nest.zip so we call `_pack_sequence_as` which flattens both
110      # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays
111      # and packs it into the structure of `orig_loop_vars`.
112      if maximum_iterations is None:
113        return cond(*_pack_sequence_as(orig_loop_vars, args))
114      else:
115        return math_ops.logical_and(
116            loop_counter < maximum_iterations_arg,
117            cond(*_pack_sequence_as(orig_loop_vars, args)))
118
119    # NOTE(skyewm): we set collections to the outer graph's collections for
120    # compatibility with TPUEstimator.
121    cond_graph = func_graph_module.func_graph_from_py_func(
122        cond_name,
123        wrapped_cond,
124        [],  # We provide signature instead of args.
125        {},
126        signature=_build_signature(loop_vars, shape_invariants),
127        func_graph=util.WhileCondFuncGraph(
128            cond_name, collections=ops.get_default_graph()._collections),  # pylint: disable=protected-access
129        add_control_dependencies=add_control_dependencies)
130
131    def wrapped_body(loop_counter, maximum_iterations_arg, *args):
132      """Loop body augmented with counter update.
133
134      Args:
135        loop_counter: Loop counter which needs to be incremented in the body.
136        maximum_iterations_arg: Maximum iterations of the loop.
137        *args: List of args
138
139      Returns:
140        A list of tensors the same length as args.
141      """
142      # Capture the tensors already captured in cond_graph so that they appear
143      # in the same order in body_graph.external_captures.
144      for t in cond_graph.external_captures:
145        ops.get_default_graph().capture(t)
146
147      # Convert the flow variables in `args` to TensorArrays. `args` should
148      # already have the same structure as `orig_loop_vars` but currently there
149      # is no nest.zip so we call `_pack_sequence_as` which flattens both
150      # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays
151      # and packs it into the structure of `orig_loop_vars`.
152      outputs = body(*_pack_sequence_as(orig_loop_vars, args))
153      if not nest.is_sequence(outputs):
154        outputs = [outputs]
155      # Compare the structure of input and output of body converting the
156      # top-level tuples to list to be compatible with legacy while_loop.
157      nest.assert_same_structure(list(outputs), list(orig_loop_vars))
158
159      outputs = _tensor_array_to_flow(outputs)
160
161      # TODO(srbs): Update lowering code to create _Enter nodes with
162      # is_constant=True for inputs that are directly passed to outputs.
163      return [loop_counter + 1, maximum_iterations_arg] + list(outputs)
164
165    body_graph = func_graph_module.func_graph_from_py_func(
166        body_name,
167        wrapped_body,
168        [],  # We provide signature instead of args.
169        {},
170        signature=_build_signature(loop_vars, shape_invariants),
171        func_graph=util.WhileBodyFuncGraph(
172            body_name, collections=ops.get_default_graph()._collections),  # pylint: disable=protected-access
173        add_control_dependencies=add_control_dependencies)
174    # Add external captures of body to the list of loop vars.
175    # Note that external tensors will be treated as loop invariants, i.e.,
176    # the value of that tensor in each iteration is the same as it was at the
177    # beginning of the loop execution.
178    loop_vars = loop_vars + body_graph.external_captures
179    # TODO(srbs): Update lowering code to create _Enter nodes with
180    # is_constant=True for inputs that are directly passed to outputs.
181    body_graph.outputs.extend(body_graph.internal_captures)
182
183    # Capture the extra `external_captures` of `body_graph` in `cond_graph` so
184    # that it expects to receive those as arguments.
185    with cond_graph.as_default():
186      num_cond_captures = len(cond_graph.external_captures)
187      assert (cond_graph.external_captures ==
188              body_graph.external_captures[:num_cond_captures])
189      for body_capture in body_graph.external_captures[num_cond_captures:]:
190        assert body_capture not in cond_graph.captures
191        cond_graph.capture(body_capture)
192
193    # Make sure that the shapes of the loop outputs are compatible with the
194    # shape invariants, or the shapes of the loop vars if the invariants are not
195    # specified.
196    num_flattened_outputs = len(nest.flatten(orig_loop_vars))
197    # First var is loop counter and second var is maximum_iterations.
198    first_loop_var_index = 2
199    _check_shapes_compat(
200        body_graph.outputs[first_loop_var_index:first_loop_var_index +
201                           num_flattened_outputs],
202        nest.flatten(
203            shape_invariants[first_loop_var_index:first_loop_var_index +
204                             len_orig_loop_vars]),
205        nest.flatten(loop_vars[first_loop_var_index:first_loop_var_index +
206                               len_orig_loop_vars]))
207    flattened_loop_vars = nest.flatten(loop_vars)
208    _check_num_inputs_outputs(cond_graph, body_graph,
209                              len(flattened_loop_vars))
210
211    with ops.control_dependencies(
212        list(cond_graph.control_captures) + list(body_graph.control_captures)):
213      outputs = gen_functional_ops._while(
214          flattened_loop_vars,
215          util.create_new_tf_function(cond_graph),
216          util.create_new_tf_function(body_graph),
217          output_shapes=[t.shape for t in body_graph.outputs],
218          parallel_iterations=parallel_iterations,
219          name=scope)
220
221    _copy_handle_data(body_graph.outputs, outputs)
222    util.maybe_set_lowering_attr(outputs[0].op)
223    util.maybe_propagate_compile_time_consts_in_xla(outputs[0].op)
224
225    # Return identities for each output of the While op, rather than the output
226    # of the While op directly. This makes pruning work if the output of
227    # while_loop() is fetched: the lowering pass converts the While outputs into
228    # IdentityN outputs, which if fetched will cause all ops in the body to be
229    # run (since it takes all exit ops as input). After lowering, each output
230    # identity op will end up with only the appropriate exit op as input.
231    outputs = tuple(array_ops.identity(t) for t in outputs)
232
233  outputs = _pack_sequence_as(
234      orig_loop_vars, outputs[first_loop_var_index:first_loop_var_index +
235                              num_flattened_outputs])
236
237  if return_same_structure:
238    return outputs
239
240  flattened_outputs = nest.flatten(outputs)
241  if len(flattened_outputs) == 1:
242    return flattened_outputs[0]
243  else:
244    return outputs
245
246
247@ops.RegisterGradient("While")
248def _WhileGrad(op, *grads):  # pylint: disable=invalid-name
249  """The gradient of a While op produced by while_loop."""
250  # Note that op is not always the same as while_op because the gradient tape,
251  # for eager mode compatibility, forgets information about the proper op. Since
252  # the loop cannot run in eager mode, however, we can safely introspect into
253  # the graph here.
254  while_op = op.outputs[0].op
255  cond_graph = _get_graph(while_op, "cond")
256  body_graph = _get_graph(while_op, "body")
257  orig_num_params = len(body_graph.outputs)
258
259  maximum_iterations = op.inputs[1]
260  parallel_iterations = op.get_attr("parallel_iterations")
261
262  grads = [_preprocess_grad(grad, body_out, while_out)
263           for grad, body_out, while_out
264           in zip(grads, body_graph.outputs, while_op.outputs)]
265
266  # We compute the gradient for the sub-graph between trainable ys and xs
267  # with non-None incoming gradients. We later pad the None's to the list of
268  # outputs.
269  ys, xs, non_none_grads = zip(*[(y, x, grad) for (y, x, grad) in zip(
270      body_graph.outputs, body_graph.inputs, grads) if grad is not None])
271
272  body_grad_graph, args = _create_grad_func(
273      ys, xs, non_none_grads, cond_graph, body_graph,
274      util.unique_grad_fn_name(body_graph.name), op, maximum_iterations)
275
276  if body_grad_graph.while_op_needs_rewrite:
277    # Modify 'op' to output the intermediate accumulators needed by the grad
278    # function.
279    # NOTE(skyewm): if there are any active sessions, this modification to `op`
280    # may make them unrunnable!
281
282    cond_graph.name += "_rewritten"
283    body_graph.name += "_rewritten"
284
285    new_inputs = body_grad_graph.empty_tensor_lists
286    new_outputs = body_graph.outputs[orig_num_params:]
287
288    while_op._set_func_attr("cond", util.create_new_tf_function(cond_graph))
289    while_op._set_func_attr("body", util.create_new_tf_function(body_graph))
290    while_op._set_type_list_attr("T", body_graph.output_types)
291    while_op._set_shape_list_attr("output_shapes", body_graph.output_shapes)
292    while_op._add_while_inputs(new_inputs)
293    while_op._add_outputs([t.dtype for t in new_outputs],
294                          [t.shape for t in new_outputs])
295    _copy_handle_data(new_outputs, op.outputs[orig_num_params:])
296
297  captured_inputs = _resolve_grad_captures(body_graph, body_grad_graph,
298                                           while_op)
299  loop_vars = args + captured_inputs
300
301  # This modifies body_grad_graph.
302  loop_vars = while_v2_indexed_slices_rewriter.rewrite_grad_indexed_slices(
303      grads, body_grad_graph, loop_vars, while_op.inputs)
304
305  def grad_cond(counter, unused_maximum_iterations_arg, forward_loop_iters,
306                *unused_args):
307    return counter < forward_loop_iters
308
309  grad_cond_name = util.unique_grad_fn_name(op.get_attr("cond").name)
310  cond_grad_graph = func_graph_module.func_graph_from_py_func(
311      grad_cond_name, grad_cond, loop_vars, {},
312      func_graph=util.WhileCondFuncGraph(grad_cond_name))
313
314  _check_num_inputs_outputs(cond_grad_graph, body_grad_graph, len(loop_vars))
315
316  outputs = gen_functional_ops._while(
317      loop_vars,
318      util.create_new_tf_function(cond_grad_graph),
319      util.create_new_tf_function(body_grad_graph),
320      output_shapes=[t.shape for t in body_grad_graph.outputs],
321      parallel_iterations=parallel_iterations,
322      name="%s_grad" % while_op.name)
323  grad_op = outputs[0].op
324
325  _copy_handle_data(body_grad_graph.outputs, outputs)
326  util.maybe_set_lowering_attr(grad_op)
327  util.maybe_propagate_compile_time_consts_in_xla(grad_op)
328
329  # See comment in while_loop.
330  outputs = [array_ops.identity(t) for t in outputs]
331  return _get_structured_grad_output(outputs, grads, body_grad_graph)
332
333
334def _preprocess_grad(grad, body_graph_output, while_op_output):
335  """Returns the initial gradient to be used for a given output tensor.
336
337  Args:
338    grad: the original gradient Tensor passed to the gradient function.
339    body_graph_output: the corresponding Tensor in the body graph.
340    while_op_output: the corresponding Tensor output of the While op.
341
342  Returns:
343    A Tensor or None.
344  """
345  # Set the incoming gradient of non-trainable inputs to None. It is possible
346  # that we receive non-None gradients for non-trainable types in nested while
347  # loops because we accumulate outputs of the inner while as variant tensors
348  # which are trainable and hence receive zeros_like tensors in the gradient
349  # pass. The non-trainable tensors then receive the popped zeros tensor from
350  # this zeros variant. The gradient for the loop vars corresponding to these
351  # tensors is None or zeros (this happens only if the loop var is accumulated
352  # as well) in _grad_fn so we reset these.
353  # TODO(b/118712257): Remove once we can handle None output grads in _grad_fn.
354  if not _is_trainable(body_graph_output):
355    return None
356
357  # GradientTape initializes resource and variant grads as None instead of
358  # zeros. Set to zeros so _GradientsHelper computes the gradients instead of
359  # returning None.
360  if (while_op_output.dtype in (dtypes.resource, dtypes.variant)
361      and grad is None):
362    return _zeros_like(while_op_output)
363
364  return grad
365
366
367# TODO(skyewm): make this return constants if op_output's shape is fully
368# defined (this can be done by checking the "shape" attr of resource vars).
369def _zeros_like(op_output):
370  """Like array_ops.zeros_like() but also accepts resource var handles."""
371  if op_output.dtype == dtypes.resource:
372    return array_ops.zeros(
373        gen_resource_variable_ops.variable_shape(op_output))
374  return array_ops.zeros_like(op_output)
375
376
377def _is_trainable(tensor):
378  """Returns whether the given tensor is trainable."""
379  if not gradients_util.IsTrainable(tensor):
380    return False
381
382  # Special case: untrainable accumulator output. The gradients algorithm
383  # doesn't know about tensor lists of untrainable elements. In theory the
384  # tensor list gradient functions should return None as appropriate, but
385  # because we can't return None from the gradient function we filter out
386  # untrainable accumulator output here to avoid computing the gradient at all.
387  if tensor.op.type == "TensorListPopBack" and tensor.value_index == 0:
388    assert tensor.dtype == dtypes.variant
389    element_type = tensor.op.get_attr("element_dtype")
390    return gradients_util.IsTrainable(element_type)
391
392  return True
393
394
395# TODO(srbs): Pull this into common utils for cond_v2 and while_v2.
396def _get_graph(while_op, func_attr_name):
397  """Returns `FuncGraph` for the given function attribute.
398
399  Args:
400    while_op: The While Operation.
401    func_attr_name: string
402
403  Returns:
404    `FuncGraph`
405  """
406  # TODO(srbs): Handle TensorShapeProto in function_def_to_graph.input_shapes.
407  input_shapes = [
408      tensor_shape.TensorShape(s) for s in while_op.get_attr("output_shapes")
409  ]
410  func_name = while_op.get_attr(func_attr_name).name
411  fdef = while_op.graph._get_function(func_name).definition
412  # `while_op.graph` may not be the same as `ops.get_default_graph()` e.g.
413  # if the `while_op` is in the body of another if/while/defun. We build the
414  # `func_graph` with `while_op.graph` as its `outer_graph`. This resembles how
415  # the `FuncGraph` was built in the forward pass. We need this so that we can
416  # appropriately capture references to outer tensors in the nested grad graphs.
417  with while_op.graph.as_default():
418    func_graph = function_def_to_graph.function_def_to_graph(fdef, input_shapes)
419  func_graph._while = while_op
420  return func_graph
421
422
423def _create_grad_func(ys, xs, grads, cond_graph, body_graph, name, while_op,
424                      maximum_iterations):
425  """Builds and returns the gradient FuncGraph of `func_graph` and its args.
426
427  The returned grad_func_graph must be called with the returned
428  args + grad_func_graph.captures.
429
430  Args:
431    ys: A `Tensor` or list of tensors to be differentiated.
432    xs: A `Tensor` or list of tensors to be used for differentiation.
433    grads: The incoming grads for `ys`.
434    cond_graph: FuncGraph for the forward cond function.
435    body_graph: FuncGraph for the forward body function.
436    name: Name of the returned gradient function.
437    while_op: The forward While op.
438    maximum_iterations: Tensor. The maximum number of iterations.
439
440  Returns:
441    2-tuple of (grad_func_graph, args).
442  """
443  assert len(ys) == len(grads)
444
445  total_iters = while_op.outputs[0]
446  counter = constant_op.constant(
447      0, dtype=total_iters.dtype, name="grad_counter")
448
449  args = [counter, maximum_iterations, total_iters] + list(grads)
450  # Note: The returned function does not have `args` in the list of
451  # `external_captures`.
452  grad_func_graph = func_graph_module.func_graph_from_py_func(
453      name,
454      lambda *args: _grad_fn(ys, xs, args, body_graph),
455      args, {},
456      func_graph=_WhileBodyGradFuncGraph(name, cond_graph, body_graph,
457                                         maximum_iterations))
458
459  # Add the popped accumulators to the list of outputs.
460  for internal_capture in grad_func_graph.internal_captures:
461    if internal_capture in grad_func_graph.popped_tensor_lists:
462      new_output = grad_func_graph.popped_tensor_lists[internal_capture]
463    elif internal_capture.dtype == dtypes.resource:
464      new_output = internal_capture
465    else:
466      raise ValueError("Tensor %s is in list of internal_captures but is"
467                       " neither a resource nor is in popped_tensor_lists." %
468                       str(internal_capture))
469    grad_func_graph.outputs.append(new_output)
470    grad_func_graph.structured_outputs.append(new_output)
471
472  return grad_func_graph, args
473
474
475def _grad_fn(ys, xs, args, func_graph):
476  """Computes the gradient of `func_graph` in the current graph.
477
478  This function builds the gradient graph of the corresponding forward-pass
479  `func_graph` by differentiating `func_graph`'s outputs w.r.t. its inputs.
480
481  Args:
482    ys: A `Tensor` or list of tensors to be differentiated.
483    xs: A `Tensor` or list of tensors to be used for differentiation.
484    args: The input arguments.
485      args[0] - Loop counter
486      args[1] - Total number of iterations.
487      args[2] - maximum_iterations.
488      args[3:] - Incoming gradients for `ys`.
489    func_graph: function.FuncGraph. The corresponding forward-pass function.
490
491  Returns:
492    The output gradient Tensors.
493  """
494  grad_ys = args[3:]
495
496  # Build the gradient graph. Note that this builds the gradient computation of
497  # func_graph in the current graph, which requires capturing tensors from
498  # func_graph. The captured func_graph tensors are resolved to external tensors
499  # after the forward While op has been rewritten in _resolve_grad_captures.
500  # TODO(srbs): Mark GradientsHelper as public?
501  grad_outs = gradients_util._GradientsHelper(
502      ys, xs, grad_ys=grad_ys, src_graph=func_graph,
503      unconnected_gradients="zero")
504
505  # TODO(b/118712257): Handle the case when grad_outs has None's e.g. when there
506  # is a tf.StopGradient in the loop body.
507  assert all(g is not None for g in grad_outs)
508  counter = args[0]
509  maximum_iterations = args[1]
510  total_iters = args[2]
511  return [counter + 1, maximum_iterations, total_iters] + grad_outs
512
513
514def _resolve_grad_captures(body_graph, body_grad_graph, while_op):
515  """Returns the tensors to pass as captured inputs to `body_grad_graph`.
516
517  `body_grad_graph` may have external references to:
518  1. Its outer graph containing the input gradients. These are left as-is.
519  2. Accumulators captured from the forward-pass graph. These should have been
520     added as `while_op` outputs after the gradient graph was built. We replace
521     these with the corresponding output of `while_op`, i.e. a tensor in
522     `body_graph.outer_graph`. In the case of nested control flow or functions,
523     the gradient logic handling `body_grad_graph.outer_graph` will make sure
524     the tensor from `body_graph.outer_graph` is also correctly captured.
525
526  Args:
527    body_graph: FuncGraph. The forward-pass body function.
528    body_grad_graph: FuncGraph. The body gradients function.
529    while_op: The forward-pass While Operation calling `body_graph`.
530
531  Returns:
532    A list of input tensors to be passed as the captured inputs to
533      `body_grad_graph`.
534  """
535  new_capture_inputs = []
536  for t in body_grad_graph.external_captures:
537    # All values captured by gradient computation should be from the forward
538    # graph or a captured resource variable (note that input gradients are
539    # regular non-captured inputs).
540    if t.graph == body_graph:
541      # Captured accumulator
542      t = while_op.outputs[t.graph.outputs.index(t)]
543      # Note: We rely on the capturing logic of the gradient While op graph to
544      # correctly capture the tensors in `body_graph.outer_graph`. Both cond_v2
545      # and while_v2 handle this while building their gradient functions.
546      assert t.graph == body_graph.outer_graph
547    else:
548      # Captured resource variable
549      assert t.dtype == dtypes.resource
550
551    new_capture_inputs.append(t)
552  return new_capture_inputs
553
554
555def _get_structured_grad_output(outputs, grads, body_grad_graph):
556  """Returns the values that should be returned from the while grad function.
557
558  Args:
559    outputs: the raw Tensor outputs of the grad While op.
560    grads: the input gradients to the gradient function.
561    body_grad_graph: _WhileBodyGradFuncGraph.
562
563  Returns:
564    A list of gradient values. May include Nones.
565  """
566  result = []
567  # outputs[0] is the loop counter.
568  # outputs[1] is maximum_iterations.
569  # outputs[2] is the total number of loop iterations.
570  outputs_idx = 3
571  structured_outputs_idx = 3
572  for g in grads:
573    # Set None as the output gradient for tensors with None input gradient.
574    if g is None:
575      result.append(None)
576      continue
577    output = body_grad_graph.structured_outputs[structured_outputs_idx]
578    structured_outputs_idx += 1
579    if isinstance(output, ops.IndexedSlices):
580      # TODO(skyewm): is there a more robust way to determine the order of
581      # flattened IndexedSlices components?
582      result.append(ops.IndexedSlices(
583          values=outputs[outputs_idx],
584          indices=outputs[outputs_idx + 1],
585          dense_shape=outputs[outputs_idx + 2]))
586      outputs_idx += 3
587    else:
588      assert isinstance(output, ops.Tensor)
589      result.append(outputs[outputs_idx])
590      outputs_idx += 1
591
592  return result
593
594
595def _get_accumulator(tensor):
596  r"""Returns TensorList if any containing accumulated values of tensor.
597
598  We try to find a pattern of the form:
599
600     input_tl   tensor
601        \        /
602    (TensorListPushBack)
603            |
604        output_tl
605
606  which satisfies the following conditions:
607
608  1. input_tl must be in tensor.graph.inputs.
609  2. output_tl or Identity(output_tl) must be in tensor.graph.outputs.
610  3. tensor.graph.input_index(input_tl) == tensor.graph.output_index(output_t).
611
612  output_tl or Identity(output_tl) (whichever is in tensor.graph.outputs) is
613  returned if such a pattern is found else None is returned.
614
615  Args:
616    tensor: The Tensor to be accumulated.
617
618  Returns:
619    A variant tensor in the same graph as `tensor` or None if no accumulator is
620    found.
621  """
622  assert isinstance(tensor.graph, func_graph_module.FuncGraph)
623
624  def get_func_graph_output(t):
625    """Returns t or Identity(t) whichever exists in graph outputs else None."""
626    if t in tensor.graph.outputs:
627      return t
628    # tf.defun adds an Identity for each output, check whether that is the case.
629    identity_op = t.consumers()[0]
630    if (identity_op.type == "Identity" and
631        identity_op.outputs[0] in tensor.graph.outputs):
632      return identity_op.outputs[0]
633    return None
634
635  for consumer in tensor.consumers():
636    # Find the consumer that is a TensorListPushBack node whose TensorList input
637    # is in the list of function inputs.
638    if (consumer.type != "TensorListPushBack" or
639        consumer.inputs[0] not in tensor.graph.inputs):
640      continue
641
642    output = get_func_graph_output(consumer.outputs[0])
643    if output is None:
644      # The TensorList output of `consumer` is not in the list of function
645      # outputs.
646      continue
647
648    accum_input_idx = tensor.graph.inputs.index(consumer.inputs[0])
649    accum_output_idx = tensor.graph.outputs.index(output)
650    if accum_input_idx == accum_output_idx:
651      return output
652  return None
653
654
655class _WhileBodyGradFuncGraph(util.WhileBodyFuncGraph):
656  """FuncGraph for the gradient function of the body of a While op.
657
658  Contains the logic for capturing the tensors from the body of the forward
659  While op which is as follows:
660  1. If the tensor is of resource type (these are not accumulated):
661     a. Ensure that the tensor is a loop invariant, i.e., it exists in both loop
662        inputs and outputs at the same index.
663     b. Lookup the corresponding resource tensor in the forward outer graph and
664        try to capture that.
665  2. If the tensor is not of resource type:
666     a. Create an accumulator for that tensor and output it from the forward
667        pass. Note this also requires adding it as an input to the forward pass.
668     b. Capture the accumulator from the forward pass in this FuncGraph. This
669        will later be resolved to the correct output of the forward While op.
670     c. Pop a value from the captured placeholder and use it as the captured
671        value for the forward pass tensor.
672
673  This only allows capturing tensors in the forward graph. A ValueError is
674  raised if an attempt is made to capture a tensor not in the forward graph.
675  To manually capture capture a tensor that is not in the forward graph, call
676  `capture` with `whitelisted=True`.
677
678  Note: The `captures` dict does not contain the forward tensor since it is not
679  directly captured. It contains the accumulator corresponding to this forward
680  tensor.
681
682  Attributes:
683    while_op_needs_rewrite: True if any non-resource intermediates were
684      captured, meaning the forward While op needs to be rewritten to output the
685      corresponding accumulators.
686    empty_tensor_lists: list of EmptyTensorList tensors to be used as initial
687      input to the new accumulators in the forward graph.
688    popped_tensor_lists: dict from the captured accumulator placeholder to the
689      TensorList obtained after popping the intermediate tensor from it. The
690      values of this dict need to be added to the list of outputs.
691  """
692
693  def __init__(self, name, forward_cond_graph, forward_body_graph,
694               maximum_iterations):
695    super(_WhileBodyGradFuncGraph, self).__init__(name)
696    self.empty_tensor_lists = []
697    self.popped_tensor_lists = {}
698    # FuncGraph for the body of the forward While op.
699    self._forward_graph = forward_body_graph
700    # FuncGraph for the cond of the forward While op.
701    self._forward_cond_graph = forward_cond_graph
702    self._maximum_iterations = maximum_iterations
703    # Dict from forward intermediate tensor to its indirectly captured tensor
704    # in this graph. Indirect capturing happens in two ways:
705    # 1. For non-resource tensors we capture their accumulators from the forward
706    #    outer graph and pop values from that accumulator inside this graph
707    #    using TensorListPopBack.
708    # 2. For resource tensors we directly capture their corresponding tensor
709    #    in the forward outer graph.
710    self._indirect_captures = {}
711
712  @property
713  def while_op_needs_rewrite(self):
714    return self.empty_tensor_lists
715
716  def capture(self, tensor, name=None, whitelisted=False):
717    """Selectively captures external tensors.
718
719    If `whitelisted` is False only allows capturing tensors in the
720    `_forward_graph`.
721
722    Args:
723      tensor: Tensor. May be from this FuncGraph or a different graph.
724      name: Optional name if a placeholder is created.
725      whitelisted: If False (default), only allows capturing tensors from the
726        forward graph.
727
728    Returns:
729      The placeholder in this graph for the tensor.
730
731    Raises:
732      ValueError: If attempting to capture an external tensor not in the forward
733        graph with `whitelisted` set to False.
734    """
735    if (not whitelisted and tensor.graph is not self and
736        tensor.graph != self._forward_graph):
737      raise ValueError("Attempting to capture tensor %s which is not in the "
738                       "forward graph but in %s." %
739                       (str(tensor), _graph_name(tensor.graph)))
740    return super(_WhileBodyGradFuncGraph, self).capture(tensor, name)
741
742  def _capture_helper(self, tensor, name):
743    if tensor.graph is not self._forward_graph:
744      return super(_WhileBodyGradFuncGraph, self)._capture_helper(tensor, name)
745
746    while tensor.op.type == "Identity":
747      # We do not accumulate the output of identity nodes so we try to capture
748      # the input of the Identity node instead.
749      tensor = tensor.op.inputs[0]
750
751    captured_tensor = self._indirect_captures.get(tensor)
752    if captured_tensor is not None:
753      return captured_tensor
754
755    # Resource tensors are not accumulated and handled specially.
756    if tensor.dtype == dtypes.resource:
757      return self._resource_capture_helper(tensor)
758
759    # Create or find an existing accumulator output for `tensor` in the forward
760    # graph, and fetch from this accumulator in the gradient graph to get the
761    # raw intermediate value.
762    accumulator = _get_accumulator(tensor)
763    if accumulator is None:
764      # Create the initial empty tensor list.
765      with self._forward_graph.outer_graph.as_default():
766        tensor_list = list_ops.empty_tensor_list(
767            element_dtype=tensor.dtype, element_shape=tensor.shape,
768            max_num_elements=self._maximum_iterations)
769      self.empty_tensor_lists.append(tensor_list)
770
771      # Push the intermediate tensor to the tensor list. This captures
772      # `tensor_list`.
773      with self._forward_graph.as_default():
774        accumulator = list_ops.tensor_list_push_back(tensor_list, tensor)
775      # Add the modified tensor list to the list of outputs. This output will be
776      # all the accumulated values.
777      self._forward_graph.outputs.append(accumulator)
778
779      # Capture in the cond graph as well so the forward cond and body inputs
780      # match.
781      with self._forward_cond_graph.as_default():
782        self._forward_cond_graph.capture(tensor_list)
783
784    # Capture the accumulator tensor list in the gradient graph directly from
785    # the forward graph -- we'll later modify this to capture the final list
786    # output by the forward While op instead.
787    captured_accumulator = super(_WhileBodyGradFuncGraph, self)._capture_helper(
788        accumulator, name)
789
790    # Pop the intermediate value from the tensor list in the gradient graph.
791    new_tensor_list, captured_tensor = list_ops.tensor_list_pop_back(
792        captured_accumulator, element_dtype=tensor.dtype)
793
794    self._indirect_captures[tensor] = captured_tensor
795    self.popped_tensor_lists[captured_accumulator] = new_tensor_list
796    return captured_tensor
797
798  def _resource_capture_helper(self, tensor):
799    """Returns the captured resource tensor.
800
801    Resource-type tensors are not accumulated. If a resource tensor exists in
802    the loop body it must either be a loop input or an output of a nested While
803    op inside the loop body which had captured the external resource.
804
805    Args:
806      tensor: the external resource Tensor to be captured.
807
808    Returns:
809      Tensor in this graph.
810    """
811    assert tensor.dtype == dtypes.resource
812
813    index = self._resource_input_index(
814        tensor.name,
815        [t.name for t in self._forward_graph.inputs],
816        {op.name: op.node_def for op in self._forward_graph.get_operations()},
817        self._forward_graph._functions)
818
819    input_placeholder = self._forward_graph.inputs[index]
820    tensor_in_outer_graph = self._forward_graph._while.inputs[index]
821
822    assert input_placeholder.dtype == dtypes.resource
823    assert tensor_in_outer_graph.dtype == dtypes.resource
824    # This must be a loop invariant.
825    assert input_placeholder == self._forward_graph.outputs[index], (
826        "Resource tensors must be loop invariants %s." %
827        tensor_in_outer_graph)
828
829    self._indirect_captures[tensor] = self.capture(
830        tensor_in_outer_graph, whitelisted=True)
831    return self._indirect_captures[tensor]
832
833  def _resource_input_index(self, tensor_name, input_names, node_defs,
834                            functions):
835    """Returns the index of the input corresponding to `tensor_name`.
836
837    This method is used to find the corresponding index of an arbitrary resource
838    tensor in a function (the function could be a loop body). We assume that
839    resource handles are never created in functions, so that every resource
840    tensor can be traced back to a function input.
841
842    The awkward signature of this method is to make it work with both FuncGraphs
843    and FunctionDefs. This is so we can recurse on function call ops without
844    building the corresponding FuncGraph (note that even if a FuncGraph for a
845    FunctionDef already exists, the input/output/node names may have been
846    changed when the FuncGraph was serialized to the FunctionDef, which makes it
847    unusable with this algorithm).
848
849    Args:
850      tensor_name: the name of the resource tensor to be resolved to an input.
851      input_names: a list of the names of all inputs to the function.
852      node_defs: a dict mapping op name -> NodeDef for every op in the function.
853      functions: a dict mapping function name -> _EagerDefinedFunction.
854
855    Returns:
856      The index into input_names corresponding to `tensor_name`.
857    """
858    while tensor_name not in input_names:
859      # FunctionDefs and graphs use different tensor naming conventions.
860      parts = tensor_name.split(":")
861      if len(parts) == 3:
862        op_name, _, output_idx = parts
863      elif len(parts) == 2:
864        op_name, output_idx = parts
865      else:
866        assert len(parts) == 1
867        op_name = parts[0]
868        output_idx = 0
869      output_idx = int(output_idx)
870      node_def = node_defs[op_name]
871
872      if node_def.op == "While":
873        # Captured resources occur at the same index in the lists of inputs and
874        # outputs of a while op. So we lookup the input of `tensor.op` at the
875        # same index as the index of `tensor` in the `tensor.op.outputs`.
876        tensor_name = node_def.input[output_idx]
877      elif node_def.op in ("PartitionedCall", "StatefulPartitionedCall"):
878        # Functions output any captured resource tensors used by their
879        # gradients.  `tensor_name` is one of these outputs from a nested
880        # function call, so recursively find the corresponding input in the
881        # nested FunctionDef.
882        func_name = node_def.attr["f"].func.name
883        fdef = functions[func_name].definition
884        output_arg_name = fdef.signature.output_arg[output_idx].name
885        output_tensor_name = fdef.ret[output_arg_name]
886        input_index = self._resource_input_index(
887            output_tensor_name,
888            [arg.name for arg in fdef.signature.input_arg],
889            {ndef.name: ndef for ndef in fdef.node_def},
890            functions)
891        tensor_name = node_def.input[input_index]
892      else:
893        # We assume there are no other ops types that will "forward" resource
894        # handles like this, so all other handles must have been created by the
895        # op. (Note that cond_v2 wraps resource handle outputs in optionals,
896        # which we'll end up accumulating).
897        raise ValueError(
898            "Taking gradient of a while loop which creates "
899            "a resource in its body is not supported: %s" % op_name)
900
901    return input_names.index(tensor_name)
902
903
904def _check_shapes_compat(output_tensors, shape_invariants, input_tensors):
905  for (t, shape, input_t) in zip(output_tensors, shape_invariants,
906                                 input_tensors):
907    if not control_flow_ops._ShapeLessThanOrEqual(t.shape, shape):
908      raise ValueError(
909          "Input tensor '%s' enters the loop with shape %s, but has "
910          "shape %s after one iteration. To allow the shape to vary across "
911          "iterations, use the `shape_invariants` argument of tf.while_loop to "
912          "specify a less-specific shape." % (input_t.name, shape, t.shape))
913
914
915def _check_num_inputs_outputs(cond_graph, body_graph, num_flattened_loop_vars):
916  """Checks the number of inputs/outputs of `cond_graph` and `body_graph`."""
917  assert len(cond_graph.inputs) == num_flattened_loop_vars, (
918      "cond_graph takes %d inputs; Expected: %d" % (len(cond_graph.inputs),
919                                                    num_flattened_loop_vars))
920  assert len(cond_graph.outputs) == 1, (
921      "cond_graph has %d outputs; Expected: 1" % len(cond_graph.outputs))
922  assert len(body_graph.inputs) == num_flattened_loop_vars, (
923      "body_graph takes %d inputs; Expected: %d" % (len(body_graph.inputs),
924                                                    num_flattened_loop_vars))
925  assert len(body_graph.outputs) == num_flattened_loop_vars, (
926      "body_graph has %d outputs; Expected: %d" % (len(body_graph.outputs),
927                                                   num_flattened_loop_vars))
928
929
930def _copy_handle_data(src_tensors, tgt_tensors):
931  for src_t, tgt_t in zip(src_tensors, tgt_tensors):
932    custom_gradient.copy_handle_data(src_t, tgt_t)
933
934
935# TODO(srbs): This method should be in control_flow_util but that introduces
936# a circular dependency ops -> control_flow_util -> ops.
937def _is_in_xla_context():
938  """Returns whether the current context is inside an XLA context."""
939  outer_graph = ops.get_default_graph()
940  # The `_control_flow_context` is not copied when building a FuncGraph so
941  # we look it up from the base graph.
942  while isinstance(outer_graph, func_graph_module.FuncGraph):
943    outer_graph = outer_graph.outer_graph
944  cur_ctxt = outer_graph._get_control_flow_context()  # pylint: disable=protected-access
945  return control_flow_util.GetContainingXLAContext(cur_ctxt) is not None
946
947
948def _graph_name(graph):
949  if isinstance(graph, func_graph_module.FuncGraph):
950    return graph.name
951  return "Base"
952
953
954def _pack_sequence_as(structure_with_tas, loop_vars):
955  """Like `nest.pack_sequence_as` but also replaces flows with TensorArrays."""
956
957  def flow_to_tensor_array(flow, ta):  # pylint: disable=missing-docstring
958    return (tensor_array_ops.build_ta_with_new_flow(ta, flow) if isinstance(  # pylint: disable=g-long-ternary
959        ta, tensor_array_ops.TensorArray) else flow)
960
961  flattened_loop_vars = [
962      flow_to_tensor_array(*z)
963      for z in zip(nest.flatten(loop_vars), nest.flatten(structure_with_tas))
964  ]
965  return nest.pack_sequence_as(structure_with_tas, flattened_loop_vars)
966
967
968def _tensor_array_to_flow(loop_vars):
969
970  def f(maybe_ta):
971    if isinstance(maybe_ta, tensor_array_ops.TensorArray):
972      return maybe_ta.flow
973    return maybe_ta
974
975  return nest.map_structure(f, loop_vars)
976
977
978def _build_signature(loop_vars, shape_invariants):
979  return nest.pack_sequence_as(loop_vars, [
980      tensor_spec.TensorSpec(s, t.dtype, name=t.op.name)
981      for s, t in zip(nest.flatten(shape_invariants), nest.flatten(loop_vars))
982  ])
983
984
985def _build_maximum_iterations_loop_var(maximum_iterations):
986  if maximum_iterations is None:
987    # Default value for max_num_elements to EmptyTensorList meaning that the
988    # list size is unbounded.
989    maximum_iterations = -1
990  # EmptyTensorList expects `max_num_elements` to be of type int32.
991  return ops.convert_to_tensor(
992      maximum_iterations, dtype=dtypes.int32, name="maximum_iterations")
993
994# pylint: enable=protected-access
995