• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Gradients for operators defined in array_ops.py."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.compiler.tf2xla.ops import gen_xla_ops
22from tensorflow.python import pywrap_tfe
23from tensorflow.python.client import pywrap_tf_session
24from tensorflow.python.eager import context
25from tensorflow.python.framework import constant_op
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import ops
28from tensorflow.python.framework import sparse_tensor
29from tensorflow.python.framework import tensor_shape
30from tensorflow.python.framework import tensor_util
31from tensorflow.python.ops import array_ops
32from tensorflow.python.ops import control_flow_ops
33from tensorflow.python.ops import control_flow_util
34from tensorflow.python.ops import gen_array_ops
35from tensorflow.python.ops import gen_math_ops
36from tensorflow.python.ops import gen_resource_variable_ops
37from tensorflow.python.ops import math_ops
38from tensorflow.python.ops import sparse_ops
39
40
41@ops.RegisterGradient("Pack")
42def _PackGrad(op, grad):
43  """Gradient for pack op."""
44  return array_ops.unstack(grad, num=op.get_attr("N"), axis=op.get_attr("axis"))
45
46
47@ops.RegisterGradient("Unpack")
48def _UnpackGrad(op, *grads):
49  """Gradient for unpack op."""
50  return array_ops.stack(grads, axis=op.get_attr("axis"))
51
52
53def _ConcatGradHelper(op, grad, start_value_index, end_value_index, dim_index):
54  """Gradient for concat op.
55
56  Args:
57    op: An operation.
58    grad: `Tensor` or `IndexedSlices` representing the gradients with respect to
59      each output of the op.
60    start_value_index: An integer index of the first value in the op.inputs.
61    end_value_index: An integer index of the last value in the op.inputs.
62    dim_index: An integer index of concat_dim or axis parameter in op.inputs.
63
64  Returns:
65    Tensors representing the partial gradients with respect to each input
66    of the op.
67
68  Raises:
69    ValueError: if concat_dim/axis is not statically known.
70  """
71
72  def _CreateDenseMaskAndBegin(sizes, concat_dim):
73    """Create variables for iteratively slicing a dense gradients tensor."""
74    # Since shape is 1-D, shape_of_shape = [rank-of-inputs]
75    shape_of_shape = array_ops.shape(sizes[0])
76    # Make a vector of length equal to the input's dimensions,
77    # with 0's everywhere and 1 in the concat dim position.
78    # Note: Can't use sparse_to_dense since it isn't GPU-capable (for now)
79    mask = array_ops.concat([
80        array_ops.zeros(
81            array_ops.expand_dims(concat_dim, 0), dtype=dtypes.int32), [1],
82        array_ops.zeros(shape_of_shape - concat_dim - 1, dtype=dtypes.int32)
83    ], 0)
84    begin = array_ops.zeros(shape_of_shape, dtype=dtypes.int32)
85    return mask, begin
86
87  def _ExtractInputShapes(inputs):
88    """Extract the shapes of a set of input tensors."""
89    if context.executing_eagerly():
90      return array_ops.shape_n(inputs)
91    sizes = []
92    fully_known = True
93    for x in inputs:
94      input_shape = array_ops.shape(x)
95      if not isinstance(input_shape,
96                        ops.Tensor) or input_shape.op.type != "Const":
97        fully_known = False
98        break
99      sizes.append(input_shape)
100
101    if fully_known:
102      return sizes
103    else:
104      return array_ops.shape_n(inputs)
105
106  # Degenerate concatenation, just return grad.
107  if len(op.inputs) == 2:
108    return grad + [None] if end_value_index <= dim_index else [None] + grad
109
110  concat_dim = op.inputs[dim_index]
111  input_values = op.inputs[start_value_index:end_value_index]
112
113  out_grads = []
114  if isinstance(grad, ops.Tensor):
115    if context.executing_eagerly() or isinstance(concat_dim, ops.EagerTensor):
116      # Using mod here for convenience since concat_dim is already verified
117      # in concat implementation to be within the allowed [-rank, rank) range.
118      non_neg_concat_dim = (
119          concat_dim._numpy().item(0) % input_values[0]._rank())  # pylint: disable=protected-access
120      # All inputs are guaranteed to be EagerTensors in eager mode
121      sizes = pywrap_tfe.TFE_Py_TensorShapeSlice(input_values,
122                                                 non_neg_concat_dim)
123      out_grads = array_ops.split(grad, sizes, non_neg_concat_dim)
124    else:
125      if constant_op.is_constant(concat_dim):
126        # If concat_dim is a constant defined in a different context,
127        # then we duplicate it in the current context to avoid passing it
128        # through an Enter node.
129        # This is a small optimization in general, but it is required when
130        # compiling with XLA, as XLA needs the concat input to be folded into a
131        # constant.
132        grad_context = control_flow_util.GetOutputContext(grad.op)
133        dim_context = control_flow_util.GetOutputContext(concat_dim.op)
134        if dim_context != grad_context:
135          value = tensor_util.constant_value(concat_dim)
136          concat_dim = constant_op.constant(value=value, dtype=concat_dim.dtype)
137
138      # Using mod here for convenience since concat_dim is already verified
139      # in concat implementation to be within the allowed [-rank, rank) range.
140      non_neg_concat_dim = concat_dim % array_ops.rank(input_values[0])
141
142      # Get the inputs' tensor shapes
143      sizes = _ExtractInputShapes(input_values)
144      # The magic number of 16 was found through benchmarking a range of sizes
145      # on CPUs and a Maxwell TitanX.  A speedup was seen in a large majority of
146      # cases when switching implementations at N=16, but it is possible that
147      # there will be a small number of performance regressions.
148      if len(sizes) > 16:
149        # extract the size of each input along the concat dimension
150        sizes = array_ops.squeeze(
151            array_ops.slice(
152                array_ops.stack(sizes, axis=1), [non_neg_concat_dim, 0],
153                [1, -1]))
154        out_grads = array_ops.split(grad, sizes, non_neg_concat_dim)
155      else:
156        offset = gen_array_ops.concat_offset(non_neg_concat_dim, sizes)
157        for (begin, size) in zip(offset, sizes):
158          out_grads.append(array_ops.slice(grad, begin, size))
159  elif isinstance(grad, ops.IndexedSlices):
160    # Using mod here for convenience since concat_dim is already verified
161    # in concat implementation to be within the allowed [-rank, rank) range.
162    non_neg_concat_dim = concat_dim % array_ops.rank(input_values[0])
163    concat_dim_static = tensor_util.constant_value(concat_dim)
164    if concat_dim_static is None:
165      raise ValueError("Can only compute IndexedSlices gradient with "
166                       "statically-known concat_dim")
167    if concat_dim_static < 0:
168      rank = tensor_util.constant_value(array_ops.rank(input_values[0]))
169      if rank is None:
170        raise ValueError("Can only compute IndexedSlices gradient with "
171                         "negative concat_dim when first value rank is "
172                         "statically-known.")
173      concat_dim_static %= rank
174    # Get the inputs' tensor shapes
175    sizes = [array_ops.shape(x) for x in input_values]
176    if concat_dim_static > 0:
177      # IndexedSlices, non_neg_concat_dim > 0. Each input gets IndexedSlices
178      # gradients with all the indices, but with grad.values sliced accordingly.
179      # This is like the Tensor case, except shape(grad.values)[0] is not equal
180      # to shape(sizes[i])[0], since only a subset of the dim-0 values are
181      # stored.
182      mask, begin = _CreateDenseMaskAndBegin(sizes, non_neg_concat_dim)
183      for size in sizes:
184        new_values = array_ops.slice(
185            grad.values, begin,
186            array_ops.concat([[-1], array_ops.slice(size, [1], [-1])], 0))
187        out_grads.append(ops.IndexedSlices(new_values, grad.indices, size))
188        # Lint complains begin = begin + ...
189        begin = math_ops.add(begin, size * mask)
190    else:
191      # IndexedSlices, concat_dim == 0. Each input gets IndexedSlices gradients
192      # only for the relevant indices.
193      start = constant_op.constant(0, dtype=grad.indices.dtype)
194      for size in sizes:
195        size_concat_dim = array_ops.gather(size, non_neg_concat_dim)
196        if size_concat_dim.dtype != grad.indices.dtype:
197          size_concat_dim = math_ops.cast(
198              size_concat_dim, dtype=grad.indices.dtype)
199        end = start + size_concat_dim
200        # Compute the 1-D Tensor of indices relevant for this input.
201        indices_to_select = array_ops.squeeze(
202            array_ops.where(
203                math_ops.logical_and(grad.indices >= start,
204                                     grad.indices < end)),
205            axis=[1])
206        new_indices = array_ops.gather(grad.indices, indices_to_select) - start
207        new_values = array_ops.gather(grad.values, indices_to_select)
208        out_grads.append(ops.IndexedSlices(new_values, new_indices, size))
209        start = end
210  else:
211    raise TypeError("Expected Tensor or IndexedSlices, got %s" % type(grad))
212
213  return (out_grads + [None] if end_value_index <= dim_index else [None] +
214          out_grads)
215
216
217@ops.RegisterGradient("Concat")
218def _ConcatGrad(op, grad):
219  return _ConcatGradHelper(
220      op,
221      grad,
222      start_value_index=1,
223      end_value_index=len(op.inputs),
224      dim_index=0)
225
226
227@ops.RegisterGradient("ConcatV2")
228def _ConcatGradV2(op, grad):
229  return _ConcatGradHelper(
230      op, grad, start_value_index=0, end_value_index=-1, dim_index=-1)
231
232
233ops.NotDifferentiable("ConcatOffset")
234
235
236@ops.RegisterGradient("Slice")
237def _SliceGrad(op, grad):
238  """Gradient for Slice op."""
239  # Create an Nx2 padding where the first column represents how many
240  # zeros are to be prepended for each dimension, and the second
241  # column indicates how many zeros are appended.
242  #
243  # The number of zeros to append is the shape of the input
244  # elementwise-subtracted by both the begin vector and sizes vector.
245  #
246  # Some more reshaping is needed to assemble this tensor with the
247  # right dimensions.
248  input_vec = op.inputs[0]
249  begin_vec = op.inputs[1]
250  input_rank = array_ops.rank(input_vec)
251  slice_size = array_ops.shape(op.outputs[0])
252  if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()):
253    return gen_xla_ops.xla_dynamic_update_slice(array_ops.zeros_like(input_vec),
254                                                grad, begin_vec), None, None
255
256  shape = array_ops.stack([input_rank, 1])
257  before_pad = array_ops.reshape(begin_vec, shape)
258  after_pad = array_ops.reshape(
259      array_ops.shape(input_vec) - slice_size - begin_vec, shape)
260  paddings = array_ops.concat([before_pad, after_pad], 1)
261  return array_ops.pad(grad, paddings), None, None
262
263
264@ops.RegisterGradient("StridedSlice")
265def _StridedSliceGrad(op, grad):
266  """Gradient for StridedSlice op."""
267  begin = op.inputs[1]
268  end = op.inputs[2]
269  strides = op.inputs[3]
270  # StridedSliceGrad requires `x`, `begin`, `end` and `strides` to be of the
271  # same dtype so we build a shape of the same type as other args.
272  # Note that the choice of `begin` for specifying `out_type` is arbitrary.
273  # We could choose any of {begin|end|strides}.dtype since they are required to
274  # be the same.
275  x = array_ops.shape(op.inputs[0], out_type=begin.dtype)
276
277  x_static = tensor_util.constant_value(x)
278  x = x_static if x_static is not None else x
279  begin_static = tensor_util.constant_value(begin)
280  begin = begin_static if begin_static is not None else begin
281  end_static = tensor_util.constant_value(end)
282  end = end_static if end_static is not None else end
283  strides_static = tensor_util.constant_value(strides)
284  strides = strides_static if strides_static is not None else strides
285
286  return array_ops.strided_slice_grad(
287      x,
288      begin,
289      end,
290      strides,
291      grad,
292      begin_mask=op.get_attr("begin_mask"),
293      end_mask=op.get_attr("end_mask"),
294      ellipsis_mask=op.get_attr("ellipsis_mask"),
295      new_axis_mask=op.get_attr("new_axis_mask"),
296      shrink_axis_mask=op.get_attr("shrink_axis_mask")), None, None, None
297
298
299@ops.RegisterGradient("StridedSliceGrad")
300def _StridedSliceGradGrad(op, grad):
301  """Gradient for StridedSliceGrad op."""
302  begin = op.inputs[1]
303  end = op.inputs[2]
304  strides = op.inputs[3]
305
306  return None, None, None, None, array_ops.strided_slice(
307      grad,
308      begin,
309      end,
310      strides,
311      begin_mask=op.get_attr("begin_mask"),
312      end_mask=op.get_attr("end_mask"),
313      ellipsis_mask=op.get_attr("ellipsis_mask"),
314      new_axis_mask=op.get_attr("new_axis_mask"),
315      shrink_axis_mask=op.get_attr("shrink_axis_mask"))
316
317
318@ops.RegisterGradient("TensorStridedSliceUpdate")
319def _TensorStridedSliceUpdateGrad(op, grad):  # pylint:disable=missing-function-docstring
320  begin = op.inputs[1]
321  end = op.inputs[2]
322  strides = op.inputs[3]
323  begin_mask = op.get_attr("begin_mask")
324  end_mask = op.get_attr("end_mask")
325  ellipsis_mask = op.get_attr("ellipsis_mask")
326  new_axis_mask = op.get_attr("new_axis_mask")
327  shrink_axis_mask = op.get_attr("shrink_axis_mask")
328  def Apply(f, *args):
329    return f(*args,
330             begin_mask=begin_mask,
331             end_mask=end_mask,
332             shrink_axis_mask=shrink_axis_mask,
333             new_axis_mask=new_axis_mask,
334             ellipsis_mask=ellipsis_mask)
335  dy = Apply(array_ops.strided_slice,
336             grad, begin, end, strides)
337  dx = Apply(array_ops.tensor_strided_slice_update,
338             grad, begin, end, strides, array_ops.zeros_like(dy))
339  return dx, None, None, None, dy
340
341
342@ops.RegisterGradient("Split")
343def _SplitGrad(op, *grads):
344  return None, array_ops.concat(list(grads), op.inputs[0])
345
346
347@ops.RegisterGradient("SplitV")
348def _SplitVGrad(op, *grads):
349  returnval = array_ops.concat(list(grads), op.inputs[2])
350  returnval = [returnval] + [
351      None,
352  ] * (
353      len(op.inputs) - 1)
354  return returnval
355
356
357ops.NotDifferentiable("Const")
358
359
360@ops.RegisterGradient("Diag")
361def _DiagGrad(_, grad):
362  return array_ops.diag_part(grad)
363
364
365@ops.RegisterGradient("DiagPart")
366def _DiagPartGrad(_, grad):
367  return array_ops.diag(grad)
368
369
370@ops.RegisterGradient("MatrixDiag")
371def _MatrixDiagGrad(_, grad):
372  return array_ops.matrix_diag_part(grad)
373
374
375@ops.RegisterGradient("MatrixDiagV2")
376def _MatrixDiagV2Grad(op, grad):
377  return array_ops.matrix_diag_part(
378      grad, k=op.inputs[1]), None, None, None, None
379
380
381@ops.RegisterGradient("MatrixDiagV3")
382def _MatrixDiagV3Grad(op, grad):
383  return array_ops.matrix_diag_part(
384      grad, k=op.inputs[1], align=op.get_attr("align")), None, None, None, None
385
386
387@ops.RegisterGradient("MatrixDiagPart")
388def _MatrixDiagPartGrad(op, grad):
389  matrix_shape = op.inputs[0].get_shape()[-2:]
390  if matrix_shape.is_fully_defined() and matrix_shape[0] == matrix_shape[1]:
391    return array_ops.matrix_diag(grad)
392  else:
393    return array_ops.matrix_set_diag(array_ops.zeros_like(op.inputs[0]), grad)
394
395
396@ops.RegisterGradient("MatrixDiagPartV2")
397def _MatrixDiagPartV2Grad(op, grad):
398  """Gradient for MatrixDiagPartV2."""
399  matrix_shape = op.inputs[0].get_shape()[-2:]
400  if matrix_shape.is_fully_defined():
401    return array_ops.matrix_diag(
402        grad,
403        k=op.inputs[1],
404        num_rows=matrix_shape[0],
405        num_cols=matrix_shape[1]), None, None
406  else:
407    return array_ops.matrix_set_diag(
408        array_ops.zeros_like(op.inputs[0]), grad, k=op.inputs[1]), None, None
409
410
411@ops.RegisterGradient("MatrixDiagPartV3")
412def _MatrixDiagPartV3Grad(op, grad):
413  """Gradient for MatrixDiagPartV3."""
414  matrix_shape = op.inputs[0].get_shape()[-2:]
415  align = op.get_attr("align")
416  if matrix_shape.is_fully_defined():
417    return array_ops.matrix_diag(
418        grad,
419        k=op.inputs[1],
420        num_rows=matrix_shape[0],
421        num_cols=matrix_shape[1],
422        align=align), None, None
423  else:
424    return array_ops.matrix_set_diag(
425        array_ops.zeros_like(op.inputs[0]), grad, k=op.inputs[1],
426        align=align), None, None
427
428
429@ops.RegisterGradient("MatrixSetDiag")
430def _MatrixSetDiagGrad(op, grad):
431  """Gradient for MatrixSetDiag."""
432  input_shape = op.inputs[0].get_shape().merge_with(grad.get_shape())
433  diag_shape = op.inputs[1].get_shape()
434  batch_shape = input_shape[:-2].merge_with(diag_shape[:-1])
435  matrix_shape = input_shape[-2:]
436  if batch_shape.is_fully_defined() and matrix_shape.is_fully_defined():
437    diag_shape = batch_shape.as_list() + [min(matrix_shape.as_list())]
438  else:
439    with ops.colocate_with(grad):
440      grad_shape = array_ops.shape(grad)
441      grad_rank = array_ops.rank(grad)
442      batch_shape = array_ops.slice(grad_shape, [0], [grad_rank - 2])
443      matrix_shape = array_ops.slice(grad_shape, [grad_rank - 2], [2])
444      min_dim = math_ops.reduce_min(matrix_shape)
445      diag_shape = array_ops.concat([batch_shape, [min_dim]], 0)
446  grad_input = array_ops.matrix_set_diag(
447      grad, array_ops.zeros(diag_shape, dtype=grad.dtype))
448  grad_diag = array_ops.matrix_diag_part(grad)
449  return (grad_input, grad_diag)
450
451
452@ops.RegisterGradient("MatrixSetDiagV2")
453def _MatrixSetDiagGradV2(op, grad):
454  """Gradient for MatrixSetDiagV2."""
455  diag_shape = op.inputs[1].get_shape()
456  if not diag_shape.is_fully_defined():
457    # Need to know the values of `d_lower` and `d_upper` to infer diag_shape.
458    grad_shape = array_ops.shape(grad)
459    batch_shape = grad_shape[:-2]
460    matrix_shape = grad_shape[-2:]
461    diag_index = array_ops.reshape(op.inputs[2], [-1])  # Converts to vector.
462    d_lower = diag_index[0]
463    d_upper = diag_index[-1]  # Works both when len(diag_index) is 1 and 2.
464    y_offset = control_flow_ops.cond(
465        math_ops.less(d_upper, 0), lambda: d_upper, lambda: 0)
466    x_offset = control_flow_ops.cond(
467        math_ops.greater(d_lower, 0), lambda: -d_lower, lambda: 0)
468
469    max_diag_len = math_ops.minimum(matrix_shape[0] + y_offset,
470                                    matrix_shape[1] + x_offset)
471    # pylint: disable=g-long-lambda
472    # pyformat: disable
473    postfix = control_flow_ops.cond(
474        math_ops.equal(d_lower, d_upper),
475        lambda: ops.convert_to_tensor([max_diag_len]),
476        lambda: ops.convert_to_tensor([d_upper - d_lower + 1,
477                                       max_diag_len]))
478    # pyformat: enable
479    # pylint: enable=g-long-lambda
480    diag_shape = array_ops.concat([batch_shape, postfix], 0)
481
482  grad_input = array_ops.matrix_set_diag(
483      grad, array_ops.zeros(diag_shape, dtype=grad.dtype), k=op.inputs[2])
484  grad_diag = array_ops.matrix_diag_part(grad, k=op.inputs[2])
485  return (grad_input, grad_diag, None)
486
487
488@ops.RegisterGradient("MatrixSetDiagV3")
489def _MatrixSetDiagGradV3(op, grad):
490  """Gradient for MatrixSetDiagV3."""
491  diag_shape = op.inputs[1].get_shape()
492  align = op.get_attr("align")
493  if not diag_shape.is_fully_defined():
494    # Need to know the values of `d_lower` and `d_upper` to infer diag_shape.
495    grad_shape = array_ops.shape(grad)
496    batch_shape = grad_shape[:-2]
497    matrix_shape = grad_shape[-2:]
498    diag_index = array_ops.reshape(op.inputs[2], [-1])  # Converts to vector.
499    d_lower = diag_index[0]
500    d_upper = diag_index[-1]  # Works both when len(diag_index) is 1 and 2.
501    y_offset = control_flow_ops.cond(
502        math_ops.less(d_upper, 0), lambda: d_upper, lambda: 0)
503    x_offset = control_flow_ops.cond(
504        math_ops.greater(d_lower, 0), lambda: -d_lower, lambda: 0)
505
506    max_diag_len = math_ops.minimum(matrix_shape[0] + y_offset,
507                                    matrix_shape[1] + x_offset)
508    # pylint: disable=g-long-lambda
509    # pyformat: disable
510    postfix = control_flow_ops.cond(
511        math_ops.equal(d_lower, d_upper),
512        lambda: ops.convert_to_tensor([max_diag_len]),
513        lambda: ops.convert_to_tensor([d_upper - d_lower + 1,
514                                       max_diag_len]))
515    # pyformat: enable
516    # pylint: enable=g-long-lambda
517    diag_shape = array_ops.concat([batch_shape, postfix], 0)
518
519  grad_input = array_ops.matrix_set_diag(
520      grad,
521      array_ops.zeros(diag_shape, dtype=grad.dtype),
522      k=op.inputs[2],
523      align=align)
524  grad_diag = array_ops.matrix_diag_part(grad, k=op.inputs[2], align=align)
525  return (grad_input, grad_diag, None)
526
527
528@ops.RegisterGradient("MatrixBandPart")
529def _MatrixBandPartGrad(op, grad):
530  num_lower = op.inputs[1]
531  num_upper = op.inputs[2]
532  return (array_ops.matrix_band_part(grad, num_lower, num_upper), None, None)
533
534
535# Edit Distance has no gradient (but can be used to eval seq2seq or CTC).
536ops.NotDifferentiable("EditDistance")
537
538
539@ops.RegisterGradient("Fill")
540def _FillGrad(_, grad):
541  return None, math_ops.reduce_sum(grad)
542
543
544ops.NotDifferentiable("ZerosLike")
545ops.NotDifferentiable("OnesLike")
546
547
548@ops.RegisterGradient("PreventGradient")
549def _PreventGradientGrad(op, _):
550  raise LookupError("Gradient explicitly disabled. Reason: %s" %
551                    op.get_attr("message"))
552
553
554def _IndexedSlicesToTensorNoWarning(indexed_slices):
555  """Converts an IndexedSlices to a Tensor without sparse->dense warnings."""
556  if not isinstance(indexed_slices, ops.IndexedSlices):
557    # If it is not IndexedSlices, it's better be a tensor.
558    return indexed_slices
559  if indexed_slices.dense_shape is None:
560    raise ValueError(
561        "Tensor conversion requested for IndexedSlices without dense_shape: %s"
562        % str(indexed_slices))
563  return math_ops.unsorted_segment_sum(indexed_slices.values,
564                                       indexed_slices.indices,
565                                       indexed_slices.dense_shape[0])
566
567
568@ops.RegisterGradient("Gather")
569def _GatherGrad(op, grad):
570  """Gradient for Gather op."""
571  # params can be large, so colocate the shape calculation with it.
572  params = op.inputs[0]
573  with ops.colocate_with(params):
574    params_shape = array_ops.shape(params)
575
576  # Build appropriately shaped IndexedSlices
577  indices = op.inputs[1]
578  size = array_ops.expand_dims(array_ops.size(indices), 0)
579  values_shape = array_ops.concat([size, params_shape[1:]], 0)
580  values = array_ops.reshape(
581      _IndexedSlicesToTensorNoWarning(grad), values_shape)
582  indices = array_ops.reshape(indices, size)
583  return [ops.IndexedSlices(values, indices, params_shape), None]
584
585
586def _GetBatchIndices(params_shape, indices, batch_dims):
587  """Addds the batch offsets to the given indices and returns the results."""
588  batch_indices = indices
589  indices_ndims = indices.shape.ndims
590  indices_dtype = indices.dtype.base_dtype
591  casted_params_shape = math_ops.cast(params_shape, indices_dtype)
592  accum_dim_value = array_ops.ones((), dtype=indices_dtype)
593  for dim in range(batch_dims, 0, -1):
594    dim_value = casted_params_shape[dim - 1]
595    accum_dim_value *= casted_params_shape[dim]
596    start = array_ops.zeros((), dtype=indices_dtype)
597    step = array_ops.ones((), dtype=indices_dtype)
598    dim_indices = math_ops.range(start, dim_value, step)
599    dim_indices *= accum_dim_value
600    dim_shape = array_ops.stack(
601        [1] * (dim - 1) + [dim_value] + [1] * (indices_ndims - dim), axis=0)
602    batch_indices += array_ops.reshape(dim_indices, dim_shape)
603
604  return batch_indices
605
606
607def _BatchGatherGrad(params_shape, values, indices, batch_dims,
608                     gather_dim_size):
609  """Returns the gradient of GatherV2 with batch dimensions."""
610
611  # Axis is the first non-batch dimension.
612  indices_size = array_ops.expand_dims(array_ops.size(indices), 0)
613  if batch_dims:
614    values_shape = array_ops.shape(values)
615    # Add the batch offsets to indices and flatten the batch dimensions.
616    outer_shape = values_shape[:batch_dims]
617    inner_shape = values_shape[batch_dims:][1:]
618    batch_size = gen_math_ops.prod(outer_shape, [0], False)
619    flat_values_shape = array_ops.concat([[-1], inner_shape], 0)
620    gather_dim_size *= batch_size
621
622    indices = _GetBatchIndices(params_shape, indices, batch_dims)
623    values = array_ops.reshape(
624        _IndexedSlicesToTensorNoWarning(values), flat_values_shape)
625
626  indices = array_ops.reshape(indices, indices_size)
627  params_grad = math_ops.unsorted_segment_sum(values, indices, gather_dim_size)
628
629  if batch_dims:
630    # Put back the batch dimensions.
631    params_grad = array_ops.reshape(
632        params_grad, array_ops.concat([outer_shape, flat_values_shape], 0))
633
634  return params_grad
635
636
637@ops.RegisterGradient("GatherV2")
638def _GatherV2Grad(op, grad):
639  """Gradient for GatherV2 op."""
640  # params can be large, so colocate the shape calculation with it.
641  #
642  # params can be very large for sparse model, array_ops.shape raises
643  # exception on the Windows platform when any dimension is larger than
644  # int32. params_shape is not used in optimizer apply_sparse gradients,
645  # so it's fine to convert it back to int32 regardless of truncation.
646  params = op.inputs[0]
647  with ops.colocate_with(params):
648    params_shape = array_ops.shape(params, out_type=ops.dtypes.int64)
649    params_shape = math_ops.cast(params_shape, dtypes.int32)
650
651  indices = op.inputs[1]
652  indices_size = array_ops.expand_dims(array_ops.size(indices), 0)
653  axis = op.inputs[2]
654  axis_static = tensor_util.constant_value(axis)
655  batch_dims = int(op.get_attr("batch_dims"))
656
657  if batch_dims < 0:
658    batch_dims += indices.shape.ndims
659
660  # For axis 0 gathers, build an appropriately shaped IndexedSlices.
661  if axis_static == 0:
662    if context.executing_eagerly():
663      with ops.device(indices_size.device):
664        params_tail_shape = array_ops.identity(params_shape)[1:]
665    else:
666      params_tail_shape = params_shape[1:]
667    values_shape = array_ops.concat([indices_size, params_tail_shape], 0)
668    values = array_ops.reshape(
669        _IndexedSlicesToTensorNoWarning(grad), values_shape)
670    indices = array_ops.reshape(indices, indices_size)
671    params_grad = ops.IndexedSlices(values, indices, params_shape)
672  else:
673    # Handle axis by transposing the axis dimension to be the first non-batch
674    # dimension, compute the gradient and transpose the result back.
675    outer_shape = params_shape[:axis]
676    inner_shape = params_shape[axis:][1:]
677    values_shape = array_ops.concat([outer_shape, [-1], inner_shape], 0)
678
679    values_dims = array_ops.size(values_shape)
680    axis_dims = array_ops.size(outer_shape)
681
682    outer_batches_indices = math_ops.range(batch_dims)
683    batch_axis_indices = math_ops.range(batch_dims, axis_dims)
684    inner_axes_indices = math_ops.range(axis_dims + 1, values_dims)
685
686    values = array_ops.reshape(
687        _IndexedSlicesToTensorNoWarning(grad), values_shape)
688
689    # Move values[axis] up to values[batch_dims]
690    transpose_dims = array_ops.concat([
691        outer_batches_indices, [axis_dims], batch_axis_indices,
692        inner_axes_indices
693    ], 0)
694    values_transpose = array_ops.transpose(values, transpose_dims)
695    params_shape_transpose = array_ops.gather(params_shape, transpose_dims)
696
697    params_grad = _BatchGatherGrad(params_shape_transpose, values_transpose,
698                                   indices, batch_dims, params_shape[axis])
699
700    # Inverts the above transpose by moving dimension batch_dims back to its
701    # original position.
702    invert_transpose_dims = array_ops.concat([
703        outer_batches_indices, batch_axis_indices + 1, [batch_dims],
704        inner_axes_indices
705    ], 0)
706    params_grad = array_ops.transpose(params_grad, invert_transpose_dims)
707
708  return [params_grad, None, None]
709
710
711@ops.RegisterGradient("GatherNd")
712def _GatherNdGrad(op, grad):
713  ref = op.inputs[0]
714  indices = op.inputs[1]
715  ref_shape = array_ops.shape(ref, out_type=indices.dtype)
716  if indices.shape.ndims == 2 and indices.shape.dims[-1].value == 1:
717    ref_grad = ops.IndexedSlices(grad, array_ops.squeeze(indices, axis=-1),
718                                 ref_shape)
719  else:
720    ref_grad = array_ops.scatter_nd(indices, grad, ref_shape)
721  return [ref_grad, None]
722
723
724@ops.RegisterGradient("ResourceGatherNd")
725def _ResourceGatherNdGrad(op, grad):  # pylint: disable=missing-docstring
726  ref = op.inputs[0]
727  indices = op.inputs[1]
728  ref_shape = gen_resource_variable_ops.variable_shape(ref, indices.dtype)
729  if indices.shape.ndims == 2 and indices.shape.dims[-1].value == 1:
730    ref_grad = ops.IndexedSlices(grad, array_ops.squeeze(indices, axis=-1),
731                                 ref_shape)
732  else:
733    ref_grad = array_ops.scatter_nd(indices, grad, ref_shape)
734  return [ref_grad, None]
735
736
737@ops.RegisterGradient("CheckNumerics")
738def _CheckNumericsGrad(op, grad):
739  """Gradient for check_numerics op."""
740  return array_ops.check_numerics(
741      grad,
742      "Not a number (NaN) or infinity (Inf) values detected in gradient. %s" %
743      op.get_attr("message"))
744
745
746@ops.RegisterGradient("CheckNumericsV2")
747def _CheckNumericsV2Grad(op, grad):
748  """Gradient for check_numerics op."""
749  return array_ops.check_numerics_v2(
750      grad,
751      "Not a number (NaN) or infinity (Inf) values detected in gradient. %s" %
752      op.get_attr("message"))
753
754
755@ops.RegisterGradient("PlaceholderWithDefault")
756@ops.RegisterGradient("Identity")
757def _IdGrad(_, grad):
758  return grad
759
760
761@ops.RegisterGradient("_EagerConst")
762def _EagerConstGrad(_, grad):
763  raise AssertionError(
764      "This op should never interact with gradient APIs. Please file a bug.")
765
766
767@ops.RegisterGradient("RefIdentity")
768def _RefIdGrad(_, grad):
769  return grad
770
771
772@ops.RegisterGradient("IdentityN")
773def _IdNGrad(_, *grad):
774  return grad
775
776
777ops.NotDifferentiable("StopGradient")
778
779
780@ops.RegisterGradient("Reshape")
781def _ReshapeGrad(op, grad):
782  return [
783      array_ops.reshape(
784          _IndexedSlicesToTensorNoWarning(grad), array_ops.shape(op.inputs[0])),
785      None
786  ]
787
788
789ops.NotDifferentiable("InvertPermutation")
790
791
792def _ReshapeToInput(op, grad):
793  """Reshapes the gradient to the shape of the original input."""
794  return array_ops.reshape(
795      _IndexedSlicesToTensorNoWarning(grad), array_ops.shape(op.inputs[0]))
796
797
798@ops.RegisterGradient("ExpandDims")
799def _ExpandDimsGrad(op, grad):
800  return [_ReshapeToInput(op, grad), None]
801
802
803@ops.RegisterGradient("Squeeze")
804def _SqueezeGrad(op, grad):
805  return _ReshapeToInput(op, grad)
806
807
808@ops.RegisterGradient("Transpose")
809def _TransposeGrad(op, grad):
810  """Returns unshuffle(grad)."""
811  p = op.inputs[1]
812  return [array_ops.transpose(grad, array_ops.invert_permutation(p)), None]
813
814
815@ops.RegisterGradient("ConjugateTranspose")
816def _ConjugateTransposeGrad(op, grad):
817  """Returns conj(unshuffle(grad))."""
818  p = op.inputs[1]
819  return [
820      array_ops.transpose(
821          grad, array_ops.invert_permutation(p), conjugate=True), None
822  ]
823
824
825ops.NotDifferentiable("Shape")
826
827ops.NotDifferentiable("ShapeN")
828
829ops.NotDifferentiable("Rank")
830
831ops.NotDifferentiable("Size")
832
833
834@ops.RegisterGradient("Tile")
835def _TileGrad(op, grad):
836  """Sum reduces grad along the tiled dimensions."""
837  input_shape = array_ops.shape(op.inputs[0], out_type=op.inputs[1].dtype)
838  # We interleave multiples and input_shape to get split_shape,
839  # reshape grad to split_shape, and reduce along all even
840  # dimensions (the tiled dimensions) to get the result
841  # with shape input_shape.  For example
842  #   input_shape = [20, 30, 40]
843  #   multiples = [2, 3, 4]
844  #   split_shape = [2, 20, 3, 30, 4, 40]
845  #   axes = [0, 2, 4]
846  split_shape = array_ops.reshape(
847      array_ops.transpose(array_ops.stack([op.inputs[1], input_shape])), [-1])
848  axes = math_ops.range(0, array_ops.size(split_shape), 2)
849  # Sum reduces grad along the first dimension for IndexedSlices
850  if isinstance(grad, ops.IndexedSlices):
851    input_shape_0 = math_ops.cast(input_shape[0], grad.indices.dtype)
852    grad = math_ops.unsorted_segment_sum(
853        grad.values, math_ops.mod(grad.indices, input_shape_0), input_shape_0)
854    split_shape = array_ops.concat([[1], split_shape[1:]], axis=0)
855  input_grad = math_ops.reduce_sum(array_ops.reshape(grad, split_shape), axes)
856  # Fix shape inference
857  if not context.executing_eagerly():
858    input_grad.set_shape(op.inputs[0].get_shape())
859  return [input_grad, None]
860
861
862ops.NotDifferentiable("BroadcastGradientArgs")
863
864
865def _PadGrad(op, grad):
866  """Gradient for Pad."""
867  # Pad introduces values around the original tensor, so the gradient function
868  # slices the original shape out of the gradient."""
869  x = op.inputs[0]
870  a = op.inputs[1]  # [Rank(x), 2]
871  # Takes a slice of a. The 1st column. [Rank(x), 1].
872  pad_before = array_ops.slice(a, [0, 0],
873                               array_ops.stack([array_ops.rank(x), 1]))
874  # Make it a 1-D tensor.
875  begin = array_ops.reshape(pad_before, [-1])
876  sizes = array_ops.shape(x, out_type=begin.dtype)
877  x_grad = array_ops.slice(grad, begin, sizes)
878  if len(op.inputs) == 3:
879    return x_grad, None, None
880  else:
881    return x_grad, None
882
883
884ops.RegisterGradient("Pad")(_PadGrad)
885ops.RegisterGradient("PadV2")(_PadGrad)
886
887
888# ReverseSequence is just a permutation.  The gradient permutes back.
889@ops.RegisterGradient("ReverseSequence")
890def _ReverseSequenceGrad(op, grad):
891  seq_lengths = op.inputs[1]
892  return [
893      array_ops.reverse_sequence(
894          grad,
895          batch_axis=op.get_attr("batch_dim"),
896          seq_axis=op.get_attr("seq_dim"),
897          seq_lengths=seq_lengths), None
898  ]
899
900
901@ops.RegisterGradient("Reverse")
902def _ReverseGrad(op, grad):
903  reverse_dims = op.inputs[1]
904  return gen_array_ops.reverse(grad, reverse_dims), None
905
906
907@ops.RegisterGradient("ReverseV2")
908def _ReverseV2Grad(op, grad):
909  axis = op.inputs[1]
910  return array_ops.reverse_v2(grad, axis), None
911
912
913@ops.RegisterGradient("SpaceToBatch")
914def _SpaceToBatchGrad(op, grad):
915  # Its gradient is the opposite op: BatchToSpace.
916  block_size = op.get_attr("block_size")
917  return [
918      array_ops.batch_to_space(grad, op.inputs[1], block_size=block_size), None
919  ]
920
921
922@ops.RegisterGradient("SpaceToBatchND")
923def _SpaceToBatchNDGrad(op, grad):
924  # Its gradient is the opposite op: BatchToSpaceND.
925  return [
926      array_ops.batch_to_space_nd(grad, op.inputs[1], op.inputs[2]), None, None
927  ]
928
929
930@ops.RegisterGradient("BatchToSpace")
931def _BatchToSpaceGrad(op, grad):
932  # Its gradient is the opposite op: SpaceToBatch.
933  block_size = op.get_attr("block_size")
934  return [
935      array_ops.space_to_batch(grad, op.inputs[1], block_size=block_size), None
936  ]
937
938
939@ops.RegisterGradient("BatchToSpaceND")
940def _BatchToSpaceNDGrad(op, grad):
941  # Its gradient is the opposite op: SpaceToBatchND.
942  return [
943      array_ops.space_to_batch_nd(grad, op.inputs[1], op.inputs[2]), None, None
944  ]
945
946
947@ops.RegisterGradient("SpaceToDepth")
948def _SpaceToDepthGrad(op, grad):
949  # Its gradient is the opposite op: DepthToSpace.
950  block_size = op.get_attr("block_size")
951  data_format = op.get_attr("data_format")
952  if data_format == "NCHW_VECT_C":
953    raise ValueError("Cannot compute SpaceToDepth gradient with NCHW_VECT_C. "
954                     "NCHW_VECT_C requires qint8 data type.")
955  return array_ops.depth_to_space(grad, block_size, data_format=data_format)
956
957
958@ops.RegisterGradient("DepthToSpace")
959def _DepthToSpaceGrad(op, grad):
960  # Its gradient is the opposite op: SpaceToDepth.
961  block_size = op.get_attr("block_size")
962  data_format = op.get_attr("data_format")
963  if data_format == "NCHW_VECT_C":
964    raise ValueError("Cannot compute DepthToSpace gradient with NCHW_VECT_C. "
965                     "NCHW_VECT_C requires qint8 data type.")
966  return array_ops.space_to_depth(grad, block_size, data_format=data_format)
967
968
969ops.NotDifferentiable("OneHot")
970
971
972@ops.RegisterGradient("MirrorPad")
973def _MirrorPadGrad(op, grad):
974  mode = op.get_attr("mode")
975  return [gen_array_ops.mirror_pad_grad(grad, op.inputs[1], mode=mode), None]
976
977
978@ops.RegisterGradient("MirrorPadGrad")
979def _MirrorPadGradGrad(op, grad):
980  mode = op.get_attr("mode")
981  return [gen_array_ops.mirror_pad(grad, op.inputs[1], mode=mode), None]
982
983
984@ops.RegisterGradient("QuantizeAndDequantize")
985def _QuantizeAndDequantizeGrad(_, grad):
986  return grad
987
988
989@ops.RegisterGradient("QuantizeAndDequantizeV2")
990def _QuantizeAndDequantizeV2Grad(_, grad):
991  return [grad, None, None]
992
993
994@ops.RegisterGradient("QuantizeAndDequantizeV3")
995def _QuantizeAndDequantizeV3Grad(_, grad):
996  # Only propagate the gradient for the unquantized input.
997  return [grad, None, None, None]
998
999
1000@ops.RegisterGradient("ExtractImagePatches")
1001def _ExtractImagePatchesGrad(op, grad):
1002  input_bhwc = array_ops.shape(op.inputs[0], out_type=dtypes.int64)
1003  batch_size, rows_in, cols_in, channels = input_bhwc[0], input_bhwc[1], \
1004                                           input_bhwc[2], input_bhwc[3]
1005
1006  # Create indices matrix for input tensor.
1007  # Note that 0 is preserved for padding location,
1008  # so indices for input start from 1 to 1 + rows_in * cols_in.
1009  input_indices_num = 1 + rows_in * cols_in
1010  input_idx = array_ops.reshape(
1011      math_ops.range(1, input_indices_num, dtype=ops.dtypes.int64),
1012      (1, rows_in, cols_in, 1))
1013  input_idx_patched = gen_array_ops.extract_image_patches(
1014      input_idx, op.get_attr("ksizes"), op.get_attr("strides"),
1015      op.get_attr("rates"), op.get_attr("padding"))
1016
1017  # Create indices matrix for output tensor.
1018  output_bhwc = array_ops.shape(op.outputs[0], out_type=dtypes.int64)
1019  rows_out, cols_out = output_bhwc[1], output_bhwc[2]
1020  _, ksize_r, ksize_c, _ = op.get_attr("ksizes")
1021  # Indices for output start from 0.
1022  output_indices_num = rows_out * cols_out * ksize_r * ksize_c
1023  output_idx = array_ops.reshape(
1024      math_ops.range(output_indices_num, dtype=ops.dtypes.int64),
1025      (1, rows_out, cols_out, ksize_r * ksize_c))
1026
1027  # Construct mapping table for indices: (input -> output).
1028  idx_matrix = array_ops.concat([
1029      array_ops.expand_dims(input_idx_patched, axis=-1),
1030      array_ops.expand_dims(output_idx, axis=-1)
1031  ],
1032                                axis=-1)
1033  idx_map = array_ops.reshape(idx_matrix, (-1, 2))
1034
1035  sp_shape = (input_indices_num, output_indices_num)
1036  sp_mat_full = sparse_tensor.SparseTensor(
1037      idx_map, array_ops.ones([output_indices_num], dtype=grad.dtype), sp_shape)
1038  # Remove all padding locations [0, :].
1039  sp_mat = sparse_ops.sparse_slice(sp_mat_full, (1, 0),
1040                                   (input_indices_num - 1, output_indices_num))
1041
1042  grad_expanded = array_ops.transpose(
1043      array_ops.reshape(
1044          _IndexedSlicesToTensorNoWarning(grad),
1045          (batch_size, rows_out, cols_out, ksize_r, ksize_c, channels)),
1046      (1, 2, 3, 4, 0, 5))
1047  grad_flat = array_ops.reshape(grad_expanded, (-1, batch_size * channels))
1048
1049  jac = sparse_ops.sparse_tensor_dense_matmul(sp_mat, grad_flat)
1050
1051  grad_out = array_ops.reshape(jac, (rows_in, cols_in, batch_size, channels))
1052  grad_out = array_ops.transpose(grad_out, (2, 0, 1, 3))
1053
1054  return [grad_out]
1055
1056
1057@ops.RegisterGradient("ExtractVolumePatches")
1058def _ExtractVolumePatchesGrad(op, grad):
1059  batch_size, planes_in, rows_in, cols_in, channels = [
1060      dim.value for dim in op.inputs[0].shape.dims
1061  ]
1062  input_bphwc = array_ops.shape(op.inputs[0])
1063  batch_size = input_bphwc[0]
1064  channels = input_bphwc[4]
1065
1066  # Create indices matrix for input tensor.
1067  # Note that 0 is preserved for padding location,
1068  # so indices for input start from 1 to 1 + rows_in * cols_in.
1069  input_indices_num = 1 + planes_in * rows_in * cols_in
1070  input_idx = array_ops.reshape(
1071      math_ops.range(1, input_indices_num, dtype=ops.dtypes.int64),
1072      (1, planes_in, rows_in, cols_in, 1))
1073  input_idx_patched = gen_array_ops.extract_volume_patches(
1074      input_idx, op.get_attr("ksizes"), op.get_attr("strides"),
1075      op.get_attr("padding"))
1076
1077  # Create indices matrix for output tensor.
1078  _, planes_out, rows_out, cols_out, _ = [
1079      dim.value for dim in op.outputs[0].shape.dims
1080  ]
1081  _, ksize_p, ksize_r, ksize_c, _ = op.get_attr("ksizes")
1082  # Indices for output start from 0.
1083  prc_indices_num = planes_out * rows_out * cols_out
1084  output_indices_num = prc_indices_num * ksize_p * ksize_r * ksize_c
1085  output_idx = array_ops.reshape(
1086      math_ops.range(output_indices_num, dtype=ops.dtypes.int64),
1087      (1, planes_out, rows_out, cols_out, ksize_p * ksize_r * ksize_c))
1088
1089  # Construct mapping table for indices: (input -> output).
1090  idx_matrix = array_ops.concat([
1091      array_ops.expand_dims(input_idx_patched, axis=-1),
1092      array_ops.expand_dims(output_idx, axis=-1)
1093  ],
1094                                axis=-1)
1095  idx_map = array_ops.reshape(idx_matrix, (-1, 2))
1096
1097  sp_shape = (input_indices_num, output_indices_num)
1098  sp_mat_full = sparse_tensor.SparseTensor(
1099      idx_map, array_ops.ones([output_indices_num], dtype=grad.dtype), sp_shape)
1100  # Remove all padding locations [0, :].
1101  sp_mat = sparse_ops.sparse_slice(sp_mat_full, (1, 0),
1102                                   (input_indices_num - 1, output_indices_num))
1103
1104  grad_expanded = array_ops.transpose(
1105      array_ops.reshape(
1106          _IndexedSlicesToTensorNoWarning(grad),
1107          (batch_size, planes_out, rows_out, cols_out, ksize_p, ksize_r,
1108           ksize_c, channels)), (1, 2, 3, 4, 5, 6, 0, 7))
1109  grad_flat = array_ops.reshape(grad_expanded, (-1, batch_size * channels))
1110
1111  jac = sparse_ops.sparse_tensor_dense_matmul(sp_mat, grad_flat)
1112
1113  grad_out = array_ops.reshape(
1114      jac, (planes_in, rows_in, cols_in, batch_size, channels))
1115  grad_out = array_ops.transpose(grad_out, (3, 0, 1, 2, 4))
1116
1117  return [grad_out]
1118
1119
1120@ops.RegisterGradient("ScatterNd")
1121def _ScatterNdGrad(op, grad):
1122  indices = op.inputs[0]
1123  updates_grad = array_ops.gather_nd(grad, indices)
1124  return [None, updates_grad, None]
1125
1126
1127@ops.RegisterGradient("TensorScatterUpdate")
1128def _TensorScatterUpdateGrad(op, grad):
1129  indices = op.inputs[1]
1130  updates_grad = array_ops.gather_nd(grad, indices)
1131  tensor_grad = array_ops.tensor_scatter_update(
1132      array_ops.identity(grad), indices,
1133      array_ops.zeros_like(op.inputs[2], dtype=grad.dtype))
1134  return [tensor_grad, None, updates_grad]
1135
1136
1137@ops.RegisterGradient("TensorScatterAdd")
1138def _TensorScatterAddGrad(op, grad):
1139  indices = op.inputs[1]
1140  updates_grad = array_ops.gather_nd(grad, indices)
1141  tensor_grad = array_ops.identity(grad)
1142  return [tensor_grad, None, updates_grad]
1143
1144
1145def _TensorScatterMinOrMaxGrad(op, grad):
1146  """Gradient for TensorScatterMin and TensorScatterMax."""
1147  indices = op.inputs[1]
1148  x = op.inputs[0]
1149  y = op.inputs[2]
1150  output = op.outputs[0]
1151  x_indicators = math_ops.cast(math_ops.equal(x, output), grad.dtype)
1152  y_output = array_ops.gather_nd(output, indices)
1153  y_indicators = math_ops.cast(math_ops.equal(y, y_output), grad.dtype)
1154  ys_indicators = array_ops.scatter_nd(indices, y_indicators,
1155                                       array_ops.shape(x))
1156  indicators = x_indicators + ys_indicators  # All elements are >= 1.
1157  # If there are multiple minimum or maximum elements then the gradient will be
1158  # divided between them.
1159  x_grad = grad * x_indicators / indicators
1160  y_grad = array_ops.gather_nd(grad / indicators, indices) * y_indicators
1161  return [x_grad, None, y_grad]
1162
1163
1164@ops.RegisterGradient("TensorScatterMax")
1165def _TensorScatterMaxGrad(op, grad):
1166  """Gradient for TensorScatterMax op."""
1167  return _TensorScatterMinOrMaxGrad(op, grad)
1168
1169
1170@ops.RegisterGradient("TensorScatterMin")
1171def _TensorScatterMinGrad(op, grad):
1172  """Gradient for TensorScatterMin op."""
1173  return _TensorScatterMinOrMaxGrad(op, grad)
1174
1175
1176@ops.RegisterGradient("TensorScatterSub")
1177def _TensorScatterSubGrad(op, grad):
1178  indices = op.inputs[1]
1179  updates_grad = array_ops.gather_nd(grad, indices)
1180  tensor_grad = array_ops.identity(grad)
1181  return [tensor_grad, None, -updates_grad]
1182
1183
1184@ops.RegisterGradient("ScatterNdNonAliasingAdd")
1185def _ScatterNdNonAliasingAddGrad(op, grad):
1186  indices = op.inputs[1]
1187  updates_grad = array_ops.gather_nd(grad, indices)
1188  return [grad, None, updates_grad]
1189
1190
1191@ops.RegisterGradient("BroadcastTo")
1192def _BroadcastToGrad(op, grad):
1193  input_value = op.inputs[0]
1194  broadcast_shape = op.inputs[1]
1195  input_value_shape = array_ops.shape(input_value)
1196  if not isinstance(broadcast_shape, ops.EagerTensor):
1197    broadcast_shape_static = tensor_shape.TensorShape(
1198        pywrap_tf_session.TF_TryEvaluateConstant_wrapper(
1199            broadcast_shape.graph._c_graph, broadcast_shape._as_tf_output()))  # pylint: disable=protected-access
1200    if broadcast_shape_static.is_fully_defined():
1201      broadcast_shape = constant_op.constant(
1202          broadcast_shape_static.as_list(), dtype=dtypes.int32)
1203  _, reduction_axes = gen_array_ops.broadcast_gradient_args(
1204      broadcast_shape, input_value_shape)
1205  updates_grad_reshaped = math_ops.reduce_sum(
1206      grad, axis=reduction_axes, keepdims=True)
1207  updates_grad = array_ops.reshape(updates_grad_reshaped, input_value_shape)
1208  return [updates_grad, None]
1209