1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15"""Methods for rewriting while_v2 grad functions with IndexedSlices output.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.framework import constant_op 22from tensorflow.python.framework import dtypes 23from tensorflow.python.framework import func_graph 24from tensorflow.python.framework import ops 25from tensorflow.python.framework import tensor_shape 26from tensorflow.python.ops import array_ops 27from tensorflow.python.ops import gen_resource_variable_ops 28from tensorflow.python.util import nest 29 30 31def rewrite_grad_indexed_slices(grads, body_grad_graph, loop_vars, 32 forward_inputs): 33 """Handles special case of IndexedSlices returned from while gradient. 34 35 Some gradient functions return IndexedSlices instead of a Tensor (e.g. the 36 gradient of Gather ops). When this happens in the gradient of a while body, 37 the resulting gradient body function will have mismatched inputs and outputs, 38 since the input is a single Tensor, but the IndexedSlices gets unnested into 39 three output Tensors. 40 41 This function fixes this by rewriting the gradient body to have three inputs 42 to match the three outputs, i.e., it effectively converts the input Tensor 43 into an input IndexedSlices. It also returns new `loop_vars` to reflect the 44 new inputs. 45 46 Args: 47 grads: the input gradient Tensors to the while gradient computation. 48 body_grad_graph: _WhileBodyGradFuncGraph. 49 loop_vars: list of Tensors. The inputs to body_grad_graph. 50 forward_inputs: list of Tensors. The (flat) inputs to the forward-pass 51 While op. 52 53 Returns: 54 The new loop_vars to pass to body_grad_graph. 55 """ 56 # Match up body_grad_graph.structured_outputs with the corresponding 57 # forward_inputs. 58 # 59 # Note that we don't expect a gradient computation to have structured output 60 # (e.g. no nested lists), so no need to flatten 61 # body_grad_graph.structured_outputs. However, structured_outputs may still 62 # contain composite tensors such as IndexedSlices, unlike 63 # body_grad_graph.outputs, which contains flattened composite tensors. 64 inputs_with_grads = [t for g, t in zip(grads, forward_inputs) 65 if g is not None] 66 # Skip loop counter, maximum_iterations and total number of loop iterations. 67 structured_outputs = body_grad_graph.structured_outputs[3:] 68 69 for forward_input, output in zip(inputs_with_grads, structured_outputs): 70 if not isinstance(output, ops.IndexedSlices): continue 71 72 if forward_input.dtype == dtypes.resource: 73 # TODO(skyewm): In theory we should use this for all captured inputs, not 74 # just resource handles (which can only be captured). We can do this by 75 # checking that forward_input is passed straight through to its output. 76 loop_vars = _rewrite_input_as_indexed_slices(body_grad_graph, output, 77 forward_input, loop_vars) 78 else: 79 _rewrite_output_as_tensor(body_grad_graph, output) 80 81 return loop_vars 82 83 84def _get_tensor_index_in_iterable(iterable, t): 85 """Returns index of first occurence of `t`, raises ValueError if not found.""" 86 for i, elem in enumerate(iterable): 87 if t is elem: 88 return i 89 raise ValueError("%s is not in iterable" % str(t)) 90 91 92def _rewrite_output_as_tensor(body_grad_graph, grad_output_slices): 93 """Rewrites grad_output_slices to be a Tensor output. 94 95 Args: 96 body_grad_graph: _WhileBodyGradFuncGraph. 97 grad_output_slices: IndexedSlices output of body_grad_graph. 98 """ 99 with body_grad_graph.as_default(): 100 new_output = ops.convert_to_tensor_v2(grad_output_slices) 101 102 idx = _get_tensor_index_in_iterable(body_grad_graph.structured_outputs, 103 grad_output_slices) 104 body_grad_graph.structured_outputs[idx] = new_output 105 body_grad_graph.outputs = func_graph.flatten( 106 body_grad_graph.structured_outputs) 107 108 109def _rewrite_input_as_indexed_slices(body_grad_graph, grad_output_slices, 110 forward_input, loop_vars): 111 """Rewrites grad_output_slices's corresponding input to be an IndexedSlices. 112 113 This rewrite requires that forward_input was captured in the forward loop, 114 i.e. is not a user-specified loop variable. This is important because the 115 rewrite assumes that forward_input is passed through to its corresponding 116 output unchanged. This assumption is used in _rewrite_input_as_indexed_slices, 117 which depends on the exact gradient structure produced by the input's fanout. 118 119 This can yield a more efficient computation than using 120 _rewrite_output_as_tensor, since it preserves the IndexedSlices structure 121 instead of converting the IndexedSlices to a dense Tensor. 122 123 Args: 124 body_grad_graph: _WhileBodyGradFuncGraph. 125 grad_output_slices: IndexedSlices output of body_grad_graph. 126 forward_input: the corresponding Tensor input to the forward loop. 127 loop_vars: list of Tensors. The inputs to body_grad_graph. 128 129 Returns: 130 The new loop_vars to pass to body_grad_graph. 131 """ 132 # Create initial IndexedSlices that will be the input to the grad While 133 # op. This will start as zeros, and accumulate the IndexedSlices grad output. 134 # Note that because forward_input is captured and not a loop var, its incoming 135 # gradient should always be zero. 136 init_slices = _create_grad_indexed_slices_init(grad_output_slices, 137 forward_input) 138 139 # Create a new version of grad_output_slices's gradient computation that uses 140 # the new IndexedSlices input instead of the original Tensor input. We'll 141 # return the new computation and leave the old computation as dead code. 142 # TODO(skyewm): considering pruning body_grad_graph to remove the old 143 # computation. 144 with body_grad_graph.as_default(): 145 input_slices = ops.IndexedSlices( 146 values=body_grad_graph.capture(init_slices.values, allowlisted=True), 147 indices=body_grad_graph.capture(init_slices.indices, allowlisted=True), 148 dense_shape=body_grad_graph.capture( 149 init_slices.dense_shape, allowlisted=True)) 150 151 # Remove the captured tensors from the function inputs. We'll add them back 152 # at the correct index in _update_indexed_slices_param. 153 for t in _flatten(init_slices): 154 captured_t = body_grad_graph.captures.pop(t) 155 body_grad_graph.inputs.remove(captured_t) 156 157 new_output_slices = _rewrite_grad_indexed_slices_output(grad_output_slices, 158 input_slices) 159 160 # Update body_grad_graph's inputs and outputs to reflect the new 161 # IndexedSlices computation. 162 return _update_indexed_slices_param( 163 body_grad_graph, loop_vars, init_slices, input_slices, new_output_slices, 164 grad_output_slices) 165 166 167def _create_grad_indexed_slices_init(grad_output_slices, forward_input): 168 """Creates an IndexedSlices to pass as input to the while grad function. 169 170 Args: 171 grad_output_slices: IndexedSlices. The corresponding while grad function 172 output. 173 forward_input: Tensor. The corresponding input to the forward while op. 174 175 Returns: 176 Zeros IndexedSlices, created in current Graph. 177 """ 178 assert isinstance(grad_output_slices, ops.IndexedSlices) 179 assert isinstance(forward_input, ops.Tensor) 180 values_out = grad_output_slices.values 181 indices_out = grad_output_slices.indices 182 183 # Create the initial values tensor. 184 if values_out.shape.is_fully_defined(): 185 values_shape = tensor_shape.TensorShape([0] + 186 values_out.shape.as_list()[1:]) 187 values = array_ops.zeros(values_shape, dtype=values_out.dtype, 188 name="values_init") 189 else: 190 if forward_input.dtype == dtypes.resource: 191 forward_shape = gen_resource_variable_ops.variable_shape(forward_input) 192 else: 193 forward_shape = array_ops.shape(forward_input) 194 values_shape = array_ops.concat([[0], forward_shape[1:]], 0) 195 values = array_ops.zeros(values_shape, dtype=values_out.dtype, 196 name="values_init") 197 198 # Create the initial indices tensor. 199 indices = constant_op.constant([], indices_out.dtype, name="indices_init") 200 201 # Create the initial dense_shape tensor. We assume is the same shape as 202 # forward_input, since captured tensors don't change shape across loop 203 # iterations. 204 if forward_input.dtype == dtypes.resource: 205 shape = gen_resource_variable_ops.variable_shape(forward_input, 206 name="shape_init") 207 else: 208 shape = array_ops.shape(forward_input, name="shape_init") 209 210 return ops.IndexedSlices(values=values, indices=indices, dense_shape=shape) 211 212 213def _rewrite_grad_indexed_slices_output(old_output_slices, new_input_slices): 214 """Creates a new version of old_output_slices with new_input_slices as input. 215 216 This method assumes that old_output_slices.{values,indices} are produced by 217 concatenating the incoming gradient Tensor input with the IndexedSlices 218 produced by the gradient computation of the while body. See 219 backprop.aggregate_indexed_slices_gradients for where these concats are 220 constructed. We build new concats that use new_input_slices instead of the 221 original Tensor input. 222 223 Args: 224 old_output_slices: original IndexedSlices output of while gradient. 225 new_input_slices: new IndexedSlices to use as input to while gradient. 226 227 Returns: 228 A new IndexedSlices to replace old_output_slices. 229 """ 230 231 def rewrite(old_output, new_input): 232 assert old_output.type == "Identity" 233 concat_op = old_output.inputs[0].op 234 assert concat_op.type == "ConcatV2" 235 # Don't include axis arg 236 old_concat_args = concat_op.inputs[:-1] 237 # We assume that the original gradient input was the first argument to the 238 # concat op. 239 # TODO(skyewm): do this in a more robust way. 240 return array_ops.concat([new_input] + old_concat_args[1:], 0) 241 242 values = rewrite(old_output_slices.values.op, new_input_slices.values) 243 indices = rewrite(old_output_slices.indices.op, new_input_slices.indices) 244 return ops.IndexedSlices(values=values, indices=indices, 245 dense_shape=new_input_slices.dense_shape) 246 247 248def _update_indexed_slices_param(graph, loop_vars, init_slices, input_slices, 249 output_slices, old_output_slices): 250 """Updates graph with new IndexedSlices input/output. 251 252 Updates graph's metadata to output the gradient computation defined by 253 init_slices, input_slices, and output_slices, instead of outputting 254 old_output_slices. Also returns a new version of loop_vars with init_slices 255 replacing the old input. 256 257 Args: 258 graph: _WhileBodyGradFuncGraph. 259 loop_vars: the inputs to graph. 260 init_slices: the new IndexedSlices to use as input to graph. 261 input_slices: the new IndexedSlices in graph that should be fed by 262 init_slices. 263 output_slices: the new IndexedSlices in graph that should be the 264 corresponding output to input_slices. 265 old_output_slices: the IndexedSlices in graph that are currently 266 being output. 267 268 Returns: 269 New loop_vars to pass to graph. 270 """ 271 structured_idx = _get_tensor_index_in_iterable(graph.structured_outputs, 272 old_output_slices) 273 # We assume that the component tensors of old_output_slices appear 274 # sequentially in graph.outputs. We use the first of these tensors 275 # as the reference index. 276 flat_idx = _get_tensor_index_in_iterable( 277 graph.outputs, 278 func_graph.flatten(old_output_slices)[0]) 279 280 graph.structured_outputs[structured_idx] = output_slices 281 graph.outputs = func_graph.flatten( 282 graph.structured_outputs) 283 284 graph.inputs = (graph.inputs[:flat_idx] + _flatten(input_slices) + 285 graph.inputs[flat_idx + 1:]) 286 287 return loop_vars[:flat_idx] + _flatten(init_slices) + loop_vars[flat_idx + 1:] 288 289 290def _flatten(arg): 291 return nest.flatten(arg, expand_composites=True) 292