• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""Methods for rewriting while_v2 grad functions with IndexedSlices output."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.framework import constant_op
22from tensorflow.python.framework import dtypes
23from tensorflow.python.framework import func_graph
24from tensorflow.python.framework import ops
25from tensorflow.python.framework import tensor_shape
26from tensorflow.python.ops import array_ops
27from tensorflow.python.ops import gen_resource_variable_ops
28from tensorflow.python.util import nest
29
30
31def rewrite_grad_indexed_slices(grads, body_grad_graph, loop_vars,
32                                forward_inputs):
33  """Handles special case of IndexedSlices returned from while gradient.
34
35  Some gradient functions return IndexedSlices instead of a Tensor (e.g. the
36  gradient of Gather ops). When this happens in the gradient of a while body,
37  the resulting gradient body function will have mismatched inputs and outputs,
38  since the input is a single Tensor, but the IndexedSlices gets unnested into
39  three output Tensors.
40
41  This function fixes this by rewriting the gradient body to have three inputs
42  to match the three outputs, i.e., it effectively converts the input Tensor
43  into an input IndexedSlices. It also returns new `loop_vars` to reflect the
44  new inputs.
45
46  Args:
47    grads: the input gradient Tensors to the while gradient computation.
48    body_grad_graph: _WhileBodyGradFuncGraph.
49    loop_vars: list of Tensors. The inputs to body_grad_graph.
50    forward_inputs: list of Tensors. The (flat) inputs to the forward-pass While
51      op.
52
53  Returns:
54    The new loop_vars to pass to body_grad_graph.
55  """
56  # Match up body_grad_graph.structured_outputs with the corresponding
57  # forward_inputs.
58  #
59  # Note that we don't expect a gradient computation to have structured output
60  # (e.g. no nested lists), so no need to flatten
61  # body_grad_graph.structured_outputs. However, structured_outputs may still
62  # contain composite tensors such as IndexedSlices, unlike
63  # body_grad_graph.outputs, which contains flattened composite tensors.
64  inputs_with_grads = [
65      t for g, t in zip(grads, forward_inputs) if g is not None
66  ]
67  # Skip loop counter, maximum_iterations and total number of loop iterations.
68  structured_outputs = body_grad_graph.structured_outputs[3:]
69
70  for forward_input, output in zip(inputs_with_grads, structured_outputs):
71    if not isinstance(output, ops.IndexedSlices):
72      continue
73
74    if forward_input.dtype == dtypes.resource:
75      # TODO(skyewm): In theory we should use this for all captured inputs, not
76      # just resource handles (which can only be captured). We can do this by
77      # checking that forward_input is passed straight through to its output.
78      loop_vars = _rewrite_input_as_indexed_slices(body_grad_graph, output,
79                                                   forward_input, loop_vars)
80    else:
81      _rewrite_output_as_tensor(body_grad_graph, output)
82
83  return loop_vars
84
85
86def _get_tensor_index_in_iterable(iterable, t):
87  """Returns index of first occurence of `t`, raises ValueError if not found."""
88  for i, elem in enumerate(iterable):
89    if t is elem:
90      return i
91  raise ValueError(f"Element `{t!r}` is not found in iterable `{iterable!r}`.")
92
93
94def _rewrite_output_as_tensor(body_grad_graph, grad_output_slices):
95  """Rewrites grad_output_slices to be a Tensor output.
96
97  Args:
98    body_grad_graph: _WhileBodyGradFuncGraph.
99    grad_output_slices: IndexedSlices output of body_grad_graph.
100  """
101  with body_grad_graph.as_default():
102    new_output = ops.convert_to_tensor_v2(grad_output_slices)
103
104  idx = _get_tensor_index_in_iterable(body_grad_graph.structured_outputs,
105                                      grad_output_slices)
106  body_grad_graph.structured_outputs[idx] = new_output
107  body_grad_graph.outputs = func_graph.flatten(
108      body_grad_graph.structured_outputs)
109
110
111def _rewrite_input_as_indexed_slices(body_grad_graph, grad_output_slices,
112                                     forward_input, loop_vars):
113  """Rewrites grad_output_slices's corresponding input to be an IndexedSlices.
114
115  This rewrite requires that forward_input was captured in the forward loop,
116  i.e. is not a user-specified loop variable. This is important because the
117  rewrite assumes that forward_input is passed through to its corresponding
118  output unchanged. This assumption is used in _rewrite_input_as_indexed_slices,
119  which depends on the exact gradient structure produced by the input's fanout.
120
121  This can yield a more efficient computation than using
122  _rewrite_output_as_tensor, since it preserves the IndexedSlices structure
123  instead of converting the IndexedSlices to a dense Tensor.
124
125  Args:
126    body_grad_graph: _WhileBodyGradFuncGraph.
127    grad_output_slices: IndexedSlices output of body_grad_graph.
128    forward_input: the corresponding Tensor input to the forward loop.
129    loop_vars: list of Tensors. The inputs to body_grad_graph.
130
131  Returns:
132    The new loop_vars to pass to body_grad_graph.
133  """
134  # Create initial IndexedSlices that will be the input to the grad While
135  # op. This will start as zeros, and accumulate the IndexedSlices grad output.
136  # Note that because forward_input is captured and not a loop var, its incoming
137  # gradient should always be zero.
138  init_slices = _create_grad_indexed_slices_init(grad_output_slices,
139                                                 forward_input)
140
141  # Create a new version of grad_output_slices's gradient computation that uses
142  # the new IndexedSlices input instead of the original Tensor input. We'll
143  # return the new computation and leave the old computation as dead code.
144  # TODO(skyewm): considering pruning body_grad_graph to remove the old
145  # computation.
146  with body_grad_graph.as_default():
147    input_slices = ops.IndexedSlices(
148        values=body_grad_graph.capture(init_slices.values, allowlisted=True),
149        indices=body_grad_graph.capture(init_slices.indices, allowlisted=True),
150        dense_shape=body_grad_graph.capture(
151            init_slices.dense_shape, allowlisted=True))
152
153    # Remove the captured tensors from the function inputs. We'll add them back
154    # at the correct index in _update_indexed_slices_param.
155    for t in _flatten(init_slices):
156      captured_t = body_grad_graph.captures.pop(t)
157      body_grad_graph.inputs.remove(captured_t)
158
159    new_output_slices = _rewrite_grad_indexed_slices_output(
160        grad_output_slices, input_slices)
161
162  # Update body_grad_graph's inputs and outputs to reflect the new
163  # IndexedSlices computation.
164  return _update_indexed_slices_param(body_grad_graph, loop_vars, init_slices,
165                                      input_slices, new_output_slices,
166                                      grad_output_slices)
167
168
169def _create_grad_indexed_slices_init(grad_output_slices, forward_input):
170  """Creates an IndexedSlices to pass as input to the while grad function.
171
172  Args:
173    grad_output_slices: IndexedSlices. The corresponding while grad function
174      output.
175    forward_input: Tensor. The corresponding input to the forward while op.
176
177  Returns:
178    Zeros IndexedSlices, created in current Graph.
179  """
180  assert isinstance(grad_output_slices, ops.IndexedSlices)
181  assert isinstance(forward_input, ops.Tensor)
182  values_out = grad_output_slices.values
183  indices_out = grad_output_slices.indices
184
185  # Create the initial values tensor.
186  if values_out.shape.is_fully_defined():
187    values_shape = tensor_shape.TensorShape([0] +
188                                            values_out.shape.as_list()[1:])
189    values = array_ops.zeros(
190        values_shape, dtype=values_out.dtype, name="values_init")
191  else:
192    if forward_input.dtype == dtypes.resource:
193      forward_shape = gen_resource_variable_ops.variable_shape(forward_input)
194    else:
195      forward_shape = array_ops.shape(forward_input)
196    values_shape = array_ops.concat([[0], forward_shape[1:]], 0)
197    values = array_ops.zeros(
198        values_shape, dtype=values_out.dtype, name="values_init")
199
200  # Create the initial indices tensor.
201  indices = constant_op.constant([], indices_out.dtype, name="indices_init")
202
203  # Create the initial dense_shape tensor. We assume is the same shape as
204  # forward_input, since captured tensors don't change shape across loop
205  # iterations.
206  if forward_input.dtype == dtypes.resource:
207    shape = gen_resource_variable_ops.variable_shape(
208        forward_input, name="shape_init")
209  else:
210    shape = array_ops.shape(forward_input, name="shape_init")
211
212  return ops.IndexedSlices(values=values, indices=indices, dense_shape=shape)
213
214
215def _rewrite_grad_indexed_slices_output(old_output_slices, new_input_slices):
216  """Creates a new version of old_output_slices with new_input_slices as input.
217
218  This method assumes that old_output_slices.{values,indices} are produced by
219  concatenating the incoming gradient Tensor input with the IndexedSlices
220  produced by the gradient computation of the while body. See
221  backprop.aggregate_indexed_slices_gradients for where these concats are
222  constructed. We build new concats that use new_input_slices instead of the
223  original Tensor input.
224
225  Args:
226    old_output_slices: original IndexedSlices output of while gradient.
227    new_input_slices: new IndexedSlices to use as input to while gradient.
228
229  Returns:
230    A new IndexedSlices to replace old_output_slices.
231  """
232
233  def rewrite(old_output, new_input):
234    assert old_output.type == "Identity"
235    concat_op = old_output.inputs[0].op
236    assert concat_op.type == "ConcatV2"
237    # Don't include axis arg
238    old_concat_args = concat_op.inputs[:-1]
239    # We assume that the original gradient input was the first argument to the
240    # concat op.
241    # TODO(skyewm): do this in a more robust way.
242    return array_ops.concat([new_input] + old_concat_args[1:], 0)
243
244  values = rewrite(old_output_slices.values.op, new_input_slices.values)
245  indices = rewrite(old_output_slices.indices.op, new_input_slices.indices)
246  return ops.IndexedSlices(
247      values=values, indices=indices, dense_shape=new_input_slices.dense_shape)
248
249
250def _update_indexed_slices_param(graph, loop_vars, init_slices, input_slices,
251                                 output_slices, old_output_slices):
252  """Updates graph with new IndexedSlices input/output.
253
254  Updates graph's metadata to output the gradient computation defined by
255  init_slices, input_slices, and output_slices, instead of outputting
256  old_output_slices. Also returns a new version of loop_vars with init_slices
257  replacing the old input.
258
259  Args:
260    graph: _WhileBodyGradFuncGraph.
261    loop_vars: the inputs to graph.
262    init_slices: the new IndexedSlices to use as input to graph.
263    input_slices: the new IndexedSlices in graph that should be fed by
264      init_slices.
265    output_slices: the new IndexedSlices in graph that should be the
266      corresponding output to input_slices.
267    old_output_slices: the IndexedSlices in graph that are currently being
268      output.
269
270  Returns:
271    New loop_vars to pass to graph.
272  """
273  structured_idx = _get_tensor_index_in_iterable(graph.structured_outputs,
274                                                 old_output_slices)
275  # We assume that the component tensors of old_output_slices appear
276  # sequentially in graph.outputs. We use the first of these tensors
277  # as the reference index.
278  flat_idx = _get_tensor_index_in_iterable(
279      graph.outputs,
280      func_graph.flatten(old_output_slices)[0])
281
282  graph.structured_outputs[structured_idx] = output_slices
283  graph.outputs = func_graph.flatten(graph.structured_outputs)
284
285  graph.inputs = (
286      graph.inputs[:flat_idx] + _flatten(input_slices) +
287      graph.inputs[flat_idx + 1:])
288
289  return loop_vars[:flat_idx] + _flatten(init_slices) + loop_vars[flat_idx + 1:]
290
291
292def _flatten(arg):
293  return nest.flatten(arg, expand_composites=True)
294