1# Copyright 2020-2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16"""array_ops""" 17 18import numpy as np 19import mindspore as ms 20from mindspore.ops import composite as C 21from .. import operations as P 22from ..operations import _grad_ops as G 23from ..operations import _inner_ops as inner 24from ..composite.multitype_ops.zeros_like_impl import zeros_like 25from ..functional import broadcast_gradient_args 26from .. import functional as F 27from .grad_base import bprop_getters 28from ..primitive import constexpr 29from ... import context 30from ...common import dtype as mstype 31from ...common.tensor import RowTensor 32from .._utils.utils import range_op, get_1d_shape, generate_shape_index 33 34reduce_sum = P.ReduceSum() 35unsorted_segment_sum = P.UnsortedSegmentSum() 36transpose = P.Transpose() 37shape_op = P.Shape() 38dyn_shape_op = P.DynamicShape() 39reshape = P.Reshape() 40size_op = P.Size() 41invert_permutation = P.InvertPermutation() 42logical_and = P.LogicalAnd() 43is_sub_class = P.IsSubClass() 44 45 46@bprop_getters.register(P.Fill) 47def get_bprop_fill(self): 48 """Generate bprop for Fill""" 49 50 def bprop(dtype, dims, x, out, dout): 51 return zeros_like(dims), zeros_like(x) 52 53 return bprop 54 55 56@bprop_getters.register(P.Ones) 57def get_bprop_ones(self): 58 """Generate bprop for Ones""" 59 60 def bprop(dims, dtype, out, dout): 61 return zeros_like(dims) 62 63 return bprop 64 65 66@bprop_getters.register(P.Zeros) 67def get_bprop_zeros(self): 68 """Generate bprop for Zeros""" 69 70 def bprop(dims, dtype, out, dout): 71 return zeros_like(dims) 72 73 return bprop 74 75 76@bprop_getters.register(P.DType) 77def get_bprop_dtype(self): 78 """Generate bprop for DType""" 79 80 def bprop(x, out, dout): 81 return (zeros_like(x),) 82 83 return bprop 84 85 86dout_cast = C.MultitypeFuncGraph("dout_cast") 87 88 89@dout_cast.register("Tensor", "Tensor") 90def dout_cast_tensor(dout, x): 91 """Casts dout to the dtype of x for Tensor.""" 92 cast = P.Cast() 93 get_dtype = P.DType() 94 dx = cast(dout, get_dtype(x)) 95 return dx 96 97 98@dout_cast.register("Number", "Number") 99def dout_cast_number(dout, x): 100 """Casts dout to the dtype of x for Number.""" 101 cast = P.Cast() 102 get_dtype = P.DType() 103 dx = cast(dout, get_dtype(x)) 104 return dx 105 106 107@dout_cast.register("RowTensor", "Tensor") 108def dout_cast_row_tensor(dout, x): 109 """Casts dout values to the dtype of x for RowTensor.""" 110 cast = P.Cast() 111 get_dtype = P.DType() 112 values = cast(dout.values, get_dtype(x)) 113 return RowTensor(dout.indices, values, dout.dense_shape) 114 115 116@bprop_getters.register(P.Cast) 117def get_bprop_cast(self): 118 """Generate bprop for Cast""" 119 cast = P.Cast() 120 get_dtype = P.DType() 121 122 def bprop(x, t, out, dout): 123 dx = cast(dout, get_dtype(x)) 124 return dx, zeros_like(t) 125 126 def bprop_sparse(x, t, out, dout): 127 dx = dout_cast(dout, x) 128 return dx, zeros_like(t) 129 130 if context.get_context('enable_sparse'): 131 return bprop_sparse 132 133 return bprop 134 135 136@bprop_getters.register(P.Shape) 137def get_bprop_shape(self): 138 """Generate bprop for Shape""" 139 140 def bprop(x, out, dout): 141 return (zeros_like(x),) 142 143 return bprop 144 145 146@bprop_getters.register(P.DynamicShape) 147def get_bprop_dynamicshape(self): 148 """Generate bprop for Shape""" 149 150 def bprop(x, out, dout): 151 return (zeros_like(x),) 152 153 return bprop 154 155 156@bprop_getters.register(P.Split) 157def get_bprop_split(self): 158 """Generate bprop for Split""" 159 axis = self.axis 160 161 def bprop(x, out, dout): 162 concat_op = P.Concat(axis) 163 dx = concat_op(dout) 164 return (dx,) 165 166 return bprop 167 168 169@bprop_getters.register(P.Rank) 170def get_bprop_rank(self): 171 """Generate bprop for Rank""" 172 173 def bprop(x, out, dout): 174 return (zeros_like(x),) 175 176 return bprop 177 178 179@bprop_getters.register(P.Reshape) 180def get_bprop_reshape(self): 181 """Generate bprop for Reshape""" 182 183 def bprop(x, shp, out, dout): 184 shapex = shape_op(x) 185 return reshape(dout, shapex), zeros_like(shp) 186 187 return bprop 188 189 190@bprop_getters.register(P.ExpandDims) 191def get_bprop_expand_dims(self): 192 """Generate bprop for ExpandDims""" 193 194 def bprop(x, axis, out, dout): 195 shapex = shape_op(x) 196 return reshape(dout, shapex), zeros_like(axis) 197 198 return bprop 199 200 201@bprop_getters.register(P.Squeeze) 202def get_bprop_squeeze(self): 203 """Generate bprop for Squeeze""" 204 205 def bprop(x, out, dout): 206 shapex = shape_op(x) 207 return (reshape(dout, shapex),) 208 209 return bprop 210 211 212@bprop_getters.register(P.Flatten) 213def get_bprop_flatten(self): 214 """Generate bprop for Flatten""" 215 flatten_grad = P.Reshape() 216 217 def bprop(x, out, dout): 218 dx = flatten_grad(dout, shape_op(x)) 219 return (dx,) 220 221 return bprop 222 223 224@constexpr 225def _tile_shape(multiples, shapex): 226 """Calculate [1,2], [3, 4] -> [1,3,2,4].""" 227 len_muli = len(multiples) 228 rank = len(shapex) 229 len_cmp = len_muli - rank 230 max_len = max(len_muli, rank) 231 i = 0 232 j = 0 233 ret = [] 234 while (i < max_len) and (j < max_len): 235 if len_cmp == 0: 236 ret.append(multiples[i]) 237 ret.append(shapex[j]) 238 i += 1 239 j += 1 240 elif len_cmp > 0: 241 ret.append(multiples[i]) 242 ret.append(1) 243 i += 1 244 len_cmp -= 1 245 else: 246 ret.append(1) 247 ret.append(shapex[j]) 248 len_cmp += 1 249 return tuple(ret) 250 251 252@bprop_getters.register(P.Tile) 253def get_bprop_tile(self): 254 """Generate bprop for Tile""" 255 256 def bprop(x, multiples, out, dout): 257 shapex = shape_op(x) 258 r_shape = _tile_shape(multiples, shapex) 259 # 0 represents the start index, and 2 represents the step 260 axis = F.make_range(0, len(r_shape), 2) 261 dx = reduce_sum(reshape(dout, r_shape), axis) 262 dx = reshape(dx, shapex) 263 return dx, zeros_like(multiples) 264 265 return bprop 266 267 268@bprop_getters.register(P.EmbeddingLookup) 269def get_bprop_embedding_lookup(self): 270 """Generate bprop for EmbeddingLookup""" 271 sub_op = P.Sub() 272 reshape_op = P.Reshape() 273 274 def bprop_sparse(x, indices, offset, out, dout): 275 x_shp = shape_op(x) 276 new_indices = sub_op(indices, offset) 277 indices_size = size_op(new_indices) 278 if indices_size > 0: 279 # Reshape the 'new_indices' 280 new_indices_shape_changed = (indices_size,) 281 new_indices = reshape_op(new_indices, new_indices_shape_changed) 282 else: 283 new_indices_shape_changed = () 284 x_shp_tail = x_shp[1:] 285 actual_dout_shape_changed = new_indices_shape_changed + x_shp_tail 286 # Reshape the 'actual_dout' on device 287 actual_dout = reshape_op(dout, actual_dout_shape_changed) 288 return RowTensor(new_indices, actual_dout, x_shp), zeros_like(indices), zeros_like(offset) 289 290 return bprop_sparse 291 292 293@constexpr 294def make_begin(shp): 295 """Creates a tuple with zero according to the shape.""" 296 begin = tuple([0 for _ in shp]) 297 return begin 298 299 300@bprop_getters.register(P.Padding) 301def get_bprop_padding(self): 302 """Grad definition for `Padding` operation.""" 303 304 def bprop(x, out, dout): 305 shp = shape_op(x) 306 begin = make_begin(shp) 307 dx = P.Slice()(dout, begin, shp) 308 return (dx,) 309 310 return bprop 311 312 313@bprop_getters.register(P.Transpose) 314def get_bprop_transpose(self): 315 """Generate bprop for Transpose""" 316 317 def bprop(x, perm, out, dout): 318 return transpose(dout, invert_permutation(perm)), zeros_like(perm) 319 320 return bprop 321 322 323@constexpr 324def _concat_grad_uniform(input_shapes, input_nums): 325 """Helper function for bprop of Concat""" 326 is_uniform = True 327 for i in range(1, input_nums): 328 if input_shapes[i - 1] != input_shapes[i]: 329 is_uniform = False 330 break 331 return is_uniform 332 333 334@bprop_getters.register(P.Concat) 335def get_bprop_concat(self): 336 """Generate bprop for Concat""" 337 axis = self.axis 338 339 def bprop(x, out, dout): 340 out_offset = G.ConcatOffset(len(x), axis)(x) 341 input_nums = len(x) 342 input_shapes = () 343 for i in range(input_nums): 344 input_shapes = input_shapes + (shape_op(x[i]),) 345 is_uniform = _concat_grad_uniform(input_shapes, input_nums) 346 if isinstance(x, list): 347 dx = [] 348 if is_uniform: 349 dx_tuple = P.Split(axis, input_nums)(dout) 350 for _, i in enumerate(dx_tuple): 351 dx = dx + [i] 352 else: 353 for i in range(input_nums): 354 slice_out = P.Slice()(dout, out_offset[i], input_shapes[i]) 355 dx = dx + [slice_out] 356 else: 357 dx = () 358 if is_uniform: 359 dx = P.Split(axis, input_nums)(dout) 360 else: 361 for i in range(input_nums): 362 slice_out = P.Slice()(dout, out_offset[i], input_shapes[i]) 363 dx = dx + (slice_out,) 364 return (dx,) 365 366 return bprop 367 368 369@constexpr 370def _slice_grad_pad(begins, sizes, shapes): 371 pads = tuple((begin, shape - begin - size) for begin, size, shape in zip(begins, sizes, shapes)) 372 return pads 373 374 375@bprop_getters.register(P.Slice) 376def get_bprop_slice(self): 377 """Generate bprop for Slice""" 378 379 def bprop(x, begin, size, out, dout): 380 dx = G.SliceGrad()(dout, x, begin, size) 381 return (dx, zeros_like(begin), zeros_like(size)) 382 383 return bprop 384 385 386@constexpr 387def _generate_inverse_index(x_shape, axis): 388 x_rank = len(x_shape) 389 index = tuple(range(x_rank)) 390 if axis < 0: 391 axis += x_rank 392 perm = index[1:1 + axis] + (0,) + index[1 + axis:] 393 return perm 394 395 396@constexpr 397def _regenerate_output_shape(x_shp, ind_shp, axis): 398 rank = len(x_shp) 399 if axis < 0: 400 axis += rank 401 out_shape = x_shp[:axis] + ind_shp + x_shp[axis + 1:] 402 return out_shape 403 404 405@bprop_getters.register(P.Gather) 406@bprop_getters.register(P.GatherV2) 407def get_bprop_gather_v2(self): 408 """Generate bprop for GatherV2""" 409 410 def bprop(x, indices, axis, out, dout): 411 orig_indices = indices 412 if F.rank(dout) == 0: 413 dout = P.ExpandDims()(dout, -1) 414 if F.rank(indices) == 0: 415 indices = P.ExpandDims()(indices, -1) 416 x_shp = shape_op(x) 417 ind_shp = shape_op(indices) 418 out_shp = _regenerate_output_shape(x_shp, ind_shp, axis) 419 dout = reshape(dout, out_shp) 420 421 x_shp = shape_op(x) 422 out_shp = shape_op(dout) 423 ind_shp = shape_op(indices) 424 # Example: out_shape:(3,2,3) axis 1 -> (1,0,2) 425 perm_1 = generate_shape_index(out_shp, ind_shp, axis) 426 values_transpose = transpose(dout, perm_1) 427 if -1 in shape_op(x): 428 params_grad = unsorted_segment_sum(values_transpose, indices, dyn_shape_op(x)[axis]) 429 else: 430 params_grad = unsorted_segment_sum(values_transpose, indices, shape_op(x)[axis]) 431 # Example: out_shape:(3,2,3) axis 2 -> (1,2,0) 432 perm_2 = _generate_inverse_index(x_shp, axis) 433 params_grad = transpose(params_grad, perm_2) 434 return params_grad, zeros_like(orig_indices), zeros_like(axis) 435 436 return bprop 437 438 439@bprop_getters.register(P.GatherD) 440def get_bprop_gather_d(self): 441 """Generate bprop for GatherD""" 442 443 def bprop(x, dim, index, out, dout): 444 x_shp = shape_op(x) 445 dx = G.GatherDGrad(dim, x_shp)(index, dout) 446 return dx, zeros_like(dim), zeros_like(index) 447 448 return bprop 449 450 451@bprop_getters.register(G.GatherDGrad) 452def get_bprop_gather_d_grad(self): 453 """Generate bprop for GatherDGrad""" 454 op = P.Gather() 455 dim = self.dim 456 x_shp = self.out_shape 457 458 def bprop(index, x, out, dout): 459 index_shp = shape_op(index) 460 dim_before_axis = 1 461 for i in range(dim): 462 dim_before_axis *= x_shp[i] 463 dim_at_axis_index = index_shp[dim] 464 dim_at_axis_output = x_shp[dim] 465 dim_after_axis = 1 466 for i in range(dim+1, len(x_shp)): 467 dim_after_axis *= x_shp[i] 468 element = dim_before_axis * dim_at_axis_index * dim_after_axis 469 id_ = range_op(0, element, 1, index.dtype) 470 i = id_ // (dim_at_axis_index * dim_after_axis) 471 k = id_ % dim_after_axis 472 j = P.Cast()(index < 0, index.dtype) 473 j_read = dim_at_axis_index * j + index 474 j_read = P.Reshape()(j_read, (-1,)) 475 read_id = i*dim_at_axis_output*dim_after_axis + j_read * dim_after_axis + k 476 dout = P.Reshape()(dout, (-1,)) 477 dx = op(dout, read_id, 0) 478 dx = P.Reshape()(dx, shape_op(x)) 479 return zeros_like(index), dx 480 481 return bprop 482 483@bprop_getters.register(P.SparseGatherV2) 484def get_bprop_sparse_gather_v2(self): 485 """Generate bprop for SparseGatherV2""" 486 487 def bprop(x, indices, axis, out, dout): 488 x_shp = shape_op(x) 489 if axis == 0: 490 indices_size = (size_op(indices),) 491 if len(x_shp) <= 1: 492 x_tail_shp = () 493 else: 494 x_tail_shp = x_shp[1:] 495 values_shape = indices_size + x_tail_shp 496 values = reshape(dout, values_shape) 497 indices_new = reshape(indices, indices_size) 498 return RowTensor(indices_new, values, x_shp), zeros_like(indices), zeros_like(axis) 499 if F.rank(dout) == 0: 500 dout = P.ExpandDims()(dout, -1) 501 if F.rank(indices) == 0: 502 indices = P.ExpandDims()(indices, -1) 503 out_shp = shape_op(dout) 504 ind_shp = shape_op(indices) 505 # Example: out_shape:(3,2,3) axis 1 -> (1,0,2) 506 perm_1 = generate_shape_index(out_shp, ind_shp, axis) 507 values_transpose = transpose(dout, perm_1) 508 params_grad = unsorted_segment_sum(values_transpose, indices, shape_op(x)[axis]) 509 # Example: out_shape:(3,2,3) axis 2 -> (1,2,0) 510 perm_2 = _generate_inverse_index(x_shp, axis) 511 params_grad = transpose(params_grad, perm_2) 512 return params_grad, zeros_like(indices), zeros_like(axis) 513 514 return bprop 515 516 517@constexpr 518def _get_transposition(axis, rank): 519 """helper function for grad of Sort""" 520 if axis < 0: 521 axis += rank 522 transposition = np.r_[np.arange(axis), [rank - 1], np.arange(axis + 1, rank - 1), [axis]] 523 trans = tuple(transposition.tolist()) 524 return trans 525 526 527@bprop_getters.register(P.Sort) 528def get_bprop_sort(self): 529 """Grad definition for `Sort` operation.""" 530 axis = self.axis 531 descending = self.descending 532 scatter = P.ScatterNd() 533 expand_dims = P.ExpandDims() 534 reshape_op = P.Reshape() 535 dtype = P.DType() 536 topk = P.TopK() 537 neg = P.Neg() 538 tranpose = P.Transpose() 539 540 def bprop(input_x, out, dout): 541 x_shape = input_x.shape 542 k = x_shape[axis] 543 rank = F.rank(input_x) 544 dvalue = dout[0] 545 if not descending: 546 input_x = neg(input_x) 547 dvalue = neg(dvalue) 548 if axis == -1 or (axis + 1) == rank: 549 transposition = None 550 top_k_input = input_x 551 else: 552 transposition = _get_transposition(axis, rank) 553 top_k_input = tranpose(input_x, transposition) 554 555 _, indices = topk(top_k_input, k) 556 ind_shape = indices.shape 557 top_k_input_shape = top_k_input.shape 558 in_lastdim = top_k_input_shape[-1] 559 ind_lastdim = ind_shape[-1] 560 ind_2d = reshape_op(indices, (-1, ind_lastdim)) 561 outer_dim = ind_2d.shape[0] 562 563 indices_dtype = dtype(indices) 564 range_flatten_index = range_op(0, outer_dim * in_lastdim, in_lastdim, indices_dtype) 565 566 # expand_dims to (k, 1), then broadcast 567 ind = reshape_op(ind_2d + expand_dims(range_flatten_index, -1), (-1,)) 568 x_shape_1d = get_1d_shape(top_k_input_shape) 569 570 if transposition is not None: 571 dvalue = tranpose(dvalue, invert_permutation(transposition)) 572 out_grad = reshape_op( 573 scatter(expand_dims(ind, -1), reshape_op(dvalue, (-1,)), x_shape_1d), top_k_input_shape) 574 dx = tranpose(out_grad, invert_permutation(transposition)) 575 else: 576 dx = reshape_op(scatter(expand_dims(ind, -1), reshape_op(dvalue, (-1,)), x_shape_1d), top_k_input_shape) 577 if not descending: 578 dx = neg(dx) 579 return (dx,) 580 581 return bprop 582 583 584@bprop_getters.register(P.Identity) 585def get_bprop_identity(self): 586 """Generate bprop for Identity""" 587 588 def bprop(x, out, dout): 589 return (dout,) 590 591 return bprop 592 593 594@bprop_getters.register(inner.Range) 595def get_bprop_range(self): 596 """Generate bprop for Range""" 597 598 def bprop(x, out, dout): 599 return (zeros_like(x),) 600 601 return bprop 602 603 604@bprop_getters.register(P.Pack) 605@bprop_getters.register(P.Stack) 606def get_bprop_stack(self): 607 """Generate bprop for Stack""" 608 axis = self.axis 609 610 def bprop(x, out, dout): 611 stack_grad = P.Unstack(axis) 612 out = stack_grad(dout) 613 if is_sub_class(F.typeof(x), ms.list_): 614 ret = [] 615 for item in out: 616 ret.append(item) 617 return (ret,) 618 return (out,) 619 620 return bprop 621 622 623@bprop_getters.register(P.ReverseV2) 624def get_bprop_reverse_v2(self): 625 """Generate bprop for ReverseV2""" 626 axis = self.axis 627 628 def bprop(x, out, dout): 629 reverse_grad = P.ReverseV2(axis) 630 dx = reverse_grad(dout) 631 return (dx,) 632 633 return bprop 634 635 636@bprop_getters.register(P.Unstack) 637def get_bprop_unstack(self): 638 """Generate bprop for Unstack""" 639 axis = self.axis 640 641 def bprop(x, out, dout): 642 unstack_grad = P.Stack(axis) 643 out = unstack_grad(dout) 644 return (out,) 645 646 return bprop 647 648 649@bprop_getters.register(P.StridedSlice) 650def get_bprop_strided_slice(self): 651 """Generate bprop for StridedSlice""" 652 input_grad = G.StridedSliceGrad(self.begin_mask, 653 self.end_mask, 654 self.ellipsis_mask, 655 self.new_axis_mask, 656 self.shrink_axis_mask) 657 658 def bprop(x, begin, end, strides, out, dout): 659 x_shape = shape_op(x) 660 if -1 in x_shape: 661 x_shape = dyn_shape_op(x) 662 dx = input_grad(dout, x_shape, begin, end, strides) 663 return dx, zeros_like(begin), zeros_like(end), zeros_like(strides) 664 665 return bprop 666 667 668@bprop_getters.register(P.Eye) 669def get_bprop_eye(self): 670 """Generate bprop for Eye""" 671 672 def bprop(n, m, t, out, dout): 673 return zeros_like(n), zeros_like(m), zeros_like(t) 674 675 return bprop 676 677 678@bprop_getters.register(P.Select) 679def get_bprop_select(self): 680 """Generate bprop for Select""" 681 select = P.Select() 682 683 def bprop(cond, x, y, out, dout): 684 return zeros_like(cond), select(cond, dout, zeros_like(x)), select(cond, zeros_like(y), dout) 685 686 return bprop 687 688 689@bprop_getters.register(P.OnesLike) 690def get_bprop_oneslike(self): 691 """Generate bprop for OnesLike""" 692 693 def bprop(x, out, dout): 694 return (zeros_like(x),) 695 696 return bprop 697 698 699@bprop_getters.register(P.ZerosLike) 700def get_bprop_zeroslike(self): 701 """Generate bprop for ZerosLike""" 702 703 def bprop(x, out, dout): 704 return (zeros_like(x),) 705 706 return bprop 707 708 709@bprop_getters.register(P.ResizeNearestNeighbor) 710def get_bprop_resize_nearest_neighbor(self): 711 """Generate bprop for ResizeNearestNeighbor""" 712 op = G.ResizeNearestNeighborGrad(self.align_corners) 713 714 def bprop(inputs, out, dout): 715 shp = shape_op(inputs) 716 # 2 and 3 represent the height and width 717 shp = (shp[2], shp[3]) 718 return (op(dout, shp),) 719 720 return bprop 721 722 723@bprop_getters.register(P.GatherNd) 724def get_bprop_gather_nd(self): 725 """Generate bprop for GatherNd""" 726 op = P.ScatterNd() 727 728 def bprop(x, indices, out, dout): 729 shp = shape_op(x) 730 return op(indices, dout, shp), zeros_like(indices) 731 732 return bprop 733 734 735@bprop_getters.register(P.ScatterNd) 736def get_bprop_scatter_nd(self): 737 """Generate bprop for ScatterNd""" 738 op = P.GatherNd() 739 740 def bprop(indices, x, shape, out, dout): 741 return zeros_like(indices), op(dout, indices), zeros_like(shape) 742 743 return bprop 744 745 746@bprop_getters.register(P.ScatterNdUpdate) 747def get_bprop_scatter_nd_update(self): 748 """Generate bprop for ScatterNdUpdate""" 749 op = P.GatherNd() 750 751 def bprop(x, indices, update, out, dout): 752 return dout, zeros_like(indices), op(dout, indices) 753 754 return bprop 755 756 757@bprop_getters.register(P.ScatterNonAliasingAdd) 758def get_bprop_scatter_non_aliasing_add_update(self): 759 """Generate bprop for ScatterNonAliasingAdd""" 760 op = P.GatherNd() 761 762 def bprop(x, indices, update, out, dout): 763 return dout, zeros_like(indices), op(dout, indices) 764 765 return bprop 766 767 768@bprop_getters.register(P.TensorScatterUpdate) 769def get_bprop_tensor_scatter_update(self): 770 """Generate bprop for TensorScatterUpdate""" 771 gather_nd = P.GatherNd() 772 tensor_scatter_update = P.TensorScatterUpdate() 773 774 def bprop(x, indices, update, out, dout): 775 x_grad = tensor_scatter_update(dout, indices, zeros_like(update)) 776 update_grad = gather_nd(dout, indices) 777 return x_grad, zeros_like(indices), update_grad 778 779 return bprop 780 781 782@bprop_getters.register(P.TensorScatterAdd) 783def get_bprop_tensor_scatter_add(self): 784 """Generate bprop for TensorScatterAdd""" 785 gather_nd = P.GatherNd() 786 787 def bprop(x, indices, update, out, dout): 788 update_grad = gather_nd(dout, indices) 789 return dout, zeros_like(indices), update_grad 790 791 return bprop 792 793 794@bprop_getters.register(P.ScatterMax) 795def get_bprop_scatter_max(self): 796 """Generate bprop for ScatterMax""" 797 gather = P.Gather() 798 799 def bprop(x, indices, update, out, dout): 800 return dout, zeros_like(indices), gather(dout, indices, 0) 801 802 return bprop 803 804 805@bprop_getters.register(P.Argmax) 806def get_bprop_argmax(self): 807 """Generate bprop for Argmax""" 808 809 def bprop(x, out, dout): 810 return (zeros_like(x),) 811 812 return bprop 813 814 815@bprop_getters.register(P.Argmin) 816def get_bprop_argmin(self): 817 """Generate bprop for Argmin""" 818 819 def bprop(x, out, dout): 820 return (zeros_like(x),) 821 822 return bprop 823 824 825@bprop_getters.register(P.SpaceToDepth) 826def get_bprop_space_to_depth(self): 827 """Generate bprop for SpaceToDepth""" 828 op = P.DepthToSpace(self.block_size) 829 830 def bprop(x, out, dout): 831 return (op(dout),) 832 833 return bprop 834 835 836@bprop_getters.register(P.DepthToSpace) 837def get_bprop_depth_to_space(self): 838 """Generate bprop for DepthToSpace""" 839 op = P.SpaceToDepth(self.block_size) 840 841 def bprop(x, out, dout): 842 return (op(dout),) 843 844 return bprop 845 846 847@bprop_getters.register(P.Diag) 848def get_bprop_diag(self): 849 """Generate bprop for Diag""" 850 op = P.DiagPart() 851 852 def bprop(x, out, dout): 853 return (op(dout),) 854 855 return bprop 856 857 858@bprop_getters.register(P.DiagPart) 859def get_bprop_diag_part(self): 860 """Generate bprop for DiagPart""" 861 op = P.Diag() 862 863 def bprop(x, out, dout): 864 return (op(dout),) 865 866 return bprop 867 868 869def _gather_drop_negatives(params, 870 ids, 871 zero_clipped_indices=None, 872 is_positive=None): 873 """Helper function for unsorted segment ops.""" 874 maximum = P.Maximum() 875 gather = P.Gather() 876 greater_equal = P.GreaterEqual() 877 rank = P.Rank() 878 fill = P.Fill() 879 select = P.Select() 880 881 if zero_clipped_indices is None: 882 zero_clipped_indices = maximum(ids, zeros_like(ids)) 883 gathered = gather(params, zero_clipped_indices, 0) 884 if is_positive is None: 885 is_positive = greater_equal(ids, 0) 886 is_positive_shape = shape_op(is_positive) 887 broadcastable_shape = is_positive_shape 888 for _ in range(rank(gathered) - rank(is_positive)): 889 broadcastable_shape += (1,) 890 is_positive = reshape(is_positive, broadcastable_shape) 891 gathered_shape = shape_op(gathered) 892 is_positive = logical_and(is_positive, fill(mstype.bool_, gathered_shape, 1)) 893 zero_slice = zeros_like(gathered) 894 return (select(is_positive, gathered, zero_slice), zero_clipped_indices, is_positive) 895 896 897def _unsorted_segment_min_or_max_grad(x, segment_ids, num_segments, out, dout): 898 """Gradient for UnsortedSegmentMin or UnsortedSegmentMax""" 899 equal = P.Equal() 900 cast = P.Cast() 901 divide = P.RealDiv() 902 get_dtype = P.DType() 903 select = P.Select() 904 905 gathered_outputs, zero_clipped_indices, is_positive = _gather_drop_negatives(out, segment_ids, None, None) 906 is_selected = equal(x, gathered_outputs) 907 is_selected = logical_and(is_selected, is_positive) 908 num_selected = unsorted_segment_sum(cast(is_selected, get_dtype(dout)), 909 segment_ids, num_segments) 910 weighted_grads = divide(dout, num_selected) 911 gathered_grads, _, _ = _gather_drop_negatives(weighted_grads, None, 912 zero_clipped_indices, is_positive) 913 zeros = zeros_like(gathered_grads) 914 return select(is_selected, gathered_grads, zeros), zeros_like(segment_ids), zeros_like(num_segments) 915 916 917@bprop_getters.register(P.UnsortedSegmentSum) 918def get_bprop_unsorted_segment_sum(self): 919 """Generate bprop for UnsortedSegmentSum""" 920 921 def bprop(x, segment_ids, num_segments, out, dout): 922 return _gather_drop_negatives(dout, segment_ids, None, None)[0], zeros_like(segment_ids), \ 923 zeros_like(num_segments) 924 925 return bprop 926 927 928@bprop_getters.register(P.UnsortedSegmentMin) 929def get_bprop_unsorted_segment_min(self): 930 """Generate bprop for UnsortedSegmentMin""" 931 932 def bprop(x, segment_ids, num_segments, out, dout): 933 return _unsorted_segment_min_or_max_grad(x, segment_ids, num_segments, out, dout) 934 935 return bprop 936 937 938@bprop_getters.register(P.UnsortedSegmentMax) 939def get_bprop_unsorted_segment_max(self): 940 """Generate bprop for UnsortedSegmentMax""" 941 942 def bprop(x, segment_ids, num_segments, out, dout): 943 return _unsorted_segment_min_or_max_grad(x, segment_ids, num_segments, out, dout) 944 945 return bprop 946 947 948@bprop_getters.register(P.UnsortedSegmentProd) 949def get_bprop_unsorted_segment_prod(self): 950 """Generate bprop for UnsortedSegmentProd""" 951 equal = P.Equal() 952 cast = P.Cast() 953 select = P.Select() 954 gather = P.Gather() 955 greater = P.Greater() 956 ones_like = P.OnesLike() 957 maximum = P.Maximum() 958 unsorted_segment_prod = P.UnsortedSegmentProd() 959 960 def bprop(x, segment_ids, num_segments, out, dout): 961 is_zero = equal(x, 0) 962 num_zero = unsorted_segment_sum(cast(is_zero, mstype.int32), segment_ids, num_segments) 963 grad = select(greater(num_zero, 1), zeros_like(dout), dout) 964 non_zero_data = select(is_zero, ones_like(x), x) 965 non_zero_prod = unsorted_segment_prod(non_zero_data, segment_ids, num_segments) 966 zero_clipped_indices = maximum(segment_ids, zeros_like(segment_ids)) 967 gathered_prod = gather(out, zero_clipped_indices, 0) 968 gathered_non_zero_prod = gather(non_zero_prod, zero_clipped_indices, 0) 969 prod_divided_by_x = gathered_prod / x 970 partial_derivative = select(is_zero, gathered_non_zero_prod, prod_divided_by_x) 971 gathered_grad, _, _ = _gather_drop_negatives(grad, segment_ids, zero_clipped_indices, None) 972 dx = gathered_grad * partial_derivative 973 return dx, zeros_like(segment_ids), zeros_like(num_segments) 974 975 return bprop 976 977 978@bprop_getters.register(P.SpaceToBatch) 979def get_bprop_space_to_batch(self): 980 """Generate bprop for SpaceToBatch""" 981 space_to_batch_grad = P.BatchToSpace(self.block_size, self.paddings) 982 983 def bprop(x, out, dout): 984 dx = space_to_batch_grad(dout) 985 return (dx,) 986 987 return bprop 988 989 990@bprop_getters.register(P.BatchToSpace) 991def get_bprop_batch_to_space(self): 992 """Generate bprop for BatchToSpace""" 993 batch_to_space_grad = P.SpaceToBatch(self.block_size, self.crops) 994 995 def bprop(x, out, dout): 996 dx = batch_to_space_grad(dout) 997 return (dx,) 998 999 return bprop 1000 1001 1002@bprop_getters.register(P.SpaceToBatchND) 1003def get_bprop_space_to_batch_nd(self): 1004 """Generate bprop for SpaceToBatchND""" 1005 space_to_batch_nd_grad = P.BatchToSpaceND(self.block_shape, self.paddings) 1006 1007 def bprop(x, out, dout): 1008 dx = space_to_batch_nd_grad(dout) 1009 return (dx,) 1010 1011 return bprop 1012 1013 1014@bprop_getters.register(P.BatchToSpaceND) 1015def get_bprop_batch_to_space_nd(self): 1016 """Generate bprop for BatchToSpaceND""" 1017 batch_to_space_nd_grad = P.SpaceToBatchND(self.block_shape, self.crops) 1018 1019 def bprop(x, out, dout): 1020 dx = batch_to_space_nd_grad(dout) 1021 return (dx,) 1022 1023 return bprop 1024 1025 1026@bprop_getters.register(P.BroadcastTo) 1027def get_bprop_broadcast_to(self): 1028 """Generate bprop for BroadcastTo""" 1029 reduce_keep_dim = P.ReduceSum(keep_dims=True) 1030 1031 def bprop(x, out, dout): 1032 x_shape = shape_op(x) 1033 dout_shape = shape_op(dout) 1034 broadcast_shape = shape_op(out) 1035 1036 if x_shape == dout_shape: 1037 return (dout,) 1038 _, reduction_axes = broadcast_gradient_args(broadcast_shape, x_shape) 1039 reduced_grad = reduce_keep_dim(dout, reduction_axes) 1040 dx = reshape(reduced_grad, x_shape) 1041 return (dx,) 1042 1043 return bprop 1044 1045 1046@bprop_getters.register(P.ReverseSequence) 1047def get_bprop_reverse_sequence(self): 1048 """Generate bprop for ReverseSequence""" 1049 reverse_sequence_grad = P.ReverseSequence(batch_dim=self.batch_dim_, seq_dim=self.seq_dim_) 1050 1051 def bprop(x, seq_lengths, out, dout): 1052 dx = reverse_sequence_grad(dout, seq_lengths) 1053 return dx, zeros_like(seq_lengths) 1054 1055 return bprop 1056 1057 1058@bprop_getters.register(P.TransShape) 1059def get_bprop_trans_shape(self): 1060 """Generate bprop for TransShape""" 1061 op = P.TransShape() 1062 1063 def bprop(x, shape, out, dout): 1064 dx = op(dout, shape_op(x)) 1065 return (dx, zeros_like(shape)) 1066 1067 return bprop 1068 1069 1070@bprop_getters.register(P.Unique) 1071def get_bprop_unique(self): 1072 """Generate bprop for Unique""" 1073 op = G.UniqueGrad() 1074 1075 def bprop(x, out, dout): 1076 dx = op(dout, out) 1077 return (dx,) 1078 1079 return bprop 1080 1081 1082@bprop_getters.register(P.MaskedSelect) 1083def get_bprop_masked_select(self): 1084 """Generate bprop for MaskedSelect""" 1085 op = G.MaskedSelectGrad() 1086 1087 def bprop(x, mask, out, dout): 1088 dx = op(x, mask, dout) 1089 return (dx, zeros_like(mask)) 1090 1091 return bprop 1092