1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Gradients for operators defined in array_ops.py.""" 16 17from tensorflow.compiler.tf2xla.ops import gen_xla_ops 18from tensorflow.python import pywrap_tfe 19from tensorflow.python.eager import context 20from tensorflow.python.framework import constant_op 21from tensorflow.python.framework import dtypes 22from tensorflow.python.framework import indexed_slices as indexed_slices_lib 23from tensorflow.python.framework import ops 24from tensorflow.python.framework import sparse_tensor 25from tensorflow.python.framework import tensor_shape 26from tensorflow.python.framework import tensor_util 27from tensorflow.python.ops import array_ops 28from tensorflow.python.ops import control_flow_ops 29from tensorflow.python.ops import control_flow_util 30from tensorflow.python.ops import gen_array_ops 31from tensorflow.python.ops import gen_math_ops 32from tensorflow.python.ops import gen_resource_variable_ops 33from tensorflow.python.ops import math_ops 34from tensorflow.python.ops import sparse_ops 35 36 37@ops.RegisterGradient("Pack") 38def _PackGrad(op, grad): 39 """Gradient for pack op.""" 40 return array_ops.unstack(grad, num=op.get_attr("N"), axis=op.get_attr("axis")) 41 42 43@ops.RegisterGradient("Unpack") 44def _UnpackGrad(op, *grads): 45 """Gradient for unpack op.""" 46 return array_ops.stack(grads, axis=op.get_attr("axis")) 47 48 49def _ConcatGradHelper(op, grad, start_value_index, end_value_index, dim_index): 50 """Gradient for concat op. 51 52 Args: 53 op: An operation. 54 grad: `Tensor` or `IndexedSlices` representing the gradients with respect to 55 each output of the op. 56 start_value_index: An integer index of the first value in the op.inputs. 57 end_value_index: An integer index of the last value in the op.inputs. 58 dim_index: An integer index of concat_dim or axis parameter in op.inputs. 59 60 Returns: 61 Tensors representing the partial gradients with respect to each input 62 of the op. 63 64 Raises: 65 ValueError: if concat_dim/axis is not statically known. 66 """ 67 68 def _CreateDenseMaskAndBegin(sizes, concat_dim): 69 """Create variables for iteratively slicing a dense gradients tensor.""" 70 # Since shape is 1-D, shape_of_shape = [rank-of-inputs] 71 shape_of_shape = array_ops.shape(sizes[0]) 72 # Make a vector of length equal to the input's dimensions, 73 # with 0's everywhere and 1 in the concat dim position. 74 # Note: Can't use sparse_to_dense since it isn't GPU-capable (for now) 75 mask = array_ops.concat([ 76 array_ops.zeros( 77 array_ops.expand_dims(concat_dim, 0), dtype=dtypes.int32), [1], 78 array_ops.zeros(shape_of_shape - concat_dim - 1, dtype=dtypes.int32) 79 ], 0) 80 begin = array_ops.zeros(shape_of_shape, dtype=dtypes.int32) 81 return mask, begin 82 83 def _ExtractInputShapes(inputs): 84 """Extract the shapes of a set of input tensors.""" 85 if context.executing_eagerly(): 86 return array_ops.shape_n(inputs) 87 sizes = [] 88 fully_known = True 89 for x in inputs: 90 input_shape = array_ops.shape(x) 91 if not isinstance(input_shape, 92 ops.Tensor) or input_shape.op.type != "Const": 93 fully_known = False 94 break 95 sizes.append(input_shape) 96 97 if fully_known: 98 return sizes 99 else: 100 return array_ops.shape_n(inputs) 101 102 # Degenerate concatenation, just return grad. 103 if len(op.inputs) == 2: 104 return grad + [None] if end_value_index <= dim_index else [None] + grad 105 106 concat_dim = op.inputs[dim_index] 107 input_values = op.inputs[start_value_index:end_value_index] 108 109 out_grads = [] 110 if isinstance(grad, ops.Tensor): 111 if context.executing_eagerly() or isinstance(concat_dim, ops.EagerTensor): 112 # Using mod here for convenience since concat_dim is already verified 113 # in concat implementation to be within the allowed [-rank, rank) range. 114 non_neg_concat_dim = ( 115 concat_dim._numpy().item(0) % input_values[0]._rank()) # pylint: disable=protected-access 116 # All inputs are guaranteed to be EagerTensors in eager mode 117 sizes = pywrap_tfe.TFE_Py_TensorShapeSlice(input_values, 118 non_neg_concat_dim) 119 out_grads = array_ops.split(grad, sizes, non_neg_concat_dim) 120 else: 121 if constant_op.is_constant(concat_dim): 122 # If concat_dim is a constant defined in a different context, 123 # then we duplicate it in the current context to avoid passing it 124 # through an Enter node. 125 # This is a small optimization in general, but it is required when 126 # compiling with XLA, as XLA needs the concat input to be folded into a 127 # constant. 128 grad_context = control_flow_util.GetOutputContext(grad.op) 129 dim_context = control_flow_util.GetOutputContext(concat_dim.op) 130 if dim_context != grad_context: 131 value = tensor_util.constant_value(concat_dim) 132 concat_dim = constant_op.constant(value=value, dtype=concat_dim.dtype) 133 134 # Using mod here for convenience since concat_dim is already verified 135 # in concat implementation to be within the allowed [-rank, rank) range. 136 non_neg_concat_dim = concat_dim % array_ops.rank(input_values[0]) 137 138 # Get the inputs' tensor shapes 139 sizes = _ExtractInputShapes(input_values) 140 # The magic number of 16 was found through benchmarking a range of sizes 141 # on CPUs and a Maxwell TitanX. A speedup was seen in a large majority of 142 # cases when switching implementations at N=16, but it is possible that 143 # there will be a small number of performance regressions. 144 if len(sizes) > 16: 145 # extract the size of each input along the concat dimension 146 sizes = array_ops.squeeze( 147 array_ops.slice( 148 array_ops.stack(sizes, axis=1), [non_neg_concat_dim, 0], 149 [1, -1])) 150 out_grads = array_ops.split(grad, sizes, non_neg_concat_dim) 151 else: 152 offset = gen_array_ops.concat_offset(non_neg_concat_dim, sizes) 153 for (begin, size) in zip(offset, sizes): 154 out_grads.append(array_ops.slice(grad, begin, size)) 155 elif isinstance(grad, indexed_slices_lib.IndexedSlices): 156 # Using mod here for convenience since concat_dim is already verified 157 # in concat implementation to be within the allowed [-rank, rank) range. 158 non_neg_concat_dim = concat_dim % array_ops.rank(input_values[0]) 159 concat_dim_static = tensor_util.constant_value(concat_dim) 160 if concat_dim_static is None: 161 raise ValueError("Can only compute IndexedSlices gradient with " 162 "statically-known concat_dim") 163 if concat_dim_static < 0: 164 rank = tensor_util.constant_value(array_ops.rank(input_values[0])) 165 if rank is None: 166 raise ValueError("Can only compute IndexedSlices gradient with " 167 "negative concat_dim when first value rank is " 168 "statically-known.") 169 concat_dim_static %= rank 170 # Get the inputs' tensor shapes 171 sizes = [array_ops.shape(x) for x in input_values] 172 if concat_dim_static > 0: 173 # IndexedSlices, non_neg_concat_dim > 0. Each input gets IndexedSlices 174 # gradients with all the indices, but with grad.values sliced accordingly. 175 # This is like the Tensor case, except shape(grad.values)[0] is not equal 176 # to shape(sizes[i])[0], since only a subset of the dim-0 values are 177 # stored. 178 mask, begin = _CreateDenseMaskAndBegin(sizes, non_neg_concat_dim) 179 for size in sizes: 180 new_values = array_ops.slice( 181 grad.values, begin, 182 array_ops.concat([[-1], array_ops.slice(size, [1], [-1])], 0)) 183 out_grads.append( 184 indexed_slices_lib.IndexedSlices(new_values, grad.indices, size)) 185 # Lint complains begin = begin + ... 186 begin = math_ops.add(begin, size * mask) 187 else: 188 # IndexedSlices, concat_dim == 0. Each input gets IndexedSlices gradients 189 # only for the relevant indices. 190 start = constant_op.constant(0, dtype=grad.indices.dtype) 191 for size in sizes: 192 size_concat_dim = array_ops.gather(size, non_neg_concat_dim) 193 if size_concat_dim.dtype != grad.indices.dtype: 194 size_concat_dim = math_ops.cast( 195 size_concat_dim, dtype=grad.indices.dtype) 196 end = start + size_concat_dim 197 # Compute the 1-D Tensor of indices relevant for this input. 198 indices_to_select = array_ops.squeeze( 199 array_ops.where( 200 math_ops.logical_and(grad.indices >= start, 201 grad.indices < end)), 202 axis=[1]) 203 new_indices = array_ops.gather(grad.indices, indices_to_select) - start 204 new_values = array_ops.gather(grad.values, indices_to_select) 205 out_grads.append( 206 indexed_slices_lib.IndexedSlices(new_values, new_indices, size)) 207 start = end 208 else: 209 raise TypeError("Expected Tensor or IndexedSlices, got %s" % type(grad)) 210 211 return (out_grads + [None] if end_value_index <= dim_index else [None] + 212 out_grads) 213 214 215@ops.RegisterGradient("Concat") 216def _ConcatGrad(op, grad): 217 return _ConcatGradHelper( 218 op, 219 grad, 220 start_value_index=1, 221 end_value_index=len(op.inputs), 222 dim_index=0) 223 224 225@ops.RegisterGradient("ConcatV2") 226def _ConcatGradV2(op, grad): 227 return _ConcatGradHelper( 228 op, grad, start_value_index=0, end_value_index=-1, dim_index=-1) 229 230 231ops.NotDifferentiable("ConcatOffset") 232 233 234@ops.RegisterGradient("Slice") 235def _SliceGrad(op, grad): 236 """Gradient for Slice op.""" 237 # Create an Nx2 padding where the first column represents how many 238 # zeros are to be prepended for each dimension, and the second 239 # column indicates how many zeros are appended. 240 # 241 # The number of zeros to append is the shape of the input 242 # elementwise-subtracted by both the begin vector and sizes vector. 243 # 244 # Some more reshaping is needed to assemble this tensor with the 245 # right dimensions. 246 input_vec = op.inputs[0] 247 begin_vec = op.inputs[1] 248 input_rank = array_ops.rank(input_vec) 249 slice_size = array_ops.shape(op.outputs[0]) 250 if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()): 251 return gen_xla_ops.xla_dynamic_update_slice(array_ops.zeros_like(input_vec), 252 grad, begin_vec), None, None 253 254 shape = array_ops.stack([input_rank, 1]) 255 before_pad = array_ops.reshape(begin_vec, shape) 256 after_pad = array_ops.reshape( 257 array_ops.shape(input_vec) - slice_size - begin_vec, shape) 258 paddings = array_ops.concat([before_pad, after_pad], 1) 259 return array_ops.pad(grad, paddings), None, None 260 261 262@ops.RegisterGradient("StridedSlice") 263def _StridedSliceGrad(op, grad): 264 """Gradient for StridedSlice op.""" 265 begin = op.inputs[1] 266 end = op.inputs[2] 267 strides = op.inputs[3] 268 # StridedSliceGrad requires `x`, `begin`, `end` and `strides` to be of the 269 # same dtype so we build a shape of the same type as other args. 270 # Note that the choice of `begin` for specifying `out_type` is arbitrary. 271 # We could choose any of {begin|end|strides}.dtype since they are required to 272 # be the same. 273 x = array_ops.shape(op.inputs[0], out_type=begin.dtype) 274 275 x_static = tensor_util.constant_value(x) 276 x = x_static if x_static is not None else x 277 begin_static = tensor_util.constant_value(begin) 278 begin = begin_static if begin_static is not None else begin 279 end_static = tensor_util.constant_value(end) 280 end = end_static if end_static is not None else end 281 strides_static = tensor_util.constant_value(strides) 282 strides = strides_static if strides_static is not None else strides 283 284 return array_ops.strided_slice_grad( 285 x, 286 begin, 287 end, 288 strides, 289 grad, 290 begin_mask=op.get_attr("begin_mask"), 291 end_mask=op.get_attr("end_mask"), 292 ellipsis_mask=op.get_attr("ellipsis_mask"), 293 new_axis_mask=op.get_attr("new_axis_mask"), 294 shrink_axis_mask=op.get_attr("shrink_axis_mask")), None, None, None 295 296 297@ops.RegisterGradient("StridedSliceGrad") 298def _StridedSliceGradGrad(op, grad): 299 """Gradient for StridedSliceGrad op.""" 300 begin = op.inputs[1] 301 end = op.inputs[2] 302 strides = op.inputs[3] 303 304 return None, None, None, None, array_ops.strided_slice( 305 grad, 306 begin, 307 end, 308 strides, 309 begin_mask=op.get_attr("begin_mask"), 310 end_mask=op.get_attr("end_mask"), 311 ellipsis_mask=op.get_attr("ellipsis_mask"), 312 new_axis_mask=op.get_attr("new_axis_mask"), 313 shrink_axis_mask=op.get_attr("shrink_axis_mask")) 314 315 316@ops.RegisterGradient("TensorStridedSliceUpdate") 317def _TensorStridedSliceUpdateGrad(op, grad): # pylint:disable=missing-function-docstring 318 begin = op.inputs[1] 319 end = op.inputs[2] 320 strides = op.inputs[3] 321 begin_mask = op.get_attr("begin_mask") 322 end_mask = op.get_attr("end_mask") 323 ellipsis_mask = op.get_attr("ellipsis_mask") 324 new_axis_mask = op.get_attr("new_axis_mask") 325 shrink_axis_mask = op.get_attr("shrink_axis_mask") 326 def Apply(f, *args): 327 return f(*args, 328 begin_mask=begin_mask, 329 end_mask=end_mask, 330 shrink_axis_mask=shrink_axis_mask, 331 new_axis_mask=new_axis_mask, 332 ellipsis_mask=ellipsis_mask) 333 dy = Apply(array_ops.strided_slice, 334 grad, begin, end, strides) 335 dx = Apply(array_ops.tensor_strided_slice_update, 336 grad, begin, end, strides, array_ops.zeros_like(dy)) 337 338 # The value is potentially broadcast to the shape of the strided slice, so we 339 # may need to adjust dy. 340 slice_shape = array_ops.shape(dy, out_type=begin.dtype) 341 value_shape = array_ops.shape(op.inputs[4], out_type=slice_shape.dtype) 342 343 _, reduction_axes = gen_array_ops.broadcast_gradient_args( 344 slice_shape, value_shape) 345 dy_reshaped = math_ops.reduce_sum(dy, axis=reduction_axes, keepdims=True) 346 dy = array_ops.reshape(dy_reshaped, value_shape) 347 348 return dx, None, None, None, dy 349 350 351@ops.RegisterGradient("Split") 352def _SplitGrad(op, *grads): 353 return None, array_ops.concat(list(grads), op.inputs[0]) 354 355 356@ops.RegisterGradient("SplitV") 357def _SplitVGrad(op, *grads): 358 returnval = array_ops.concat(list(grads), op.inputs[2]) 359 returnval = [returnval] + [ 360 None, 361 ] * ( 362 len(op.inputs) - 1) 363 return returnval 364 365 366ops.NotDifferentiable("Const") 367 368 369@ops.RegisterGradient("Diag") 370def _DiagGrad(_, grad): 371 return array_ops.diag_part(grad) 372 373 374@ops.RegisterGradient("DiagPart") 375def _DiagPartGrad(_, grad): 376 return array_ops.diag(grad) 377 378 379@ops.RegisterGradient("MatrixDiag") 380def _MatrixDiagGrad(_, grad): 381 return array_ops.matrix_diag_part(grad) 382 383 384@ops.RegisterGradient("MatrixDiagV2") 385def _MatrixDiagV2Grad(op, grad): 386 return array_ops.matrix_diag_part( 387 grad, k=op.inputs[1]), None, None, None, None 388 389 390@ops.RegisterGradient("MatrixDiagV3") 391def _MatrixDiagV3Grad(op, grad): 392 return array_ops.matrix_diag_part( 393 grad, k=op.inputs[1], align=op.get_attr("align")), None, None, None, None 394 395 396@ops.RegisterGradient("MatrixDiagPart") 397def _MatrixDiagPartGrad(op, grad): 398 matrix_shape = op.inputs[0].get_shape()[-2:] 399 if matrix_shape.is_fully_defined() and matrix_shape[0] == matrix_shape[1]: 400 return array_ops.matrix_diag(grad) 401 else: 402 return array_ops.matrix_set_diag(array_ops.zeros_like(op.inputs[0]), grad) 403 404 405@ops.RegisterGradient("MatrixDiagPartV2") 406def _MatrixDiagPartV2Grad(op, grad): 407 """Gradient for MatrixDiagPartV2.""" 408 matrix_shape = op.inputs[0].get_shape()[-2:] 409 if matrix_shape.is_fully_defined(): 410 return array_ops.matrix_diag( 411 grad, 412 k=op.inputs[1], 413 num_rows=matrix_shape[0], 414 num_cols=matrix_shape[1]), None, None 415 else: 416 return array_ops.matrix_set_diag( 417 array_ops.zeros_like(op.inputs[0]), grad, k=op.inputs[1]), None, None 418 419 420@ops.RegisterGradient("MatrixDiagPartV3") 421def _MatrixDiagPartV3Grad(op, grad): 422 """Gradient for MatrixDiagPartV3.""" 423 matrix_shape = op.inputs[0].get_shape()[-2:] 424 align = op.get_attr("align") 425 if matrix_shape.is_fully_defined(): 426 return array_ops.matrix_diag( 427 grad, 428 k=op.inputs[1], 429 num_rows=matrix_shape[0], 430 num_cols=matrix_shape[1], 431 align=align), None, None 432 else: 433 return array_ops.matrix_set_diag( 434 array_ops.zeros_like(op.inputs[0]), grad, k=op.inputs[1], 435 align=align), None, None 436 437 438@ops.RegisterGradient("MatrixSetDiag") 439def _MatrixSetDiagGrad(op, grad): 440 """Gradient for MatrixSetDiag.""" 441 input_shape = op.inputs[0].get_shape().merge_with(grad.get_shape()) 442 diag_shape = op.inputs[1].get_shape() 443 batch_shape = input_shape[:-2].merge_with(diag_shape[:-1]) 444 matrix_shape = input_shape[-2:] 445 if batch_shape.is_fully_defined() and matrix_shape.is_fully_defined(): 446 diag_shape = batch_shape.as_list() + [min(matrix_shape.as_list())] 447 else: 448 with ops.colocate_with(grad): 449 grad_shape = array_ops.shape(grad) 450 grad_rank = array_ops.rank(grad) 451 batch_shape = array_ops.slice(grad_shape, [0], [grad_rank - 2]) 452 matrix_shape = array_ops.slice(grad_shape, [grad_rank - 2], [2]) 453 min_dim = math_ops.reduce_min(matrix_shape) 454 diag_shape = array_ops.concat([batch_shape, [min_dim]], 0) 455 grad_input = array_ops.matrix_set_diag( 456 grad, array_ops.zeros(diag_shape, dtype=grad.dtype)) 457 grad_diag = array_ops.matrix_diag_part(grad) 458 return (grad_input, grad_diag) 459 460 461@ops.RegisterGradient("MatrixSetDiagV2") 462def _MatrixSetDiagGradV2(op, grad): 463 """Gradient for MatrixSetDiagV2.""" 464 diag_shape = op.inputs[1].get_shape() 465 if not diag_shape.is_fully_defined(): 466 # Need to know the values of `d_lower` and `d_upper` to infer diag_shape. 467 grad_shape = array_ops.shape(grad) 468 batch_shape = grad_shape[:-2] 469 matrix_shape = grad_shape[-2:] 470 diag_index = array_ops.reshape(op.inputs[2], [-1]) # Converts to vector. 471 d_lower = diag_index[0] 472 d_upper = diag_index[-1] # Works both when len(diag_index) is 1 and 2. 473 y_offset = control_flow_ops.cond( 474 math_ops.less(d_upper, 0), lambda: d_upper, lambda: 0) 475 x_offset = control_flow_ops.cond( 476 math_ops.greater(d_lower, 0), lambda: -d_lower, lambda: 0) 477 478 max_diag_len = math_ops.minimum(matrix_shape[0] + y_offset, 479 matrix_shape[1] + x_offset) 480 # pylint: disable=g-long-lambda 481 # pyformat: disable 482 postfix = control_flow_ops.cond( 483 math_ops.equal(d_lower, d_upper), 484 lambda: ops.convert_to_tensor([max_diag_len]), 485 lambda: ops.convert_to_tensor([d_upper - d_lower + 1, 486 max_diag_len])) 487 # pyformat: enable 488 # pylint: enable=g-long-lambda 489 diag_shape = array_ops.concat([batch_shape, postfix], 0) 490 491 grad_input = array_ops.matrix_set_diag( 492 grad, array_ops.zeros(diag_shape, dtype=grad.dtype), k=op.inputs[2]) 493 grad_diag = array_ops.matrix_diag_part(grad, k=op.inputs[2]) 494 return (grad_input, grad_diag, None) 495 496 497@ops.RegisterGradient("MatrixSetDiagV3") 498def _MatrixSetDiagGradV3(op, grad): 499 """Gradient for MatrixSetDiagV3.""" 500 diag_shape = op.inputs[1].get_shape() 501 align = op.get_attr("align") 502 if not diag_shape.is_fully_defined(): 503 # Need to know the values of `d_lower` and `d_upper` to infer diag_shape. 504 grad_shape = array_ops.shape(grad) 505 batch_shape = grad_shape[:-2] 506 matrix_shape = grad_shape[-2:] 507 diag_index = array_ops.reshape(op.inputs[2], [-1]) # Converts to vector. 508 d_lower = diag_index[0] 509 d_upper = diag_index[-1] # Works both when len(diag_index) is 1 and 2. 510 y_offset = control_flow_ops.cond( 511 math_ops.less(d_upper, 0), lambda: d_upper, lambda: 0) 512 x_offset = control_flow_ops.cond( 513 math_ops.greater(d_lower, 0), lambda: -d_lower, lambda: 0) 514 515 max_diag_len = math_ops.minimum(matrix_shape[0] + y_offset, 516 matrix_shape[1] + x_offset) 517 # pylint: disable=g-long-lambda 518 # pyformat: disable 519 postfix = control_flow_ops.cond( 520 math_ops.equal(d_lower, d_upper), 521 lambda: ops.convert_to_tensor([max_diag_len]), 522 lambda: ops.convert_to_tensor([d_upper - d_lower + 1, 523 max_diag_len])) 524 # pyformat: enable 525 # pylint: enable=g-long-lambda 526 diag_shape = array_ops.concat([batch_shape, postfix], 0) 527 528 grad_input = array_ops.matrix_set_diag( 529 grad, 530 array_ops.zeros(diag_shape, dtype=grad.dtype), 531 k=op.inputs[2], 532 align=align) 533 grad_diag = array_ops.matrix_diag_part(grad, k=op.inputs[2], align=align) 534 return (grad_input, grad_diag, None) 535 536 537@ops.RegisterGradient("MatrixBandPart") 538def _MatrixBandPartGrad(op, grad): 539 num_lower = op.inputs[1] 540 num_upper = op.inputs[2] 541 return (array_ops.matrix_band_part(grad, num_lower, num_upper), None, None) 542 543 544# Edit Distance has no gradient (but can be used to eval seq2seq or CTC). 545ops.NotDifferentiable("EditDistance") 546 547 548@ops.RegisterGradient("Fill") 549def _FillGrad(_, grad): 550 return None, math_ops.reduce_sum(grad) 551 552 553ops.NotDifferentiable("ZerosLike") 554ops.NotDifferentiable("OnesLike") 555 556 557@ops.RegisterGradient("PreventGradient") 558def _PreventGradientGrad(op, _): 559 raise LookupError("Gradient explicitly disabled. Reason: %s" % 560 op.get_attr("message")) 561 562 563def _IndexedSlicesToTensorNoWarning(indexed_slices): 564 """Converts an IndexedSlices to a Tensor without sparse->dense warnings.""" 565 if not isinstance(indexed_slices, indexed_slices_lib.IndexedSlices): 566 # If it is not IndexedSlices, it's better be a tensor. 567 return indexed_slices 568 if indexed_slices.dense_shape is None: 569 raise ValueError( 570 "Tensor conversion requested for IndexedSlices without dense_shape: %s" 571 % str(indexed_slices)) 572 return math_ops.unsorted_segment_sum(indexed_slices.values, 573 indexed_slices.indices, 574 indexed_slices.dense_shape[0]) 575 576 577@ops.RegisterGradient("Gather") 578def _GatherGrad(op, grad): 579 """Gradient for Gather op.""" 580 # params can be large, so colocate the shape calculation with it. 581 params = op.inputs[0] 582 with ops.colocate_with(params): 583 params_shape = array_ops.shape(params) 584 585 # Build appropriately shaped IndexedSlices 586 indices = op.inputs[1] 587 size = array_ops.expand_dims(array_ops.size(indices), 0) 588 values_shape = array_ops.concat([size, params_shape[1:]], 0) 589 values = array_ops.reshape( 590 _IndexedSlicesToTensorNoWarning(grad), values_shape) 591 indices = array_ops.reshape(indices, size) 592 return [indexed_slices_lib.IndexedSlices(values, indices, params_shape), None] 593 594 595def _GetBatchIndices(params_shape, indices, batch_dims): 596 """Addds the batch offsets to the given indices and returns the results.""" 597 batch_indices = indices 598 indices_dtype = indices.dtype.base_dtype 599 casted_params_shape = math_ops.cast(params_shape, indices_dtype) 600 accum_dim_value = array_ops.ones((), dtype=indices_dtype) 601 for dim in range(batch_dims, 0, -1): 602 dim_value = casted_params_shape[dim - 1] 603 accum_dim_value *= casted_params_shape[dim] 604 start = array_ops.zeros((), dtype=indices_dtype) 605 step = array_ops.ones((), dtype=indices_dtype) 606 dim_indices = math_ops.range(start, dim_value, step) 607 dim_indices *= accum_dim_value 608 dim_shape = array_ops.concat([ 609 array_ops.tile([1], [dim - 1]), [dim_value], 610 array_ops.tile([1], [array_ops.rank(indices) - dim]) 611 ], axis=0) 612 batch_indices += array_ops.reshape(dim_indices, dim_shape) 613 614 return batch_indices 615 616 617def _BatchGatherGrad(params_shape, values, indices, batch_dims, 618 gather_dim_size): 619 """Returns the gradient of GatherV2 with batch dimensions.""" 620 621 # Axis is the first non-batch dimension. 622 indices_size = array_ops.expand_dims(array_ops.size(indices), 0) 623 if batch_dims: 624 values_shape = array_ops.shape(values) 625 # Add the batch offsets to indices and flatten the batch dimensions. 626 outer_shape = values_shape[:batch_dims] 627 inner_shape = values_shape[batch_dims:][1:] 628 batch_size = gen_math_ops.prod(outer_shape, [0], False) 629 flat_values_shape = array_ops.concat([[-1], inner_shape], 0) 630 gather_dim_size *= batch_size 631 632 indices = _GetBatchIndices(params_shape, indices, batch_dims) 633 values = array_ops.reshape( 634 _IndexedSlicesToTensorNoWarning(values), flat_values_shape) 635 636 indices = array_ops.reshape(indices, indices_size) 637 params_grad = math_ops.unsorted_segment_sum(values, indices, gather_dim_size) 638 639 if batch_dims: 640 # Put back the batch dimensions. 641 params_grad = array_ops.reshape( 642 params_grad, array_ops.concat([outer_shape, flat_values_shape], 0)) 643 644 return params_grad 645 646 647@ops.RegisterGradient("GatherV2") 648def _GatherV2Grad(op, grad): 649 """Gradient for GatherV2 op.""" 650 # params can be large, so colocate the shape calculation with it. 651 # 652 # params can be very large for sparse model, array_ops.shape raises 653 # exception on the Windows platform when any dimension is larger than 654 # int32. params_shape is not used in optimizer apply_sparse gradients, 655 # so it's fine to convert it back to int32 regardless of truncation. 656 params = op.inputs[0] 657 with ops.colocate_with(params): 658 params_shape = array_ops.shape(params, out_type=ops.dtypes.int64) 659 params_shape = math_ops.cast(params_shape, dtypes.int32) 660 661 indices = op.inputs[1] 662 indices_size = array_ops.expand_dims(array_ops.size(indices), 0) 663 axis = op.inputs[2] 664 axis_static = tensor_util.constant_value(axis) 665 batch_dims = int(op.get_attr("batch_dims")) 666 667 if batch_dims < 0: 668 if indices.shape.ndims is None: 669 raise ValueError( 670 f"Currently, it is unsupported to take the gradient of tf.gather " 671 f"when batch_dims < 0 and the rank of the indices is unknown. Please " 672 f"pass a positive batch_dims or use tf.ensure_shape to update the " 673 f"shape of indices when calling tf.gather. Got " 674 f"batch_dims={batch_dims} and indices={indices}") 675 batch_dims += indices.shape.ndims 676 677 # For axis 0 gathers, build an appropriately shaped IndexedSlices. 678 if axis_static == 0: 679 if context.executing_eagerly(): 680 with ops.device(indices_size.device): 681 params_tail_shape = array_ops.identity(params_shape)[1:] 682 else: 683 params_tail_shape = params_shape[1:] 684 values_shape = array_ops.concat([indices_size, params_tail_shape], 0) 685 values = array_ops.reshape( 686 _IndexedSlicesToTensorNoWarning(grad), values_shape) 687 indices = array_ops.reshape(indices, indices_size) 688 params_grad = indexed_slices_lib.IndexedSlices(values, indices, 689 params_shape) 690 else: 691 # Handle axis by transposing the axis dimension to be the first non-batch 692 # dimension, compute the gradient and transpose the result back. 693 outer_shape = params_shape[:axis] 694 inner_shape = params_shape[axis:][1:] 695 values_shape = array_ops.concat([outer_shape, [-1], inner_shape], 0) 696 697 values_dims = array_ops.size(values_shape) 698 axis_dims = array_ops.size(outer_shape) 699 700 outer_batches_indices = math_ops.range(batch_dims) 701 batch_axis_indices = math_ops.range(batch_dims, axis_dims) 702 inner_axes_indices = math_ops.range(axis_dims + 1, values_dims) 703 704 values = array_ops.reshape( 705 _IndexedSlicesToTensorNoWarning(grad), values_shape) 706 707 # Move values[axis] up to values[batch_dims] 708 transpose_dims = array_ops.concat([ 709 outer_batches_indices, [axis_dims], batch_axis_indices, 710 inner_axes_indices 711 ], 0) 712 values_transpose = array_ops.transpose(values, transpose_dims) 713 params_shape_transpose = array_ops.gather(params_shape, transpose_dims) 714 715 params_grad = _BatchGatherGrad(params_shape_transpose, values_transpose, 716 indices, batch_dims, params_shape[axis]) 717 718 # Inverts the above transpose by moving dimension batch_dims back to its 719 # original position. 720 invert_transpose_dims = array_ops.concat([ 721 outer_batches_indices, batch_axis_indices + 1, [batch_dims], 722 inner_axes_indices 723 ], 0) 724 params_grad = array_ops.transpose(params_grad, invert_transpose_dims) 725 726 return [params_grad, None, None] 727 728 729@ops.RegisterGradient("GatherNd") 730def _GatherNdGrad(op, grad): 731 ref = op.inputs[0] 732 indices = op.inputs[1] 733 ref_shape = array_ops.shape(ref, out_type=indices.dtype) 734 if indices.shape.ndims == 2 and indices.shape.dims[-1].value == 1: 735 ref_grad = indexed_slices_lib.IndexedSlices( 736 grad, array_ops.squeeze(indices, axis=-1), ref_shape) 737 else: 738 ref_grad = array_ops.scatter_nd(indices, grad, ref_shape) 739 return [ref_grad, None] 740 741 742@ops.RegisterGradient("ResourceGatherNd") 743def _ResourceGatherNdGrad(op, grad): # pylint: disable=missing-docstring 744 ref = op.inputs[0] 745 indices = op.inputs[1] 746 ref_shape = gen_resource_variable_ops.variable_shape(ref, indices.dtype) 747 if indices.shape.ndims == 2 and indices.shape.dims[-1].value == 1: 748 ref_grad = indexed_slices_lib.IndexedSlices( 749 grad, array_ops.squeeze(indices, axis=-1), ref_shape) 750 else: 751 ref_grad = array_ops.scatter_nd(indices, grad, ref_shape) 752 return [ref_grad, None] 753 754 755@ops.RegisterGradient("CheckNumerics") 756def _CheckNumericsGrad(op, grad): 757 """Gradient for check_numerics op.""" 758 return array_ops.check_numerics( 759 grad, 760 "Not a number (NaN) or infinity (Inf) values detected in gradient. %s" % 761 op.get_attr("message")) 762 763 764@ops.RegisterGradient("CheckNumericsV2") 765def _CheckNumericsV2Grad(op, grad): 766 """Gradient for check_numerics op.""" 767 return array_ops.check_numerics_v2( 768 grad, 769 "Not a number (NaN) or infinity (Inf) values detected in gradient. %s" % 770 op.get_attr("message")) 771 772 773@ops.RegisterGradient("PlaceholderWithDefault") 774@ops.RegisterGradient("Identity") 775def _IdGrad(_, grad): 776 return grad 777 778 779@ops.RegisterGradient("_EagerConst") 780def _EagerConstGrad(_, grad): 781 raise AssertionError( 782 "This op should never interact with gradient APIs. Please file a bug.") 783 784 785@ops.RegisterGradient("RefIdentity") 786def _RefIdGrad(_, grad): 787 return grad 788 789 790@ops.RegisterGradient("IdentityN") 791def _IdNGrad(_, *grad): 792 return grad 793 794 795ops.NotDifferentiable("StopGradient") 796 797 798@ops.RegisterGradient("Reshape") 799def _ReshapeGrad(op, grad): 800 return [ 801 array_ops.reshape( 802 _IndexedSlicesToTensorNoWarning(grad), array_ops.shape(op.inputs[0])), 803 None 804 ] 805 806 807ops.NotDifferentiable("InvertPermutation") 808 809 810def _ReshapeToInput(op, grad): 811 """Reshapes the gradient to the shape of the original input.""" 812 return array_ops.reshape( 813 _IndexedSlicesToTensorNoWarning(grad), array_ops.shape(op.inputs[0])) 814 815 816@ops.RegisterGradient("ExpandDims") 817def _ExpandDimsGrad(op, grad): 818 return [_ReshapeToInput(op, grad), None] 819 820 821@ops.RegisterGradient("Squeeze") 822def _SqueezeGrad(op, grad): 823 return _ReshapeToInput(op, grad) 824 825 826@ops.RegisterGradient("Transpose") 827def _TransposeGrad(op, grad): 828 """Returns unshuffle(grad).""" 829 p = op.inputs[1] 830 return [array_ops.transpose(grad, array_ops.invert_permutation(p)), None] 831 832 833@ops.RegisterGradient("ConjugateTranspose") 834def _ConjugateTransposeGrad(op, grad): 835 """Returns conj(unshuffle(grad)).""" 836 p = op.inputs[1] 837 return [ 838 array_ops.transpose( 839 grad, array_ops.invert_permutation(p), conjugate=True), None 840 ] 841 842 843ops.NotDifferentiable("Shape") 844 845ops.NotDifferentiable("ShapeN") 846 847ops.NotDifferentiable("Rank") 848 849ops.NotDifferentiable("Size") 850 851 852@ops.RegisterGradient("Tile") 853def _TileGrad(op, grad): 854 """Sum reduces grad along the tiled dimensions.""" 855 input_shape = array_ops.shape(op.inputs[0], out_type=op.inputs[1].dtype) 856 # We interleave multiples and input_shape to get split_shape, 857 # reshape grad to split_shape, and reduce along all even 858 # dimensions (the tiled dimensions) to get the result 859 # with shape input_shape. For example 860 # input_shape = [20, 30, 40] 861 # multiples = [2, 3, 4] 862 # split_shape = [2, 20, 3, 30, 4, 40] 863 # axes = [0, 2, 4] 864 split_shape = array_ops.reshape( 865 array_ops.transpose(array_ops.stack([op.inputs[1], input_shape])), [-1]) 866 axes = math_ops.range(0, array_ops.size(split_shape), 2) 867 # Sum reduces grad along the first dimension for IndexedSlices 868 if isinstance(grad, indexed_slices_lib.IndexedSlices): 869 input_shape_0 = math_ops.cast(input_shape[0], grad.indices.dtype) 870 grad = math_ops.unsorted_segment_sum( 871 grad.values, math_ops.mod(grad.indices, input_shape_0), input_shape_0) 872 split_shape = array_ops.concat([[1], split_shape[1:]], axis=0) 873 input_grad = math_ops.reduce_sum(array_ops.reshape(grad, split_shape), axes) 874 # Fix shape inference 875 if not context.executing_eagerly(): 876 input_grad.set_shape(op.inputs[0].get_shape()) 877 return [input_grad, None] 878 879 880ops.NotDifferentiable("BroadcastGradientArgs") 881 882 883def _PadGrad(op, grad): 884 """Gradient for Pad.""" 885 # Pad introduces values around the original tensor, so the gradient function 886 # slices the original shape out of the gradient.""" 887 x = op.inputs[0] 888 a = op.inputs[1] # [Rank(x), 2] 889 # Takes a slice of a. The 1st column. [Rank(x), 1]. 890 pad_before = array_ops.slice(a, [0, 0], 891 array_ops.stack([array_ops.rank(x), 1])) 892 # Make it a 1-D tensor. 893 begin = array_ops.reshape(pad_before, [-1]) 894 sizes = array_ops.shape(x, out_type=begin.dtype) 895 x_grad = array_ops.slice(grad, begin, sizes) 896 if len(op.inputs) == 3: 897 return x_grad, None, None 898 else: 899 return x_grad, None 900 901 902ops.RegisterGradient("Pad")(_PadGrad) 903ops.RegisterGradient("PadV2")(_PadGrad) 904 905 906# ReverseSequence is just a permutation. The gradient permutes back. 907@ops.RegisterGradient("ReverseSequence") 908def _ReverseSequenceGrad(op, grad): 909 seq_lengths = op.inputs[1] 910 return [ 911 array_ops.reverse_sequence( 912 grad, 913 batch_axis=op.get_attr("batch_dim"), 914 seq_axis=op.get_attr("seq_dim"), 915 seq_lengths=seq_lengths), None 916 ] 917 918 919@ops.RegisterGradient("Reverse") 920def _ReverseGrad(op, grad): 921 reverse_dims = op.inputs[1] 922 return gen_array_ops.reverse(grad, reverse_dims), None 923 924 925@ops.RegisterGradient("ReverseV2") 926def _ReverseV2Grad(op, grad): 927 axis = op.inputs[1] 928 return array_ops.reverse_v2(grad, axis), None 929 930 931@ops.RegisterGradient("SpaceToBatch") 932def _SpaceToBatchGrad(op, grad): 933 # Its gradient is the opposite op: BatchToSpace. 934 block_size = op.get_attr("block_size") 935 return [ 936 array_ops.batch_to_space(grad, op.inputs[1], block_size=block_size), None 937 ] 938 939 940@ops.RegisterGradient("SpaceToBatchND") 941def _SpaceToBatchNDGrad(op, grad): 942 # Its gradient is the opposite op: BatchToSpaceND. 943 return [ 944 array_ops.batch_to_space_nd(grad, op.inputs[1], op.inputs[2]), None, None 945 ] 946 947 948@ops.RegisterGradient("BatchToSpace") 949def _BatchToSpaceGrad(op, grad): 950 # Its gradient is the opposite op: SpaceToBatch. 951 block_size = op.get_attr("block_size") 952 return [ 953 array_ops.space_to_batch(grad, op.inputs[1], block_size=block_size), None 954 ] 955 956 957@ops.RegisterGradient("BatchToSpaceND") 958def _BatchToSpaceNDGrad(op, grad): 959 # Its gradient is the opposite op: SpaceToBatchND. 960 return [ 961 array_ops.space_to_batch_nd(grad, op.inputs[1], op.inputs[2]), None, None 962 ] 963 964 965@ops.RegisterGradient("SpaceToDepth") 966def _SpaceToDepthGrad(op, grad): 967 # Its gradient is the opposite op: DepthToSpace. 968 block_size = op.get_attr("block_size") 969 data_format = op.get_attr("data_format") 970 if data_format == "NCHW_VECT_C": 971 raise ValueError("Cannot compute SpaceToDepth gradient with NCHW_VECT_C. " 972 "NCHW_VECT_C requires qint8 data type.") 973 return array_ops.depth_to_space(grad, block_size, data_format=data_format) 974 975 976@ops.RegisterGradient("DepthToSpace") 977def _DepthToSpaceGrad(op, grad): 978 # Its gradient is the opposite op: SpaceToDepth. 979 block_size = op.get_attr("block_size") 980 data_format = op.get_attr("data_format") 981 if data_format == "NCHW_VECT_C": 982 raise ValueError("Cannot compute DepthToSpace gradient with NCHW_VECT_C. " 983 "NCHW_VECT_C requires qint8 data type.") 984 return array_ops.space_to_depth(grad, block_size, data_format=data_format) 985 986 987ops.NotDifferentiable("OneHot") 988 989 990@ops.RegisterGradient("MirrorPad") 991def _MirrorPadGrad(op, grad): 992 mode = op.get_attr("mode") 993 return [gen_array_ops.mirror_pad_grad(grad, op.inputs[1], mode=mode), None] 994 995 996@ops.RegisterGradient("MirrorPadGrad") 997def _MirrorPadGradGrad(op, grad): 998 mode = op.get_attr("mode") 999 return [gen_array_ops.mirror_pad(grad, op.inputs[1], mode=mode), None] 1000 1001 1002@ops.RegisterGradient("QuantizeAndDequantize") 1003def _QuantizeAndDequantizeGrad(_, grad): 1004 return grad 1005 1006 1007@ops.RegisterGradient("QuantizeAndDequantizeV2") 1008def _QuantizeAndDequantizeV2Grad(_, grad): 1009 return [grad, None, None] 1010 1011 1012@ops.RegisterGradient("QuantizeAndDequantizeV3") 1013def _QuantizeAndDequantizeV3Grad(_, grad): 1014 # Only propagate the gradient for the unquantized input. 1015 return [grad, None, None, None] 1016 1017 1018@ops.RegisterGradient("ExtractImagePatches") 1019def _ExtractImagePatchesGrad(op, grad): 1020 input_bhwc = array_ops.shape(op.inputs[0], out_type=dtypes.int64) 1021 batch_size, rows_in, cols_in, channels = input_bhwc[0], input_bhwc[1], \ 1022 input_bhwc[2], input_bhwc[3] 1023 1024 # Create indices matrix for input tensor. 1025 # Note that 0 is preserved for padding location, 1026 # so indices for input start from 1 to 1 + rows_in * cols_in. 1027 input_indices_num = 1 + rows_in * cols_in 1028 input_idx = array_ops.reshape( 1029 math_ops.range(1, input_indices_num, dtype=ops.dtypes.int64), 1030 (1, rows_in, cols_in, 1)) 1031 input_idx_patched = gen_array_ops.extract_image_patches( 1032 input_idx, op.get_attr("ksizes"), op.get_attr("strides"), 1033 op.get_attr("rates"), op.get_attr("padding")) 1034 1035 # Create indices matrix for output tensor. 1036 output_bhwc = array_ops.shape(op.outputs[0], out_type=dtypes.int64) 1037 rows_out, cols_out = output_bhwc[1], output_bhwc[2] 1038 _, ksize_r, ksize_c, _ = op.get_attr("ksizes") 1039 # Indices for output start from 0. 1040 output_indices_num = rows_out * cols_out * ksize_r * ksize_c 1041 output_idx = array_ops.reshape( 1042 math_ops.range(output_indices_num, dtype=ops.dtypes.int64), 1043 (1, rows_out, cols_out, ksize_r * ksize_c)) 1044 1045 # Construct mapping table for indices: (input -> output). 1046 idx_matrix = array_ops.concat([ 1047 array_ops.expand_dims(input_idx_patched, axis=-1), 1048 array_ops.expand_dims(output_idx, axis=-1) 1049 ], 1050 axis=-1) 1051 idx_map = array_ops.reshape(idx_matrix, (-1, 2)) 1052 1053 sp_shape = (input_indices_num, output_indices_num) 1054 sp_mat_full = sparse_tensor.SparseTensor( 1055 idx_map, array_ops.ones([output_indices_num], dtype=grad.dtype), sp_shape) 1056 # Remove all padding locations [0, :]. 1057 sp_mat = sparse_ops.sparse_slice(sp_mat_full, (1, 0), 1058 (input_indices_num - 1, output_indices_num)) 1059 1060 grad_expanded = array_ops.transpose( 1061 array_ops.reshape( 1062 _IndexedSlicesToTensorNoWarning(grad), 1063 (batch_size, rows_out, cols_out, ksize_r, ksize_c, channels)), 1064 (1, 2, 3, 4, 0, 5)) 1065 grad_flat = array_ops.reshape(grad_expanded, (-1, batch_size * channels)) 1066 1067 jac = sparse_ops.sparse_tensor_dense_matmul(sp_mat, grad_flat) 1068 1069 grad_out = array_ops.reshape(jac, (rows_in, cols_in, batch_size, channels)) 1070 grad_out = array_ops.transpose(grad_out, (2, 0, 1, 3)) 1071 1072 return [grad_out] 1073 1074 1075@ops.RegisterGradient("ExtractVolumePatches") 1076def _ExtractVolumePatchesGrad(op, grad): 1077 batch_size, planes_in, rows_in, cols_in, channels = [ 1078 dim.value for dim in op.inputs[0].shape.dims 1079 ] 1080 input_bphwc = array_ops.shape(op.inputs[0]) 1081 batch_size = input_bphwc[0] 1082 channels = input_bphwc[4] 1083 1084 # Create indices matrix for input tensor. 1085 # Note that 0 is preserved for padding location, 1086 # so indices for input start from 1 to 1 + rows_in * cols_in. 1087 input_indices_num = 1 + planes_in * rows_in * cols_in 1088 input_idx = array_ops.reshape( 1089 math_ops.range(1, input_indices_num, dtype=ops.dtypes.int64), 1090 (1, planes_in, rows_in, cols_in, 1)) 1091 input_idx_patched = gen_array_ops.extract_volume_patches( 1092 input_idx, op.get_attr("ksizes"), op.get_attr("strides"), 1093 op.get_attr("padding")) 1094 1095 # Create indices matrix for output tensor. 1096 _, planes_out, rows_out, cols_out, _ = [ 1097 dim.value for dim in op.outputs[0].shape.dims 1098 ] 1099 _, ksize_p, ksize_r, ksize_c, _ = op.get_attr("ksizes") 1100 # Indices for output start from 0. 1101 prc_indices_num = planes_out * rows_out * cols_out 1102 output_indices_num = prc_indices_num * ksize_p * ksize_r * ksize_c 1103 output_idx = array_ops.reshape( 1104 math_ops.range(output_indices_num, dtype=ops.dtypes.int64), 1105 (1, planes_out, rows_out, cols_out, ksize_p * ksize_r * ksize_c)) 1106 1107 # Construct mapping table for indices: (input -> output). 1108 idx_matrix = array_ops.concat([ 1109 array_ops.expand_dims(input_idx_patched, axis=-1), 1110 array_ops.expand_dims(output_idx, axis=-1) 1111 ], 1112 axis=-1) 1113 idx_map = array_ops.reshape(idx_matrix, (-1, 2)) 1114 1115 sp_shape = (input_indices_num, output_indices_num) 1116 sp_mat_full = sparse_tensor.SparseTensor( 1117 idx_map, array_ops.ones([output_indices_num], dtype=grad.dtype), sp_shape) 1118 # Remove all padding locations [0, :]. 1119 sp_mat = sparse_ops.sparse_slice(sp_mat_full, (1, 0), 1120 (input_indices_num - 1, output_indices_num)) 1121 1122 grad_expanded = array_ops.transpose( 1123 array_ops.reshape( 1124 _IndexedSlicesToTensorNoWarning(grad), 1125 (batch_size, planes_out, rows_out, cols_out, ksize_p, ksize_r, 1126 ksize_c, channels)), (1, 2, 3, 4, 5, 6, 0, 7)) 1127 grad_flat = array_ops.reshape(grad_expanded, (-1, batch_size * channels)) 1128 1129 jac = sparse_ops.sparse_tensor_dense_matmul(sp_mat, grad_flat) 1130 1131 grad_out = array_ops.reshape( 1132 jac, (planes_in, rows_in, cols_in, batch_size, channels)) 1133 grad_out = array_ops.transpose(grad_out, (3, 0, 1, 2, 4)) 1134 1135 return [grad_out] 1136 1137 1138@ops.RegisterGradient("ScatterNd") 1139def _ScatterNdGrad(op, grad): 1140 indices = op.inputs[0] 1141 updates_grad = array_ops.gather_nd(grad, indices) 1142 return [None, updates_grad, None] 1143 1144 1145@ops.RegisterGradient("TensorScatterUpdate") 1146def _TensorScatterUpdateGrad(op, grad): 1147 indices = op.inputs[1] 1148 updates_grad = array_ops.gather_nd(grad, indices) 1149 tensor_grad = array_ops.tensor_scatter_update( 1150 array_ops.identity(grad), indices, 1151 array_ops.zeros_like(op.inputs[2], dtype=grad.dtype)) 1152 return [tensor_grad, None, updates_grad] 1153 1154 1155@ops.RegisterGradient("TensorScatterAdd") 1156def _TensorScatterAddGrad(op, grad): 1157 indices = op.inputs[1] 1158 updates_grad = array_ops.gather_nd(grad, indices) 1159 tensor_grad = array_ops.identity(grad) 1160 return [tensor_grad, None, updates_grad] 1161 1162 1163def _TensorScatterMinOrMaxGrad(op, grad): 1164 """Gradient for TensorScatterMin and TensorScatterMax.""" 1165 indices = op.inputs[1] 1166 x = op.inputs[0] 1167 y = op.inputs[2] 1168 output = op.outputs[0] 1169 x_indicators = math_ops.cast(math_ops.equal(x, output), grad.dtype) 1170 y_output = array_ops.gather_nd(output, indices) 1171 y_indicators = math_ops.cast(math_ops.equal(y, y_output), grad.dtype) 1172 ys_indicators = array_ops.scatter_nd( 1173 indices, y_indicators, array_ops.shape(x, out_type=indices.dtype)) 1174 indicators = x_indicators + ys_indicators # All elements are >= 1. 1175 # If there are multiple minimum or maximum elements then the gradient will be 1176 # divided between them. 1177 x_grad = grad * x_indicators / indicators 1178 y_grad = array_ops.gather_nd(grad / indicators, indices) * y_indicators 1179 return [x_grad, None, y_grad] 1180 1181 1182@ops.RegisterGradient("TensorScatterMax") 1183def _TensorScatterMaxGrad(op, grad): 1184 """Gradient for TensorScatterMax op.""" 1185 return _TensorScatterMinOrMaxGrad(op, grad) 1186 1187 1188@ops.RegisterGradient("TensorScatterMin") 1189def _TensorScatterMinGrad(op, grad): 1190 """Gradient for TensorScatterMin op.""" 1191 return _TensorScatterMinOrMaxGrad(op, grad) 1192 1193 1194@ops.RegisterGradient("TensorScatterSub") 1195def _TensorScatterSubGrad(op, grad): 1196 indices = op.inputs[1] 1197 updates_grad = array_ops.gather_nd(grad, indices) 1198 tensor_grad = array_ops.identity(grad) 1199 return [tensor_grad, None, -updates_grad] 1200 1201 1202@ops.RegisterGradient("ScatterNdNonAliasingAdd") 1203def _ScatterNdNonAliasingAddGrad(op, grad): 1204 indices = op.inputs[1] 1205 updates_grad = array_ops.gather_nd(grad, indices) 1206 return [grad, None, updates_grad] 1207 1208 1209@ops.RegisterGradient("BroadcastTo") 1210def _BroadcastToGrad(op, grad): 1211 input_value = op.inputs[0] 1212 broadcast_shape = op.inputs[1] 1213 shape_dtype = dtypes.int32 1214 if isinstance(broadcast_shape, ops.Tensor): 1215 shape_dtype = broadcast_shape.dtype 1216 1217 input_value_shape = array_ops.shape(input_value, out_type=shape_dtype) 1218 if not isinstance(broadcast_shape, ops.EagerTensor): 1219 broadcast_shape_static = tensor_shape.TensorShape( 1220 tensor_util.try_evaluate_constant(broadcast_shape)) 1221 if broadcast_shape_static.is_fully_defined(): 1222 broadcast_shape = constant_op.constant( 1223 broadcast_shape_static.as_list(), dtype=shape_dtype) 1224 _, reduction_axes = gen_array_ops.broadcast_gradient_args( 1225 broadcast_shape, input_value_shape) 1226 updates_grad_reshaped = math_ops.reduce_sum( 1227 grad, axis=reduction_axes, keepdims=True) 1228 updates_grad = array_ops.reshape(updates_grad_reshaped, input_value_shape) 1229 return [updates_grad, None] 1230