1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Gradients for operators defined in nn_ops.py.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.eager import backprop 22from tensorflow.python.framework import dtypes 23from tensorflow.python.framework import ops 24from tensorflow.python.ops import array_ops 25from tensorflow.python.ops import gen_nn_ops 26from tensorflow.python.ops import math_ops 27from tensorflow.python.ops import nn_ops 28 29 30@ops.RegisterGradient("Conv2DBackpropInput") 31def _Conv2DBackpropInputGrad(op, grad): 32 """The derivatives for deconvolution. 33 34 Args: 35 op: the Deconvolution op. 36 grad: the tensor representing the gradient w.r.t. the output 37 38 Returns: 39 the gradients w.r.t. the input and the filter 40 """ 41 # We call the gen_nn_ops backprop functions instead of nn_ops backprop 42 # functions for performance reasons in Eager mode. See _Conv2DGrad. 43 return [ 44 None, 45 gen_nn_ops.conv2d_backprop_filter( 46 grad, 47 array_ops.shape(op.inputs[1]), 48 op.inputs[2], 49 dilations=op.get_attr("dilations"), 50 strides=op.get_attr("strides"), 51 padding=op.get_attr("padding"), 52 explicit_paddings=op.get_attr("explicit_paddings"), 53 use_cudnn_on_gpu=op.get_attr("use_cudnn_on_gpu"), 54 data_format=op.get_attr("data_format").decode()), 55 gen_nn_ops.conv2d( 56 grad, 57 op.inputs[1], 58 dilations=op.get_attr("dilations"), 59 strides=op.get_attr("strides"), 60 padding=op.get_attr("padding"), 61 explicit_paddings=op.get_attr("explicit_paddings"), 62 use_cudnn_on_gpu=op.get_attr("use_cudnn_on_gpu"), 63 data_format=op.get_attr("data_format").decode()) 64 ] 65 66 67@ops.RegisterGradient("Conv2DBackpropFilter") 68def _Conv2DBackpropFilterGrad(op, grad): 69 # We call the gen_nn_ops backprop functions instead of nn_ops backprop 70 # functions for performance reasons in Eager mode. See _Conv2DGrad. 71 return [ 72 gen_nn_ops.conv2d_backprop_input( 73 array_ops.shape(op.inputs[0]), 74 grad, 75 op.inputs[2], 76 dilations=op.get_attr("dilations"), 77 strides=op.get_attr("strides"), 78 padding=op.get_attr("padding"), 79 explicit_paddings=op.get_attr("explicit_paddings"), 80 use_cudnn_on_gpu=op.get_attr("use_cudnn_on_gpu"), 81 data_format=op.get_attr("data_format").decode()), None, 82 gen_nn_ops.conv2d( 83 op.inputs[0], 84 grad, 85 dilations=op.get_attr("dilations"), 86 strides=op.get_attr("strides"), 87 padding=op.get_attr("padding"), 88 explicit_paddings=op.get_attr("explicit_paddings"), 89 use_cudnn_on_gpu=op.get_attr("use_cudnn_on_gpu"), 90 data_format=op.get_attr("data_format").decode()) 91 ] 92 93 94@ops.RegisterGradient("DepthwiseConv2dNativeBackpropInput") 95def _DepthwiseConv2dNativeBackpropInputGrad(op, grad): 96 """The derivatives for deconvolution. 97 98 Args: 99 op: the Deconvolution op. 100 grad: the tensor representing the gradient w.r.t. the output 101 102 Returns: 103 the gradients w.r.t. the input and the filter 104 """ 105 return [ 106 None, 107 gen_nn_ops.depthwise_conv2d_native_backprop_filter( 108 grad, 109 array_ops.shape(op.inputs[1]), 110 op.inputs[2], 111 dilations=op.get_attr("dilations"), 112 strides=op.get_attr("strides"), 113 padding=op.get_attr("padding"), 114 explicit_paddings=op.get_attr("explicit_paddings"), 115 data_format=op.get_attr("data_format")), 116 gen_nn_ops.depthwise_conv2d_native( 117 grad, 118 op.inputs[1], 119 dilations=op.get_attr("dilations"), 120 strides=op.get_attr("strides"), 121 padding=op.get_attr("padding"), 122 explicit_paddings=op.get_attr("explicit_paddings"), 123 data_format=op.get_attr("data_format")) 124 ] 125 126 127@ops.RegisterGradient("DepthwiseConv2dNativeBackpropFilter") 128def _DepthwiseConv2dNativeBackpropFilterGrad(op, grad): 129 return [ 130 gen_nn_ops.depthwise_conv2d_native_backprop_input( 131 array_ops.shape(op.inputs[0]), 132 grad, 133 op.inputs[2], 134 dilations=op.get_attr("dilations"), 135 strides=op.get_attr("strides"), 136 padding=op.get_attr("padding"), 137 explicit_paddings=op.get_attr("explicit_paddings"), 138 data_format=op.get_attr("data_format")), None, 139 gen_nn_ops.depthwise_conv2d_native( 140 op.inputs[0], 141 grad, 142 dilations=op.get_attr("dilations"), 143 strides=op.get_attr("strides"), 144 padding=op.get_attr("padding"), 145 explicit_paddings=op.get_attr("explicit_paddings"), 146 data_format=op.get_attr("data_format")) 147 ] 148 149 150@ops.RegisterGradient("Conv3D") 151def _Conv3DGrad(op, grad): 152 data_format = op.get_attr("data_format").decode() 153 return [ 154 nn_ops.conv3d_backprop_input_v2( 155 array_ops.shape(op.inputs[0]), 156 op.inputs[1], 157 grad, 158 dilations=op.get_attr("dilations"), 159 strides=op.get_attr("strides"), 160 padding=op.get_attr("padding"), 161 data_format=data_format), 162 nn_ops.conv3d_backprop_filter_v2( 163 op.inputs[0], 164 array_ops.shape(op.inputs[1]), 165 grad, 166 dilations=op.get_attr("dilations"), 167 strides=op.get_attr("strides"), 168 padding=op.get_attr("padding"), 169 data_format=data_format) 170 ] 171 172 173@ops.RegisterGradient("Conv3DBackpropInputV2") 174def _Conv3DBackpropInputGrad(op, grad): 175 data_format = op.get_attr("data_format").decode() 176 return [ 177 None, 178 nn_ops.conv3d_backprop_filter_v2( 179 grad, 180 array_ops.shape(op.inputs[1]), 181 op.inputs[2], 182 dilations=op.get_attr("dilations"), 183 strides=op.get_attr("strides"), 184 padding=op.get_attr("padding"), 185 data_format=data_format), 186 nn_ops.conv3d( 187 grad, 188 op.inputs[1], 189 dilations=op.get_attr("dilations"), 190 strides=op.get_attr("strides"), 191 padding=op.get_attr("padding"), 192 data_format=data_format) 193 ] 194 195 196@ops.RegisterGradient("Conv3DBackpropFilterV2") 197def _Conv3DBackpropFilterGrad(op, grad): 198 data_format = op.get_attr("data_format").decode() 199 return [ 200 nn_ops.conv3d_backprop_input_v2( 201 array_ops.shape(op.inputs[0]), 202 grad, 203 op.inputs[2], 204 dilations=op.get_attr("dilations"), 205 strides=op.get_attr("strides"), 206 padding=op.get_attr("padding"), 207 data_format=data_format), None, 208 nn_ops.conv3d( 209 op.inputs[0], 210 grad, 211 dilations=op.get_attr("dilations"), 212 strides=op.get_attr("strides"), 213 padding=op.get_attr("padding"), 214 data_format=data_format) 215 ] 216 217 218@ops.RegisterGradient("AvgPool3D") 219def _AvgPool3DGrad(op, grad): 220 return gen_nn_ops.avg_pool3d_grad( 221 array_ops.shape(op.inputs[0]), 222 grad, 223 ksize=op.get_attr("ksize"), 224 strides=op.get_attr("strides"), 225 padding=op.get_attr("padding"), 226 data_format=op.get_attr("data_format").decode()) 227 228 229@ops.RegisterGradient("AvgPool3DGrad") 230def _AvgPool3DGradGrad(op, grad): 231 return (array_ops.stop_gradient(op.inputs[0]), 232 gen_nn_ops.avg_pool3d( 233 grad, 234 op.get_attr("ksize"), 235 op.get_attr("strides"), 236 op.get_attr("padding"), 237 data_format=op.get_attr("data_format").decode())) 238 239 240@ops.RegisterGradient("MaxPool3D") 241def _MaxPool3DGrad(op, grad): 242 return gen_nn_ops.max_pool3d_grad( 243 op.inputs[0], 244 op.outputs[0], 245 grad, 246 ksize=op.get_attr("ksize"), 247 strides=op.get_attr("strides"), 248 padding=op.get_attr("padding"), 249 data_format=op.get_attr("data_format").decode()) 250 251 252@ops.RegisterGradient("MaxPool3DGrad") 253def _MaxPool3DGradGrad(op, grad): 254 return (array_ops.zeros( 255 shape=array_ops.shape(op.inputs[0]), dtype=op.inputs[0].dtype), 256 array_ops.zeros( 257 shape=array_ops.shape(op.inputs[1]), dtype=op.inputs[1].dtype), 258 gen_nn_ops.max_pool3d_grad_grad( 259 op.inputs[0], 260 op.inputs[1], 261 grad, 262 op.get_attr("ksize"), 263 op.get_attr("strides"), 264 padding=op.get_attr("padding"), 265 data_format=op.get_attr("data_format").decode())) 266 267 268@ops.RegisterGradient("MaxPool3DGradGrad") 269def _MaxPool3DGradGradGrad(op, grad): 270 return (array_ops.zeros( 271 shape=array_ops.shape(op.inputs[0]), dtype=op.inputs[0].dtype), 272 array_ops.zeros( 273 shape=array_ops.shape(op.inputs[1]), dtype=op.inputs[1].dtype), 274 gen_nn_ops.max_pool3d_grad( 275 op.inputs[0], 276 op.inputs[1], 277 grad, 278 op.get_attr("ksize"), 279 op.get_attr("strides"), 280 padding=op.get_attr("padding"), 281 data_format=op.get_attr("data_format").decode())) 282 283 284@ops.RegisterGradient("Softmax") 285def _SoftmaxGrad(op, grad_softmax): 286 """The derivative of the softmax nonlinearity. 287 288 We assume that probs is of shape [batch_size * dim] 289 The formula for dsoftmax / dx = (diag(softmax) - softmax * softmax'). 290 This matrix is diagonal minus a rank one matrix, so it is easy to implement 291 as follows: 292 293 grad_x = grad_softmax * softmax - sum(grad_softmax * softmax) * softmax 294 295 Args: 296 op: the Softmax op. 297 grad_softmax: the tensor representing the gradient w.r.t. the softmax 298 output. 299 300 Returns: 301 gradient w.r.t the input to the softmax 302 303 """ 304 softmax = op.outputs[0] 305 sum_channels = math_ops.reduce_sum(grad_softmax * softmax, -1, keepdims=True) 306 return (grad_softmax - sum_channels) * softmax 307 308 309@ops.RegisterGradient("LogSoftmax") 310def _LogSoftmaxGrad(op, grad): 311 """The gradient for log_softmax. 312 313 log_softmax = input - log(sum(exp(input)) 314 dlog_softmax/dinput = diag - softmax(input) 315 316 Args: 317 op: The log softmax op. 318 grad: The tensor representing the gradient w.r.t. the output. 319 320 Returns: 321 The gradients w.r.t. the input. 322 """ 323 softmax = math_ops.exp(op.outputs[0]) 324 return grad - math_ops.reduce_sum(grad, -1, keepdims=True) * softmax 325 326 327@ops.RegisterGradient("BiasAdd") 328def _BiasAddGrad(op, received_grad): 329 """Return the gradients for the 2 inputs of bias_op. 330 331 The first input of unused_bias_op is the tensor t, and its gradient is 332 just the gradient the unused_bias_op received. 333 334 The second input of unused_bias_op is the bias vector which has one fewer 335 dimension than "received_grad" (the batch dimension.) Its gradient is the 336 received gradient Summed on the batch dimension, which is the first dimension. 337 338 Args: 339 op: The BiasOp for which we need to generate gradients. 340 received_grad: Tensor. The gradients passed to the BiasOp. 341 342 Returns: 343 Two tensors, the first one for the "tensor" input of the BiasOp, 344 the second one for the "bias" input of the BiasOp. 345 """ 346 try: 347 data_format = op.get_attr("data_format") 348 except ValueError: 349 data_format = None 350 return (received_grad, 351 gen_nn_ops.bias_add_grad( 352 out_backprop=received_grad, data_format=data_format)) 353 354 355@ops.RegisterGradient("BiasAddGrad") 356def _BiasAddGradGrad(op, received_grad): 357 """Gradient for the BiasAddGrad op. 358 359 Args: 360 op: BiasAddGrad op for which we are calculating gradients. 361 received_grad: The gradients passed to the BiasAddGrad op. 362 363 Returns: 364 A single gradient Tensor for the input to BiasAddGrad (which 365 is the gradient of the bias term in BiasAdd) 366 """ 367 368 try: 369 data_format = op.get_attr("data_format") 370 except ValueError: 371 data_format = None 372 373 shape = array_ops.shape(op.inputs[0]) 374 bias_shape = array_ops.shape(received_grad) 375 376 if data_format == b"NCHW": 377 expanded_shape = array_ops.concat([ 378 array_ops.ones_like(shape[:1]), bias_shape, 379 array_ops.ones_like(shape[2:]) 380 ], 0) 381 tile_mults = array_ops.concat([shape[:1], [1], shape[2:]], 0) 382 else: 383 expanded_shape = array_ops.concat( 384 [array_ops.ones_like(shape[:-1]), bias_shape], 0) 385 tile_mults = array_ops.concat([shape[:-1], [1]], 0) 386 387 expanded_grad = array_ops.reshape(received_grad, expanded_shape) 388 return array_ops.tile(expanded_grad, tile_mults) 389 390 391@ops.RegisterGradient("BiasAddV1") 392def _BiasAddGradV1(unused_bias_op, received_grad): 393 """Return the gradients for the 2 inputs of bias_op. 394 395 The first input of unused_bias_op is the tensor t, and its gradient is 396 just the gradient the unused_bias_op received. 397 398 The second input of unused_bias_op is the bias vector which has one fewer 399 dimension than "received_grad" (the batch dimension.) Its gradient is the 400 received gradient Summed on the batch dimension, which is the first dimension. 401 402 Args: 403 unused_bias_op: The BiasOp for which we need to generate gradients. 404 received_grad: Tensor. The gradients passed to the BiasOp. 405 406 Returns: 407 Two tensors, the first one for the "tensor" input of the BiasOp, 408 the second one for the "bias" input of the BiasOp. 409 """ 410 reduction_dim_tensor = math_ops.range(array_ops.rank(received_grad) - 1) 411 return (received_grad, math_ops.reduce_sum(received_grad, 412 reduction_dim_tensor)) 413 414 415@ops.RegisterGradient("Relu") 416def _ReluGrad(op, grad): 417 return gen_nn_ops.relu_grad(grad, op.outputs[0]) 418 419 420@ops.RegisterGradient("EluGrad") 421def _EluGradGrad(op, grad): 422 elu_x = op.inputs[1] 423 return (gen_nn_ops.elu_grad(grad, elu_x), 424 array_ops.where( 425 elu_x < 0, grad * op.inputs[0], array_ops.zeros_like(elu_x))) 426 427 428@ops.RegisterGradient("SeluGrad") 429def _SeluGradGrad(op, grad): 430 selu_x = op.inputs[1] 431 return (gen_nn_ops.selu_grad(grad, selu_x), 432 array_ops.where( 433 selu_x < 0., grad * op.inputs[0], array_ops.zeros_like(selu_x))) 434 435 436@ops.RegisterGradient("Relu6") 437def _Relu6Grad(op, grad): 438 return gen_nn_ops.relu6_grad(grad, op.outputs[0]) 439 440 441@ops.RegisterGradient("Relu6Grad") 442def _Relu6GradGrad(op, grad): 443 x = op.inputs[1] 444 return (gen_nn_ops.relu6_grad(grad, x), 445 array_ops.zeros(shape=array_ops.shape(x), dtype=x.dtype)) 446 447 448@ops.RegisterGradient("LeakyRelu") 449def _LeakyReluGrad(op, grad): 450 x = op.inputs[0] 451 alpha = op.get_attr("alpha") 452 return gen_nn_ops.leaky_relu_grad(grad, x, alpha=alpha) 453 454 455@ops.RegisterGradient("LeakyReluGrad") 456def _LeakyReluGradGrad(op, grad): 457 x = op.inputs[1] 458 alpha = op.get_attr("alpha") 459 return (gen_nn_ops.leaky_relu_grad(grad, x, alpha=alpha), 460 array_ops.zeros(shape=array_ops.shape(x), dtype=x.dtype)) 461 462 463@ops.RegisterGradient("Elu") 464def _EluGrad(op, grad): 465 return gen_nn_ops.elu_grad(grad, op.outputs[0]) 466 467 468@ops.RegisterGradient("Selu") 469def _SeluGrad(op, grad): 470 return gen_nn_ops.selu_grad(grad, op.outputs[0]) 471 472 473@ops.RegisterGradient("Softplus") 474def _SoftplusGrad(op, grad): 475 return grad * math_ops.sigmoid(op.inputs[0]) 476 477 478@ops.RegisterGradient("SoftplusGrad") 479def _SoftplusGradGrad(op, grad): 480 # Let: 481 # y = tf.nn.softplus(x) 482 # dx = gen_nn_ops.softplus_grad(dy, x) = dy / (1 + exp(-x)) 483 # This op computes (ddy, d2x) from op.inputs == [dy, x] and grad == ddx. 484 dy, x = op.inputs 485 with ops.control_dependencies([grad]): 486 ddy = gen_nn_ops.softplus_grad(grad, x) 487 d2x = grad * dy / (math_ops.exp(-x) + 2.0 + math_ops.exp(x)) 488 return (ddy, d2x) 489 490 491@ops.RegisterGradient("Softsign") 492def _SoftsignGrad(op, grad): 493 return gen_nn_ops.softsign_grad(grad, op.inputs[0]) 494 495 496@ops.RegisterGradient("ReluGrad") 497def _ReluGradGrad(op, grad): 498 x = op.inputs[1] 499 return (gen_nn_ops.relu_grad(grad, x), 500 array_ops.zeros(shape=array_ops.shape(x), dtype=x.dtype)) 501 502 503def _BroadcastMul(vec, mat): 504 """Multiply after broadcasting vec to match dimensions of mat. 505 506 Args: 507 vec: A 1-D tensor of dimension [D0] 508 mat: A 2-D tensor of dimension [D0, D1] 509 510 Returns: 511 A tensor of dimension [D0, D1], the result of vec * mat 512 """ 513 # Reshape vec to [D0, 1] 514 vec = array_ops.expand_dims(vec, -1) 515 return vec * mat 516 517 518@ops.RegisterGradient("SoftmaxCrossEntropyWithLogits") 519def _SoftmaxCrossEntropyWithLogitsGrad(op, grad_loss, grad_grad): 520 """Gradient function for SoftmaxCrossEntropyWithLogits.""" 521 # grad_loss is the backprop for cost, and we multiply it with the gradients 522 # (which is output[1]) 523 # grad_grad is the backprop for softmax gradient. 524 # 525 # Second derivative is just softmax derivative w.r.t. logits. 526 softmax_grad = op.outputs[1] 527 grad = _BroadcastMul(grad_loss, softmax_grad) 528 529 logits = op.inputs[0] 530 if (grad_grad is not None and 531 not getattr(grad_grad, "_is_zeros_tensor", False)): 532 softmax = nn_ops.softmax(logits) 533 534 grad += ((grad_grad - array_ops.squeeze( 535 math_ops.matmul( 536 array_ops.expand_dims(grad_grad, 1), 537 array_ops.expand_dims(softmax, 2)), 538 axis=1)) * softmax) 539 540 return grad, _BroadcastMul(grad_loss, -nn_ops.log_softmax(logits)) 541 542 543@ops.RegisterGradient("SparseSoftmaxCrossEntropyWithLogits") 544def _SparseSoftmaxCrossEntropyWithLogitsGrad(op, grad_loss, grad_grad): 545 """Gradient function for SparseSoftmaxCrossEntropyWithLogits.""" 546 # grad_loss is the backprop for cost, and we multiply it with the gradients 547 # (which is output[1]) 548 # grad_grad is the backprop for softmax gradient. 549 # There is no gradient for the labels 550 # 551 # Second derivative is just softmax derivative w.r.t. logits. 552 softmax_grad = op.outputs[1] 553 grad = _BroadcastMul(grad_loss, softmax_grad) 554 555 logits = op.inputs[0] 556 if (grad_grad is not None and 557 not getattr(grad_grad, "_is_zeros_tensor", False)): 558 softmax = nn_ops.softmax(logits) 559 560 grad += ((grad_grad - array_ops.squeeze( 561 math_ops.matmul( 562 array_ops.expand_dims(grad_grad, 1), 563 array_ops.expand_dims(softmax, 2)), 564 axis=1)) * softmax) 565 566 return grad, None 567 568 569@ops.RegisterGradient("Conv2D") 570def _Conv2DGrad(op, grad): 571 """Gradient function for Conv2D.""" 572 dilations = op.get_attr("dilations") 573 strides = op.get_attr("strides") 574 padding = op.get_attr("padding") 575 explicit_paddings = op.get_attr("explicit_paddings") 576 use_cudnn_on_gpu = op.get_attr("use_cudnn_on_gpu") 577 data_format = op.get_attr("data_format") 578 shape_0, shape_1 = array_ops.shape_n([op.inputs[0], op.inputs[1]]) 579 580 # We call the gen_nn_ops backprop functions instead of nn_ops backprop 581 # functions for performance reasons in Eager mode. gen_nn_ops functions take a 582 # `explicit_paddings` parameter, but nn_ops functions do not. So if we were 583 # to use the nn_ops functions, we would have to convert `padding` and 584 # `explicit_paddings` into a single `padding` parameter, increasing overhead 585 # in Eager mode. 586 return [ 587 gen_nn_ops.conv2d_backprop_input( 588 shape_0, 589 op.inputs[1], 590 grad, 591 dilations=dilations, 592 strides=strides, 593 padding=padding, 594 explicit_paddings=explicit_paddings, 595 use_cudnn_on_gpu=use_cudnn_on_gpu, 596 data_format=data_format), 597 gen_nn_ops.conv2d_backprop_filter( 598 op.inputs[0], 599 shape_1, 600 grad, 601 dilations=dilations, 602 strides=strides, 603 padding=padding, 604 explicit_paddings=explicit_paddings, 605 use_cudnn_on_gpu=use_cudnn_on_gpu, 606 data_format=data_format) 607 ] 608 609 610@ops.RegisterGradient("DepthwiseConv2dNative") 611def _DepthwiseConv2dNativeGrad(op, grad): 612 return [ 613 gen_nn_ops.depthwise_conv2d_native_backprop_input( 614 array_ops.shape(op.inputs[0]), 615 op.inputs[1], 616 grad, 617 dilations=op.get_attr("dilations"), 618 strides=op.get_attr("strides"), 619 padding=op.get_attr("padding"), 620 explicit_paddings=op.get_attr("explicit_paddings"), 621 data_format=op.get_attr("data_format")), 622 gen_nn_ops.depthwise_conv2d_native_backprop_filter( 623 op.inputs[0], 624 array_ops.shape(op.inputs[1]), 625 grad, 626 dilations=op.get_attr("dilations"), 627 strides=op.get_attr("strides"), 628 padding=op.get_attr("padding"), 629 explicit_paddings=op.get_attr("explicit_paddings"), 630 data_format=op.get_attr("data_format")) 631 ] 632 633 634@ops.RegisterGradient("Dilation2D") 635def _Dilation2DGrad(op, grad): 636 return [ 637 nn_ops.dilation2d_backprop_input(op.inputs[0], op.inputs[1], grad, 638 op.get_attr("strides"), 639 op.get_attr("rates"), 640 op.get_attr("padding")), 641 nn_ops.dilation2d_backprop_filter(op.inputs[0], op.inputs[1], grad, 642 op.get_attr("strides"), 643 op.get_attr("rates"), 644 op.get_attr("padding")) 645 ] 646 647 648@ops.RegisterGradient("LRN") 649def _LRNGrad(op, grad): 650 depth_radius = op.get_attr("depth_radius") 651 bias = op.get_attr("bias") 652 alpha = op.get_attr("alpha") 653 beta = op.get_attr("beta") 654 return [ 655 gen_nn_ops.lrn_grad(grad, op.inputs[0], op.outputs[0], depth_radius, bias, 656 alpha, beta) 657 ] 658 659 660@ops.RegisterGradient("AvgPool") 661def _AvgPoolGrad(op, grad): 662 return gen_nn_ops.avg_pool_grad( 663 array_ops.shape(op.inputs[0]), 664 grad, 665 op.get_attr("ksize"), 666 op.get_attr("strides"), 667 op.get_attr("padding"), 668 data_format=op.get_attr("data_format")) 669 670 671@ops.RegisterGradient("AvgPoolGrad") 672def _AvgPoolGradGrad(op, grad): 673 return (array_ops.stop_gradient(op.inputs[0]), 674 gen_nn_ops.avg_pool( 675 grad, 676 op.get_attr("ksize"), 677 op.get_attr("strides"), 678 op.get_attr("padding"), 679 data_format=op.get_attr("data_format"))) 680 681 682@ops.RegisterGradient("MaxPool") 683def _MaxPoolGrad(op, grad): 684 return gen_nn_ops.max_pool_grad( 685 op.inputs[0], 686 op.outputs[0], 687 grad, 688 op.get_attr("ksize"), 689 op.get_attr("strides"), 690 padding=op.get_attr("padding"), 691 explicit_paddings=op.get_attr("explicit_paddings"), 692 data_format=op.get_attr("data_format")) 693 694 695@ops.RegisterGradient("MaxPoolV2") 696def _MaxPoolGradV2(op, grad): 697 ksize = op.inputs[1] 698 strides = op.inputs[2] 699 return gen_nn_ops.max_pool_grad_v2( 700 op.inputs[0], 701 op.outputs[0], 702 grad, 703 ksize, 704 strides, 705 padding=op.get_attr("padding"), 706 data_format=op.get_attr("data_format")), None, None 707 708 709@ops.RegisterGradient("MaxPoolWithArgmax") 710def _MaxPoolGradWithArgmax(op, grad, unused_argmax_grad): 711 del unused_argmax_grad 712 return gen_nn_ops.max_pool_grad_with_argmax( 713 op.inputs[0], 714 grad, 715 op.outputs[1], 716 op.get_attr("ksize"), 717 op.get_attr("strides"), 718 padding=op.get_attr("padding"), 719 include_batch_in_index=op.get_attr("include_batch_in_index")) 720 721 722@ops.RegisterGradient("MaxPoolGrad") 723def _MaxPoolGradGrad(op, grad): 724 return (array_ops.zeros( 725 shape=array_ops.shape(op.inputs[0]), dtype=op.inputs[0].dtype), 726 array_ops.zeros( 727 shape=array_ops.shape(op.inputs[1]), dtype=op.inputs[1].dtype), 728 gen_nn_ops.max_pool_grad_grad( 729 op.inputs[0], 730 op.inputs[1], 731 grad, 732 op.get_attr("ksize"), 733 op.get_attr("strides"), 734 padding=op.get_attr("padding"), 735 data_format=op.get_attr("data_format"))) 736 737 738@ops.RegisterGradient("MaxPoolGradV2") 739def _MaxPoolGradGradV2(op, grad): 740 ksize = op.inputs[3] 741 strides = op.inputs[4] 742 return (array_ops.zeros( 743 shape=array_ops.shape(op.inputs[0]), dtype=op.inputs[0].dtype), 744 array_ops.zeros( 745 shape=array_ops.shape(op.inputs[1]), dtype=op.inputs[1].dtype), 746 gen_nn_ops.max_pool_grad_grad_v2( 747 op.inputs[0], 748 op.inputs[1], 749 grad, 750 ksize, 751 strides, 752 padding=op.get_attr("padding"), 753 data_format=op.get_attr("data_format")), None, None) 754 755 756@ops.RegisterGradient("MaxPoolGradGrad") 757def _MaxPoolGradGradGrad(op, grad): 758 return (array_ops.zeros( 759 shape=array_ops.shape(op.inputs[0]), dtype=op.inputs[0].dtype), 760 array_ops.zeros( 761 shape=array_ops.shape(op.inputs[1]), dtype=op.inputs[1].dtype), 762 gen_nn_ops.max_pool_grad( 763 op.inputs[0], 764 op.inputs[1], 765 grad, 766 op.get_attr("ksize"), 767 op.get_attr("strides"), 768 padding=op.get_attr("padding"), 769 data_format=op.get_attr("data_format"))) 770 771 772@ops.RegisterGradient("FractionalMaxPool") 773def _FractionalMaxPoolGrad(op, grad_0, unused_grad_1, unused_grad_2): 774 """Returns gradient for FractionalMaxPool. 775 776 Since FractionalMaxPool has three outputs, there are three gradients passed in 777 for each of the outputs. Only the first one is useful, the other two gradients 778 are empty. 779 780 Args: 781 op: The FractionalMaxPoolOp. 782 grad_0: Gradient with respect to op.outputs[0] 783 unused_grad_1: Gradient with respect to op.outputs[1]/row_seq. It is empty. 784 unused_grad_2: Gradient with respect to op.outputs[2]/col_seq. It is empty. 785 786 Returns: 787 Input backprop for FractionalMaxPool op. 788 """ 789 return gen_nn_ops.fractional_max_pool_grad( 790 op.inputs[0], op.outputs[0], grad_0, op.outputs[1], op.outputs[2], 791 op.get_attr("overlapping")) 792 793 794@ops.RegisterGradient("FractionalAvgPool") 795def _FractionalAvgPoolGrad(op, grad_0, unused_grad_1, unused_grad_2): 796 """Returns gradient for FractionalAvgPool. 797 798 Since FractionalAvgPool has three outputs, there are three gradients passed in 799 for each of the outputs. Only the first one is useful, the other two gradients 800 are empty. 801 802 Args: 803 op: The FractionalAvgPoolOp. 804 grad_0: Gradient with respect to op.outputs[0] 805 unused_grad_1: Gradient with respect to op.outputs[1]/row_seq. It is empty. 806 unused_grad_2: Gradient with respect to op.outputs[2]/col_seq. It is empty. 807 808 Returns: 809 Input backprop for FractionalAvgPool op. 810 """ 811 return gen_nn_ops.fractional_avg_pool_grad(op.inputs[0].get_shape(), grad_0, 812 op.outputs[1], op.outputs[2], 813 op.get_attr("overlapping")) 814 815 816@ops.RegisterGradient("BatchNormWithGlobalNormalization") 817def _BatchNormWithGlobalNormalizationGrad(op, grad): 818 """Return the gradients for the 5 inputs of BatchNormWithGlobalNormalization. 819 820 We do not backprop anything for the mean and var intentionally as they are 821 not being trained with backprop in the operation. 822 823 Args: 824 op: The BatchNormOp for which we need to generate gradients. 825 grad: Tensor. The gradients passed to the BatchNormOp. 826 827 Returns: 828 dx: Backprop for input, which is (grad * (g * rsqrt(v + epsilon))) 829 dm: Backprop for mean, which is 830 sum_over_rest(grad * g) * (-1 / rsqrt(v + epsilon)) 831 dv: Backprop for variance, which is 832 sum_over_rest(grad * g * (x - m)) * (-1/2) * (v + epsilon) ^ (-3/2) 833 db: Backprop for beta, which is grad reduced in all except the 834 last dimension. 835 dg: Backprop for gamma, which is (grad * ((x - m) * rsqrt(v + epsilon))) 836 """ 837 dx, dm, dv, db, dg = gen_nn_ops.batch_norm_with_global_normalization_grad( 838 op.inputs[0], op.inputs[1], op.inputs[2], op.inputs[4], grad, 839 op.get_attr("variance_epsilon"), op.get_attr("scale_after_normalization")) 840 return dx, dm, dv, db, dg 841 842 843def _BaseFusedBatchNormGrad(op, version, *grad): 844 """Return the gradients for the 3 inputs of BatchNorm. 845 846 Args: 847 op: The BatchNormOp for which we need to compute gradients. 848 version: Integer indicating which version to use of the fused batch 849 norm gradient. 850 *grad: An argument list for tensors of gradients wrt the outputs 851 with grad[0] as grad_y. 852 853 Returns: 854 grad_x: gradient for x, which is scale * rsqrt(variance + epsilon) * 855 [grad_y - mean(grad_y) - (x - mean(x)) * 856 mean(grad_y * (x - mean(x))) / (variance + epsilon)] 857 in training mode; grad_y * scale * rsqrt(pop_variance + epsilon) 858 in freeze mode. 859 860 grad_scale: gradient for scale, which is sum(grad_y * (x - mean(x)) * 861 rsqrt(variance + epsilon)) in training mode; 862 sum(grad_y * (x - pop_mean) * rsqrt(pop_variance + epsilon)) 863 in freeze mode. 864 865 grad_offset: gradient for offset, which is sum(grad_y) in training mode; 866 sum(grad_y) in freeze mode. 867 """ 868 x = op.inputs[0] 869 grad_y = grad[0] 870 scale = op.inputs[1] 871 epsilon = op.get_attr("epsilon") 872 data_format = op.get_attr("data_format") 873 is_training = op.get_attr("is_training") 874 if version == 2: 875 grad_fun = gen_nn_ops.fused_batch_norm_grad_v3 876 elif version == 1: 877 grad_fun = gen_nn_ops.fused_batch_norm_grad_v2 878 else: 879 grad_fun = gen_nn_ops.fused_batch_norm_grad 880 if is_training: 881 args = { 882 "y_backprop": grad_y, 883 "x": x, 884 "scale": scale, 885 "reserve_space_1": op.outputs[3], 886 "reserve_space_2": op.outputs[4], 887 "epsilon": epsilon, 888 "data_format": data_format, 889 "is_training": is_training 890 } 891 if version == 2: 892 args["reserve_space_3"] = op.outputs[5] 893 dx, dscale, doffset, _, _ = grad_fun(**args) 894 else: 895 pop_mean = op.inputs[3] 896 pop_var = op.inputs[4] 897 if data_format == b"NCHW": 898 x = array_ops.transpose(x, [0, 2, 3, 1]) 899 grad_y = array_ops.transpose(grad_y, [0, 2, 3, 1]) 900 elif data_format == b"NCDHW": 901 x = array_ops.transpose(x, [0, 2, 3, 4, 1]) 902 grad_y = array_ops.transpose(grad_y, [0, 2, 3, 4, 1]) 903 target_data_format = ("NHWC" if data_format in (b"NCHW", 904 b"NHWC") else "NDHWC") 905 args = { 906 "y_backprop": grad_y, 907 "x": x, 908 "scale": scale, 909 "reserve_space_1": pop_mean, 910 "reserve_space_2": pop_var, 911 "epsilon": epsilon, 912 "data_format": target_data_format, 913 "is_training": is_training 914 } 915 if version == 2: 916 args["reserve_space_3"] = op.outputs[5] 917 dx, dscale, doffset, _, _ = grad_fun(**args) 918 if data_format == b"NCHW": 919 dx = array_ops.transpose(dx, [0, 3, 1, 2]) 920 elif data_format == b"NCDHW": 921 dx = array_ops.transpose(dx, [0, 4, 1, 2, 3]) 922 return dx, dscale, doffset, None, None 923 924 925@ops.RegisterGradient("FusedBatchNorm") 926def _FusedBatchNormGrad(op, *grad): 927 return _BaseFusedBatchNormGrad(op, 0, *grad) 928 929 930@ops.RegisterGradient("FusedBatchNormV2") 931def _FusedBatchNormV2Grad(op, *grad): 932 return _BaseFusedBatchNormGrad(op, 1, *grad) 933 934 935@ops.RegisterGradient("FusedBatchNormV3") 936def _FusedBatchNormV3Grad(op, *grad): 937 return _BaseFusedBatchNormGrad(op, 2, *grad) 938 939 940def _BatchNormGrad(grad_y, 941 x, 942 scale, 943 pop_mean, 944 pop_var, 945 epsilon, 946 data_format, 947 is_training=True): 948 """Returns the gradients for the 3 inputs of BatchNorm. 949 950 Args: 951 grad_y: A `Tensor` of 4 or 5 dimensions for gradient for y. 952 x: A `Tensor` of 4 or 5 dimensions for x. 953 scale: A `Tensor` of 1 dimension for scaling. 954 pop_mean: A `Tensor` of 1 dimension for the population mean. Only used when 955 is_training=False. 956 pop_var: A `Tensor` of 1 dimension for the population variance. Only used 957 when is_training=False. 958 epsilon: A small float number added to the variance of x. 959 data_format: The data format for input. Either b"NHWC" or b"NCHW". 960 is_training: A bool value to indicate the operation is for training 961 (default) or inference. 962 963 Returns: 964 A tuple (grad_x, grad_scale, grad_offset), where grad_x is the gradient 965 for x, grad_scale the gradient for scale, and grad_offset the gradient 966 for offset. 967 """ 968 x_dtype = x.dtype.base_dtype 969 if x_dtype == dtypes.float16: 970 # float16 math is too imprecise, so we do the batch norm gradient 971 # computations in float32. 972 x = math_ops.cast(x, dtypes.float32) 973 grad_y = math_ops.cast(grad_y, dtypes.float32) 974 if is_training: 975 if data_format == b"NHWC": 976 keepdims = False 977 reduce_axis = [0, 1, 2] 978 elif data_format == b"NDHWC": 979 keepdims = False 980 reduce_axis = [0, 1, 2, 3] 981 elif data_format == b"NCHW": 982 keepdims = True 983 reduce_axis = [0, 2, 3] 984 shape = [1, array_ops.size(scale), 1, 1] 985 scale = array_ops.reshape(scale, shape) 986 else: 987 keepdims = True 988 reduce_axis = [0, 2, 3, 4] 989 shape = [1, array_ops.size(scale), 1, 1, 1] 990 scale = array_ops.reshape(scale, shape) 991 mean_grad_y = math_ops.reduce_mean(grad_y, reduce_axis, keepdims=keepdims) 992 mean_x = math_ops.reduce_mean(x, reduce_axis, keepdims=keepdims) 993 var_x = math_ops.reduce_mean( 994 math_ops.squared_difference(x, array_ops.stop_gradient(mean_x)), 995 reduce_axis, 996 keepdims=keepdims) 997 grad_y_offset = grad_y - mean_grad_y 998 x_offset = x - mean_x 999 mean = math_ops.reduce_mean( 1000 grad_y * x_offset, axis=reduce_axis, keepdims=keepdims) 1001 grad_x = scale * math_ops.rsqrt(var_x + epsilon) * ( 1002 grad_y_offset - math_ops.reciprocal(var_x + epsilon) * mean * x_offset) 1003 grad_scale = math_ops.rsqrt(var_x + epsilon) * math_ops.reduce_sum( 1004 grad_y * x_offset, axis=reduce_axis, keepdims=keepdims) 1005 if data_format == b"NCHW" or data_format == b"NCDHW": 1006 grad_scale = array_ops.squeeze(grad_scale) 1007 grad_offset = math_ops.reduce_sum(grad_y, axis=reduce_axis) 1008 return math_ops.cast(grad_x, x_dtype), grad_scale, grad_offset 1009 else: 1010 if data_format == b"NHWC": 1011 reduce_axis = [0, 1, 2] 1012 elif data_format == b"NDHWC": 1013 reduce_axis = [0, 1, 2, 3] 1014 elif data_format == b"NCHW": 1015 reduce_axis = [0, 2, 3] 1016 shape = [1, array_ops.size(pop_mean), 1, 1] 1017 pop_mean = array_ops.reshape(pop_mean, shape) 1018 pop_var = array_ops.reshape(pop_var, shape) 1019 scale = array_ops.reshape(scale, shape) 1020 else: 1021 reduce_axis = [0, 2, 3, 4] 1022 shape = [1, array_ops.size(pop_mean), 1, 1, 1] 1023 pop_mean = array_ops.reshape(pop_mean, shape) 1024 pop_var = array_ops.reshape(pop_var, shape) 1025 scale = array_ops.reshape(scale, shape) 1026 1027 grad_offset = math_ops.reduce_sum(grad_y, axis=reduce_axis) 1028 var_rsqrt = math_ops.rsqrt(pop_var + epsilon) 1029 grad_scale = math_ops.reduce_sum( 1030 grad_y * (x - pop_mean) * var_rsqrt, axis=reduce_axis) 1031 grad_x = grad_y * scale * var_rsqrt 1032 return math_ops.cast(grad_x, x_dtype), grad_scale, grad_offset 1033 1034 1035@ops.RegisterGradient("FusedBatchNormGrad") 1036def _FusedBatchNormGradGrad(op, *grad): 1037 """Returns the gradients for the 3 inputs of FusedBatchNormGrad. 1038 1039 Args: 1040 op: The FusedBatchNormGradOp for which we need to compute gradients. 1041 *grad: An argument list for tensors of gradients wrt the outputs with 1042 grad[0] as grad_grad_x, grad[1] as grad_grad_scale, grad[2] as 1043 grad_grad_offset. 1044 1045 Returns: 1046 A tuple (grad_grad_y, grad_x, grad_scale, None, None), where grad_grad_y 1047 is the gradient for grad_y, grad_x the gradient for x, grad_scale the 1048 gradient for scale. 1049 """ 1050 data_format = op.get_attr("data_format") 1051 epsilon = op.get_attr("epsilon") 1052 is_training = op.get_attr("is_training") 1053 grad_y = op.inputs[0] 1054 x = op.inputs[1] 1055 scale = op.inputs[2] 1056 pop_mean = op.inputs[3] 1057 pop_var = op.inputs[4] 1058 grad_grad_x = grad[0] 1059 grad_grad_scale = grad[1] 1060 grad_grad_offset = grad[2] 1061 with backprop.GradientTape() as tape: 1062 tape.watch(grad_y) 1063 tape.watch(x) 1064 tape.watch(scale) 1065 grad_x, grad_scale, grad_offset = _BatchNormGrad( 1066 grad_y, x, scale, pop_mean, pop_var, epsilon, data_format, is_training) 1067 grad_initial = [grad_grad_x, grad_grad_scale, grad_grad_offset] 1068 grad_grad_y, grad_x, grad_scale = tape.gradient( 1069 [grad_x, grad_scale, grad_offset], [grad_y, x, scale], grad_initial) 1070 return grad_grad_y, grad_x, grad_scale, None, None 1071 1072 1073@ops.RegisterGradient("FusedBatchNormGradV2") 1074def _FusedBatchNormGradGradV2(op, *grad): 1075 return _FusedBatchNormGradGrad(op, *grad) 1076 1077 1078@ops.RegisterGradient("FusedBatchNormGradV3") 1079def _FusedBatchNormGradGradV3(op, *grad): 1080 grad_grad_y, grad_x, grad_scale, _, _ = _FusedBatchNormGradGrad(op, *grad) 1081 return grad_grad_y, grad_x, grad_scale, None, None, None 1082 1083 1084@ops.RegisterGradient("L2Loss") 1085def _L2LossGrad(op, grad): 1086 """Return the gradients for L2Loss. 1087 1088 Args: 1089 op: The L2LossOp for which we need to generate gradients. 1090 grad: Tensor containing a single number. 1091 1092 Returns: 1093 The gradient, which is (x * grad). 1094 """ 1095 return op.inputs[0] * grad 1096 1097 1098@ops.RegisterGradient("TopK") 1099@ops.RegisterGradient("TopKV2") 1100def _TopKGrad(op, grad, _): 1101 """Return the gradients for TopK. 1102 1103 Args: 1104 op: The TopKOp for which we need to generate gradients. 1105 grad: Tensor. The gradients passed to the TopKOp. 1106 1107 Returns: 1108 A list of two tensors, the first being the gradient w.r.t to the input and 1109 TopK, and the second being the gradient w.r.t. to the indices (all zero). 1110 """ 1111 in_shape = array_ops.shape(op.inputs[0]) 1112 ind_shape = array_ops.shape(op.outputs[1]) 1113 1114 # int32 is not supported on GPU hence up-casting 1115 ind_lastdim = array_ops.gather( 1116 math_ops.cast(ind_shape, dtypes.int64), 1117 array_ops.size(ind_shape) - 1) 1118 # Flatten indices to 2D. 1119 ind_2d = array_ops.reshape(op.outputs[1], array_ops.stack([-1, ind_lastdim])) 1120 1121 in_lastdim = array_ops.gather( 1122 math_ops.cast(in_shape, dtypes.int64), 1123 array_ops.size(in_shape) - 1) 1124 outerdim = array_ops.shape(ind_2d)[0] 1125 # Compute linear indices (flattened to 1D). 1126 ind = array_ops.reshape( 1127 ind_2d + math_ops.cast( 1128 array_ops.expand_dims( 1129 math_ops.range(0, 1130 math_ops.cast(outerdim, dtypes.int64) * in_lastdim, 1131 in_lastdim), -1), dtypes.int32), [-1]) 1132 1133 # Substitute grad to appropriate locations and fill the rest with zeros, 1134 # finally reshaping it to the original input shape. 1135 return [ 1136 array_ops.reshape( 1137 array_ops.scatter_nd( 1138 array_ops.expand_dims(ind, -1), array_ops.reshape(grad, [-1]), 1139 [math_ops.reduce_prod(in_shape)]), in_shape), 1140 array_ops.zeros([], dtype=dtypes.int32) 1141 ] 1142 1143 1144@ops.RegisterGradient("NthElement") 1145def _NthElementGrad(op, grad): 1146 """Return the gradients for NthElement. 1147 1148 Args: 1149 op: The NthElementOp for which we need to generate gradients. 1150 grad: Tensor. The gradients passed to the NthElementOp 1151 1152 Returns: 1153 A list of two tensors, the first being the gradient w.r.t. the input, 1154 the second being the gradient w.r.t. the N (None). 1155 """ 1156 input = op.inputs[0] # pylint: disable=redefined-builtin 1157 output = op.outputs[0] 1158 1159 # Compute the number of elements which equal to output in each reduction 1160 # dimension. If there are multiple elements then the gradient will be 1161 # divided between them. 1162 indicators = math_ops.cast( 1163 math_ops.equal(array_ops.expand_dims(output, -1), input), grad.dtype) 1164 1165 grad = array_ops.expand_dims(grad, -1) 1166 num_selected = array_ops.expand_dims(math_ops.reduce_sum(indicators, -1), -1) 1167 1168 return [math_ops.divide(indicators, num_selected) * grad, None] 1169 1170 1171def _MeanAggregator(inputs, segments): 1172 """Replaces each segment with its mean along the last axis. 1173 1174 Specifically, each value in the `inputs` tensor gets replaced by the mean 1175 value computed from the values that belong to the same segment. 1176 1177 Args: 1178 inputs: A 2-tensor. Aggregation is done over dimension 1. 1179 segments: A 2-tensor, same shape as `input`. 1180 1181 Returns: 1182 The result, same shape and type as `inputs`. 1183 """ 1184 result = [] 1185 for inputs_i, segments_i in zip( 1186 array_ops.split(inputs, inputs.shape[0]), 1187 array_ops.split(segments, segments.shape[0])): 1188 # Note that we do not use tf.math.segment_mean, as it has no TPU support. 1189 means_i = math_ops.unsorted_segment_mean( 1190 inputs_i, segments_i, num_segments=math_ops.reduce_max(segments_i) + 1) 1191 result.append( 1192 array_ops.reshape(array_ops.gather(means_i, segments_i), [-1])) 1193 return array_ops.stack(result, axis=0) 1194 1195 1196# We have to register the gradients for these ops so that tensorflow will know 1197# how to differentiate them. 1198@ops.RegisterGradient("IsotonicRegression") 1199def _IsotonicRegressionGrad(op, grad_output, grad_segments): 1200 """Gradient for the isotonic regression function. 1201 1202 Args: 1203 op: The IsotonicRegression tensorflow op. 1204 grad_output: Tensor of incoming gradients with respect to the output. 1205 grad_segments: Tensor of incoming gradients with respect to the segments. 1206 1207 Returns: 1208 A tensor, same size as `grad_output` with the gradient with respect to 1209 the input. 1210 """ 1211 del grad_segments # Discrete, non-differentiable. 1212 segments = op.outputs[1] 1213 return _MeanAggregator(grad_output, segments) 1214