1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15"""Methods for rewriting while_v2 grad functions with IndexedSlices output.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.framework import constant_op 22from tensorflow.python.framework import dtypes 23from tensorflow.python.framework import func_graph 24from tensorflow.python.framework import ops 25from tensorflow.python.framework import tensor_shape 26from tensorflow.python.ops import array_ops 27from tensorflow.python.ops import gen_resource_variable_ops 28from tensorflow.python.util import nest 29 30 31def rewrite_grad_indexed_slices(grads, body_grad_graph, loop_vars, 32 forward_inputs): 33 """Handles special case of IndexedSlices returned from while gradient. 34 35 Some gradient functions return IndexedSlices instead of a Tensor (e.g. the 36 gradient of Gather ops). When this happens in the gradient of a while body, 37 the resulting gradient body function will have mismatched inputs and outputs, 38 since the input is a single Tensor, but the IndexedSlices gets unnested into 39 three output Tensors. 40 41 This function fixes this by rewriting the gradient body to have three inputs 42 to match the three outputs, i.e., it effectively converts the input Tensor 43 into an input IndexedSlices. It also returns new `loop_vars` to reflect the 44 new inputs. 45 46 Args: 47 grads: the input gradient Tensors to the while gradient computation. 48 body_grad_graph: _WhileBodyGradFuncGraph. 49 loop_vars: list of Tensors. The inputs to body_grad_graph. 50 forward_inputs: list of Tensors. The (flat) inputs to the forward-pass While 51 op. 52 53 Returns: 54 The new loop_vars to pass to body_grad_graph. 55 """ 56 # Match up body_grad_graph.structured_outputs with the corresponding 57 # forward_inputs. 58 # 59 # Note that we don't expect a gradient computation to have structured output 60 # (e.g. no nested lists), so no need to flatten 61 # body_grad_graph.structured_outputs. However, structured_outputs may still 62 # contain composite tensors such as IndexedSlices, unlike 63 # body_grad_graph.outputs, which contains flattened composite tensors. 64 inputs_with_grads = [ 65 t for g, t in zip(grads, forward_inputs) if g is not None 66 ] 67 # Skip loop counter, maximum_iterations and total number of loop iterations. 68 structured_outputs = body_grad_graph.structured_outputs[3:] 69 70 for forward_input, output in zip(inputs_with_grads, structured_outputs): 71 if not isinstance(output, ops.IndexedSlices): 72 continue 73 74 if forward_input.dtype == dtypes.resource: 75 # TODO(skyewm): In theory we should use this for all captured inputs, not 76 # just resource handles (which can only be captured). We can do this by 77 # checking that forward_input is passed straight through to its output. 78 loop_vars = _rewrite_input_as_indexed_slices(body_grad_graph, output, 79 forward_input, loop_vars) 80 else: 81 _rewrite_output_as_tensor(body_grad_graph, output) 82 83 return loop_vars 84 85 86def _get_tensor_index_in_iterable(iterable, t): 87 """Returns index of first occurence of `t`, raises ValueError if not found.""" 88 for i, elem in enumerate(iterable): 89 if t is elem: 90 return i 91 raise ValueError(f"Element `{t!r}` is not found in iterable `{iterable!r}`.") 92 93 94def _rewrite_output_as_tensor(body_grad_graph, grad_output_slices): 95 """Rewrites grad_output_slices to be a Tensor output. 96 97 Args: 98 body_grad_graph: _WhileBodyGradFuncGraph. 99 grad_output_slices: IndexedSlices output of body_grad_graph. 100 """ 101 with body_grad_graph.as_default(): 102 new_output = ops.convert_to_tensor_v2(grad_output_slices) 103 104 idx = _get_tensor_index_in_iterable(body_grad_graph.structured_outputs, 105 grad_output_slices) 106 body_grad_graph.structured_outputs[idx] = new_output 107 body_grad_graph.outputs = func_graph.flatten( 108 body_grad_graph.structured_outputs) 109 110 111def _rewrite_input_as_indexed_slices(body_grad_graph, grad_output_slices, 112 forward_input, loop_vars): 113 """Rewrites grad_output_slices's corresponding input to be an IndexedSlices. 114 115 This rewrite requires that forward_input was captured in the forward loop, 116 i.e. is not a user-specified loop variable. This is important because the 117 rewrite assumes that forward_input is passed through to its corresponding 118 output unchanged. This assumption is used in _rewrite_input_as_indexed_slices, 119 which depends on the exact gradient structure produced by the input's fanout. 120 121 This can yield a more efficient computation than using 122 _rewrite_output_as_tensor, since it preserves the IndexedSlices structure 123 instead of converting the IndexedSlices to a dense Tensor. 124 125 Args: 126 body_grad_graph: _WhileBodyGradFuncGraph. 127 grad_output_slices: IndexedSlices output of body_grad_graph. 128 forward_input: the corresponding Tensor input to the forward loop. 129 loop_vars: list of Tensors. The inputs to body_grad_graph. 130 131 Returns: 132 The new loop_vars to pass to body_grad_graph. 133 """ 134 # Create initial IndexedSlices that will be the input to the grad While 135 # op. This will start as zeros, and accumulate the IndexedSlices grad output. 136 # Note that because forward_input is captured and not a loop var, its incoming 137 # gradient should always be zero. 138 init_slices = _create_grad_indexed_slices_init(grad_output_slices, 139 forward_input) 140 141 # Create a new version of grad_output_slices's gradient computation that uses 142 # the new IndexedSlices input instead of the original Tensor input. We'll 143 # return the new computation and leave the old computation as dead code. 144 # TODO(skyewm): considering pruning body_grad_graph to remove the old 145 # computation. 146 with body_grad_graph.as_default(): 147 input_slices = ops.IndexedSlices( 148 values=body_grad_graph.capture(init_slices.values, allowlisted=True), 149 indices=body_grad_graph.capture(init_slices.indices, allowlisted=True), 150 dense_shape=body_grad_graph.capture( 151 init_slices.dense_shape, allowlisted=True)) 152 153 # Remove the captured tensors from the function inputs. We'll add them back 154 # at the correct index in _update_indexed_slices_param. 155 for t in _flatten(init_slices): 156 captured_t = body_grad_graph.captures.pop(t) 157 body_grad_graph.inputs.remove(captured_t) 158 159 new_output_slices = _rewrite_grad_indexed_slices_output( 160 grad_output_slices, input_slices) 161 162 # Update body_grad_graph's inputs and outputs to reflect the new 163 # IndexedSlices computation. 164 return _update_indexed_slices_param(body_grad_graph, loop_vars, init_slices, 165 input_slices, new_output_slices, 166 grad_output_slices) 167 168 169def _create_grad_indexed_slices_init(grad_output_slices, forward_input): 170 """Creates an IndexedSlices to pass as input to the while grad function. 171 172 Args: 173 grad_output_slices: IndexedSlices. The corresponding while grad function 174 output. 175 forward_input: Tensor. The corresponding input to the forward while op. 176 177 Returns: 178 Zeros IndexedSlices, created in current Graph. 179 """ 180 assert isinstance(grad_output_slices, ops.IndexedSlices) 181 assert isinstance(forward_input, ops.Tensor) 182 values_out = grad_output_slices.values 183 indices_out = grad_output_slices.indices 184 185 # Create the initial values tensor. 186 if values_out.shape.is_fully_defined(): 187 values_shape = tensor_shape.TensorShape([0] + 188 values_out.shape.as_list()[1:]) 189 values = array_ops.zeros( 190 values_shape, dtype=values_out.dtype, name="values_init") 191 else: 192 if forward_input.dtype == dtypes.resource: 193 forward_shape = gen_resource_variable_ops.variable_shape(forward_input) 194 else: 195 forward_shape = array_ops.shape(forward_input) 196 values_shape = array_ops.concat([[0], forward_shape[1:]], 0) 197 values = array_ops.zeros( 198 values_shape, dtype=values_out.dtype, name="values_init") 199 200 # Create the initial indices tensor. 201 indices = constant_op.constant([], indices_out.dtype, name="indices_init") 202 203 # Create the initial dense_shape tensor. We assume is the same shape as 204 # forward_input, since captured tensors don't change shape across loop 205 # iterations. 206 if forward_input.dtype == dtypes.resource: 207 shape = gen_resource_variable_ops.variable_shape( 208 forward_input, name="shape_init") 209 else: 210 shape = array_ops.shape(forward_input, name="shape_init") 211 212 return ops.IndexedSlices(values=values, indices=indices, dense_shape=shape) 213 214 215def _rewrite_grad_indexed_slices_output(old_output_slices, new_input_slices): 216 """Creates a new version of old_output_slices with new_input_slices as input. 217 218 This method assumes that old_output_slices.{values,indices} are produced by 219 concatenating the incoming gradient Tensor input with the IndexedSlices 220 produced by the gradient computation of the while body. See 221 backprop.aggregate_indexed_slices_gradients for where these concats are 222 constructed. We build new concats that use new_input_slices instead of the 223 original Tensor input. 224 225 Args: 226 old_output_slices: original IndexedSlices output of while gradient. 227 new_input_slices: new IndexedSlices to use as input to while gradient. 228 229 Returns: 230 A new IndexedSlices to replace old_output_slices. 231 """ 232 233 def rewrite(old_output, new_input): 234 assert old_output.type == "Identity" 235 concat_op = old_output.inputs[0].op 236 assert concat_op.type == "ConcatV2" 237 # Don't include axis arg 238 old_concat_args = concat_op.inputs[:-1] 239 # We assume that the original gradient input was the first argument to the 240 # concat op. 241 # TODO(skyewm): do this in a more robust way. 242 return array_ops.concat([new_input] + old_concat_args[1:], 0) 243 244 values = rewrite(old_output_slices.values.op, new_input_slices.values) 245 indices = rewrite(old_output_slices.indices.op, new_input_slices.indices) 246 return ops.IndexedSlices( 247 values=values, indices=indices, dense_shape=new_input_slices.dense_shape) 248 249 250def _update_indexed_slices_param(graph, loop_vars, init_slices, input_slices, 251 output_slices, old_output_slices): 252 """Updates graph with new IndexedSlices input/output. 253 254 Updates graph's metadata to output the gradient computation defined by 255 init_slices, input_slices, and output_slices, instead of outputting 256 old_output_slices. Also returns a new version of loop_vars with init_slices 257 replacing the old input. 258 259 Args: 260 graph: _WhileBodyGradFuncGraph. 261 loop_vars: the inputs to graph. 262 init_slices: the new IndexedSlices to use as input to graph. 263 input_slices: the new IndexedSlices in graph that should be fed by 264 init_slices. 265 output_slices: the new IndexedSlices in graph that should be the 266 corresponding output to input_slices. 267 old_output_slices: the IndexedSlices in graph that are currently being 268 output. 269 270 Returns: 271 New loop_vars to pass to graph. 272 """ 273 structured_idx = _get_tensor_index_in_iterable(graph.structured_outputs, 274 old_output_slices) 275 # We assume that the component tensors of old_output_slices appear 276 # sequentially in graph.outputs. We use the first of these tensors 277 # as the reference index. 278 flat_idx = _get_tensor_index_in_iterable( 279 graph.outputs, 280 func_graph.flatten(old_output_slices)[0]) 281 282 graph.structured_outputs[structured_idx] = output_slices 283 graph.outputs = func_graph.flatten(graph.structured_outputs) 284 285 graph.inputs = ( 286 graph.inputs[:flat_idx] + _flatten(input_slices) + 287 graph.inputs[flat_idx + 1:]) 288 289 return loop_vars[:flat_idx] + _flatten(init_slices) + loop_vars[flat_idx + 1:] 290 291 292def _flatten(arg): 293 return nest.flatten(arg, expand_composites=True) 294