• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020-2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16"""array_ops"""
17
18import numpy as np
19import mindspore as ms
20from mindspore.ops import composite as C
21from .. import operations as P
22from ..operations import _grad_ops as G
23from ..operations import _inner_ops as inner
24from ..composite.multitype_ops.zeros_like_impl import zeros_like
25from ..functional import broadcast_gradient_args
26from .. import functional as F
27from .grad_base import bprop_getters
28from ..primitive import constexpr
29from ... import context
30from ...common import dtype as mstype
31from ...common.tensor import RowTensor
32from .._utils.utils import range_op, get_1d_shape, generate_shape_index
33
34reduce_sum = P.ReduceSum()
35unsorted_segment_sum = P.UnsortedSegmentSum()
36transpose = P.Transpose()
37shape_op = P.Shape()
38dyn_shape_op = P.DynamicShape()
39reshape = P.Reshape()
40size_op = P.Size()
41invert_permutation = P.InvertPermutation()
42logical_and = P.LogicalAnd()
43is_sub_class = P.IsSubClass()
44
45
46@bprop_getters.register(P.Fill)
47def get_bprop_fill(self):
48    """Generate bprop for Fill"""
49
50    def bprop(dtype, dims, x, out, dout):
51        return zeros_like(dims), zeros_like(x)
52
53    return bprop
54
55
56@bprop_getters.register(P.Ones)
57def get_bprop_ones(self):
58    """Generate bprop for Ones"""
59
60    def bprop(dims, dtype, out, dout):
61        return zeros_like(dims)
62
63    return bprop
64
65
66@bprop_getters.register(P.Zeros)
67def get_bprop_zeros(self):
68    """Generate bprop for Zeros"""
69
70    def bprop(dims, dtype, out, dout):
71        return zeros_like(dims)
72
73    return bprop
74
75
76@bprop_getters.register(P.DType)
77def get_bprop_dtype(self):
78    """Generate bprop for DType"""
79
80    def bprop(x, out, dout):
81        return (zeros_like(x),)
82
83    return bprop
84
85
86dout_cast = C.MultitypeFuncGraph("dout_cast")
87
88
89@dout_cast.register("Tensor", "Tensor")
90def dout_cast_tensor(dout, x):
91    """Casts dout to the dtype of x for Tensor."""
92    cast = P.Cast()
93    get_dtype = P.DType()
94    dx = cast(dout, get_dtype(x))
95    return dx
96
97
98@dout_cast.register("Number", "Number")
99def dout_cast_number(dout, x):
100    """Casts dout to the dtype of x for Number."""
101    cast = P.Cast()
102    get_dtype = P.DType()
103    dx = cast(dout, get_dtype(x))
104    return dx
105
106
107@dout_cast.register("RowTensor", "Tensor")
108def dout_cast_row_tensor(dout, x):
109    """Casts dout values to the dtype of x for RowTensor."""
110    cast = P.Cast()
111    get_dtype = P.DType()
112    values = cast(dout.values, get_dtype(x))
113    return RowTensor(dout.indices, values, dout.dense_shape)
114
115
116@bprop_getters.register(P.Cast)
117def get_bprop_cast(self):
118    """Generate bprop for Cast"""
119    cast = P.Cast()
120    get_dtype = P.DType()
121
122    def bprop(x, t, out, dout):
123        dx = cast(dout, get_dtype(x))
124        return dx, zeros_like(t)
125
126    def bprop_sparse(x, t, out, dout):
127        dx = dout_cast(dout, x)
128        return dx, zeros_like(t)
129
130    if context.get_context('enable_sparse'):
131        return bprop_sparse
132
133    return bprop
134
135
136@bprop_getters.register(P.Shape)
137def get_bprop_shape(self):
138    """Generate bprop for Shape"""
139
140    def bprop(x, out, dout):
141        return (zeros_like(x),)
142
143    return bprop
144
145
146@bprop_getters.register(P.DynamicShape)
147def get_bprop_dynamicshape(self):
148    """Generate bprop for Shape"""
149
150    def bprop(x, out, dout):
151        return (zeros_like(x),)
152
153    return bprop
154
155
156@bprop_getters.register(P.Split)
157def get_bprop_split(self):
158    """Generate bprop for Split"""
159    axis = self.axis
160
161    def bprop(x, out, dout):
162        concat_op = P.Concat(axis)
163        dx = concat_op(dout)
164        return (dx,)
165
166    return bprop
167
168
169@bprop_getters.register(P.Rank)
170def get_bprop_rank(self):
171    """Generate bprop for Rank"""
172
173    def bprop(x, out, dout):
174        return (zeros_like(x),)
175
176    return bprop
177
178
179@bprop_getters.register(P.Reshape)
180def get_bprop_reshape(self):
181    """Generate bprop for Reshape"""
182
183    def bprop(x, shp, out, dout):
184        shapex = shape_op(x)
185        return reshape(dout, shapex), zeros_like(shp)
186
187    return bprop
188
189
190@bprop_getters.register(P.ExpandDims)
191def get_bprop_expand_dims(self):
192    """Generate bprop for ExpandDims"""
193
194    def bprop(x, axis, out, dout):
195        shapex = shape_op(x)
196        return reshape(dout, shapex), zeros_like(axis)
197
198    return bprop
199
200
201@bprop_getters.register(P.Squeeze)
202def get_bprop_squeeze(self):
203    """Generate bprop for Squeeze"""
204
205    def bprop(x, out, dout):
206        shapex = shape_op(x)
207        return (reshape(dout, shapex),)
208
209    return bprop
210
211
212@bprop_getters.register(P.Flatten)
213def get_bprop_flatten(self):
214    """Generate bprop for Flatten"""
215    flatten_grad = P.Reshape()
216
217    def bprop(x, out, dout):
218        dx = flatten_grad(dout, shape_op(x))
219        return (dx,)
220
221    return bprop
222
223
224@constexpr
225def _tile_shape(multiples, shapex):
226    """Calculate [1,2], [3, 4] -> [1,3,2,4]."""
227    len_muli = len(multiples)
228    rank = len(shapex)
229    len_cmp = len_muli - rank
230    max_len = max(len_muli, rank)
231    i = 0
232    j = 0
233    ret = []
234    while (i < max_len) and (j < max_len):
235        if len_cmp == 0:
236            ret.append(multiples[i])
237            ret.append(shapex[j])
238            i += 1
239            j += 1
240        elif len_cmp > 0:
241            ret.append(multiples[i])
242            ret.append(1)
243            i += 1
244            len_cmp -= 1
245        else:
246            ret.append(1)
247            ret.append(shapex[j])
248            len_cmp += 1
249    return tuple(ret)
250
251
252@bprop_getters.register(P.Tile)
253def get_bprop_tile(self):
254    """Generate bprop for Tile"""
255
256    def bprop(x, multiples, out, dout):
257        shapex = shape_op(x)
258        r_shape = _tile_shape(multiples, shapex)
259        # 0 represents the start index, and 2 represents the step
260        axis = F.make_range(0, len(r_shape), 2)
261        dx = reduce_sum(reshape(dout, r_shape), axis)
262        dx = reshape(dx, shapex)
263        return dx, zeros_like(multiples)
264
265    return bprop
266
267
268@bprop_getters.register(P.EmbeddingLookup)
269def get_bprop_embedding_lookup(self):
270    """Generate bprop for EmbeddingLookup"""
271    sub_op = P.Sub()
272    reshape_op = P.Reshape()
273
274    def bprop_sparse(x, indices, offset, out, dout):
275        x_shp = shape_op(x)
276        new_indices = sub_op(indices, offset)
277        indices_size = size_op(new_indices)
278        if indices_size > 0:
279            # Reshape the 'new_indices'
280            new_indices_shape_changed = (indices_size,)
281            new_indices = reshape_op(new_indices, new_indices_shape_changed)
282        else:
283            new_indices_shape_changed = ()
284        x_shp_tail = x_shp[1:]
285        actual_dout_shape_changed = new_indices_shape_changed + x_shp_tail
286        # Reshape the 'actual_dout' on device
287        actual_dout = reshape_op(dout, actual_dout_shape_changed)
288        return RowTensor(new_indices, actual_dout, x_shp), zeros_like(indices), zeros_like(offset)
289
290    return bprop_sparse
291
292
293@constexpr
294def make_begin(shp):
295    """Creates a tuple with zero according to the shape."""
296    begin = tuple([0 for _ in shp])
297    return begin
298
299
300@bprop_getters.register(P.Padding)
301def get_bprop_padding(self):
302    """Grad definition for `Padding` operation."""
303
304    def bprop(x, out, dout):
305        shp = shape_op(x)
306        begin = make_begin(shp)
307        dx = P.Slice()(dout, begin, shp)
308        return (dx,)
309
310    return bprop
311
312
313@bprop_getters.register(P.Transpose)
314def get_bprop_transpose(self):
315    """Generate bprop for Transpose"""
316
317    def bprop(x, perm, out, dout):
318        return transpose(dout, invert_permutation(perm)), zeros_like(perm)
319
320    return bprop
321
322
323@constexpr
324def _concat_grad_uniform(input_shapes, input_nums):
325    """Helper function for bprop of Concat"""
326    is_uniform = True
327    for i in range(1, input_nums):
328        if input_shapes[i - 1] != input_shapes[i]:
329            is_uniform = False
330            break
331    return is_uniform
332
333
334@bprop_getters.register(P.Concat)
335def get_bprop_concat(self):
336    """Generate bprop for Concat"""
337    axis = self.axis
338
339    def bprop(x, out, dout):
340        out_offset = G.ConcatOffset(len(x), axis)(x)
341        input_nums = len(x)
342        input_shapes = ()
343        for i in range(input_nums):
344            input_shapes = input_shapes + (shape_op(x[i]),)
345        is_uniform = _concat_grad_uniform(input_shapes, input_nums)
346        if isinstance(x, list):
347            dx = []
348            if is_uniform:
349                dx_tuple = P.Split(axis, input_nums)(dout)
350                for _, i in enumerate(dx_tuple):
351                    dx = dx + [i]
352            else:
353                for i in range(input_nums):
354                    slice_out = P.Slice()(dout, out_offset[i], input_shapes[i])
355                    dx = dx + [slice_out]
356        else:
357            dx = ()
358            if is_uniform:
359                dx = P.Split(axis, input_nums)(dout)
360            else:
361                for i in range(input_nums):
362                    slice_out = P.Slice()(dout, out_offset[i], input_shapes[i])
363                    dx = dx + (slice_out,)
364        return (dx,)
365
366    return bprop
367
368
369@constexpr
370def _slice_grad_pad(begins, sizes, shapes):
371    pads = tuple((begin, shape - begin - size) for begin, size, shape in zip(begins, sizes, shapes))
372    return pads
373
374
375@bprop_getters.register(P.Slice)
376def get_bprop_slice(self):
377    """Generate bprop for Slice"""
378
379    def bprop(x, begin, size, out, dout):
380        dx = G.SliceGrad()(dout, x, begin, size)
381        return (dx, zeros_like(begin), zeros_like(size))
382
383    return bprop
384
385
386@constexpr
387def _generate_inverse_index(x_shape, axis):
388    x_rank = len(x_shape)
389    index = tuple(range(x_rank))
390    if axis < 0:
391        axis += x_rank
392    perm = index[1:1 + axis] + (0,) + index[1 + axis:]
393    return perm
394
395
396@constexpr
397def _regenerate_output_shape(x_shp, ind_shp, axis):
398    rank = len(x_shp)
399    if axis < 0:
400        axis += rank
401    out_shape = x_shp[:axis] + ind_shp + x_shp[axis + 1:]
402    return out_shape
403
404
405@bprop_getters.register(P.Gather)
406@bprop_getters.register(P.GatherV2)
407def get_bprop_gather_v2(self):
408    """Generate bprop for GatherV2"""
409
410    def bprop(x, indices, axis, out, dout):
411        orig_indices = indices
412        if F.rank(dout) == 0:
413            dout = P.ExpandDims()(dout, -1)
414        if F.rank(indices) == 0:
415            indices = P.ExpandDims()(indices, -1)
416            x_shp = shape_op(x)
417            ind_shp = shape_op(indices)
418            out_shp = _regenerate_output_shape(x_shp, ind_shp, axis)
419            dout = reshape(dout, out_shp)
420
421        x_shp = shape_op(x)
422        out_shp = shape_op(dout)
423        ind_shp = shape_op(indices)
424        # Example: out_shape:(3,2,3) axis 1 -> (1,0,2)
425        perm_1 = generate_shape_index(out_shp, ind_shp, axis)
426        values_transpose = transpose(dout, perm_1)
427        if -1 in shape_op(x):
428            params_grad = unsorted_segment_sum(values_transpose, indices, dyn_shape_op(x)[axis])
429        else:
430            params_grad = unsorted_segment_sum(values_transpose, indices, shape_op(x)[axis])
431        # Example: out_shape:(3,2,3) axis 2 -> (1,2,0)
432        perm_2 = _generate_inverse_index(x_shp, axis)
433        params_grad = transpose(params_grad, perm_2)
434        return params_grad, zeros_like(orig_indices), zeros_like(axis)
435
436    return bprop
437
438
439@bprop_getters.register(P.GatherD)
440def get_bprop_gather_d(self):
441    """Generate bprop for GatherD"""
442
443    def bprop(x, dim, index, out, dout):
444        x_shp = shape_op(x)
445        dx = G.GatherDGrad(dim, x_shp)(index, dout)
446        return dx, zeros_like(dim), zeros_like(index)
447
448    return bprop
449
450
451@bprop_getters.register(G.GatherDGrad)
452def get_bprop_gather_d_grad(self):
453    """Generate bprop for GatherDGrad"""
454    op = P.Gather()
455    dim = self.dim
456    x_shp = self.out_shape
457
458    def bprop(index, x, out, dout):
459        index_shp = shape_op(index)
460        dim_before_axis = 1
461        for i in range(dim):
462            dim_before_axis *= x_shp[i]
463        dim_at_axis_index = index_shp[dim]
464        dim_at_axis_output = x_shp[dim]
465        dim_after_axis = 1
466        for i in range(dim+1, len(x_shp)):
467            dim_after_axis *= x_shp[i]
468        element = dim_before_axis * dim_at_axis_index * dim_after_axis
469        id_ = range_op(0, element, 1, index.dtype)
470        i = id_ // (dim_at_axis_index * dim_after_axis)
471        k = id_ % dim_after_axis
472        j = P.Cast()(index < 0, index.dtype)
473        j_read = dim_at_axis_index * j + index
474        j_read = P.Reshape()(j_read, (-1,))
475        read_id = i*dim_at_axis_output*dim_after_axis + j_read * dim_after_axis + k
476        dout = P.Reshape()(dout, (-1,))
477        dx = op(dout, read_id, 0)
478        dx = P.Reshape()(dx, shape_op(x))
479        return zeros_like(index), dx
480
481    return bprop
482
483@bprop_getters.register(P.SparseGatherV2)
484def get_bprop_sparse_gather_v2(self):
485    """Generate bprop for SparseGatherV2"""
486
487    def bprop(x, indices, axis, out, dout):
488        x_shp = shape_op(x)
489        if axis == 0:
490            indices_size = (size_op(indices),)
491            if len(x_shp) <= 1:
492                x_tail_shp = ()
493            else:
494                x_tail_shp = x_shp[1:]
495            values_shape = indices_size + x_tail_shp
496            values = reshape(dout, values_shape)
497            indices_new = reshape(indices, indices_size)
498            return RowTensor(indices_new, values, x_shp), zeros_like(indices), zeros_like(axis)
499        if F.rank(dout) == 0:
500            dout = P.ExpandDims()(dout, -1)
501        if F.rank(indices) == 0:
502            indices = P.ExpandDims()(indices, -1)
503        out_shp = shape_op(dout)
504        ind_shp = shape_op(indices)
505        # Example: out_shape:(3,2,3) axis 1 -> (1,0,2)
506        perm_1 = generate_shape_index(out_shp, ind_shp, axis)
507        values_transpose = transpose(dout, perm_1)
508        params_grad = unsorted_segment_sum(values_transpose, indices, shape_op(x)[axis])
509        # Example: out_shape:(3,2,3) axis 2 -> (1,2,0)
510        perm_2 = _generate_inverse_index(x_shp, axis)
511        params_grad = transpose(params_grad, perm_2)
512        return params_grad, zeros_like(indices), zeros_like(axis)
513
514    return bprop
515
516
517@constexpr
518def _get_transposition(axis, rank):
519    """helper function for grad of Sort"""
520    if axis < 0:
521        axis += rank
522    transposition = np.r_[np.arange(axis), [rank - 1], np.arange(axis + 1, rank - 1), [axis]]
523    trans = tuple(transposition.tolist())
524    return trans
525
526
527@bprop_getters.register(P.Sort)
528def get_bprop_sort(self):
529    """Grad definition for `Sort` operation."""
530    axis = self.axis
531    descending = self.descending
532    scatter = P.ScatterNd()
533    expand_dims = P.ExpandDims()
534    reshape_op = P.Reshape()
535    dtype = P.DType()
536    topk = P.TopK()
537    neg = P.Neg()
538    tranpose = P.Transpose()
539
540    def bprop(input_x, out, dout):
541        x_shape = input_x.shape
542        k = x_shape[axis]
543        rank = F.rank(input_x)
544        dvalue = dout[0]
545        if not descending:
546            input_x = neg(input_x)
547            dvalue = neg(dvalue)
548        if axis == -1 or (axis + 1) == rank:
549            transposition = None
550            top_k_input = input_x
551        else:
552            transposition = _get_transposition(axis, rank)
553            top_k_input = tranpose(input_x, transposition)
554
555        _, indices = topk(top_k_input, k)
556        ind_shape = indices.shape
557        top_k_input_shape = top_k_input.shape
558        in_lastdim = top_k_input_shape[-1]
559        ind_lastdim = ind_shape[-1]
560        ind_2d = reshape_op(indices, (-1, ind_lastdim))
561        outer_dim = ind_2d.shape[0]
562
563        indices_dtype = dtype(indices)
564        range_flatten_index = range_op(0, outer_dim * in_lastdim, in_lastdim, indices_dtype)
565
566        # expand_dims to (k, 1), then broadcast
567        ind = reshape_op(ind_2d + expand_dims(range_flatten_index, -1), (-1,))
568        x_shape_1d = get_1d_shape(top_k_input_shape)
569
570        if transposition is not None:
571            dvalue = tranpose(dvalue, invert_permutation(transposition))
572            out_grad = reshape_op(
573                scatter(expand_dims(ind, -1), reshape_op(dvalue, (-1,)), x_shape_1d), top_k_input_shape)
574            dx = tranpose(out_grad, invert_permutation(transposition))
575        else:
576            dx = reshape_op(scatter(expand_dims(ind, -1), reshape_op(dvalue, (-1,)), x_shape_1d), top_k_input_shape)
577        if not descending:
578            dx = neg(dx)
579        return (dx,)
580
581    return bprop
582
583
584@bprop_getters.register(P.Identity)
585def get_bprop_identity(self):
586    """Generate bprop for Identity"""
587
588    def bprop(x, out, dout):
589        return (dout,)
590
591    return bprop
592
593
594@bprop_getters.register(inner.Range)
595def get_bprop_range(self):
596    """Generate bprop for Range"""
597
598    def bprop(x, out, dout):
599        return (zeros_like(x),)
600
601    return bprop
602
603
604@bprop_getters.register(P.Pack)
605@bprop_getters.register(P.Stack)
606def get_bprop_stack(self):
607    """Generate bprop for Stack"""
608    axis = self.axis
609
610    def bprop(x, out, dout):
611        stack_grad = P.Unstack(axis)
612        out = stack_grad(dout)
613        if is_sub_class(F.typeof(x), ms.list_):
614            ret = []
615            for item in out:
616                ret.append(item)
617            return (ret,)
618        return (out,)
619
620    return bprop
621
622
623@bprop_getters.register(P.ReverseV2)
624def get_bprop_reverse_v2(self):
625    """Generate bprop for ReverseV2"""
626    axis = self.axis
627
628    def bprop(x, out, dout):
629        reverse_grad = P.ReverseV2(axis)
630        dx = reverse_grad(dout)
631        return (dx,)
632
633    return bprop
634
635
636@bprop_getters.register(P.Unstack)
637def get_bprop_unstack(self):
638    """Generate bprop for Unstack"""
639    axis = self.axis
640
641    def bprop(x, out, dout):
642        unstack_grad = P.Stack(axis)
643        out = unstack_grad(dout)
644        return (out,)
645
646    return bprop
647
648
649@bprop_getters.register(P.StridedSlice)
650def get_bprop_strided_slice(self):
651    """Generate bprop for StridedSlice"""
652    input_grad = G.StridedSliceGrad(self.begin_mask,
653                                    self.end_mask,
654                                    self.ellipsis_mask,
655                                    self.new_axis_mask,
656                                    self.shrink_axis_mask)
657
658    def bprop(x, begin, end, strides, out, dout):
659        x_shape = shape_op(x)
660        if -1 in x_shape:
661            x_shape = dyn_shape_op(x)
662        dx = input_grad(dout, x_shape, begin, end, strides)
663        return dx, zeros_like(begin), zeros_like(end), zeros_like(strides)
664
665    return bprop
666
667
668@bprop_getters.register(P.Eye)
669def get_bprop_eye(self):
670    """Generate bprop for Eye"""
671
672    def bprop(n, m, t, out, dout):
673        return zeros_like(n), zeros_like(m), zeros_like(t)
674
675    return bprop
676
677
678@bprop_getters.register(P.Select)
679def get_bprop_select(self):
680    """Generate bprop for Select"""
681    select = P.Select()
682
683    def bprop(cond, x, y, out, dout):
684        return zeros_like(cond), select(cond, dout, zeros_like(x)), select(cond, zeros_like(y), dout)
685
686    return bprop
687
688
689@bprop_getters.register(P.OnesLike)
690def get_bprop_oneslike(self):
691    """Generate bprop for OnesLike"""
692
693    def bprop(x, out, dout):
694        return (zeros_like(x),)
695
696    return bprop
697
698
699@bprop_getters.register(P.ZerosLike)
700def get_bprop_zeroslike(self):
701    """Generate bprop for ZerosLike"""
702
703    def bprop(x, out, dout):
704        return (zeros_like(x),)
705
706    return bprop
707
708
709@bprop_getters.register(P.ResizeNearestNeighbor)
710def get_bprop_resize_nearest_neighbor(self):
711    """Generate bprop for ResizeNearestNeighbor"""
712    op = G.ResizeNearestNeighborGrad(self.align_corners)
713
714    def bprop(inputs, out, dout):
715        shp = shape_op(inputs)
716        # 2 and 3 represent the height and width
717        shp = (shp[2], shp[3])
718        return (op(dout, shp),)
719
720    return bprop
721
722
723@bprop_getters.register(P.GatherNd)
724def get_bprop_gather_nd(self):
725    """Generate bprop for GatherNd"""
726    op = P.ScatterNd()
727
728    def bprop(x, indices, out, dout):
729        shp = shape_op(x)
730        return op(indices, dout, shp), zeros_like(indices)
731
732    return bprop
733
734
735@bprop_getters.register(P.ScatterNd)
736def get_bprop_scatter_nd(self):
737    """Generate bprop for ScatterNd"""
738    op = P.GatherNd()
739
740    def bprop(indices, x, shape, out, dout):
741        return zeros_like(indices), op(dout, indices), zeros_like(shape)
742
743    return bprop
744
745
746@bprop_getters.register(P.ScatterNdUpdate)
747def get_bprop_scatter_nd_update(self):
748    """Generate bprop for ScatterNdUpdate"""
749    op = P.GatherNd()
750
751    def bprop(x, indices, update, out, dout):
752        return dout, zeros_like(indices), op(dout, indices)
753
754    return bprop
755
756
757@bprop_getters.register(P.ScatterNonAliasingAdd)
758def get_bprop_scatter_non_aliasing_add_update(self):
759    """Generate bprop for ScatterNonAliasingAdd"""
760    op = P.GatherNd()
761
762    def bprop(x, indices, update, out, dout):
763        return dout, zeros_like(indices), op(dout, indices)
764
765    return bprop
766
767
768@bprop_getters.register(P.TensorScatterUpdate)
769def get_bprop_tensor_scatter_update(self):
770    """Generate bprop for TensorScatterUpdate"""
771    gather_nd = P.GatherNd()
772    tensor_scatter_update = P.TensorScatterUpdate()
773
774    def bprop(x, indices, update, out, dout):
775        x_grad = tensor_scatter_update(dout, indices, zeros_like(update))
776        update_grad = gather_nd(dout, indices)
777        return x_grad, zeros_like(indices), update_grad
778
779    return bprop
780
781
782@bprop_getters.register(P.TensorScatterAdd)
783def get_bprop_tensor_scatter_add(self):
784    """Generate bprop for TensorScatterAdd"""
785    gather_nd = P.GatherNd()
786
787    def bprop(x, indices, update, out, dout):
788        update_grad = gather_nd(dout, indices)
789        return dout, zeros_like(indices), update_grad
790
791    return bprop
792
793
794@bprop_getters.register(P.ScatterMax)
795def get_bprop_scatter_max(self):
796    """Generate bprop for ScatterMax"""
797    gather = P.Gather()
798
799    def bprop(x, indices, update, out, dout):
800        return dout, zeros_like(indices), gather(dout, indices, 0)
801
802    return bprop
803
804
805@bprop_getters.register(P.Argmax)
806def get_bprop_argmax(self):
807    """Generate bprop for Argmax"""
808
809    def bprop(x, out, dout):
810        return (zeros_like(x),)
811
812    return bprop
813
814
815@bprop_getters.register(P.Argmin)
816def get_bprop_argmin(self):
817    """Generate bprop for Argmin"""
818
819    def bprop(x, out, dout):
820        return (zeros_like(x),)
821
822    return bprop
823
824
825@bprop_getters.register(P.SpaceToDepth)
826def get_bprop_space_to_depth(self):
827    """Generate bprop for SpaceToDepth"""
828    op = P.DepthToSpace(self.block_size)
829
830    def bprop(x, out, dout):
831        return (op(dout),)
832
833    return bprop
834
835
836@bprop_getters.register(P.DepthToSpace)
837def get_bprop_depth_to_space(self):
838    """Generate bprop for DepthToSpace"""
839    op = P.SpaceToDepth(self.block_size)
840
841    def bprop(x, out, dout):
842        return (op(dout),)
843
844    return bprop
845
846
847@bprop_getters.register(P.Diag)
848def get_bprop_diag(self):
849    """Generate bprop for Diag"""
850    op = P.DiagPart()
851
852    def bprop(x, out, dout):
853        return (op(dout),)
854
855    return bprop
856
857
858@bprop_getters.register(P.DiagPart)
859def get_bprop_diag_part(self):
860    """Generate bprop for DiagPart"""
861    op = P.Diag()
862
863    def bprop(x, out, dout):
864        return (op(dout),)
865
866    return bprop
867
868
869def _gather_drop_negatives(params,
870                           ids,
871                           zero_clipped_indices=None,
872                           is_positive=None):
873    """Helper function for unsorted segment ops."""
874    maximum = P.Maximum()
875    gather = P.Gather()
876    greater_equal = P.GreaterEqual()
877    rank = P.Rank()
878    fill = P.Fill()
879    select = P.Select()
880
881    if zero_clipped_indices is None:
882        zero_clipped_indices = maximum(ids, zeros_like(ids))
883    gathered = gather(params, zero_clipped_indices, 0)
884    if is_positive is None:
885        is_positive = greater_equal(ids, 0)
886        is_positive_shape = shape_op(is_positive)
887        broadcastable_shape = is_positive_shape
888        for _ in range(rank(gathered) - rank(is_positive)):
889            broadcastable_shape += (1,)
890        is_positive = reshape(is_positive, broadcastable_shape)
891        gathered_shape = shape_op(gathered)
892        is_positive = logical_and(is_positive, fill(mstype.bool_, gathered_shape, 1))
893    zero_slice = zeros_like(gathered)
894    return (select(is_positive, gathered, zero_slice), zero_clipped_indices, is_positive)
895
896
897def _unsorted_segment_min_or_max_grad(x, segment_ids, num_segments, out, dout):
898    """Gradient for UnsortedSegmentMin or UnsortedSegmentMax"""
899    equal = P.Equal()
900    cast = P.Cast()
901    divide = P.RealDiv()
902    get_dtype = P.DType()
903    select = P.Select()
904
905    gathered_outputs, zero_clipped_indices, is_positive = _gather_drop_negatives(out, segment_ids, None, None)
906    is_selected = equal(x, gathered_outputs)
907    is_selected = logical_and(is_selected, is_positive)
908    num_selected = unsorted_segment_sum(cast(is_selected, get_dtype(dout)),
909                                        segment_ids, num_segments)
910    weighted_grads = divide(dout, num_selected)
911    gathered_grads, _, _ = _gather_drop_negatives(weighted_grads, None,
912                                                  zero_clipped_indices, is_positive)
913    zeros = zeros_like(gathered_grads)
914    return select(is_selected, gathered_grads, zeros), zeros_like(segment_ids), zeros_like(num_segments)
915
916
917@bprop_getters.register(P.UnsortedSegmentSum)
918def get_bprop_unsorted_segment_sum(self):
919    """Generate bprop for UnsortedSegmentSum"""
920
921    def bprop(x, segment_ids, num_segments, out, dout):
922        return _gather_drop_negatives(dout, segment_ids, None, None)[0], zeros_like(segment_ids), \
923               zeros_like(num_segments)
924
925    return bprop
926
927
928@bprop_getters.register(P.UnsortedSegmentMin)
929def get_bprop_unsorted_segment_min(self):
930    """Generate bprop for UnsortedSegmentMin"""
931
932    def bprop(x, segment_ids, num_segments, out, dout):
933        return _unsorted_segment_min_or_max_grad(x, segment_ids, num_segments, out, dout)
934
935    return bprop
936
937
938@bprop_getters.register(P.UnsortedSegmentMax)
939def get_bprop_unsorted_segment_max(self):
940    """Generate bprop for UnsortedSegmentMax"""
941
942    def bprop(x, segment_ids, num_segments, out, dout):
943        return _unsorted_segment_min_or_max_grad(x, segment_ids, num_segments, out, dout)
944
945    return bprop
946
947
948@bprop_getters.register(P.UnsortedSegmentProd)
949def get_bprop_unsorted_segment_prod(self):
950    """Generate bprop for UnsortedSegmentProd"""
951    equal = P.Equal()
952    cast = P.Cast()
953    select = P.Select()
954    gather = P.Gather()
955    greater = P.Greater()
956    ones_like = P.OnesLike()
957    maximum = P.Maximum()
958    unsorted_segment_prod = P.UnsortedSegmentProd()
959
960    def bprop(x, segment_ids, num_segments, out, dout):
961        is_zero = equal(x, 0)
962        num_zero = unsorted_segment_sum(cast(is_zero, mstype.int32), segment_ids, num_segments)
963        grad = select(greater(num_zero, 1), zeros_like(dout), dout)
964        non_zero_data = select(is_zero, ones_like(x), x)
965        non_zero_prod = unsorted_segment_prod(non_zero_data, segment_ids, num_segments)
966        zero_clipped_indices = maximum(segment_ids, zeros_like(segment_ids))
967        gathered_prod = gather(out, zero_clipped_indices, 0)
968        gathered_non_zero_prod = gather(non_zero_prod, zero_clipped_indices, 0)
969        prod_divided_by_x = gathered_prod / x
970        partial_derivative = select(is_zero, gathered_non_zero_prod, prod_divided_by_x)
971        gathered_grad, _, _ = _gather_drop_negatives(grad, segment_ids, zero_clipped_indices, None)
972        dx = gathered_grad * partial_derivative
973        return dx, zeros_like(segment_ids), zeros_like(num_segments)
974
975    return bprop
976
977
978@bprop_getters.register(P.SpaceToBatch)
979def get_bprop_space_to_batch(self):
980    """Generate bprop for SpaceToBatch"""
981    space_to_batch_grad = P.BatchToSpace(self.block_size, self.paddings)
982
983    def bprop(x, out, dout):
984        dx = space_to_batch_grad(dout)
985        return (dx,)
986
987    return bprop
988
989
990@bprop_getters.register(P.BatchToSpace)
991def get_bprop_batch_to_space(self):
992    """Generate bprop for BatchToSpace"""
993    batch_to_space_grad = P.SpaceToBatch(self.block_size, self.crops)
994
995    def bprop(x, out, dout):
996        dx = batch_to_space_grad(dout)
997        return (dx,)
998
999    return bprop
1000
1001
1002@bprop_getters.register(P.SpaceToBatchND)
1003def get_bprop_space_to_batch_nd(self):
1004    """Generate bprop for SpaceToBatchND"""
1005    space_to_batch_nd_grad = P.BatchToSpaceND(self.block_shape, self.paddings)
1006
1007    def bprop(x, out, dout):
1008        dx = space_to_batch_nd_grad(dout)
1009        return (dx,)
1010
1011    return bprop
1012
1013
1014@bprop_getters.register(P.BatchToSpaceND)
1015def get_bprop_batch_to_space_nd(self):
1016    """Generate bprop for BatchToSpaceND"""
1017    batch_to_space_nd_grad = P.SpaceToBatchND(self.block_shape, self.crops)
1018
1019    def bprop(x, out, dout):
1020        dx = batch_to_space_nd_grad(dout)
1021        return (dx,)
1022
1023    return bprop
1024
1025
1026@bprop_getters.register(P.BroadcastTo)
1027def get_bprop_broadcast_to(self):
1028    """Generate bprop for BroadcastTo"""
1029    reduce_keep_dim = P.ReduceSum(keep_dims=True)
1030
1031    def bprop(x, out, dout):
1032        x_shape = shape_op(x)
1033        dout_shape = shape_op(dout)
1034        broadcast_shape = shape_op(out)
1035
1036        if x_shape == dout_shape:
1037            return (dout,)
1038        _, reduction_axes = broadcast_gradient_args(broadcast_shape, x_shape)
1039        reduced_grad = reduce_keep_dim(dout, reduction_axes)
1040        dx = reshape(reduced_grad, x_shape)
1041        return (dx,)
1042
1043    return bprop
1044
1045
1046@bprop_getters.register(P.ReverseSequence)
1047def get_bprop_reverse_sequence(self):
1048    """Generate bprop for ReverseSequence"""
1049    reverse_sequence_grad = P.ReverseSequence(batch_dim=self.batch_dim_, seq_dim=self.seq_dim_)
1050
1051    def bprop(x, seq_lengths, out, dout):
1052        dx = reverse_sequence_grad(dout, seq_lengths)
1053        return dx, zeros_like(seq_lengths)
1054
1055    return bprop
1056
1057
1058@bprop_getters.register(P.TransShape)
1059def get_bprop_trans_shape(self):
1060    """Generate bprop for TransShape"""
1061    op = P.TransShape()
1062
1063    def bprop(x, shape, out, dout):
1064        dx = op(dout, shape_op(x))
1065        return (dx, zeros_like(shape))
1066
1067    return bprop
1068
1069
1070@bprop_getters.register(P.Unique)
1071def get_bprop_unique(self):
1072    """Generate bprop for Unique"""
1073    op = G.UniqueGrad()
1074
1075    def bprop(x, out, dout):
1076        dx = op(dout, out)
1077        return (dx,)
1078
1079    return bprop
1080
1081
1082@bprop_getters.register(P.MaskedSelect)
1083def get_bprop_masked_select(self):
1084    """Generate bprop for MaskedSelect"""
1085    op = G.MaskedSelectGrad()
1086
1087    def bprop(x, mask, out, dout):
1088        dx = op(x, mask, dout)
1089        return (dx, zeros_like(mask))
1090
1091    return bprop
1092