• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Gradients for operators defined in math_ops.py."""
16import numpy as np
17
18from tensorflow.python.compat import compat
19from tensorflow.python.eager import context
20from tensorflow.python.framework import constant_op
21from tensorflow.python.framework import dtypes
22from tensorflow.python.framework import ops
23from tensorflow.python.framework import tensor_util
24from tensorflow.python.ops import array_ops
25from tensorflow.python.ops import gen_array_ops
26from tensorflow.python.ops import gen_math_ops
27from tensorflow.python.ops import math_ops
28from tensorflow.python.ops import special_math_ops
29
30
31def _safe_shape_div(x, y):
32  """Divides `x / y` assuming `x, y >= 0`, treating `0 / 0 = 0`."""
33  return x // math_ops.maximum(y, 1)
34
35
36@ops.RegisterGradient("ArgMax")
37def _ArgMaxGrad(op, grad):
38  del op, grad
39  return [None, None]
40
41
42@ops.RegisterGradient("ArgMin")
43def _ArgMinGrad(op, grad):
44  del op, grad
45  return [None, None]
46
47
48@ops.RegisterGradient("EuclideanNorm")
49def _EuclideanNormGrad(op, grad):
50  """Gradient for EuclideanNorm."""
51
52  output = op.outputs[0]
53
54  if not op.get_attr("keep_dims"):
55    output_shape_kept_dims = math_ops.reduced_shape(
56        array_ops.shape(op.inputs[0]), op.inputs[1])
57    output = array_ops.reshape(output, output_shape_kept_dims)
58    grad = array_ops.reshape(grad, output_shape_kept_dims)
59
60  return math_ops.truediv(op.inputs[0], output / grad), None
61
62
63def SmartBroadcastGradientArgs(x, y, grad):
64  """Optimized version of `broadcast_gradient_args` that caches results.
65
66  This implementation avoids creating `broadcast_gradient_args` ops in the case
67  that the input shapes are fully defined, and provides hints to the calling
68  code that can be used to avoid creating reduction and reshaping ops.
69
70  Args:
71    x: The left input tensor to a broadcasting binary op.
72    y: The right input tensor to a broadcasting binary op.
73    grad: The incoming gradient tensor for a broadcasting binary op.
74
75  Returns:
76    A pair of tuples, containing:
77      * A 3-tuple of broadcast information for x, containing:
78        * The shape of x (as a tuple or Tensor).
79        * The reduction indices for x (as a tuple or Tensor).
80        * A boolean, which if True, indicates that x's shape differs from grad's
81          shape (and so x's gradient must be reduced and/or reshaped).
82      * A 3-tuple of broadcast information for y, containing the respective
83        details for y.
84  """
85  # NOTE: It may be productive to apply these optimizations in the eager case
86  # as well.
87  if context.executing_eagerly() or not (
88      isinstance(x, ops.Tensor) and isinstance(y, ops.Tensor)
89      and isinstance(grad, ops.Tensor)):
90    sx = array_ops.shape(x)
91    sy = array_ops.shape(y)
92    rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
93    return (sx, rx, True), (sy, ry, True)
94
95  # pylint: disable=protected-access
96  x_shape_tuple = x._shape_tuple()
97  y_shape_tuple = y._shape_tuple()
98  grad_shape_tuple = grad._shape_tuple()
99  # pylint: enable=protected-access
100
101  if (x_shape_tuple is None or None in x_shape_tuple or
102      y_shape_tuple is None or None in y_shape_tuple):
103    sx = array_ops.shape_internal(x, optimize=False)
104    sy = array_ops.shape_internal(y, optimize=False)
105    rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
106    return (sx, rx, True), (sy, ry, True)
107
108  x_needs_reduction = x_shape_tuple != grad_shape_tuple
109  y_needs_reduction = y_shape_tuple != grad_shape_tuple
110
111  # Get the default graph rather than relying on `x.graph`, `y.graph`, or
112  # `grad.graph`, because these may be eager tensors.
113  g = ops.get_default_graph()
114
115  try:
116    rx, ry = g._bcast_grad_args_cache[(x_shape_tuple, y_shape_tuple)]  # pylint: disable=protected-access
117    return (x_shape_tuple, rx, x_needs_reduction), (
118        y_shape_tuple, ry, y_needs_reduction)
119  except KeyError:
120    rx, ry = array_ops.broadcast_gradient_args(x_shape_tuple, y_shape_tuple)
121    # TODO(mrry): If this becomes a bottleneck, add a multi-output version of
122    # `TF_TryEvaluateConstant()`.
123    rx_value = tuple(tensor_util.try_evaluate_constant(rx))
124    assert rx_value is not None
125    ry_value = tuple(tensor_util.try_evaluate_constant(ry))
126    assert ry_value is not None
127    g._bcast_grad_args_cache[(x_shape_tuple, y_shape_tuple)] = (  # pylint: disable=protected-access
128        rx_value, ry_value)
129
130    return (x_shape_tuple, rx_value, x_needs_reduction), (
131        y_shape_tuple, ry_value, y_needs_reduction)
132
133
134_empty_tuple = ()
135
136
137def _IsScalar(x):
138  return x._shape_tuple() is _empty_tuple  # pylint: disable=protected-access
139
140
141@ops.RegisterGradient("Sum")
142def _SumGrad(op, grad):
143  """Gradient for Sum."""
144  # Fast path for when reducing to a scalar and ndims is known: adds only
145  # Reshape and Tile ops (and possibly a Shape).
146  input_0_shape = op.inputs[0]._shape_tuple()  # pylint: disable=protected-access
147  if input_0_shape is not None:
148    axes = tensor_util.constant_value(op.inputs[1])
149    if axes is not None:
150      rank = len(input_0_shape)
151      if np.array_equal(axes, np.arange(rank)):  # Reduce all dims.
152        if context.executing_eagerly():
153          ctx = context.context()
154          new_shape = ctx.ones_rank_cache().get(rank)
155          if new_shape is None:
156            new_shape = constant_op.constant([1] * rank, dtype=dtypes.int32)
157            ctx.ones_rank_cache().put(rank, new_shape)
158        else:
159          new_shape = [1] * rank
160        grad = array_ops.reshape(grad, new_shape)
161        # If shape is not fully defined (but rank is), we use Shape.
162        if None not in input_0_shape:
163          input_shape = constant_op.constant(input_0_shape, dtype=dtypes.int32)
164        else:
165          input_shape = array_ops.shape(op.inputs[0])
166        return [array_ops.tile(grad, input_shape), None]
167      elif None not in input_0_shape and not context.executing_eagerly():
168        # The shape and reduction indices are statically known, so we use a
169        # graph-level cache to avoid recomputing `reduced_shape()` for each
170        # invocation.
171        graph = ops.get_default_graph()
172
173        # Canonicalize `axes` to be a tuple of indices. The incoming
174        # value may be a scalar or a vector, and may include negative indices.
175        axes = tuple(axes.reshape(-1))
176
177        try:
178          output_shape_kept_dims, tile_scaling = graph._reduced_shape_cache[  # pylint: disable=protected-access
179              (input_0_shape, axes)]
180        except KeyError:
181
182          # Compute and cache `output_shape_kept_dims` and `tile_scaling`.
183          def EvaluateAsTuple(t):
184            if tensor_util.is_tf_type(t):
185              value = tensor_util.try_evaluate_constant(t)
186              assert value is not None
187            else:
188              value = t
189            return tuple(value)
190
191          output_shape_kept_dims = EvaluateAsTuple(
192              math_ops.reduced_shape(input_0_shape, axes))
193          tile_scaling = EvaluateAsTuple(
194              _safe_shape_div(input_0_shape, output_shape_kept_dims))
195          graph._reduced_shape_cache[(input_0_shape, axes)] = (  # pylint:disable=protected-access
196              output_shape_kept_dims, tile_scaling)
197
198        grad = array_ops.reshape(grad, output_shape_kept_dims)
199        return [array_ops.tile(grad, tile_scaling), None]
200
201  input_shape = array_ops.shape(op.inputs[0])
202
203  if not op.get_attr("keep_dims"):
204    with ops.colocate_with(input_shape):
205      # TODO(apassos) remove this once device placement for eager ops makes
206      # more sense.
207      output_shape_kept_dims = math_ops.reduced_shape(input_shape,
208                                                      op.inputs[1])
209    grad = array_ops.reshape(grad, output_shape_kept_dims)
210  return [array_ops.broadcast_to(grad, input_shape), None]
211
212
213def _MinOrMaxGrad(op, grad):
214  """Gradient for Min or Max. Amazingly it's precisely the same code."""
215  input_shape = array_ops.shape(op.inputs[0])
216  y = op.outputs[0]
217  if not op.get_attr("keep_dims"):
218    output_shape_kept_dims = math_ops.reduced_shape(input_shape, op.inputs[1])
219    y = array_ops.reshape(y, output_shape_kept_dims)
220    grad = array_ops.reshape(grad, output_shape_kept_dims)
221  else:
222    output_shape_kept_dims = array_ops.shape(y)
223
224  # Compute the number of selected (maximum or minimum) elements in each
225  # reduction dimension. If there are multiple minimum or maximum elements
226  # then the gradient will be divided between them.
227  indicators = math_ops.cast(math_ops.equal(y, op.inputs[0]), grad.dtype)
228  num_selected = array_ops.reshape(
229      math_ops.reduce_sum(indicators, op.inputs[1]), output_shape_kept_dims)
230
231  return [math_ops.divide(indicators, num_selected) * grad, None]
232
233
234@ops.RegisterGradient("Max")
235def _MaxGrad(op, grad):
236  """Gradient for Max."""
237  return _MinOrMaxGrad(op, grad)
238
239
240@ops.RegisterGradient("Min")
241def _MinGrad(op, grad):
242  return _MinOrMaxGrad(op, grad)
243
244
245@ops.RegisterGradient("Mean")
246def _MeanGrad(op, grad):
247  """Gradient for Mean."""
248  sum_grad = _SumGrad(op, grad)[0]
249  input_shape = op.inputs[0]._shape_tuple()  # pylint: disable=protected-access
250  output_shape = op.outputs[0]._shape_tuple()  # pylint: disable=protected-access
251  if (input_shape is not None and output_shape is not None and
252      None not in input_shape and None not in output_shape):
253    input_size = np.prod(input_shape)
254    output_size = np.prod(output_shape)
255    factor = input_size // max(output_size, 1)
256    factor = constant_op.constant(factor, dtype=sum_grad.dtype)
257  else:
258    input_shape = array_ops.shape(op.inputs[0])
259    output_shape = array_ops.shape(op.outputs[0])
260    factor = _safe_shape_div(
261        math_ops.reduce_prod(input_shape), math_ops.reduce_prod(output_shape))
262  return math_ops.truediv(sum_grad, math_ops.cast(factor, sum_grad.dtype)), None
263
264
265@ops.RegisterGradient("Prod")
266def _ProdGrad(op, grad):
267  """Gradient for Prod."""
268  # The gradient can be expressed by dividing the product by each entry of the
269  # input tensor, but this approach can't deal with zeros in the input.
270  # Here, we avoid this problem by composing the output as a product of two
271  # cumprod operations.
272
273  input_shape = array_ops.shape(op.inputs[0])
274  # Reshape reduction indices for the case where the parameter is a scalar
275  reduction_indices = array_ops.reshape(op.inputs[1], [-1])
276
277  # Expand grad to full input shape
278  if not op.get_attr("keep_dims"):
279    output_shape_kept_dims = math_ops.reduced_shape(input_shape, op.inputs[1])
280    grad = array_ops.reshape(grad, output_shape_kept_dims)
281
282  grad = array_ops.broadcast_to(grad, input_shape)
283
284  # Pack all reduced dimensions into a single one, so we can perform the
285  # cumprod ops. If the reduction dims list is empty, it defaults to float32,
286  # so we need to cast here.  We put all the shape-related ops on CPU to avoid
287  # copying back and forth, and since listdiff is CPU only.
288  with ops.device("/cpu:0"):
289    rank = array_ops.rank(op.inputs[0])
290    reduction_indices = (reduction_indices + rank) % rank
291    reduced = math_ops.cast(reduction_indices, dtypes.int32)
292    idx = math_ops.range(0, rank)
293    other, _ = gen_array_ops.list_diff(idx, reduced, dtypes.int32)
294    perm = array_ops.concat([reduced, other], 0)
295    reduced_num = math_ops.reduce_prod(array_ops.gather(input_shape, reduced))
296    other_num = math_ops.reduce_prod(array_ops.gather(input_shape, other))
297  permuted = array_ops.transpose(op.inputs[0], perm)
298  permuted_shape = array_ops.shape(permuted)
299  reshaped = array_ops.reshape(permuted, (reduced_num, other_num))
300
301  # Calculate product, leaving out the current entry
302  left = math_ops.cumprod(reshaped, axis=0, exclusive=True)
303  right = math_ops.cumprod(reshaped, axis=0, exclusive=True, reverse=True)
304  # For complex inputs, the gradient is in the conjugate direction.
305  y = array_ops.reshape(
306      math_ops.conj(left) * math_ops.conj(right), permuted_shape)
307
308  # Invert the transpose and reshape operations.
309  # Make sure to set the statically known shape information through a reshape.
310  out = grad * array_ops.transpose(y, array_ops.invert_permutation(perm))
311  return array_ops.reshape(out, input_shape), None
312
313
314@ops.RegisterGradient("SegmentSum")
315def _SegmentSumGrad(op, grad):
316  """Gradient for SegmentSum."""
317  return array_ops.gather(grad, op.inputs[1]), None
318
319
320@ops.RegisterGradient("SegmentMean")
321def _SegmentMeanGrad(op, grad):
322  """Gradient for SegmentMean."""
323  input_rank = array_ops.rank(op.inputs[0])
324  ones_shape = array_ops.concat([
325      array_ops.shape(op.inputs[1]),
326      array_ops.ones(
327          array_ops.expand_dims(input_rank - 1, 0), dtype=dtypes.int32)
328  ], 0)
329  ones = array_ops.ones(ones_shape, dtype=grad.dtype)
330  scaled_grad = math_ops.divide(grad, math_ops.segment_sum(ones, op.inputs[1]))
331  return array_ops.gather(scaled_grad, op.inputs[1]), None
332
333
334@ops.RegisterGradient("SparseSegmentSum")
335def _SparseSegmentSumGrad(op, grad):
336  """Gradient for SparseSegmentSum."""
337  dim0 = array_ops.shape(op.inputs[0])[0]
338  if compat.forward_compatible(2021, 6, 10):
339    return (math_ops.sparse_segment_sum_grad(grad, op.inputs[1], op.inputs[2],
340                                             dim0), None, None)
341  else:
342    return (math_ops.unsorted_segment_sum(
343        array_ops.gather(grad, op.inputs[2]), op.inputs[1], dim0), None, None)
344
345
346@ops.RegisterGradient("SparseSegmentSumWithNumSegments")
347def _SparseSegmentSumWithNumSegmentsGrad(op, grad):
348  """Gradient for SparseSegmentSumWithNumSegments."""
349  dim0 = array_ops.shape(op.inputs[0])[0]
350  if compat.forward_compatible(2021, 6, 10):
351    return (math_ops.sparse_segment_sum_grad(grad, op.inputs[1], op.inputs[2],
352                                             dim0), None, None, None)
353  else:
354    return (math_ops.unsorted_segment_sum(
355        array_ops.gather(grad, op.inputs[2]), op.inputs[1],
356        dim0), None, None, None)
357
358
359@ops.RegisterGradient("SparseSegmentMean")
360def _SparseSegmentMeanGrad(op, grad):
361  """Gradient for SparseSegmentMean."""
362  dim0 = array_ops.shape(op.inputs[0])[0]
363  return (math_ops.sparse_segment_mean_grad(grad, op.inputs[1], op.inputs[2],
364                                            dim0), None, None)
365
366
367@ops.RegisterGradient("SparseSegmentMeanWithNumSegments")
368def _SparseSegmentMeanWithNumSegmentsGrad(op, grad):
369  """Gradient for SparseSegmentMeanWithNumSegments."""
370  dim0 = array_ops.shape(op.inputs[0])[0]
371  return (math_ops.sparse_segment_mean_grad(grad, op.inputs[1], op.inputs[2],
372                                            dim0), None, None, None)
373
374
375@ops.RegisterGradient("SparseSegmentSqrtN")
376def _SparseSegmentSqrtNGrad(op, grad):
377  """Gradient for SparseSegmentSqrtN."""
378  dim0 = array_ops.shape(op.inputs[0])[0]
379  return (math_ops.sparse_segment_sqrt_n_grad(grad, op.inputs[1], op.inputs[2],
380                                              dim0), None, None)
381
382
383@ops.RegisterGradient("SparseSegmentSqrtNWithNumSegments")
384def _SparseSegmentSqrtNWithNumSegmentsGrad(op, grad):
385  """Gradient for SparseSegmentSqrtNWithNumSegments."""
386  dim0 = array_ops.shape(op.inputs[0])[0]
387  return (math_ops.sparse_segment_sqrt_n_grad(grad, op.inputs[1], op.inputs[2],
388                                              dim0), None, None, None)
389
390
391def _SegmentMinOrMaxGrad(op, grad):
392  """ Gradient for SegmentMin and SegmentMax. """
393  zeros = array_ops.zeros_like(op.inputs[0], dtype=op.inputs[0].dtype)
394  # Get the number of selected (minimum or maximum) elements in each segment.
395  gathered_outputs = array_ops.gather(op.outputs[0], op.inputs[1])
396  is_selected = math_ops.equal(op.inputs[0], gathered_outputs)
397  num_selected = math_ops.segment_sum(
398      math_ops.cast(is_selected, grad.dtype), op.inputs[1])
399  # Compute the gradient for each segment. The gradient for the ith segment is
400  # divided evenly among the selected elements in that segment.
401  weighted_grads = math_ops.divide(grad, num_selected)
402  gathered_grads = array_ops.gather(weighted_grads, op.inputs[1])
403  return array_ops.where_v2(is_selected, gathered_grads, zeros), None
404
405
406@ops.RegisterGradient("SegmentMin")
407def _SegmentMinGrad(op, grad):
408  """Gradient for SegmentMin."""
409  return _SegmentMinOrMaxGrad(op, grad)
410
411
412@ops.RegisterGradient("SegmentMax")
413def _SegmentMaxGrad(op, grad):
414  """Gradient for SegmentMax."""
415  return _SegmentMinOrMaxGrad(op, grad)
416
417
418@ops.RegisterGradient("SegmentProd")
419def _SegmentProdGrad(op, grad):
420  """Gradient for SegmentProd.
421
422  The gradient can be expressed for each segment by dividing the segment's
423  product by each element of the segment input tensor, but this approach can't
424  deal with zeros in the input.
425  Unlike reduce_prod we can't use cumsum here as individual segments may have
426  a different number of elements. Therefore we consider three cases:
427  1) A segment input contains no zeros and we can safely divide by the input
428     tensor.
429  2) A segment contains exactly one zero. Then the gradient of each input of
430     the segment is zero except for the 0-input, there the gradient is
431     the product of the remaining segment entries.
432  3) A segment contains at least two zeros. The gradient is zero for all
433     segment inputs.
434  """
435  data = op.inputs[0]
436  segment_ids = op.inputs[1]
437  is_zero = math_ops.equal(data, 0)
438  num_zeros = gen_math_ops.segment_sum(
439      math_ops.cast(is_zero, dtype=dtypes.int32), segment_ids)
440  # handle case 3 and set the gradient to 0 for segments with more than one
441  # 0 as input
442  grad = array_ops.where_v2(
443      math_ops.greater(num_zeros, 1), array_ops.zeros_like(grad), grad)
444  # replace all zeros with ones and compute the segment_prod
445  non_zero_data = array_ops.where_v2(is_zero, array_ops.ones_like(data), data)
446  non_zero_prod = gen_math_ops.segment_prod(non_zero_data, segment_ids)
447  gathered_prod = array_ops.gather(op.outputs[0], segment_ids)
448  gathered_non_zero_prod = array_ops.gather(non_zero_prod, segment_ids)
449  prod_divided_by_el = gathered_prod / non_zero_data
450  # Now fetch the individual results for segments containing 0 and those that
451  # don't.
452  partial_derivative = array_ops.where_v2(is_zero, gathered_non_zero_prod,
453                                          prod_divided_by_el)
454  gathered_grad = array_ops.gather(grad, segment_ids)
455  return gathered_grad * partial_derivative, None
456
457
458def _GatherDropNegatives(params,
459                         ids,
460                         zero_clipped_indices=None,
461                         is_positive=None):
462  """ Helper function for unsorted segment ops.
463
464  Gathers params for
465      positive segment ids and gathers 0 for inputs with negative segment id.
466      Also returns the clipped indices and a boolean mask with the same shape
467      as ids where a positive id is masked as true. With this, the latter two
468      can be passed as arguments to this function to reuse them.
469  """
470  if zero_clipped_indices is None:
471    zero_clipped_indices = math_ops.maximum(ids, array_ops.zeros_like(ids))
472  gathered = array_ops.gather(params, zero_clipped_indices)
473  if is_positive is None:
474    is_positive = math_ops.greater_equal(ids, 0)
475    # tf.where(condition, x, y) requires condition to have the same shape as x
476    # and y.
477    is_positive_shape = array_ops.shape(is_positive)
478    broadcastable_shape = array_ops.concat(
479        [is_positive_shape,
480         array_ops.ones([array_ops.rank(gathered)
481                         - array_ops.rank(is_positive)],
482                        dtype=is_positive_shape.dtype)],
483        axis=0)
484    is_positive = array_ops.reshape(is_positive, broadcastable_shape)
485    is_positive = (
486        is_positive & array_ops.ones_like(gathered, dtype=dtypes.bool))
487  # replace gathered params of negative indices with 0
488  zero_slice = array_ops.zeros_like(gathered)
489  return (array_ops.where_v2(is_positive, gathered,
490                             zero_slice), zero_clipped_indices, is_positive)
491
492
493def _UnsortedSegmentMinOrMaxGrad(op, grad):
494  """ Gradient for UnsortedSegmentMin and UnsortedSegmentMax. """
495  # Get the number of selected (minimum or maximum) elements in each segment.
496  gathered_outputs, zero_clipped_indices, is_positive = \
497      _GatherDropNegatives(op.outputs[0], op.inputs[1])
498  is_selected = math_ops.equal(op.inputs[0], gathered_outputs)
499  is_selected = math_ops.logical_and(is_selected, is_positive)
500  num_selected = math_ops.unsorted_segment_sum(
501      math_ops.cast(is_selected, grad.dtype), op.inputs[1], op.inputs[2])
502  # Compute the gradient for each segment. The gradient for the ith segment is
503  # divided evenly among the selected elements in that segment.
504  weighted_grads = math_ops.divide(grad, num_selected)
505  gathered_grads, _, _ = _GatherDropNegatives(weighted_grads, None,
506                                              zero_clipped_indices, is_positive)
507  zeros = array_ops.zeros_like(gathered_grads)
508  return array_ops.where_v2(is_selected, gathered_grads, zeros), None, None
509
510
511@ops.RegisterGradient("UnsortedSegmentSum")
512def _UnsortedSegmentSumGrad(op, grad):
513  """Gradient for UnsortedSegmentSum."""
514  return _GatherDropNegatives(grad, op.inputs[1])[0], None, None
515
516
517@ops.RegisterGradient("UnsortedSegmentMax")
518def _UnsortedSegmentMaxGrad(op, grad):
519  """ Gradient for UnsortedSegmentMax. """
520  return _UnsortedSegmentMinOrMaxGrad(op, grad)
521
522
523@ops.RegisterGradient("UnsortedSegmentMin")
524def _UnsortedSegmentMinGrad(op, grad):
525  """ Gradient for UnsortedSegmentMin. """
526  return _UnsortedSegmentMinOrMaxGrad(op, grad)
527
528
529@ops.RegisterGradient("UnsortedSegmentProd")
530def _UnsortedSegmentProdGrad(op, grad):
531  """ Gradient for UnsortedSegmentProd.
532
533  The gradient can be expressed for each segment by dividing the segment's
534  product by each element of the segment input tensor, but this approach can't
535  deal with zeros in the input.
536  Unlike reduce_prod we can't use cumsum here as individual segments may have
537  a different number of elements. Therefore we consider three cases:
538  1) A segment input contains no zeros and we can safely divide by the input
539     tensor.
540  2) A segment contains exactly one zero. Then the gradient of each input of
541     the segment is zero except for the 0-input, there the gradient is
542     the product of the remaining segment entries.
543  3) A segment contains at least two zeros. The gradient is zero for all
544     segment inputs.
545  """
546  # Note that unsorted_segment_sum will filter out the negative indices,
547  # so we don't need to do a logical_and with is_positive here
548  is_zero = math_ops.equal(op.inputs[0], 0)
549  num_zeros = gen_math_ops.unsorted_segment_sum(
550      math_ops.cast(is_zero, dtype=dtypes.int32), op.inputs[1], op.inputs[2])
551  # handle case 3 and set the gradient to 0 for segments with more than one
552  # 0 as input
553  grad = array_ops.where_v2(
554      math_ops.greater(num_zeros, 1), array_ops.zeros_like(grad), grad)
555  # replace all zeros with ones and compute the unsorted_segment_prod
556  non_zero_data = array_ops.where_v2(is_zero, array_ops.ones_like(op.inputs[0]),
557                                     op.inputs[0])
558  non_zero_prod = gen_math_ops.unsorted_segment_prod(non_zero_data,
559                                                     op.inputs[1], op.inputs[2])
560  # clip the indices for gather to be positive
561  zero_clipped_indices = math_ops.maximum(op.inputs[1],
562                                          array_ops.zeros_like(op.inputs[1]))
563  gathered_prod = array_ops.gather(op.outputs[0], zero_clipped_indices)
564  gathered_non_zero_prod = array_ops.gather(non_zero_prod, zero_clipped_indices)
565  prod_divided_by_el = gathered_prod / op.inputs[0]  # May contain nan/inf.
566  # Now fetch the individual results for segments containing 0 and those that
567  # don't. is_zero will also fetch results for entries with negative index
568  # but the following gather_drop_negatives sets the corresponding entry in
569  # grad to 0 for these
570  partial_derivative = array_ops.where_v2(is_zero, gathered_non_zero_prod,
571                                          prod_divided_by_el)
572  gathered_grad = _GatherDropNegatives(grad, op.inputs[1],
573                                       zero_clipped_indices)[0]
574  return gathered_grad * partial_derivative, None, None
575
576
577@ops.RegisterGradient("Abs")
578def _AbsGrad(op, grad):
579  x = op.inputs[0]
580  return grad * math_ops.sign(x)
581
582
583@ops.RegisterGradient("Neg")
584def _NegGrad(_, grad):
585  """Returns -grad."""
586  return -grad
587
588
589@ops.RegisterGradient("Inv")
590def _InvGrad(op, grad):
591  """Returns -grad * (1 / x^2)."""
592  y = op.outputs[0]  # y = 1 / x
593  return gen_math_ops.reciprocal_grad(y, grad)
594
595
596@ops.RegisterGradient("Reciprocal")
597def _ReciprocalGrad(op, grad):
598  """Returns -grad * (1 / x^2)."""
599  y = op.outputs[0]  # y = 1 / x
600  return gen_math_ops.reciprocal_grad(y, grad)
601
602
603@ops.RegisterGradient("InvGrad")
604def _InvGradGrad(op, grad):
605  b = op.inputs[1]
606  # op.output[0]: y = -b * conj(a)^2
607  with ops.control_dependencies([grad]):
608    ca = math_ops.conj(op.inputs[0])
609    cg = math_ops.conj(grad)
610    return cg * -2.0 * b * ca, gen_math_ops.reciprocal_grad(ca, grad)
611
612
613@ops.RegisterGradient("ReciprocalGrad")
614def _ReciprocalGradGrad(op, grad):
615  b = op.inputs[1]
616  # op.output[0]: y = -b * conj(a)^2
617  with ops.control_dependencies([grad]):
618    ca = math_ops.conj(op.inputs[0])
619    cg = math_ops.conj(grad)
620    return cg * -2.0 * b * ca, gen_math_ops.reciprocal_grad(ca, grad)
621
622
623@ops.RegisterGradient("Square")
624def _SquareGrad(op, grad):
625  x = op.inputs[0]
626  # Added control dependencies to prevent 2*x from being computed too early.
627  with ops.control_dependencies([grad]):
628    x = math_ops.conj(x)
629    y = constant_op.constant(2.0, dtype=x.dtype)
630    return math_ops.multiply(grad, math_ops.multiply(x, y))
631
632
633@ops.RegisterGradient("Sqrt")
634def _SqrtGrad(op, grad):
635  y = op.outputs[0]  # y = x^(1/2)
636  return gen_math_ops.sqrt_grad(y, grad)
637
638
639@ops.RegisterGradient("SqrtGrad")
640def _SqrtGradGrad(op, grad):
641  a = op.inputs[0]
642  y = op.outputs[0]  # y = 0.5 * b / conj(a)
643  with ops.control_dependencies([grad]):
644    ga = grad / a
645    return -math_ops.conj(ga) * y, 0.5 * ga  # pylint: disable=invalid-unary-operand-type
646
647
648@ops.RegisterGradient("Rsqrt")
649def _RsqrtGrad(op, grad):
650  """Returns -0.5 * grad * conj(y)^3."""
651  y = op.outputs[0]  # y = x^(-1/2)
652  return gen_math_ops.rsqrt_grad(y, grad)
653
654
655@ops.RegisterGradient("RsqrtGrad")
656def _RsqrtGradGrad(op, grad):
657  """Returns backprop gradient for f(a,b) = -0.5 * b * conj(a)^3."""
658  a = op.inputs[0]  # a = x^{-1/2}
659  b = op.inputs[1]  # backprop gradient for a
660  with ops.control_dependencies([grad]):
661    ca = math_ops.conj(a)
662    cg = math_ops.conj(grad)
663    grad_a = -1.5 * cg * b * math_ops.square(ca)
664    grad_b = gen_math_ops.rsqrt_grad(ca, grad)
665    return grad_a, grad_b
666
667
668@ops.RegisterGradient("Exp")
669def _ExpGrad(op, grad):
670  """Returns grad * exp(x)."""
671  y = op.outputs[0]  # y = e^x
672  with ops.control_dependencies([grad]):
673    y = math_ops.conj(y)
674    return grad * y
675
676
677@ops.RegisterGradient("Expm1")
678def _Expm1Grad(op, grad):
679  """Returns grad * exp(x)."""
680  x = op.inputs[0]
681  with ops.control_dependencies([grad]):
682    x = math_ops.conj(x)
683    y = math_ops.exp(x)
684    return grad * y
685
686
687@ops.RegisterGradient("Log")
688def _LogGrad(op, grad):
689  """Returns grad * (1/x)."""
690  x = op.inputs[0]
691  with ops.control_dependencies([grad]):
692    x = math_ops.conj(x)
693    return grad * math_ops.reciprocal(x)
694
695
696@ops.RegisterGradient("Log1p")
697def _Log1pGrad(op, grad):
698  """Returns grad * (1/(1 + x))."""
699  x = op.inputs[0]
700  with ops.control_dependencies([grad]):
701    x = math_ops.conj(x)
702    return grad * math_ops.reciprocal(1 + x)
703
704
705@ops.RegisterGradient("Xlogy")
706def _XLogyGrad(op, grad):
707  """Returns gradient of xlogy(x, y) with respect to x and y."""
708  x = op.inputs[0]
709  y = op.inputs[1]
710  sx = array_ops.shape(x)
711  sy = array_ops.shape(y)
712  rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
713  with ops.control_dependencies([grad]):
714    not_zero_x = math_ops.cast(
715        math_ops.not_equal(x, math_ops.cast(0., dtype=x.dtype)), dtype=x.dtype)
716    partial_x = gen_math_ops.xlogy(not_zero_x, y)
717    partial_y = gen_math_ops.xdivy(x, y)
718    return (array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx),
719            array_ops.reshape(math_ops.reduce_sum(partial_y * grad, ry), sy))
720
721
722@ops.RegisterGradient("Xlog1py")
723def _XLog1pyGrad(op, grad):
724  """Returns gradient of xlog1py(x, y) with respect to x and y."""
725  x = op.inputs[0]
726  y = op.inputs[1]
727  sx = array_ops.shape(x)
728  sy = array_ops.shape(y)
729  rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
730  with ops.control_dependencies([grad]):
731    not_zero_x = math_ops.cast(
732        math_ops.not_equal(x, math_ops.cast(0., dtype=x.dtype)), dtype=x.dtype)
733    partial_x = gen_math_ops.xlog1py(not_zero_x, y)
734    partial_y = gen_math_ops.xdivy(x, y + 1.)
735    return (array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx),
736            array_ops.reshape(math_ops.reduce_sum(partial_y * grad, ry), sy))
737
738
739@ops.RegisterGradient("Xdivy")
740def _XDivyGrad(op, grad):
741  """Returns gradient of xdivy(x, y) with respect to x and y."""
742  x = op.inputs[0]
743  y = op.inputs[1]
744  sx = array_ops.shape(x)
745  sy = array_ops.shape(y)
746  rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
747  with ops.control_dependencies([grad]):
748    not_zero_x = math_ops.cast(
749        math_ops.not_equal(x, math_ops.cast(0., dtype=x.dtype)), dtype=x.dtype)
750    partial_x = gen_math_ops.xdivy(not_zero_x, y)
751    partial_y = gen_math_ops.xdivy(math_ops.negative(x), y**2)
752    return (array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx),
753            array_ops.reshape(math_ops.reduce_sum(partial_y * grad, ry), sy))
754
755
756@ops.RegisterGradient("Sinh")
757def _SinhGrad(op, grad):
758  """Returns grad * cosh(x)."""
759  x = op.inputs[0]
760  with ops.control_dependencies([grad]):
761    x = math_ops.conj(x)
762    return grad * math_ops.cosh(x)
763
764
765@ops.RegisterGradient("Cosh")
766def _CoshGrad(op, grad):
767  """Returns grad * sinh(x)."""
768  x = op.inputs[0]
769  with ops.control_dependencies([grad]):
770    x = math_ops.conj(x)
771    return grad * math_ops.sinh(x)
772
773
774@ops.RegisterGradient("Tanh")
775def _TanhGrad(op, grad):
776  """Returns grad * (1 - tanh(x) * tanh(x))."""
777  y = op.outputs[0]  # y = tanh(x)
778  with ops.control_dependencies([grad]):
779    y = math_ops.conj(y)
780    return gen_math_ops.tanh_grad(y, grad)
781
782
783@ops.RegisterGradient("Asinh")
784def _AsinhGrad(op, grad):
785  """Returns grad * 1/cosh(y)."""
786  y = op.outputs[0]
787  with ops.control_dependencies([grad]):
788    y = math_ops.conj(y)
789    return grad / math_ops.cosh(y)
790
791
792@ops.RegisterGradient("Acosh")
793def _AcoshGrad(op, grad):
794  """Returns grad * 1/sinh(y)."""
795  y = op.outputs[0]
796  with ops.control_dependencies([grad]):
797    y = math_ops.conj(y)
798    return grad / math_ops.sinh(y)
799
800
801@ops.RegisterGradient("Atanh")
802def _AtanhGrad(op, grad):
803  """Returns grad * 1/ (1 - x^2)."""
804  x = op.inputs[0]
805  with ops.control_dependencies([grad]):
806    x = math_ops.conj(x)
807    x2 = math_ops.square(x)
808    one = constant_op.constant(1, dtype=grad.dtype)
809    inv = math_ops.reciprocal(math_ops.subtract(one, x2))
810    return grad * inv
811
812
813@ops.RegisterGradient("TanhGrad")
814def _TanhGradGrad(op, grad):
815  with ops.control_dependencies([grad]):
816    a = math_ops.conj(op.inputs[0])
817    b = math_ops.conj(op.inputs[1])
818    return grad * -2.0 * b * a, gen_math_ops.tanh_grad(a, grad)
819
820
821@ops.RegisterGradient("Erf")
822def _ErfGrad(op, grad):
823  """Returns grad * 2/sqrt(pi) * exp(-x**2)."""
824  x = op.inputs[0]
825  two_over_root_pi = constant_op.constant(2 / np.sqrt(np.pi), dtype=grad.dtype)
826  with ops.control_dependencies([grad]):
827    x = math_ops.conj(x)
828    return grad * two_over_root_pi * math_ops.exp(-math_ops.square(x))
829
830
831@ops.RegisterGradient("Erfc")
832def _ErfcGrad(op, grad):
833  """Returns -grad * 2/sqrt(pi) * exp(-x**2)."""
834  x = op.inputs[0]
835  minus_two_over_root_pi = constant_op.constant(
836      -2 / np.sqrt(np.pi), dtype=grad.dtype)
837  with ops.control_dependencies([grad]):
838    x = math_ops.conj(x)
839    return grad * minus_two_over_root_pi * math_ops.exp(-math_ops.square(x))
840
841
842@ops.RegisterGradient("Erfinv")
843def _ErfinvGrad(op, grad):
844  """Returns grad * sqrt(pi) / 2 * exp(erfinv(x)**2)."""
845  root_pi_over_two = constant_op.constant(np.sqrt(np.pi) / 2, dtype=grad.dtype)
846  with ops.control_dependencies([grad]):
847    return grad * root_pi_over_two * math_ops.exp(
848        math_ops.square(op.outputs[0]))
849
850
851@ops.RegisterGradient("Ndtri")
852def _NdtriGrad(op, grad):
853  """Returns grad * sqrt(2 * pi) * exp(ndtri(x)**2 / 2)."""
854  root_two_pi = constant_op.constant(np.sqrt(2 * np.pi), dtype=grad.dtype)
855  with ops.control_dependencies([grad]):
856    return grad * root_two_pi * math_ops.exp(
857        math_ops.square(op.outputs[0]) / 2.)
858
859
860@ops.RegisterGradient("Lgamma")
861def _LgammaGrad(op, grad):
862  """Returns grad * digamma(x)."""
863  x = op.inputs[0]
864  with ops.control_dependencies([grad]):
865    x = math_ops.conj(x)
866    return grad * math_ops.digamma(x)
867
868
869@ops.RegisterGradient("Digamma")
870def _DigammaGrad(op, grad):
871  """Compute gradient of the digamma function with respect to its argument."""
872  x = op.inputs[0]
873  with ops.control_dependencies([grad]):
874    x = math_ops.conj(x)
875    partial_x = math_ops.polygamma(array_ops.constant(1, dtype=x.dtype), x)
876    return grad * partial_x
877
878
879@ops.RegisterGradient("Dawsn")
880def _DawsnGrad(op, grad):
881  """Compute gradient of dawsn(x) with respect to its argument."""
882  x = op.inputs[0]
883  y = op.outputs[0]
884  with ops.control_dependencies([grad]):
885    return grad * (1. - 2 * x * y)
886
887
888@ops.RegisterGradient("Expint")
889def _ExpintGrad(op, grad):
890  """Compute gradient of expint(x) with respect to its argument."""
891  x = op.inputs[0]
892  with ops.control_dependencies([grad]):
893    return grad * math_ops.exp(x) / x
894
895
896@ops.RegisterGradient("FresnelCos")
897def _FresnelCosGrad(op, grad):
898  """Compute gradient of fresnel_cos(x) with respect to its argument."""
899  x = op.inputs[0]
900  with ops.control_dependencies([grad]):
901    return grad * math_ops.cos((np.pi  / 2.) * math_ops.square(x))
902
903
904@ops.RegisterGradient("FresnelSin")
905def _FresnelSinGrad(op, grad):
906  """Compute gradient of fresnel_sin(x) with respect to its argument."""
907  x = op.inputs[0]
908  with ops.control_dependencies([grad]):
909    return grad * math_ops.sin((np.pi  / 2.) * math_ops.square(x))
910
911
912@ops.RegisterGradient("Spence")
913def _SpenceGrad(op, grad):
914  """Compute gradient of spence(x) with respect to its argument."""
915  x = op.inputs[0]
916  with ops.control_dependencies([grad]):
917    partial_x = math_ops.log(x) / (1 - x)
918    partial_x = array_ops.where(
919        math_ops.equal(x, 1.), -array_ops.ones_like(x), partial_x)  # pylint: disable=invalid-unary-operand-type
920    return grad * partial_x
921
922
923@ops.RegisterGradient("BesselI0")
924def _BesselI0Grad(op, grad):
925  """Compute gradient of bessel_i0(x) with respect to its argument."""
926  x = op.inputs[0]
927  with ops.control_dependencies([grad]):
928    partial_x = special_math_ops.bessel_i1(x)
929    return grad * partial_x
930
931
932@ops.RegisterGradient("BesselI0e")
933def _BesselI0eGrad(op, grad):
934  """Compute gradient of bessel_i0e(x) with respect to its argument."""
935  x = op.inputs[0]
936  y = op.outputs[0]
937  with ops.control_dependencies([grad]):
938    partial_x = (special_math_ops.bessel_i1e(x) - math_ops.sign(x) * y)
939    return grad * partial_x
940
941
942@ops.RegisterGradient("BesselI1")
943def _BesselI1Grad(op, grad):
944  """Compute gradient of bessel_i1(x) with respect to its argument."""
945  x = op.inputs[0]
946  y = op.outputs[0]
947  with ops.control_dependencies([grad]):
948    # For x = 0, the correct gradient is 1.0.
949    # However, the main branch gives NaN because of the division by x, so
950    # we impute the gradient manually.
951    # An alternative solution is to express the gradient via bessel_i0 and
952    # bessel_i2, but the latter is not yet implemented in Eigen.
953    dy_dx = array_ops.where_v2(
954        math_ops.equal(x, 0.), math_ops.cast(1., x.dtype),
955        special_math_ops.bessel_i0(x) - math_ops.div(y, x))
956    return grad * dy_dx
957
958
959@ops.RegisterGradient("BesselI1e")
960def _BesselI1eGrad(op, grad):
961  """Compute gradient of bessel_i1e(x) with respect to its argument."""
962  x = op.inputs[0]
963  y = op.outputs[0]
964  with ops.control_dependencies([grad]):
965    # For x = 0, the correct gradient is 0.5.
966    # However, the main branch gives NaN because of the division by x, so
967    # we impute the gradient manually.
968    # An alternative solution is to express the gradient via bessel_i0e and
969    # bessel_i2e, but the latter is not yet implemented in Eigen.
970    dy_dx = array_ops.where_v2(
971        math_ops.equal(x, 0.), math_ops.cast(0.5, x.dtype),
972        special_math_ops.bessel_i0e(x) - y *
973        (math_ops.sign(x) + math_ops.reciprocal(x)))
974    return grad * dy_dx
975
976
977@ops.RegisterGradient("BesselK0")
978def _BesselK0Grad(op, grad):
979  """Compute gradient of bessel_k0(x) with respect to its argument."""
980  x = op.inputs[0]
981  with ops.control_dependencies([grad]):
982    partial_x = -special_math_ops.bessel_k1(x)
983    return grad * partial_x
984
985
986@ops.RegisterGradient("BesselK0e")
987def _BesselK0eGrad(op, grad):
988  """Compute gradient of bessel_k0e(x) with respect to its argument."""
989  x = op.inputs[0]
990  y = op.outputs[0]
991  with ops.control_dependencies([grad]):
992    partial_x = (y - special_math_ops.bessel_k1e(x))
993    return grad * partial_x
994
995
996@ops.RegisterGradient("BesselK1")
997def _BesselK1Grad(op, grad):
998  """Compute gradient of bessel_k1(x) with respect to its argument."""
999  x = op.inputs[0]
1000  y = op.outputs[0]
1001  with ops.control_dependencies([grad]):
1002    # At 0., this is NaN which is fine since the derivative is undefined
1003    # at 0.
1004    partial_x = -special_math_ops.bessel_k0(x) - math_ops.div(y, x)
1005    return grad * partial_x
1006
1007
1008@ops.RegisterGradient("BesselK1e")
1009def _BesselK1eGrad(op, grad):
1010  """Compute gradient of bessel_k1e(x) with respect to its argument."""
1011  x = op.inputs[0]
1012  y = op.outputs[0]
1013  with ops.control_dependencies([grad]):
1014    # At 0., this is NaN which is fine since the derivative is undefined
1015    # at 0.
1016    partial_x = (
1017        y * (1. - math_ops.reciprocal(x)) - special_math_ops.bessel_k0e(x))
1018    return grad * partial_x
1019
1020
1021@ops.RegisterGradient("BesselJ0")
1022def _BesselJ0Grad(op, grad):
1023  """Compute gradient of bessel_j0(x) with respect to its argument."""
1024  x = op.inputs[0]
1025  with ops.control_dependencies([grad]):
1026    partial_x = -special_math_ops.bessel_j1(x)
1027    return grad * partial_x
1028
1029
1030@ops.RegisterGradient("BesselJ1")
1031def _BesselJ1Grad(op, grad):
1032  """Compute gradient of bessel_j1(x) with respect to its argument."""
1033  x = op.inputs[0]
1034  y = op.outputs[0]
1035  with ops.control_dependencies([grad]):
1036    # For x = 0, the correct gradient is 0.5.
1037    # However, the main branch gives NaN because of the division by x, so
1038    # we impute the gradient manually.
1039    # An alternative solution is to express the gradient via bessel_i0e and
1040    # bessel_i2e, but the latter is not yet implemented in Eigen.
1041    dy_dx = array_ops.where_v2(
1042        math_ops.equal(x, 0.), math_ops.cast(0.5, x.dtype),
1043        special_math_ops.bessel_j0(x) - math_ops.div(y, x))
1044    return grad * dy_dx
1045
1046
1047@ops.RegisterGradient("BesselY0")
1048def _BesselY0Grad(op, grad):
1049  """Compute gradient of bessel_y0(x) with respect to its argument."""
1050  x = op.inputs[0]
1051  with ops.control_dependencies([grad]):
1052    partial_x = -special_math_ops.bessel_y1(x)
1053    return grad * partial_x
1054
1055
1056@ops.RegisterGradient("BesselY1")
1057def _BesselY1Grad(op, grad):
1058  """Compute gradient of bessel_y1(x) with respect to its argument."""
1059  x = op.inputs[0]
1060  y = op.outputs[0]
1061  with ops.control_dependencies([grad]):
1062    # At 0., this is NaN which is fine since the derivative is undefined
1063    # at 0.
1064    partial_x = special_math_ops.bessel_y0(x) - math_ops.div(y, x)
1065    return grad * partial_x
1066
1067
1068@ops.RegisterGradient("Igamma")
1069def _IgammaGrad(op, grad):
1070  """Returns gradient of igamma(a, x) with respect to a and x."""
1071  a = op.inputs[0]
1072  x = op.inputs[1]
1073  sa = array_ops.shape(a)
1074  sx = array_ops.shape(x)
1075  ra, rx = gen_array_ops.broadcast_gradient_args(sa, sx)
1076
1077  with ops.control_dependencies([grad]):
1078    partial_a = gen_math_ops.igamma_grad_a(a, x)
1079    # Perform operations in log space before summing, because Gamma(a)
1080    # and Gamma'(a) can grow large.
1081    partial_x = math_ops.exp(-x + (a - 1) * math_ops.log(x) -
1082                             math_ops.lgamma(a))
1083    return (array_ops.reshape(math_ops.reduce_sum(partial_a * grad, ra), sa),
1084            array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx))
1085
1086
1087@ops.RegisterGradient("Igammac")
1088def _IgammacGrad(op, grad):
1089  """Returns gradient of igammac(a, x) = 1 - igamma(a, x) w.r.t. a and x."""
1090  igamma_grad_a, igamma_grad_x = _IgammaGrad(op, grad)
1091  return (-igamma_grad_a, -igamma_grad_x)
1092
1093
1094@ops.RegisterGradient("Betainc")
1095def _BetaincGrad(op, grad):
1096  """Returns gradient of betainc(a, b, x) with respect to x."""
1097  # TODO(ebrevdo): Perhaps add the derivative w.r.t. a, b
1098  a, b, x = op.inputs
1099
1100  # two cases: x is a scalar and a/b are same-shaped tensors, or vice
1101  # versa; so its sufficient to check against shape(a).
1102  sa = array_ops.shape(a)
1103  sx = array_ops.shape(x)
1104  _, rx = gen_array_ops.broadcast_gradient_args(sa, sx)
1105
1106  # Perform operations in log space before summing, because terms
1107  # can grow large.
1108  log_beta = (
1109      gen_math_ops.lgamma(a) + gen_math_ops.lgamma(b) -
1110      gen_math_ops.lgamma(a + b))
1111  # We use xlog1py and xlogy since the derivatives should tend to
1112  # zero one of the tails when a is 1. or b is 1.
1113  partial_x = math_ops.exp(math_ops.xlog1py(b - 1, -x) +
1114                           math_ops.xlogy(a - 1, x) - log_beta)
1115
1116  return (
1117      None,  # da
1118      None,  # db
1119      array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx))
1120
1121
1122@ops.RegisterGradient("Zeta")
1123def _ZetaGrad(op, grad):
1124  """Returns gradient of zeta(x, q) with respect to x and q."""
1125  # TODO(tillahoffmann): Add derivative with respect to x
1126  x = op.inputs[0]
1127  q = op.inputs[1]
1128  # Broadcast gradients
1129  sx = array_ops.shape(x)
1130  sq = array_ops.shape(q)
1131  unused_rx, rq = gen_array_ops.broadcast_gradient_args(sx, sq)
1132  # Evaluate gradient
1133  with ops.control_dependencies([grad]):
1134    x = math_ops.conj(x)
1135    q = math_ops.conj(q)
1136    partial_q = -x * math_ops.zeta(x + 1, q)  # pylint: disable=invalid-unary-operand-type
1137    return (None,
1138            array_ops.reshape(math_ops.reduce_sum(partial_q * grad, rq), sq))
1139
1140
1141@ops.RegisterGradient("Polygamma")
1142def _PolygammaGrad(op, grad):
1143  """Returns gradient of psi(n, x) with respect to n and x."""
1144  # TODO(tillahoffmann): Add derivative with respect to n
1145  n = op.inputs[0]
1146  x = op.inputs[1]
1147  # Broadcast gradients
1148  sn = array_ops.shape(n)
1149  sx = array_ops.shape(x)
1150  unused_rn, rx = gen_array_ops.broadcast_gradient_args(sn, sx)
1151  # Evaluate gradient
1152  with ops.control_dependencies([grad]):
1153    n = math_ops.conj(n)
1154    x = math_ops.conj(x)
1155    partial_x = math_ops.polygamma(n + 1, x)
1156    return (None,
1157            array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx))
1158
1159
1160@ops.RegisterGradient("Sigmoid")
1161def _SigmoidGrad(op, grad):
1162  """Returns grad * sigmoid(x) * (1 - sigmoid(x))."""
1163  y = op.outputs[0]  # y = sigmoid(x)
1164  with ops.control_dependencies([grad]):
1165    y = math_ops.conj(y)
1166    return gen_math_ops.sigmoid_grad(y, grad)
1167
1168
1169@ops.RegisterGradient("SigmoidGrad")
1170def _SigmoidGradGrad(op, grad):
1171  with ops.control_dependencies([grad]):
1172    a = math_ops.conj(op.inputs[0])
1173    b = math_ops.conj(op.inputs[1])
1174    gb = grad * b
1175    return gb - 2.0 * gb * a, gen_math_ops.sigmoid_grad(a, grad)
1176
1177
1178@ops.RegisterGradient("Sign")
1179def _SignGrad(op, _):
1180  """Returns 0."""
1181  x = op.inputs[0]
1182  return array_ops.zeros_like(x)
1183
1184
1185@ops.RegisterGradient("Sin")
1186def _SinGrad(op, grad):
1187  """Returns grad * cos(x)."""
1188  x = op.inputs[0]
1189  with ops.control_dependencies([grad]):
1190    x = math_ops.conj(x)
1191    return grad * math_ops.cos(x)
1192
1193
1194@ops.RegisterGradient("Cos")
1195def _CosGrad(op, grad):
1196  """Returns grad * -sin(x)."""
1197  x = op.inputs[0]
1198  with ops.control_dependencies([grad]):
1199    x = math_ops.conj(x)
1200    return -grad * math_ops.sin(x)
1201
1202
1203@ops.RegisterGradient("Tan")
1204def _TanGrad(op, grad):
1205  """Returns grad * 1/sec^2(x)."""
1206  x = op.inputs[0]
1207  with ops.control_dependencies([grad]):
1208    x = math_ops.conj(x)
1209    secx = math_ops.reciprocal(math_ops.cos(x))
1210    secx2 = math_ops.square(secx)
1211    return secx2 * grad
1212
1213
1214@ops.RegisterGradient("Asin")
1215def _AsinGrad(op, grad):
1216  """Returns grad * 1/sqrt(1-x^2)."""
1217  x = op.inputs[0]
1218  with ops.control_dependencies([grad]):
1219    x = math_ops.conj(x)
1220    x2 = math_ops.square(x)
1221    one = constant_op.constant(1, dtype=grad.dtype)
1222    den = math_ops.sqrt(math_ops.subtract(one, x2))
1223    inv = math_ops.reciprocal(den)
1224    return grad * inv
1225
1226
1227@ops.RegisterGradient("Acos")
1228def _AcosGrad(op, grad):
1229  """Returns grad * -1/sqrt(1-x^2)."""
1230  x = op.inputs[0]
1231  with ops.control_dependencies([grad]):
1232    x = math_ops.conj(x)
1233    x2 = math_ops.square(x)
1234    one = constant_op.constant(1, dtype=grad.dtype)
1235    den = math_ops.sqrt(math_ops.subtract(one, x2))
1236    inv = math_ops.reciprocal(den)
1237    return -grad * inv
1238
1239
1240@ops.RegisterGradient("Atan")
1241def _AtanGrad(op, grad):
1242  """Returns grad * 1/ (1 + x^2)."""
1243  x = op.inputs[0]
1244  with ops.control_dependencies([grad]):
1245    x = math_ops.conj(x)
1246    x2 = math_ops.square(x)
1247    one = constant_op.constant(1, dtype=grad.dtype)
1248    inv = math_ops.reciprocal(math_ops.add(one, x2))
1249    return grad * inv
1250
1251
1252@ops.RegisterGradient("Atan2")
1253def _Atan2Grad(op, grad):
1254  """Returns grad * x / (x^2 + y^2), grad * -y / (x^2 + y^2)."""
1255  y = op.inputs[0]
1256  x = op.inputs[1]
1257  with ops.control_dependencies([grad]):
1258    grad_inv = grad / (math_ops.square(x) + math_ops.square(y))
1259    return x * grad_inv, -y * grad_inv
1260
1261
1262@ops.RegisterGradient("AddN")
1263def _AddNGrad(op, grad):
1264  """Copies the gradient to all inputs."""
1265  # Not broadcasting.
1266  return [grad] * len(op.inputs)
1267
1268
1269def _ShapesFullySpecifiedAndEqual(x, y, grad):
1270  # pylint: disable=protected-access
1271  x_shape = x._shape_tuple()
1272  y_shape = y._shape_tuple()
1273  grad_shape = grad._shape_tuple()
1274  # pylint: enable=protected-access
1275  return (x_shape == y_shape and x_shape == grad_shape and
1276          x_shape is not None and None not in x_shape)
1277
1278
1279@ops.RegisterGradient("Add")
1280@ops.RegisterGradient("AddV2")
1281def _AddGrad(op, grad):
1282  """Gradient for Add."""
1283  y = op.inputs[1]
1284  skip_input_indices = None
1285  try:
1286    skip_input_indices = op.skip_input_indices
1287    if skip_input_indices is not None and 1 in skip_input_indices and _IsScalar(
1288        y):
1289      return grad, None
1290  except AttributeError:
1291    # No gradient skipping, so do the full gradient computation
1292    pass
1293  x = op.inputs[0]
1294  if (isinstance(grad, ops.Tensor) and
1295      _ShapesFullySpecifiedAndEqual(x, y, grad)):
1296    return grad, grad
1297  (sx, rx, must_reduce_x), (sy, ry, must_reduce_y) = (
1298      SmartBroadcastGradientArgs(x, y, grad))
1299  if skip_input_indices is not None and 0 in skip_input_indices:
1300    gx = None
1301  elif not must_reduce_x:
1302    gx = grad
1303  else:
1304    gx = array_ops.reshape(math_ops.reduce_sum(grad, rx), sx)
1305  if skip_input_indices is not None and 1 in skip_input_indices:
1306    gy = None
1307  elif not must_reduce_y:
1308    gy = grad
1309  else:
1310    gy = array_ops.reshape(math_ops.reduce_sum(grad, ry), sy)
1311  return (gx, gy)
1312
1313
1314@ops.RegisterGradient("Sub")
1315def _SubGrad(op, grad):
1316  """Gradient for Sub."""
1317  y = op.inputs[1]
1318  skip_input_indices = None
1319  try:
1320    skip_input_indices = op.skip_input_indices
1321    if skip_input_indices is not None and 1 in skip_input_indices and _IsScalar(
1322        y):
1323      return grad, None
1324  except AttributeError:
1325    # No gradient skipping, so do the full gradient computation
1326    pass
1327  x = op.inputs[0]
1328  if (isinstance(grad, ops.Tensor) and
1329      _ShapesFullySpecifiedAndEqual(x, y, grad)):
1330    return grad, -grad
1331  (sx, rx, must_reduce_x), (sy, ry, must_reduce_y) = (
1332      SmartBroadcastGradientArgs(x, y, grad))
1333  if skip_input_indices is not None and 0 in skip_input_indices:
1334    gx = None
1335  elif not must_reduce_x:
1336    gx = grad
1337  else:
1338    gx = array_ops.reshape(math_ops.reduce_sum(grad, rx), sx)
1339  if skip_input_indices is not None and 1 in skip_input_indices:
1340    gy = None
1341  elif not must_reduce_y:
1342    gy = -grad
1343  else:
1344    gy = array_ops.reshape(math_ops.reduce_sum(-grad, ry), sy)
1345  return (gx, gy)
1346
1347
1348@ops.RegisterGradient("Mul")
1349def _MulGrad(op, grad):
1350  """The gradient of scalar multiplication."""
1351  y = op.inputs[1]
1352  skip_input_indices = None
1353  try:
1354    skip_input_indices = op.skip_input_indices
1355    if skip_input_indices is not None and 1 in skip_input_indices and _IsScalar(
1356        y):
1357      return gen_math_ops.mul(grad, math_ops.conj(y)), None
1358  except AttributeError:
1359    # No gradient skipping, so do the full gradient computation
1360    pass
1361  x = op.inputs[0]
1362  if (isinstance(grad, ops.Tensor) and
1363      _ShapesFullySpecifiedAndEqual(x, y, grad) and
1364      grad.dtype in (dtypes.int32, dtypes.float32)):
1365    return gen_math_ops.mul(grad, y), gen_math_ops.mul(grad, x)
1366  assert x.dtype.base_dtype == y.dtype.base_dtype, (x.dtype, " vs. ", y.dtype)
1367
1368  (sx, rx, must_reduce_x), (sy, ry, must_reduce_y) = (
1369      SmartBroadcastGradientArgs(x, y, grad))
1370  x = math_ops.conj(x)
1371  y = math_ops.conj(y)
1372  if skip_input_indices is not None and 0 in skip_input_indices:
1373    gx = None
1374  elif not must_reduce_x:
1375    gx = gen_math_ops.mul(grad, y)
1376  else:
1377    gx = array_ops.reshape(
1378        math_ops.reduce_sum(gen_math_ops.mul(grad, y), rx), sx)
1379  if skip_input_indices is not None and 1 in skip_input_indices:
1380    gy = None
1381  elif not must_reduce_y:
1382    gy = gen_math_ops.mul(x, grad)
1383  else:
1384    gy = array_ops.reshape(
1385        math_ops.reduce_sum(gen_math_ops.mul(x, grad), ry), sy)
1386  return (gx, gy)
1387
1388
1389@ops.RegisterGradient("MulNoNan")
1390def _MulNoNanGrad(op, grad):
1391  """The gradient of scalar multiplication with NaN-suppression."""
1392  x = op.inputs[0]
1393  y = op.inputs[1]
1394  if (isinstance(grad, ops.Tensor) and
1395      _ShapesFullySpecifiedAndEqual(x, y, grad)):
1396    return gen_math_ops.mul_no_nan(grad, y), gen_math_ops.mul_no_nan(x, grad)
1397  assert x.dtype.base_dtype == y.dtype.base_dtype, (x.dtype, " vs. ", y.dtype)
1398  sx = array_ops.shape(x)
1399  sy = array_ops.shape(y)
1400  rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
1401  return (array_ops.reshape(
1402      math_ops.reduce_sum(gen_math_ops.mul_no_nan(grad, y), rx), sx),
1403          array_ops.reshape(
1404              math_ops.reduce_sum(gen_math_ops.mul_no_nan(x, grad), ry), sy))
1405
1406
1407@ops.RegisterGradient("Div")
1408def _DivGrad(op, grad):
1409  """The gradient for the Div operator."""
1410  x = op.inputs[0]
1411  y = op.inputs[1]
1412  sx = array_ops.shape(x)
1413  sy = array_ops.shape(y)
1414  rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
1415  x = math_ops.conj(x)
1416  y = math_ops.conj(y)
1417  # pylint: disable=invalid-unary-operand-type
1418  return (
1419      array_ops.reshape(math_ops.reduce_sum(math_ops.divide(grad, y), rx), sx),
1420      array_ops.reshape(
1421          math_ops.reduce_sum(grad * math_ops.divide(math_ops.divide(-x, y), y),
1422                              ry), sy))
1423
1424
1425@ops.RegisterGradient("FloorDiv")
1426def _FloorDivGrad(_, unused_grad):
1427  """The gradient for the FloorDiv operator."""
1428  return None, None
1429
1430
1431@ops.RegisterGradient("FloorMod")
1432def _FloorModGrad(op, grad):
1433  """Returns grad * (1, -floor(x/y))."""
1434  x = math_ops.conj(op.inputs[0])
1435  y = math_ops.conj(op.inputs[1])
1436
1437  sx = array_ops.shape(x)
1438  sy = array_ops.shape(y)
1439  rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
1440  floor_xy = math_ops.floor_div(x, y)
1441  gx = array_ops.reshape(math_ops.reduce_sum(grad, rx), sx)
1442  gy = array_ops.reshape(
1443      math_ops.reduce_sum(grad * math_ops.negative(floor_xy), ry), sy)
1444  return gx, gy
1445
1446
1447@ops.RegisterGradient("TruncateDiv")
1448def _TruncateDivGrad(_, unused_grad):
1449  return None, None
1450
1451
1452@ops.RegisterGradient("RealDiv")
1453def _RealDivGrad(op, grad):
1454  """RealDiv op gradient."""
1455  x = op.inputs[0]
1456  y = op.inputs[1]
1457  sx = array_ops.shape(x)
1458  sy = array_ops.shape(y)
1459  rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
1460  x = math_ops.conj(x)
1461  y = math_ops.conj(y)
1462  return (array_ops.reshape(
1463      math_ops.reduce_sum(math_ops.realdiv(grad, y), rx), sx),
1464          array_ops.reshape(
1465              math_ops.reduce_sum(
1466                  grad * math_ops.realdiv(math_ops.realdiv(-x, y), y), ry), sy))  # pylint: disable=invalid-unary-operand-type
1467
1468
1469@ops.RegisterGradient("DivNoNan")
1470def _DivNoNanGrad(op, grad):
1471  """DivNoNan op gradient."""
1472  x = op.inputs[0]
1473  y = op.inputs[1]
1474  sx = array_ops.shape(x)
1475  sy = array_ops.shape(y)
1476  rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
1477  x = math_ops.conj(x)
1478  y = math_ops.conj(y)
1479  return (
1480      array_ops.reshape(
1481          math_ops.reduce_sum(math_ops.div_no_nan(grad, y), rx), sx),
1482      array_ops.reshape(
1483          math_ops.reduce_sum(
1484              grad * math_ops.div_no_nan(math_ops.div_no_nan(-x, y), y),  # pylint: disable=invalid-unary-operand-type
1485              ry),
1486          sy))
1487
1488
1489@ops.RegisterGradient("Pow")
1490def _PowGrad(op, grad):
1491  """Returns grad * (y*x^(y-1), z*log(x))."""
1492  x = op.inputs[0]
1493  y = op.inputs[1]
1494  skip_input_indices = None
1495  try:
1496    skip_input_indices = op.skip_input_indices
1497    # TODO(mrry): If `y` is a constant, we can combine `tf.sub()` and the
1498    # constant `1` into a single constant op.
1499    if skip_input_indices is not None and 1 in skip_input_indices and _IsScalar(
1500        y):
1501      x = math_ops.conj(x)
1502      y = math_ops.conj(y)
1503      return grad * y * math_ops.pow(x, y - 1), None
1504
1505  except AttributeError:
1506    # No gradient skipping, so do the full gradient computation
1507    pass
1508
1509  (sx, rx, must_reduce_x), (sy, ry, must_reduce_y) = (
1510      SmartBroadcastGradientArgs(x, y, grad))
1511  x = math_ops.conj(x)
1512  y = math_ops.conj(y)
1513
1514  if skip_input_indices is None or 0 not in skip_input_indices:
1515    gx = grad * y * math_ops.pow(x, y - 1)
1516    if must_reduce_x:
1517      gx = array_ops.reshape(math_ops.reduce_sum(gx, rx), sx)
1518  else:
1519    gx = None
1520
1521  if skip_input_indices is None or 1 not in skip_input_indices:
1522    z = math_ops.conj(op.outputs[0])
1523
1524    # Avoid false singularity at x = 0
1525    if x.dtype.is_complex:
1526      # real(x) < 0 is fine for the complex case
1527      mask = math_ops.not_equal(x, 0)
1528    else:
1529      # There's no sensible real value to return if x < 0, so return 0
1530      mask = x > 0
1531    safe_x = array_ops.where(mask, x, array_ops.ones_like(x))
1532    log_x = array_ops.where(mask, math_ops.log(safe_x), array_ops.zeros_like(x))
1533    gy = grad * z * log_x
1534    if must_reduce_y:
1535      gy = array_ops.reshape(math_ops.reduce_sum(gy, ry), sy)
1536  else:
1537    gy = None
1538
1539  return gx, gy
1540
1541
1542def _MaximumMinimumGradInputOnly(op, grad, selector_op):
1543  x = op.inputs[0]
1544  y = op.inputs[1]
1545  zeros = array_ops.zeros_like(grad)
1546  xmask = selector_op(x, y)
1547  xgrad = array_ops.where_v2(xmask, grad, zeros)
1548  ygrad = None  # Return None for ygrad since the config allows that.
1549  return (xgrad, ygrad)
1550
1551
1552def _MaximumMinimumGrad(op, grad, selector_op):
1553  """Factor out the code for the gradient of Maximum or Minimum."""
1554  y = op.inputs[1]
1555  skip_input_indices = None
1556  try:
1557    skip_input_indices = op.skip_input_indices
1558    if skip_input_indices is not None and 1 in skip_input_indices and _IsScalar(
1559        y):
1560      # When we want to get gradients for the first input only, and the second
1561      # input tensor is a scalar, we can do a much simpler calculation
1562      return _MaximumMinimumGradInputOnly(op, grad, selector_op)
1563  except AttributeError:
1564    # No gradient skipping, so do the full gradient computation
1565    pass
1566  x = op.inputs[0]
1567  sx = array_ops.shape(x)
1568  sy = array_ops.shape(y)
1569  zeros = array_ops.zeros_like(grad)
1570  xmask = selector_op(x, y)
1571  rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
1572  if skip_input_indices is not None and 0 in skip_input_indices:
1573    gx = None
1574  else:
1575    xgrad = array_ops.where_v2(xmask, grad, zeros)
1576    gx = array_ops.reshape(math_ops.reduce_sum(xgrad, rx), sx)
1577
1578  if skip_input_indices is not None and 1 in skip_input_indices:
1579    gy = None
1580  else:
1581    ygrad = array_ops.where_v2(xmask, zeros, grad)
1582    gy = array_ops.reshape(math_ops.reduce_sum(ygrad, ry), sy)
1583
1584  return (gx, gy)
1585
1586
1587@ops.RegisterGradient("Maximum")
1588def _MaximumGrad(op, grad):
1589  """Returns grad*(x >= y, x < y) with type of grad."""
1590  return _MaximumMinimumGrad(op, grad, math_ops.greater_equal)
1591
1592
1593@ops.RegisterGradient("Minimum")
1594def _MinimumGrad(op, grad):
1595  """Returns grad*(x <= y, x > y) with type of grad."""
1596  return _MaximumMinimumGrad(op, grad, math_ops.less_equal)
1597
1598
1599@ops.RegisterGradient("SquaredDifference")
1600def _SquaredDifferenceGrad(op, grad):
1601  """Returns the gradient for (x-y)^2."""
1602  x = op.inputs[0]
1603  y = op.inputs[1]
1604  skip_input_indices = None
1605  try:
1606    skip_input_indices = op.skip_input_indices
1607  except AttributeError:
1608    # No gradient skipping, so do the full gradient computation
1609    pass
1610
1611  with ops.control_dependencies([grad]):
1612    # The parens ensure that if grad is IndexedSlices, it'll get multiplied by
1613    # Tensor (not a number like 2.0) which causes it to convert to Tensor.
1614    x_grad = math_ops.scalar_mul(2.0, grad) * (x - y)
1615
1616  if (isinstance(grad, ops.Tensor) and
1617      _ShapesFullySpecifiedAndEqual(x, y, grad)):
1618    return x_grad, -x_grad
1619
1620  (sx, rx, must_reduce_x), (sy, ry, must_reduce_y) = (
1621      SmartBroadcastGradientArgs(x, y, grad))
1622
1623  if skip_input_indices is not None and 0 in skip_input_indices:
1624    gx = None
1625  elif must_reduce_x:
1626    gx = array_ops.reshape(math_ops.reduce_sum(x_grad, rx), sx)
1627  else:
1628    gx = x_grad
1629
1630  if skip_input_indices is not None and 1 in skip_input_indices:
1631    gy = None
1632  elif must_reduce_y:
1633    gy = -array_ops.reshape(math_ops.reduce_sum(x_grad, ry), sy)
1634  else:
1635    gy = -x_grad
1636  return (gx, gy)
1637
1638
1639# Logical operations have no gradients.
1640ops.NotDifferentiable("Less")
1641ops.NotDifferentiable("LessEqual")
1642ops.NotDifferentiable("Greater")
1643ops.NotDifferentiable("GreaterEqual")
1644ops.NotDifferentiable("Equal")
1645ops.NotDifferentiable("ApproximateEqual")
1646ops.NotDifferentiable("NotEqual")
1647ops.NotDifferentiable("LogicalAnd")
1648ops.NotDifferentiable("LogicalOr")
1649ops.NotDifferentiable("LogicalNot")
1650
1651
1652@ops.RegisterGradient("Select")
1653def _SelectGrad(op, grad):
1654  c = op.inputs[0]
1655  x = op.inputs[1]
1656  zeros = array_ops.zeros_like(x)
1657  return (None, array_ops.where(c, grad, zeros), array_ops.where(
1658      c, zeros, grad))
1659
1660
1661@ops.RegisterGradient("SelectV2")
1662def _SelectGradV2(op, grad):
1663  c = op.inputs[0]
1664  x = op.inputs[1]
1665  y = op.inputs[2]
1666  zeros = array_ops.zeros([], dtype=grad.dtype.base_dtype)
1667  gx = array_ops.where_v2(c, grad, zeros)
1668  x_shape = array_ops.shape(x)
1669  output_shape = array_ops.shape(op.outputs[0])
1670  # Reduce away broadcasted leading dims.
1671  reduce_x, _ = gen_array_ops.broadcast_gradient_args(x_shape, output_shape)
1672  gx = math_ops.reduce_sum(gx, keepdims=True, axis=reduce_x)
1673  gx = array_ops.reshape(gx, x_shape)
1674
1675  gy = array_ops.where_v2(c, zeros, grad)
1676  y_shape = array_ops.shape(y)
1677  # Reduce away broadcasted leading dims.
1678  reduce_y, _ = gen_array_ops.broadcast_gradient_args(y_shape, output_shape)
1679  gy = math_ops.reduce_sum(gy, keepdims=True, axis=reduce_y)
1680  gy = array_ops.reshape(gy, y_shape)
1681
1682  return (None, gx, gy)
1683
1684
1685def _MatMulGradAgainstFirstOnly(op, grad):
1686  """Gradient for MatMul, only for the first input."""
1687  t_a = op.get_attr("transpose_a")
1688  t_b = op.get_attr("transpose_b")
1689  b = math_ops.conj(op.inputs[1])
1690  if not t_a and not t_b:
1691    grad_a = gen_math_ops.mat_mul(grad, b, transpose_b=True)
1692  elif not t_a and t_b:
1693    grad_a = gen_math_ops.mat_mul(grad, b)
1694  elif t_a and not t_b:
1695    grad_a = gen_math_ops.mat_mul(b, grad, transpose_b=True)
1696  elif t_a and t_b:
1697    grad_a = gen_math_ops.mat_mul(b, grad, transpose_a=True, transpose_b=True)
1698  return grad_a, None
1699
1700
1701def _MatMulGradAgainstSecondOnly(op, grad):
1702  """Gradient for MatMul, only for the second input."""
1703  t_a = op.get_attr("transpose_a")
1704  t_b = op.get_attr("transpose_b")
1705  a = math_ops.conj(op.inputs[0])
1706  if not t_a and not t_b:
1707    grad_b = gen_math_ops.mat_mul(a, grad, transpose_a=True)
1708  elif not t_a and t_b:
1709    grad_b = gen_math_ops.mat_mul(grad, a, transpose_a=True)
1710  elif t_a and not t_b:
1711    grad_b = gen_math_ops.mat_mul(a, grad)
1712  elif t_a and t_b:
1713    grad_b = gen_math_ops.mat_mul(grad, a, transpose_a=True, transpose_b=True)
1714  return None, grad_b
1715
1716
1717@ops.RegisterGradient("MatMul")
1718def _MatMulGrad(op, grad):
1719  """Gradient for MatMul."""
1720  try:
1721    skip_input_indices = op.skip_input_indices
1722    if skip_input_indices is not None:
1723      if 1 in skip_input_indices:
1724        return _MatMulGradAgainstFirstOnly(op, grad)
1725      elif 0 in skip_input_indices:
1726        return _MatMulGradAgainstSecondOnly(op, grad)
1727  except AttributeError:
1728    # No gradient skipping, so do the full gradient computation
1729    pass
1730
1731  t_a = op.get_attr("transpose_a")
1732  t_b = op.get_attr("transpose_b")
1733  a = math_ops.conj(op.inputs[0])
1734  b = math_ops.conj(op.inputs[1])
1735  if not t_a and not t_b:
1736    grad_a = gen_math_ops.mat_mul(grad, b, transpose_b=True)
1737    grad_b = gen_math_ops.mat_mul(a, grad, transpose_a=True)
1738  elif not t_a and t_b:
1739    grad_a = gen_math_ops.mat_mul(grad, b)
1740    grad_b = gen_math_ops.mat_mul(grad, a, transpose_a=True)
1741  elif t_a and not t_b:
1742    grad_a = gen_math_ops.mat_mul(b, grad, transpose_b=True)
1743    grad_b = gen_math_ops.mat_mul(a, grad)
1744  elif t_a and t_b:
1745    grad_a = gen_math_ops.mat_mul(b, grad, transpose_a=True, transpose_b=True)
1746    grad_b = gen_math_ops.mat_mul(grad, a, transpose_a=True, transpose_b=True)
1747  return grad_a, grad_b
1748
1749
1750@ops.RegisterGradient("SparseMatMul")
1751def _SparseMatMulGrad(op, grad):
1752  """Gradient for SparseMatMul."""
1753
1754  t_a = op.get_attr("transpose_a")
1755  t_b = op.get_attr("transpose_b")
1756  is_sparse = {}
1757  is_sparse[op.inputs[0].ref()] = op.get_attr("a_is_sparse")
1758  is_sparse[op.inputs[1].ref()] = op.get_attr("b_is_sparse")
1759  # Use heuristic to figure out if grad might be sparse
1760  is_sparse[grad.ref()] = not context.executing_eagerly() and (
1761      grad.op.type == "ReluGrad")
1762
1763  def _SparseMatMul(t1, t2, out_dtype, transpose_a=False, transpose_b=False):
1764    """Helper function to create SparseMatMul op."""
1765
1766    assert t1.ref() in is_sparse and t2.ref() in is_sparse
1767    t1_sparse = is_sparse[t1.ref()]
1768    t2_sparse = is_sparse[t2.ref()]
1769    if transpose_b:
1770      t2 = array_ops.transpose(t2)
1771      transpose_b = False
1772    prod = math_ops.matmul(
1773        t1,
1774        t2,
1775        transpose_a=transpose_a,
1776        transpose_b=transpose_b,
1777        a_is_sparse=t1_sparse,
1778        b_is_sparse=t2_sparse)
1779    if prod.dtype != out_dtype:
1780      prod = math_ops.cast(prod, out_dtype)
1781    return prod
1782
1783  dtype_a = op.inputs[0].dtype
1784  dtype_b = op.inputs[1].dtype
1785  if not t_a and not t_b:
1786    return (_SparseMatMul(grad, op.inputs[1], dtype_a, transpose_b=True),
1787            _SparseMatMul(op.inputs[0], grad, dtype_b, transpose_a=True))
1788  elif not t_a and t_b:
1789    return (_SparseMatMul(grad, op.inputs[1], dtype_a),
1790            _SparseMatMul(grad, op.inputs[0], dtype_b, transpose_a=True))
1791  elif t_a and not t_b:
1792    return (_SparseMatMul(op.inputs[1], grad, dtype_a, transpose_b=True),
1793            _SparseMatMul(op.inputs[0], grad, dtype_b))
1794  elif t_a and t_b:
1795    return (_SparseMatMul(
1796        op.inputs[1], grad, dtype_a, transpose_a=True, transpose_b=True),
1797            _SparseMatMul(
1798                grad, op.inputs[0], dtype_b, transpose_a=True,
1799                transpose_b=True))
1800
1801
1802@ops.RegisterGradient("Floor")
1803def _FloorGrad(_, unused_grad):
1804  return [None]
1805
1806
1807@ops.RegisterGradient("Ceil")
1808def _CeilGrad(_, unused_grad):
1809  return [None]
1810
1811
1812@ops.RegisterGradient("Round")
1813def _RoundGrad(_, unused_grad):
1814  return [None]
1815
1816
1817@ops.RegisterGradient("Rint")
1818def _RintGrad(_, unused_grad):
1819  # the gradient of Rint is zero
1820  return [None]
1821
1822
1823@ops.RegisterGradient("BatchMatMul")
1824def _BatchMatMul(op, grad):
1825  """Returns the gradient of x and y given the gradient of x * y."""
1826  x = op.inputs[0]
1827  y = op.inputs[1]
1828  adj_x = op.get_attr("adj_x")
1829  adj_y = op.get_attr("adj_y")
1830
1831  if not adj_x:
1832    if not adj_y:
1833      grad_x = math_ops.matmul(grad, y, adjoint_a=False, adjoint_b=True)
1834      grad_y = math_ops.matmul(x, grad, adjoint_a=True, adjoint_b=False)
1835    else:
1836      grad_x = math_ops.matmul(grad, y, adjoint_a=False, adjoint_b=False)
1837      grad_y = math_ops.matmul(grad, x, adjoint_a=True, adjoint_b=False)
1838  else:
1839    if not adj_y:
1840      grad_x = math_ops.matmul(y, grad, adjoint_a=False, adjoint_b=True)
1841      grad_y = math_ops.matmul(x, grad, adjoint_a=False, adjoint_b=False)
1842    else:
1843      grad_x = math_ops.matmul(y, grad, adjoint_a=True, adjoint_b=True)
1844      grad_y = math_ops.matmul(grad, x, adjoint_a=True, adjoint_b=True)
1845
1846  return grad_x, grad_y
1847
1848
1849@ops.RegisterGradient("BatchMatMulV2")
1850@ops.RegisterGradient("BatchMatMulV3")
1851def _BatchMatMulV2(op, grad):
1852  """Returns the gradient of x and y given the gradient of x * y."""
1853  x = op.inputs[0]
1854  y = op.inputs[1]
1855  adj_x = op.get_attr("adj_x")
1856  adj_y = op.get_attr("adj_y")
1857
1858  if not adj_x:
1859    if not adj_y:
1860      grad_x = math_ops.matmul(grad, y, adjoint_a=False, adjoint_b=True)
1861      grad_y = math_ops.matmul(x, grad, adjoint_a=True, adjoint_b=False)
1862    else:
1863      grad_x = math_ops.matmul(grad, y, adjoint_a=False, adjoint_b=False)
1864      grad_y = math_ops.matmul(grad, x, adjoint_a=True, adjoint_b=False)
1865  else:
1866    if not adj_y:
1867      grad_x = math_ops.matmul(y, grad, adjoint_a=False, adjoint_b=True)
1868      grad_y = math_ops.matmul(x, grad, adjoint_a=False, adjoint_b=False)
1869    else:
1870      grad_x = math_ops.matmul(y, grad, adjoint_a=True, adjoint_b=True)
1871      grad_y = math_ops.matmul(grad, x, adjoint_a=True, adjoint_b=True)
1872
1873  # Possibly reduce along the broadcasted batch dimensions, if broadcasting
1874  # is required.
1875  shape_x_static = x.get_shape()
1876  shape_y_static = y.get_shape()
1877  output_may_have_non_empty_batch_shape = (
1878      (shape_x_static.rank is None or shape_x_static.rank > 2) or
1879      (shape_y_static.rank is None or shape_y_static.rank > 2))
1880  batch_shapes_match = (
1881      shape_x_static[:-2].is_fully_defined() and
1882      shape_y_static[:-2].is_fully_defined() and
1883      shape_x_static[:-2] == shape_y_static[:-2])
1884  if (not output_may_have_non_empty_batch_shape) or batch_shapes_match:
1885    return grad_x, grad_y
1886
1887  sx = array_ops.shape(x)
1888  sy = array_ops.shape(y)
1889  rx, ry = gen_array_ops.broadcast_gradient_args(sx[:-2], sy[:-2])
1890  grad_x = array_ops.reshape(math_ops.reduce_sum(grad_x, rx), sx)
1891  grad_y = array_ops.reshape(math_ops.reduce_sum(grad_y, ry), sy)
1892  return grad_x, grad_y
1893
1894
1895ops.NotDifferentiable("Range")
1896ops.NotDifferentiable("LinSpace")
1897
1898
1899@ops.RegisterGradient("Complex")
1900def _ComplexGrad(op, grad):
1901  """Returns the real and imaginary components of 'grad', respectively."""
1902  x = op.inputs[0]
1903  y = op.inputs[1]
1904  sx = array_ops.shape(x)
1905  sy = array_ops.shape(y)
1906  rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
1907  return (array_ops.reshape(math_ops.reduce_sum(math_ops.real(grad), rx), sx),
1908          array_ops.reshape(math_ops.reduce_sum(math_ops.imag(grad), ry), sy))
1909
1910
1911@ops.RegisterGradient("Real")
1912def _RealGrad(_, grad):
1913  """Returns 'grad' as the real part and set the imaginary part 0."""
1914  zero = constant_op.constant(0, dtype=grad.dtype)
1915  return math_ops.complex(grad, zero)
1916
1917
1918@ops.RegisterGradient("Imag")
1919def _ImagGrad(_, grad):
1920  """Returns 'grad' as the imaginary part and set the real part 0."""
1921  zero = constant_op.constant(0, dtype=grad.dtype)
1922  return math_ops.complex(zero, grad)
1923
1924
1925@ops.RegisterGradient("Angle")
1926def _AngleGrad(op, grad):
1927  """Returns -grad / (Im(x) + iRe(x))"""
1928  x = op.inputs[0]
1929  with ops.control_dependencies([grad]):
1930    re = math_ops.real(x)
1931    im = math_ops.imag(x)
1932    z = math_ops.reciprocal(math_ops.complex(im, re))
1933    zero = constant_op.constant(0, dtype=grad.dtype)
1934    complex_grad = math_ops.complex(grad, zero)
1935    return -complex_grad * z
1936
1937
1938@ops.RegisterGradient("Conj")
1939def _ConjGrad(_, grad):
1940  """Returns the complex conjugate of grad."""
1941  return math_ops.conj(grad)
1942
1943
1944@ops.RegisterGradient("ComplexAbs")
1945def _ComplexAbsGrad(op, grad):
1946  """Returns the gradient of ComplexAbs."""
1947  return math_ops.div_no_nan(
1948      math_ops.complex(
1949          grad, array_ops.zeros_like(grad)) * op.inputs[0],
1950      math_ops.complex(
1951          op.outputs[0], array_ops.zeros_like(op.outputs[0])))
1952
1953
1954@ops.RegisterGradient("Cast")
1955def _CastGrad(op, grad):
1956  t = [
1957      dtypes.float16, dtypes.float32, dtypes.float64, dtypes.bfloat16,
1958      dtypes.complex64, dtypes.complex128
1959  ]
1960  src_type = op.inputs[0].dtype.base_dtype
1961  dst_type = grad.dtype.base_dtype
1962  if src_type in t and dst_type in t:
1963    return math_ops.cast(grad, src_type)
1964  else:
1965    return None
1966
1967
1968@ops.RegisterGradient("Cross")
1969def _CrossGrad(op, grad):
1970  u = op.inputs[0]
1971  v = op.inputs[1]
1972  return (math_ops.cross(v, grad), math_ops.cross(grad, u))
1973
1974
1975@ops.RegisterGradient("Cumsum")
1976def _CumsumGrad(op, grad):
1977  axis = op.inputs[1]
1978  exclusive = op.get_attr("exclusive")
1979  reverse = op.get_attr("reverse")
1980  return [
1981      math_ops.cumsum(grad, axis, exclusive=exclusive, reverse=not reverse),
1982      None
1983  ]
1984
1985
1986@ops.RegisterGradient("Cumprod")
1987def _CumprodGrad(op, grad):
1988  x = op.inputs[0]
1989  axis = op.inputs[1]
1990  exclusive = op.get_attr("exclusive")
1991  reverse = op.get_attr("reverse")
1992
1993  prod = math_ops.cumprod(x, axis, exclusive=exclusive, reverse=reverse)
1994  out = math_ops.cumsum(
1995      prod * grad, axis, exclusive=exclusive, reverse=not reverse)
1996  return [math_ops.div_no_nan(out, x), None]
1997
1998
1999@ops.RegisterGradient("CumulativeLogsumexp")
2000def _CumulativeLogsumexpGrad(op, grad):
2001  x = op.inputs[0]
2002  axis = op.inputs[1]
2003  cumulative_logsumexp = op.outputs[0]
2004
2005  exclusive = op.get_attr("exclusive")
2006  reverse = op.get_attr("reverse")
2007
2008  # Split the incoming gradient into positive and negative part
2009  # in order to take logs. This is required for stable results.
2010  log_grad_positive = array_ops.where_v2(
2011      math_ops.greater(grad, 0),
2012      math_ops.log(grad),
2013      grad.dtype.min)
2014
2015  log_grad_negative = array_ops.where_v2(
2016      math_ops.less(grad, 0),
2017      math_ops.log(-grad),
2018      grad.dtype.min)
2019
2020  output_pos = math_ops.exp(
2021      math_ops.cumulative_logsumexp(
2022          log_grad_positive - cumulative_logsumexp,
2023          axis=axis, reverse=not reverse, exclusive=exclusive) + x)
2024
2025  output_neg = math_ops.exp(
2026      math_ops.cumulative_logsumexp(
2027          log_grad_negative - cumulative_logsumexp,
2028          axis=axis, reverse=not reverse, exclusive=exclusive) + x)
2029
2030  return [output_pos - output_neg, None]
2031
2032
2033@ops.RegisterGradient("NextAfter")
2034def _NextAfterGrad(op, grad):
2035  """Returns gradient of nextafter(x1, x2) with respect to x1 and x2."""
2036  x1 = op.inputs[0]
2037  x2 = op.inputs[1]
2038  s_x1 = array_ops.shape(x1)
2039  s_x2 = array_ops.shape(x2)
2040  r_x1, r_x2 = gen_array_ops.broadcast_gradient_args(s_x1, s_x2)
2041  with ops.control_dependencies([grad]):
2042    partial_x1 = array_ops.ones(s_x1, dtype=x1.dtype)
2043    partial_x2 = array_ops.zeros(s_x2, dtype=x2.dtype)
2044    return (array_ops.reshape(
2045        math_ops.reduce_sum(partial_x1 * grad, r_x1), s_x1),
2046            array_ops.reshape(
2047                math_ops.reduce_sum(partial_x2 * grad, r_x2), s_x2))
2048