1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Gradients for operators defined in math_ops.py.""" 16import numpy as np 17 18from tensorflow.python.compat import compat 19from tensorflow.python.eager import context 20from tensorflow.python.framework import constant_op 21from tensorflow.python.framework import dtypes 22from tensorflow.python.framework import ops 23from tensorflow.python.framework import tensor_util 24from tensorflow.python.ops import array_ops 25from tensorflow.python.ops import gen_array_ops 26from tensorflow.python.ops import gen_math_ops 27from tensorflow.python.ops import math_ops 28from tensorflow.python.ops import special_math_ops 29 30 31def _safe_shape_div(x, y): 32 """Divides `x / y` assuming `x, y >= 0`, treating `0 / 0 = 0`.""" 33 return x // math_ops.maximum(y, 1) 34 35 36@ops.RegisterGradient("ArgMax") 37def _ArgMaxGrad(op, grad): 38 del op, grad 39 return [None, None] 40 41 42@ops.RegisterGradient("ArgMin") 43def _ArgMinGrad(op, grad): 44 del op, grad 45 return [None, None] 46 47 48@ops.RegisterGradient("EuclideanNorm") 49def _EuclideanNormGrad(op, grad): 50 """Gradient for EuclideanNorm.""" 51 52 output = op.outputs[0] 53 54 if not op.get_attr("keep_dims"): 55 output_shape_kept_dims = math_ops.reduced_shape( 56 array_ops.shape(op.inputs[0]), op.inputs[1]) 57 output = array_ops.reshape(output, output_shape_kept_dims) 58 grad = array_ops.reshape(grad, output_shape_kept_dims) 59 60 return math_ops.truediv(op.inputs[0], output / grad), None 61 62 63def SmartBroadcastGradientArgs(x, y, grad): 64 """Optimized version of `broadcast_gradient_args` that caches results. 65 66 This implementation avoids creating `broadcast_gradient_args` ops in the case 67 that the input shapes are fully defined, and provides hints to the calling 68 code that can be used to avoid creating reduction and reshaping ops. 69 70 Args: 71 x: The left input tensor to a broadcasting binary op. 72 y: The right input tensor to a broadcasting binary op. 73 grad: The incoming gradient tensor for a broadcasting binary op. 74 75 Returns: 76 A pair of tuples, containing: 77 * A 3-tuple of broadcast information for x, containing: 78 * The shape of x (as a tuple or Tensor). 79 * The reduction indices for x (as a tuple or Tensor). 80 * A boolean, which if True, indicates that x's shape differs from grad's 81 shape (and so x's gradient must be reduced and/or reshaped). 82 * A 3-tuple of broadcast information for y, containing the respective 83 details for y. 84 """ 85 # NOTE: It may be productive to apply these optimizations in the eager case 86 # as well. 87 if context.executing_eagerly() or not ( 88 isinstance(x, ops.Tensor) and isinstance(y, ops.Tensor) 89 and isinstance(grad, ops.Tensor)): 90 sx = array_ops.shape(x) 91 sy = array_ops.shape(y) 92 rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) 93 return (sx, rx, True), (sy, ry, True) 94 95 # pylint: disable=protected-access 96 x_shape_tuple = x._shape_tuple() 97 y_shape_tuple = y._shape_tuple() 98 grad_shape_tuple = grad._shape_tuple() 99 # pylint: enable=protected-access 100 101 if (x_shape_tuple is None or None in x_shape_tuple or 102 y_shape_tuple is None or None in y_shape_tuple): 103 sx = array_ops.shape_internal(x, optimize=False) 104 sy = array_ops.shape_internal(y, optimize=False) 105 rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) 106 return (sx, rx, True), (sy, ry, True) 107 108 x_needs_reduction = x_shape_tuple != grad_shape_tuple 109 y_needs_reduction = y_shape_tuple != grad_shape_tuple 110 111 # Get the default graph rather than relying on `x.graph`, `y.graph`, or 112 # `grad.graph`, because these may be eager tensors. 113 g = ops.get_default_graph() 114 115 try: 116 rx, ry = g._bcast_grad_args_cache[(x_shape_tuple, y_shape_tuple)] # pylint: disable=protected-access 117 return (x_shape_tuple, rx, x_needs_reduction), ( 118 y_shape_tuple, ry, y_needs_reduction) 119 except KeyError: 120 rx, ry = array_ops.broadcast_gradient_args(x_shape_tuple, y_shape_tuple) 121 # TODO(mrry): If this becomes a bottleneck, add a multi-output version of 122 # `TF_TryEvaluateConstant()`. 123 rx_value = tuple(tensor_util.try_evaluate_constant(rx)) 124 assert rx_value is not None 125 ry_value = tuple(tensor_util.try_evaluate_constant(ry)) 126 assert ry_value is not None 127 g._bcast_grad_args_cache[(x_shape_tuple, y_shape_tuple)] = ( # pylint: disable=protected-access 128 rx_value, ry_value) 129 130 return (x_shape_tuple, rx_value, x_needs_reduction), ( 131 y_shape_tuple, ry_value, y_needs_reduction) 132 133 134_empty_tuple = () 135 136 137def _IsScalar(x): 138 return x._shape_tuple() is _empty_tuple # pylint: disable=protected-access 139 140 141@ops.RegisterGradient("Sum") 142def _SumGrad(op, grad): 143 """Gradient for Sum.""" 144 # Fast path for when reducing to a scalar and ndims is known: adds only 145 # Reshape and Tile ops (and possibly a Shape). 146 input_0_shape = op.inputs[0]._shape_tuple() # pylint: disable=protected-access 147 if input_0_shape is not None: 148 axes = tensor_util.constant_value(op.inputs[1]) 149 if axes is not None: 150 rank = len(input_0_shape) 151 if np.array_equal(axes, np.arange(rank)): # Reduce all dims. 152 if context.executing_eagerly(): 153 ctx = context.context() 154 new_shape = ctx.ones_rank_cache().get(rank) 155 if new_shape is None: 156 new_shape = constant_op.constant([1] * rank, dtype=dtypes.int32) 157 ctx.ones_rank_cache().put(rank, new_shape) 158 else: 159 new_shape = [1] * rank 160 grad = array_ops.reshape(grad, new_shape) 161 # If shape is not fully defined (but rank is), we use Shape. 162 if None not in input_0_shape: 163 input_shape = constant_op.constant(input_0_shape, dtype=dtypes.int32) 164 else: 165 input_shape = array_ops.shape(op.inputs[0]) 166 return [array_ops.tile(grad, input_shape), None] 167 elif None not in input_0_shape and not context.executing_eagerly(): 168 # The shape and reduction indices are statically known, so we use a 169 # graph-level cache to avoid recomputing `reduced_shape()` for each 170 # invocation. 171 graph = ops.get_default_graph() 172 173 # Canonicalize `axes` to be a tuple of indices. The incoming 174 # value may be a scalar or a vector, and may include negative indices. 175 axes = tuple(axes.reshape(-1)) 176 177 try: 178 output_shape_kept_dims, tile_scaling = graph._reduced_shape_cache[ # pylint: disable=protected-access 179 (input_0_shape, axes)] 180 except KeyError: 181 182 # Compute and cache `output_shape_kept_dims` and `tile_scaling`. 183 def EvaluateAsTuple(t): 184 if tensor_util.is_tf_type(t): 185 value = tensor_util.try_evaluate_constant(t) 186 assert value is not None 187 else: 188 value = t 189 return tuple(value) 190 191 output_shape_kept_dims = EvaluateAsTuple( 192 math_ops.reduced_shape(input_0_shape, axes)) 193 tile_scaling = EvaluateAsTuple( 194 _safe_shape_div(input_0_shape, output_shape_kept_dims)) 195 graph._reduced_shape_cache[(input_0_shape, axes)] = ( # pylint:disable=protected-access 196 output_shape_kept_dims, tile_scaling) 197 198 grad = array_ops.reshape(grad, output_shape_kept_dims) 199 return [array_ops.tile(grad, tile_scaling), None] 200 201 input_shape = array_ops.shape(op.inputs[0]) 202 203 if not op.get_attr("keep_dims"): 204 with ops.colocate_with(input_shape): 205 # TODO(apassos) remove this once device placement for eager ops makes 206 # more sense. 207 output_shape_kept_dims = math_ops.reduced_shape(input_shape, 208 op.inputs[1]) 209 grad = array_ops.reshape(grad, output_shape_kept_dims) 210 return [array_ops.broadcast_to(grad, input_shape), None] 211 212 213def _MinOrMaxGrad(op, grad): 214 """Gradient for Min or Max. Amazingly it's precisely the same code.""" 215 input_shape = array_ops.shape(op.inputs[0]) 216 y = op.outputs[0] 217 if not op.get_attr("keep_dims"): 218 output_shape_kept_dims = math_ops.reduced_shape(input_shape, op.inputs[1]) 219 y = array_ops.reshape(y, output_shape_kept_dims) 220 grad = array_ops.reshape(grad, output_shape_kept_dims) 221 else: 222 output_shape_kept_dims = array_ops.shape(y) 223 224 # Compute the number of selected (maximum or minimum) elements in each 225 # reduction dimension. If there are multiple minimum or maximum elements 226 # then the gradient will be divided between them. 227 indicators = math_ops.cast(math_ops.equal(y, op.inputs[0]), grad.dtype) 228 num_selected = array_ops.reshape( 229 math_ops.reduce_sum(indicators, op.inputs[1]), output_shape_kept_dims) 230 231 return [math_ops.divide(indicators, num_selected) * grad, None] 232 233 234@ops.RegisterGradient("Max") 235def _MaxGrad(op, grad): 236 """Gradient for Max.""" 237 return _MinOrMaxGrad(op, grad) 238 239 240@ops.RegisterGradient("Min") 241def _MinGrad(op, grad): 242 return _MinOrMaxGrad(op, grad) 243 244 245@ops.RegisterGradient("Mean") 246def _MeanGrad(op, grad): 247 """Gradient for Mean.""" 248 sum_grad = _SumGrad(op, grad)[0] 249 input_shape = op.inputs[0]._shape_tuple() # pylint: disable=protected-access 250 output_shape = op.outputs[0]._shape_tuple() # pylint: disable=protected-access 251 if (input_shape is not None and output_shape is not None and 252 None not in input_shape and None not in output_shape): 253 input_size = np.prod(input_shape) 254 output_size = np.prod(output_shape) 255 factor = input_size // max(output_size, 1) 256 factor = constant_op.constant(factor, dtype=sum_grad.dtype) 257 else: 258 input_shape = array_ops.shape(op.inputs[0]) 259 output_shape = array_ops.shape(op.outputs[0]) 260 factor = _safe_shape_div( 261 math_ops.reduce_prod(input_shape), math_ops.reduce_prod(output_shape)) 262 return math_ops.truediv(sum_grad, math_ops.cast(factor, sum_grad.dtype)), None 263 264 265@ops.RegisterGradient("Prod") 266def _ProdGrad(op, grad): 267 """Gradient for Prod.""" 268 # The gradient can be expressed by dividing the product by each entry of the 269 # input tensor, but this approach can't deal with zeros in the input. 270 # Here, we avoid this problem by composing the output as a product of two 271 # cumprod operations. 272 273 input_shape = array_ops.shape(op.inputs[0]) 274 # Reshape reduction indices for the case where the parameter is a scalar 275 reduction_indices = array_ops.reshape(op.inputs[1], [-1]) 276 277 # Expand grad to full input shape 278 if not op.get_attr("keep_dims"): 279 output_shape_kept_dims = math_ops.reduced_shape(input_shape, op.inputs[1]) 280 grad = array_ops.reshape(grad, output_shape_kept_dims) 281 282 grad = array_ops.broadcast_to(grad, input_shape) 283 284 # Pack all reduced dimensions into a single one, so we can perform the 285 # cumprod ops. If the reduction dims list is empty, it defaults to float32, 286 # so we need to cast here. We put all the shape-related ops on CPU to avoid 287 # copying back and forth, and since listdiff is CPU only. 288 with ops.device("/cpu:0"): 289 rank = array_ops.rank(op.inputs[0]) 290 reduction_indices = (reduction_indices + rank) % rank 291 reduced = math_ops.cast(reduction_indices, dtypes.int32) 292 idx = math_ops.range(0, rank) 293 other, _ = gen_array_ops.list_diff(idx, reduced, dtypes.int32) 294 perm = array_ops.concat([reduced, other], 0) 295 reduced_num = math_ops.reduce_prod(array_ops.gather(input_shape, reduced)) 296 other_num = math_ops.reduce_prod(array_ops.gather(input_shape, other)) 297 permuted = array_ops.transpose(op.inputs[0], perm) 298 permuted_shape = array_ops.shape(permuted) 299 reshaped = array_ops.reshape(permuted, (reduced_num, other_num)) 300 301 # Calculate product, leaving out the current entry 302 left = math_ops.cumprod(reshaped, axis=0, exclusive=True) 303 right = math_ops.cumprod(reshaped, axis=0, exclusive=True, reverse=True) 304 # For complex inputs, the gradient is in the conjugate direction. 305 y = array_ops.reshape( 306 math_ops.conj(left) * math_ops.conj(right), permuted_shape) 307 308 # Invert the transpose and reshape operations. 309 # Make sure to set the statically known shape information through a reshape. 310 out = grad * array_ops.transpose(y, array_ops.invert_permutation(perm)) 311 return array_ops.reshape(out, input_shape), None 312 313 314@ops.RegisterGradient("SegmentSum") 315def _SegmentSumGrad(op, grad): 316 """Gradient for SegmentSum.""" 317 return array_ops.gather(grad, op.inputs[1]), None 318 319 320@ops.RegisterGradient("SegmentMean") 321def _SegmentMeanGrad(op, grad): 322 """Gradient for SegmentMean.""" 323 input_rank = array_ops.rank(op.inputs[0]) 324 ones_shape = array_ops.concat([ 325 array_ops.shape(op.inputs[1]), 326 array_ops.ones( 327 array_ops.expand_dims(input_rank - 1, 0), dtype=dtypes.int32) 328 ], 0) 329 ones = array_ops.ones(ones_shape, dtype=grad.dtype) 330 scaled_grad = math_ops.divide(grad, math_ops.segment_sum(ones, op.inputs[1])) 331 return array_ops.gather(scaled_grad, op.inputs[1]), None 332 333 334@ops.RegisterGradient("SparseSegmentSum") 335def _SparseSegmentSumGrad(op, grad): 336 """Gradient for SparseSegmentSum.""" 337 dim0 = array_ops.shape(op.inputs[0])[0] 338 if compat.forward_compatible(2021, 6, 10): 339 return (math_ops.sparse_segment_sum_grad(grad, op.inputs[1], op.inputs[2], 340 dim0), None, None) 341 else: 342 return (math_ops.unsorted_segment_sum( 343 array_ops.gather(grad, op.inputs[2]), op.inputs[1], dim0), None, None) 344 345 346@ops.RegisterGradient("SparseSegmentSumWithNumSegments") 347def _SparseSegmentSumWithNumSegmentsGrad(op, grad): 348 """Gradient for SparseSegmentSumWithNumSegments.""" 349 dim0 = array_ops.shape(op.inputs[0])[0] 350 if compat.forward_compatible(2021, 6, 10): 351 return (math_ops.sparse_segment_sum_grad(grad, op.inputs[1], op.inputs[2], 352 dim0), None, None, None) 353 else: 354 return (math_ops.unsorted_segment_sum( 355 array_ops.gather(grad, op.inputs[2]), op.inputs[1], 356 dim0), None, None, None) 357 358 359@ops.RegisterGradient("SparseSegmentMean") 360def _SparseSegmentMeanGrad(op, grad): 361 """Gradient for SparseSegmentMean.""" 362 dim0 = array_ops.shape(op.inputs[0])[0] 363 return (math_ops.sparse_segment_mean_grad(grad, op.inputs[1], op.inputs[2], 364 dim0), None, None) 365 366 367@ops.RegisterGradient("SparseSegmentMeanWithNumSegments") 368def _SparseSegmentMeanWithNumSegmentsGrad(op, grad): 369 """Gradient for SparseSegmentMeanWithNumSegments.""" 370 dim0 = array_ops.shape(op.inputs[0])[0] 371 return (math_ops.sparse_segment_mean_grad(grad, op.inputs[1], op.inputs[2], 372 dim0), None, None, None) 373 374 375@ops.RegisterGradient("SparseSegmentSqrtN") 376def _SparseSegmentSqrtNGrad(op, grad): 377 """Gradient for SparseSegmentSqrtN.""" 378 dim0 = array_ops.shape(op.inputs[0])[0] 379 return (math_ops.sparse_segment_sqrt_n_grad(grad, op.inputs[1], op.inputs[2], 380 dim0), None, None) 381 382 383@ops.RegisterGradient("SparseSegmentSqrtNWithNumSegments") 384def _SparseSegmentSqrtNWithNumSegmentsGrad(op, grad): 385 """Gradient for SparseSegmentSqrtNWithNumSegments.""" 386 dim0 = array_ops.shape(op.inputs[0])[0] 387 return (math_ops.sparse_segment_sqrt_n_grad(grad, op.inputs[1], op.inputs[2], 388 dim0), None, None, None) 389 390 391def _SegmentMinOrMaxGrad(op, grad): 392 """ Gradient for SegmentMin and SegmentMax. """ 393 zeros = array_ops.zeros_like(op.inputs[0], dtype=op.inputs[0].dtype) 394 # Get the number of selected (minimum or maximum) elements in each segment. 395 gathered_outputs = array_ops.gather(op.outputs[0], op.inputs[1]) 396 is_selected = math_ops.equal(op.inputs[0], gathered_outputs) 397 num_selected = math_ops.segment_sum( 398 math_ops.cast(is_selected, grad.dtype), op.inputs[1]) 399 # Compute the gradient for each segment. The gradient for the ith segment is 400 # divided evenly among the selected elements in that segment. 401 weighted_grads = math_ops.divide(grad, num_selected) 402 gathered_grads = array_ops.gather(weighted_grads, op.inputs[1]) 403 return array_ops.where_v2(is_selected, gathered_grads, zeros), None 404 405 406@ops.RegisterGradient("SegmentMin") 407def _SegmentMinGrad(op, grad): 408 """Gradient for SegmentMin.""" 409 return _SegmentMinOrMaxGrad(op, grad) 410 411 412@ops.RegisterGradient("SegmentMax") 413def _SegmentMaxGrad(op, grad): 414 """Gradient for SegmentMax.""" 415 return _SegmentMinOrMaxGrad(op, grad) 416 417 418@ops.RegisterGradient("SegmentProd") 419def _SegmentProdGrad(op, grad): 420 """Gradient for SegmentProd. 421 422 The gradient can be expressed for each segment by dividing the segment's 423 product by each element of the segment input tensor, but this approach can't 424 deal with zeros in the input. 425 Unlike reduce_prod we can't use cumsum here as individual segments may have 426 a different number of elements. Therefore we consider three cases: 427 1) A segment input contains no zeros and we can safely divide by the input 428 tensor. 429 2) A segment contains exactly one zero. Then the gradient of each input of 430 the segment is zero except for the 0-input, there the gradient is 431 the product of the remaining segment entries. 432 3) A segment contains at least two zeros. The gradient is zero for all 433 segment inputs. 434 """ 435 data = op.inputs[0] 436 segment_ids = op.inputs[1] 437 is_zero = math_ops.equal(data, 0) 438 num_zeros = gen_math_ops.segment_sum( 439 math_ops.cast(is_zero, dtype=dtypes.int32), segment_ids) 440 # handle case 3 and set the gradient to 0 for segments with more than one 441 # 0 as input 442 grad = array_ops.where_v2( 443 math_ops.greater(num_zeros, 1), array_ops.zeros_like(grad), grad) 444 # replace all zeros with ones and compute the segment_prod 445 non_zero_data = array_ops.where_v2(is_zero, array_ops.ones_like(data), data) 446 non_zero_prod = gen_math_ops.segment_prod(non_zero_data, segment_ids) 447 gathered_prod = array_ops.gather(op.outputs[0], segment_ids) 448 gathered_non_zero_prod = array_ops.gather(non_zero_prod, segment_ids) 449 prod_divided_by_el = gathered_prod / non_zero_data 450 # Now fetch the individual results for segments containing 0 and those that 451 # don't. 452 partial_derivative = array_ops.where_v2(is_zero, gathered_non_zero_prod, 453 prod_divided_by_el) 454 gathered_grad = array_ops.gather(grad, segment_ids) 455 return gathered_grad * partial_derivative, None 456 457 458def _GatherDropNegatives(params, 459 ids, 460 zero_clipped_indices=None, 461 is_positive=None): 462 """ Helper function for unsorted segment ops. 463 464 Gathers params for 465 positive segment ids and gathers 0 for inputs with negative segment id. 466 Also returns the clipped indices and a boolean mask with the same shape 467 as ids where a positive id is masked as true. With this, the latter two 468 can be passed as arguments to this function to reuse them. 469 """ 470 if zero_clipped_indices is None: 471 zero_clipped_indices = math_ops.maximum(ids, array_ops.zeros_like(ids)) 472 gathered = array_ops.gather(params, zero_clipped_indices) 473 if is_positive is None: 474 is_positive = math_ops.greater_equal(ids, 0) 475 # tf.where(condition, x, y) requires condition to have the same shape as x 476 # and y. 477 is_positive_shape = array_ops.shape(is_positive) 478 broadcastable_shape = array_ops.concat( 479 [is_positive_shape, 480 array_ops.ones([array_ops.rank(gathered) 481 - array_ops.rank(is_positive)], 482 dtype=is_positive_shape.dtype)], 483 axis=0) 484 is_positive = array_ops.reshape(is_positive, broadcastable_shape) 485 is_positive = ( 486 is_positive & array_ops.ones_like(gathered, dtype=dtypes.bool)) 487 # replace gathered params of negative indices with 0 488 zero_slice = array_ops.zeros_like(gathered) 489 return (array_ops.where_v2(is_positive, gathered, 490 zero_slice), zero_clipped_indices, is_positive) 491 492 493def _UnsortedSegmentMinOrMaxGrad(op, grad): 494 """ Gradient for UnsortedSegmentMin and UnsortedSegmentMax. """ 495 # Get the number of selected (minimum or maximum) elements in each segment. 496 gathered_outputs, zero_clipped_indices, is_positive = \ 497 _GatherDropNegatives(op.outputs[0], op.inputs[1]) 498 is_selected = math_ops.equal(op.inputs[0], gathered_outputs) 499 is_selected = math_ops.logical_and(is_selected, is_positive) 500 num_selected = math_ops.unsorted_segment_sum( 501 math_ops.cast(is_selected, grad.dtype), op.inputs[1], op.inputs[2]) 502 # Compute the gradient for each segment. The gradient for the ith segment is 503 # divided evenly among the selected elements in that segment. 504 weighted_grads = math_ops.divide(grad, num_selected) 505 gathered_grads, _, _ = _GatherDropNegatives(weighted_grads, None, 506 zero_clipped_indices, is_positive) 507 zeros = array_ops.zeros_like(gathered_grads) 508 return array_ops.where_v2(is_selected, gathered_grads, zeros), None, None 509 510 511@ops.RegisterGradient("UnsortedSegmentSum") 512def _UnsortedSegmentSumGrad(op, grad): 513 """Gradient for UnsortedSegmentSum.""" 514 return _GatherDropNegatives(grad, op.inputs[1])[0], None, None 515 516 517@ops.RegisterGradient("UnsortedSegmentMax") 518def _UnsortedSegmentMaxGrad(op, grad): 519 """ Gradient for UnsortedSegmentMax. """ 520 return _UnsortedSegmentMinOrMaxGrad(op, grad) 521 522 523@ops.RegisterGradient("UnsortedSegmentMin") 524def _UnsortedSegmentMinGrad(op, grad): 525 """ Gradient for UnsortedSegmentMin. """ 526 return _UnsortedSegmentMinOrMaxGrad(op, grad) 527 528 529@ops.RegisterGradient("UnsortedSegmentProd") 530def _UnsortedSegmentProdGrad(op, grad): 531 """ Gradient for UnsortedSegmentProd. 532 533 The gradient can be expressed for each segment by dividing the segment's 534 product by each element of the segment input tensor, but this approach can't 535 deal with zeros in the input. 536 Unlike reduce_prod we can't use cumsum here as individual segments may have 537 a different number of elements. Therefore we consider three cases: 538 1) A segment input contains no zeros and we can safely divide by the input 539 tensor. 540 2) A segment contains exactly one zero. Then the gradient of each input of 541 the segment is zero except for the 0-input, there the gradient is 542 the product of the remaining segment entries. 543 3) A segment contains at least two zeros. The gradient is zero for all 544 segment inputs. 545 """ 546 # Note that unsorted_segment_sum will filter out the negative indices, 547 # so we don't need to do a logical_and with is_positive here 548 is_zero = math_ops.equal(op.inputs[0], 0) 549 num_zeros = gen_math_ops.unsorted_segment_sum( 550 math_ops.cast(is_zero, dtype=dtypes.int32), op.inputs[1], op.inputs[2]) 551 # handle case 3 and set the gradient to 0 for segments with more than one 552 # 0 as input 553 grad = array_ops.where_v2( 554 math_ops.greater(num_zeros, 1), array_ops.zeros_like(grad), grad) 555 # replace all zeros with ones and compute the unsorted_segment_prod 556 non_zero_data = array_ops.where_v2(is_zero, array_ops.ones_like(op.inputs[0]), 557 op.inputs[0]) 558 non_zero_prod = gen_math_ops.unsorted_segment_prod(non_zero_data, 559 op.inputs[1], op.inputs[2]) 560 # clip the indices for gather to be positive 561 zero_clipped_indices = math_ops.maximum(op.inputs[1], 562 array_ops.zeros_like(op.inputs[1])) 563 gathered_prod = array_ops.gather(op.outputs[0], zero_clipped_indices) 564 gathered_non_zero_prod = array_ops.gather(non_zero_prod, zero_clipped_indices) 565 prod_divided_by_el = gathered_prod / op.inputs[0] # May contain nan/inf. 566 # Now fetch the individual results for segments containing 0 and those that 567 # don't. is_zero will also fetch results for entries with negative index 568 # but the following gather_drop_negatives sets the corresponding entry in 569 # grad to 0 for these 570 partial_derivative = array_ops.where_v2(is_zero, gathered_non_zero_prod, 571 prod_divided_by_el) 572 gathered_grad = _GatherDropNegatives(grad, op.inputs[1], 573 zero_clipped_indices)[0] 574 return gathered_grad * partial_derivative, None, None 575 576 577@ops.RegisterGradient("Abs") 578def _AbsGrad(op, grad): 579 x = op.inputs[0] 580 return grad * math_ops.sign(x) 581 582 583@ops.RegisterGradient("Neg") 584def _NegGrad(_, grad): 585 """Returns -grad.""" 586 return -grad 587 588 589@ops.RegisterGradient("Inv") 590def _InvGrad(op, grad): 591 """Returns -grad * (1 / x^2).""" 592 y = op.outputs[0] # y = 1 / x 593 return gen_math_ops.reciprocal_grad(y, grad) 594 595 596@ops.RegisterGradient("Reciprocal") 597def _ReciprocalGrad(op, grad): 598 """Returns -grad * (1 / x^2).""" 599 y = op.outputs[0] # y = 1 / x 600 return gen_math_ops.reciprocal_grad(y, grad) 601 602 603@ops.RegisterGradient("InvGrad") 604def _InvGradGrad(op, grad): 605 b = op.inputs[1] 606 # op.output[0]: y = -b * conj(a)^2 607 with ops.control_dependencies([grad]): 608 ca = math_ops.conj(op.inputs[0]) 609 cg = math_ops.conj(grad) 610 return cg * -2.0 * b * ca, gen_math_ops.reciprocal_grad(ca, grad) 611 612 613@ops.RegisterGradient("ReciprocalGrad") 614def _ReciprocalGradGrad(op, grad): 615 b = op.inputs[1] 616 # op.output[0]: y = -b * conj(a)^2 617 with ops.control_dependencies([grad]): 618 ca = math_ops.conj(op.inputs[0]) 619 cg = math_ops.conj(grad) 620 return cg * -2.0 * b * ca, gen_math_ops.reciprocal_grad(ca, grad) 621 622 623@ops.RegisterGradient("Square") 624def _SquareGrad(op, grad): 625 x = op.inputs[0] 626 # Added control dependencies to prevent 2*x from being computed too early. 627 with ops.control_dependencies([grad]): 628 x = math_ops.conj(x) 629 y = constant_op.constant(2.0, dtype=x.dtype) 630 return math_ops.multiply(grad, math_ops.multiply(x, y)) 631 632 633@ops.RegisterGradient("Sqrt") 634def _SqrtGrad(op, grad): 635 y = op.outputs[0] # y = x^(1/2) 636 return gen_math_ops.sqrt_grad(y, grad) 637 638 639@ops.RegisterGradient("SqrtGrad") 640def _SqrtGradGrad(op, grad): 641 a = op.inputs[0] 642 y = op.outputs[0] # y = 0.5 * b / conj(a) 643 with ops.control_dependencies([grad]): 644 ga = grad / a 645 return -math_ops.conj(ga) * y, 0.5 * ga # pylint: disable=invalid-unary-operand-type 646 647 648@ops.RegisterGradient("Rsqrt") 649def _RsqrtGrad(op, grad): 650 """Returns -0.5 * grad * conj(y)^3.""" 651 y = op.outputs[0] # y = x^(-1/2) 652 return gen_math_ops.rsqrt_grad(y, grad) 653 654 655@ops.RegisterGradient("RsqrtGrad") 656def _RsqrtGradGrad(op, grad): 657 """Returns backprop gradient for f(a,b) = -0.5 * b * conj(a)^3.""" 658 a = op.inputs[0] # a = x^{-1/2} 659 b = op.inputs[1] # backprop gradient for a 660 with ops.control_dependencies([grad]): 661 ca = math_ops.conj(a) 662 cg = math_ops.conj(grad) 663 grad_a = -1.5 * cg * b * math_ops.square(ca) 664 grad_b = gen_math_ops.rsqrt_grad(ca, grad) 665 return grad_a, grad_b 666 667 668@ops.RegisterGradient("Exp") 669def _ExpGrad(op, grad): 670 """Returns grad * exp(x).""" 671 y = op.outputs[0] # y = e^x 672 with ops.control_dependencies([grad]): 673 y = math_ops.conj(y) 674 return grad * y 675 676 677@ops.RegisterGradient("Expm1") 678def _Expm1Grad(op, grad): 679 """Returns grad * exp(x).""" 680 x = op.inputs[0] 681 with ops.control_dependencies([grad]): 682 x = math_ops.conj(x) 683 y = math_ops.exp(x) 684 return grad * y 685 686 687@ops.RegisterGradient("Log") 688def _LogGrad(op, grad): 689 """Returns grad * (1/x).""" 690 x = op.inputs[0] 691 with ops.control_dependencies([grad]): 692 x = math_ops.conj(x) 693 return grad * math_ops.reciprocal(x) 694 695 696@ops.RegisterGradient("Log1p") 697def _Log1pGrad(op, grad): 698 """Returns grad * (1/(1 + x)).""" 699 x = op.inputs[0] 700 with ops.control_dependencies([grad]): 701 x = math_ops.conj(x) 702 return grad * math_ops.reciprocal(1 + x) 703 704 705@ops.RegisterGradient("Xlogy") 706def _XLogyGrad(op, grad): 707 """Returns gradient of xlogy(x, y) with respect to x and y.""" 708 x = op.inputs[0] 709 y = op.inputs[1] 710 sx = array_ops.shape(x) 711 sy = array_ops.shape(y) 712 rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) 713 with ops.control_dependencies([grad]): 714 not_zero_x = math_ops.cast( 715 math_ops.not_equal(x, math_ops.cast(0., dtype=x.dtype)), dtype=x.dtype) 716 partial_x = gen_math_ops.xlogy(not_zero_x, y) 717 partial_y = gen_math_ops.xdivy(x, y) 718 return (array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx), 719 array_ops.reshape(math_ops.reduce_sum(partial_y * grad, ry), sy)) 720 721 722@ops.RegisterGradient("Xlog1py") 723def _XLog1pyGrad(op, grad): 724 """Returns gradient of xlog1py(x, y) with respect to x and y.""" 725 x = op.inputs[0] 726 y = op.inputs[1] 727 sx = array_ops.shape(x) 728 sy = array_ops.shape(y) 729 rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) 730 with ops.control_dependencies([grad]): 731 not_zero_x = math_ops.cast( 732 math_ops.not_equal(x, math_ops.cast(0., dtype=x.dtype)), dtype=x.dtype) 733 partial_x = gen_math_ops.xlog1py(not_zero_x, y) 734 partial_y = gen_math_ops.xdivy(x, y + 1.) 735 return (array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx), 736 array_ops.reshape(math_ops.reduce_sum(partial_y * grad, ry), sy)) 737 738 739@ops.RegisterGradient("Xdivy") 740def _XDivyGrad(op, grad): 741 """Returns gradient of xdivy(x, y) with respect to x and y.""" 742 x = op.inputs[0] 743 y = op.inputs[1] 744 sx = array_ops.shape(x) 745 sy = array_ops.shape(y) 746 rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) 747 with ops.control_dependencies([grad]): 748 not_zero_x = math_ops.cast( 749 math_ops.not_equal(x, math_ops.cast(0., dtype=x.dtype)), dtype=x.dtype) 750 partial_x = gen_math_ops.xdivy(not_zero_x, y) 751 partial_y = gen_math_ops.xdivy(math_ops.negative(x), y**2) 752 return (array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx), 753 array_ops.reshape(math_ops.reduce_sum(partial_y * grad, ry), sy)) 754 755 756@ops.RegisterGradient("Sinh") 757def _SinhGrad(op, grad): 758 """Returns grad * cosh(x).""" 759 x = op.inputs[0] 760 with ops.control_dependencies([grad]): 761 x = math_ops.conj(x) 762 return grad * math_ops.cosh(x) 763 764 765@ops.RegisterGradient("Cosh") 766def _CoshGrad(op, grad): 767 """Returns grad * sinh(x).""" 768 x = op.inputs[0] 769 with ops.control_dependencies([grad]): 770 x = math_ops.conj(x) 771 return grad * math_ops.sinh(x) 772 773 774@ops.RegisterGradient("Tanh") 775def _TanhGrad(op, grad): 776 """Returns grad * (1 - tanh(x) * tanh(x)).""" 777 y = op.outputs[0] # y = tanh(x) 778 with ops.control_dependencies([grad]): 779 y = math_ops.conj(y) 780 return gen_math_ops.tanh_grad(y, grad) 781 782 783@ops.RegisterGradient("Asinh") 784def _AsinhGrad(op, grad): 785 """Returns grad * 1/cosh(y).""" 786 y = op.outputs[0] 787 with ops.control_dependencies([grad]): 788 y = math_ops.conj(y) 789 return grad / math_ops.cosh(y) 790 791 792@ops.RegisterGradient("Acosh") 793def _AcoshGrad(op, grad): 794 """Returns grad * 1/sinh(y).""" 795 y = op.outputs[0] 796 with ops.control_dependencies([grad]): 797 y = math_ops.conj(y) 798 return grad / math_ops.sinh(y) 799 800 801@ops.RegisterGradient("Atanh") 802def _AtanhGrad(op, grad): 803 """Returns grad * 1/ (1 - x^2).""" 804 x = op.inputs[0] 805 with ops.control_dependencies([grad]): 806 x = math_ops.conj(x) 807 x2 = math_ops.square(x) 808 one = constant_op.constant(1, dtype=grad.dtype) 809 inv = math_ops.reciprocal(math_ops.subtract(one, x2)) 810 return grad * inv 811 812 813@ops.RegisterGradient("TanhGrad") 814def _TanhGradGrad(op, grad): 815 with ops.control_dependencies([grad]): 816 a = math_ops.conj(op.inputs[0]) 817 b = math_ops.conj(op.inputs[1]) 818 return grad * -2.0 * b * a, gen_math_ops.tanh_grad(a, grad) 819 820 821@ops.RegisterGradient("Erf") 822def _ErfGrad(op, grad): 823 """Returns grad * 2/sqrt(pi) * exp(-x**2).""" 824 x = op.inputs[0] 825 two_over_root_pi = constant_op.constant(2 / np.sqrt(np.pi), dtype=grad.dtype) 826 with ops.control_dependencies([grad]): 827 x = math_ops.conj(x) 828 return grad * two_over_root_pi * math_ops.exp(-math_ops.square(x)) 829 830 831@ops.RegisterGradient("Erfc") 832def _ErfcGrad(op, grad): 833 """Returns -grad * 2/sqrt(pi) * exp(-x**2).""" 834 x = op.inputs[0] 835 minus_two_over_root_pi = constant_op.constant( 836 -2 / np.sqrt(np.pi), dtype=grad.dtype) 837 with ops.control_dependencies([grad]): 838 x = math_ops.conj(x) 839 return grad * minus_two_over_root_pi * math_ops.exp(-math_ops.square(x)) 840 841 842@ops.RegisterGradient("Erfinv") 843def _ErfinvGrad(op, grad): 844 """Returns grad * sqrt(pi) / 2 * exp(erfinv(x)**2).""" 845 root_pi_over_two = constant_op.constant(np.sqrt(np.pi) / 2, dtype=grad.dtype) 846 with ops.control_dependencies([grad]): 847 return grad * root_pi_over_two * math_ops.exp( 848 math_ops.square(op.outputs[0])) 849 850 851@ops.RegisterGradient("Ndtri") 852def _NdtriGrad(op, grad): 853 """Returns grad * sqrt(2 * pi) * exp(ndtri(x)**2 / 2).""" 854 root_two_pi = constant_op.constant(np.sqrt(2 * np.pi), dtype=grad.dtype) 855 with ops.control_dependencies([grad]): 856 return grad * root_two_pi * math_ops.exp( 857 math_ops.square(op.outputs[0]) / 2.) 858 859 860@ops.RegisterGradient("Lgamma") 861def _LgammaGrad(op, grad): 862 """Returns grad * digamma(x).""" 863 x = op.inputs[0] 864 with ops.control_dependencies([grad]): 865 x = math_ops.conj(x) 866 return grad * math_ops.digamma(x) 867 868 869@ops.RegisterGradient("Digamma") 870def _DigammaGrad(op, grad): 871 """Compute gradient of the digamma function with respect to its argument.""" 872 x = op.inputs[0] 873 with ops.control_dependencies([grad]): 874 x = math_ops.conj(x) 875 partial_x = math_ops.polygamma(array_ops.constant(1, dtype=x.dtype), x) 876 return grad * partial_x 877 878 879@ops.RegisterGradient("Dawsn") 880def _DawsnGrad(op, grad): 881 """Compute gradient of dawsn(x) with respect to its argument.""" 882 x = op.inputs[0] 883 y = op.outputs[0] 884 with ops.control_dependencies([grad]): 885 return grad * (1. - 2 * x * y) 886 887 888@ops.RegisterGradient("Expint") 889def _ExpintGrad(op, grad): 890 """Compute gradient of expint(x) with respect to its argument.""" 891 x = op.inputs[0] 892 with ops.control_dependencies([grad]): 893 return grad * math_ops.exp(x) / x 894 895 896@ops.RegisterGradient("FresnelCos") 897def _FresnelCosGrad(op, grad): 898 """Compute gradient of fresnel_cos(x) with respect to its argument.""" 899 x = op.inputs[0] 900 with ops.control_dependencies([grad]): 901 return grad * math_ops.cos((np.pi / 2.) * math_ops.square(x)) 902 903 904@ops.RegisterGradient("FresnelSin") 905def _FresnelSinGrad(op, grad): 906 """Compute gradient of fresnel_sin(x) with respect to its argument.""" 907 x = op.inputs[0] 908 with ops.control_dependencies([grad]): 909 return grad * math_ops.sin((np.pi / 2.) * math_ops.square(x)) 910 911 912@ops.RegisterGradient("Spence") 913def _SpenceGrad(op, grad): 914 """Compute gradient of spence(x) with respect to its argument.""" 915 x = op.inputs[0] 916 with ops.control_dependencies([grad]): 917 partial_x = math_ops.log(x) / (1 - x) 918 partial_x = array_ops.where( 919 math_ops.equal(x, 1.), -array_ops.ones_like(x), partial_x) # pylint: disable=invalid-unary-operand-type 920 return grad * partial_x 921 922 923@ops.RegisterGradient("BesselI0") 924def _BesselI0Grad(op, grad): 925 """Compute gradient of bessel_i0(x) with respect to its argument.""" 926 x = op.inputs[0] 927 with ops.control_dependencies([grad]): 928 partial_x = special_math_ops.bessel_i1(x) 929 return grad * partial_x 930 931 932@ops.RegisterGradient("BesselI0e") 933def _BesselI0eGrad(op, grad): 934 """Compute gradient of bessel_i0e(x) with respect to its argument.""" 935 x = op.inputs[0] 936 y = op.outputs[0] 937 with ops.control_dependencies([grad]): 938 partial_x = (special_math_ops.bessel_i1e(x) - math_ops.sign(x) * y) 939 return grad * partial_x 940 941 942@ops.RegisterGradient("BesselI1") 943def _BesselI1Grad(op, grad): 944 """Compute gradient of bessel_i1(x) with respect to its argument.""" 945 x = op.inputs[0] 946 y = op.outputs[0] 947 with ops.control_dependencies([grad]): 948 # For x = 0, the correct gradient is 1.0. 949 # However, the main branch gives NaN because of the division by x, so 950 # we impute the gradient manually. 951 # An alternative solution is to express the gradient via bessel_i0 and 952 # bessel_i2, but the latter is not yet implemented in Eigen. 953 dy_dx = array_ops.where_v2( 954 math_ops.equal(x, 0.), math_ops.cast(1., x.dtype), 955 special_math_ops.bessel_i0(x) - math_ops.div(y, x)) 956 return grad * dy_dx 957 958 959@ops.RegisterGradient("BesselI1e") 960def _BesselI1eGrad(op, grad): 961 """Compute gradient of bessel_i1e(x) with respect to its argument.""" 962 x = op.inputs[0] 963 y = op.outputs[0] 964 with ops.control_dependencies([grad]): 965 # For x = 0, the correct gradient is 0.5. 966 # However, the main branch gives NaN because of the division by x, so 967 # we impute the gradient manually. 968 # An alternative solution is to express the gradient via bessel_i0e and 969 # bessel_i2e, but the latter is not yet implemented in Eigen. 970 dy_dx = array_ops.where_v2( 971 math_ops.equal(x, 0.), math_ops.cast(0.5, x.dtype), 972 special_math_ops.bessel_i0e(x) - y * 973 (math_ops.sign(x) + math_ops.reciprocal(x))) 974 return grad * dy_dx 975 976 977@ops.RegisterGradient("BesselK0") 978def _BesselK0Grad(op, grad): 979 """Compute gradient of bessel_k0(x) with respect to its argument.""" 980 x = op.inputs[0] 981 with ops.control_dependencies([grad]): 982 partial_x = -special_math_ops.bessel_k1(x) 983 return grad * partial_x 984 985 986@ops.RegisterGradient("BesselK0e") 987def _BesselK0eGrad(op, grad): 988 """Compute gradient of bessel_k0e(x) with respect to its argument.""" 989 x = op.inputs[0] 990 y = op.outputs[0] 991 with ops.control_dependencies([grad]): 992 partial_x = (y - special_math_ops.bessel_k1e(x)) 993 return grad * partial_x 994 995 996@ops.RegisterGradient("BesselK1") 997def _BesselK1Grad(op, grad): 998 """Compute gradient of bessel_k1(x) with respect to its argument.""" 999 x = op.inputs[0] 1000 y = op.outputs[0] 1001 with ops.control_dependencies([grad]): 1002 # At 0., this is NaN which is fine since the derivative is undefined 1003 # at 0. 1004 partial_x = -special_math_ops.bessel_k0(x) - math_ops.div(y, x) 1005 return grad * partial_x 1006 1007 1008@ops.RegisterGradient("BesselK1e") 1009def _BesselK1eGrad(op, grad): 1010 """Compute gradient of bessel_k1e(x) with respect to its argument.""" 1011 x = op.inputs[0] 1012 y = op.outputs[0] 1013 with ops.control_dependencies([grad]): 1014 # At 0., this is NaN which is fine since the derivative is undefined 1015 # at 0. 1016 partial_x = ( 1017 y * (1. - math_ops.reciprocal(x)) - special_math_ops.bessel_k0e(x)) 1018 return grad * partial_x 1019 1020 1021@ops.RegisterGradient("BesselJ0") 1022def _BesselJ0Grad(op, grad): 1023 """Compute gradient of bessel_j0(x) with respect to its argument.""" 1024 x = op.inputs[0] 1025 with ops.control_dependencies([grad]): 1026 partial_x = -special_math_ops.bessel_j1(x) 1027 return grad * partial_x 1028 1029 1030@ops.RegisterGradient("BesselJ1") 1031def _BesselJ1Grad(op, grad): 1032 """Compute gradient of bessel_j1(x) with respect to its argument.""" 1033 x = op.inputs[0] 1034 y = op.outputs[0] 1035 with ops.control_dependencies([grad]): 1036 # For x = 0, the correct gradient is 0.5. 1037 # However, the main branch gives NaN because of the division by x, so 1038 # we impute the gradient manually. 1039 # An alternative solution is to express the gradient via bessel_i0e and 1040 # bessel_i2e, but the latter is not yet implemented in Eigen. 1041 dy_dx = array_ops.where_v2( 1042 math_ops.equal(x, 0.), math_ops.cast(0.5, x.dtype), 1043 special_math_ops.bessel_j0(x) - math_ops.div(y, x)) 1044 return grad * dy_dx 1045 1046 1047@ops.RegisterGradient("BesselY0") 1048def _BesselY0Grad(op, grad): 1049 """Compute gradient of bessel_y0(x) with respect to its argument.""" 1050 x = op.inputs[0] 1051 with ops.control_dependencies([grad]): 1052 partial_x = -special_math_ops.bessel_y1(x) 1053 return grad * partial_x 1054 1055 1056@ops.RegisterGradient("BesselY1") 1057def _BesselY1Grad(op, grad): 1058 """Compute gradient of bessel_y1(x) with respect to its argument.""" 1059 x = op.inputs[0] 1060 y = op.outputs[0] 1061 with ops.control_dependencies([grad]): 1062 # At 0., this is NaN which is fine since the derivative is undefined 1063 # at 0. 1064 partial_x = special_math_ops.bessel_y0(x) - math_ops.div(y, x) 1065 return grad * partial_x 1066 1067 1068@ops.RegisterGradient("Igamma") 1069def _IgammaGrad(op, grad): 1070 """Returns gradient of igamma(a, x) with respect to a and x.""" 1071 a = op.inputs[0] 1072 x = op.inputs[1] 1073 sa = array_ops.shape(a) 1074 sx = array_ops.shape(x) 1075 ra, rx = gen_array_ops.broadcast_gradient_args(sa, sx) 1076 1077 with ops.control_dependencies([grad]): 1078 partial_a = gen_math_ops.igamma_grad_a(a, x) 1079 # Perform operations in log space before summing, because Gamma(a) 1080 # and Gamma'(a) can grow large. 1081 partial_x = math_ops.exp(-x + (a - 1) * math_ops.log(x) - 1082 math_ops.lgamma(a)) 1083 return (array_ops.reshape(math_ops.reduce_sum(partial_a * grad, ra), sa), 1084 array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx)) 1085 1086 1087@ops.RegisterGradient("Igammac") 1088def _IgammacGrad(op, grad): 1089 """Returns gradient of igammac(a, x) = 1 - igamma(a, x) w.r.t. a and x.""" 1090 igamma_grad_a, igamma_grad_x = _IgammaGrad(op, grad) 1091 return (-igamma_grad_a, -igamma_grad_x) 1092 1093 1094@ops.RegisterGradient("Betainc") 1095def _BetaincGrad(op, grad): 1096 """Returns gradient of betainc(a, b, x) with respect to x.""" 1097 # TODO(ebrevdo): Perhaps add the derivative w.r.t. a, b 1098 a, b, x = op.inputs 1099 1100 # two cases: x is a scalar and a/b are same-shaped tensors, or vice 1101 # versa; so its sufficient to check against shape(a). 1102 sa = array_ops.shape(a) 1103 sx = array_ops.shape(x) 1104 _, rx = gen_array_ops.broadcast_gradient_args(sa, sx) 1105 1106 # Perform operations in log space before summing, because terms 1107 # can grow large. 1108 log_beta = ( 1109 gen_math_ops.lgamma(a) + gen_math_ops.lgamma(b) - 1110 gen_math_ops.lgamma(a + b)) 1111 # We use xlog1py and xlogy since the derivatives should tend to 1112 # zero one of the tails when a is 1. or b is 1. 1113 partial_x = math_ops.exp(math_ops.xlog1py(b - 1, -x) + 1114 math_ops.xlogy(a - 1, x) - log_beta) 1115 1116 return ( 1117 None, # da 1118 None, # db 1119 array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx)) 1120 1121 1122@ops.RegisterGradient("Zeta") 1123def _ZetaGrad(op, grad): 1124 """Returns gradient of zeta(x, q) with respect to x and q.""" 1125 # TODO(tillahoffmann): Add derivative with respect to x 1126 x = op.inputs[0] 1127 q = op.inputs[1] 1128 # Broadcast gradients 1129 sx = array_ops.shape(x) 1130 sq = array_ops.shape(q) 1131 unused_rx, rq = gen_array_ops.broadcast_gradient_args(sx, sq) 1132 # Evaluate gradient 1133 with ops.control_dependencies([grad]): 1134 x = math_ops.conj(x) 1135 q = math_ops.conj(q) 1136 partial_q = -x * math_ops.zeta(x + 1, q) # pylint: disable=invalid-unary-operand-type 1137 return (None, 1138 array_ops.reshape(math_ops.reduce_sum(partial_q * grad, rq), sq)) 1139 1140 1141@ops.RegisterGradient("Polygamma") 1142def _PolygammaGrad(op, grad): 1143 """Returns gradient of psi(n, x) with respect to n and x.""" 1144 # TODO(tillahoffmann): Add derivative with respect to n 1145 n = op.inputs[0] 1146 x = op.inputs[1] 1147 # Broadcast gradients 1148 sn = array_ops.shape(n) 1149 sx = array_ops.shape(x) 1150 unused_rn, rx = gen_array_ops.broadcast_gradient_args(sn, sx) 1151 # Evaluate gradient 1152 with ops.control_dependencies([grad]): 1153 n = math_ops.conj(n) 1154 x = math_ops.conj(x) 1155 partial_x = math_ops.polygamma(n + 1, x) 1156 return (None, 1157 array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx)) 1158 1159 1160@ops.RegisterGradient("Sigmoid") 1161def _SigmoidGrad(op, grad): 1162 """Returns grad * sigmoid(x) * (1 - sigmoid(x)).""" 1163 y = op.outputs[0] # y = sigmoid(x) 1164 with ops.control_dependencies([grad]): 1165 y = math_ops.conj(y) 1166 return gen_math_ops.sigmoid_grad(y, grad) 1167 1168 1169@ops.RegisterGradient("SigmoidGrad") 1170def _SigmoidGradGrad(op, grad): 1171 with ops.control_dependencies([grad]): 1172 a = math_ops.conj(op.inputs[0]) 1173 b = math_ops.conj(op.inputs[1]) 1174 gb = grad * b 1175 return gb - 2.0 * gb * a, gen_math_ops.sigmoid_grad(a, grad) 1176 1177 1178@ops.RegisterGradient("Sign") 1179def _SignGrad(op, _): 1180 """Returns 0.""" 1181 x = op.inputs[0] 1182 return array_ops.zeros_like(x) 1183 1184 1185@ops.RegisterGradient("Sin") 1186def _SinGrad(op, grad): 1187 """Returns grad * cos(x).""" 1188 x = op.inputs[0] 1189 with ops.control_dependencies([grad]): 1190 x = math_ops.conj(x) 1191 return grad * math_ops.cos(x) 1192 1193 1194@ops.RegisterGradient("Cos") 1195def _CosGrad(op, grad): 1196 """Returns grad * -sin(x).""" 1197 x = op.inputs[0] 1198 with ops.control_dependencies([grad]): 1199 x = math_ops.conj(x) 1200 return -grad * math_ops.sin(x) 1201 1202 1203@ops.RegisterGradient("Tan") 1204def _TanGrad(op, grad): 1205 """Returns grad * 1/sec^2(x).""" 1206 x = op.inputs[0] 1207 with ops.control_dependencies([grad]): 1208 x = math_ops.conj(x) 1209 secx = math_ops.reciprocal(math_ops.cos(x)) 1210 secx2 = math_ops.square(secx) 1211 return secx2 * grad 1212 1213 1214@ops.RegisterGradient("Asin") 1215def _AsinGrad(op, grad): 1216 """Returns grad * 1/sqrt(1-x^2).""" 1217 x = op.inputs[0] 1218 with ops.control_dependencies([grad]): 1219 x = math_ops.conj(x) 1220 x2 = math_ops.square(x) 1221 one = constant_op.constant(1, dtype=grad.dtype) 1222 den = math_ops.sqrt(math_ops.subtract(one, x2)) 1223 inv = math_ops.reciprocal(den) 1224 return grad * inv 1225 1226 1227@ops.RegisterGradient("Acos") 1228def _AcosGrad(op, grad): 1229 """Returns grad * -1/sqrt(1-x^2).""" 1230 x = op.inputs[0] 1231 with ops.control_dependencies([grad]): 1232 x = math_ops.conj(x) 1233 x2 = math_ops.square(x) 1234 one = constant_op.constant(1, dtype=grad.dtype) 1235 den = math_ops.sqrt(math_ops.subtract(one, x2)) 1236 inv = math_ops.reciprocal(den) 1237 return -grad * inv 1238 1239 1240@ops.RegisterGradient("Atan") 1241def _AtanGrad(op, grad): 1242 """Returns grad * 1/ (1 + x^2).""" 1243 x = op.inputs[0] 1244 with ops.control_dependencies([grad]): 1245 x = math_ops.conj(x) 1246 x2 = math_ops.square(x) 1247 one = constant_op.constant(1, dtype=grad.dtype) 1248 inv = math_ops.reciprocal(math_ops.add(one, x2)) 1249 return grad * inv 1250 1251 1252@ops.RegisterGradient("Atan2") 1253def _Atan2Grad(op, grad): 1254 """Returns grad * x / (x^2 + y^2), grad * -y / (x^2 + y^2).""" 1255 y = op.inputs[0] 1256 x = op.inputs[1] 1257 with ops.control_dependencies([grad]): 1258 grad_inv = grad / (math_ops.square(x) + math_ops.square(y)) 1259 return x * grad_inv, -y * grad_inv 1260 1261 1262@ops.RegisterGradient("AddN") 1263def _AddNGrad(op, grad): 1264 """Copies the gradient to all inputs.""" 1265 # Not broadcasting. 1266 return [grad] * len(op.inputs) 1267 1268 1269def _ShapesFullySpecifiedAndEqual(x, y, grad): 1270 # pylint: disable=protected-access 1271 x_shape = x._shape_tuple() 1272 y_shape = y._shape_tuple() 1273 grad_shape = grad._shape_tuple() 1274 # pylint: enable=protected-access 1275 return (x_shape == y_shape and x_shape == grad_shape and 1276 x_shape is not None and None not in x_shape) 1277 1278 1279@ops.RegisterGradient("Add") 1280@ops.RegisterGradient("AddV2") 1281def _AddGrad(op, grad): 1282 """Gradient for Add.""" 1283 y = op.inputs[1] 1284 skip_input_indices = None 1285 try: 1286 skip_input_indices = op.skip_input_indices 1287 if skip_input_indices is not None and 1 in skip_input_indices and _IsScalar( 1288 y): 1289 return grad, None 1290 except AttributeError: 1291 # No gradient skipping, so do the full gradient computation 1292 pass 1293 x = op.inputs[0] 1294 if (isinstance(grad, ops.Tensor) and 1295 _ShapesFullySpecifiedAndEqual(x, y, grad)): 1296 return grad, grad 1297 (sx, rx, must_reduce_x), (sy, ry, must_reduce_y) = ( 1298 SmartBroadcastGradientArgs(x, y, grad)) 1299 if skip_input_indices is not None and 0 in skip_input_indices: 1300 gx = None 1301 elif not must_reduce_x: 1302 gx = grad 1303 else: 1304 gx = array_ops.reshape(math_ops.reduce_sum(grad, rx), sx) 1305 if skip_input_indices is not None and 1 in skip_input_indices: 1306 gy = None 1307 elif not must_reduce_y: 1308 gy = grad 1309 else: 1310 gy = array_ops.reshape(math_ops.reduce_sum(grad, ry), sy) 1311 return (gx, gy) 1312 1313 1314@ops.RegisterGradient("Sub") 1315def _SubGrad(op, grad): 1316 """Gradient for Sub.""" 1317 y = op.inputs[1] 1318 skip_input_indices = None 1319 try: 1320 skip_input_indices = op.skip_input_indices 1321 if skip_input_indices is not None and 1 in skip_input_indices and _IsScalar( 1322 y): 1323 return grad, None 1324 except AttributeError: 1325 # No gradient skipping, so do the full gradient computation 1326 pass 1327 x = op.inputs[0] 1328 if (isinstance(grad, ops.Tensor) and 1329 _ShapesFullySpecifiedAndEqual(x, y, grad)): 1330 return grad, -grad 1331 (sx, rx, must_reduce_x), (sy, ry, must_reduce_y) = ( 1332 SmartBroadcastGradientArgs(x, y, grad)) 1333 if skip_input_indices is not None and 0 in skip_input_indices: 1334 gx = None 1335 elif not must_reduce_x: 1336 gx = grad 1337 else: 1338 gx = array_ops.reshape(math_ops.reduce_sum(grad, rx), sx) 1339 if skip_input_indices is not None and 1 in skip_input_indices: 1340 gy = None 1341 elif not must_reduce_y: 1342 gy = -grad 1343 else: 1344 gy = array_ops.reshape(math_ops.reduce_sum(-grad, ry), sy) 1345 return (gx, gy) 1346 1347 1348@ops.RegisterGradient("Mul") 1349def _MulGrad(op, grad): 1350 """The gradient of scalar multiplication.""" 1351 y = op.inputs[1] 1352 skip_input_indices = None 1353 try: 1354 skip_input_indices = op.skip_input_indices 1355 if skip_input_indices is not None and 1 in skip_input_indices and _IsScalar( 1356 y): 1357 return gen_math_ops.mul(grad, math_ops.conj(y)), None 1358 except AttributeError: 1359 # No gradient skipping, so do the full gradient computation 1360 pass 1361 x = op.inputs[0] 1362 if (isinstance(grad, ops.Tensor) and 1363 _ShapesFullySpecifiedAndEqual(x, y, grad) and 1364 grad.dtype in (dtypes.int32, dtypes.float32)): 1365 return gen_math_ops.mul(grad, y), gen_math_ops.mul(grad, x) 1366 assert x.dtype.base_dtype == y.dtype.base_dtype, (x.dtype, " vs. ", y.dtype) 1367 1368 (sx, rx, must_reduce_x), (sy, ry, must_reduce_y) = ( 1369 SmartBroadcastGradientArgs(x, y, grad)) 1370 x = math_ops.conj(x) 1371 y = math_ops.conj(y) 1372 if skip_input_indices is not None and 0 in skip_input_indices: 1373 gx = None 1374 elif not must_reduce_x: 1375 gx = gen_math_ops.mul(grad, y) 1376 else: 1377 gx = array_ops.reshape( 1378 math_ops.reduce_sum(gen_math_ops.mul(grad, y), rx), sx) 1379 if skip_input_indices is not None and 1 in skip_input_indices: 1380 gy = None 1381 elif not must_reduce_y: 1382 gy = gen_math_ops.mul(x, grad) 1383 else: 1384 gy = array_ops.reshape( 1385 math_ops.reduce_sum(gen_math_ops.mul(x, grad), ry), sy) 1386 return (gx, gy) 1387 1388 1389@ops.RegisterGradient("MulNoNan") 1390def _MulNoNanGrad(op, grad): 1391 """The gradient of scalar multiplication with NaN-suppression.""" 1392 x = op.inputs[0] 1393 y = op.inputs[1] 1394 if (isinstance(grad, ops.Tensor) and 1395 _ShapesFullySpecifiedAndEqual(x, y, grad)): 1396 return gen_math_ops.mul_no_nan(grad, y), gen_math_ops.mul_no_nan(x, grad) 1397 assert x.dtype.base_dtype == y.dtype.base_dtype, (x.dtype, " vs. ", y.dtype) 1398 sx = array_ops.shape(x) 1399 sy = array_ops.shape(y) 1400 rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) 1401 return (array_ops.reshape( 1402 math_ops.reduce_sum(gen_math_ops.mul_no_nan(grad, y), rx), sx), 1403 array_ops.reshape( 1404 math_ops.reduce_sum(gen_math_ops.mul_no_nan(x, grad), ry), sy)) 1405 1406 1407@ops.RegisterGradient("Div") 1408def _DivGrad(op, grad): 1409 """The gradient for the Div operator.""" 1410 x = op.inputs[0] 1411 y = op.inputs[1] 1412 sx = array_ops.shape(x) 1413 sy = array_ops.shape(y) 1414 rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) 1415 x = math_ops.conj(x) 1416 y = math_ops.conj(y) 1417 # pylint: disable=invalid-unary-operand-type 1418 return ( 1419 array_ops.reshape(math_ops.reduce_sum(math_ops.divide(grad, y), rx), sx), 1420 array_ops.reshape( 1421 math_ops.reduce_sum(grad * math_ops.divide(math_ops.divide(-x, y), y), 1422 ry), sy)) 1423 1424 1425@ops.RegisterGradient("FloorDiv") 1426def _FloorDivGrad(_, unused_grad): 1427 """The gradient for the FloorDiv operator.""" 1428 return None, None 1429 1430 1431@ops.RegisterGradient("FloorMod") 1432def _FloorModGrad(op, grad): 1433 """Returns grad * (1, -floor(x/y)).""" 1434 x = math_ops.conj(op.inputs[0]) 1435 y = math_ops.conj(op.inputs[1]) 1436 1437 sx = array_ops.shape(x) 1438 sy = array_ops.shape(y) 1439 rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) 1440 floor_xy = math_ops.floor_div(x, y) 1441 gx = array_ops.reshape(math_ops.reduce_sum(grad, rx), sx) 1442 gy = array_ops.reshape( 1443 math_ops.reduce_sum(grad * math_ops.negative(floor_xy), ry), sy) 1444 return gx, gy 1445 1446 1447@ops.RegisterGradient("TruncateDiv") 1448def _TruncateDivGrad(_, unused_grad): 1449 return None, None 1450 1451 1452@ops.RegisterGradient("RealDiv") 1453def _RealDivGrad(op, grad): 1454 """RealDiv op gradient.""" 1455 x = op.inputs[0] 1456 y = op.inputs[1] 1457 sx = array_ops.shape(x) 1458 sy = array_ops.shape(y) 1459 rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) 1460 x = math_ops.conj(x) 1461 y = math_ops.conj(y) 1462 return (array_ops.reshape( 1463 math_ops.reduce_sum(math_ops.realdiv(grad, y), rx), sx), 1464 array_ops.reshape( 1465 math_ops.reduce_sum( 1466 grad * math_ops.realdiv(math_ops.realdiv(-x, y), y), ry), sy)) # pylint: disable=invalid-unary-operand-type 1467 1468 1469@ops.RegisterGradient("DivNoNan") 1470def _DivNoNanGrad(op, grad): 1471 """DivNoNan op gradient.""" 1472 x = op.inputs[0] 1473 y = op.inputs[1] 1474 sx = array_ops.shape(x) 1475 sy = array_ops.shape(y) 1476 rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) 1477 x = math_ops.conj(x) 1478 y = math_ops.conj(y) 1479 return ( 1480 array_ops.reshape( 1481 math_ops.reduce_sum(math_ops.div_no_nan(grad, y), rx), sx), 1482 array_ops.reshape( 1483 math_ops.reduce_sum( 1484 grad * math_ops.div_no_nan(math_ops.div_no_nan(-x, y), y), # pylint: disable=invalid-unary-operand-type 1485 ry), 1486 sy)) 1487 1488 1489@ops.RegisterGradient("Pow") 1490def _PowGrad(op, grad): 1491 """Returns grad * (y*x^(y-1), z*log(x)).""" 1492 x = op.inputs[0] 1493 y = op.inputs[1] 1494 skip_input_indices = None 1495 try: 1496 skip_input_indices = op.skip_input_indices 1497 # TODO(mrry): If `y` is a constant, we can combine `tf.sub()` and the 1498 # constant `1` into a single constant op. 1499 if skip_input_indices is not None and 1 in skip_input_indices and _IsScalar( 1500 y): 1501 x = math_ops.conj(x) 1502 y = math_ops.conj(y) 1503 return grad * y * math_ops.pow(x, y - 1), None 1504 1505 except AttributeError: 1506 # No gradient skipping, so do the full gradient computation 1507 pass 1508 1509 (sx, rx, must_reduce_x), (sy, ry, must_reduce_y) = ( 1510 SmartBroadcastGradientArgs(x, y, grad)) 1511 x = math_ops.conj(x) 1512 y = math_ops.conj(y) 1513 1514 if skip_input_indices is None or 0 not in skip_input_indices: 1515 gx = grad * y * math_ops.pow(x, y - 1) 1516 if must_reduce_x: 1517 gx = array_ops.reshape(math_ops.reduce_sum(gx, rx), sx) 1518 else: 1519 gx = None 1520 1521 if skip_input_indices is None or 1 not in skip_input_indices: 1522 z = math_ops.conj(op.outputs[0]) 1523 1524 # Avoid false singularity at x = 0 1525 if x.dtype.is_complex: 1526 # real(x) < 0 is fine for the complex case 1527 mask = math_ops.not_equal(x, 0) 1528 else: 1529 # There's no sensible real value to return if x < 0, so return 0 1530 mask = x > 0 1531 safe_x = array_ops.where(mask, x, array_ops.ones_like(x)) 1532 log_x = array_ops.where(mask, math_ops.log(safe_x), array_ops.zeros_like(x)) 1533 gy = grad * z * log_x 1534 if must_reduce_y: 1535 gy = array_ops.reshape(math_ops.reduce_sum(gy, ry), sy) 1536 else: 1537 gy = None 1538 1539 return gx, gy 1540 1541 1542def _MaximumMinimumGradInputOnly(op, grad, selector_op): 1543 x = op.inputs[0] 1544 y = op.inputs[1] 1545 zeros = array_ops.zeros_like(grad) 1546 xmask = selector_op(x, y) 1547 xgrad = array_ops.where_v2(xmask, grad, zeros) 1548 ygrad = None # Return None for ygrad since the config allows that. 1549 return (xgrad, ygrad) 1550 1551 1552def _MaximumMinimumGrad(op, grad, selector_op): 1553 """Factor out the code for the gradient of Maximum or Minimum.""" 1554 y = op.inputs[1] 1555 skip_input_indices = None 1556 try: 1557 skip_input_indices = op.skip_input_indices 1558 if skip_input_indices is not None and 1 in skip_input_indices and _IsScalar( 1559 y): 1560 # When we want to get gradients for the first input only, and the second 1561 # input tensor is a scalar, we can do a much simpler calculation 1562 return _MaximumMinimumGradInputOnly(op, grad, selector_op) 1563 except AttributeError: 1564 # No gradient skipping, so do the full gradient computation 1565 pass 1566 x = op.inputs[0] 1567 sx = array_ops.shape(x) 1568 sy = array_ops.shape(y) 1569 zeros = array_ops.zeros_like(grad) 1570 xmask = selector_op(x, y) 1571 rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) 1572 if skip_input_indices is not None and 0 in skip_input_indices: 1573 gx = None 1574 else: 1575 xgrad = array_ops.where_v2(xmask, grad, zeros) 1576 gx = array_ops.reshape(math_ops.reduce_sum(xgrad, rx), sx) 1577 1578 if skip_input_indices is not None and 1 in skip_input_indices: 1579 gy = None 1580 else: 1581 ygrad = array_ops.where_v2(xmask, zeros, grad) 1582 gy = array_ops.reshape(math_ops.reduce_sum(ygrad, ry), sy) 1583 1584 return (gx, gy) 1585 1586 1587@ops.RegisterGradient("Maximum") 1588def _MaximumGrad(op, grad): 1589 """Returns grad*(x >= y, x < y) with type of grad.""" 1590 return _MaximumMinimumGrad(op, grad, math_ops.greater_equal) 1591 1592 1593@ops.RegisterGradient("Minimum") 1594def _MinimumGrad(op, grad): 1595 """Returns grad*(x <= y, x > y) with type of grad.""" 1596 return _MaximumMinimumGrad(op, grad, math_ops.less_equal) 1597 1598 1599@ops.RegisterGradient("SquaredDifference") 1600def _SquaredDifferenceGrad(op, grad): 1601 """Returns the gradient for (x-y)^2.""" 1602 x = op.inputs[0] 1603 y = op.inputs[1] 1604 skip_input_indices = None 1605 try: 1606 skip_input_indices = op.skip_input_indices 1607 except AttributeError: 1608 # No gradient skipping, so do the full gradient computation 1609 pass 1610 1611 with ops.control_dependencies([grad]): 1612 # The parens ensure that if grad is IndexedSlices, it'll get multiplied by 1613 # Tensor (not a number like 2.0) which causes it to convert to Tensor. 1614 x_grad = math_ops.scalar_mul(2.0, grad) * (x - y) 1615 1616 if (isinstance(grad, ops.Tensor) and 1617 _ShapesFullySpecifiedAndEqual(x, y, grad)): 1618 return x_grad, -x_grad 1619 1620 (sx, rx, must_reduce_x), (sy, ry, must_reduce_y) = ( 1621 SmartBroadcastGradientArgs(x, y, grad)) 1622 1623 if skip_input_indices is not None and 0 in skip_input_indices: 1624 gx = None 1625 elif must_reduce_x: 1626 gx = array_ops.reshape(math_ops.reduce_sum(x_grad, rx), sx) 1627 else: 1628 gx = x_grad 1629 1630 if skip_input_indices is not None and 1 in skip_input_indices: 1631 gy = None 1632 elif must_reduce_y: 1633 gy = -array_ops.reshape(math_ops.reduce_sum(x_grad, ry), sy) 1634 else: 1635 gy = -x_grad 1636 return (gx, gy) 1637 1638 1639# Logical operations have no gradients. 1640ops.NotDifferentiable("Less") 1641ops.NotDifferentiable("LessEqual") 1642ops.NotDifferentiable("Greater") 1643ops.NotDifferentiable("GreaterEqual") 1644ops.NotDifferentiable("Equal") 1645ops.NotDifferentiable("ApproximateEqual") 1646ops.NotDifferentiable("NotEqual") 1647ops.NotDifferentiable("LogicalAnd") 1648ops.NotDifferentiable("LogicalOr") 1649ops.NotDifferentiable("LogicalNot") 1650 1651 1652@ops.RegisterGradient("Select") 1653def _SelectGrad(op, grad): 1654 c = op.inputs[0] 1655 x = op.inputs[1] 1656 zeros = array_ops.zeros_like(x) 1657 return (None, array_ops.where(c, grad, zeros), array_ops.where( 1658 c, zeros, grad)) 1659 1660 1661@ops.RegisterGradient("SelectV2") 1662def _SelectGradV2(op, grad): 1663 c = op.inputs[0] 1664 x = op.inputs[1] 1665 y = op.inputs[2] 1666 zeros = array_ops.zeros([], dtype=grad.dtype.base_dtype) 1667 gx = array_ops.where_v2(c, grad, zeros) 1668 x_shape = array_ops.shape(x) 1669 output_shape = array_ops.shape(op.outputs[0]) 1670 # Reduce away broadcasted leading dims. 1671 reduce_x, _ = gen_array_ops.broadcast_gradient_args(x_shape, output_shape) 1672 gx = math_ops.reduce_sum(gx, keepdims=True, axis=reduce_x) 1673 gx = array_ops.reshape(gx, x_shape) 1674 1675 gy = array_ops.where_v2(c, zeros, grad) 1676 y_shape = array_ops.shape(y) 1677 # Reduce away broadcasted leading dims. 1678 reduce_y, _ = gen_array_ops.broadcast_gradient_args(y_shape, output_shape) 1679 gy = math_ops.reduce_sum(gy, keepdims=True, axis=reduce_y) 1680 gy = array_ops.reshape(gy, y_shape) 1681 1682 return (None, gx, gy) 1683 1684 1685def _MatMulGradAgainstFirstOnly(op, grad): 1686 """Gradient for MatMul, only for the first input.""" 1687 t_a = op.get_attr("transpose_a") 1688 t_b = op.get_attr("transpose_b") 1689 b = math_ops.conj(op.inputs[1]) 1690 if not t_a and not t_b: 1691 grad_a = gen_math_ops.mat_mul(grad, b, transpose_b=True) 1692 elif not t_a and t_b: 1693 grad_a = gen_math_ops.mat_mul(grad, b) 1694 elif t_a and not t_b: 1695 grad_a = gen_math_ops.mat_mul(b, grad, transpose_b=True) 1696 elif t_a and t_b: 1697 grad_a = gen_math_ops.mat_mul(b, grad, transpose_a=True, transpose_b=True) 1698 return grad_a, None 1699 1700 1701def _MatMulGradAgainstSecondOnly(op, grad): 1702 """Gradient for MatMul, only for the second input.""" 1703 t_a = op.get_attr("transpose_a") 1704 t_b = op.get_attr("transpose_b") 1705 a = math_ops.conj(op.inputs[0]) 1706 if not t_a and not t_b: 1707 grad_b = gen_math_ops.mat_mul(a, grad, transpose_a=True) 1708 elif not t_a and t_b: 1709 grad_b = gen_math_ops.mat_mul(grad, a, transpose_a=True) 1710 elif t_a and not t_b: 1711 grad_b = gen_math_ops.mat_mul(a, grad) 1712 elif t_a and t_b: 1713 grad_b = gen_math_ops.mat_mul(grad, a, transpose_a=True, transpose_b=True) 1714 return None, grad_b 1715 1716 1717@ops.RegisterGradient("MatMul") 1718def _MatMulGrad(op, grad): 1719 """Gradient for MatMul.""" 1720 try: 1721 skip_input_indices = op.skip_input_indices 1722 if skip_input_indices is not None: 1723 if 1 in skip_input_indices: 1724 return _MatMulGradAgainstFirstOnly(op, grad) 1725 elif 0 in skip_input_indices: 1726 return _MatMulGradAgainstSecondOnly(op, grad) 1727 except AttributeError: 1728 # No gradient skipping, so do the full gradient computation 1729 pass 1730 1731 t_a = op.get_attr("transpose_a") 1732 t_b = op.get_attr("transpose_b") 1733 a = math_ops.conj(op.inputs[0]) 1734 b = math_ops.conj(op.inputs[1]) 1735 if not t_a and not t_b: 1736 grad_a = gen_math_ops.mat_mul(grad, b, transpose_b=True) 1737 grad_b = gen_math_ops.mat_mul(a, grad, transpose_a=True) 1738 elif not t_a and t_b: 1739 grad_a = gen_math_ops.mat_mul(grad, b) 1740 grad_b = gen_math_ops.mat_mul(grad, a, transpose_a=True) 1741 elif t_a and not t_b: 1742 grad_a = gen_math_ops.mat_mul(b, grad, transpose_b=True) 1743 grad_b = gen_math_ops.mat_mul(a, grad) 1744 elif t_a and t_b: 1745 grad_a = gen_math_ops.mat_mul(b, grad, transpose_a=True, transpose_b=True) 1746 grad_b = gen_math_ops.mat_mul(grad, a, transpose_a=True, transpose_b=True) 1747 return grad_a, grad_b 1748 1749 1750@ops.RegisterGradient("SparseMatMul") 1751def _SparseMatMulGrad(op, grad): 1752 """Gradient for SparseMatMul.""" 1753 1754 t_a = op.get_attr("transpose_a") 1755 t_b = op.get_attr("transpose_b") 1756 is_sparse = {} 1757 is_sparse[op.inputs[0].ref()] = op.get_attr("a_is_sparse") 1758 is_sparse[op.inputs[1].ref()] = op.get_attr("b_is_sparse") 1759 # Use heuristic to figure out if grad might be sparse 1760 is_sparse[grad.ref()] = not context.executing_eagerly() and ( 1761 grad.op.type == "ReluGrad") 1762 1763 def _SparseMatMul(t1, t2, out_dtype, transpose_a=False, transpose_b=False): 1764 """Helper function to create SparseMatMul op.""" 1765 1766 assert t1.ref() in is_sparse and t2.ref() in is_sparse 1767 t1_sparse = is_sparse[t1.ref()] 1768 t2_sparse = is_sparse[t2.ref()] 1769 if transpose_b: 1770 t2 = array_ops.transpose(t2) 1771 transpose_b = False 1772 prod = math_ops.matmul( 1773 t1, 1774 t2, 1775 transpose_a=transpose_a, 1776 transpose_b=transpose_b, 1777 a_is_sparse=t1_sparse, 1778 b_is_sparse=t2_sparse) 1779 if prod.dtype != out_dtype: 1780 prod = math_ops.cast(prod, out_dtype) 1781 return prod 1782 1783 dtype_a = op.inputs[0].dtype 1784 dtype_b = op.inputs[1].dtype 1785 if not t_a and not t_b: 1786 return (_SparseMatMul(grad, op.inputs[1], dtype_a, transpose_b=True), 1787 _SparseMatMul(op.inputs[0], grad, dtype_b, transpose_a=True)) 1788 elif not t_a and t_b: 1789 return (_SparseMatMul(grad, op.inputs[1], dtype_a), 1790 _SparseMatMul(grad, op.inputs[0], dtype_b, transpose_a=True)) 1791 elif t_a and not t_b: 1792 return (_SparseMatMul(op.inputs[1], grad, dtype_a, transpose_b=True), 1793 _SparseMatMul(op.inputs[0], grad, dtype_b)) 1794 elif t_a and t_b: 1795 return (_SparseMatMul( 1796 op.inputs[1], grad, dtype_a, transpose_a=True, transpose_b=True), 1797 _SparseMatMul( 1798 grad, op.inputs[0], dtype_b, transpose_a=True, 1799 transpose_b=True)) 1800 1801 1802@ops.RegisterGradient("Floor") 1803def _FloorGrad(_, unused_grad): 1804 return [None] 1805 1806 1807@ops.RegisterGradient("Ceil") 1808def _CeilGrad(_, unused_grad): 1809 return [None] 1810 1811 1812@ops.RegisterGradient("Round") 1813def _RoundGrad(_, unused_grad): 1814 return [None] 1815 1816 1817@ops.RegisterGradient("Rint") 1818def _RintGrad(_, unused_grad): 1819 # the gradient of Rint is zero 1820 return [None] 1821 1822 1823@ops.RegisterGradient("BatchMatMul") 1824def _BatchMatMul(op, grad): 1825 """Returns the gradient of x and y given the gradient of x * y.""" 1826 x = op.inputs[0] 1827 y = op.inputs[1] 1828 adj_x = op.get_attr("adj_x") 1829 adj_y = op.get_attr("adj_y") 1830 1831 if not adj_x: 1832 if not adj_y: 1833 grad_x = math_ops.matmul(grad, y, adjoint_a=False, adjoint_b=True) 1834 grad_y = math_ops.matmul(x, grad, adjoint_a=True, adjoint_b=False) 1835 else: 1836 grad_x = math_ops.matmul(grad, y, adjoint_a=False, adjoint_b=False) 1837 grad_y = math_ops.matmul(grad, x, adjoint_a=True, adjoint_b=False) 1838 else: 1839 if not adj_y: 1840 grad_x = math_ops.matmul(y, grad, adjoint_a=False, adjoint_b=True) 1841 grad_y = math_ops.matmul(x, grad, adjoint_a=False, adjoint_b=False) 1842 else: 1843 grad_x = math_ops.matmul(y, grad, adjoint_a=True, adjoint_b=True) 1844 grad_y = math_ops.matmul(grad, x, adjoint_a=True, adjoint_b=True) 1845 1846 return grad_x, grad_y 1847 1848 1849@ops.RegisterGradient("BatchMatMulV2") 1850@ops.RegisterGradient("BatchMatMulV3") 1851def _BatchMatMulV2(op, grad): 1852 """Returns the gradient of x and y given the gradient of x * y.""" 1853 x = op.inputs[0] 1854 y = op.inputs[1] 1855 adj_x = op.get_attr("adj_x") 1856 adj_y = op.get_attr("adj_y") 1857 1858 if not adj_x: 1859 if not adj_y: 1860 grad_x = math_ops.matmul(grad, y, adjoint_a=False, adjoint_b=True) 1861 grad_y = math_ops.matmul(x, grad, adjoint_a=True, adjoint_b=False) 1862 else: 1863 grad_x = math_ops.matmul(grad, y, adjoint_a=False, adjoint_b=False) 1864 grad_y = math_ops.matmul(grad, x, adjoint_a=True, adjoint_b=False) 1865 else: 1866 if not adj_y: 1867 grad_x = math_ops.matmul(y, grad, adjoint_a=False, adjoint_b=True) 1868 grad_y = math_ops.matmul(x, grad, adjoint_a=False, adjoint_b=False) 1869 else: 1870 grad_x = math_ops.matmul(y, grad, adjoint_a=True, adjoint_b=True) 1871 grad_y = math_ops.matmul(grad, x, adjoint_a=True, adjoint_b=True) 1872 1873 # Possibly reduce along the broadcasted batch dimensions, if broadcasting 1874 # is required. 1875 shape_x_static = x.get_shape() 1876 shape_y_static = y.get_shape() 1877 output_may_have_non_empty_batch_shape = ( 1878 (shape_x_static.rank is None or shape_x_static.rank > 2) or 1879 (shape_y_static.rank is None or shape_y_static.rank > 2)) 1880 batch_shapes_match = ( 1881 shape_x_static[:-2].is_fully_defined() and 1882 shape_y_static[:-2].is_fully_defined() and 1883 shape_x_static[:-2] == shape_y_static[:-2]) 1884 if (not output_may_have_non_empty_batch_shape) or batch_shapes_match: 1885 return grad_x, grad_y 1886 1887 sx = array_ops.shape(x) 1888 sy = array_ops.shape(y) 1889 rx, ry = gen_array_ops.broadcast_gradient_args(sx[:-2], sy[:-2]) 1890 grad_x = array_ops.reshape(math_ops.reduce_sum(grad_x, rx), sx) 1891 grad_y = array_ops.reshape(math_ops.reduce_sum(grad_y, ry), sy) 1892 return grad_x, grad_y 1893 1894 1895ops.NotDifferentiable("Range") 1896ops.NotDifferentiable("LinSpace") 1897 1898 1899@ops.RegisterGradient("Complex") 1900def _ComplexGrad(op, grad): 1901 """Returns the real and imaginary components of 'grad', respectively.""" 1902 x = op.inputs[0] 1903 y = op.inputs[1] 1904 sx = array_ops.shape(x) 1905 sy = array_ops.shape(y) 1906 rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) 1907 return (array_ops.reshape(math_ops.reduce_sum(math_ops.real(grad), rx), sx), 1908 array_ops.reshape(math_ops.reduce_sum(math_ops.imag(grad), ry), sy)) 1909 1910 1911@ops.RegisterGradient("Real") 1912def _RealGrad(_, grad): 1913 """Returns 'grad' as the real part and set the imaginary part 0.""" 1914 zero = constant_op.constant(0, dtype=grad.dtype) 1915 return math_ops.complex(grad, zero) 1916 1917 1918@ops.RegisterGradient("Imag") 1919def _ImagGrad(_, grad): 1920 """Returns 'grad' as the imaginary part and set the real part 0.""" 1921 zero = constant_op.constant(0, dtype=grad.dtype) 1922 return math_ops.complex(zero, grad) 1923 1924 1925@ops.RegisterGradient("Angle") 1926def _AngleGrad(op, grad): 1927 """Returns -grad / (Im(x) + iRe(x))""" 1928 x = op.inputs[0] 1929 with ops.control_dependencies([grad]): 1930 re = math_ops.real(x) 1931 im = math_ops.imag(x) 1932 z = math_ops.reciprocal(math_ops.complex(im, re)) 1933 zero = constant_op.constant(0, dtype=grad.dtype) 1934 complex_grad = math_ops.complex(grad, zero) 1935 return -complex_grad * z 1936 1937 1938@ops.RegisterGradient("Conj") 1939def _ConjGrad(_, grad): 1940 """Returns the complex conjugate of grad.""" 1941 return math_ops.conj(grad) 1942 1943 1944@ops.RegisterGradient("ComplexAbs") 1945def _ComplexAbsGrad(op, grad): 1946 """Returns the gradient of ComplexAbs.""" 1947 return math_ops.div_no_nan( 1948 math_ops.complex( 1949 grad, array_ops.zeros_like(grad)) * op.inputs[0], 1950 math_ops.complex( 1951 op.outputs[0], array_ops.zeros_like(op.outputs[0]))) 1952 1953 1954@ops.RegisterGradient("Cast") 1955def _CastGrad(op, grad): 1956 t = [ 1957 dtypes.float16, dtypes.float32, dtypes.float64, dtypes.bfloat16, 1958 dtypes.complex64, dtypes.complex128 1959 ] 1960 src_type = op.inputs[0].dtype.base_dtype 1961 dst_type = grad.dtype.base_dtype 1962 if src_type in t and dst_type in t: 1963 return math_ops.cast(grad, src_type) 1964 else: 1965 return None 1966 1967 1968@ops.RegisterGradient("Cross") 1969def _CrossGrad(op, grad): 1970 u = op.inputs[0] 1971 v = op.inputs[1] 1972 return (math_ops.cross(v, grad), math_ops.cross(grad, u)) 1973 1974 1975@ops.RegisterGradient("Cumsum") 1976def _CumsumGrad(op, grad): 1977 axis = op.inputs[1] 1978 exclusive = op.get_attr("exclusive") 1979 reverse = op.get_attr("reverse") 1980 return [ 1981 math_ops.cumsum(grad, axis, exclusive=exclusive, reverse=not reverse), 1982 None 1983 ] 1984 1985 1986@ops.RegisterGradient("Cumprod") 1987def _CumprodGrad(op, grad): 1988 x = op.inputs[0] 1989 axis = op.inputs[1] 1990 exclusive = op.get_attr("exclusive") 1991 reverse = op.get_attr("reverse") 1992 1993 prod = math_ops.cumprod(x, axis, exclusive=exclusive, reverse=reverse) 1994 out = math_ops.cumsum( 1995 prod * grad, axis, exclusive=exclusive, reverse=not reverse) 1996 return [math_ops.div_no_nan(out, x), None] 1997 1998 1999@ops.RegisterGradient("CumulativeLogsumexp") 2000def _CumulativeLogsumexpGrad(op, grad): 2001 x = op.inputs[0] 2002 axis = op.inputs[1] 2003 cumulative_logsumexp = op.outputs[0] 2004 2005 exclusive = op.get_attr("exclusive") 2006 reverse = op.get_attr("reverse") 2007 2008 # Split the incoming gradient into positive and negative part 2009 # in order to take logs. This is required for stable results. 2010 log_grad_positive = array_ops.where_v2( 2011 math_ops.greater(grad, 0), 2012 math_ops.log(grad), 2013 grad.dtype.min) 2014 2015 log_grad_negative = array_ops.where_v2( 2016 math_ops.less(grad, 0), 2017 math_ops.log(-grad), 2018 grad.dtype.min) 2019 2020 output_pos = math_ops.exp( 2021 math_ops.cumulative_logsumexp( 2022 log_grad_positive - cumulative_logsumexp, 2023 axis=axis, reverse=not reverse, exclusive=exclusive) + x) 2024 2025 output_neg = math_ops.exp( 2026 math_ops.cumulative_logsumexp( 2027 log_grad_negative - cumulative_logsumexp, 2028 axis=axis, reverse=not reverse, exclusive=exclusive) + x) 2029 2030 return [output_pos - output_neg, None] 2031 2032 2033@ops.RegisterGradient("NextAfter") 2034def _NextAfterGrad(op, grad): 2035 """Returns gradient of nextafter(x1, x2) with respect to x1 and x2.""" 2036 x1 = op.inputs[0] 2037 x2 = op.inputs[1] 2038 s_x1 = array_ops.shape(x1) 2039 s_x2 = array_ops.shape(x2) 2040 r_x1, r_x2 = gen_array_ops.broadcast_gradient_args(s_x1, s_x2) 2041 with ops.control_dependencies([grad]): 2042 partial_x1 = array_ops.ones(s_x1, dtype=x1.dtype) 2043 partial_x2 = array_ops.zeros(s_x2, dtype=x2.dtype) 2044 return (array_ops.reshape( 2045 math_ops.reduce_sum(partial_x1 * grad, r_x1), s_x1), 2046 array_ops.reshape( 2047 math_ops.reduce_sum(partial_x2 * grad, r_x2), s_x2)) 2048