1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Gradients for operators defined in nn_ops.py.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.eager import backprop 22from tensorflow.python.eager import context 23from tensorflow.python.framework import dtypes 24from tensorflow.python.framework import ops 25from tensorflow.python.framework import tensor_util 26from tensorflow.python.ops import array_ops 27from tensorflow.python.ops import gen_nn_ops 28from tensorflow.python.ops import math_ops 29from tensorflow.python.ops import nn_ops 30 31 32@ops.RegisterGradient("Conv2DBackpropInput") 33def _Conv2DBackpropInputGrad(op, grad): 34 """The derivatives for deconvolution. 35 36 Args: 37 op: the Deconvolution op. 38 grad: the tensor representing the gradient w.r.t. the output 39 40 Returns: 41 the gradients w.r.t. the input and the filter 42 """ 43 return [ 44 None, 45 nn_ops.conv2d_backprop_filter( 46 grad, 47 array_ops.shape(op.inputs[1]), 48 op.inputs[2], 49 dilations=op.get_attr("dilations"), 50 strides=op.get_attr("strides"), 51 padding=op.get_attr("padding"), 52 use_cudnn_on_gpu=op.get_attr("use_cudnn_on_gpu"), 53 data_format=op.get_attr("data_format").decode()), 54 nn_ops.conv2d( 55 grad, 56 op.inputs[1], 57 dilations=op.get_attr("dilations"), 58 strides=op.get_attr("strides"), 59 padding=op.get_attr("padding"), 60 use_cudnn_on_gpu=op.get_attr("use_cudnn_on_gpu"), 61 data_format=op.get_attr("data_format").decode()) 62 ] 63 64 65@ops.RegisterGradient("Conv2DBackpropFilter") 66def _Conv2DBackpropFilterGrad(op, grad): 67 return [ 68 nn_ops.conv2d_backprop_input( 69 array_ops.shape(op.inputs[0]), 70 grad, 71 op.inputs[2], 72 dilations=op.get_attr("dilations"), 73 strides=op.get_attr("strides"), 74 padding=op.get_attr("padding"), 75 use_cudnn_on_gpu=op.get_attr("use_cudnn_on_gpu"), 76 data_format=op.get_attr("data_format").decode()), None, 77 nn_ops.conv2d( 78 op.inputs[0], 79 grad, 80 dilations=op.get_attr("dilations"), 81 strides=op.get_attr("strides"), 82 padding=op.get_attr("padding"), 83 use_cudnn_on_gpu=op.get_attr("use_cudnn_on_gpu"), 84 data_format=op.get_attr("data_format").decode()) 85 ] 86 87 88@ops.RegisterGradient("DepthwiseConv2dNativeBackpropInput") 89def _DepthwiseConv2dNativeBackpropInputGrad(op, grad): 90 """The derivatives for deconvolution. 91 92 Args: 93 op: the Deconvolution op. 94 grad: the tensor representing the gradient w.r.t. the output 95 96 Returns: 97 the gradients w.r.t. the input and the filter 98 """ 99 return [ 100 None, 101 nn_ops.depthwise_conv2d_native_backprop_filter( 102 grad, 103 array_ops.shape(op.inputs[1]), 104 op.inputs[2], 105 dilations=op.get_attr("dilations"), 106 strides=op.get_attr("strides"), 107 padding=op.get_attr("padding"), 108 data_format=op.get_attr("data_format")), 109 nn_ops.depthwise_conv2d_native( 110 grad, 111 op.inputs[1], 112 dilations=op.get_attr("dilations"), 113 strides=op.get_attr("strides"), 114 padding=op.get_attr("padding"), 115 data_format=op.get_attr("data_format")) 116 ] 117 118 119@ops.RegisterGradient("DepthwiseConv2dNativeBackpropFilter") 120def _DepthwiseConv2dNativeBackpropFilterGrad(op, grad): 121 return [ 122 nn_ops.depthwise_conv2d_native_backprop_input( 123 array_ops.shape(op.inputs[0]), 124 grad, 125 op.inputs[2], 126 dilations=op.get_attr("dilations"), 127 strides=op.get_attr("strides"), 128 padding=op.get_attr("padding"), 129 data_format=op.get_attr("data_format")), None, 130 nn_ops.depthwise_conv2d_native( 131 op.inputs[0], 132 grad, 133 dilations=op.get_attr("dilations"), 134 strides=op.get_attr("strides"), 135 padding=op.get_attr("padding"), 136 data_format=op.get_attr("data_format")) 137 ] 138 139 140@ops.RegisterGradient("Conv3D") 141def _Conv3DGrad(op, grad): 142 data_format = op.get_attr("data_format").decode() 143 return [ 144 nn_ops.conv3d_backprop_input_v2( 145 array_ops.shape(op.inputs[0]), 146 op.inputs[1], 147 grad, 148 dilations=op.get_attr("dilations"), 149 strides=op.get_attr("strides"), 150 padding=op.get_attr("padding"), 151 data_format=data_format), 152 nn_ops.conv3d_backprop_filter_v2( 153 op.inputs[0], 154 array_ops.shape(op.inputs[1]), 155 grad, 156 dilations=op.get_attr("dilations"), 157 strides=op.get_attr("strides"), 158 padding=op.get_attr("padding"), 159 data_format=data_format) 160 ] 161 162 163@ops.RegisterGradient("Conv3DBackpropInputV2") 164def _Conv3DBackpropInputGrad(op, grad): 165 data_format = op.get_attr("data_format").decode() 166 return [ 167 None, 168 nn_ops.conv3d_backprop_filter_v2( 169 grad, 170 array_ops.shape(op.inputs[1]), 171 op.inputs[2], 172 dilations=op.get_attr("dilations"), 173 strides=op.get_attr("strides"), 174 padding=op.get_attr("padding"), 175 data_format=data_format), 176 nn_ops.conv3d( 177 grad, 178 op.inputs[1], 179 dilations=op.get_attr("dilations"), 180 strides=op.get_attr("strides"), 181 padding=op.get_attr("padding"), 182 data_format=data_format) 183 ] 184 185 186@ops.RegisterGradient("Conv3DBackpropFilterV2") 187def _Conv3DBackpropFilterGrad(op, grad): 188 data_format = op.get_attr("data_format").decode() 189 return [ 190 nn_ops.conv3d_backprop_input_v2( 191 array_ops.shape(op.inputs[0]), 192 grad, 193 op.inputs[2], 194 dilations=op.get_attr("dilations"), 195 strides=op.get_attr("strides"), 196 padding=op.get_attr("padding"), 197 data_format=data_format), None, 198 nn_ops.conv3d( 199 op.inputs[0], 200 grad, 201 dilations=op.get_attr("dilations"), 202 strides=op.get_attr("strides"), 203 padding=op.get_attr("padding"), 204 data_format=data_format) 205 ] 206 207 208@ops.RegisterGradient("AvgPool3D") 209def _AvgPool3DGrad(op, grad): 210 return gen_nn_ops.avg_pool3d_grad( 211 array_ops.shape(op.inputs[0]), 212 grad, 213 ksize=op.get_attr("ksize"), 214 strides=op.get_attr("strides"), 215 padding=op.get_attr("padding"), 216 data_format=op.get_attr("data_format").decode()) 217 218 219@ops.RegisterGradient("AvgPool3DGrad") 220def _AvgPool3DGradGrad(op, grad): 221 return (array_ops.stop_gradient(op.inputs[0]), 222 gen_nn_ops.avg_pool3d( 223 grad, 224 op.get_attr("ksize"), 225 op.get_attr("strides"), 226 op.get_attr("padding"), 227 data_format=op.get_attr("data_format").decode())) 228 229 230@ops.RegisterGradient("MaxPool3D") 231def _MaxPool3DGrad(op, grad): 232 return gen_nn_ops.max_pool3d_grad( 233 op.inputs[0], 234 op.outputs[0], 235 grad, 236 ksize=op.get_attr("ksize"), 237 strides=op.get_attr("strides"), 238 padding=op.get_attr("padding"), 239 data_format=op.get_attr("data_format").decode()) 240 241 242@ops.RegisterGradient("MaxPool3DGrad") 243def _MaxPool3DGradGrad(op, grad): 244 return (array_ops.zeros( 245 shape=array_ops.shape(op.inputs[0]), dtype=op.inputs[0].dtype), 246 array_ops.zeros( 247 shape=array_ops.shape(op.inputs[1]), dtype=op.inputs[1].dtype), 248 gen_nn_ops.max_pool3d_grad_grad( 249 op.inputs[0], 250 op.inputs[1], 251 grad, 252 op.get_attr("ksize"), 253 op.get_attr("strides"), 254 padding=op.get_attr("padding"), 255 data_format=op.get_attr("data_format").decode())) 256 257 258@ops.RegisterGradient("MaxPool3DGradGrad") 259def _MaxPool3DGradGradGrad(op, grad): 260 return (array_ops.zeros( 261 shape=array_ops.shape(op.inputs[0]), dtype=op.inputs[0].dtype), 262 array_ops.zeros( 263 shape=array_ops.shape(op.inputs[1]), dtype=op.inputs[1].dtype), 264 gen_nn_ops.max_pool3d_grad( 265 op.inputs[0], 266 op.inputs[1], 267 grad, 268 op.get_attr("ksize"), 269 op.get_attr("strides"), 270 padding=op.get_attr("padding"), 271 data_format=op.get_attr("data_format").decode())) 272 273 274@ops.RegisterGradient("Softmax") 275def _SoftmaxGrad(op, grad_softmax): 276 """The derivative of the softmax nonlinearity. 277 278 We assume that probs is of shape [batch_size * dim] 279 The formula for dsoftmax / dx = (diag(softmax) - softmax * softmax'). 280 This matrix is diagonal minus a rank one matrix, so it is easy to implement 281 as follows: 282 283 grad_x = grad_softmax * softmax - sum(grad_softmax * softmax) * softmax 284 285 Args: 286 op: the Softmax op. 287 grad_softmax: the tensor representing the gradient w.r.t. the softmax 288 output. 289 290 Returns: 291 gradient w.r.t the input to the softmax 292 293 """ 294 softmax = op.outputs[0] 295 sum_channels = math_ops.reduce_sum(grad_softmax * softmax, -1, keepdims=True) 296 return (grad_softmax - sum_channels) * softmax 297 298 299@ops.RegisterGradient("LogSoftmax") 300def _LogSoftmaxGrad(op, grad): 301 """The gradient for log_softmax. 302 303 log_softmax = input - log(sum(exp(input)) 304 dlog_softmax/dinput = diag - softmax(input) 305 306 Args: 307 op: The log softmax op. 308 grad: The tensor representing the gradient w.r.t. the output. 309 310 Returns: 311 The gradients w.r.t. the input. 312 """ 313 softmax = math_ops.exp(op.outputs[0]) 314 return grad - math_ops.reduce_sum(grad, -1, keepdims=True) * softmax 315 316 317@ops.RegisterGradient("BiasAdd") 318def _BiasAddGrad(op, received_grad): 319 """Return the gradients for the 2 inputs of bias_op. 320 321 The first input of unused_bias_op is the tensor t, and its gradient is 322 just the gradient the unused_bias_op received. 323 324 The second input of unused_bias_op is the bias vector which has one fewer 325 dimension than "received_grad" (the batch dimension.) Its gradient is the 326 received gradient Summed on the batch dimension, which is the first dimension. 327 328 Args: 329 op: The BiasOp for which we need to generate gradients. 330 received_grad: Tensor. The gradients passed to the BiasOp. 331 332 Returns: 333 Two tensors, the first one for the "tensor" input of the BiasOp, 334 the second one for the "bias" input of the BiasOp. 335 """ 336 try: 337 data_format = op.get_attr("data_format") 338 except ValueError: 339 data_format = None 340 return (received_grad, 341 gen_nn_ops.bias_add_grad( 342 out_backprop=received_grad, data_format=data_format)) 343 344 345@ops.RegisterGradient("BiasAddGrad") 346def _BiasAddGradGrad(op, received_grad): 347 """Gradient for the BiasAddGrad op. 348 349 Args: 350 op: BiasAddGrad op for which we are calculating gradients. 351 received_grad: The gradients passed to the BiasAddGrad op. 352 353 Returns: 354 A single gradient Tensor for the input to BiasAddGrad (which 355 is the gradient of the bias term in BiasAdd) 356 """ 357 358 try: 359 data_format = op.get_attr("data_format") 360 except ValueError: 361 data_format = None 362 363 shape = array_ops.shape(op.inputs[0]) 364 bias_shape = array_ops.shape(received_grad) 365 366 if data_format == b"NCHW": 367 expanded_shape = array_ops.concat([ 368 array_ops.ones_like(shape[:1]), bias_shape, 369 array_ops.ones_like(shape[2:]) 370 ], 0) 371 tile_mults = array_ops.concat([shape[:1], [1], shape[2:]], 0) 372 else: 373 expanded_shape = array_ops.concat( 374 [array_ops.ones_like(shape[:-1]), bias_shape], 0) 375 tile_mults = array_ops.concat([shape[:-1], [1]], 0) 376 377 expanded_grad = array_ops.reshape(received_grad, expanded_shape) 378 return array_ops.tile(expanded_grad, tile_mults) 379 380 381@ops.RegisterGradient("BiasAddV1") 382def _BiasAddGradV1(unused_bias_op, received_grad): 383 """Return the gradients for the 2 inputs of bias_op. 384 385 The first input of unused_bias_op is the tensor t, and its gradient is 386 just the gradient the unused_bias_op received. 387 388 The second input of unused_bias_op is the bias vector which has one fewer 389 dimension than "received_grad" (the batch dimension.) Its gradient is the 390 received gradient Summed on the batch dimension, which is the first dimension. 391 392 Args: 393 unused_bias_op: The BiasOp for which we need to generate gradients. 394 received_grad: Tensor. The gradients passed to the BiasOp. 395 396 Returns: 397 Two tensors, the first one for the "tensor" input of the BiasOp, 398 the second one for the "bias" input of the BiasOp. 399 """ 400 reduction_dim_tensor = math_ops.range(array_ops.rank(received_grad) - 1) 401 return (received_grad, math_ops.reduce_sum(received_grad, 402 reduction_dim_tensor)) 403 404 405@ops.RegisterGradient("Relu") 406def _ReluGrad(op, grad): 407 return gen_nn_ops.relu_grad(grad, op.outputs[0]) 408 409 410@ops.RegisterGradient("EluGrad") 411def _EluGradGrad(op, grad): 412 elu_x = op.inputs[1] 413 return (gen_nn_ops.elu_grad(grad, op.outputs[0]), 414 array_ops.where( 415 elu_x < 0, grad * op.inputs[0], 416 array_ops.zeros(shape=array_ops.shape(elu_x), dtype=elu_x.dtype))) 417 418 419@ops.RegisterGradient("SeluGrad") 420def _SeluGradGrad(op, grad): 421 x = op.inputs[1] 422 scale_alpha = 1.7580993408473768599402175208123 423 return (gen_nn_ops.elu_grad(grad, op.outputs[0]), 424 array_ops.where( 425 x < 0., gen_nn_ops.elu_grad(grad, op.outputs[0] + scale_alpha), 426 array_ops.zeros(shape=array_ops.shape(x), dtype=x.dtype))) 427 428 429@ops.RegisterGradient("Relu6") 430def _Relu6Grad(op, grad): 431 return gen_nn_ops.relu6_grad(grad, op.outputs[0]) 432 433 434@ops.RegisterGradient("Relu6Grad") 435def _Relu6GradGrad(op, grad): 436 x = op.inputs[1] 437 return (gen_nn_ops.relu6_grad(grad, x), 438 array_ops.zeros(shape=array_ops.shape(x), dtype=x.dtype)) 439 440 441@ops.RegisterGradient("LeakyRelu") 442def _LeakyReluGrad(op, grad): 443 x = op.inputs[0] 444 alpha = op.get_attr("alpha") 445 return gen_nn_ops.leaky_relu_grad(grad, x, alpha=alpha) 446 447 448@ops.RegisterGradient("LeakyReluGrad") 449def _LeakyReluGradGrad(op, grad): 450 x = op.inputs[1] 451 alpha = op.get_attr("alpha") 452 return (gen_nn_ops.leaky_relu_grad(grad, x, alpha=alpha), 453 array_ops.zeros(shape=array_ops.shape(x), dtype=x.dtype)) 454 455 456@ops.RegisterGradient("Elu") 457def _EluGrad(op, grad): 458 return gen_nn_ops.elu_grad(grad, op.outputs[0]) 459 460 461@ops.RegisterGradient("Selu") 462def _SeluGrad(op, grad): 463 return gen_nn_ops.selu_grad(grad, op.outputs[0]) 464 465 466@ops.RegisterGradient("Softplus") 467def _SoftplusGrad(op, grad): 468 return gen_nn_ops.softplus_grad(grad, op.inputs[0]) 469 470 471@ops.RegisterGradient("SoftplusGrad") 472def _SoftplusGradGrad(op, grad): 473 # Let: 474 # y = tf.nn.softplus(x) 475 # dx = gen_nn_ops.softplus_grad(dy, x) = dy / (1 + exp(-x)) 476 # This op computes (ddy, d2x) from op.inputs == [dy, x] and grad == ddx. 477 dy, x = op.inputs 478 with ops.control_dependencies([grad]): 479 ddy = gen_nn_ops.softplus_grad(grad, x) 480 d2x = grad * dy / (math_ops.exp(-x) + 2.0 + math_ops.exp(x)) 481 return (ddy, d2x) 482 483 484@ops.RegisterGradient("Softsign") 485def _SoftsignGrad(op, grad): 486 return gen_nn_ops.softsign_grad(grad, op.inputs[0]) 487 488 489@ops.RegisterGradient("ReluGrad") 490def _ReluGradGrad(op, grad): 491 x = op.inputs[1] 492 return (gen_nn_ops.relu_grad(grad, x), 493 array_ops.zeros(shape=array_ops.shape(x), dtype=x.dtype)) 494 495 496def _BroadcastMul(vec, mat): 497 """Multiply after broadcasting vec to match dimensions of mat. 498 499 Args: 500 vec: A 1-D tensor of dimension [D0] 501 mat: A 2-D tensor of dimension [D0, D1] 502 503 Returns: 504 A tensor of dimension [D0, D1], the result of vec * mat 505 """ 506 # Reshape vec to [D0, 1] 507 vec = array_ops.expand_dims(vec, -1) 508 return vec * mat 509 510 511@ops.RegisterGradient("SoftmaxCrossEntropyWithLogits") 512def _SoftmaxCrossEntropyWithLogitsGrad(op, grad_loss, grad_grad): 513 """Gradient function for SoftmaxCrossEntropyWithLogits.""" 514 # grad_loss is the backprop for cost, and we multiply it with the gradients 515 # (which is output[1]) 516 # grad_grad is the backprop for softmax gradient. 517 # 518 # Second derivative is just softmax derivative w.r.t. logits. 519 softmax_grad = op.outputs[1] 520 grad = _BroadcastMul(grad_loss, softmax_grad) 521 522 def IsZero(g): 523 # Some introspection to check if the gradient is feeding zeros 524 if context.executing_eagerly(): 525 # TODO(apassos) add an efficient way to detect eager zeros here. 526 return False 527 if g.op.type in ("ZerosLike", "Zeros"): 528 return True 529 const_fill_value = tensor_util.constant_value(g) 530 return const_fill_value is not None and (const_fill_value == 0).all() 531 532 logits = op.inputs[0] 533 if grad_grad is not None and not IsZero(grad_grad): 534 softmax = nn_ops.softmax(logits) 535 536 grad += ((grad_grad - array_ops.squeeze( 537 math_ops.matmul( 538 array_ops.expand_dims(grad_grad, 1), 539 array_ops.expand_dims(softmax, 2)), 540 axis=1)) * softmax) 541 542 return grad, _BroadcastMul(grad_loss, -nn_ops.log_softmax(logits)) 543 544 545@ops.RegisterGradient("SparseSoftmaxCrossEntropyWithLogits") 546def _SparseSoftmaxCrossEntropyWithLogitsGrad(op, grad_0, _): 547 """Gradient function for SparseSoftmaxCrossEntropyWithLogits.""" 548 # grad_0 is the backprop for cost, and we multiply it with the gradients 549 # (which is output[1]) 550 # There is no gradient for the labels 551 # 552 # Currently there is no way to take the second derivative of this op 553 # due to the fused implementation's interaction with tf.gradients(), 554 # so we make sure we prevent silently incorrect results by raising 555 # an error if the second derivative is requested via prevent_gradient. 556 sparse_softmax_grad_without_gradient = array_ops.prevent_gradient( 557 op.outputs[1], 558 message="Currently there is no way to take the second " 559 "derivative of sparse_softmax_cross_entropy_with_logits due to the fused " 560 "implementation's interaction with tf.gradients()") 561 return _BroadcastMul(grad_0, sparse_softmax_grad_without_gradient), None 562 563 564@ops.RegisterGradient("Conv2D") 565def _Conv2DGrad(op, grad): 566 """Gradient function for Conv2D.""" 567 dilations = op.get_attr("dilations") 568 strides = op.get_attr("strides") 569 padding = op.get_attr("padding") 570 explicit_paddings = op.get_attr("explicit_paddings") 571 use_cudnn_on_gpu = op.get_attr("use_cudnn_on_gpu") 572 data_format = op.get_attr("data_format") 573 shape_0, shape_1 = array_ops.shape_n([op.inputs[0], op.inputs[1]]) 574 575 # We call the gen_nn_ops backprop functions instead of nn_ops backprop 576 # functions for performance reasons in Eager mode. gen_nn_ops functions take a 577 # `explicit_paddings` parameter, but nn_ops functions do not. So if were were 578 # to use the nn_ops functions, we would have to convert `padding` and 579 # `explicit_paddings` into a single `padding` parameter, increasing overhead 580 # in Eager mode. 581 return [ 582 gen_nn_ops.conv2d_backprop_input( 583 shape_0, 584 op.inputs[1], 585 grad, 586 dilations=dilations, 587 strides=strides, 588 padding=padding, 589 explicit_paddings=explicit_paddings, 590 use_cudnn_on_gpu=use_cudnn_on_gpu, 591 data_format=data_format), 592 gen_nn_ops.conv2d_backprop_filter( 593 op.inputs[0], 594 shape_1, 595 grad, 596 dilations=dilations, 597 strides=strides, 598 padding=padding, 599 explicit_paddings=explicit_paddings, 600 use_cudnn_on_gpu=use_cudnn_on_gpu, 601 data_format=data_format) 602 ] 603 604 605@ops.RegisterGradient("DepthwiseConv2dNative") 606def _DepthwiseConv2dNativeGrad(op, grad): 607 return [ 608 nn_ops.depthwise_conv2d_native_backprop_input( 609 array_ops.shape(op.inputs[0]), 610 op.inputs[1], 611 grad, 612 op.get_attr("strides"), 613 op.get_attr("padding"), 614 data_format=op.get_attr("data_format")), 615 nn_ops.depthwise_conv2d_native_backprop_filter( 616 op.inputs[0], 617 array_ops.shape(op.inputs[1]), 618 grad, 619 op.get_attr("strides"), 620 op.get_attr("padding"), 621 data_format=op.get_attr("data_format")) 622 ] 623 624 625@ops.RegisterGradient("Dilation2D") 626def _Dilation2DGrad(op, grad): 627 return [ 628 nn_ops.dilation2d_backprop_input(op.inputs[0], op.inputs[1], grad, 629 op.get_attr("strides"), 630 op.get_attr("rates"), 631 op.get_attr("padding")), 632 nn_ops.dilation2d_backprop_filter(op.inputs[0], op.inputs[1], grad, 633 op.get_attr("strides"), 634 op.get_attr("rates"), 635 op.get_attr("padding")) 636 ] 637 638 639@ops.RegisterGradient("LRN") 640def _LRNGrad(op, grad): 641 depth_radius = op.get_attr("depth_radius") 642 bias = op.get_attr("bias") 643 alpha = op.get_attr("alpha") 644 beta = op.get_attr("beta") 645 return [ 646 gen_nn_ops.lrn_grad(grad, op.inputs[0], op.outputs[0], depth_radius, bias, 647 alpha, beta) 648 ] 649 650 651@ops.RegisterGradient("AvgPool") 652def _AvgPoolGrad(op, grad): 653 return gen_nn_ops.avg_pool_grad( 654 array_ops.shape(op.inputs[0]), 655 grad, 656 op.get_attr("ksize"), 657 op.get_attr("strides"), 658 op.get_attr("padding"), 659 data_format=op.get_attr("data_format")) 660 661 662@ops.RegisterGradient("AvgPoolGrad") 663def _AvgPoolGradGrad(op, grad): 664 return (array_ops.stop_gradient(op.inputs[0]), 665 gen_nn_ops.avg_pool( 666 grad, 667 op.get_attr("ksize"), 668 op.get_attr("strides"), 669 op.get_attr("padding"), 670 data_format=op.get_attr("data_format"))) 671 672 673@ops.RegisterGradient("MaxPool") 674def _MaxPoolGrad(op, grad): 675 return gen_nn_ops.max_pool_grad( 676 op.inputs[0], 677 op.outputs[0], 678 grad, 679 op.get_attr("ksize"), 680 op.get_attr("strides"), 681 padding=op.get_attr("padding"), 682 data_format=op.get_attr("data_format")) 683 684 685@ops.RegisterGradient("MaxPoolV2") 686def _MaxPoolGradV2(op, grad): 687 ksize = op.inputs[1] 688 strides = op.inputs[2] 689 return gen_nn_ops.max_pool_grad_v2( 690 op.inputs[0], 691 op.outputs[0], 692 grad, 693 ksize, 694 strides, 695 padding=op.get_attr("padding"), 696 data_format=op.get_attr("data_format")), None, None 697 698 699@ops.RegisterGradient("MaxPoolWithArgmax") 700def _MaxPoolGradWithArgmax(op, grad, unused_argmax_grad): 701 del unused_argmax_grad 702 return gen_nn_ops.max_pool_grad_with_argmax( 703 op.inputs[0], 704 grad, 705 op.outputs[1], 706 op.get_attr("ksize"), 707 op.get_attr("strides"), 708 padding=op.get_attr("padding"), 709 include_batch_in_index=op.get_attr("include_batch_in_index")) 710 711 712@ops.RegisterGradient("MaxPoolGrad") 713def _MaxPoolGradGrad(op, grad): 714 return (array_ops.zeros( 715 shape=array_ops.shape(op.inputs[0]), dtype=op.inputs[0].dtype), 716 array_ops.zeros( 717 shape=array_ops.shape(op.inputs[1]), dtype=op.inputs[1].dtype), 718 gen_nn_ops.max_pool_grad_grad( 719 op.inputs[0], 720 op.inputs[1], 721 grad, 722 op.get_attr("ksize"), 723 op.get_attr("strides"), 724 padding=op.get_attr("padding"), 725 data_format=op.get_attr("data_format"))) 726 727 728@ops.RegisterGradient("MaxPoolGradV2") 729def _MaxPoolGradGradV2(op, grad): 730 ksize = op.inputs[3] 731 strides = op.inputs[4] 732 return (array_ops.zeros( 733 shape=array_ops.shape(op.inputs[0]), dtype=op.inputs[0].dtype), 734 array_ops.zeros( 735 shape=array_ops.shape(op.inputs[1]), dtype=op.inputs[1].dtype), 736 gen_nn_ops.max_pool_grad_grad_v2( 737 op.inputs[0], 738 op.inputs[1], 739 grad, 740 ksize, 741 strides, 742 padding=op.get_attr("padding"), 743 data_format=op.get_attr("data_format")), None, None) 744 745 746@ops.RegisterGradient("MaxPoolGradGrad") 747def _MaxPoolGradGradGrad(op, grad): 748 return (array_ops.zeros( 749 shape=array_ops.shape(op.inputs[0]), dtype=op.inputs[0].dtype), 750 array_ops.zeros( 751 shape=array_ops.shape(op.inputs[1]), dtype=op.inputs[1].dtype), 752 gen_nn_ops.max_pool_grad( 753 op.inputs[0], 754 op.inputs[1], 755 grad, 756 op.get_attr("ksize"), 757 op.get_attr("strides"), 758 padding=op.get_attr("padding"), 759 data_format=op.get_attr("data_format"))) 760 761 762@ops.RegisterGradient("FractionalMaxPool") 763def _FractionalMaxPoolGrad(op, grad_0, unused_grad_1, unused_grad_2): 764 """Returns gradient for FractionalMaxPool. 765 766 Since FractionalMaxPool has three outputs, there are three gradients passed in 767 for each of the outputs. Only the first one is useful, the other two gradients 768 are empty. 769 770 Args: 771 op: The FractionalMaxPoolOp. 772 grad_0: Gradient with respect to op.outputs[0] 773 unused_grad_1: Gradient with respect to op.outputs[1]/row_seq. It is empty. 774 unused_grad_2: Gradient with respect to op.outputs[2]/col_seq. It is empty. 775 776 Returns: 777 Input backprop for FractionalMaxPool op. 778 """ 779 return gen_nn_ops.fractional_max_pool_grad( 780 op.inputs[0], op.outputs[0], grad_0, op.outputs[1], op.outputs[2], 781 op.get_attr("overlapping")) 782 783 784@ops.RegisterGradient("FractionalAvgPool") 785def _FractionalAvgPoolGrad(op, grad_0, unused_grad_1, unused_grad_2): 786 """Returns gradient for FractionalAvgPool. 787 788 Since FractionalAvgPool has three outputs, there are three gradients passed in 789 for each of the outputs. Only the first one is useful, the other two gradients 790 are empty. 791 792 Args: 793 op: The FractionalAvgPoolOp. 794 grad_0: Gradient with respect to op.outputs[0] 795 unused_grad_1: Gradient with respect to op.outputs[1]/row_seq. It is empty. 796 unused_grad_2: Gradient with respect to op.outputs[2]/col_seq. It is empty. 797 798 Returns: 799 Input backprop for FractionalAvgPool op. 800 """ 801 return gen_nn_ops.fractional_avg_pool_grad(op.inputs[0].get_shape(), grad_0, 802 op.outputs[1], op.outputs[2], 803 op.get_attr("overlapping")) 804 805 806@ops.RegisterGradient("BatchNormWithGlobalNormalization") 807def _BatchNormWithGlobalNormalizationGrad(op, grad): 808 """Return the gradients for the 5 inputs of BatchNormWithGlobalNormalization. 809 810 We do not backprop anything for the mean and var intentionally as they are 811 not being trained with backprop in the operation. 812 813 Args: 814 op: The BatchNormOp for which we need to generate gradients. 815 grad: Tensor. The gradients passed to the BatchNormOp. 816 817 Returns: 818 dx: Backprop for input, which is (grad * (g * rsqrt(v + epsilon))) 819 dm: Backprop for mean, which is 820 sum_over_rest(grad * g) * (-1 / rsqrt(v + epsilon)) 821 dv: Backprop for variance, which is 822 sum_over_rest(grad * g * (x - m)) * (-1/2) * (v + epsilon) ^ (-3/2) 823 db: Backprop for beta, which is grad reduced in all except the 824 last dimension. 825 dg: Backprop for gamma, which is (grad * ((x - m) * rsqrt(v + epsilon))) 826 """ 827 dx, dm, dv, db, dg = gen_nn_ops.batch_norm_with_global_normalization_grad( 828 op.inputs[0], op.inputs[1], op.inputs[2], op.inputs[4], grad, 829 op.get_attr("variance_epsilon"), op.get_attr("scale_after_normalization")) 830 return dx, dm, dv, db, dg 831 832 833def _BaseFusedBatchNormGrad(op, use_v2, *grad): 834 """Return the gradients for the 3 inputs of BatchNorm. 835 836 Args: 837 op: The BatchNormOp for which we need to compute gradients. 838 use_v2: Boolean indicating whether to use the V2 version of the fused batch 839 norm gradient. 840 *grad: An argument list for tensors of gradients wrt the outputs with 841 grad[0] as grad_y. 842 843 Returns: 844 grad_x: gradient for x, which is scale * rsqrt(variance + epsilon) * 845 [grad_y - mean(grad_y) - (x - mean(x)) * 846 mean(grad_y * (x - mean(x))) / (variance + epsilon)] 847 in training mode; grad_y * scale * rsqrt(pop_variance + epsilon) 848 in freeze mode. 849 850 grad_scale: gradient for scale, which is sum(grad_y * (x - mean(x)) * 851 rsqrt(variance + epsilon)) in training mode; 852 sum(grad_y * (x - pop_mean) * rsqrt(pop_variance + epsilon)) 853 in freeze mode. 854 855 grad_offset: gradient for offset, which is sum(grad_y) in training mode; 856 sum(grad_y) in freeze mode. 857 """ 858 x = op.inputs[0] 859 grad_y = grad[0] 860 scale = op.inputs[1] 861 epsilon = op.get_attr("epsilon") 862 data_format = op.get_attr("data_format") 863 is_training = op.get_attr("is_training") 864 grad_fun = ( 865 gen_nn_ops.fused_batch_norm_grad_v2 866 if use_v2 else gen_nn_ops.fused_batch_norm_grad) 867 if is_training: 868 return grad_fun( 869 grad_y, 870 x, 871 scale, 872 op.outputs[3], 873 op.outputs[4], 874 epsilon=epsilon, 875 data_format=data_format, 876 is_training=is_training) 877 else: 878 pop_mean = op.inputs[3] 879 pop_var = op.inputs[4] 880 if data_format == b"NCHW": 881 x = array_ops.transpose(x, [0, 2, 3, 1]) 882 grad_y = array_ops.transpose(grad_y, [0, 2, 3, 1]) 883 dx, dscale, doffset, _, _ = grad_fun( 884 grad_y, 885 x, 886 scale, 887 pop_mean, 888 pop_var, 889 epsilon=epsilon, 890 data_format="NHWC", 891 is_training=is_training) 892 if data_format == b"NCHW": 893 dx = array_ops.transpose(dx, [0, 3, 1, 2]) 894 return dx, dscale, doffset, None, None 895 896 897@ops.RegisterGradient("FusedBatchNorm") 898def _FusedBatchNormGrad(op, *grad): 899 return _BaseFusedBatchNormGrad(op, False, *grad) 900 901 902@ops.RegisterGradient("FusedBatchNormV2") 903def _FusedBatchNormV2Grad(op, *grad): 904 return _BaseFusedBatchNormGrad(op, True, *grad) 905 906 907def _BatchNormGrad(grad_y, 908 x, 909 scale, 910 pop_mean, 911 pop_var, 912 epsilon, 913 data_format, 914 is_training=True): 915 """Returns the gradients for the 3 inputs of BatchNorm. 916 917 Args: 918 grad_y: A `Tensor` of 4 dimensions for gradient for y. 919 x: A `Tensor` of 4 dimensions for x. 920 scale: A `Tensor` of 1 dimension for scaling. 921 pop_mean: A `Tensor` of 1 dimension for the population mean. Only used when 922 is_training=False. 923 pop_var: A `Tensor` of 1 dimension for the population variance. Only used 924 when is_training=False. 925 epsilon: A small float number added to the variance of x. 926 data_format: The data format for input. Either b"NHWC" or b"NCHW". 927 is_training: A bool value to indicate the operation is for training 928 (default) or inference. 929 930 Returns: 931 A tuple (grad_x, grad_scale, grad_offset), where grad_x is the gradient 932 for x, grad_scale the gradient for scale, and grad_offset the gradient 933 for offset. 934 """ 935 x_dtype = x.dtype.base_dtype 936 if x_dtype == dtypes.float16: 937 # float16 math is too imprecise, so we do the batch norm gradient 938 # computations in float32. 939 x = math_ops.cast(x, dtypes.float32) 940 grad_y = math_ops.cast(grad_y, dtypes.float32) 941 if is_training: 942 if data_format == b"NHWC": 943 keepdims = False 944 reduce_axis = [0, 1, 2] 945 else: 946 keepdims = True 947 reduce_axis = [0, 2, 3] 948 shape = [1, array_ops.size(scale), 1, 1] 949 scale = array_ops.reshape(scale, shape) 950 mean_grad_y = math_ops.reduce_mean(grad_y, reduce_axis, keepdims=keepdims) 951 mean_x = math_ops.reduce_mean(x, reduce_axis, keepdims=keepdims) 952 var_x = math_ops.reduce_mean( 953 math_ops.squared_difference(x, array_ops.stop_gradient(mean_x)), 954 reduce_axis, 955 keepdims=keepdims) 956 grad_y_offset = grad_y - mean_grad_y 957 x_offset = x - mean_x 958 mean = math_ops.reduce_mean( 959 grad_y * x_offset, axis=reduce_axis, keepdims=keepdims) 960 grad_x = scale * math_ops.rsqrt(var_x + epsilon) * ( 961 grad_y_offset - math_ops.reciprocal(var_x + epsilon) * mean * x_offset) 962 grad_scale = math_ops.rsqrt(var_x + epsilon) * math_ops.reduce_sum( 963 grad_y * x_offset, axis=reduce_axis, keepdims=keepdims) 964 if data_format == b"NCHW": 965 grad_scale = array_ops.squeeze(grad_scale) 966 grad_offset = math_ops.reduce_sum(grad_y, axis=reduce_axis) 967 return math_ops.cast(grad_x, x_dtype), grad_scale, grad_offset 968 else: 969 if data_format == b"NHWC": 970 reduce_axis = [0, 1, 2] 971 else: 972 reduce_axis = [0, 2, 3] 973 shape = [1, array_ops.size(pop_mean), 1, 1] 974 pop_mean = array_ops.reshape(pop_mean, shape) 975 pop_var = array_ops.reshape(pop_var, shape) 976 scale = array_ops.reshape(scale, shape) 977 978 grad_offset = math_ops.reduce_sum(grad_y, axis=reduce_axis) 979 var_rsqrt = math_ops.rsqrt(pop_var + epsilon) 980 grad_scale = math_ops.reduce_sum( 981 grad_y * (x - pop_mean) * var_rsqrt, axis=reduce_axis) 982 grad_x = grad_y * scale * var_rsqrt 983 return math_ops.cast(grad_x, x_dtype), grad_scale, grad_offset 984 985 986@ops.RegisterGradient("FusedBatchNormGrad") 987def _FusedBatchNormGradGrad(op, *grad): 988 """Returns the gradients for the 3 inputs of FusedBatchNormGrad. 989 990 Args: 991 op: The FusedBatchNormGradOp for which we need to compute gradients. 992 *grad: An argument list for tensors of gradients wrt the outputs with 993 grad[0] as grad_grad_x, grad[1] as grad_grad_scale, grad[2] as 994 grad_grad_offset. 995 996 Returns: 997 A tuple (grad_grad_y, grad_x, grad_scale, None, None), where grad_grad_y 998 is the gradient for grad_y, grad_x the gradient for x, grad_scale the 999 gradient for scale. 1000 """ 1001 data_format = op.get_attr("data_format") 1002 epsilon = op.get_attr("epsilon") 1003 is_training = op.get_attr("is_training") 1004 grad_y = op.inputs[0] 1005 x = op.inputs[1] 1006 scale = op.inputs[2] 1007 pop_mean = op.inputs[3] 1008 pop_var = op.inputs[4] 1009 grad_grad_x = grad[0] 1010 grad_grad_scale = grad[1] 1011 grad_grad_offset = grad[2] 1012 with backprop.GradientTape() as tape: 1013 tape.watch(grad_y) 1014 tape.watch(x) 1015 tape.watch(scale) 1016 grad_x, grad_scale, grad_offset = _BatchNormGrad( 1017 grad_y, x, scale, pop_mean, pop_var, epsilon, data_format, is_training) 1018 grad_initial = [grad_grad_x, grad_grad_scale, grad_grad_offset] 1019 grad_grad_y, grad_x, grad_scale = tape.gradient( 1020 [grad_x, grad_scale, grad_offset], [grad_y, x, scale], grad_initial) 1021 return grad_grad_y, grad_x, grad_scale, None, None 1022 1023 1024@ops.RegisterGradient("FusedBatchNormGradV2") 1025def _FusedBatchNormGradGradV2(op, *grad): 1026 return _FusedBatchNormGradGrad(op, *grad) 1027 1028 1029@ops.RegisterGradient("L2Loss") 1030def _L2LossGrad(op, grad): 1031 """Return the gradients for L2Loss. 1032 1033 Args: 1034 op: The L2LossOp for which we need to generate gradients. 1035 grad: Tensor containing a single number. 1036 1037 Returns: 1038 The gradient, which is (x * grad). 1039 """ 1040 return op.inputs[0] * grad 1041 1042 1043@ops.RegisterGradient("TopK") 1044@ops.RegisterGradient("TopKV2") 1045def _TopKGrad(op, grad, _): 1046 """Return the gradients for TopK. 1047 1048 Args: 1049 op: The TopKOp for which we need to generate gradients. 1050 grad: Tensor. The gradients passed to the TopKOp. 1051 1052 Returns: 1053 A list of two tensors, the first being the gradient w.r.t to the input and 1054 TopK, and the second being the gradient w.r.t. to the indices (all zero). 1055 """ 1056 in_shape = array_ops.shape(op.inputs[0]) 1057 ind_shape = array_ops.shape(op.outputs[1]) 1058 1059 # int32 is not supported on GPU hence up-casting 1060 ind_lastdim = array_ops.gather( 1061 math_ops.cast(ind_shape, dtypes.int64), 1062 array_ops.size(ind_shape) - 1) 1063 # Flatten indices to 2D. 1064 ind_2d = array_ops.reshape(op.outputs[1], array_ops.stack([-1, ind_lastdim])) 1065 1066 in_lastdim = array_ops.gather( 1067 math_ops.cast(in_shape, dtypes.int64), 1068 array_ops.size(in_shape) - 1) 1069 outerdim = array_ops.shape(ind_2d)[0] 1070 # Compute linear indices (flattened to 1D). 1071 ind = array_ops.reshape( 1072 ind_2d + math_ops.cast( 1073 array_ops.expand_dims( 1074 math_ops.range(0, 1075 math_ops.cast(outerdim, dtypes.int64) * in_lastdim, 1076 in_lastdim), -1), dtypes.int32), [-1]) 1077 1078 # Substitute grad to appropriate locations and fill the rest with zeros, 1079 # finally reshaping it to the original input shape. 1080 return [ 1081 array_ops.reshape( 1082 array_ops.scatter_nd( 1083 array_ops.expand_dims(ind, -1), array_ops.reshape(grad, [-1]), 1084 [math_ops.reduce_prod(in_shape)]), in_shape), 1085 array_ops.zeros([], dtype=dtypes.int32) 1086 ] 1087 1088 1089@ops.RegisterGradient("NthElement") 1090def _NthElementGrad(op, grad): 1091 """Return the gradients for NthElement. 1092 1093 Args: 1094 op: The NthElementOp for which we need to generate gradients. 1095 grad: Tensor. The gradients passed to the NthElementOp 1096 1097 Returns: 1098 A list of two tensors, the first being the gradient w.r.t. the input, 1099 the second being the gradient w.r.t. the N (None). 1100 """ 1101 input = op.inputs[0] # pylint: disable=redefined-builtin 1102 output = op.outputs[0] 1103 1104 # Compute the number of elements which equal to output in each reduction 1105 # dimension. If there are multiple elements then the gradient will be 1106 # divided between them. 1107 indicators = math_ops.cast( 1108 math_ops.equal(array_ops.expand_dims(output, -1), input), grad.dtype) 1109 1110 grad = array_ops.expand_dims(grad, -1) 1111 num_selected = array_ops.expand_dims(math_ops.reduce_sum(indicators, -1), -1) 1112 1113 return [math_ops.div(indicators, num_selected) * grad, None] 1114