1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Gradients for operators defined in array_ops.py.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.compiler.tf2xla.ops import gen_xla_ops 22from tensorflow.python import pywrap_tfe 23from tensorflow.python.client import pywrap_tf_session 24from tensorflow.python.eager import context 25from tensorflow.python.framework import constant_op 26from tensorflow.python.framework import dtypes 27from tensorflow.python.framework import ops 28from tensorflow.python.framework import sparse_tensor 29from tensorflow.python.framework import tensor_shape 30from tensorflow.python.framework import tensor_util 31from tensorflow.python.ops import array_ops 32from tensorflow.python.ops import control_flow_ops 33from tensorflow.python.ops import control_flow_util 34from tensorflow.python.ops import gen_array_ops 35from tensorflow.python.ops import gen_math_ops 36from tensorflow.python.ops import gen_resource_variable_ops 37from tensorflow.python.ops import math_ops 38from tensorflow.python.ops import sparse_ops 39 40 41@ops.RegisterGradient("Pack") 42def _PackGrad(op, grad): 43 """Gradient for pack op.""" 44 return array_ops.unstack(grad, num=op.get_attr("N"), axis=op.get_attr("axis")) 45 46 47@ops.RegisterGradient("Unpack") 48def _UnpackGrad(op, *grads): 49 """Gradient for unpack op.""" 50 return array_ops.stack(grads, axis=op.get_attr("axis")) 51 52 53def _ConcatGradHelper(op, grad, start_value_index, end_value_index, dim_index): 54 """Gradient for concat op. 55 56 Args: 57 op: An operation. 58 grad: `Tensor` or `IndexedSlices` representing the gradients with respect to 59 each output of the op. 60 start_value_index: An integer index of the first value in the op.inputs. 61 end_value_index: An integer index of the last value in the op.inputs. 62 dim_index: An integer index of concat_dim or axis parameter in op.inputs. 63 64 Returns: 65 Tensors representing the partial gradients with respect to each input 66 of the op. 67 68 Raises: 69 ValueError: if concat_dim/axis is not statically known. 70 """ 71 72 def _CreateDenseMaskAndBegin(sizes, concat_dim): 73 """Create variables for iteratively slicing a dense gradients tensor.""" 74 # Since shape is 1-D, shape_of_shape = [rank-of-inputs] 75 shape_of_shape = array_ops.shape(sizes[0]) 76 # Make a vector of length equal to the input's dimensions, 77 # with 0's everywhere and 1 in the concat dim position. 78 # Note: Can't use sparse_to_dense since it isn't GPU-capable (for now) 79 mask = array_ops.concat([ 80 array_ops.fill(array_ops.expand_dims(concat_dim, 0), 0), [1], 81 array_ops.fill(shape_of_shape - concat_dim - 1, 0) 82 ], 0) 83 begin = array_ops.fill(shape_of_shape, 0) 84 return mask, begin 85 86 def _ExtractInputShapes(inputs): 87 """Extract the shapes of a set of input tensors.""" 88 if context.executing_eagerly(): 89 return array_ops.shape_n(inputs) 90 sizes = [] 91 fully_known = True 92 for x in inputs: 93 input_shape = array_ops.shape(x) 94 if not isinstance(input_shape, 95 ops.Tensor) or input_shape.op.type != "Const": 96 fully_known = False 97 break 98 sizes.append(input_shape) 99 100 if fully_known: 101 return sizes 102 else: 103 return array_ops.shape_n(inputs) 104 105 # Degenerate concatenation, just return grad. 106 if len(op.inputs) == 2: 107 return grad + [None] if end_value_index <= dim_index else [None] + grad 108 109 concat_dim = op.inputs[dim_index] 110 input_values = op.inputs[start_value_index:end_value_index] 111 112 out_grads = [] 113 if isinstance(grad, ops.Tensor): 114 if context.executing_eagerly() or isinstance(concat_dim, ops.EagerTensor): 115 # Using mod here for convenience since concat_dim is already verified 116 # in concat implementation to be within the allowed [-rank, rank) range. 117 non_neg_concat_dim = ( 118 concat_dim._numpy().item(0) % input_values[0]._rank()) # pylint: disable=protected-access 119 # All inputs are guaranteed to be EagerTensors in eager mode 120 sizes = pywrap_tfe.TFE_Py_TensorShapeSlice(input_values, 121 non_neg_concat_dim) 122 out_grads = array_ops.split(grad, sizes, non_neg_concat_dim) 123 else: 124 if constant_op.is_constant(concat_dim): 125 # If concat_dim is a constant defined in a different context, 126 # then we duplicate it in the current context to avoid passing it 127 # through an Enter node. 128 # This is a small optimization in general, but it is required when 129 # compiling with XLA, as XLA needs the concat input to be folded into a 130 # constant. 131 grad_context = control_flow_util.GetOutputContext(grad.op) 132 dim_context = control_flow_util.GetOutputContext(concat_dim.op) 133 if dim_context != grad_context: 134 value = tensor_util.constant_value(concat_dim) 135 concat_dim = constant_op.constant(value=value, dtype=concat_dim.dtype) 136 137 # Using mod here for convenience since concat_dim is already verified 138 # in concat implementation to be within the allowed [-rank, rank) range. 139 non_neg_concat_dim = concat_dim % array_ops.rank(input_values[0]) 140 141 # Get the inputs' tensor shapes 142 sizes = _ExtractInputShapes(input_values) 143 # The magic number of 16 was found through benchmarking a range of sizes 144 # on CPUs and a Maxwell TitanX. A speedup was seen in a large majority of 145 # cases when switching implementations at N=16, but it is possible that 146 # there will be a small number of performance regressions. 147 if len(sizes) > 16: 148 # extract the size of each input along the concat dimension 149 sizes = array_ops.squeeze( 150 array_ops.slice( 151 array_ops.stack(sizes, axis=1), [non_neg_concat_dim, 0], 152 [1, -1])) 153 out_grads = array_ops.split(grad, sizes, non_neg_concat_dim) 154 else: 155 offset = gen_array_ops.concat_offset(non_neg_concat_dim, sizes) 156 for (begin, size) in zip(offset, sizes): 157 out_grads.append(array_ops.slice(grad, begin, size)) 158 elif isinstance(grad, ops.IndexedSlices): 159 # Using mod here for convenience since concat_dim is already verified 160 # in concat implementation to be within the allowed [-rank, rank) range. 161 non_neg_concat_dim = concat_dim % array_ops.rank(input_values[0]) 162 concat_dim_static = tensor_util.constant_value(concat_dim) 163 if concat_dim_static is None: 164 raise ValueError("Can only compute IndexedSlices gradient with " 165 "statically-known concat_dim") 166 if concat_dim_static < 0: 167 rank = tensor_util.constant_value(array_ops.rank(input_values[0])) 168 if rank is None: 169 raise ValueError("Can only compute IndexedSlices gradient with " 170 "negative concat_dim when first value rank is " 171 "statically-known.") 172 concat_dim_static %= rank 173 # Get the inputs' tensor shapes 174 sizes = [array_ops.shape(x) for x in input_values] 175 if concat_dim_static > 0: 176 # IndexedSlices, non_neg_concat_dim > 0. Each input gets IndexedSlices 177 # gradients with all the indices, but with grad.values sliced accordingly. 178 # This is like the Tensor case, except shape(grad.values)[0] is not equal 179 # to shape(sizes[i])[0], since only a subset of the dim-0 values are 180 # stored. 181 mask, begin = _CreateDenseMaskAndBegin(sizes, non_neg_concat_dim) 182 for size in sizes: 183 new_values = array_ops.slice( 184 grad.values, begin, 185 array_ops.concat([[-1], array_ops.slice(size, [1], [-1])], 0)) 186 out_grads.append(ops.IndexedSlices(new_values, grad.indices, size)) 187 # Lint complains begin = begin + ... 188 begin = math_ops.add(begin, size * mask) 189 else: 190 # IndexedSlices, concat_dim == 0. Each input gets IndexedSlices gradients 191 # only for the relevant indices. 192 start = constant_op.constant(0, dtype=grad.indices.dtype) 193 for size in sizes: 194 size_concat_dim = array_ops.gather(size, non_neg_concat_dim) 195 if size_concat_dim.dtype != grad.indices.dtype: 196 size_concat_dim = math_ops.cast( 197 size_concat_dim, dtype=grad.indices.dtype) 198 end = start + size_concat_dim 199 # Compute the 1-D Tensor of indices relevant for this input. 200 indices_to_select = array_ops.squeeze( 201 array_ops.where( 202 math_ops.logical_and(grad.indices >= start, 203 grad.indices < end)), 204 axis=[1]) 205 new_indices = array_ops.gather(grad.indices, indices_to_select) - start 206 new_values = array_ops.gather(grad.values, indices_to_select) 207 out_grads.append(ops.IndexedSlices(new_values, new_indices, size)) 208 start = end 209 else: 210 raise TypeError("Expected Tensor or IndexedSlices, got %s" % type(grad)) 211 212 return (out_grads + [None] if end_value_index <= dim_index else [None] + 213 out_grads) 214 215 216@ops.RegisterGradient("Concat") 217def _ConcatGrad(op, grad): 218 return _ConcatGradHelper( 219 op, 220 grad, 221 start_value_index=1, 222 end_value_index=len(op.inputs), 223 dim_index=0) 224 225 226@ops.RegisterGradient("ConcatV2") 227def _ConcatGradV2(op, grad): 228 return _ConcatGradHelper( 229 op, grad, start_value_index=0, end_value_index=-1, dim_index=-1) 230 231 232ops.NotDifferentiable("ConcatOffset") 233 234 235@ops.RegisterGradient("Slice") 236def _SliceGrad(op, grad): 237 """Gradient for Slice op.""" 238 # Create an Nx2 padding where the first column represents how many 239 # zeros are to be prepended for each dimension, and the second 240 # column indicates how many zeros are appended. 241 # 242 # The number of zeros to append is the shape of the input 243 # elementwise-subtracted by both the begin vector and sizes vector. 244 # 245 # Some more reshaping is needed to assemble this tensor with the 246 # right dimensions. 247 input_vec = op.inputs[0] 248 begin_vec = op.inputs[1] 249 input_rank = array_ops.rank(input_vec) 250 slice_size = array_ops.shape(op.outputs[0]) 251 if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()): 252 return gen_xla_ops.xla_dynamic_update_slice(array_ops.zeros_like(input_vec), 253 grad, begin_vec), None, None 254 255 shape = array_ops.stack([input_rank, 1]) 256 before_pad = array_ops.reshape(begin_vec, shape) 257 after_pad = array_ops.reshape( 258 array_ops.shape(input_vec) - slice_size - begin_vec, shape) 259 paddings = array_ops.concat([before_pad, after_pad], 1) 260 return array_ops.pad(grad, paddings), None, None 261 262 263@ops.RegisterGradient("StridedSlice") 264def _StridedSliceGrad(op, grad): 265 """Gradient for StridedSlice op.""" 266 begin = op.inputs[1] 267 end = op.inputs[2] 268 strides = op.inputs[3] 269 # StridedSliceGrad requires `x`, `begin`, `end` and `strides` to be of the 270 # same dtype so we build a shape of the same type as other args. 271 # Note that the choice of `begin` for specifying `out_type` is arbitrary. 272 # We could choose any of {begin|end|strides}.dtype since they are required to 273 # be the same. 274 x = array_ops.shape(op.inputs[0], out_type=begin.dtype) 275 276 x_static = tensor_util.constant_value(x) 277 x = x_static if x_static is not None else x 278 begin_static = tensor_util.constant_value(begin) 279 begin = begin_static if begin_static is not None else begin 280 end_static = tensor_util.constant_value(end) 281 end = end_static if end_static is not None else end 282 strides_static = tensor_util.constant_value(strides) 283 strides = strides_static if strides_static is not None else strides 284 285 return array_ops.strided_slice_grad( 286 x, 287 begin, 288 end, 289 strides, 290 grad, 291 begin_mask=op.get_attr("begin_mask"), 292 end_mask=op.get_attr("end_mask"), 293 ellipsis_mask=op.get_attr("ellipsis_mask"), 294 new_axis_mask=op.get_attr("new_axis_mask"), 295 shrink_axis_mask=op.get_attr("shrink_axis_mask")), None, None, None 296 297 298@ops.RegisterGradient("StridedSliceGrad") 299def _StridedSliceGradGrad(op, grad): 300 """Gradient for StridedSliceGrad op.""" 301 begin = op.inputs[1] 302 end = op.inputs[2] 303 strides = op.inputs[3] 304 305 return None, None, None, None, array_ops.strided_slice( 306 grad, 307 begin, 308 end, 309 strides, 310 begin_mask=op.get_attr("begin_mask"), 311 end_mask=op.get_attr("end_mask"), 312 ellipsis_mask=op.get_attr("ellipsis_mask"), 313 new_axis_mask=op.get_attr("new_axis_mask"), 314 shrink_axis_mask=op.get_attr("shrink_axis_mask")) 315 316 317@ops.RegisterGradient("TensorStridedSliceUpdate") 318def _TensorStridedSliceUpdateGrad(op, grad): # pylint:disable=missing-function-docstring 319 begin = op.inputs[1] 320 end = op.inputs[2] 321 strides = op.inputs[3] 322 begin_mask = op.get_attr("begin_mask") 323 end_mask = op.get_attr("end_mask") 324 ellipsis_mask = op.get_attr("ellipsis_mask") 325 new_axis_mask = op.get_attr("new_axis_mask") 326 shrink_axis_mask = op.get_attr("shrink_axis_mask") 327 def Apply(f, *args): 328 return f(*args, 329 begin_mask=begin_mask, 330 end_mask=end_mask, 331 shrink_axis_mask=shrink_axis_mask, 332 new_axis_mask=new_axis_mask, 333 ellipsis_mask=ellipsis_mask) 334 dy = Apply(array_ops.strided_slice, 335 grad, begin, end, strides) 336 dx = Apply(array_ops.tensor_strided_slice_update, 337 grad, begin, end, strides, array_ops.zeros_like(dy)) 338 return dx, None, None, None, dy 339 340 341@ops.RegisterGradient("Split") 342def _SplitGrad(op, *grads): 343 return None, array_ops.concat(list(grads), op.inputs[0]) 344 345 346@ops.RegisterGradient("SplitV") 347def _SplitVGrad(op, *grads): 348 returnval = array_ops.concat(list(grads), op.inputs[2]) 349 returnval = [returnval] + [ 350 None, 351 ] * ( 352 len(op.inputs) - 1) 353 return returnval 354 355 356ops.NotDifferentiable("Const") 357 358 359@ops.RegisterGradient("Diag") 360def _DiagGrad(_, grad): 361 return array_ops.diag_part(grad) 362 363 364@ops.RegisterGradient("DiagPart") 365def _DiagPartGrad(_, grad): 366 return array_ops.diag(grad) 367 368 369@ops.RegisterGradient("MatrixDiag") 370def _MatrixDiagGrad(_, grad): 371 return array_ops.matrix_diag_part(grad) 372 373 374@ops.RegisterGradient("MatrixDiagV2") 375def _MatrixDiagV2Grad(op, grad): 376 return array_ops.matrix_diag_part( 377 grad, k=op.inputs[1]), None, None, None, None 378 379 380@ops.RegisterGradient("MatrixDiagV3") 381def _MatrixDiagV3Grad(op, grad): 382 return array_ops.matrix_diag_part( 383 grad, k=op.inputs[1], align=op.get_attr("align")), None, None, None, None 384 385 386@ops.RegisterGradient("MatrixDiagPart") 387def _MatrixDiagPartGrad(op, grad): 388 matrix_shape = op.inputs[0].get_shape()[-2:] 389 if matrix_shape.is_fully_defined() and matrix_shape[0] == matrix_shape[1]: 390 return array_ops.matrix_diag(grad) 391 else: 392 return array_ops.matrix_set_diag(array_ops.zeros_like(op.inputs[0]), grad) 393 394 395@ops.RegisterGradient("MatrixDiagPartV2") 396def _MatrixDiagPartV2Grad(op, grad): 397 """Gradient for MatrixDiagPartV2.""" 398 matrix_shape = op.inputs[0].get_shape()[-2:] 399 if matrix_shape.is_fully_defined(): 400 return array_ops.matrix_diag( 401 grad, 402 k=op.inputs[1], 403 num_rows=matrix_shape[0], 404 num_cols=matrix_shape[1]), None, None 405 else: 406 return array_ops.matrix_set_diag( 407 array_ops.zeros_like(op.inputs[0]), grad, k=op.inputs[1]), None, None 408 409 410@ops.RegisterGradient("MatrixDiagPartV3") 411def _MatrixDiagPartV3Grad(op, grad): 412 """Gradient for MatrixDiagPartV3.""" 413 matrix_shape = op.inputs[0].get_shape()[-2:] 414 align = op.get_attr("align") 415 if matrix_shape.is_fully_defined(): 416 return array_ops.matrix_diag( 417 grad, 418 k=op.inputs[1], 419 num_rows=matrix_shape[0], 420 num_cols=matrix_shape[1], 421 align=align), None, None 422 else: 423 return array_ops.matrix_set_diag( 424 array_ops.zeros_like(op.inputs[0]), grad, k=op.inputs[1], 425 align=align), None, None 426 427 428@ops.RegisterGradient("MatrixSetDiag") 429def _MatrixSetDiagGrad(op, grad): 430 """Gradient for MatrixSetDiag.""" 431 input_shape = op.inputs[0].get_shape().merge_with(grad.get_shape()) 432 diag_shape = op.inputs[1].get_shape() 433 batch_shape = input_shape[:-2].merge_with(diag_shape[:-1]) 434 matrix_shape = input_shape[-2:] 435 if batch_shape.is_fully_defined() and matrix_shape.is_fully_defined(): 436 diag_shape = batch_shape.as_list() + [min(matrix_shape.as_list())] 437 else: 438 with ops.colocate_with(grad): 439 grad_shape = array_ops.shape(grad) 440 grad_rank = array_ops.rank(grad) 441 batch_shape = array_ops.slice(grad_shape, [0], [grad_rank - 2]) 442 matrix_shape = array_ops.slice(grad_shape, [grad_rank - 2], [2]) 443 min_dim = math_ops.reduce_min(matrix_shape) 444 diag_shape = array_ops.concat([batch_shape, [min_dim]], 0) 445 grad_input = array_ops.matrix_set_diag( 446 grad, array_ops.zeros(diag_shape, dtype=grad.dtype)) 447 grad_diag = array_ops.matrix_diag_part(grad) 448 return (grad_input, grad_diag) 449 450 451@ops.RegisterGradient("MatrixSetDiagV2") 452def _MatrixSetDiagGradV2(op, grad): 453 """Gradient for MatrixSetDiagV2.""" 454 diag_shape = op.inputs[1].get_shape() 455 if not diag_shape.is_fully_defined(): 456 # Need to know the values of `d_lower` and `d_upper` to infer diag_shape. 457 grad_shape = array_ops.shape(grad) 458 batch_shape = grad_shape[:-2] 459 matrix_shape = grad_shape[-2:] 460 diag_index = array_ops.reshape(op.inputs[2], [-1]) # Converts to vector. 461 d_lower = diag_index[0] 462 d_upper = diag_index[-1] # Works both when len(diag_index) is 1 and 2. 463 y_offset = control_flow_ops.cond( 464 math_ops.less(d_upper, 0), lambda: d_upper, lambda: 0) 465 x_offset = control_flow_ops.cond( 466 math_ops.greater(d_lower, 0), lambda: -d_lower, lambda: 0) 467 468 max_diag_len = math_ops.minimum(matrix_shape[0] + y_offset, 469 matrix_shape[1] + x_offset) 470 # pylint: disable=g-long-lambda 471 # pyformat: disable 472 postfix = control_flow_ops.cond( 473 math_ops.equal(d_lower, d_upper), 474 lambda: ops.convert_to_tensor([max_diag_len]), 475 lambda: ops.convert_to_tensor([d_upper - d_lower + 1, 476 max_diag_len])) 477 # pyformat: enable 478 # pylint: enable=g-long-lambda 479 diag_shape = array_ops.concat([batch_shape, postfix], 0) 480 481 grad_input = array_ops.matrix_set_diag( 482 grad, array_ops.zeros(diag_shape, dtype=grad.dtype), k=op.inputs[2]) 483 grad_diag = array_ops.matrix_diag_part(grad, k=op.inputs[2]) 484 return (grad_input, grad_diag, None) 485 486 487@ops.RegisterGradient("MatrixSetDiagV3") 488def _MatrixSetDiagGradV3(op, grad): 489 """Gradient for MatrixSetDiagV3.""" 490 diag_shape = op.inputs[1].get_shape() 491 align = op.get_attr("align") 492 if not diag_shape.is_fully_defined(): 493 # Need to know the values of `d_lower` and `d_upper` to infer diag_shape. 494 grad_shape = array_ops.shape(grad) 495 batch_shape = grad_shape[:-2] 496 matrix_shape = grad_shape[-2:] 497 diag_index = array_ops.reshape(op.inputs[2], [-1]) # Converts to vector. 498 d_lower = diag_index[0] 499 d_upper = diag_index[-1] # Works both when len(diag_index) is 1 and 2. 500 y_offset = control_flow_ops.cond( 501 math_ops.less(d_upper, 0), lambda: d_upper, lambda: 0) 502 x_offset = control_flow_ops.cond( 503 math_ops.greater(d_lower, 0), lambda: -d_lower, lambda: 0) 504 505 max_diag_len = math_ops.minimum(matrix_shape[0] + y_offset, 506 matrix_shape[1] + x_offset) 507 # pylint: disable=g-long-lambda 508 # pyformat: disable 509 postfix = control_flow_ops.cond( 510 math_ops.equal(d_lower, d_upper), 511 lambda: ops.convert_to_tensor([max_diag_len]), 512 lambda: ops.convert_to_tensor([d_upper - d_lower + 1, 513 max_diag_len])) 514 # pyformat: enable 515 # pylint: enable=g-long-lambda 516 diag_shape = array_ops.concat([batch_shape, postfix], 0) 517 518 grad_input = array_ops.matrix_set_diag( 519 grad, 520 array_ops.zeros(diag_shape, dtype=grad.dtype), 521 k=op.inputs[2], 522 align=align) 523 grad_diag = array_ops.matrix_diag_part(grad, k=op.inputs[2], align=align) 524 return (grad_input, grad_diag, None) 525 526 527@ops.RegisterGradient("MatrixBandPart") 528def _MatrixBandPartGrad(op, grad): 529 num_lower = op.inputs[1] 530 num_upper = op.inputs[2] 531 return (array_ops.matrix_band_part(grad, num_lower, num_upper), None, None) 532 533 534# Edit Distance has no gradient (but can be used to eval seq2seq or CTC). 535ops.NotDifferentiable("EditDistance") 536 537 538@ops.RegisterGradient("Fill") 539def _FillGrad(_, grad): 540 return None, math_ops.reduce_sum(grad) 541 542 543ops.NotDifferentiable("ZerosLike") 544ops.NotDifferentiable("OnesLike") 545 546 547@ops.RegisterGradient("PreventGradient") 548def _PreventGradientGrad(op, _): 549 raise LookupError("Gradient explicitly disabled. Reason: %s" % 550 op.get_attr("message")) 551 552 553def _IndexedSlicesToTensorNoWarning(indexed_slices): 554 """Converts an IndexedSlices to a Tensor without sparse->dense warnings.""" 555 if not isinstance(indexed_slices, ops.IndexedSlices): 556 # If it is not IndexedSlices, it's better be a tensor. 557 return indexed_slices 558 if indexed_slices.dense_shape is None: 559 raise ValueError( 560 "Tensor conversion requested for IndexedSlices without dense_shape: %s" 561 % str(indexed_slices)) 562 return math_ops.unsorted_segment_sum(indexed_slices.values, 563 indexed_slices.indices, 564 indexed_slices.dense_shape[0]) 565 566 567@ops.RegisterGradient("Gather") 568def _GatherGrad(op, grad): 569 """Gradient for Gather op.""" 570 # params can be large, so colocate the shape calculation with it. 571 params = op.inputs[0] 572 with ops.colocate_with(params): 573 params_shape = array_ops.shape(params) 574 575 # Build appropriately shaped IndexedSlices 576 indices = op.inputs[1] 577 size = array_ops.expand_dims(array_ops.size(indices), 0) 578 values_shape = array_ops.concat([size, params_shape[1:]], 0) 579 values = array_ops.reshape( 580 _IndexedSlicesToTensorNoWarning(grad), values_shape) 581 indices = array_ops.reshape(indices, size) 582 return [ops.IndexedSlices(values, indices, params_shape), None] 583 584 585def _GetBatchIndices(params_shape, indices, batch_dims): 586 """Addds the batch offsets to the given indices and returns the results.""" 587 batch_indices = indices 588 indices_ndims = indices.shape.ndims 589 indices_dtype = indices.dtype.base_dtype 590 casted_params_shape = math_ops.cast(params_shape, indices_dtype) 591 accum_dim_value = array_ops.ones((), dtype=indices_dtype) 592 for dim in range(batch_dims, 0, -1): 593 dim_value = casted_params_shape[dim - 1] 594 accum_dim_value *= casted_params_shape[dim] 595 start = array_ops.zeros((), dtype=indices_dtype) 596 step = array_ops.ones((), dtype=indices_dtype) 597 dim_indices = math_ops.range(start, dim_value, step) 598 dim_indices *= accum_dim_value 599 dim_shape = array_ops.stack( 600 [1] * (dim - 1) + [dim_value] + [1] * (indices_ndims - dim), axis=0) 601 batch_indices += array_ops.reshape(dim_indices, dim_shape) 602 603 return batch_indices 604 605 606def _BatchGatherGrad(params_shape, values, indices, batch_dims, 607 gather_dim_size): 608 """Returns the gradient of GatherV2 with batch dimensions.""" 609 610 # Axis is the first non-batch dimension. 611 indices_size = array_ops.expand_dims(array_ops.size(indices), 0) 612 if batch_dims: 613 values_shape = array_ops.shape(values) 614 # Add the batch offsets to indices and flatten the batch dimensions. 615 outer_shape = values_shape[:batch_dims] 616 inner_shape = values_shape[batch_dims:][1:] 617 batch_size = gen_math_ops.prod(outer_shape, [0], False) 618 flat_values_shape = array_ops.concat([[-1], inner_shape], 0) 619 gather_dim_size *= batch_size 620 621 indices = _GetBatchIndices(params_shape, indices, batch_dims) 622 values = array_ops.reshape( 623 _IndexedSlicesToTensorNoWarning(values), flat_values_shape) 624 625 indices = array_ops.reshape(indices, indices_size) 626 params_grad = math_ops.unsorted_segment_sum(values, indices, gather_dim_size) 627 628 if batch_dims: 629 # Put back the batch dimensions. 630 params_grad = array_ops.reshape( 631 params_grad, array_ops.concat([outer_shape, flat_values_shape], 0)) 632 633 return params_grad 634 635 636@ops.RegisterGradient("GatherV2") 637def _GatherV2Grad(op, grad): 638 """Gradient for GatherV2 op.""" 639 # params can be large, so colocate the shape calculation with it. 640 # 641 # params can be very large for sparse model, array_ops.shape raises 642 # exception on the Windows platform when any dimension is larger than 643 # int32. params_shape is not used in optimizer apply_sparse gradients, 644 # so it's fine to convert it back to int32 regardless of truncation. 645 params = op.inputs[0] 646 with ops.colocate_with(params): 647 params_shape = array_ops.shape(params, out_type=ops.dtypes.int64) 648 params_shape = math_ops.cast(params_shape, dtypes.int32) 649 650 indices = op.inputs[1] 651 indices_size = array_ops.expand_dims(array_ops.size(indices), 0) 652 axis = op.inputs[2] 653 axis_static = tensor_util.constant_value(axis) 654 batch_dims = int(op.get_attr("batch_dims")) 655 656 if batch_dims < 0: 657 batch_dims += indices.shape.ndims 658 659 # For axis 0 gathers, build an appropriately shaped IndexedSlices. 660 if axis_static == 0: 661 if context.executing_eagerly(): 662 with ops.device(indices_size.device): 663 params_tail_shape = array_ops.identity(params_shape)[1:] 664 else: 665 params_tail_shape = params_shape[1:] 666 values_shape = array_ops.concat([indices_size, params_tail_shape], 0) 667 values = array_ops.reshape( 668 _IndexedSlicesToTensorNoWarning(grad), values_shape) 669 indices = array_ops.reshape(indices, indices_size) 670 params_grad = ops.IndexedSlices(values, indices, params_shape) 671 else: 672 # Handle axis by transposing the axis dimension to be the first non-batch 673 # dimension, compute the gradient and transpose the result back. 674 outer_shape = params_shape[:axis] 675 inner_shape = params_shape[axis:][1:] 676 values_shape = array_ops.concat([outer_shape, [-1], inner_shape], 0) 677 678 values_dims = array_ops.size(values_shape) 679 axis_dims = array_ops.size(outer_shape) 680 681 outer_batches_indices = math_ops.range(batch_dims) 682 batch_axis_indices = math_ops.range(batch_dims, axis_dims) 683 inner_axes_indices = math_ops.range(axis_dims + 1, values_dims) 684 685 values = array_ops.reshape( 686 _IndexedSlicesToTensorNoWarning(grad), values_shape) 687 688 # Move values[axis] up to values[batch_dims] 689 transpose_dims = array_ops.concat([ 690 outer_batches_indices, [axis_dims], batch_axis_indices, 691 inner_axes_indices 692 ], 0) 693 values_transpose = array_ops.transpose(values, transpose_dims) 694 695 params_grad = _BatchGatherGrad(params_shape, values_transpose, indices, 696 batch_dims, params_shape[axis]) 697 698 # Inverts the above transpose by moving dimension batch_dims back to its 699 # original position. 700 invert_transpose_dims = array_ops.concat([ 701 outer_batches_indices, batch_axis_indices + 1, [batch_dims], 702 inner_axes_indices 703 ], 0) 704 params_grad = array_ops.transpose(params_grad, invert_transpose_dims) 705 706 return [params_grad, None, None] 707 708 709@ops.RegisterGradient("GatherNd") 710def _GatherNdGrad(op, grad): 711 ref = op.inputs[0] 712 indices = op.inputs[1] 713 ref_shape = array_ops.shape(ref, out_type=indices.dtype) 714 if indices.shape.ndims == 2 and indices.shape.dims[-1].value == 1: 715 ref_grad = ops.IndexedSlices(grad, array_ops.squeeze(indices, axis=-1), 716 ref_shape) 717 else: 718 ref_grad = array_ops.scatter_nd(indices, grad, ref_shape) 719 return [ref_grad, None] 720 721 722@ops.RegisterGradient("ResourceGatherNd") 723def _ResourceGatherNdGrad(op, grad): # pylint: disable=missing-docstring 724 ref = op.inputs[0] 725 indices = op.inputs[1] 726 ref_shape = gen_resource_variable_ops.variable_shape(ref, indices.dtype) 727 if indices.shape.ndims == 2 and indices.shape.dims[-1].value == 1: 728 ref_grad = ops.IndexedSlices(grad, array_ops.squeeze(indices, axis=-1), 729 ref_shape) 730 else: 731 ref_grad = array_ops.scatter_nd(indices, grad, ref_shape) 732 return [ref_grad, None] 733 734 735@ops.RegisterGradient("CheckNumerics") 736def _CheckNumericsGrad(op, grad): 737 """Gradient for check_numerics op.""" 738 return array_ops.check_numerics( 739 grad, 740 "Not a number (NaN) or infinity (Inf) values detected in gradient. %s" % 741 op.get_attr("message")) 742 743 744@ops.RegisterGradient("CheckNumericsV2") 745def _CheckNumericsV2Grad(op, grad): 746 """Gradient for check_numerics op.""" 747 return array_ops.check_numerics_v2( 748 grad, 749 "Not a number (NaN) or infinity (Inf) values detected in gradient. %s" % 750 op.get_attr("message")) 751 752 753@ops.RegisterGradient("PlaceholderWithDefault") 754@ops.RegisterGradient("Identity") 755def _IdGrad(_, grad): 756 return grad 757 758 759@ops.RegisterGradient("RefIdentity") 760def _RefIdGrad(_, grad): 761 return grad 762 763 764@ops.RegisterGradient("IdentityN") 765def _IdNGrad(_, *grad): 766 return grad 767 768 769ops.NotDifferentiable("StopGradient") 770 771 772@ops.RegisterGradient("Reshape") 773def _ReshapeGrad(op, grad): 774 return [ 775 array_ops.reshape( 776 _IndexedSlicesToTensorNoWarning(grad), array_ops.shape(op.inputs[0])), 777 None 778 ] 779 780 781ops.NotDifferentiable("InvertPermutation") 782 783 784def _ReshapeToInput(op, grad): 785 """Reshapes the gradient to the shape of the original input.""" 786 return array_ops.reshape( 787 _IndexedSlicesToTensorNoWarning(grad), array_ops.shape(op.inputs[0])) 788 789 790@ops.RegisterGradient("ExpandDims") 791def _ExpandDimsGrad(op, grad): 792 return [_ReshapeToInput(op, grad), None] 793 794 795@ops.RegisterGradient("Squeeze") 796def _SqueezeGrad(op, grad): 797 return _ReshapeToInput(op, grad) 798 799 800@ops.RegisterGradient("Transpose") 801def _TransposeGrad(op, grad): 802 """Returns unshuffle(grad).""" 803 p = op.inputs[1] 804 return [array_ops.transpose(grad, array_ops.invert_permutation(p)), None] 805 806 807@ops.RegisterGradient("ConjugateTranspose") 808def _ConjugateTransposeGrad(op, grad): 809 """Returns conj(unshuffle(grad)).""" 810 p = op.inputs[1] 811 return [ 812 array_ops.transpose( 813 grad, array_ops.invert_permutation(p), conjugate=True), None 814 ] 815 816 817ops.NotDifferentiable("Shape") 818 819ops.NotDifferentiable("ShapeN") 820 821ops.NotDifferentiable("Rank") 822 823ops.NotDifferentiable("Size") 824 825 826@ops.RegisterGradient("Tile") 827def _TileGrad(op, grad): 828 """Sum reduces grad along the tiled dimensions.""" 829 input_shape = array_ops.shape(op.inputs[0], out_type=op.inputs[1].dtype) 830 # We interleave multiples and input_shape to get split_shape, 831 # reshape grad to split_shape, and reduce along all even 832 # dimensions (the tiled dimensions) to get the result 833 # with shape input_shape. For example 834 # input_shape = [20, 30, 40] 835 # multiples = [2, 3, 4] 836 # split_shape = [2, 20, 3, 30, 4, 40] 837 # axes = [0, 2, 4] 838 split_shape = array_ops.reshape( 839 array_ops.transpose(array_ops.stack([op.inputs[1], input_shape])), [-1]) 840 axes = math_ops.range(0, array_ops.size(split_shape), 2) 841 # Sum reduces grad along the first dimension for IndexedSlices 842 if isinstance(grad, ops.IndexedSlices): 843 input_shape_0 = math_ops.cast(input_shape[0], grad.indices.dtype) 844 grad = math_ops.unsorted_segment_sum( 845 grad.values, math_ops.mod(grad.indices, input_shape_0), input_shape_0) 846 split_shape = array_ops.concat([[1], split_shape[1:]], axis=0) 847 input_grad = math_ops.reduce_sum(array_ops.reshape(grad, split_shape), axes) 848 # Fix shape inference 849 if not context.executing_eagerly(): 850 input_grad.set_shape(op.inputs[0].get_shape()) 851 return [input_grad, None] 852 853 854ops.NotDifferentiable("BroadcastGradientArgs") 855 856 857def _PadGrad(op, grad): 858 """Gradient for Pad.""" 859 # Pad introduces values around the original tensor, so the gradient function 860 # slices the original shape out of the gradient.""" 861 x = op.inputs[0] 862 a = op.inputs[1] # [Rank(x), 2] 863 # Takes a slice of a. The 1st column. [Rank(x), 1]. 864 pad_before = array_ops.slice(a, [0, 0], 865 array_ops.stack([array_ops.rank(x), 1])) 866 # Make it a 1-D tensor. 867 begin = array_ops.reshape(pad_before, [-1]) 868 sizes = array_ops.shape(x, out_type=begin.dtype) 869 x_grad = array_ops.slice(grad, begin, sizes) 870 if len(op.inputs) == 3: 871 return x_grad, None, None 872 else: 873 return x_grad, None 874 875 876ops.RegisterGradient("Pad")(_PadGrad) 877ops.RegisterGradient("PadV2")(_PadGrad) 878 879 880# ReverseSequence is just a permutation. The gradient permutes back. 881@ops.RegisterGradient("ReverseSequence") 882def _ReverseSequenceGrad(op, grad): 883 seq_lengths = op.inputs[1] 884 return [ 885 array_ops.reverse_sequence( 886 grad, 887 batch_axis=op.get_attr("batch_dim"), 888 seq_axis=op.get_attr("seq_dim"), 889 seq_lengths=seq_lengths), None 890 ] 891 892 893@ops.RegisterGradient("Reverse") 894def _ReverseGrad(op, grad): 895 reverse_dims = op.inputs[1] 896 return gen_array_ops.reverse(grad, reverse_dims), None 897 898 899@ops.RegisterGradient("ReverseV2") 900def _ReverseV2Grad(op, grad): 901 axis = op.inputs[1] 902 return array_ops.reverse_v2(grad, axis), None 903 904 905@ops.RegisterGradient("SpaceToBatch") 906def _SpaceToBatchGrad(op, grad): 907 # Its gradient is the opposite op: BatchToSpace. 908 block_size = op.get_attr("block_size") 909 return [ 910 array_ops.batch_to_space(grad, op.inputs[1], block_size=block_size), None 911 ] 912 913 914@ops.RegisterGradient("SpaceToBatchND") 915def _SpaceToBatchNDGrad(op, grad): 916 # Its gradient is the opposite op: BatchToSpaceND. 917 return [ 918 array_ops.batch_to_space_nd(grad, op.inputs[1], op.inputs[2]), None, None 919 ] 920 921 922@ops.RegisterGradient("BatchToSpace") 923def _BatchToSpaceGrad(op, grad): 924 # Its gradient is the opposite op: SpaceToBatch. 925 block_size = op.get_attr("block_size") 926 return [ 927 array_ops.space_to_batch(grad, op.inputs[1], block_size=block_size), None 928 ] 929 930 931@ops.RegisterGradient("BatchToSpaceND") 932def _BatchToSpaceNDGrad(op, grad): 933 # Its gradient is the opposite op: SpaceToBatchND. 934 return [ 935 array_ops.space_to_batch_nd(grad, op.inputs[1], op.inputs[2]), None, None 936 ] 937 938 939@ops.RegisterGradient("SpaceToDepth") 940def _SpaceToDepthGrad(op, grad): 941 # Its gradient is the opposite op: DepthToSpace. 942 block_size = op.get_attr("block_size") 943 data_format = op.get_attr("data_format") 944 if data_format == "NCHW_VECT_C": 945 raise ValueError("Cannot compute SpaceToDepth gradient with NCHW_VECT_C. " 946 "NCHW_VECT_C requires qint8 data type.") 947 return array_ops.depth_to_space(grad, block_size, data_format=data_format) 948 949 950@ops.RegisterGradient("DepthToSpace") 951def _DepthToSpaceGrad(op, grad): 952 # Its gradient is the opposite op: SpaceToDepth. 953 block_size = op.get_attr("block_size") 954 data_format = op.get_attr("data_format") 955 if data_format == "NCHW_VECT_C": 956 raise ValueError("Cannot compute DepthToSpace gradient with NCHW_VECT_C. " 957 "NCHW_VECT_C requires qint8 data type.") 958 return array_ops.space_to_depth(grad, block_size, data_format=data_format) 959 960 961ops.NotDifferentiable("OneHot") 962 963 964@ops.RegisterGradient("MirrorPad") 965def _MirrorPadGrad(op, grad): 966 mode = op.get_attr("mode") 967 return [gen_array_ops.mirror_pad_grad(grad, op.inputs[1], mode=mode), None] 968 969 970@ops.RegisterGradient("MirrorPadGrad") 971def _MirrorPadGradGrad(op, grad): 972 mode = op.get_attr("mode") 973 return [gen_array_ops.mirror_pad(grad, op.inputs[1], mode=mode), None] 974 975 976@ops.RegisterGradient("QuantizeAndDequantize") 977def _QuantizeAndDequantizeGrad(_, grad): 978 return grad 979 980 981@ops.RegisterGradient("QuantizeAndDequantizeV2") 982def _QuantizeAndDequantizeV2Grad(_, grad): 983 return [grad, None, None] 984 985 986@ops.RegisterGradient("QuantizeAndDequantizeV3") 987def _QuantizeAndDequantizeV3Grad(_, grad): 988 # Only propagate the gradient for the unquantized input. 989 return [grad, None, None, None] 990 991 992@ops.RegisterGradient("ExtractImagePatches") 993def _ExtractImagePatchesGrad(op, grad): 994 input_bhwc = array_ops.shape(op.inputs[0], out_type=dtypes.int64) 995 batch_size, rows_in, cols_in, channels = input_bhwc[0], input_bhwc[1], \ 996 input_bhwc[2], input_bhwc[3] 997 998 # Create indices matrix for input tensor. 999 # Note that 0 is preserved for padding location, 1000 # so indices for input start from 1 to 1 + rows_in * cols_in. 1001 input_indices_num = 1 + rows_in * cols_in 1002 input_idx = array_ops.reshape( 1003 math_ops.range(1, input_indices_num, dtype=ops.dtypes.int64), 1004 (1, rows_in, cols_in, 1)) 1005 input_idx_patched = gen_array_ops.extract_image_patches( 1006 input_idx, op.get_attr("ksizes"), op.get_attr("strides"), 1007 op.get_attr("rates"), op.get_attr("padding")) 1008 1009 # Create indices matrix for output tensor. 1010 output_bhwc = array_ops.shape(op.outputs[0], out_type=dtypes.int64) 1011 rows_out, cols_out = output_bhwc[1], output_bhwc[2] 1012 _, ksize_r, ksize_c, _ = op.get_attr("ksizes") 1013 # Indices for output start from 0. 1014 output_indices_num = rows_out * cols_out * ksize_r * ksize_c 1015 output_idx = array_ops.reshape( 1016 math_ops.range(output_indices_num, dtype=ops.dtypes.int64), 1017 (1, rows_out, cols_out, ksize_r * ksize_c)) 1018 1019 # Construct mapping table for indices: (input -> output). 1020 idx_matrix = array_ops.concat([ 1021 array_ops.expand_dims(input_idx_patched, axis=-1), 1022 array_ops.expand_dims(output_idx, axis=-1) 1023 ], 1024 axis=-1) 1025 idx_map = array_ops.reshape(idx_matrix, (-1, 2)) 1026 1027 sp_shape = (input_indices_num, output_indices_num) 1028 sp_mat_full = sparse_tensor.SparseTensor( 1029 idx_map, array_ops.ones([output_indices_num], dtype=grad.dtype), sp_shape) 1030 # Remove all padding locations [0, :]. 1031 sp_mat = sparse_ops.sparse_slice(sp_mat_full, (1, 0), 1032 (input_indices_num - 1, output_indices_num)) 1033 1034 grad_expanded = array_ops.transpose( 1035 array_ops.reshape( 1036 _IndexedSlicesToTensorNoWarning(grad), 1037 (batch_size, rows_out, cols_out, ksize_r, ksize_c, channels)), 1038 (1, 2, 3, 4, 0, 5)) 1039 grad_flat = array_ops.reshape(grad_expanded, (-1, batch_size * channels)) 1040 1041 jac = sparse_ops.sparse_tensor_dense_matmul(sp_mat, grad_flat) 1042 1043 grad_out = array_ops.reshape(jac, (rows_in, cols_in, batch_size, channels)) 1044 grad_out = array_ops.transpose(grad_out, (2, 0, 1, 3)) 1045 1046 return [grad_out] 1047 1048 1049@ops.RegisterGradient("ExtractVolumePatches") 1050def _ExtractVolumePatchesGrad(op, grad): 1051 batch_size, planes_in, rows_in, cols_in, channels = [ 1052 dim.value for dim in op.inputs[0].shape.dims 1053 ] 1054 input_bphwc = array_ops.shape(op.inputs[0]) 1055 batch_size = input_bphwc[0] 1056 channels = input_bphwc[4] 1057 1058 # Create indices matrix for input tensor. 1059 # Note that 0 is preserved for padding location, 1060 # so indices for input start from 1 to 1 + rows_in * cols_in. 1061 input_indices_num = 1 + planes_in * rows_in * cols_in 1062 input_idx = array_ops.reshape( 1063 math_ops.range(1, input_indices_num, dtype=ops.dtypes.int64), 1064 (1, planes_in, rows_in, cols_in, 1)) 1065 input_idx_patched = gen_array_ops.extract_volume_patches( 1066 input_idx, op.get_attr("ksizes"), op.get_attr("strides"), 1067 op.get_attr("padding")) 1068 1069 # Create indices matrix for output tensor. 1070 _, planes_out, rows_out, cols_out, _ = [ 1071 dim.value for dim in op.outputs[0].shape.dims 1072 ] 1073 _, ksize_p, ksize_r, ksize_c, _ = op.get_attr("ksizes") 1074 # Indices for output start from 0. 1075 prc_indices_num = planes_out * rows_out * cols_out 1076 output_indices_num = prc_indices_num * ksize_p * ksize_r * ksize_c 1077 output_idx = array_ops.reshape( 1078 math_ops.range(output_indices_num, dtype=ops.dtypes.int64), 1079 (1, planes_out, rows_out, cols_out, ksize_p * ksize_r * ksize_c)) 1080 1081 # Construct mapping table for indices: (input -> output). 1082 idx_matrix = array_ops.concat([ 1083 array_ops.expand_dims(input_idx_patched, axis=-1), 1084 array_ops.expand_dims(output_idx, axis=-1) 1085 ], 1086 axis=-1) 1087 idx_map = array_ops.reshape(idx_matrix, (-1, 2)) 1088 1089 sp_shape = (input_indices_num, output_indices_num) 1090 sp_mat_full = sparse_tensor.SparseTensor( 1091 idx_map, array_ops.ones([output_indices_num], dtype=grad.dtype), sp_shape) 1092 # Remove all padding locations [0, :]. 1093 sp_mat = sparse_ops.sparse_slice(sp_mat_full, (1, 0), 1094 (input_indices_num - 1, output_indices_num)) 1095 1096 grad_expanded = array_ops.transpose( 1097 array_ops.reshape( 1098 _IndexedSlicesToTensorNoWarning(grad), 1099 (batch_size, planes_out, rows_out, cols_out, ksize_p, ksize_r, 1100 ksize_c, channels)), (1, 2, 3, 4, 5, 6, 0, 7)) 1101 grad_flat = array_ops.reshape(grad_expanded, (-1, batch_size * channels)) 1102 1103 jac = sparse_ops.sparse_tensor_dense_matmul(sp_mat, grad_flat) 1104 1105 grad_out = array_ops.reshape( 1106 jac, (planes_in, rows_in, cols_in, batch_size, channels)) 1107 grad_out = array_ops.transpose(grad_out, (3, 0, 1, 2, 4)) 1108 1109 return [grad_out] 1110 1111 1112@ops.RegisterGradient("ScatterNd") 1113def _ScatterNdGrad(op, grad): 1114 indices = op.inputs[0] 1115 updates_grad = array_ops.gather_nd(grad, indices) 1116 return [None, updates_grad, None] 1117 1118 1119@ops.RegisterGradient("TensorScatterUpdate") 1120def _TensorScatterUpdateGrad(op, grad): 1121 indices = op.inputs[1] 1122 updates_grad = array_ops.gather_nd(grad, indices) 1123 tensor_grad = array_ops.tensor_scatter_update( 1124 array_ops.identity(grad), indices, 1125 array_ops.zeros_like(op.inputs[2], dtype=grad.dtype)) 1126 return [tensor_grad, None, updates_grad] 1127 1128 1129@ops.RegisterGradient("TensorScatterAdd") 1130def _TensorScatterAddGrad(op, grad): 1131 indices = op.inputs[1] 1132 updates_grad = array_ops.gather_nd(grad, indices) 1133 tensor_grad = array_ops.identity(grad) 1134 return [tensor_grad, None, updates_grad] 1135 1136 1137def _TensorScatterMinOrMaxGrad(op, grad): 1138 """Gradient for TensorScatterMin and TensorScatterMax.""" 1139 indices = op.inputs[1] 1140 x = op.inputs[0] 1141 y = op.inputs[2] 1142 output = op.outputs[0] 1143 x_indicators = math_ops.cast(math_ops.equal(x, output), grad.dtype) 1144 y_output = array_ops.gather_nd(output, indices) 1145 y_indicators = math_ops.cast(math_ops.equal(y, y_output), grad.dtype) 1146 ys_indicators = array_ops.scatter_nd(indices, y_indicators, 1147 array_ops.shape(x)) 1148 indicators = x_indicators + ys_indicators # All elements are >= 1. 1149 # If there are multiple minimum or maximum elements then the gradient will be 1150 # divided between them. 1151 x_grad = grad * x_indicators / indicators 1152 y_grad = array_ops.gather_nd(grad / indicators, indices) * y_indicators 1153 return [x_grad, None, y_grad] 1154 1155 1156@ops.RegisterGradient("TensorScatterMax") 1157def _TensorScatterMaxGrad(op, grad): 1158 """Gradient for TensorScatterMax op.""" 1159 return _TensorScatterMinOrMaxGrad(op, grad) 1160 1161 1162@ops.RegisterGradient("TensorScatterMin") 1163def _TensorScatterMinGrad(op, grad): 1164 """Gradient for TensorScatterMin op.""" 1165 return _TensorScatterMinOrMaxGrad(op, grad) 1166 1167 1168@ops.RegisterGradient("TensorScatterSub") 1169def _TensorScatterSubGrad(op, grad): 1170 indices = op.inputs[1] 1171 updates_grad = array_ops.gather_nd(grad, indices) 1172 tensor_grad = array_ops.identity(grad) 1173 return [tensor_grad, None, -updates_grad] 1174 1175 1176@ops.RegisterGradient("ScatterNdNonAliasingAdd") 1177def _ScatterNdNonAliasingAddGrad(op, grad): 1178 indices = op.inputs[1] 1179 updates_grad = array_ops.gather_nd(grad, indices) 1180 return [grad, None, updates_grad] 1181 1182 1183@ops.RegisterGradient("BroadcastTo") 1184def _BroadcastToGrad(op, grad): 1185 input_value = op.inputs[0] 1186 broadcast_shape = op.inputs[1] 1187 input_value_shape = array_ops.shape(input_value) 1188 if not isinstance(broadcast_shape, ops.EagerTensor): 1189 broadcast_shape_static = tensor_shape.TensorShape( 1190 pywrap_tf_session.TF_TryEvaluateConstant_wrapper( 1191 broadcast_shape.graph._c_graph, broadcast_shape._as_tf_output())) # pylint: disable=protected-access 1192 if broadcast_shape_static.is_fully_defined(): 1193 broadcast_shape = constant_op.constant( 1194 broadcast_shape_static.as_list(), dtype=dtypes.int32) 1195 _, reduction_axes = gen_array_ops.broadcast_gradient_args( 1196 broadcast_shape, input_value_shape) 1197 updates_grad_reshaped = math_ops.reduce_sum( 1198 grad, axis=reduction_axes, keepdims=True) 1199 updates_grad = array_ops.reshape(updates_grad_reshaped, input_value_shape) 1200 return [updates_grad, None] 1201