• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Gradients for operators defined in array_ops.py."""
16
17from tensorflow.compiler.tf2xla.ops import gen_xla_ops
18from tensorflow.python import pywrap_tfe
19from tensorflow.python.eager import context
20from tensorflow.python.framework import constant_op
21from tensorflow.python.framework import dtypes
22from tensorflow.python.framework import indexed_slices as indexed_slices_lib
23from tensorflow.python.framework import ops
24from tensorflow.python.framework import sparse_tensor
25from tensorflow.python.framework import tensor_shape
26from tensorflow.python.framework import tensor_util
27from tensorflow.python.ops import array_ops
28from tensorflow.python.ops import control_flow_ops
29from tensorflow.python.ops import control_flow_util
30from tensorflow.python.ops import gen_array_ops
31from tensorflow.python.ops import gen_math_ops
32from tensorflow.python.ops import gen_resource_variable_ops
33from tensorflow.python.ops import math_ops
34from tensorflow.python.ops import sparse_ops
35
36
37@ops.RegisterGradient("Pack")
38def _PackGrad(op, grad):
39  """Gradient for pack op."""
40  return array_ops.unstack(grad, num=op.get_attr("N"), axis=op.get_attr("axis"))
41
42
43@ops.RegisterGradient("Unpack")
44def _UnpackGrad(op, *grads):
45  """Gradient for unpack op."""
46  return array_ops.stack(grads, axis=op.get_attr("axis"))
47
48
49def _ConcatGradHelper(op, grad, start_value_index, end_value_index, dim_index):
50  """Gradient for concat op.
51
52  Args:
53    op: An operation.
54    grad: `Tensor` or `IndexedSlices` representing the gradients with respect to
55      each output of the op.
56    start_value_index: An integer index of the first value in the op.inputs.
57    end_value_index: An integer index of the last value in the op.inputs.
58    dim_index: An integer index of concat_dim or axis parameter in op.inputs.
59
60  Returns:
61    Tensors representing the partial gradients with respect to each input
62    of the op.
63
64  Raises:
65    ValueError: if concat_dim/axis is not statically known.
66  """
67
68  def _CreateDenseMaskAndBegin(sizes, concat_dim):
69    """Create variables for iteratively slicing a dense gradients tensor."""
70    # Since shape is 1-D, shape_of_shape = [rank-of-inputs]
71    shape_of_shape = array_ops.shape(sizes[0])
72    # Make a vector of length equal to the input's dimensions,
73    # with 0's everywhere and 1 in the concat dim position.
74    # Note: Can't use sparse_to_dense since it isn't GPU-capable (for now)
75    mask = array_ops.concat([
76        array_ops.zeros(
77            array_ops.expand_dims(concat_dim, 0), dtype=dtypes.int32), [1],
78        array_ops.zeros(shape_of_shape - concat_dim - 1, dtype=dtypes.int32)
79    ], 0)
80    begin = array_ops.zeros(shape_of_shape, dtype=dtypes.int32)
81    return mask, begin
82
83  def _ExtractInputShapes(inputs):
84    """Extract the shapes of a set of input tensors."""
85    if context.executing_eagerly():
86      return array_ops.shape_n(inputs)
87    sizes = []
88    fully_known = True
89    for x in inputs:
90      input_shape = array_ops.shape(x)
91      if not isinstance(input_shape,
92                        ops.Tensor) or input_shape.op.type != "Const":
93        fully_known = False
94        break
95      sizes.append(input_shape)
96
97    if fully_known:
98      return sizes
99    else:
100      return array_ops.shape_n(inputs)
101
102  # Degenerate concatenation, just return grad.
103  if len(op.inputs) == 2:
104    return grad + [None] if end_value_index <= dim_index else [None] + grad
105
106  concat_dim = op.inputs[dim_index]
107  input_values = op.inputs[start_value_index:end_value_index]
108
109  out_grads = []
110  if isinstance(grad, ops.Tensor):
111    if context.executing_eagerly() or isinstance(concat_dim, ops.EagerTensor):
112      # Using mod here for convenience since concat_dim is already verified
113      # in concat implementation to be within the allowed [-rank, rank) range.
114      non_neg_concat_dim = (
115          concat_dim._numpy().item(0) % input_values[0]._rank())  # pylint: disable=protected-access
116      # All inputs are guaranteed to be EagerTensors in eager mode
117      sizes = pywrap_tfe.TFE_Py_TensorShapeSlice(input_values,
118                                                 non_neg_concat_dim)
119      out_grads = array_ops.split(grad, sizes, non_neg_concat_dim)
120    else:
121      if constant_op.is_constant(concat_dim):
122        # If concat_dim is a constant defined in a different context,
123        # then we duplicate it in the current context to avoid passing it
124        # through an Enter node.
125        # This is a small optimization in general, but it is required when
126        # compiling with XLA, as XLA needs the concat input to be folded into a
127        # constant.
128        grad_context = control_flow_util.GetOutputContext(grad.op)
129        dim_context = control_flow_util.GetOutputContext(concat_dim.op)
130        if dim_context != grad_context:
131          value = tensor_util.constant_value(concat_dim)
132          concat_dim = constant_op.constant(value=value, dtype=concat_dim.dtype)
133
134      # Using mod here for convenience since concat_dim is already verified
135      # in concat implementation to be within the allowed [-rank, rank) range.
136      non_neg_concat_dim = concat_dim % array_ops.rank(input_values[0])
137
138      # Get the inputs' tensor shapes
139      sizes = _ExtractInputShapes(input_values)
140      # The magic number of 16 was found through benchmarking a range of sizes
141      # on CPUs and a Maxwell TitanX.  A speedup was seen in a large majority of
142      # cases when switching implementations at N=16, but it is possible that
143      # there will be a small number of performance regressions.
144      if len(sizes) > 16:
145        # extract the size of each input along the concat dimension
146        sizes = array_ops.squeeze(
147            array_ops.slice(
148                array_ops.stack(sizes, axis=1), [non_neg_concat_dim, 0],
149                [1, -1]))
150        out_grads = array_ops.split(grad, sizes, non_neg_concat_dim)
151      else:
152        offset = gen_array_ops.concat_offset(non_neg_concat_dim, sizes)
153        for (begin, size) in zip(offset, sizes):
154          out_grads.append(array_ops.slice(grad, begin, size))
155  elif isinstance(grad, indexed_slices_lib.IndexedSlices):
156    # Using mod here for convenience since concat_dim is already verified
157    # in concat implementation to be within the allowed [-rank, rank) range.
158    non_neg_concat_dim = concat_dim % array_ops.rank(input_values[0])
159    concat_dim_static = tensor_util.constant_value(concat_dim)
160    if concat_dim_static is None:
161      raise ValueError("Can only compute IndexedSlices gradient with "
162                       "statically-known concat_dim")
163    if concat_dim_static < 0:
164      rank = tensor_util.constant_value(array_ops.rank(input_values[0]))
165      if rank is None:
166        raise ValueError("Can only compute IndexedSlices gradient with "
167                         "negative concat_dim when first value rank is "
168                         "statically-known.")
169      concat_dim_static %= rank
170    # Get the inputs' tensor shapes
171    sizes = [array_ops.shape(x) for x in input_values]
172    if concat_dim_static > 0:
173      # IndexedSlices, non_neg_concat_dim > 0. Each input gets IndexedSlices
174      # gradients with all the indices, but with grad.values sliced accordingly.
175      # This is like the Tensor case, except shape(grad.values)[0] is not equal
176      # to shape(sizes[i])[0], since only a subset of the dim-0 values are
177      # stored.
178      mask, begin = _CreateDenseMaskAndBegin(sizes, non_neg_concat_dim)
179      for size in sizes:
180        new_values = array_ops.slice(
181            grad.values, begin,
182            array_ops.concat([[-1], array_ops.slice(size, [1], [-1])], 0))
183        out_grads.append(
184            indexed_slices_lib.IndexedSlices(new_values, grad.indices, size))
185        # Lint complains begin = begin + ...
186        begin = math_ops.add(begin, size * mask)
187    else:
188      # IndexedSlices, concat_dim == 0. Each input gets IndexedSlices gradients
189      # only for the relevant indices.
190      start = constant_op.constant(0, dtype=grad.indices.dtype)
191      for size in sizes:
192        size_concat_dim = array_ops.gather(size, non_neg_concat_dim)
193        if size_concat_dim.dtype != grad.indices.dtype:
194          size_concat_dim = math_ops.cast(
195              size_concat_dim, dtype=grad.indices.dtype)
196        end = start + size_concat_dim
197        # Compute the 1-D Tensor of indices relevant for this input.
198        indices_to_select = array_ops.squeeze(
199            array_ops.where(
200                math_ops.logical_and(grad.indices >= start,
201                                     grad.indices < end)),
202            axis=[1])
203        new_indices = array_ops.gather(grad.indices, indices_to_select) - start
204        new_values = array_ops.gather(grad.values, indices_to_select)
205        out_grads.append(
206            indexed_slices_lib.IndexedSlices(new_values, new_indices, size))
207        start = end
208  else:
209    raise TypeError("Expected Tensor or IndexedSlices, got %s" % type(grad))
210
211  return (out_grads + [None] if end_value_index <= dim_index else [None] +
212          out_grads)
213
214
215@ops.RegisterGradient("Concat")
216def _ConcatGrad(op, grad):
217  return _ConcatGradHelper(
218      op,
219      grad,
220      start_value_index=1,
221      end_value_index=len(op.inputs),
222      dim_index=0)
223
224
225@ops.RegisterGradient("ConcatV2")
226def _ConcatGradV2(op, grad):
227  return _ConcatGradHelper(
228      op, grad, start_value_index=0, end_value_index=-1, dim_index=-1)
229
230
231ops.NotDifferentiable("ConcatOffset")
232
233
234@ops.RegisterGradient("Slice")
235def _SliceGrad(op, grad):
236  """Gradient for Slice op."""
237  # Create an Nx2 padding where the first column represents how many
238  # zeros are to be prepended for each dimension, and the second
239  # column indicates how many zeros are appended.
240  #
241  # The number of zeros to append is the shape of the input
242  # elementwise-subtracted by both the begin vector and sizes vector.
243  #
244  # Some more reshaping is needed to assemble this tensor with the
245  # right dimensions.
246  input_vec = op.inputs[0]
247  begin_vec = op.inputs[1]
248  input_rank = array_ops.rank(input_vec)
249  slice_size = array_ops.shape(op.outputs[0])
250  if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()):
251    return gen_xla_ops.xla_dynamic_update_slice(array_ops.zeros_like(input_vec),
252                                                grad, begin_vec), None, None
253
254  shape = array_ops.stack([input_rank, 1])
255  before_pad = array_ops.reshape(begin_vec, shape)
256  after_pad = array_ops.reshape(
257      array_ops.shape(input_vec) - slice_size - begin_vec, shape)
258  paddings = array_ops.concat([before_pad, after_pad], 1)
259  return array_ops.pad(grad, paddings), None, None
260
261
262@ops.RegisterGradient("StridedSlice")
263def _StridedSliceGrad(op, grad):
264  """Gradient for StridedSlice op."""
265  begin = op.inputs[1]
266  end = op.inputs[2]
267  strides = op.inputs[3]
268  # StridedSliceGrad requires `x`, `begin`, `end` and `strides` to be of the
269  # same dtype so we build a shape of the same type as other args.
270  # Note that the choice of `begin` for specifying `out_type` is arbitrary.
271  # We could choose any of {begin|end|strides}.dtype since they are required to
272  # be the same.
273  x = array_ops.shape(op.inputs[0], out_type=begin.dtype)
274
275  x_static = tensor_util.constant_value(x)
276  x = x_static if x_static is not None else x
277  begin_static = tensor_util.constant_value(begin)
278  begin = begin_static if begin_static is not None else begin
279  end_static = tensor_util.constant_value(end)
280  end = end_static if end_static is not None else end
281  strides_static = tensor_util.constant_value(strides)
282  strides = strides_static if strides_static is not None else strides
283
284  return array_ops.strided_slice_grad(
285      x,
286      begin,
287      end,
288      strides,
289      grad,
290      begin_mask=op.get_attr("begin_mask"),
291      end_mask=op.get_attr("end_mask"),
292      ellipsis_mask=op.get_attr("ellipsis_mask"),
293      new_axis_mask=op.get_attr("new_axis_mask"),
294      shrink_axis_mask=op.get_attr("shrink_axis_mask")), None, None, None
295
296
297@ops.RegisterGradient("StridedSliceGrad")
298def _StridedSliceGradGrad(op, grad):
299  """Gradient for StridedSliceGrad op."""
300  begin = op.inputs[1]
301  end = op.inputs[2]
302  strides = op.inputs[3]
303
304  return None, None, None, None, array_ops.strided_slice(
305      grad,
306      begin,
307      end,
308      strides,
309      begin_mask=op.get_attr("begin_mask"),
310      end_mask=op.get_attr("end_mask"),
311      ellipsis_mask=op.get_attr("ellipsis_mask"),
312      new_axis_mask=op.get_attr("new_axis_mask"),
313      shrink_axis_mask=op.get_attr("shrink_axis_mask"))
314
315
316@ops.RegisterGradient("TensorStridedSliceUpdate")
317def _TensorStridedSliceUpdateGrad(op, grad):  # pylint:disable=missing-function-docstring
318  begin = op.inputs[1]
319  end = op.inputs[2]
320  strides = op.inputs[3]
321  begin_mask = op.get_attr("begin_mask")
322  end_mask = op.get_attr("end_mask")
323  ellipsis_mask = op.get_attr("ellipsis_mask")
324  new_axis_mask = op.get_attr("new_axis_mask")
325  shrink_axis_mask = op.get_attr("shrink_axis_mask")
326  def Apply(f, *args):
327    return f(*args,
328             begin_mask=begin_mask,
329             end_mask=end_mask,
330             shrink_axis_mask=shrink_axis_mask,
331             new_axis_mask=new_axis_mask,
332             ellipsis_mask=ellipsis_mask)
333  dy = Apply(array_ops.strided_slice,
334             grad, begin, end, strides)
335  dx = Apply(array_ops.tensor_strided_slice_update,
336             grad, begin, end, strides, array_ops.zeros_like(dy))
337
338  # The value is potentially broadcast to the shape of the strided slice, so we
339  # may need to adjust dy.
340  slice_shape = array_ops.shape(dy, out_type=begin.dtype)
341  value_shape = array_ops.shape(op.inputs[4], out_type=slice_shape.dtype)
342
343  _, reduction_axes = gen_array_ops.broadcast_gradient_args(
344      slice_shape, value_shape)
345  dy_reshaped = math_ops.reduce_sum(dy, axis=reduction_axes, keepdims=True)
346  dy = array_ops.reshape(dy_reshaped, value_shape)
347
348  return dx, None, None, None, dy
349
350
351@ops.RegisterGradient("Split")
352def _SplitGrad(op, *grads):
353  return None, array_ops.concat(list(grads), op.inputs[0])
354
355
356@ops.RegisterGradient("SplitV")
357def _SplitVGrad(op, *grads):
358  returnval = array_ops.concat(list(grads), op.inputs[2])
359  returnval = [returnval] + [
360      None,
361  ] * (
362      len(op.inputs) - 1)
363  return returnval
364
365
366ops.NotDifferentiable("Const")
367
368
369@ops.RegisterGradient("Diag")
370def _DiagGrad(_, grad):
371  return array_ops.diag_part(grad)
372
373
374@ops.RegisterGradient("DiagPart")
375def _DiagPartGrad(_, grad):
376  return array_ops.diag(grad)
377
378
379@ops.RegisterGradient("MatrixDiag")
380def _MatrixDiagGrad(_, grad):
381  return array_ops.matrix_diag_part(grad)
382
383
384@ops.RegisterGradient("MatrixDiagV2")
385def _MatrixDiagV2Grad(op, grad):
386  return array_ops.matrix_diag_part(
387      grad, k=op.inputs[1]), None, None, None, None
388
389
390@ops.RegisterGradient("MatrixDiagV3")
391def _MatrixDiagV3Grad(op, grad):
392  return array_ops.matrix_diag_part(
393      grad, k=op.inputs[1], align=op.get_attr("align")), None, None, None, None
394
395
396@ops.RegisterGradient("MatrixDiagPart")
397def _MatrixDiagPartGrad(op, grad):
398  matrix_shape = op.inputs[0].get_shape()[-2:]
399  if matrix_shape.is_fully_defined() and matrix_shape[0] == matrix_shape[1]:
400    return array_ops.matrix_diag(grad)
401  else:
402    return array_ops.matrix_set_diag(array_ops.zeros_like(op.inputs[0]), grad)
403
404
405@ops.RegisterGradient("MatrixDiagPartV2")
406def _MatrixDiagPartV2Grad(op, grad):
407  """Gradient for MatrixDiagPartV2."""
408  matrix_shape = op.inputs[0].get_shape()[-2:]
409  if matrix_shape.is_fully_defined():
410    return array_ops.matrix_diag(
411        grad,
412        k=op.inputs[1],
413        num_rows=matrix_shape[0],
414        num_cols=matrix_shape[1]), None, None
415  else:
416    return array_ops.matrix_set_diag(
417        array_ops.zeros_like(op.inputs[0]), grad, k=op.inputs[1]), None, None
418
419
420@ops.RegisterGradient("MatrixDiagPartV3")
421def _MatrixDiagPartV3Grad(op, grad):
422  """Gradient for MatrixDiagPartV3."""
423  matrix_shape = op.inputs[0].get_shape()[-2:]
424  align = op.get_attr("align")
425  if matrix_shape.is_fully_defined():
426    return array_ops.matrix_diag(
427        grad,
428        k=op.inputs[1],
429        num_rows=matrix_shape[0],
430        num_cols=matrix_shape[1],
431        align=align), None, None
432  else:
433    return array_ops.matrix_set_diag(
434        array_ops.zeros_like(op.inputs[0]), grad, k=op.inputs[1],
435        align=align), None, None
436
437
438@ops.RegisterGradient("MatrixSetDiag")
439def _MatrixSetDiagGrad(op, grad):
440  """Gradient for MatrixSetDiag."""
441  input_shape = op.inputs[0].get_shape().merge_with(grad.get_shape())
442  diag_shape = op.inputs[1].get_shape()
443  batch_shape = input_shape[:-2].merge_with(diag_shape[:-1])
444  matrix_shape = input_shape[-2:]
445  if batch_shape.is_fully_defined() and matrix_shape.is_fully_defined():
446    diag_shape = batch_shape.as_list() + [min(matrix_shape.as_list())]
447  else:
448    with ops.colocate_with(grad):
449      grad_shape = array_ops.shape(grad)
450      grad_rank = array_ops.rank(grad)
451      batch_shape = array_ops.slice(grad_shape, [0], [grad_rank - 2])
452      matrix_shape = array_ops.slice(grad_shape, [grad_rank - 2], [2])
453      min_dim = math_ops.reduce_min(matrix_shape)
454      diag_shape = array_ops.concat([batch_shape, [min_dim]], 0)
455  grad_input = array_ops.matrix_set_diag(
456      grad, array_ops.zeros(diag_shape, dtype=grad.dtype))
457  grad_diag = array_ops.matrix_diag_part(grad)
458  return (grad_input, grad_diag)
459
460
461@ops.RegisterGradient("MatrixSetDiagV2")
462def _MatrixSetDiagGradV2(op, grad):
463  """Gradient for MatrixSetDiagV2."""
464  diag_shape = op.inputs[1].get_shape()
465  if not diag_shape.is_fully_defined():
466    # Need to know the values of `d_lower` and `d_upper` to infer diag_shape.
467    grad_shape = array_ops.shape(grad)
468    batch_shape = grad_shape[:-2]
469    matrix_shape = grad_shape[-2:]
470    diag_index = array_ops.reshape(op.inputs[2], [-1])  # Converts to vector.
471    d_lower = diag_index[0]
472    d_upper = diag_index[-1]  # Works both when len(diag_index) is 1 and 2.
473    y_offset = control_flow_ops.cond(
474        math_ops.less(d_upper, 0), lambda: d_upper, lambda: 0)
475    x_offset = control_flow_ops.cond(
476        math_ops.greater(d_lower, 0), lambda: -d_lower, lambda: 0)
477
478    max_diag_len = math_ops.minimum(matrix_shape[0] + y_offset,
479                                    matrix_shape[1] + x_offset)
480    # pylint: disable=g-long-lambda
481    # pyformat: disable
482    postfix = control_flow_ops.cond(
483        math_ops.equal(d_lower, d_upper),
484        lambda: ops.convert_to_tensor([max_diag_len]),
485        lambda: ops.convert_to_tensor([d_upper - d_lower + 1,
486                                       max_diag_len]))
487    # pyformat: enable
488    # pylint: enable=g-long-lambda
489    diag_shape = array_ops.concat([batch_shape, postfix], 0)
490
491  grad_input = array_ops.matrix_set_diag(
492      grad, array_ops.zeros(diag_shape, dtype=grad.dtype), k=op.inputs[2])
493  grad_diag = array_ops.matrix_diag_part(grad, k=op.inputs[2])
494  return (grad_input, grad_diag, None)
495
496
497@ops.RegisterGradient("MatrixSetDiagV3")
498def _MatrixSetDiagGradV3(op, grad):
499  """Gradient for MatrixSetDiagV3."""
500  diag_shape = op.inputs[1].get_shape()
501  align = op.get_attr("align")
502  if not diag_shape.is_fully_defined():
503    # Need to know the values of `d_lower` and `d_upper` to infer diag_shape.
504    grad_shape = array_ops.shape(grad)
505    batch_shape = grad_shape[:-2]
506    matrix_shape = grad_shape[-2:]
507    diag_index = array_ops.reshape(op.inputs[2], [-1])  # Converts to vector.
508    d_lower = diag_index[0]
509    d_upper = diag_index[-1]  # Works both when len(diag_index) is 1 and 2.
510    y_offset = control_flow_ops.cond(
511        math_ops.less(d_upper, 0), lambda: d_upper, lambda: 0)
512    x_offset = control_flow_ops.cond(
513        math_ops.greater(d_lower, 0), lambda: -d_lower, lambda: 0)
514
515    max_diag_len = math_ops.minimum(matrix_shape[0] + y_offset,
516                                    matrix_shape[1] + x_offset)
517    # pylint: disable=g-long-lambda
518    # pyformat: disable
519    postfix = control_flow_ops.cond(
520        math_ops.equal(d_lower, d_upper),
521        lambda: ops.convert_to_tensor([max_diag_len]),
522        lambda: ops.convert_to_tensor([d_upper - d_lower + 1,
523                                       max_diag_len]))
524    # pyformat: enable
525    # pylint: enable=g-long-lambda
526    diag_shape = array_ops.concat([batch_shape, postfix], 0)
527
528  grad_input = array_ops.matrix_set_diag(
529      grad,
530      array_ops.zeros(diag_shape, dtype=grad.dtype),
531      k=op.inputs[2],
532      align=align)
533  grad_diag = array_ops.matrix_diag_part(grad, k=op.inputs[2], align=align)
534  return (grad_input, grad_diag, None)
535
536
537@ops.RegisterGradient("MatrixBandPart")
538def _MatrixBandPartGrad(op, grad):
539  num_lower = op.inputs[1]
540  num_upper = op.inputs[2]
541  return (array_ops.matrix_band_part(grad, num_lower, num_upper), None, None)
542
543
544# Edit Distance has no gradient (but can be used to eval seq2seq or CTC).
545ops.NotDifferentiable("EditDistance")
546
547
548@ops.RegisterGradient("Fill")
549def _FillGrad(_, grad):
550  return None, math_ops.reduce_sum(grad)
551
552
553ops.NotDifferentiable("ZerosLike")
554ops.NotDifferentiable("OnesLike")
555
556
557@ops.RegisterGradient("PreventGradient")
558def _PreventGradientGrad(op, _):
559  raise LookupError("Gradient explicitly disabled. Reason: %s" %
560                    op.get_attr("message"))
561
562
563def _IndexedSlicesToTensorNoWarning(indexed_slices):
564  """Converts an IndexedSlices to a Tensor without sparse->dense warnings."""
565  if not isinstance(indexed_slices, indexed_slices_lib.IndexedSlices):
566    # If it is not IndexedSlices, it's better be a tensor.
567    return indexed_slices
568  if indexed_slices.dense_shape is None:
569    raise ValueError(
570        "Tensor conversion requested for IndexedSlices without dense_shape: %s"
571        % str(indexed_slices))
572  return math_ops.unsorted_segment_sum(indexed_slices.values,
573                                       indexed_slices.indices,
574                                       indexed_slices.dense_shape[0])
575
576
577@ops.RegisterGradient("Gather")
578def _GatherGrad(op, grad):
579  """Gradient for Gather op."""
580  # params can be large, so colocate the shape calculation with it.
581  params = op.inputs[0]
582  with ops.colocate_with(params):
583    params_shape = array_ops.shape(params)
584
585  # Build appropriately shaped IndexedSlices
586  indices = op.inputs[1]
587  size = array_ops.expand_dims(array_ops.size(indices), 0)
588  values_shape = array_ops.concat([size, params_shape[1:]], 0)
589  values = array_ops.reshape(
590      _IndexedSlicesToTensorNoWarning(grad), values_shape)
591  indices = array_ops.reshape(indices, size)
592  return [indexed_slices_lib.IndexedSlices(values, indices, params_shape), None]
593
594
595def _GetBatchIndices(params_shape, indices, batch_dims):
596  """Addds the batch offsets to the given indices and returns the results."""
597  batch_indices = indices
598  indices_dtype = indices.dtype.base_dtype
599  casted_params_shape = math_ops.cast(params_shape, indices_dtype)
600  accum_dim_value = array_ops.ones((), dtype=indices_dtype)
601  for dim in range(batch_dims, 0, -1):
602    dim_value = casted_params_shape[dim - 1]
603    accum_dim_value *= casted_params_shape[dim]
604    start = array_ops.zeros((), dtype=indices_dtype)
605    step = array_ops.ones((), dtype=indices_dtype)
606    dim_indices = math_ops.range(start, dim_value, step)
607    dim_indices *= accum_dim_value
608    dim_shape = array_ops.concat([
609        array_ops.tile([1], [dim - 1]), [dim_value],
610        array_ops.tile([1], [array_ops.rank(indices) - dim])
611    ], axis=0)
612    batch_indices += array_ops.reshape(dim_indices, dim_shape)
613
614  return batch_indices
615
616
617def _BatchGatherGrad(params_shape, values, indices, batch_dims,
618                     gather_dim_size):
619  """Returns the gradient of GatherV2 with batch dimensions."""
620
621  # Axis is the first non-batch dimension.
622  indices_size = array_ops.expand_dims(array_ops.size(indices), 0)
623  if batch_dims:
624    values_shape = array_ops.shape(values)
625    # Add the batch offsets to indices and flatten the batch dimensions.
626    outer_shape = values_shape[:batch_dims]
627    inner_shape = values_shape[batch_dims:][1:]
628    batch_size = gen_math_ops.prod(outer_shape, [0], False)
629    flat_values_shape = array_ops.concat([[-1], inner_shape], 0)
630    gather_dim_size *= batch_size
631
632    indices = _GetBatchIndices(params_shape, indices, batch_dims)
633    values = array_ops.reshape(
634        _IndexedSlicesToTensorNoWarning(values), flat_values_shape)
635
636  indices = array_ops.reshape(indices, indices_size)
637  params_grad = math_ops.unsorted_segment_sum(values, indices, gather_dim_size)
638
639  if batch_dims:
640    # Put back the batch dimensions.
641    params_grad = array_ops.reshape(
642        params_grad, array_ops.concat([outer_shape, flat_values_shape], 0))
643
644  return params_grad
645
646
647@ops.RegisterGradient("GatherV2")
648def _GatherV2Grad(op, grad):
649  """Gradient for GatherV2 op."""
650  # params can be large, so colocate the shape calculation with it.
651  #
652  # params can be very large for sparse model, array_ops.shape raises
653  # exception on the Windows platform when any dimension is larger than
654  # int32. params_shape is not used in optimizer apply_sparse gradients,
655  # so it's fine to convert it back to int32 regardless of truncation.
656  params = op.inputs[0]
657  with ops.colocate_with(params):
658    params_shape = array_ops.shape(params, out_type=ops.dtypes.int64)
659    params_shape = math_ops.cast(params_shape, dtypes.int32)
660
661  indices = op.inputs[1]
662  indices_size = array_ops.expand_dims(array_ops.size(indices), 0)
663  axis = op.inputs[2]
664  axis_static = tensor_util.constant_value(axis)
665  batch_dims = int(op.get_attr("batch_dims"))
666
667  if batch_dims < 0:
668    if indices.shape.ndims is None:
669      raise ValueError(
670          f"Currently, it is unsupported to take the gradient of tf.gather "
671          f"when batch_dims < 0 and the rank of the indices is unknown. Please "
672          f"pass a positive batch_dims or use tf.ensure_shape to update the "
673          f"shape of indices when calling tf.gather. Got "
674          f"batch_dims={batch_dims} and indices={indices}")
675    batch_dims += indices.shape.ndims
676
677  # For axis 0 gathers, build an appropriately shaped IndexedSlices.
678  if axis_static == 0:
679    if context.executing_eagerly():
680      with ops.device(indices_size.device):
681        params_tail_shape = array_ops.identity(params_shape)[1:]
682    else:
683      params_tail_shape = params_shape[1:]
684    values_shape = array_ops.concat([indices_size, params_tail_shape], 0)
685    values = array_ops.reshape(
686        _IndexedSlicesToTensorNoWarning(grad), values_shape)
687    indices = array_ops.reshape(indices, indices_size)
688    params_grad = indexed_slices_lib.IndexedSlices(values, indices,
689                                                   params_shape)
690  else:
691    # Handle axis by transposing the axis dimension to be the first non-batch
692    # dimension, compute the gradient and transpose the result back.
693    outer_shape = params_shape[:axis]
694    inner_shape = params_shape[axis:][1:]
695    values_shape = array_ops.concat([outer_shape, [-1], inner_shape], 0)
696
697    values_dims = array_ops.size(values_shape)
698    axis_dims = array_ops.size(outer_shape)
699
700    outer_batches_indices = math_ops.range(batch_dims)
701    batch_axis_indices = math_ops.range(batch_dims, axis_dims)
702    inner_axes_indices = math_ops.range(axis_dims + 1, values_dims)
703
704    values = array_ops.reshape(
705        _IndexedSlicesToTensorNoWarning(grad), values_shape)
706
707    # Move values[axis] up to values[batch_dims]
708    transpose_dims = array_ops.concat([
709        outer_batches_indices, [axis_dims], batch_axis_indices,
710        inner_axes_indices
711    ], 0)
712    values_transpose = array_ops.transpose(values, transpose_dims)
713    params_shape_transpose = array_ops.gather(params_shape, transpose_dims)
714
715    params_grad = _BatchGatherGrad(params_shape_transpose, values_transpose,
716                                   indices, batch_dims, params_shape[axis])
717
718    # Inverts the above transpose by moving dimension batch_dims back to its
719    # original position.
720    invert_transpose_dims = array_ops.concat([
721        outer_batches_indices, batch_axis_indices + 1, [batch_dims],
722        inner_axes_indices
723    ], 0)
724    params_grad = array_ops.transpose(params_grad, invert_transpose_dims)
725
726  return [params_grad, None, None]
727
728
729@ops.RegisterGradient("GatherNd")
730def _GatherNdGrad(op, grad):
731  ref = op.inputs[0]
732  indices = op.inputs[1]
733  ref_shape = array_ops.shape(ref, out_type=indices.dtype)
734  if indices.shape.ndims == 2 and indices.shape.dims[-1].value == 1:
735    ref_grad = indexed_slices_lib.IndexedSlices(
736        grad, array_ops.squeeze(indices, axis=-1), ref_shape)
737  else:
738    ref_grad = array_ops.scatter_nd(indices, grad, ref_shape)
739  return [ref_grad, None]
740
741
742@ops.RegisterGradient("ResourceGatherNd")
743def _ResourceGatherNdGrad(op, grad):  # pylint: disable=missing-docstring
744  ref = op.inputs[0]
745  indices = op.inputs[1]
746  ref_shape = gen_resource_variable_ops.variable_shape(ref, indices.dtype)
747  if indices.shape.ndims == 2 and indices.shape.dims[-1].value == 1:
748    ref_grad = indexed_slices_lib.IndexedSlices(
749        grad, array_ops.squeeze(indices, axis=-1), ref_shape)
750  else:
751    ref_grad = array_ops.scatter_nd(indices, grad, ref_shape)
752  return [ref_grad, None]
753
754
755@ops.RegisterGradient("CheckNumerics")
756def _CheckNumericsGrad(op, grad):
757  """Gradient for check_numerics op."""
758  return array_ops.check_numerics(
759      grad,
760      "Not a number (NaN) or infinity (Inf) values detected in gradient. %s" %
761      op.get_attr("message"))
762
763
764@ops.RegisterGradient("CheckNumericsV2")
765def _CheckNumericsV2Grad(op, grad):
766  """Gradient for check_numerics op."""
767  return array_ops.check_numerics_v2(
768      grad,
769      "Not a number (NaN) or infinity (Inf) values detected in gradient. %s" %
770      op.get_attr("message"))
771
772
773@ops.RegisterGradient("PlaceholderWithDefault")
774@ops.RegisterGradient("Identity")
775def _IdGrad(_, grad):
776  return grad
777
778
779@ops.RegisterGradient("_EagerConst")
780def _EagerConstGrad(_, grad):
781  raise AssertionError(
782      "This op should never interact with gradient APIs. Please file a bug.")
783
784
785@ops.RegisterGradient("RefIdentity")
786def _RefIdGrad(_, grad):
787  return grad
788
789
790@ops.RegisterGradient("IdentityN")
791def _IdNGrad(_, *grad):
792  return grad
793
794
795ops.NotDifferentiable("StopGradient")
796
797
798@ops.RegisterGradient("Reshape")
799def _ReshapeGrad(op, grad):
800  return [
801      array_ops.reshape(
802          _IndexedSlicesToTensorNoWarning(grad), array_ops.shape(op.inputs[0])),
803      None
804  ]
805
806
807ops.NotDifferentiable("InvertPermutation")
808
809
810def _ReshapeToInput(op, grad):
811  """Reshapes the gradient to the shape of the original input."""
812  return array_ops.reshape(
813      _IndexedSlicesToTensorNoWarning(grad), array_ops.shape(op.inputs[0]))
814
815
816@ops.RegisterGradient("ExpandDims")
817def _ExpandDimsGrad(op, grad):
818  return [_ReshapeToInput(op, grad), None]
819
820
821@ops.RegisterGradient("Squeeze")
822def _SqueezeGrad(op, grad):
823  return _ReshapeToInput(op, grad)
824
825
826@ops.RegisterGradient("Transpose")
827def _TransposeGrad(op, grad):
828  """Returns unshuffle(grad)."""
829  p = op.inputs[1]
830  return [array_ops.transpose(grad, array_ops.invert_permutation(p)), None]
831
832
833@ops.RegisterGradient("ConjugateTranspose")
834def _ConjugateTransposeGrad(op, grad):
835  """Returns conj(unshuffle(grad))."""
836  p = op.inputs[1]
837  return [
838      array_ops.transpose(
839          grad, array_ops.invert_permutation(p), conjugate=True), None
840  ]
841
842
843ops.NotDifferentiable("Shape")
844
845ops.NotDifferentiable("ShapeN")
846
847ops.NotDifferentiable("Rank")
848
849ops.NotDifferentiable("Size")
850
851
852@ops.RegisterGradient("Tile")
853def _TileGrad(op, grad):
854  """Sum reduces grad along the tiled dimensions."""
855  input_shape = array_ops.shape(op.inputs[0], out_type=op.inputs[1].dtype)
856  # We interleave multiples and input_shape to get split_shape,
857  # reshape grad to split_shape, and reduce along all even
858  # dimensions (the tiled dimensions) to get the result
859  # with shape input_shape.  For example
860  #   input_shape = [20, 30, 40]
861  #   multiples = [2, 3, 4]
862  #   split_shape = [2, 20, 3, 30, 4, 40]
863  #   axes = [0, 2, 4]
864  split_shape = array_ops.reshape(
865      array_ops.transpose(array_ops.stack([op.inputs[1], input_shape])), [-1])
866  axes = math_ops.range(0, array_ops.size(split_shape), 2)
867  # Sum reduces grad along the first dimension for IndexedSlices
868  if isinstance(grad, indexed_slices_lib.IndexedSlices):
869    input_shape_0 = math_ops.cast(input_shape[0], grad.indices.dtype)
870    grad = math_ops.unsorted_segment_sum(
871        grad.values, math_ops.mod(grad.indices, input_shape_0), input_shape_0)
872    split_shape = array_ops.concat([[1], split_shape[1:]], axis=0)
873  input_grad = math_ops.reduce_sum(array_ops.reshape(grad, split_shape), axes)
874  # Fix shape inference
875  if not context.executing_eagerly():
876    input_grad.set_shape(op.inputs[0].get_shape())
877  return [input_grad, None]
878
879
880ops.NotDifferentiable("BroadcastGradientArgs")
881
882
883def _PadGrad(op, grad):
884  """Gradient for Pad."""
885  # Pad introduces values around the original tensor, so the gradient function
886  # slices the original shape out of the gradient."""
887  x = op.inputs[0]
888  a = op.inputs[1]  # [Rank(x), 2]
889  # Takes a slice of a. The 1st column. [Rank(x), 1].
890  pad_before = array_ops.slice(a, [0, 0],
891                               array_ops.stack([array_ops.rank(x), 1]))
892  # Make it a 1-D tensor.
893  begin = array_ops.reshape(pad_before, [-1])
894  sizes = array_ops.shape(x, out_type=begin.dtype)
895  x_grad = array_ops.slice(grad, begin, sizes)
896  if len(op.inputs) == 3:
897    return x_grad, None, None
898  else:
899    return x_grad, None
900
901
902ops.RegisterGradient("Pad")(_PadGrad)
903ops.RegisterGradient("PadV2")(_PadGrad)
904
905
906# ReverseSequence is just a permutation.  The gradient permutes back.
907@ops.RegisterGradient("ReverseSequence")
908def _ReverseSequenceGrad(op, grad):
909  seq_lengths = op.inputs[1]
910  return [
911      array_ops.reverse_sequence(
912          grad,
913          batch_axis=op.get_attr("batch_dim"),
914          seq_axis=op.get_attr("seq_dim"),
915          seq_lengths=seq_lengths), None
916  ]
917
918
919@ops.RegisterGradient("Reverse")
920def _ReverseGrad(op, grad):
921  reverse_dims = op.inputs[1]
922  return gen_array_ops.reverse(grad, reverse_dims), None
923
924
925@ops.RegisterGradient("ReverseV2")
926def _ReverseV2Grad(op, grad):
927  axis = op.inputs[1]
928  return array_ops.reverse_v2(grad, axis), None
929
930
931@ops.RegisterGradient("SpaceToBatch")
932def _SpaceToBatchGrad(op, grad):
933  # Its gradient is the opposite op: BatchToSpace.
934  block_size = op.get_attr("block_size")
935  return [
936      array_ops.batch_to_space(grad, op.inputs[1], block_size=block_size), None
937  ]
938
939
940@ops.RegisterGradient("SpaceToBatchND")
941def _SpaceToBatchNDGrad(op, grad):
942  # Its gradient is the opposite op: BatchToSpaceND.
943  return [
944      array_ops.batch_to_space_nd(grad, op.inputs[1], op.inputs[2]), None, None
945  ]
946
947
948@ops.RegisterGradient("BatchToSpace")
949def _BatchToSpaceGrad(op, grad):
950  # Its gradient is the opposite op: SpaceToBatch.
951  block_size = op.get_attr("block_size")
952  return [
953      array_ops.space_to_batch(grad, op.inputs[1], block_size=block_size), None
954  ]
955
956
957@ops.RegisterGradient("BatchToSpaceND")
958def _BatchToSpaceNDGrad(op, grad):
959  # Its gradient is the opposite op: SpaceToBatchND.
960  return [
961      array_ops.space_to_batch_nd(grad, op.inputs[1], op.inputs[2]), None, None
962  ]
963
964
965@ops.RegisterGradient("SpaceToDepth")
966def _SpaceToDepthGrad(op, grad):
967  # Its gradient is the opposite op: DepthToSpace.
968  block_size = op.get_attr("block_size")
969  data_format = op.get_attr("data_format")
970  if data_format == "NCHW_VECT_C":
971    raise ValueError("Cannot compute SpaceToDepth gradient with NCHW_VECT_C. "
972                     "NCHW_VECT_C requires qint8 data type.")
973  return array_ops.depth_to_space(grad, block_size, data_format=data_format)
974
975
976@ops.RegisterGradient("DepthToSpace")
977def _DepthToSpaceGrad(op, grad):
978  # Its gradient is the opposite op: SpaceToDepth.
979  block_size = op.get_attr("block_size")
980  data_format = op.get_attr("data_format")
981  if data_format == "NCHW_VECT_C":
982    raise ValueError("Cannot compute DepthToSpace gradient with NCHW_VECT_C. "
983                     "NCHW_VECT_C requires qint8 data type.")
984  return array_ops.space_to_depth(grad, block_size, data_format=data_format)
985
986
987ops.NotDifferentiable("OneHot")
988
989
990@ops.RegisterGradient("MirrorPad")
991def _MirrorPadGrad(op, grad):
992  mode = op.get_attr("mode")
993  return [gen_array_ops.mirror_pad_grad(grad, op.inputs[1], mode=mode), None]
994
995
996@ops.RegisterGradient("MirrorPadGrad")
997def _MirrorPadGradGrad(op, grad):
998  mode = op.get_attr("mode")
999  return [gen_array_ops.mirror_pad(grad, op.inputs[1], mode=mode), None]
1000
1001
1002@ops.RegisterGradient("QuantizeAndDequantize")
1003def _QuantizeAndDequantizeGrad(_, grad):
1004  return grad
1005
1006
1007@ops.RegisterGradient("QuantizeAndDequantizeV2")
1008def _QuantizeAndDequantizeV2Grad(_, grad):
1009  return [grad, None, None]
1010
1011
1012@ops.RegisterGradient("QuantizeAndDequantizeV3")
1013def _QuantizeAndDequantizeV3Grad(_, grad):
1014  # Only propagate the gradient for the unquantized input.
1015  return [grad, None, None, None]
1016
1017
1018@ops.RegisterGradient("ExtractImagePatches")
1019def _ExtractImagePatchesGrad(op, grad):
1020  input_bhwc = array_ops.shape(op.inputs[0], out_type=dtypes.int64)
1021  batch_size, rows_in, cols_in, channels = input_bhwc[0], input_bhwc[1], \
1022                                           input_bhwc[2], input_bhwc[3]
1023
1024  # Create indices matrix for input tensor.
1025  # Note that 0 is preserved for padding location,
1026  # so indices for input start from 1 to 1 + rows_in * cols_in.
1027  input_indices_num = 1 + rows_in * cols_in
1028  input_idx = array_ops.reshape(
1029      math_ops.range(1, input_indices_num, dtype=ops.dtypes.int64),
1030      (1, rows_in, cols_in, 1))
1031  input_idx_patched = gen_array_ops.extract_image_patches(
1032      input_idx, op.get_attr("ksizes"), op.get_attr("strides"),
1033      op.get_attr("rates"), op.get_attr("padding"))
1034
1035  # Create indices matrix for output tensor.
1036  output_bhwc = array_ops.shape(op.outputs[0], out_type=dtypes.int64)
1037  rows_out, cols_out = output_bhwc[1], output_bhwc[2]
1038  _, ksize_r, ksize_c, _ = op.get_attr("ksizes")
1039  # Indices for output start from 0.
1040  output_indices_num = rows_out * cols_out * ksize_r * ksize_c
1041  output_idx = array_ops.reshape(
1042      math_ops.range(output_indices_num, dtype=ops.dtypes.int64),
1043      (1, rows_out, cols_out, ksize_r * ksize_c))
1044
1045  # Construct mapping table for indices: (input -> output).
1046  idx_matrix = array_ops.concat([
1047      array_ops.expand_dims(input_idx_patched, axis=-1),
1048      array_ops.expand_dims(output_idx, axis=-1)
1049  ],
1050                                axis=-1)
1051  idx_map = array_ops.reshape(idx_matrix, (-1, 2))
1052
1053  sp_shape = (input_indices_num, output_indices_num)
1054  sp_mat_full = sparse_tensor.SparseTensor(
1055      idx_map, array_ops.ones([output_indices_num], dtype=grad.dtype), sp_shape)
1056  # Remove all padding locations [0, :].
1057  sp_mat = sparse_ops.sparse_slice(sp_mat_full, (1, 0),
1058                                   (input_indices_num - 1, output_indices_num))
1059
1060  grad_expanded = array_ops.transpose(
1061      array_ops.reshape(
1062          _IndexedSlicesToTensorNoWarning(grad),
1063          (batch_size, rows_out, cols_out, ksize_r, ksize_c, channels)),
1064      (1, 2, 3, 4, 0, 5))
1065  grad_flat = array_ops.reshape(grad_expanded, (-1, batch_size * channels))
1066
1067  jac = sparse_ops.sparse_tensor_dense_matmul(sp_mat, grad_flat)
1068
1069  grad_out = array_ops.reshape(jac, (rows_in, cols_in, batch_size, channels))
1070  grad_out = array_ops.transpose(grad_out, (2, 0, 1, 3))
1071
1072  return [grad_out]
1073
1074
1075@ops.RegisterGradient("ExtractVolumePatches")
1076def _ExtractVolumePatchesGrad(op, grad):
1077  batch_size, planes_in, rows_in, cols_in, channels = [
1078      dim.value for dim in op.inputs[0].shape.dims
1079  ]
1080  input_bphwc = array_ops.shape(op.inputs[0])
1081  batch_size = input_bphwc[0]
1082  channels = input_bphwc[4]
1083
1084  # Create indices matrix for input tensor.
1085  # Note that 0 is preserved for padding location,
1086  # so indices for input start from 1 to 1 + rows_in * cols_in.
1087  input_indices_num = 1 + planes_in * rows_in * cols_in
1088  input_idx = array_ops.reshape(
1089      math_ops.range(1, input_indices_num, dtype=ops.dtypes.int64),
1090      (1, planes_in, rows_in, cols_in, 1))
1091  input_idx_patched = gen_array_ops.extract_volume_patches(
1092      input_idx, op.get_attr("ksizes"), op.get_attr("strides"),
1093      op.get_attr("padding"))
1094
1095  # Create indices matrix for output tensor.
1096  _, planes_out, rows_out, cols_out, _ = [
1097      dim.value for dim in op.outputs[0].shape.dims
1098  ]
1099  _, ksize_p, ksize_r, ksize_c, _ = op.get_attr("ksizes")
1100  # Indices for output start from 0.
1101  prc_indices_num = planes_out * rows_out * cols_out
1102  output_indices_num = prc_indices_num * ksize_p * ksize_r * ksize_c
1103  output_idx = array_ops.reshape(
1104      math_ops.range(output_indices_num, dtype=ops.dtypes.int64),
1105      (1, planes_out, rows_out, cols_out, ksize_p * ksize_r * ksize_c))
1106
1107  # Construct mapping table for indices: (input -> output).
1108  idx_matrix = array_ops.concat([
1109      array_ops.expand_dims(input_idx_patched, axis=-1),
1110      array_ops.expand_dims(output_idx, axis=-1)
1111  ],
1112                                axis=-1)
1113  idx_map = array_ops.reshape(idx_matrix, (-1, 2))
1114
1115  sp_shape = (input_indices_num, output_indices_num)
1116  sp_mat_full = sparse_tensor.SparseTensor(
1117      idx_map, array_ops.ones([output_indices_num], dtype=grad.dtype), sp_shape)
1118  # Remove all padding locations [0, :].
1119  sp_mat = sparse_ops.sparse_slice(sp_mat_full, (1, 0),
1120                                   (input_indices_num - 1, output_indices_num))
1121
1122  grad_expanded = array_ops.transpose(
1123      array_ops.reshape(
1124          _IndexedSlicesToTensorNoWarning(grad),
1125          (batch_size, planes_out, rows_out, cols_out, ksize_p, ksize_r,
1126           ksize_c, channels)), (1, 2, 3, 4, 5, 6, 0, 7))
1127  grad_flat = array_ops.reshape(grad_expanded, (-1, batch_size * channels))
1128
1129  jac = sparse_ops.sparse_tensor_dense_matmul(sp_mat, grad_flat)
1130
1131  grad_out = array_ops.reshape(
1132      jac, (planes_in, rows_in, cols_in, batch_size, channels))
1133  grad_out = array_ops.transpose(grad_out, (3, 0, 1, 2, 4))
1134
1135  return [grad_out]
1136
1137
1138@ops.RegisterGradient("ScatterNd")
1139def _ScatterNdGrad(op, grad):
1140  indices = op.inputs[0]
1141  updates_grad = array_ops.gather_nd(grad, indices)
1142  return [None, updates_grad, None]
1143
1144
1145@ops.RegisterGradient("TensorScatterUpdate")
1146def _TensorScatterUpdateGrad(op, grad):
1147  indices = op.inputs[1]
1148  updates_grad = array_ops.gather_nd(grad, indices)
1149  tensor_grad = array_ops.tensor_scatter_update(
1150      array_ops.identity(grad), indices,
1151      array_ops.zeros_like(op.inputs[2], dtype=grad.dtype))
1152  return [tensor_grad, None, updates_grad]
1153
1154
1155@ops.RegisterGradient("TensorScatterAdd")
1156def _TensorScatterAddGrad(op, grad):
1157  indices = op.inputs[1]
1158  updates_grad = array_ops.gather_nd(grad, indices)
1159  tensor_grad = array_ops.identity(grad)
1160  return [tensor_grad, None, updates_grad]
1161
1162
1163def _TensorScatterMinOrMaxGrad(op, grad):
1164  """Gradient for TensorScatterMin and TensorScatterMax."""
1165  indices = op.inputs[1]
1166  x = op.inputs[0]
1167  y = op.inputs[2]
1168  output = op.outputs[0]
1169  x_indicators = math_ops.cast(math_ops.equal(x, output), grad.dtype)
1170  y_output = array_ops.gather_nd(output, indices)
1171  y_indicators = math_ops.cast(math_ops.equal(y, y_output), grad.dtype)
1172  ys_indicators = array_ops.scatter_nd(
1173      indices, y_indicators, array_ops.shape(x, out_type=indices.dtype))
1174  indicators = x_indicators + ys_indicators  # All elements are >= 1.
1175  # If there are multiple minimum or maximum elements then the gradient will be
1176  # divided between them.
1177  x_grad = grad * x_indicators / indicators
1178  y_grad = array_ops.gather_nd(grad / indicators, indices) * y_indicators
1179  return [x_grad, None, y_grad]
1180
1181
1182@ops.RegisterGradient("TensorScatterMax")
1183def _TensorScatterMaxGrad(op, grad):
1184  """Gradient for TensorScatterMax op."""
1185  return _TensorScatterMinOrMaxGrad(op, grad)
1186
1187
1188@ops.RegisterGradient("TensorScatterMin")
1189def _TensorScatterMinGrad(op, grad):
1190  """Gradient for TensorScatterMin op."""
1191  return _TensorScatterMinOrMaxGrad(op, grad)
1192
1193
1194@ops.RegisterGradient("TensorScatterSub")
1195def _TensorScatterSubGrad(op, grad):
1196  indices = op.inputs[1]
1197  updates_grad = array_ops.gather_nd(grad, indices)
1198  tensor_grad = array_ops.identity(grad)
1199  return [tensor_grad, None, -updates_grad]
1200
1201
1202@ops.RegisterGradient("ScatterNdNonAliasingAdd")
1203def _ScatterNdNonAliasingAddGrad(op, grad):
1204  indices = op.inputs[1]
1205  updates_grad = array_ops.gather_nd(grad, indices)
1206  return [grad, None, updates_grad]
1207
1208
1209@ops.RegisterGradient("BroadcastTo")
1210def _BroadcastToGrad(op, grad):
1211  input_value = op.inputs[0]
1212  broadcast_shape = op.inputs[1]
1213  shape_dtype = dtypes.int32
1214  if isinstance(broadcast_shape, ops.Tensor):
1215    shape_dtype = broadcast_shape.dtype
1216
1217  input_value_shape = array_ops.shape(input_value, out_type=shape_dtype)
1218  if not isinstance(broadcast_shape, ops.EagerTensor):
1219    broadcast_shape_static = tensor_shape.TensorShape(
1220        tensor_util.try_evaluate_constant(broadcast_shape))
1221    if broadcast_shape_static.is_fully_defined():
1222      broadcast_shape = constant_op.constant(
1223          broadcast_shape_static.as_list(), dtype=shape_dtype)
1224  _, reduction_axes = gen_array_ops.broadcast_gradient_args(
1225      broadcast_shape, input_value_shape)
1226  updates_grad_reshaped = math_ops.reduce_sum(
1227      grad, axis=reduction_axes, keepdims=True)
1228  updates_grad = array_ops.reshape(updates_grad_reshaped, input_value_shape)
1229  return [updates_grad, None]
1230