• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Gradients for operators defined in array_ops.py."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.compiler.tf2xla.ops import gen_xla_ops
22from tensorflow.python import pywrap_tfe
23from tensorflow.python.client import pywrap_tf_session
24from tensorflow.python.eager import context
25from tensorflow.python.framework import constant_op
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import ops
28from tensorflow.python.framework import sparse_tensor
29from tensorflow.python.framework import tensor_shape
30from tensorflow.python.framework import tensor_util
31from tensorflow.python.ops import array_ops
32from tensorflow.python.ops import control_flow_ops
33from tensorflow.python.ops import control_flow_util
34from tensorflow.python.ops import gen_array_ops
35from tensorflow.python.ops import gen_math_ops
36from tensorflow.python.ops import gen_resource_variable_ops
37from tensorflow.python.ops import math_ops
38from tensorflow.python.ops import sparse_ops
39
40
41@ops.RegisterGradient("Pack")
42def _PackGrad(op, grad):
43  """Gradient for pack op."""
44  return array_ops.unstack(grad, num=op.get_attr("N"), axis=op.get_attr("axis"))
45
46
47@ops.RegisterGradient("Unpack")
48def _UnpackGrad(op, *grads):
49  """Gradient for unpack op."""
50  return array_ops.stack(grads, axis=op.get_attr("axis"))
51
52
53def _ConcatGradHelper(op, grad, start_value_index, end_value_index, dim_index):
54  """Gradient for concat op.
55
56  Args:
57    op: An operation.
58    grad: `Tensor` or `IndexedSlices` representing the gradients with respect to
59      each output of the op.
60    start_value_index: An integer index of the first value in the op.inputs.
61    end_value_index: An integer index of the last value in the op.inputs.
62    dim_index: An integer index of concat_dim or axis parameter in op.inputs.
63
64  Returns:
65    Tensors representing the partial gradients with respect to each input
66    of the op.
67
68  Raises:
69    ValueError: if concat_dim/axis is not statically known.
70  """
71
72  def _CreateDenseMaskAndBegin(sizes, concat_dim):
73    """Create variables for iteratively slicing a dense gradients tensor."""
74    # Since shape is 1-D, shape_of_shape = [rank-of-inputs]
75    shape_of_shape = array_ops.shape(sizes[0])
76    # Make a vector of length equal to the input's dimensions,
77    # with 0's everywhere and 1 in the concat dim position.
78    # Note: Can't use sparse_to_dense since it isn't GPU-capable (for now)
79    mask = array_ops.concat([
80        array_ops.fill(array_ops.expand_dims(concat_dim, 0), 0), [1],
81        array_ops.fill(shape_of_shape - concat_dim - 1, 0)
82    ], 0)
83    begin = array_ops.fill(shape_of_shape, 0)
84    return mask, begin
85
86  def _ExtractInputShapes(inputs):
87    """Extract the shapes of a set of input tensors."""
88    if context.executing_eagerly():
89      return array_ops.shape_n(inputs)
90    sizes = []
91    fully_known = True
92    for x in inputs:
93      input_shape = array_ops.shape(x)
94      if not isinstance(input_shape,
95                        ops.Tensor) or input_shape.op.type != "Const":
96        fully_known = False
97        break
98      sizes.append(input_shape)
99
100    if fully_known:
101      return sizes
102    else:
103      return array_ops.shape_n(inputs)
104
105  # Degenerate concatenation, just return grad.
106  if len(op.inputs) == 2:
107    return grad + [None] if end_value_index <= dim_index else [None] + grad
108
109  concat_dim = op.inputs[dim_index]
110  input_values = op.inputs[start_value_index:end_value_index]
111
112  out_grads = []
113  if isinstance(grad, ops.Tensor):
114    if context.executing_eagerly() or isinstance(concat_dim, ops.EagerTensor):
115      # Using mod here for convenience since concat_dim is already verified
116      # in concat implementation to be within the allowed [-rank, rank) range.
117      non_neg_concat_dim = (
118          concat_dim._numpy().item(0) % input_values[0]._rank())  # pylint: disable=protected-access
119      # All inputs are guaranteed to be EagerTensors in eager mode
120      sizes = pywrap_tfe.TFE_Py_TensorShapeSlice(input_values,
121                                                 non_neg_concat_dim)
122      out_grads = array_ops.split(grad, sizes, non_neg_concat_dim)
123    else:
124      if constant_op.is_constant(concat_dim):
125        # If concat_dim is a constant defined in a different context,
126        # then we duplicate it in the current context to avoid passing it
127        # through an Enter node.
128        # This is a small optimization in general, but it is required when
129        # compiling with XLA, as XLA needs the concat input to be folded into a
130        # constant.
131        grad_context = control_flow_util.GetOutputContext(grad.op)
132        dim_context = control_flow_util.GetOutputContext(concat_dim.op)
133        if dim_context != grad_context:
134          value = tensor_util.constant_value(concat_dim)
135          concat_dim = constant_op.constant(value=value, dtype=concat_dim.dtype)
136
137      # Using mod here for convenience since concat_dim is already verified
138      # in concat implementation to be within the allowed [-rank, rank) range.
139      non_neg_concat_dim = concat_dim % array_ops.rank(input_values[0])
140
141      # Get the inputs' tensor shapes
142      sizes = _ExtractInputShapes(input_values)
143      # The magic number of 16 was found through benchmarking a range of sizes
144      # on CPUs and a Maxwell TitanX.  A speedup was seen in a large majority of
145      # cases when switching implementations at N=16, but it is possible that
146      # there will be a small number of performance regressions.
147      if len(sizes) > 16:
148        # extract the size of each input along the concat dimension
149        sizes = array_ops.squeeze(
150            array_ops.slice(
151                array_ops.stack(sizes, axis=1), [non_neg_concat_dim, 0],
152                [1, -1]))
153        out_grads = array_ops.split(grad, sizes, non_neg_concat_dim)
154      else:
155        offset = gen_array_ops.concat_offset(non_neg_concat_dim, sizes)
156        for (begin, size) in zip(offset, sizes):
157          out_grads.append(array_ops.slice(grad, begin, size))
158  elif isinstance(grad, ops.IndexedSlices):
159    # Using mod here for convenience since concat_dim is already verified
160    # in concat implementation to be within the allowed [-rank, rank) range.
161    non_neg_concat_dim = concat_dim % array_ops.rank(input_values[0])
162    concat_dim_static = tensor_util.constant_value(concat_dim)
163    if concat_dim_static is None:
164      raise ValueError("Can only compute IndexedSlices gradient with "
165                       "statically-known concat_dim")
166    if concat_dim_static < 0:
167      rank = tensor_util.constant_value(array_ops.rank(input_values[0]))
168      if rank is None:
169        raise ValueError("Can only compute IndexedSlices gradient with "
170                         "negative concat_dim when first value rank is "
171                         "statically-known.")
172      concat_dim_static %= rank
173    # Get the inputs' tensor shapes
174    sizes = [array_ops.shape(x) for x in input_values]
175    if concat_dim_static > 0:
176      # IndexedSlices, non_neg_concat_dim > 0. Each input gets IndexedSlices
177      # gradients with all the indices, but with grad.values sliced accordingly.
178      # This is like the Tensor case, except shape(grad.values)[0] is not equal
179      # to shape(sizes[i])[0], since only a subset of the dim-0 values are
180      # stored.
181      mask, begin = _CreateDenseMaskAndBegin(sizes, non_neg_concat_dim)
182      for size in sizes:
183        new_values = array_ops.slice(
184            grad.values, begin,
185            array_ops.concat([[-1], array_ops.slice(size, [1], [-1])], 0))
186        out_grads.append(ops.IndexedSlices(new_values, grad.indices, size))
187        # Lint complains begin = begin + ...
188        begin = math_ops.add(begin, size * mask)
189    else:
190      # IndexedSlices, concat_dim == 0. Each input gets IndexedSlices gradients
191      # only for the relevant indices.
192      start = constant_op.constant(0, dtype=grad.indices.dtype)
193      for size in sizes:
194        size_concat_dim = array_ops.gather(size, non_neg_concat_dim)
195        if size_concat_dim.dtype != grad.indices.dtype:
196          size_concat_dim = math_ops.cast(
197              size_concat_dim, dtype=grad.indices.dtype)
198        end = start + size_concat_dim
199        # Compute the 1-D Tensor of indices relevant for this input.
200        indices_to_select = array_ops.squeeze(
201            array_ops.where(
202                math_ops.logical_and(grad.indices >= start,
203                                     grad.indices < end)),
204            axis=[1])
205        new_indices = array_ops.gather(grad.indices, indices_to_select) - start
206        new_values = array_ops.gather(grad.values, indices_to_select)
207        out_grads.append(ops.IndexedSlices(new_values, new_indices, size))
208        start = end
209  else:
210    raise TypeError("Expected Tensor or IndexedSlices, got %s" % type(grad))
211
212  return (out_grads + [None] if end_value_index <= dim_index else [None] +
213          out_grads)
214
215
216@ops.RegisterGradient("Concat")
217def _ConcatGrad(op, grad):
218  return _ConcatGradHelper(
219      op,
220      grad,
221      start_value_index=1,
222      end_value_index=len(op.inputs),
223      dim_index=0)
224
225
226@ops.RegisterGradient("ConcatV2")
227def _ConcatGradV2(op, grad):
228  return _ConcatGradHelper(
229      op, grad, start_value_index=0, end_value_index=-1, dim_index=-1)
230
231
232ops.NotDifferentiable("ConcatOffset")
233
234
235@ops.RegisterGradient("Slice")
236def _SliceGrad(op, grad):
237  """Gradient for Slice op."""
238  # Create an Nx2 padding where the first column represents how many
239  # zeros are to be prepended for each dimension, and the second
240  # column indicates how many zeros are appended.
241  #
242  # The number of zeros to append is the shape of the input
243  # elementwise-subtracted by both the begin vector and sizes vector.
244  #
245  # Some more reshaping is needed to assemble this tensor with the
246  # right dimensions.
247  input_vec = op.inputs[0]
248  begin_vec = op.inputs[1]
249  input_rank = array_ops.rank(input_vec)
250  slice_size = array_ops.shape(op.outputs[0])
251  if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()):
252    return gen_xla_ops.xla_dynamic_update_slice(array_ops.zeros_like(input_vec),
253                                                grad, begin_vec), None, None
254
255  shape = array_ops.stack([input_rank, 1])
256  before_pad = array_ops.reshape(begin_vec, shape)
257  after_pad = array_ops.reshape(
258      array_ops.shape(input_vec) - slice_size - begin_vec, shape)
259  paddings = array_ops.concat([before_pad, after_pad], 1)
260  return array_ops.pad(grad, paddings), None, None
261
262
263@ops.RegisterGradient("StridedSlice")
264def _StridedSliceGrad(op, grad):
265  """Gradient for StridedSlice op."""
266  begin = op.inputs[1]
267  end = op.inputs[2]
268  strides = op.inputs[3]
269  # StridedSliceGrad requires `x`, `begin`, `end` and `strides` to be of the
270  # same dtype so we build a shape of the same type as other args.
271  # Note that the choice of `begin` for specifying `out_type` is arbitrary.
272  # We could choose any of {begin|end|strides}.dtype since they are required to
273  # be the same.
274  x = array_ops.shape(op.inputs[0], out_type=begin.dtype)
275
276  x_static = tensor_util.constant_value(x)
277  x = x_static if x_static is not None else x
278  begin_static = tensor_util.constant_value(begin)
279  begin = begin_static if begin_static is not None else begin
280  end_static = tensor_util.constant_value(end)
281  end = end_static if end_static is not None else end
282  strides_static = tensor_util.constant_value(strides)
283  strides = strides_static if strides_static is not None else strides
284
285  return array_ops.strided_slice_grad(
286      x,
287      begin,
288      end,
289      strides,
290      grad,
291      begin_mask=op.get_attr("begin_mask"),
292      end_mask=op.get_attr("end_mask"),
293      ellipsis_mask=op.get_attr("ellipsis_mask"),
294      new_axis_mask=op.get_attr("new_axis_mask"),
295      shrink_axis_mask=op.get_attr("shrink_axis_mask")), None, None, None
296
297
298@ops.RegisterGradient("StridedSliceGrad")
299def _StridedSliceGradGrad(op, grad):
300  """Gradient for StridedSliceGrad op."""
301  begin = op.inputs[1]
302  end = op.inputs[2]
303  strides = op.inputs[3]
304
305  return None, None, None, None, array_ops.strided_slice(
306      grad,
307      begin,
308      end,
309      strides,
310      begin_mask=op.get_attr("begin_mask"),
311      end_mask=op.get_attr("end_mask"),
312      ellipsis_mask=op.get_attr("ellipsis_mask"),
313      new_axis_mask=op.get_attr("new_axis_mask"),
314      shrink_axis_mask=op.get_attr("shrink_axis_mask"))
315
316
317@ops.RegisterGradient("TensorStridedSliceUpdate")
318def _TensorStridedSliceUpdateGrad(op, grad):  # pylint:disable=missing-function-docstring
319  begin = op.inputs[1]
320  end = op.inputs[2]
321  strides = op.inputs[3]
322  begin_mask = op.get_attr("begin_mask")
323  end_mask = op.get_attr("end_mask")
324  ellipsis_mask = op.get_attr("ellipsis_mask")
325  new_axis_mask = op.get_attr("new_axis_mask")
326  shrink_axis_mask = op.get_attr("shrink_axis_mask")
327  def Apply(f, *args):
328    return f(*args,
329             begin_mask=begin_mask,
330             end_mask=end_mask,
331             shrink_axis_mask=shrink_axis_mask,
332             new_axis_mask=new_axis_mask,
333             ellipsis_mask=ellipsis_mask)
334  dy = Apply(array_ops.strided_slice,
335             grad, begin, end, strides)
336  dx = Apply(array_ops.tensor_strided_slice_update,
337             grad, begin, end, strides, array_ops.zeros_like(dy))
338  return dx, None, None, None, dy
339
340
341@ops.RegisterGradient("Split")
342def _SplitGrad(op, *grads):
343  return None, array_ops.concat(list(grads), op.inputs[0])
344
345
346@ops.RegisterGradient("SplitV")
347def _SplitVGrad(op, *grads):
348  returnval = array_ops.concat(list(grads), op.inputs[2])
349  returnval = [returnval] + [
350      None,
351  ] * (
352      len(op.inputs) - 1)
353  return returnval
354
355
356ops.NotDifferentiable("Const")
357
358
359@ops.RegisterGradient("Diag")
360def _DiagGrad(_, grad):
361  return array_ops.diag_part(grad)
362
363
364@ops.RegisterGradient("DiagPart")
365def _DiagPartGrad(_, grad):
366  return array_ops.diag(grad)
367
368
369@ops.RegisterGradient("MatrixDiag")
370def _MatrixDiagGrad(_, grad):
371  return array_ops.matrix_diag_part(grad)
372
373
374@ops.RegisterGradient("MatrixDiagV2")
375def _MatrixDiagV2Grad(op, grad):
376  return array_ops.matrix_diag_part(
377      grad, k=op.inputs[1]), None, None, None, None
378
379
380@ops.RegisterGradient("MatrixDiagV3")
381def _MatrixDiagV3Grad(op, grad):
382  return array_ops.matrix_diag_part(
383      grad, k=op.inputs[1], align=op.get_attr("align")), None, None, None, None
384
385
386@ops.RegisterGradient("MatrixDiagPart")
387def _MatrixDiagPartGrad(op, grad):
388  matrix_shape = op.inputs[0].get_shape()[-2:]
389  if matrix_shape.is_fully_defined() and matrix_shape[0] == matrix_shape[1]:
390    return array_ops.matrix_diag(grad)
391  else:
392    return array_ops.matrix_set_diag(array_ops.zeros_like(op.inputs[0]), grad)
393
394
395@ops.RegisterGradient("MatrixDiagPartV2")
396def _MatrixDiagPartV2Grad(op, grad):
397  """Gradient for MatrixDiagPartV2."""
398  matrix_shape = op.inputs[0].get_shape()[-2:]
399  if matrix_shape.is_fully_defined():
400    return array_ops.matrix_diag(
401        grad,
402        k=op.inputs[1],
403        num_rows=matrix_shape[0],
404        num_cols=matrix_shape[1]), None, None
405  else:
406    return array_ops.matrix_set_diag(
407        array_ops.zeros_like(op.inputs[0]), grad, k=op.inputs[1]), None, None
408
409
410@ops.RegisterGradient("MatrixDiagPartV3")
411def _MatrixDiagPartV3Grad(op, grad):
412  """Gradient for MatrixDiagPartV3."""
413  matrix_shape = op.inputs[0].get_shape()[-2:]
414  align = op.get_attr("align")
415  if matrix_shape.is_fully_defined():
416    return array_ops.matrix_diag(
417        grad,
418        k=op.inputs[1],
419        num_rows=matrix_shape[0],
420        num_cols=matrix_shape[1],
421        align=align), None, None
422  else:
423    return array_ops.matrix_set_diag(
424        array_ops.zeros_like(op.inputs[0]), grad, k=op.inputs[1],
425        align=align), None, None
426
427
428@ops.RegisterGradient("MatrixSetDiag")
429def _MatrixSetDiagGrad(op, grad):
430  """Gradient for MatrixSetDiag."""
431  input_shape = op.inputs[0].get_shape().merge_with(grad.get_shape())
432  diag_shape = op.inputs[1].get_shape()
433  batch_shape = input_shape[:-2].merge_with(diag_shape[:-1])
434  matrix_shape = input_shape[-2:]
435  if batch_shape.is_fully_defined() and matrix_shape.is_fully_defined():
436    diag_shape = batch_shape.as_list() + [min(matrix_shape.as_list())]
437  else:
438    with ops.colocate_with(grad):
439      grad_shape = array_ops.shape(grad)
440      grad_rank = array_ops.rank(grad)
441      batch_shape = array_ops.slice(grad_shape, [0], [grad_rank - 2])
442      matrix_shape = array_ops.slice(grad_shape, [grad_rank - 2], [2])
443      min_dim = math_ops.reduce_min(matrix_shape)
444      diag_shape = array_ops.concat([batch_shape, [min_dim]], 0)
445  grad_input = array_ops.matrix_set_diag(
446      grad, array_ops.zeros(diag_shape, dtype=grad.dtype))
447  grad_diag = array_ops.matrix_diag_part(grad)
448  return (grad_input, grad_diag)
449
450
451@ops.RegisterGradient("MatrixSetDiagV2")
452def _MatrixSetDiagGradV2(op, grad):
453  """Gradient for MatrixSetDiagV2."""
454  diag_shape = op.inputs[1].get_shape()
455  if not diag_shape.is_fully_defined():
456    # Need to know the values of `d_lower` and `d_upper` to infer diag_shape.
457    grad_shape = array_ops.shape(grad)
458    batch_shape = grad_shape[:-2]
459    matrix_shape = grad_shape[-2:]
460    diag_index = array_ops.reshape(op.inputs[2], [-1])  # Converts to vector.
461    d_lower = diag_index[0]
462    d_upper = diag_index[-1]  # Works both when len(diag_index) is 1 and 2.
463    y_offset = control_flow_ops.cond(
464        math_ops.less(d_upper, 0), lambda: d_upper, lambda: 0)
465    x_offset = control_flow_ops.cond(
466        math_ops.greater(d_lower, 0), lambda: -d_lower, lambda: 0)
467
468    max_diag_len = math_ops.minimum(matrix_shape[0] + y_offset,
469                                    matrix_shape[1] + x_offset)
470    # pylint: disable=g-long-lambda
471    # pyformat: disable
472    postfix = control_flow_ops.cond(
473        math_ops.equal(d_lower, d_upper),
474        lambda: ops.convert_to_tensor([max_diag_len]),
475        lambda: ops.convert_to_tensor([d_upper - d_lower + 1,
476                                       max_diag_len]))
477    # pyformat: enable
478    # pylint: enable=g-long-lambda
479    diag_shape = array_ops.concat([batch_shape, postfix], 0)
480
481  grad_input = array_ops.matrix_set_diag(
482      grad, array_ops.zeros(diag_shape, dtype=grad.dtype), k=op.inputs[2])
483  grad_diag = array_ops.matrix_diag_part(grad, k=op.inputs[2])
484  return (grad_input, grad_diag, None)
485
486
487@ops.RegisterGradient("MatrixSetDiagV3")
488def _MatrixSetDiagGradV3(op, grad):
489  """Gradient for MatrixSetDiagV3."""
490  diag_shape = op.inputs[1].get_shape()
491  align = op.get_attr("align")
492  if not diag_shape.is_fully_defined():
493    # Need to know the values of `d_lower` and `d_upper` to infer diag_shape.
494    grad_shape = array_ops.shape(grad)
495    batch_shape = grad_shape[:-2]
496    matrix_shape = grad_shape[-2:]
497    diag_index = array_ops.reshape(op.inputs[2], [-1])  # Converts to vector.
498    d_lower = diag_index[0]
499    d_upper = diag_index[-1]  # Works both when len(diag_index) is 1 and 2.
500    y_offset = control_flow_ops.cond(
501        math_ops.less(d_upper, 0), lambda: d_upper, lambda: 0)
502    x_offset = control_flow_ops.cond(
503        math_ops.greater(d_lower, 0), lambda: -d_lower, lambda: 0)
504
505    max_diag_len = math_ops.minimum(matrix_shape[0] + y_offset,
506                                    matrix_shape[1] + x_offset)
507    # pylint: disable=g-long-lambda
508    # pyformat: disable
509    postfix = control_flow_ops.cond(
510        math_ops.equal(d_lower, d_upper),
511        lambda: ops.convert_to_tensor([max_diag_len]),
512        lambda: ops.convert_to_tensor([d_upper - d_lower + 1,
513                                       max_diag_len]))
514    # pyformat: enable
515    # pylint: enable=g-long-lambda
516    diag_shape = array_ops.concat([batch_shape, postfix], 0)
517
518  grad_input = array_ops.matrix_set_diag(
519      grad,
520      array_ops.zeros(diag_shape, dtype=grad.dtype),
521      k=op.inputs[2],
522      align=align)
523  grad_diag = array_ops.matrix_diag_part(grad, k=op.inputs[2], align=align)
524  return (grad_input, grad_diag, None)
525
526
527@ops.RegisterGradient("MatrixBandPart")
528def _MatrixBandPartGrad(op, grad):
529  num_lower = op.inputs[1]
530  num_upper = op.inputs[2]
531  return (array_ops.matrix_band_part(grad, num_lower, num_upper), None, None)
532
533
534# Edit Distance has no gradient (but can be used to eval seq2seq or CTC).
535ops.NotDifferentiable("EditDistance")
536
537
538@ops.RegisterGradient("Fill")
539def _FillGrad(_, grad):
540  return None, math_ops.reduce_sum(grad)
541
542
543ops.NotDifferentiable("ZerosLike")
544ops.NotDifferentiable("OnesLike")
545
546
547@ops.RegisterGradient("PreventGradient")
548def _PreventGradientGrad(op, _):
549  raise LookupError("Gradient explicitly disabled. Reason: %s" %
550                    op.get_attr("message"))
551
552
553def _IndexedSlicesToTensorNoWarning(indexed_slices):
554  """Converts an IndexedSlices to a Tensor without sparse->dense warnings."""
555  if not isinstance(indexed_slices, ops.IndexedSlices):
556    # If it is not IndexedSlices, it's better be a tensor.
557    return indexed_slices
558  if indexed_slices.dense_shape is None:
559    raise ValueError(
560        "Tensor conversion requested for IndexedSlices without dense_shape: %s"
561        % str(indexed_slices))
562  return math_ops.unsorted_segment_sum(indexed_slices.values,
563                                       indexed_slices.indices,
564                                       indexed_slices.dense_shape[0])
565
566
567@ops.RegisterGradient("Gather")
568def _GatherGrad(op, grad):
569  """Gradient for Gather op."""
570  # params can be large, so colocate the shape calculation with it.
571  params = op.inputs[0]
572  with ops.colocate_with(params):
573    params_shape = array_ops.shape(params)
574
575  # Build appropriately shaped IndexedSlices
576  indices = op.inputs[1]
577  size = array_ops.expand_dims(array_ops.size(indices), 0)
578  values_shape = array_ops.concat([size, params_shape[1:]], 0)
579  values = array_ops.reshape(
580      _IndexedSlicesToTensorNoWarning(grad), values_shape)
581  indices = array_ops.reshape(indices, size)
582  return [ops.IndexedSlices(values, indices, params_shape), None]
583
584
585def _GetBatchIndices(params_shape, indices, batch_dims):
586  """Addds the batch offsets to the given indices and returns the results."""
587  batch_indices = indices
588  indices_ndims = indices.shape.ndims
589  indices_dtype = indices.dtype.base_dtype
590  casted_params_shape = math_ops.cast(params_shape, indices_dtype)
591  accum_dim_value = array_ops.ones((), dtype=indices_dtype)
592  for dim in range(batch_dims, 0, -1):
593    dim_value = casted_params_shape[dim - 1]
594    accum_dim_value *= casted_params_shape[dim]
595    start = array_ops.zeros((), dtype=indices_dtype)
596    step = array_ops.ones((), dtype=indices_dtype)
597    dim_indices = math_ops.range(start, dim_value, step)
598    dim_indices *= accum_dim_value
599    dim_shape = array_ops.stack(
600        [1] * (dim - 1) + [dim_value] + [1] * (indices_ndims - dim), axis=0)
601    batch_indices += array_ops.reshape(dim_indices, dim_shape)
602
603  return batch_indices
604
605
606def _BatchGatherGrad(params_shape, values, indices, batch_dims,
607                     gather_dim_size):
608  """Returns the gradient of GatherV2 with batch dimensions."""
609
610  # Axis is the first non-batch dimension.
611  indices_size = array_ops.expand_dims(array_ops.size(indices), 0)
612  if batch_dims:
613    values_shape = array_ops.shape(values)
614    # Add the batch offsets to indices and flatten the batch dimensions.
615    outer_shape = values_shape[:batch_dims]
616    inner_shape = values_shape[batch_dims:][1:]
617    batch_size = gen_math_ops.prod(outer_shape, [0], False)
618    flat_values_shape = array_ops.concat([[-1], inner_shape], 0)
619    gather_dim_size *= batch_size
620
621    indices = _GetBatchIndices(params_shape, indices, batch_dims)
622    values = array_ops.reshape(
623        _IndexedSlicesToTensorNoWarning(values), flat_values_shape)
624
625  indices = array_ops.reshape(indices, indices_size)
626  params_grad = math_ops.unsorted_segment_sum(values, indices, gather_dim_size)
627
628  if batch_dims:
629    # Put back the batch dimensions.
630    params_grad = array_ops.reshape(
631        params_grad, array_ops.concat([outer_shape, flat_values_shape], 0))
632
633  return params_grad
634
635
636@ops.RegisterGradient("GatherV2")
637def _GatherV2Grad(op, grad):
638  """Gradient for GatherV2 op."""
639  # params can be large, so colocate the shape calculation with it.
640  #
641  # params can be very large for sparse model, array_ops.shape raises
642  # exception on the Windows platform when any dimension is larger than
643  # int32. params_shape is not used in optimizer apply_sparse gradients,
644  # so it's fine to convert it back to int32 regardless of truncation.
645  params = op.inputs[0]
646  with ops.colocate_with(params):
647    params_shape = array_ops.shape(params, out_type=ops.dtypes.int64)
648    params_shape = math_ops.cast(params_shape, dtypes.int32)
649
650  indices = op.inputs[1]
651  indices_size = array_ops.expand_dims(array_ops.size(indices), 0)
652  axis = op.inputs[2]
653  axis_static = tensor_util.constant_value(axis)
654  batch_dims = int(op.get_attr("batch_dims"))
655
656  if batch_dims < 0:
657    batch_dims += indices.shape.ndims
658
659  # For axis 0 gathers, build an appropriately shaped IndexedSlices.
660  if axis_static == 0:
661    if context.executing_eagerly():
662      with ops.device(indices_size.device):
663        params_tail_shape = array_ops.identity(params_shape)[1:]
664    else:
665      params_tail_shape = params_shape[1:]
666    values_shape = array_ops.concat([indices_size, params_tail_shape], 0)
667    values = array_ops.reshape(
668        _IndexedSlicesToTensorNoWarning(grad), values_shape)
669    indices = array_ops.reshape(indices, indices_size)
670    params_grad = ops.IndexedSlices(values, indices, params_shape)
671  else:
672    # Handle axis by transposing the axis dimension to be the first non-batch
673    # dimension, compute the gradient and transpose the result back.
674    outer_shape = params_shape[:axis]
675    inner_shape = params_shape[axis:][1:]
676    values_shape = array_ops.concat([outer_shape, [-1], inner_shape], 0)
677
678    values_dims = array_ops.size(values_shape)
679    axis_dims = array_ops.size(outer_shape)
680
681    outer_batches_indices = math_ops.range(batch_dims)
682    batch_axis_indices = math_ops.range(batch_dims, axis_dims)
683    inner_axes_indices = math_ops.range(axis_dims + 1, values_dims)
684
685    values = array_ops.reshape(
686        _IndexedSlicesToTensorNoWarning(grad), values_shape)
687
688    # Move values[axis] up to values[batch_dims]
689    transpose_dims = array_ops.concat([
690        outer_batches_indices, [axis_dims], batch_axis_indices,
691        inner_axes_indices
692    ], 0)
693    values_transpose = array_ops.transpose(values, transpose_dims)
694
695    params_grad = _BatchGatherGrad(params_shape, values_transpose, indices,
696                                   batch_dims, params_shape[axis])
697
698    # Inverts the above transpose by moving dimension batch_dims back to its
699    # original position.
700    invert_transpose_dims = array_ops.concat([
701        outer_batches_indices, batch_axis_indices + 1, [batch_dims],
702        inner_axes_indices
703    ], 0)
704    params_grad = array_ops.transpose(params_grad, invert_transpose_dims)
705
706  return [params_grad, None, None]
707
708
709@ops.RegisterGradient("GatherNd")
710def _GatherNdGrad(op, grad):
711  ref = op.inputs[0]
712  indices = op.inputs[1]
713  ref_shape = array_ops.shape(ref, out_type=indices.dtype)
714  if indices.shape.ndims == 2 and indices.shape.dims[-1].value == 1:
715    ref_grad = ops.IndexedSlices(grad, array_ops.squeeze(indices, axis=-1),
716                                 ref_shape)
717  else:
718    ref_grad = array_ops.scatter_nd(indices, grad, ref_shape)
719  return [ref_grad, None]
720
721
722@ops.RegisterGradient("ResourceGatherNd")
723def _ResourceGatherNdGrad(op, grad):  # pylint: disable=missing-docstring
724  ref = op.inputs[0]
725  indices = op.inputs[1]
726  ref_shape = gen_resource_variable_ops.variable_shape(ref, indices.dtype)
727  if indices.shape.ndims == 2 and indices.shape.dims[-1].value == 1:
728    ref_grad = ops.IndexedSlices(grad, array_ops.squeeze(indices, axis=-1),
729                                 ref_shape)
730  else:
731    ref_grad = array_ops.scatter_nd(indices, grad, ref_shape)
732  return [ref_grad, None]
733
734
735@ops.RegisterGradient("CheckNumerics")
736def _CheckNumericsGrad(op, grad):
737  """Gradient for check_numerics op."""
738  return array_ops.check_numerics(
739      grad,
740      "Not a number (NaN) or infinity (Inf) values detected in gradient. %s" %
741      op.get_attr("message"))
742
743
744@ops.RegisterGradient("CheckNumericsV2")
745def _CheckNumericsV2Grad(op, grad):
746  """Gradient for check_numerics op."""
747  return array_ops.check_numerics_v2(
748      grad,
749      "Not a number (NaN) or infinity (Inf) values detected in gradient. %s" %
750      op.get_attr("message"))
751
752
753@ops.RegisterGradient("PlaceholderWithDefault")
754@ops.RegisterGradient("Identity")
755def _IdGrad(_, grad):
756  return grad
757
758
759@ops.RegisterGradient("RefIdentity")
760def _RefIdGrad(_, grad):
761  return grad
762
763
764@ops.RegisterGradient("IdentityN")
765def _IdNGrad(_, *grad):
766  return grad
767
768
769ops.NotDifferentiable("StopGradient")
770
771
772@ops.RegisterGradient("Reshape")
773def _ReshapeGrad(op, grad):
774  return [
775      array_ops.reshape(
776          _IndexedSlicesToTensorNoWarning(grad), array_ops.shape(op.inputs[0])),
777      None
778  ]
779
780
781ops.NotDifferentiable("InvertPermutation")
782
783
784def _ReshapeToInput(op, grad):
785  """Reshapes the gradient to the shape of the original input."""
786  return array_ops.reshape(
787      _IndexedSlicesToTensorNoWarning(grad), array_ops.shape(op.inputs[0]))
788
789
790@ops.RegisterGradient("ExpandDims")
791def _ExpandDimsGrad(op, grad):
792  return [_ReshapeToInput(op, grad), None]
793
794
795@ops.RegisterGradient("Squeeze")
796def _SqueezeGrad(op, grad):
797  return _ReshapeToInput(op, grad)
798
799
800@ops.RegisterGradient("Transpose")
801def _TransposeGrad(op, grad):
802  """Returns unshuffle(grad)."""
803  p = op.inputs[1]
804  return [array_ops.transpose(grad, array_ops.invert_permutation(p)), None]
805
806
807@ops.RegisterGradient("ConjugateTranspose")
808def _ConjugateTransposeGrad(op, grad):
809  """Returns conj(unshuffle(grad))."""
810  p = op.inputs[1]
811  return [
812      array_ops.transpose(
813          grad, array_ops.invert_permutation(p), conjugate=True), None
814  ]
815
816
817ops.NotDifferentiable("Shape")
818
819ops.NotDifferentiable("ShapeN")
820
821ops.NotDifferentiable("Rank")
822
823ops.NotDifferentiable("Size")
824
825
826@ops.RegisterGradient("Tile")
827def _TileGrad(op, grad):
828  """Sum reduces grad along the tiled dimensions."""
829  input_shape = array_ops.shape(op.inputs[0], out_type=op.inputs[1].dtype)
830  # We interleave multiples and input_shape to get split_shape,
831  # reshape grad to split_shape, and reduce along all even
832  # dimensions (the tiled dimensions) to get the result
833  # with shape input_shape.  For example
834  #   input_shape = [20, 30, 40]
835  #   multiples = [2, 3, 4]
836  #   split_shape = [2, 20, 3, 30, 4, 40]
837  #   axes = [0, 2, 4]
838  split_shape = array_ops.reshape(
839      array_ops.transpose(array_ops.stack([op.inputs[1], input_shape])), [-1])
840  axes = math_ops.range(0, array_ops.size(split_shape), 2)
841  # Sum reduces grad along the first dimension for IndexedSlices
842  if isinstance(grad, ops.IndexedSlices):
843    input_shape_0 = math_ops.cast(input_shape[0], grad.indices.dtype)
844    grad = math_ops.unsorted_segment_sum(
845        grad.values, math_ops.mod(grad.indices, input_shape_0), input_shape_0)
846    split_shape = array_ops.concat([[1], split_shape[1:]], axis=0)
847  input_grad = math_ops.reduce_sum(array_ops.reshape(grad, split_shape), axes)
848  # Fix shape inference
849  if not context.executing_eagerly():
850    input_grad.set_shape(op.inputs[0].get_shape())
851  return [input_grad, None]
852
853
854ops.NotDifferentiable("BroadcastGradientArgs")
855
856
857def _PadGrad(op, grad):
858  """Gradient for Pad."""
859  # Pad introduces values around the original tensor, so the gradient function
860  # slices the original shape out of the gradient."""
861  x = op.inputs[0]
862  a = op.inputs[1]  # [Rank(x), 2]
863  # Takes a slice of a. The 1st column. [Rank(x), 1].
864  pad_before = array_ops.slice(a, [0, 0],
865                               array_ops.stack([array_ops.rank(x), 1]))
866  # Make it a 1-D tensor.
867  begin = array_ops.reshape(pad_before, [-1])
868  sizes = array_ops.shape(x, out_type=begin.dtype)
869  x_grad = array_ops.slice(grad, begin, sizes)
870  if len(op.inputs) == 3:
871    return x_grad, None, None
872  else:
873    return x_grad, None
874
875
876ops.RegisterGradient("Pad")(_PadGrad)
877ops.RegisterGradient("PadV2")(_PadGrad)
878
879
880# ReverseSequence is just a permutation.  The gradient permutes back.
881@ops.RegisterGradient("ReverseSequence")
882def _ReverseSequenceGrad(op, grad):
883  seq_lengths = op.inputs[1]
884  return [
885      array_ops.reverse_sequence(
886          grad,
887          batch_axis=op.get_attr("batch_dim"),
888          seq_axis=op.get_attr("seq_dim"),
889          seq_lengths=seq_lengths), None
890  ]
891
892
893@ops.RegisterGradient("Reverse")
894def _ReverseGrad(op, grad):
895  reverse_dims = op.inputs[1]
896  return gen_array_ops.reverse(grad, reverse_dims), None
897
898
899@ops.RegisterGradient("ReverseV2")
900def _ReverseV2Grad(op, grad):
901  axis = op.inputs[1]
902  return array_ops.reverse_v2(grad, axis), None
903
904
905@ops.RegisterGradient("SpaceToBatch")
906def _SpaceToBatchGrad(op, grad):
907  # Its gradient is the opposite op: BatchToSpace.
908  block_size = op.get_attr("block_size")
909  return [
910      array_ops.batch_to_space(grad, op.inputs[1], block_size=block_size), None
911  ]
912
913
914@ops.RegisterGradient("SpaceToBatchND")
915def _SpaceToBatchNDGrad(op, grad):
916  # Its gradient is the opposite op: BatchToSpaceND.
917  return [
918      array_ops.batch_to_space_nd(grad, op.inputs[1], op.inputs[2]), None, None
919  ]
920
921
922@ops.RegisterGradient("BatchToSpace")
923def _BatchToSpaceGrad(op, grad):
924  # Its gradient is the opposite op: SpaceToBatch.
925  block_size = op.get_attr("block_size")
926  return [
927      array_ops.space_to_batch(grad, op.inputs[1], block_size=block_size), None
928  ]
929
930
931@ops.RegisterGradient("BatchToSpaceND")
932def _BatchToSpaceNDGrad(op, grad):
933  # Its gradient is the opposite op: SpaceToBatchND.
934  return [
935      array_ops.space_to_batch_nd(grad, op.inputs[1], op.inputs[2]), None, None
936  ]
937
938
939@ops.RegisterGradient("SpaceToDepth")
940def _SpaceToDepthGrad(op, grad):
941  # Its gradient is the opposite op: DepthToSpace.
942  block_size = op.get_attr("block_size")
943  data_format = op.get_attr("data_format")
944  if data_format == "NCHW_VECT_C":
945    raise ValueError("Cannot compute SpaceToDepth gradient with NCHW_VECT_C. "
946                     "NCHW_VECT_C requires qint8 data type.")
947  return array_ops.depth_to_space(grad, block_size, data_format=data_format)
948
949
950@ops.RegisterGradient("DepthToSpace")
951def _DepthToSpaceGrad(op, grad):
952  # Its gradient is the opposite op: SpaceToDepth.
953  block_size = op.get_attr("block_size")
954  data_format = op.get_attr("data_format")
955  if data_format == "NCHW_VECT_C":
956    raise ValueError("Cannot compute DepthToSpace gradient with NCHW_VECT_C. "
957                     "NCHW_VECT_C requires qint8 data type.")
958  return array_ops.space_to_depth(grad, block_size, data_format=data_format)
959
960
961ops.NotDifferentiable("OneHot")
962
963
964@ops.RegisterGradient("MirrorPad")
965def _MirrorPadGrad(op, grad):
966  mode = op.get_attr("mode")
967  return [gen_array_ops.mirror_pad_grad(grad, op.inputs[1], mode=mode), None]
968
969
970@ops.RegisterGradient("MirrorPadGrad")
971def _MirrorPadGradGrad(op, grad):
972  mode = op.get_attr("mode")
973  return [gen_array_ops.mirror_pad(grad, op.inputs[1], mode=mode), None]
974
975
976@ops.RegisterGradient("QuantizeAndDequantize")
977def _QuantizeAndDequantizeGrad(_, grad):
978  return grad
979
980
981@ops.RegisterGradient("QuantizeAndDequantizeV2")
982def _QuantizeAndDequantizeV2Grad(_, grad):
983  return [grad, None, None]
984
985
986@ops.RegisterGradient("QuantizeAndDequantizeV3")
987def _QuantizeAndDequantizeV3Grad(_, grad):
988  # Only propagate the gradient for the unquantized input.
989  return [grad, None, None, None]
990
991
992@ops.RegisterGradient("ExtractImagePatches")
993def _ExtractImagePatchesGrad(op, grad):
994  input_bhwc = array_ops.shape(op.inputs[0], out_type=dtypes.int64)
995  batch_size, rows_in, cols_in, channels = input_bhwc[0], input_bhwc[1], \
996                                           input_bhwc[2], input_bhwc[3]
997
998  # Create indices matrix for input tensor.
999  # Note that 0 is preserved for padding location,
1000  # so indices for input start from 1 to 1 + rows_in * cols_in.
1001  input_indices_num = 1 + rows_in * cols_in
1002  input_idx = array_ops.reshape(
1003      math_ops.range(1, input_indices_num, dtype=ops.dtypes.int64),
1004      (1, rows_in, cols_in, 1))
1005  input_idx_patched = gen_array_ops.extract_image_patches(
1006      input_idx, op.get_attr("ksizes"), op.get_attr("strides"),
1007      op.get_attr("rates"), op.get_attr("padding"))
1008
1009  # Create indices matrix for output tensor.
1010  output_bhwc = array_ops.shape(op.outputs[0], out_type=dtypes.int64)
1011  rows_out, cols_out = output_bhwc[1], output_bhwc[2]
1012  _, ksize_r, ksize_c, _ = op.get_attr("ksizes")
1013  # Indices for output start from 0.
1014  output_indices_num = rows_out * cols_out * ksize_r * ksize_c
1015  output_idx = array_ops.reshape(
1016      math_ops.range(output_indices_num, dtype=ops.dtypes.int64),
1017      (1, rows_out, cols_out, ksize_r * ksize_c))
1018
1019  # Construct mapping table for indices: (input -> output).
1020  idx_matrix = array_ops.concat([
1021      array_ops.expand_dims(input_idx_patched, axis=-1),
1022      array_ops.expand_dims(output_idx, axis=-1)
1023  ],
1024                                axis=-1)
1025  idx_map = array_ops.reshape(idx_matrix, (-1, 2))
1026
1027  sp_shape = (input_indices_num, output_indices_num)
1028  sp_mat_full = sparse_tensor.SparseTensor(
1029      idx_map, array_ops.ones([output_indices_num], dtype=grad.dtype), sp_shape)
1030  # Remove all padding locations [0, :].
1031  sp_mat = sparse_ops.sparse_slice(sp_mat_full, (1, 0),
1032                                   (input_indices_num - 1, output_indices_num))
1033
1034  grad_expanded = array_ops.transpose(
1035      array_ops.reshape(
1036          _IndexedSlicesToTensorNoWarning(grad),
1037          (batch_size, rows_out, cols_out, ksize_r, ksize_c, channels)),
1038      (1, 2, 3, 4, 0, 5))
1039  grad_flat = array_ops.reshape(grad_expanded, (-1, batch_size * channels))
1040
1041  jac = sparse_ops.sparse_tensor_dense_matmul(sp_mat, grad_flat)
1042
1043  grad_out = array_ops.reshape(jac, (rows_in, cols_in, batch_size, channels))
1044  grad_out = array_ops.transpose(grad_out, (2, 0, 1, 3))
1045
1046  return [grad_out]
1047
1048
1049@ops.RegisterGradient("ExtractVolumePatches")
1050def _ExtractVolumePatchesGrad(op, grad):
1051  batch_size, planes_in, rows_in, cols_in, channels = [
1052      dim.value for dim in op.inputs[0].shape.dims
1053  ]
1054  input_bphwc = array_ops.shape(op.inputs[0])
1055  batch_size = input_bphwc[0]
1056  channels = input_bphwc[4]
1057
1058  # Create indices matrix for input tensor.
1059  # Note that 0 is preserved for padding location,
1060  # so indices for input start from 1 to 1 + rows_in * cols_in.
1061  input_indices_num = 1 + planes_in * rows_in * cols_in
1062  input_idx = array_ops.reshape(
1063      math_ops.range(1, input_indices_num, dtype=ops.dtypes.int64),
1064      (1, planes_in, rows_in, cols_in, 1))
1065  input_idx_patched = gen_array_ops.extract_volume_patches(
1066      input_idx, op.get_attr("ksizes"), op.get_attr("strides"),
1067      op.get_attr("padding"))
1068
1069  # Create indices matrix for output tensor.
1070  _, planes_out, rows_out, cols_out, _ = [
1071      dim.value for dim in op.outputs[0].shape.dims
1072  ]
1073  _, ksize_p, ksize_r, ksize_c, _ = op.get_attr("ksizes")
1074  # Indices for output start from 0.
1075  prc_indices_num = planes_out * rows_out * cols_out
1076  output_indices_num = prc_indices_num * ksize_p * ksize_r * ksize_c
1077  output_idx = array_ops.reshape(
1078      math_ops.range(output_indices_num, dtype=ops.dtypes.int64),
1079      (1, planes_out, rows_out, cols_out, ksize_p * ksize_r * ksize_c))
1080
1081  # Construct mapping table for indices: (input -> output).
1082  idx_matrix = array_ops.concat([
1083      array_ops.expand_dims(input_idx_patched, axis=-1),
1084      array_ops.expand_dims(output_idx, axis=-1)
1085  ],
1086                                axis=-1)
1087  idx_map = array_ops.reshape(idx_matrix, (-1, 2))
1088
1089  sp_shape = (input_indices_num, output_indices_num)
1090  sp_mat_full = sparse_tensor.SparseTensor(
1091      idx_map, array_ops.ones([output_indices_num], dtype=grad.dtype), sp_shape)
1092  # Remove all padding locations [0, :].
1093  sp_mat = sparse_ops.sparse_slice(sp_mat_full, (1, 0),
1094                                   (input_indices_num - 1, output_indices_num))
1095
1096  grad_expanded = array_ops.transpose(
1097      array_ops.reshape(
1098          _IndexedSlicesToTensorNoWarning(grad),
1099          (batch_size, planes_out, rows_out, cols_out, ksize_p, ksize_r,
1100           ksize_c, channels)), (1, 2, 3, 4, 5, 6, 0, 7))
1101  grad_flat = array_ops.reshape(grad_expanded, (-1, batch_size * channels))
1102
1103  jac = sparse_ops.sparse_tensor_dense_matmul(sp_mat, grad_flat)
1104
1105  grad_out = array_ops.reshape(
1106      jac, (planes_in, rows_in, cols_in, batch_size, channels))
1107  grad_out = array_ops.transpose(grad_out, (3, 0, 1, 2, 4))
1108
1109  return [grad_out]
1110
1111
1112@ops.RegisterGradient("ScatterNd")
1113def _ScatterNdGrad(op, grad):
1114  indices = op.inputs[0]
1115  updates_grad = array_ops.gather_nd(grad, indices)
1116  return [None, updates_grad, None]
1117
1118
1119@ops.RegisterGradient("TensorScatterUpdate")
1120def _TensorScatterUpdateGrad(op, grad):
1121  indices = op.inputs[1]
1122  updates_grad = array_ops.gather_nd(grad, indices)
1123  tensor_grad = array_ops.tensor_scatter_update(
1124      array_ops.identity(grad), indices,
1125      array_ops.zeros_like(op.inputs[2], dtype=grad.dtype))
1126  return [tensor_grad, None, updates_grad]
1127
1128
1129@ops.RegisterGradient("TensorScatterAdd")
1130def _TensorScatterAddGrad(op, grad):
1131  indices = op.inputs[1]
1132  updates_grad = array_ops.gather_nd(grad, indices)
1133  tensor_grad = array_ops.identity(grad)
1134  return [tensor_grad, None, updates_grad]
1135
1136
1137def _TensorScatterMinOrMaxGrad(op, grad):
1138  """Gradient for TensorScatterMin and TensorScatterMax."""
1139  indices = op.inputs[1]
1140  x = op.inputs[0]
1141  y = op.inputs[2]
1142  output = op.outputs[0]
1143  x_indicators = math_ops.cast(math_ops.equal(x, output), grad.dtype)
1144  y_output = array_ops.gather_nd(output, indices)
1145  y_indicators = math_ops.cast(math_ops.equal(y, y_output), grad.dtype)
1146  ys_indicators = array_ops.scatter_nd(indices, y_indicators,
1147                                       array_ops.shape(x))
1148  indicators = x_indicators + ys_indicators  # All elements are >= 1.
1149  # If there are multiple minimum or maximum elements then the gradient will be
1150  # divided between them.
1151  x_grad = grad * x_indicators / indicators
1152  y_grad = array_ops.gather_nd(grad / indicators, indices) * y_indicators
1153  return [x_grad, None, y_grad]
1154
1155
1156@ops.RegisterGradient("TensorScatterMax")
1157def _TensorScatterMaxGrad(op, grad):
1158  """Gradient for TensorScatterMax op."""
1159  return _TensorScatterMinOrMaxGrad(op, grad)
1160
1161
1162@ops.RegisterGradient("TensorScatterMin")
1163def _TensorScatterMinGrad(op, grad):
1164  """Gradient for TensorScatterMin op."""
1165  return _TensorScatterMinOrMaxGrad(op, grad)
1166
1167
1168@ops.RegisterGradient("TensorScatterSub")
1169def _TensorScatterSubGrad(op, grad):
1170  indices = op.inputs[1]
1171  updates_grad = array_ops.gather_nd(grad, indices)
1172  tensor_grad = array_ops.identity(grad)
1173  return [tensor_grad, None, -updates_grad]
1174
1175
1176@ops.RegisterGradient("ScatterNdNonAliasingAdd")
1177def _ScatterNdNonAliasingAddGrad(op, grad):
1178  indices = op.inputs[1]
1179  updates_grad = array_ops.gather_nd(grad, indices)
1180  return [grad, None, updates_grad]
1181
1182
1183@ops.RegisterGradient("BroadcastTo")
1184def _BroadcastToGrad(op, grad):
1185  input_value = op.inputs[0]
1186  broadcast_shape = op.inputs[1]
1187  input_value_shape = array_ops.shape(input_value)
1188  if not isinstance(broadcast_shape, ops.EagerTensor):
1189    broadcast_shape_static = tensor_shape.TensorShape(
1190        pywrap_tf_session.TF_TryEvaluateConstant_wrapper(
1191            broadcast_shape.graph._c_graph, broadcast_shape._as_tf_output()))  # pylint: disable=protected-access
1192    if broadcast_shape_static.is_fully_defined():
1193      broadcast_shape = constant_op.constant(
1194          broadcast_shape_static.as_list(), dtype=dtypes.int32)
1195  _, reduction_axes = gen_array_ops.broadcast_gradient_args(
1196      broadcast_shape, input_value_shape)
1197  updates_grad_reshaped = math_ops.reduce_sum(
1198      grad, axis=reduction_axes, keepdims=True)
1199  updates_grad = array_ops.reshape(updates_grad_reshaped, input_value_shape)
1200  return [updates_grad, None]
1201