1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Gradients for operators defined in array_ops.py.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.compiler.tf2xla.ops import gen_xla_ops 22from tensorflow.python import pywrap_tfe 23from tensorflow.python.client import pywrap_tf_session 24from tensorflow.python.eager import context 25from tensorflow.python.framework import constant_op 26from tensorflow.python.framework import dtypes 27from tensorflow.python.framework import ops 28from tensorflow.python.framework import sparse_tensor 29from tensorflow.python.framework import tensor_shape 30from tensorflow.python.framework import tensor_util 31from tensorflow.python.ops import array_ops 32from tensorflow.python.ops import control_flow_ops 33from tensorflow.python.ops import control_flow_util 34from tensorflow.python.ops import gen_array_ops 35from tensorflow.python.ops import gen_math_ops 36from tensorflow.python.ops import gen_resource_variable_ops 37from tensorflow.python.ops import math_ops 38from tensorflow.python.ops import sparse_ops 39 40 41@ops.RegisterGradient("Pack") 42def _PackGrad(op, grad): 43 """Gradient for pack op.""" 44 return array_ops.unstack(grad, num=op.get_attr("N"), axis=op.get_attr("axis")) 45 46 47@ops.RegisterGradient("Unpack") 48def _UnpackGrad(op, *grads): 49 """Gradient for unpack op.""" 50 return array_ops.stack(grads, axis=op.get_attr("axis")) 51 52 53def _ConcatGradHelper(op, grad, start_value_index, end_value_index, dim_index): 54 """Gradient for concat op. 55 56 Args: 57 op: An operation. 58 grad: `Tensor` or `IndexedSlices` representing the gradients with respect to 59 each output of the op. 60 start_value_index: An integer index of the first value in the op.inputs. 61 end_value_index: An integer index of the last value in the op.inputs. 62 dim_index: An integer index of concat_dim or axis parameter in op.inputs. 63 64 Returns: 65 Tensors representing the partial gradients with respect to each input 66 of the op. 67 68 Raises: 69 ValueError: if concat_dim/axis is not statically known. 70 """ 71 72 def _CreateDenseMaskAndBegin(sizes, concat_dim): 73 """Create variables for iteratively slicing a dense gradients tensor.""" 74 # Since shape is 1-D, shape_of_shape = [rank-of-inputs] 75 shape_of_shape = array_ops.shape(sizes[0]) 76 # Make a vector of length equal to the input's dimensions, 77 # with 0's everywhere and 1 in the concat dim position. 78 # Note: Can't use sparse_to_dense since it isn't GPU-capable (for now) 79 mask = array_ops.concat([ 80 array_ops.zeros( 81 array_ops.expand_dims(concat_dim, 0), dtype=dtypes.int32), [1], 82 array_ops.zeros(shape_of_shape - concat_dim - 1, dtype=dtypes.int32) 83 ], 0) 84 begin = array_ops.zeros(shape_of_shape, dtype=dtypes.int32) 85 return mask, begin 86 87 def _ExtractInputShapes(inputs): 88 """Extract the shapes of a set of input tensors.""" 89 if context.executing_eagerly(): 90 return array_ops.shape_n(inputs) 91 sizes = [] 92 fully_known = True 93 for x in inputs: 94 input_shape = array_ops.shape(x) 95 if not isinstance(input_shape, 96 ops.Tensor) or input_shape.op.type != "Const": 97 fully_known = False 98 break 99 sizes.append(input_shape) 100 101 if fully_known: 102 return sizes 103 else: 104 return array_ops.shape_n(inputs) 105 106 # Degenerate concatenation, just return grad. 107 if len(op.inputs) == 2: 108 return grad + [None] if end_value_index <= dim_index else [None] + grad 109 110 concat_dim = op.inputs[dim_index] 111 input_values = op.inputs[start_value_index:end_value_index] 112 113 out_grads = [] 114 if isinstance(grad, ops.Tensor): 115 if context.executing_eagerly() or isinstance(concat_dim, ops.EagerTensor): 116 # Using mod here for convenience since concat_dim is already verified 117 # in concat implementation to be within the allowed [-rank, rank) range. 118 non_neg_concat_dim = ( 119 concat_dim._numpy().item(0) % input_values[0]._rank()) # pylint: disable=protected-access 120 # All inputs are guaranteed to be EagerTensors in eager mode 121 sizes = pywrap_tfe.TFE_Py_TensorShapeSlice(input_values, 122 non_neg_concat_dim) 123 out_grads = array_ops.split(grad, sizes, non_neg_concat_dim) 124 else: 125 if constant_op.is_constant(concat_dim): 126 # If concat_dim is a constant defined in a different context, 127 # then we duplicate it in the current context to avoid passing it 128 # through an Enter node. 129 # This is a small optimization in general, but it is required when 130 # compiling with XLA, as XLA needs the concat input to be folded into a 131 # constant. 132 grad_context = control_flow_util.GetOutputContext(grad.op) 133 dim_context = control_flow_util.GetOutputContext(concat_dim.op) 134 if dim_context != grad_context: 135 value = tensor_util.constant_value(concat_dim) 136 concat_dim = constant_op.constant(value=value, dtype=concat_dim.dtype) 137 138 # Using mod here for convenience since concat_dim is already verified 139 # in concat implementation to be within the allowed [-rank, rank) range. 140 non_neg_concat_dim = concat_dim % array_ops.rank(input_values[0]) 141 142 # Get the inputs' tensor shapes 143 sizes = _ExtractInputShapes(input_values) 144 # The magic number of 16 was found through benchmarking a range of sizes 145 # on CPUs and a Maxwell TitanX. A speedup was seen in a large majority of 146 # cases when switching implementations at N=16, but it is possible that 147 # there will be a small number of performance regressions. 148 if len(sizes) > 16: 149 # extract the size of each input along the concat dimension 150 sizes = array_ops.squeeze( 151 array_ops.slice( 152 array_ops.stack(sizes, axis=1), [non_neg_concat_dim, 0], 153 [1, -1])) 154 out_grads = array_ops.split(grad, sizes, non_neg_concat_dim) 155 else: 156 offset = gen_array_ops.concat_offset(non_neg_concat_dim, sizes) 157 for (begin, size) in zip(offset, sizes): 158 out_grads.append(array_ops.slice(grad, begin, size)) 159 elif isinstance(grad, ops.IndexedSlices): 160 # Using mod here for convenience since concat_dim is already verified 161 # in concat implementation to be within the allowed [-rank, rank) range. 162 non_neg_concat_dim = concat_dim % array_ops.rank(input_values[0]) 163 concat_dim_static = tensor_util.constant_value(concat_dim) 164 if concat_dim_static is None: 165 raise ValueError("Can only compute IndexedSlices gradient with " 166 "statically-known concat_dim") 167 if concat_dim_static < 0: 168 rank = tensor_util.constant_value(array_ops.rank(input_values[0])) 169 if rank is None: 170 raise ValueError("Can only compute IndexedSlices gradient with " 171 "negative concat_dim when first value rank is " 172 "statically-known.") 173 concat_dim_static %= rank 174 # Get the inputs' tensor shapes 175 sizes = [array_ops.shape(x) for x in input_values] 176 if concat_dim_static > 0: 177 # IndexedSlices, non_neg_concat_dim > 0. Each input gets IndexedSlices 178 # gradients with all the indices, but with grad.values sliced accordingly. 179 # This is like the Tensor case, except shape(grad.values)[0] is not equal 180 # to shape(sizes[i])[0], since only a subset of the dim-0 values are 181 # stored. 182 mask, begin = _CreateDenseMaskAndBegin(sizes, non_neg_concat_dim) 183 for size in sizes: 184 new_values = array_ops.slice( 185 grad.values, begin, 186 array_ops.concat([[-1], array_ops.slice(size, [1], [-1])], 0)) 187 out_grads.append(ops.IndexedSlices(new_values, grad.indices, size)) 188 # Lint complains begin = begin + ... 189 begin = math_ops.add(begin, size * mask) 190 else: 191 # IndexedSlices, concat_dim == 0. Each input gets IndexedSlices gradients 192 # only for the relevant indices. 193 start = constant_op.constant(0, dtype=grad.indices.dtype) 194 for size in sizes: 195 size_concat_dim = array_ops.gather(size, non_neg_concat_dim) 196 if size_concat_dim.dtype != grad.indices.dtype: 197 size_concat_dim = math_ops.cast( 198 size_concat_dim, dtype=grad.indices.dtype) 199 end = start + size_concat_dim 200 # Compute the 1-D Tensor of indices relevant for this input. 201 indices_to_select = array_ops.squeeze( 202 array_ops.where( 203 math_ops.logical_and(grad.indices >= start, 204 grad.indices < end)), 205 axis=[1]) 206 new_indices = array_ops.gather(grad.indices, indices_to_select) - start 207 new_values = array_ops.gather(grad.values, indices_to_select) 208 out_grads.append(ops.IndexedSlices(new_values, new_indices, size)) 209 start = end 210 else: 211 raise TypeError("Expected Tensor or IndexedSlices, got %s" % type(grad)) 212 213 return (out_grads + [None] if end_value_index <= dim_index else [None] + 214 out_grads) 215 216 217@ops.RegisterGradient("Concat") 218def _ConcatGrad(op, grad): 219 return _ConcatGradHelper( 220 op, 221 grad, 222 start_value_index=1, 223 end_value_index=len(op.inputs), 224 dim_index=0) 225 226 227@ops.RegisterGradient("ConcatV2") 228def _ConcatGradV2(op, grad): 229 return _ConcatGradHelper( 230 op, grad, start_value_index=0, end_value_index=-1, dim_index=-1) 231 232 233ops.NotDifferentiable("ConcatOffset") 234 235 236@ops.RegisterGradient("Slice") 237def _SliceGrad(op, grad): 238 """Gradient for Slice op.""" 239 # Create an Nx2 padding where the first column represents how many 240 # zeros are to be prepended for each dimension, and the second 241 # column indicates how many zeros are appended. 242 # 243 # The number of zeros to append is the shape of the input 244 # elementwise-subtracted by both the begin vector and sizes vector. 245 # 246 # Some more reshaping is needed to assemble this tensor with the 247 # right dimensions. 248 input_vec = op.inputs[0] 249 begin_vec = op.inputs[1] 250 input_rank = array_ops.rank(input_vec) 251 slice_size = array_ops.shape(op.outputs[0]) 252 if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()): 253 return gen_xla_ops.xla_dynamic_update_slice(array_ops.zeros_like(input_vec), 254 grad, begin_vec), None, None 255 256 shape = array_ops.stack([input_rank, 1]) 257 before_pad = array_ops.reshape(begin_vec, shape) 258 after_pad = array_ops.reshape( 259 array_ops.shape(input_vec) - slice_size - begin_vec, shape) 260 paddings = array_ops.concat([before_pad, after_pad], 1) 261 return array_ops.pad(grad, paddings), None, None 262 263 264@ops.RegisterGradient("StridedSlice") 265def _StridedSliceGrad(op, grad): 266 """Gradient for StridedSlice op.""" 267 begin = op.inputs[1] 268 end = op.inputs[2] 269 strides = op.inputs[3] 270 # StridedSliceGrad requires `x`, `begin`, `end` and `strides` to be of the 271 # same dtype so we build a shape of the same type as other args. 272 # Note that the choice of `begin` for specifying `out_type` is arbitrary. 273 # We could choose any of {begin|end|strides}.dtype since they are required to 274 # be the same. 275 x = array_ops.shape(op.inputs[0], out_type=begin.dtype) 276 277 x_static = tensor_util.constant_value(x) 278 x = x_static if x_static is not None else x 279 begin_static = tensor_util.constant_value(begin) 280 begin = begin_static if begin_static is not None else begin 281 end_static = tensor_util.constant_value(end) 282 end = end_static if end_static is not None else end 283 strides_static = tensor_util.constant_value(strides) 284 strides = strides_static if strides_static is not None else strides 285 286 return array_ops.strided_slice_grad( 287 x, 288 begin, 289 end, 290 strides, 291 grad, 292 begin_mask=op.get_attr("begin_mask"), 293 end_mask=op.get_attr("end_mask"), 294 ellipsis_mask=op.get_attr("ellipsis_mask"), 295 new_axis_mask=op.get_attr("new_axis_mask"), 296 shrink_axis_mask=op.get_attr("shrink_axis_mask")), None, None, None 297 298 299@ops.RegisterGradient("StridedSliceGrad") 300def _StridedSliceGradGrad(op, grad): 301 """Gradient for StridedSliceGrad op.""" 302 begin = op.inputs[1] 303 end = op.inputs[2] 304 strides = op.inputs[3] 305 306 return None, None, None, None, array_ops.strided_slice( 307 grad, 308 begin, 309 end, 310 strides, 311 begin_mask=op.get_attr("begin_mask"), 312 end_mask=op.get_attr("end_mask"), 313 ellipsis_mask=op.get_attr("ellipsis_mask"), 314 new_axis_mask=op.get_attr("new_axis_mask"), 315 shrink_axis_mask=op.get_attr("shrink_axis_mask")) 316 317 318@ops.RegisterGradient("TensorStridedSliceUpdate") 319def _TensorStridedSliceUpdateGrad(op, grad): # pylint:disable=missing-function-docstring 320 begin = op.inputs[1] 321 end = op.inputs[2] 322 strides = op.inputs[3] 323 begin_mask = op.get_attr("begin_mask") 324 end_mask = op.get_attr("end_mask") 325 ellipsis_mask = op.get_attr("ellipsis_mask") 326 new_axis_mask = op.get_attr("new_axis_mask") 327 shrink_axis_mask = op.get_attr("shrink_axis_mask") 328 def Apply(f, *args): 329 return f(*args, 330 begin_mask=begin_mask, 331 end_mask=end_mask, 332 shrink_axis_mask=shrink_axis_mask, 333 new_axis_mask=new_axis_mask, 334 ellipsis_mask=ellipsis_mask) 335 dy = Apply(array_ops.strided_slice, 336 grad, begin, end, strides) 337 dx = Apply(array_ops.tensor_strided_slice_update, 338 grad, begin, end, strides, array_ops.zeros_like(dy)) 339 return dx, None, None, None, dy 340 341 342@ops.RegisterGradient("Split") 343def _SplitGrad(op, *grads): 344 return None, array_ops.concat(list(grads), op.inputs[0]) 345 346 347@ops.RegisterGradient("SplitV") 348def _SplitVGrad(op, *grads): 349 returnval = array_ops.concat(list(grads), op.inputs[2]) 350 returnval = [returnval] + [ 351 None, 352 ] * ( 353 len(op.inputs) - 1) 354 return returnval 355 356 357ops.NotDifferentiable("Const") 358 359 360@ops.RegisterGradient("Diag") 361def _DiagGrad(_, grad): 362 return array_ops.diag_part(grad) 363 364 365@ops.RegisterGradient("DiagPart") 366def _DiagPartGrad(_, grad): 367 return array_ops.diag(grad) 368 369 370@ops.RegisterGradient("MatrixDiag") 371def _MatrixDiagGrad(_, grad): 372 return array_ops.matrix_diag_part(grad) 373 374 375@ops.RegisterGradient("MatrixDiagV2") 376def _MatrixDiagV2Grad(op, grad): 377 return array_ops.matrix_diag_part( 378 grad, k=op.inputs[1]), None, None, None, None 379 380 381@ops.RegisterGradient("MatrixDiagV3") 382def _MatrixDiagV3Grad(op, grad): 383 return array_ops.matrix_diag_part( 384 grad, k=op.inputs[1], align=op.get_attr("align")), None, None, None, None 385 386 387@ops.RegisterGradient("MatrixDiagPart") 388def _MatrixDiagPartGrad(op, grad): 389 matrix_shape = op.inputs[0].get_shape()[-2:] 390 if matrix_shape.is_fully_defined() and matrix_shape[0] == matrix_shape[1]: 391 return array_ops.matrix_diag(grad) 392 else: 393 return array_ops.matrix_set_diag(array_ops.zeros_like(op.inputs[0]), grad) 394 395 396@ops.RegisterGradient("MatrixDiagPartV2") 397def _MatrixDiagPartV2Grad(op, grad): 398 """Gradient for MatrixDiagPartV2.""" 399 matrix_shape = op.inputs[0].get_shape()[-2:] 400 if matrix_shape.is_fully_defined(): 401 return array_ops.matrix_diag( 402 grad, 403 k=op.inputs[1], 404 num_rows=matrix_shape[0], 405 num_cols=matrix_shape[1]), None, None 406 else: 407 return array_ops.matrix_set_diag( 408 array_ops.zeros_like(op.inputs[0]), grad, k=op.inputs[1]), None, None 409 410 411@ops.RegisterGradient("MatrixDiagPartV3") 412def _MatrixDiagPartV3Grad(op, grad): 413 """Gradient for MatrixDiagPartV3.""" 414 matrix_shape = op.inputs[0].get_shape()[-2:] 415 align = op.get_attr("align") 416 if matrix_shape.is_fully_defined(): 417 return array_ops.matrix_diag( 418 grad, 419 k=op.inputs[1], 420 num_rows=matrix_shape[0], 421 num_cols=matrix_shape[1], 422 align=align), None, None 423 else: 424 return array_ops.matrix_set_diag( 425 array_ops.zeros_like(op.inputs[0]), grad, k=op.inputs[1], 426 align=align), None, None 427 428 429@ops.RegisterGradient("MatrixSetDiag") 430def _MatrixSetDiagGrad(op, grad): 431 """Gradient for MatrixSetDiag.""" 432 input_shape = op.inputs[0].get_shape().merge_with(grad.get_shape()) 433 diag_shape = op.inputs[1].get_shape() 434 batch_shape = input_shape[:-2].merge_with(diag_shape[:-1]) 435 matrix_shape = input_shape[-2:] 436 if batch_shape.is_fully_defined() and matrix_shape.is_fully_defined(): 437 diag_shape = batch_shape.as_list() + [min(matrix_shape.as_list())] 438 else: 439 with ops.colocate_with(grad): 440 grad_shape = array_ops.shape(grad) 441 grad_rank = array_ops.rank(grad) 442 batch_shape = array_ops.slice(grad_shape, [0], [grad_rank - 2]) 443 matrix_shape = array_ops.slice(grad_shape, [grad_rank - 2], [2]) 444 min_dim = math_ops.reduce_min(matrix_shape) 445 diag_shape = array_ops.concat([batch_shape, [min_dim]], 0) 446 grad_input = array_ops.matrix_set_diag( 447 grad, array_ops.zeros(diag_shape, dtype=grad.dtype)) 448 grad_diag = array_ops.matrix_diag_part(grad) 449 return (grad_input, grad_diag) 450 451 452@ops.RegisterGradient("MatrixSetDiagV2") 453def _MatrixSetDiagGradV2(op, grad): 454 """Gradient for MatrixSetDiagV2.""" 455 diag_shape = op.inputs[1].get_shape() 456 if not diag_shape.is_fully_defined(): 457 # Need to know the values of `d_lower` and `d_upper` to infer diag_shape. 458 grad_shape = array_ops.shape(grad) 459 batch_shape = grad_shape[:-2] 460 matrix_shape = grad_shape[-2:] 461 diag_index = array_ops.reshape(op.inputs[2], [-1]) # Converts to vector. 462 d_lower = diag_index[0] 463 d_upper = diag_index[-1] # Works both when len(diag_index) is 1 and 2. 464 y_offset = control_flow_ops.cond( 465 math_ops.less(d_upper, 0), lambda: d_upper, lambda: 0) 466 x_offset = control_flow_ops.cond( 467 math_ops.greater(d_lower, 0), lambda: -d_lower, lambda: 0) 468 469 max_diag_len = math_ops.minimum(matrix_shape[0] + y_offset, 470 matrix_shape[1] + x_offset) 471 # pylint: disable=g-long-lambda 472 # pyformat: disable 473 postfix = control_flow_ops.cond( 474 math_ops.equal(d_lower, d_upper), 475 lambda: ops.convert_to_tensor([max_diag_len]), 476 lambda: ops.convert_to_tensor([d_upper - d_lower + 1, 477 max_diag_len])) 478 # pyformat: enable 479 # pylint: enable=g-long-lambda 480 diag_shape = array_ops.concat([batch_shape, postfix], 0) 481 482 grad_input = array_ops.matrix_set_diag( 483 grad, array_ops.zeros(diag_shape, dtype=grad.dtype), k=op.inputs[2]) 484 grad_diag = array_ops.matrix_diag_part(grad, k=op.inputs[2]) 485 return (grad_input, grad_diag, None) 486 487 488@ops.RegisterGradient("MatrixSetDiagV3") 489def _MatrixSetDiagGradV3(op, grad): 490 """Gradient for MatrixSetDiagV3.""" 491 diag_shape = op.inputs[1].get_shape() 492 align = op.get_attr("align") 493 if not diag_shape.is_fully_defined(): 494 # Need to know the values of `d_lower` and `d_upper` to infer diag_shape. 495 grad_shape = array_ops.shape(grad) 496 batch_shape = grad_shape[:-2] 497 matrix_shape = grad_shape[-2:] 498 diag_index = array_ops.reshape(op.inputs[2], [-1]) # Converts to vector. 499 d_lower = diag_index[0] 500 d_upper = diag_index[-1] # Works both when len(diag_index) is 1 and 2. 501 y_offset = control_flow_ops.cond( 502 math_ops.less(d_upper, 0), lambda: d_upper, lambda: 0) 503 x_offset = control_flow_ops.cond( 504 math_ops.greater(d_lower, 0), lambda: -d_lower, lambda: 0) 505 506 max_diag_len = math_ops.minimum(matrix_shape[0] + y_offset, 507 matrix_shape[1] + x_offset) 508 # pylint: disable=g-long-lambda 509 # pyformat: disable 510 postfix = control_flow_ops.cond( 511 math_ops.equal(d_lower, d_upper), 512 lambda: ops.convert_to_tensor([max_diag_len]), 513 lambda: ops.convert_to_tensor([d_upper - d_lower + 1, 514 max_diag_len])) 515 # pyformat: enable 516 # pylint: enable=g-long-lambda 517 diag_shape = array_ops.concat([batch_shape, postfix], 0) 518 519 grad_input = array_ops.matrix_set_diag( 520 grad, 521 array_ops.zeros(diag_shape, dtype=grad.dtype), 522 k=op.inputs[2], 523 align=align) 524 grad_diag = array_ops.matrix_diag_part(grad, k=op.inputs[2], align=align) 525 return (grad_input, grad_diag, None) 526 527 528@ops.RegisterGradient("MatrixBandPart") 529def _MatrixBandPartGrad(op, grad): 530 num_lower = op.inputs[1] 531 num_upper = op.inputs[2] 532 return (array_ops.matrix_band_part(grad, num_lower, num_upper), None, None) 533 534 535# Edit Distance has no gradient (but can be used to eval seq2seq or CTC). 536ops.NotDifferentiable("EditDistance") 537 538 539@ops.RegisterGradient("Fill") 540def _FillGrad(_, grad): 541 return None, math_ops.reduce_sum(grad) 542 543 544ops.NotDifferentiable("ZerosLike") 545ops.NotDifferentiable("OnesLike") 546 547 548@ops.RegisterGradient("PreventGradient") 549def _PreventGradientGrad(op, _): 550 raise LookupError("Gradient explicitly disabled. Reason: %s" % 551 op.get_attr("message")) 552 553 554def _IndexedSlicesToTensorNoWarning(indexed_slices): 555 """Converts an IndexedSlices to a Tensor without sparse->dense warnings.""" 556 if not isinstance(indexed_slices, ops.IndexedSlices): 557 # If it is not IndexedSlices, it's better be a tensor. 558 return indexed_slices 559 if indexed_slices.dense_shape is None: 560 raise ValueError( 561 "Tensor conversion requested for IndexedSlices without dense_shape: %s" 562 % str(indexed_slices)) 563 return math_ops.unsorted_segment_sum(indexed_slices.values, 564 indexed_slices.indices, 565 indexed_slices.dense_shape[0]) 566 567 568@ops.RegisterGradient("Gather") 569def _GatherGrad(op, grad): 570 """Gradient for Gather op.""" 571 # params can be large, so colocate the shape calculation with it. 572 params = op.inputs[0] 573 with ops.colocate_with(params): 574 params_shape = array_ops.shape(params) 575 576 # Build appropriately shaped IndexedSlices 577 indices = op.inputs[1] 578 size = array_ops.expand_dims(array_ops.size(indices), 0) 579 values_shape = array_ops.concat([size, params_shape[1:]], 0) 580 values = array_ops.reshape( 581 _IndexedSlicesToTensorNoWarning(grad), values_shape) 582 indices = array_ops.reshape(indices, size) 583 return [ops.IndexedSlices(values, indices, params_shape), None] 584 585 586def _GetBatchIndices(params_shape, indices, batch_dims): 587 """Addds the batch offsets to the given indices and returns the results.""" 588 batch_indices = indices 589 indices_ndims = indices.shape.ndims 590 indices_dtype = indices.dtype.base_dtype 591 casted_params_shape = math_ops.cast(params_shape, indices_dtype) 592 accum_dim_value = array_ops.ones((), dtype=indices_dtype) 593 for dim in range(batch_dims, 0, -1): 594 dim_value = casted_params_shape[dim - 1] 595 accum_dim_value *= casted_params_shape[dim] 596 start = array_ops.zeros((), dtype=indices_dtype) 597 step = array_ops.ones((), dtype=indices_dtype) 598 dim_indices = math_ops.range(start, dim_value, step) 599 dim_indices *= accum_dim_value 600 dim_shape = array_ops.stack( 601 [1] * (dim - 1) + [dim_value] + [1] * (indices_ndims - dim), axis=0) 602 batch_indices += array_ops.reshape(dim_indices, dim_shape) 603 604 return batch_indices 605 606 607def _BatchGatherGrad(params_shape, values, indices, batch_dims, 608 gather_dim_size): 609 """Returns the gradient of GatherV2 with batch dimensions.""" 610 611 # Axis is the first non-batch dimension. 612 indices_size = array_ops.expand_dims(array_ops.size(indices), 0) 613 if batch_dims: 614 values_shape = array_ops.shape(values) 615 # Add the batch offsets to indices and flatten the batch dimensions. 616 outer_shape = values_shape[:batch_dims] 617 inner_shape = values_shape[batch_dims:][1:] 618 batch_size = gen_math_ops.prod(outer_shape, [0], False) 619 flat_values_shape = array_ops.concat([[-1], inner_shape], 0) 620 gather_dim_size *= batch_size 621 622 indices = _GetBatchIndices(params_shape, indices, batch_dims) 623 values = array_ops.reshape( 624 _IndexedSlicesToTensorNoWarning(values), flat_values_shape) 625 626 indices = array_ops.reshape(indices, indices_size) 627 params_grad = math_ops.unsorted_segment_sum(values, indices, gather_dim_size) 628 629 if batch_dims: 630 # Put back the batch dimensions. 631 params_grad = array_ops.reshape( 632 params_grad, array_ops.concat([outer_shape, flat_values_shape], 0)) 633 634 return params_grad 635 636 637@ops.RegisterGradient("GatherV2") 638def _GatherV2Grad(op, grad): 639 """Gradient for GatherV2 op.""" 640 # params can be large, so colocate the shape calculation with it. 641 # 642 # params can be very large for sparse model, array_ops.shape raises 643 # exception on the Windows platform when any dimension is larger than 644 # int32. params_shape is not used in optimizer apply_sparse gradients, 645 # so it's fine to convert it back to int32 regardless of truncation. 646 params = op.inputs[0] 647 with ops.colocate_with(params): 648 params_shape = array_ops.shape(params, out_type=ops.dtypes.int64) 649 params_shape = math_ops.cast(params_shape, dtypes.int32) 650 651 indices = op.inputs[1] 652 indices_size = array_ops.expand_dims(array_ops.size(indices), 0) 653 axis = op.inputs[2] 654 axis_static = tensor_util.constant_value(axis) 655 batch_dims = int(op.get_attr("batch_dims")) 656 657 if batch_dims < 0: 658 batch_dims += indices.shape.ndims 659 660 # For axis 0 gathers, build an appropriately shaped IndexedSlices. 661 if axis_static == 0: 662 if context.executing_eagerly(): 663 with ops.device(indices_size.device): 664 params_tail_shape = array_ops.identity(params_shape)[1:] 665 else: 666 params_tail_shape = params_shape[1:] 667 values_shape = array_ops.concat([indices_size, params_tail_shape], 0) 668 values = array_ops.reshape( 669 _IndexedSlicesToTensorNoWarning(grad), values_shape) 670 indices = array_ops.reshape(indices, indices_size) 671 params_grad = ops.IndexedSlices(values, indices, params_shape) 672 else: 673 # Handle axis by transposing the axis dimension to be the first non-batch 674 # dimension, compute the gradient and transpose the result back. 675 outer_shape = params_shape[:axis] 676 inner_shape = params_shape[axis:][1:] 677 values_shape = array_ops.concat([outer_shape, [-1], inner_shape], 0) 678 679 values_dims = array_ops.size(values_shape) 680 axis_dims = array_ops.size(outer_shape) 681 682 outer_batches_indices = math_ops.range(batch_dims) 683 batch_axis_indices = math_ops.range(batch_dims, axis_dims) 684 inner_axes_indices = math_ops.range(axis_dims + 1, values_dims) 685 686 values = array_ops.reshape( 687 _IndexedSlicesToTensorNoWarning(grad), values_shape) 688 689 # Move values[axis] up to values[batch_dims] 690 transpose_dims = array_ops.concat([ 691 outer_batches_indices, [axis_dims], batch_axis_indices, 692 inner_axes_indices 693 ], 0) 694 values_transpose = array_ops.transpose(values, transpose_dims) 695 params_shape_transpose = array_ops.gather(params_shape, transpose_dims) 696 697 params_grad = _BatchGatherGrad(params_shape_transpose, values_transpose, 698 indices, batch_dims, params_shape[axis]) 699 700 # Inverts the above transpose by moving dimension batch_dims back to its 701 # original position. 702 invert_transpose_dims = array_ops.concat([ 703 outer_batches_indices, batch_axis_indices + 1, [batch_dims], 704 inner_axes_indices 705 ], 0) 706 params_grad = array_ops.transpose(params_grad, invert_transpose_dims) 707 708 return [params_grad, None, None] 709 710 711@ops.RegisterGradient("GatherNd") 712def _GatherNdGrad(op, grad): 713 ref = op.inputs[0] 714 indices = op.inputs[1] 715 ref_shape = array_ops.shape(ref, out_type=indices.dtype) 716 if indices.shape.ndims == 2 and indices.shape.dims[-1].value == 1: 717 ref_grad = ops.IndexedSlices(grad, array_ops.squeeze(indices, axis=-1), 718 ref_shape) 719 else: 720 ref_grad = array_ops.scatter_nd(indices, grad, ref_shape) 721 return [ref_grad, None] 722 723 724@ops.RegisterGradient("ResourceGatherNd") 725def _ResourceGatherNdGrad(op, grad): # pylint: disable=missing-docstring 726 ref = op.inputs[0] 727 indices = op.inputs[1] 728 ref_shape = gen_resource_variable_ops.variable_shape(ref, indices.dtype) 729 if indices.shape.ndims == 2 and indices.shape.dims[-1].value == 1: 730 ref_grad = ops.IndexedSlices(grad, array_ops.squeeze(indices, axis=-1), 731 ref_shape) 732 else: 733 ref_grad = array_ops.scatter_nd(indices, grad, ref_shape) 734 return [ref_grad, None] 735 736 737@ops.RegisterGradient("CheckNumerics") 738def _CheckNumericsGrad(op, grad): 739 """Gradient for check_numerics op.""" 740 return array_ops.check_numerics( 741 grad, 742 "Not a number (NaN) or infinity (Inf) values detected in gradient. %s" % 743 op.get_attr("message")) 744 745 746@ops.RegisterGradient("CheckNumericsV2") 747def _CheckNumericsV2Grad(op, grad): 748 """Gradient for check_numerics op.""" 749 return array_ops.check_numerics_v2( 750 grad, 751 "Not a number (NaN) or infinity (Inf) values detected in gradient. %s" % 752 op.get_attr("message")) 753 754 755@ops.RegisterGradient("PlaceholderWithDefault") 756@ops.RegisterGradient("Identity") 757def _IdGrad(_, grad): 758 return grad 759 760 761@ops.RegisterGradient("_EagerConst") 762def _EagerConstGrad(_, grad): 763 raise AssertionError( 764 "This op should never interact with gradient APIs. Please file a bug.") 765 766 767@ops.RegisterGradient("RefIdentity") 768def _RefIdGrad(_, grad): 769 return grad 770 771 772@ops.RegisterGradient("IdentityN") 773def _IdNGrad(_, *grad): 774 return grad 775 776 777ops.NotDifferentiable("StopGradient") 778 779 780@ops.RegisterGradient("Reshape") 781def _ReshapeGrad(op, grad): 782 return [ 783 array_ops.reshape( 784 _IndexedSlicesToTensorNoWarning(grad), array_ops.shape(op.inputs[0])), 785 None 786 ] 787 788 789ops.NotDifferentiable("InvertPermutation") 790 791 792def _ReshapeToInput(op, grad): 793 """Reshapes the gradient to the shape of the original input.""" 794 return array_ops.reshape( 795 _IndexedSlicesToTensorNoWarning(grad), array_ops.shape(op.inputs[0])) 796 797 798@ops.RegisterGradient("ExpandDims") 799def _ExpandDimsGrad(op, grad): 800 return [_ReshapeToInput(op, grad), None] 801 802 803@ops.RegisterGradient("Squeeze") 804def _SqueezeGrad(op, grad): 805 return _ReshapeToInput(op, grad) 806 807 808@ops.RegisterGradient("Transpose") 809def _TransposeGrad(op, grad): 810 """Returns unshuffle(grad).""" 811 p = op.inputs[1] 812 return [array_ops.transpose(grad, array_ops.invert_permutation(p)), None] 813 814 815@ops.RegisterGradient("ConjugateTranspose") 816def _ConjugateTransposeGrad(op, grad): 817 """Returns conj(unshuffle(grad)).""" 818 p = op.inputs[1] 819 return [ 820 array_ops.transpose( 821 grad, array_ops.invert_permutation(p), conjugate=True), None 822 ] 823 824 825ops.NotDifferentiable("Shape") 826 827ops.NotDifferentiable("ShapeN") 828 829ops.NotDifferentiable("Rank") 830 831ops.NotDifferentiable("Size") 832 833 834@ops.RegisterGradient("Tile") 835def _TileGrad(op, grad): 836 """Sum reduces grad along the tiled dimensions.""" 837 input_shape = array_ops.shape(op.inputs[0], out_type=op.inputs[1].dtype) 838 # We interleave multiples and input_shape to get split_shape, 839 # reshape grad to split_shape, and reduce along all even 840 # dimensions (the tiled dimensions) to get the result 841 # with shape input_shape. For example 842 # input_shape = [20, 30, 40] 843 # multiples = [2, 3, 4] 844 # split_shape = [2, 20, 3, 30, 4, 40] 845 # axes = [0, 2, 4] 846 split_shape = array_ops.reshape( 847 array_ops.transpose(array_ops.stack([op.inputs[1], input_shape])), [-1]) 848 axes = math_ops.range(0, array_ops.size(split_shape), 2) 849 # Sum reduces grad along the first dimension for IndexedSlices 850 if isinstance(grad, ops.IndexedSlices): 851 input_shape_0 = math_ops.cast(input_shape[0], grad.indices.dtype) 852 grad = math_ops.unsorted_segment_sum( 853 grad.values, math_ops.mod(grad.indices, input_shape_0), input_shape_0) 854 split_shape = array_ops.concat([[1], split_shape[1:]], axis=0) 855 input_grad = math_ops.reduce_sum(array_ops.reshape(grad, split_shape), axes) 856 # Fix shape inference 857 if not context.executing_eagerly(): 858 input_grad.set_shape(op.inputs[0].get_shape()) 859 return [input_grad, None] 860 861 862ops.NotDifferentiable("BroadcastGradientArgs") 863 864 865def _PadGrad(op, grad): 866 """Gradient for Pad.""" 867 # Pad introduces values around the original tensor, so the gradient function 868 # slices the original shape out of the gradient.""" 869 x = op.inputs[0] 870 a = op.inputs[1] # [Rank(x), 2] 871 # Takes a slice of a. The 1st column. [Rank(x), 1]. 872 pad_before = array_ops.slice(a, [0, 0], 873 array_ops.stack([array_ops.rank(x), 1])) 874 # Make it a 1-D tensor. 875 begin = array_ops.reshape(pad_before, [-1]) 876 sizes = array_ops.shape(x, out_type=begin.dtype) 877 x_grad = array_ops.slice(grad, begin, sizes) 878 if len(op.inputs) == 3: 879 return x_grad, None, None 880 else: 881 return x_grad, None 882 883 884ops.RegisterGradient("Pad")(_PadGrad) 885ops.RegisterGradient("PadV2")(_PadGrad) 886 887 888# ReverseSequence is just a permutation. The gradient permutes back. 889@ops.RegisterGradient("ReverseSequence") 890def _ReverseSequenceGrad(op, grad): 891 seq_lengths = op.inputs[1] 892 return [ 893 array_ops.reverse_sequence( 894 grad, 895 batch_axis=op.get_attr("batch_dim"), 896 seq_axis=op.get_attr("seq_dim"), 897 seq_lengths=seq_lengths), None 898 ] 899 900 901@ops.RegisterGradient("Reverse") 902def _ReverseGrad(op, grad): 903 reverse_dims = op.inputs[1] 904 return gen_array_ops.reverse(grad, reverse_dims), None 905 906 907@ops.RegisterGradient("ReverseV2") 908def _ReverseV2Grad(op, grad): 909 axis = op.inputs[1] 910 return array_ops.reverse_v2(grad, axis), None 911 912 913@ops.RegisterGradient("SpaceToBatch") 914def _SpaceToBatchGrad(op, grad): 915 # Its gradient is the opposite op: BatchToSpace. 916 block_size = op.get_attr("block_size") 917 return [ 918 array_ops.batch_to_space(grad, op.inputs[1], block_size=block_size), None 919 ] 920 921 922@ops.RegisterGradient("SpaceToBatchND") 923def _SpaceToBatchNDGrad(op, grad): 924 # Its gradient is the opposite op: BatchToSpaceND. 925 return [ 926 array_ops.batch_to_space_nd(grad, op.inputs[1], op.inputs[2]), None, None 927 ] 928 929 930@ops.RegisterGradient("BatchToSpace") 931def _BatchToSpaceGrad(op, grad): 932 # Its gradient is the opposite op: SpaceToBatch. 933 block_size = op.get_attr("block_size") 934 return [ 935 array_ops.space_to_batch(grad, op.inputs[1], block_size=block_size), None 936 ] 937 938 939@ops.RegisterGradient("BatchToSpaceND") 940def _BatchToSpaceNDGrad(op, grad): 941 # Its gradient is the opposite op: SpaceToBatchND. 942 return [ 943 array_ops.space_to_batch_nd(grad, op.inputs[1], op.inputs[2]), None, None 944 ] 945 946 947@ops.RegisterGradient("SpaceToDepth") 948def _SpaceToDepthGrad(op, grad): 949 # Its gradient is the opposite op: DepthToSpace. 950 block_size = op.get_attr("block_size") 951 data_format = op.get_attr("data_format") 952 if data_format == "NCHW_VECT_C": 953 raise ValueError("Cannot compute SpaceToDepth gradient with NCHW_VECT_C. " 954 "NCHW_VECT_C requires qint8 data type.") 955 return array_ops.depth_to_space(grad, block_size, data_format=data_format) 956 957 958@ops.RegisterGradient("DepthToSpace") 959def _DepthToSpaceGrad(op, grad): 960 # Its gradient is the opposite op: SpaceToDepth. 961 block_size = op.get_attr("block_size") 962 data_format = op.get_attr("data_format") 963 if data_format == "NCHW_VECT_C": 964 raise ValueError("Cannot compute DepthToSpace gradient with NCHW_VECT_C. " 965 "NCHW_VECT_C requires qint8 data type.") 966 return array_ops.space_to_depth(grad, block_size, data_format=data_format) 967 968 969ops.NotDifferentiable("OneHot") 970 971 972@ops.RegisterGradient("MirrorPad") 973def _MirrorPadGrad(op, grad): 974 mode = op.get_attr("mode") 975 return [gen_array_ops.mirror_pad_grad(grad, op.inputs[1], mode=mode), None] 976 977 978@ops.RegisterGradient("MirrorPadGrad") 979def _MirrorPadGradGrad(op, grad): 980 mode = op.get_attr("mode") 981 return [gen_array_ops.mirror_pad(grad, op.inputs[1], mode=mode), None] 982 983 984@ops.RegisterGradient("QuantizeAndDequantize") 985def _QuantizeAndDequantizeGrad(_, grad): 986 return grad 987 988 989@ops.RegisterGradient("QuantizeAndDequantizeV2") 990def _QuantizeAndDequantizeV2Grad(_, grad): 991 return [grad, None, None] 992 993 994@ops.RegisterGradient("QuantizeAndDequantizeV3") 995def _QuantizeAndDequantizeV3Grad(_, grad): 996 # Only propagate the gradient for the unquantized input. 997 return [grad, None, None, None] 998 999 1000@ops.RegisterGradient("ExtractImagePatches") 1001def _ExtractImagePatchesGrad(op, grad): 1002 input_bhwc = array_ops.shape(op.inputs[0], out_type=dtypes.int64) 1003 batch_size, rows_in, cols_in, channels = input_bhwc[0], input_bhwc[1], \ 1004 input_bhwc[2], input_bhwc[3] 1005 1006 # Create indices matrix for input tensor. 1007 # Note that 0 is preserved for padding location, 1008 # so indices for input start from 1 to 1 + rows_in * cols_in. 1009 input_indices_num = 1 + rows_in * cols_in 1010 input_idx = array_ops.reshape( 1011 math_ops.range(1, input_indices_num, dtype=ops.dtypes.int64), 1012 (1, rows_in, cols_in, 1)) 1013 input_idx_patched = gen_array_ops.extract_image_patches( 1014 input_idx, op.get_attr("ksizes"), op.get_attr("strides"), 1015 op.get_attr("rates"), op.get_attr("padding")) 1016 1017 # Create indices matrix for output tensor. 1018 output_bhwc = array_ops.shape(op.outputs[0], out_type=dtypes.int64) 1019 rows_out, cols_out = output_bhwc[1], output_bhwc[2] 1020 _, ksize_r, ksize_c, _ = op.get_attr("ksizes") 1021 # Indices for output start from 0. 1022 output_indices_num = rows_out * cols_out * ksize_r * ksize_c 1023 output_idx = array_ops.reshape( 1024 math_ops.range(output_indices_num, dtype=ops.dtypes.int64), 1025 (1, rows_out, cols_out, ksize_r * ksize_c)) 1026 1027 # Construct mapping table for indices: (input -> output). 1028 idx_matrix = array_ops.concat([ 1029 array_ops.expand_dims(input_idx_patched, axis=-1), 1030 array_ops.expand_dims(output_idx, axis=-1) 1031 ], 1032 axis=-1) 1033 idx_map = array_ops.reshape(idx_matrix, (-1, 2)) 1034 1035 sp_shape = (input_indices_num, output_indices_num) 1036 sp_mat_full = sparse_tensor.SparseTensor( 1037 idx_map, array_ops.ones([output_indices_num], dtype=grad.dtype), sp_shape) 1038 # Remove all padding locations [0, :]. 1039 sp_mat = sparse_ops.sparse_slice(sp_mat_full, (1, 0), 1040 (input_indices_num - 1, output_indices_num)) 1041 1042 grad_expanded = array_ops.transpose( 1043 array_ops.reshape( 1044 _IndexedSlicesToTensorNoWarning(grad), 1045 (batch_size, rows_out, cols_out, ksize_r, ksize_c, channels)), 1046 (1, 2, 3, 4, 0, 5)) 1047 grad_flat = array_ops.reshape(grad_expanded, (-1, batch_size * channels)) 1048 1049 jac = sparse_ops.sparse_tensor_dense_matmul(sp_mat, grad_flat) 1050 1051 grad_out = array_ops.reshape(jac, (rows_in, cols_in, batch_size, channels)) 1052 grad_out = array_ops.transpose(grad_out, (2, 0, 1, 3)) 1053 1054 return [grad_out] 1055 1056 1057@ops.RegisterGradient("ExtractVolumePatches") 1058def _ExtractVolumePatchesGrad(op, grad): 1059 batch_size, planes_in, rows_in, cols_in, channels = [ 1060 dim.value for dim in op.inputs[0].shape.dims 1061 ] 1062 input_bphwc = array_ops.shape(op.inputs[0]) 1063 batch_size = input_bphwc[0] 1064 channels = input_bphwc[4] 1065 1066 # Create indices matrix for input tensor. 1067 # Note that 0 is preserved for padding location, 1068 # so indices for input start from 1 to 1 + rows_in * cols_in. 1069 input_indices_num = 1 + planes_in * rows_in * cols_in 1070 input_idx = array_ops.reshape( 1071 math_ops.range(1, input_indices_num, dtype=ops.dtypes.int64), 1072 (1, planes_in, rows_in, cols_in, 1)) 1073 input_idx_patched = gen_array_ops.extract_volume_patches( 1074 input_idx, op.get_attr("ksizes"), op.get_attr("strides"), 1075 op.get_attr("padding")) 1076 1077 # Create indices matrix for output tensor. 1078 _, planes_out, rows_out, cols_out, _ = [ 1079 dim.value for dim in op.outputs[0].shape.dims 1080 ] 1081 _, ksize_p, ksize_r, ksize_c, _ = op.get_attr("ksizes") 1082 # Indices for output start from 0. 1083 prc_indices_num = planes_out * rows_out * cols_out 1084 output_indices_num = prc_indices_num * ksize_p * ksize_r * ksize_c 1085 output_idx = array_ops.reshape( 1086 math_ops.range(output_indices_num, dtype=ops.dtypes.int64), 1087 (1, planes_out, rows_out, cols_out, ksize_p * ksize_r * ksize_c)) 1088 1089 # Construct mapping table for indices: (input -> output). 1090 idx_matrix = array_ops.concat([ 1091 array_ops.expand_dims(input_idx_patched, axis=-1), 1092 array_ops.expand_dims(output_idx, axis=-1) 1093 ], 1094 axis=-1) 1095 idx_map = array_ops.reshape(idx_matrix, (-1, 2)) 1096 1097 sp_shape = (input_indices_num, output_indices_num) 1098 sp_mat_full = sparse_tensor.SparseTensor( 1099 idx_map, array_ops.ones([output_indices_num], dtype=grad.dtype), sp_shape) 1100 # Remove all padding locations [0, :]. 1101 sp_mat = sparse_ops.sparse_slice(sp_mat_full, (1, 0), 1102 (input_indices_num - 1, output_indices_num)) 1103 1104 grad_expanded = array_ops.transpose( 1105 array_ops.reshape( 1106 _IndexedSlicesToTensorNoWarning(grad), 1107 (batch_size, planes_out, rows_out, cols_out, ksize_p, ksize_r, 1108 ksize_c, channels)), (1, 2, 3, 4, 5, 6, 0, 7)) 1109 grad_flat = array_ops.reshape(grad_expanded, (-1, batch_size * channels)) 1110 1111 jac = sparse_ops.sparse_tensor_dense_matmul(sp_mat, grad_flat) 1112 1113 grad_out = array_ops.reshape( 1114 jac, (planes_in, rows_in, cols_in, batch_size, channels)) 1115 grad_out = array_ops.transpose(grad_out, (3, 0, 1, 2, 4)) 1116 1117 return [grad_out] 1118 1119 1120@ops.RegisterGradient("ScatterNd") 1121def _ScatterNdGrad(op, grad): 1122 indices = op.inputs[0] 1123 updates_grad = array_ops.gather_nd(grad, indices) 1124 return [None, updates_grad, None] 1125 1126 1127@ops.RegisterGradient("TensorScatterUpdate") 1128def _TensorScatterUpdateGrad(op, grad): 1129 indices = op.inputs[1] 1130 updates_grad = array_ops.gather_nd(grad, indices) 1131 tensor_grad = array_ops.tensor_scatter_update( 1132 array_ops.identity(grad), indices, 1133 array_ops.zeros_like(op.inputs[2], dtype=grad.dtype)) 1134 return [tensor_grad, None, updates_grad] 1135 1136 1137@ops.RegisterGradient("TensorScatterAdd") 1138def _TensorScatterAddGrad(op, grad): 1139 indices = op.inputs[1] 1140 updates_grad = array_ops.gather_nd(grad, indices) 1141 tensor_grad = array_ops.identity(grad) 1142 return [tensor_grad, None, updates_grad] 1143 1144 1145def _TensorScatterMinOrMaxGrad(op, grad): 1146 """Gradient for TensorScatterMin and TensorScatterMax.""" 1147 indices = op.inputs[1] 1148 x = op.inputs[0] 1149 y = op.inputs[2] 1150 output = op.outputs[0] 1151 x_indicators = math_ops.cast(math_ops.equal(x, output), grad.dtype) 1152 y_output = array_ops.gather_nd(output, indices) 1153 y_indicators = math_ops.cast(math_ops.equal(y, y_output), grad.dtype) 1154 ys_indicators = array_ops.scatter_nd(indices, y_indicators, 1155 array_ops.shape(x)) 1156 indicators = x_indicators + ys_indicators # All elements are >= 1. 1157 # If there are multiple minimum or maximum elements then the gradient will be 1158 # divided between them. 1159 x_grad = grad * x_indicators / indicators 1160 y_grad = array_ops.gather_nd(grad / indicators, indices) * y_indicators 1161 return [x_grad, None, y_grad] 1162 1163 1164@ops.RegisterGradient("TensorScatterMax") 1165def _TensorScatterMaxGrad(op, grad): 1166 """Gradient for TensorScatterMax op.""" 1167 return _TensorScatterMinOrMaxGrad(op, grad) 1168 1169 1170@ops.RegisterGradient("TensorScatterMin") 1171def _TensorScatterMinGrad(op, grad): 1172 """Gradient for TensorScatterMin op.""" 1173 return _TensorScatterMinOrMaxGrad(op, grad) 1174 1175 1176@ops.RegisterGradient("TensorScatterSub") 1177def _TensorScatterSubGrad(op, grad): 1178 indices = op.inputs[1] 1179 updates_grad = array_ops.gather_nd(grad, indices) 1180 tensor_grad = array_ops.identity(grad) 1181 return [tensor_grad, None, -updates_grad] 1182 1183 1184@ops.RegisterGradient("ScatterNdNonAliasingAdd") 1185def _ScatterNdNonAliasingAddGrad(op, grad): 1186 indices = op.inputs[1] 1187 updates_grad = array_ops.gather_nd(grad, indices) 1188 return [grad, None, updates_grad] 1189 1190 1191@ops.RegisterGradient("BroadcastTo") 1192def _BroadcastToGrad(op, grad): 1193 input_value = op.inputs[0] 1194 broadcast_shape = op.inputs[1] 1195 input_value_shape = array_ops.shape(input_value) 1196 if not isinstance(broadcast_shape, ops.EagerTensor): 1197 broadcast_shape_static = tensor_shape.TensorShape( 1198 pywrap_tf_session.TF_TryEvaluateConstant_wrapper( 1199 broadcast_shape.graph._c_graph, broadcast_shape._as_tf_output())) # pylint: disable=protected-access 1200 if broadcast_shape_static.is_fully_defined(): 1201 broadcast_shape = constant_op.constant( 1202 broadcast_shape_static.as_list(), dtype=dtypes.int32) 1203 _, reduction_axes = gen_array_ops.broadcast_gradient_args( 1204 broadcast_shape, input_value_shape) 1205 updates_grad_reshaped = math_ops.reduce_sum( 1206 grad, axis=reduction_axes, keepdims=True) 1207 updates_grad = array_ops.reshape(updates_grad_reshaped, input_value_shape) 1208 return [updates_grad, None] 1209