• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""Methods for rewriting while_v2 grad functions with IndexedSlices output."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.framework import constant_op
22from tensorflow.python.framework import dtypes
23from tensorflow.python.framework import func_graph
24from tensorflow.python.framework import ops
25from tensorflow.python.framework import tensor_shape
26from tensorflow.python.ops import array_ops
27from tensorflow.python.ops import gen_resource_variable_ops
28from tensorflow.python.util import nest
29
30
31def rewrite_grad_indexed_slices(grads, body_grad_graph, loop_vars,
32                                forward_inputs):
33  """Handles special case of IndexedSlices returned from while gradient.
34
35  Some gradient functions return IndexedSlices instead of a Tensor (e.g. the
36  gradient of Gather ops). When this happens in the gradient of a while body,
37  the resulting gradient body function will have mismatched inputs and outputs,
38  since the input is a single Tensor, but the IndexedSlices gets unnested into
39  three output Tensors.
40
41  This function fixes this by rewriting the gradient body to have three inputs
42  to match the three outputs, i.e., it effectively converts the input Tensor
43  into an input IndexedSlices. It also returns new `loop_vars` to reflect the
44  new inputs.
45
46  Args:
47    grads: the input gradient Tensors to the while gradient computation.
48    body_grad_graph: _WhileBodyGradFuncGraph.
49    loop_vars: list of Tensors. The inputs to body_grad_graph.
50    forward_inputs: list of Tensors. The (flat) inputs to the forward-pass
51      While op.
52
53  Returns:
54    The new loop_vars to pass to body_grad_graph.
55  """
56  # Match up body_grad_graph.structured_outputs with the corresponding
57  # forward_inputs.
58  #
59  # Note that we don't expect a gradient computation to have structured output
60  # (e.g. no nested lists), so no need to flatten
61  # body_grad_graph.structured_outputs. However, structured_outputs may still
62  # contain composite tensors such as IndexedSlices, unlike
63  # body_grad_graph.outputs, which contains flattened composite tensors.
64  inputs_with_grads = [t for g, t in zip(grads, forward_inputs)
65                       if g is not None]
66  # Skip loop counter, maximum_iterations and total number of loop iterations.
67  structured_outputs = body_grad_graph.structured_outputs[3:]
68
69  for forward_input, output in zip(inputs_with_grads, structured_outputs):
70    if not isinstance(output, ops.IndexedSlices): continue
71
72    if forward_input.dtype == dtypes.resource:
73      # TODO(skyewm): In theory we should use this for all captured inputs, not
74      # just resource handles (which can only be captured). We can do this by
75      # checking that forward_input is passed straight through to its output.
76      loop_vars = _rewrite_input_as_indexed_slices(body_grad_graph, output,
77                                                   forward_input, loop_vars)
78    else:
79      _rewrite_output_as_tensor(body_grad_graph, output)
80
81  return loop_vars
82
83
84def _get_tensor_index_in_iterable(iterable, t):
85  """Returns index of first occurence of `t`, raises ValueError if not found."""
86  for i, elem in enumerate(iterable):
87    if t is elem:
88      return i
89  raise ValueError("%s is not in iterable" % str(t))
90
91
92def _rewrite_output_as_tensor(body_grad_graph, grad_output_slices):
93  """Rewrites grad_output_slices to be a Tensor output.
94
95  Args:
96    body_grad_graph: _WhileBodyGradFuncGraph.
97    grad_output_slices: IndexedSlices output of body_grad_graph.
98  """
99  with body_grad_graph.as_default():
100    new_output = ops.convert_to_tensor_v2(grad_output_slices)
101
102  idx = _get_tensor_index_in_iterable(body_grad_graph.structured_outputs,
103                                      grad_output_slices)
104  body_grad_graph.structured_outputs[idx] = new_output
105  body_grad_graph.outputs = func_graph.flatten(
106      body_grad_graph.structured_outputs)
107
108
109def _rewrite_input_as_indexed_slices(body_grad_graph, grad_output_slices,
110                                     forward_input, loop_vars):
111  """Rewrites grad_output_slices's corresponding input to be an IndexedSlices.
112
113  This rewrite requires that forward_input was captured in the forward loop,
114  i.e. is not a user-specified loop variable. This is important because the
115  rewrite assumes that forward_input is passed through to its corresponding
116  output unchanged. This assumption is used in _rewrite_input_as_indexed_slices,
117  which depends on the exact gradient structure produced by the input's fanout.
118
119  This can yield a more efficient computation than using
120  _rewrite_output_as_tensor, since it preserves the IndexedSlices structure
121  instead of converting the IndexedSlices to a dense Tensor.
122
123  Args:
124    body_grad_graph: _WhileBodyGradFuncGraph.
125    grad_output_slices: IndexedSlices output of body_grad_graph.
126    forward_input: the corresponding Tensor input to the forward loop.
127    loop_vars: list of Tensors. The inputs to body_grad_graph.
128
129  Returns:
130    The new loop_vars to pass to body_grad_graph.
131  """
132  # Create initial IndexedSlices that will be the input to the grad While
133  # op. This will start as zeros, and accumulate the IndexedSlices grad output.
134  # Note that because forward_input is captured and not a loop var, its incoming
135  # gradient should always be zero.
136  init_slices = _create_grad_indexed_slices_init(grad_output_slices,
137                                                 forward_input)
138
139  # Create a new version of grad_output_slices's gradient computation that uses
140  # the new IndexedSlices input instead of the original Tensor input. We'll
141  # return the new computation and leave the old computation as dead code.
142  # TODO(skyewm): considering pruning body_grad_graph to remove the old
143  # computation.
144  with body_grad_graph.as_default():
145    input_slices = ops.IndexedSlices(
146        values=body_grad_graph.capture(init_slices.values, allowlisted=True),
147        indices=body_grad_graph.capture(init_slices.indices, allowlisted=True),
148        dense_shape=body_grad_graph.capture(
149            init_slices.dense_shape, allowlisted=True))
150
151    # Remove the captured tensors from the function inputs. We'll add them back
152    # at the correct index in _update_indexed_slices_param.
153    for t in _flatten(init_slices):
154      captured_t = body_grad_graph.captures.pop(t)
155      body_grad_graph.inputs.remove(captured_t)
156
157    new_output_slices = _rewrite_grad_indexed_slices_output(grad_output_slices,
158                                                            input_slices)
159
160  # Update body_grad_graph's inputs and outputs to reflect the new
161  # IndexedSlices computation.
162  return _update_indexed_slices_param(
163      body_grad_graph, loop_vars, init_slices, input_slices, new_output_slices,
164      grad_output_slices)
165
166
167def _create_grad_indexed_slices_init(grad_output_slices, forward_input):
168  """Creates an IndexedSlices to pass as input to the while grad function.
169
170  Args:
171    grad_output_slices: IndexedSlices. The corresponding while grad function
172      output.
173    forward_input: Tensor. The corresponding input to the forward while op.
174
175  Returns:
176    Zeros IndexedSlices, created in current Graph.
177  """
178  assert isinstance(grad_output_slices, ops.IndexedSlices)
179  assert isinstance(forward_input, ops.Tensor)
180  values_out = grad_output_slices.values
181  indices_out = grad_output_slices.indices
182
183  # Create the initial values tensor.
184  if values_out.shape.is_fully_defined():
185    values_shape = tensor_shape.TensorShape([0] +
186                                            values_out.shape.as_list()[1:])
187    values = array_ops.zeros(values_shape, dtype=values_out.dtype,
188                             name="values_init")
189  else:
190    if forward_input.dtype == dtypes.resource:
191      forward_shape = gen_resource_variable_ops.variable_shape(forward_input)
192    else:
193      forward_shape = array_ops.shape(forward_input)
194    values_shape = array_ops.concat([[0], forward_shape[1:]], 0)
195    values = array_ops.zeros(values_shape, dtype=values_out.dtype,
196                             name="values_init")
197
198  # Create the initial indices tensor.
199  indices = constant_op.constant([], indices_out.dtype, name="indices_init")
200
201  # Create the initial dense_shape tensor. We assume is the same shape as
202  # forward_input, since captured tensors don't change shape across loop
203  # iterations.
204  if forward_input.dtype == dtypes.resource:
205    shape = gen_resource_variable_ops.variable_shape(forward_input,
206                                                     name="shape_init")
207  else:
208    shape = array_ops.shape(forward_input, name="shape_init")
209
210  return ops.IndexedSlices(values=values, indices=indices, dense_shape=shape)
211
212
213def _rewrite_grad_indexed_slices_output(old_output_slices, new_input_slices):
214  """Creates a new version of old_output_slices with new_input_slices as input.
215
216  This method assumes that old_output_slices.{values,indices} are produced by
217  concatenating the incoming gradient Tensor input with the IndexedSlices
218  produced by the gradient computation of the while body. See
219  backprop.aggregate_indexed_slices_gradients for where these concats are
220  constructed. We build new concats that use new_input_slices instead of the
221  original Tensor input.
222
223  Args:
224    old_output_slices: original IndexedSlices output of while gradient.
225    new_input_slices: new IndexedSlices to use as input to while gradient.
226
227  Returns:
228    A new IndexedSlices to replace old_output_slices.
229  """
230
231  def rewrite(old_output, new_input):
232    assert old_output.type == "Identity"
233    concat_op = old_output.inputs[0].op
234    assert concat_op.type == "ConcatV2"
235    # Don't include axis arg
236    old_concat_args = concat_op.inputs[:-1]
237    # We assume that the original gradient input was the first argument to the
238    # concat op.
239    # TODO(skyewm): do this in a more robust way.
240    return array_ops.concat([new_input] + old_concat_args[1:], 0)
241
242  values = rewrite(old_output_slices.values.op, new_input_slices.values)
243  indices = rewrite(old_output_slices.indices.op, new_input_slices.indices)
244  return ops.IndexedSlices(values=values, indices=indices,
245                           dense_shape=new_input_slices.dense_shape)
246
247
248def _update_indexed_slices_param(graph, loop_vars, init_slices, input_slices,
249                                 output_slices, old_output_slices):
250  """Updates graph with new IndexedSlices input/output.
251
252  Updates graph's metadata to output the gradient computation defined by
253  init_slices, input_slices, and output_slices, instead of outputting
254  old_output_slices. Also returns a new version of loop_vars with init_slices
255  replacing the old input.
256
257  Args:
258    graph: _WhileBodyGradFuncGraph.
259    loop_vars: the inputs to graph.
260    init_slices: the new IndexedSlices to use as input to graph.
261    input_slices: the new IndexedSlices in graph that should be fed by
262      init_slices.
263    output_slices: the new IndexedSlices in graph that should be the
264      corresponding output to input_slices.
265    old_output_slices: the IndexedSlices in graph that are currently
266      being output.
267
268  Returns:
269    New loop_vars to pass to graph.
270  """
271  structured_idx = _get_tensor_index_in_iterable(graph.structured_outputs,
272                                                 old_output_slices)
273  # We assume that the component tensors of old_output_slices appear
274  # sequentially in graph.outputs. We use the first of these tensors
275  # as the reference index.
276  flat_idx = _get_tensor_index_in_iterable(
277      graph.outputs,
278      func_graph.flatten(old_output_slices)[0])
279
280  graph.structured_outputs[structured_idx] = output_slices
281  graph.outputs = func_graph.flatten(
282      graph.structured_outputs)
283
284  graph.inputs = (graph.inputs[:flat_idx] + _flatten(input_slices) +
285                  graph.inputs[flat_idx + 1:])
286
287  return loop_vars[:flat_idx] + _flatten(init_slices) + loop_vars[flat_idx + 1:]
288
289
290def _flatten(arg):
291  return nest.flatten(arg, expand_composites=True)
292