• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020-2024 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16"""Operators for gradients."""
17# pylint: disable=unused-import
18from __future__ import absolute_import
19
20from __future__ import division
21from mindspore._checkparam import _check_3d_int_or_tuple
22from mindspore.ops.operations.nn_ops import _check_positive_int_or_tuple
23from mindspore.ops import signature as sig
24from mindspore.ops._utils import get_concat_offset
25from mindspore.ops.primitive import Primitive, PrimitiveWithInfer, prim_attr_register
26import mindspore.context as context
27from mindspore import _checkparam as validator
28from mindspore.common import dtype as mstype
29from mindspore.communication.management import GlobalComm
30from mindspore.common._utils import is_shape_unknown, is_dim_unknown
31from ..auto_generate import (AbsGrad, ACosGrad, LogitGrad, AcoshGrad, AsinGrad, AsinhGrad, ReciprocalGrad, RsqrtGrad,
32                             SqrtGrad, BatchNormGrad, BatchNormGradGrad, BiasAddGrad, GeLUGrad, FastGeLUGrad,
33                             AvgPoolGrad, MinimumGrad, LogSoftmaxGrad, PReLUGrad, ReluGrad, ReLU6Grad, EluGrad,
34                             GatherDGradV2, ResizeBilinearGrad, ResizeLinear1DGrad, ResizeNearestNeighborV2Grad,
35                             SigmoidGrad, HSwishGrad, NLLLossGrad, AtanGrad, GridSampler3DGrad, GridSampler2DGrad,
36                             ResizeBicubicGrad, HSigmoidGrad, CholeskyGrad, ResizeNearestNeighborGrad, LayerNormGrad,
37                             HShrinkGrad, LayerNormGradGrad, SiLUGrad, MaximumGrad, MaximumGradGrad, RmsNormGrad,
38                             FlashAttentionScoreGrad, UpsampleTrilinear3DGrad, UpsampleNearest3DGrad,
39                             BinaryCrossEntropyGrad)
40
41
42class SparseFillEmptyRowsGrad(Primitive):
43    """Performs grad of SparseFillEmptyRows operation."""
44
45    @prim_attr_register
46    def __init__(self):
47        """Initialize SparseFillEmptyRowsGrad."""
48        self.init_prim_io_names(inputs=['reverse_index_map', 'grad_values'],
49                                outputs=['y_values', 'y_default_value'])
50
51
52class ScaleAndTranslateGrad(Primitive):
53    """Performs grad of ScaleAndTranslate operation."""
54
55    @prim_attr_register
56    def __init__(self, kernel_type="lanczos3", antialias=True):
57        """Initialize ScaleAndTranslateGrad"""
58        validator.check_value_type("kernel_type", kernel_type, [str], self.name)
59        validator.check_string(kernel_type, ["lanczos1", "lanczos3", "lanczos5", "gaussian", "box", "triangle",
60                                             "keyscubic", "mitchellcubic"], "kernel_type", self.name)
61        validator.check_value_type("antialias", antialias, [bool], self.name)
62
63
64class SoftmaxGrad(Primitive):
65    """Performs grad of Softmax operation."""
66
67    @prim_attr_register
68    def __init__(self):
69        """Initialize SoftmaxGrad"""
70        self.init_prim_io_names(inputs=['y', 'dy'], outputs=['z'])
71
72
73class SyncBatchNormGrad(Primitive):
74    """Performs grad of SyncBatchNorm operation."""
75
76    @prim_attr_register
77    def __init__(self, epsilon=1e-5, group="group0", device_num=2):
78        validator.check_float_range(epsilon, 0, 1, validator.INC_RIGHT, 'epsilon', self.name)
79        if not isinstance(group, str):
80            raise TypeError("The group attr of SyncBatchNormGrad must be str.")
81        validator.check_int(device_num, 2, validator.GE, "device_num", self.name)
82
83
84class KLDivLossGrad(Primitive):
85    """Computes gradients for `KLDivLoss` operation."""
86
87    @prim_attr_register
88    def __init__(self, reduction='mean'):
89        device_target = context.get_context("device_target")
90        if device_target == "CPU":
91            support_mode = ['none', 'mean', 'batchmean', 'sum']
92        elif device_target == "GPU":
93            support_mode = ['none', 'mean', 'sum']
94        elif device_target == "Ascend":
95            support_mode = ['none', 'mean', 'batchmean', 'sum']
96        else:
97            raise ValueError(f"'{self.name}' unknown device target: '{device_target}'")
98        self.reduction = validator.check_string(reduction, support_mode, 'reduction', self.name)
99
100
101class LuUnpackGrad(Primitive):
102    """Computes gradients for `LuUnpack` operation."""
103
104    @prim_attr_register
105    def __init__(self, L_grad_flag, U_grad_flag):
106        validator.check_value_type("L_grad_flag", L_grad_flag, [bool], self.name)
107        validator.check_value_type("U_grad_flag", U_grad_flag, [bool], self.name)
108        self.add_prim_attr("cust_aicpu", self.name)
109
110
111class ConcatOffset(PrimitiveWithInfer):
112    """primitive for computing Concat's gradient."""
113
114    @prim_attr_register
115    def __init__(self, N=2, axis=0):
116        """Initialize ConcatOffset"""
117
118    def __infer__(self, input_x):
119        axis = self.axis
120        x_shp = input_x['shape']
121        x_type = input_x['dtype']
122        self.add_prim_attr('T', x_type[0].element_type())
123
124        # input_x is dynamic rank
125        rank = -1
126        is_dyn_rank = False
127        for _, sh in enumerate(x_shp):
128            if is_dim_unknown(sh):
129                is_dyn_rank = True
130            else:
131                rank = len(sh)
132        if is_dyn_rank:
133            return {
134                'shape': [len(x_shp), rank],
135                'dtype': mstype.int64,
136                'value': None
137            }
138
139        # if the dimension of input_x on the axis is dynamic
140        if axis < -rank or axis >= rank:
141            raise ValueError("For 'ConcatOffset', 'axis' must be in range [{}, {}), but got {}"
142                             .format(-rank, rank, axis))
143        if axis < 0:
144            axis = axis + rank
145        for each in x_shp:
146            if each[axis] == -1:
147                return {
148                    'shape': [len(x_shp), len(x_shp[0])],
149                    'dtype': mstype.int64,
150                    'value': None
151                }
152
153        offset, _, axis = get_concat_offset(x_shp, x_type, axis, self.name)
154        offset_values = []
155        for i in range(len(x_shp)):
156            values = []
157            for j in range(len(x_shp[0])):
158                value = 0
159                if j == axis:
160                    value = offset[i]
161                values.append(value)
162            offset_values.append(tuple(values))
163        out = {'shape': None,
164               'dtype': None,
165               'value': tuple(offset_values)}
166        return out
167
168
169class Conv3DBackpropFilter(Primitive):
170    """
171    Computes the gradients of convolution 3D with respect to the filter.
172
173    Args:
174        out_channel (int): The dimension of the output.
175        kernel_size (Union[int, tuple[int]]): The kernel size of the 3D convolution.
176        mode (int): Modes for different convolutions. Not currently used.
177        pad_mode (str): Modes to fill padding. It could be "valid", "same", or "pad". Default: "valid".
178        pad (Union(int, tuple[int])): The pad value to be filled. Default: 0. If `pad` is an integer, the paddings of
179                    head, tail, top, bottom, left and right are the same, equal to pad. If `pad` is a tuple of four
180                    integers, the padding of head, tail, top, bottom, left and right equal to pad[0], pad[1], pad[2],
181                    pad[3], pad[4] and pad[5] correspondingly.
182        stride (Union(int, tuple[int])): The stride to be applied to the convolution filter. Default: 1.
183        dilation (Union(int, tuple[int])): Specifies the space to use between kernel elements. Default: 1.
184        group (int): Splits input into groups. Default: 1.
185        data_format (str): The optional value for data format. Currently only support 'NCDHW'.
186
187    Inputs:
188        - **x** (Tensor) - The input of the convolution, then the shape is :math:`(C_{out}, C_{in}, D_{in}, K_1, K_2)`.
189          Currently dout data type only support float16 and float32.
190        - **dout** (Tensor) - The gradients w.r.t the output of the convolution. The shape conforms to the default
191          data_format :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})`. Currently dout data type only support float16
192          and float32.
193        - **w_size** (tuple(int)) - A tuple describes the shape of the weight which conforms to the format
194          :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})`.
195
196    Outputs:
197        Tensor, the gradients w.r.t the weight of convolution 3D. It has the same shape as the weight.
198
199    Supported Platforms:
200        ``Ascend``
201
202    Examples:
203        >>> x = Tensor(np.ones([16, 32, 13, 37, 33]), mindspore.float16)
204        >>> dout = Tensor(np.ones([16, 32, 10, 32, 32]), mindspore.float16)
205        >>> w = Tensor(np.ones([32, 32, 4, 6, 2]), mindspore.float16)
206        >>> conv3d_backprop_input = P.Conv3DBackpropInput(out_channel=4, kernel_size=(4, 6, 2))
207        >>> output = conv3d_backprop_input(x, dout, F.shape(w))
208        >>> print(output.shape)
209        (32, 32, 4, 6, 2)
210    """
211
212    @prim_attr_register
213    def __init__(self,
214                 out_channel,
215                 kernel_size,
216                 mode=1,
217                 pad_mode="valid",
218                 pad=0,
219                 stride=(1, 1, 1, 1, 1),
220                 dilation=(1, 1, 1, 1, 1),
221                 group=1,
222                 data_format="NCDHW"):
223        """Initialize Convolution"""
224        self.init_prim_io_names(inputs=['x', 'out_backprop', 'filter_size'], outputs=['y'])
225        self.out_channel = validator.check_positive_int(out_channel, 'out_channel', self.name)
226        self.kernel_size = _check_3d_int_or_tuple('kernel_size', kernel_size, self.name)
227        self.stride = _check_3d_int_or_tuple('stride', stride, self.name, allow_five=True, ret_five=True)
228        self.add_prim_attr('strides', self.stride)
229        self.dilation = _check_3d_int_or_tuple('dilation', dilation, self.name, allow_five=True, ret_five=True)
230        self.add_prim_attr('dilations', self.dilation)
231        validator.check_value_type('pad', pad, (int, tuple), self.name)
232        if isinstance(pad, int):
233            pad = (pad,) * 6
234        validator.check_equal_int(len(pad), 6, 'pad size', self.name)
235        self.add_prim_attr('pad', pad)
236        self.pad_list = pad
237        self.add_prim_attr('pad_list', self.pad_list)
238
239        validator.check_value_type('pad_mode', pad_mode, [str], self.name)
240        self.pad_mode = validator.check_string(pad_mode.lower(), ['valid', 'same', 'pad'], 'pad_mode', self.name)
241        if self.pad_mode != 'pad' and self.pad_list != (0, 0, 0, 0, 0, 0):
242            raise ValueError(f"For '{self.name}', when pad is not 0, pad_mode must be set as 'pad'.")
243        if self.pad_mode == 'pad':
244            for item in pad:
245                validator.check_non_negative_int(item, 'pad item', self.name)
246        self.add_prim_attr('pad_mode', self.pad_mode)
247
248        self.mode = validator.check_equal_int(mode, 1, 'mode', self.name)
249        self.add_prim_attr('mode', self.mode)
250        self.group = validator.check_positive_int(group, 'group', self.name)
251        self.add_prim_attr('groups', self.group)
252        self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.name)
253        self.add_prim_attr('data_format', self.format)
254
255
256class Conv2DBackpropFilter(Primitive):
257    """
258    Computes the gradients of convolution with respect to the filter.
259
260    Args:
261        out_channel (int): The dimensionality of the output space.
262        kernel_size (Union[int, tuple[int]]): The size of the convolution window.
263        pad_mode (str): Modes to fill padding. It could be "valid", "same", or "pad". Default: "valid".
264        pad (Union(int, tuple[int])): The pad value to be filled. Default: 0. If `pad` is an integer, the paddings of
265                    top, bottom, left and right are the same, equal to pad. If `pad` is a tuple of four integers, the
266                    padding of top, bottom, left and right equal to pad[0], pad[1], pad[2], and pad[3] correspondingly.
267        pad_list (tuple): The pad list like (top, bottom, left, right). Default: (0, 0, 0, 0).
268        mode (int): Modes for different convolutions. 0 Math convolution, 1 cross-correlation convolution ,
269                    2 deconvolution, 3 depthwise convolution. Default: 1.
270        stride (tuple): The stride to be applied to the convolution filter. Default: (1, 1).
271        dilation (tuple): Specifies the dilation rate to be used for the dilated convolution. Default: (1, 1, 1, 1).
272        group (int): Splits input into groups. Default: 1.
273        data_format (str) - The format of input and output data. It should be 'NHWC' or 'NCHW', \
274            default is 'NCHW'.
275
276    Returns:
277        Tensor, the gradients of convolution.
278    """
279
280    @prim_attr_register
281    def __init__(self,
282                 out_channel,
283                 kernel_size,
284                 pad_mode="valid",
285                 pad=0,
286                 pad_list=(0, 0, 0, 0),
287                 mode=1,
288                 stride=(1, 1),
289                 dilation=(1, 1, 1, 1),
290                 group=1,
291                 data_format="NCHW"):
292        """Initialize Convolution"""
293        self.init_prim_io_names(inputs=['out_backprop', 'input', 'filter_sizes'], outputs=['output'])
294        self.out_channel = out_channel
295        self.kernel_size = kernel_size
296        self.mode = mode
297        pad_mode = pad_mode.upper()
298        self.add_prim_attr('pad_mode', pad_mode)
299        if isinstance(pad, int):
300            pad = (pad,) * 4
301        else:
302            validator.check_equal_int(len(pad), 4, 'pad size', self.name)
303        self.add_prim_attr("pad", pad)
304        self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name)
305        if context.get_context("device_target") != "GPU" and self.format == "NHWC":
306            raise ValueError("NHWC format only support in GPU target.")
307        self.add_prim_attr('data_format', self.format)
308        self.stride = _check_positive_int_or_tuple('stride', stride, self.name, allow_four=True, ret_four=True)
309        self.add_prim_attr('stride', self.stride)
310        self.dilation = _check_positive_int_or_tuple('dilation', dilation, self.name, allow_four=True, ret_four=True)
311        self.add_prim_attr('dilation', self.dilation)
312        self.group = group
313        self.add_prim_attr('groups', group)
314        if pad_list:
315            for x in pad_list:
316                if x != -1:
317                    validator.check_non_negative_int(x, 'element of pad_list', self.name)
318        self.pad_list = pad_list
319
320
321class DepthwiseConv2dNativeBackpropFilter(PrimitiveWithInfer):
322    """
323    Returns the gradient of filter for DepthwiseConv2dNative.
324
325    Applies depthwise conv2d for the input, which will generate more channels with channel_multiplier.
326
327    Refer to class DepthwiseConv2dNative for more details.
328
329    Args:
330        channel_multiplier (int): The multiplier for the original output conv.
331        kernel_size (int or tuple): The size of the conv kernel.
332        mode (int): Modes for different convolutions. 0 Math convolution, 1 cross-correlation convolution,
333                       2 deconvolution,3 depthwise convolution. Default: 3.
334        pad_mode (str): The mode to fill padding which can be: "valid", "same" or "pad". Default: "valid".
335        pad (Union(int, tuple[int])): The pad value to be filled. Default: 0. If `pad` is an integer, the paddings of
336                    top, bottom, left and right are the same, equal to pad. If `pad` is a tuple of four integers, the
337                    padding of top, bottom, left and right equal to pad[0], pad[1], pad[2], and pad[3] correspondingly.
338        pad_list (tuple): The pad list like (top, bottom, left, right). Default: (0, 0, 0, 0).
339        stride (int): The stride to be applied to the convolution filter. Default: 1.
340        dilation (int): Specifies the space to use between kernel elements. Default: 1.
341        group (int): Splits input into groups. Default: 1.
342
343    Returns:
344        Tensor, the value is the gradient of filter for DepthwiseConv2dNative.
345    """
346
347    @prim_attr_register
348    def __init__(self,
349                 channel_multiplier,
350                 kernel_size,
351                 pad_mode="valid",
352                 pad=0,
353                 pad_list=(0, 0, 0, 0),
354                 mode=3,
355                 stride=1,
356                 dilation=1,
357                 group=1):
358        """Initialize Convolution"""
359        self.init_prim_io_names(inputs=['input', 'filter_size', 'dout'], outputs=['output'])
360        self.channel_multiplier = channel_multiplier
361        self.kernel_size = kernel_size
362        self.mode = mode
363        self.pad_mode = pad_mode
364        if isinstance(pad, int):
365            pad = (pad,) * 4
366        else:
367            validator.check_equal_int(len(pad), 4, 'pad size', self.name)
368        self.add_prim_attr("pad", pad)
369        self.pad_list = pad_list
370        self.stride = stride
371        self.dilation = dilation
372        self.group = group
373        self.add_prim_attr('data_format', "NCHW")
374
375    def __infer__(self, x, w_size, dout):
376        w_size_v = w_size['value']
377        args = {'x': x['dtype'], 'dout': dout['dtype']}
378        validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
379        out = {
380            'value': None,
381            'shape': w_size_v,
382            'dtype': dout['dtype'],
383        }
384        return out
385
386
387class DepthwiseConv2dNativeBackpropInput(PrimitiveWithInfer):
388    """
389    Returns the gradient of input for DepthwiseConv2dNative.
390
391    Applies depthwise conv2d for the input, which will generate more channels with channel_multiplier.
392
393    Args:
394        channel_multiplier (int): The multiplier for the original output conv.
395        kernel_size (int or tuple): The size of the conv kernel.
396        mode (int): Modes for different convolutions. 0 Math convolution, 1 cross-correlation convolution ,
397                    2 deconvolution,3 depthwise convolution. Default: 3.
398        pad_mode (str):  Modes to fill padding. It could be "valid", "same", or "pad". Default: "valid".
399        pad (Union(int, tuple[int])): The pad value to be filled. Default: 0. If `pad` is an integer, the paddings of
400                    top, bottom, left and right are the same, equal to pad. If `pad` is a tuple of four integers, the
401                    padding of top, bottom, left and right equal to pad[0], pad[1], pad[2], and pad[3] correspondingly.
402        pad_list (tuple): The pad list like (top, bottom, left, right). Default: (0, 0, 0, 0).
403        stride (int): The stride to be applied to the convolution filter. Default: 1.
404        dilation (int): Specifies the space to use between kernel elements. Default: 1.
405        group (int): Splits input into groups. Default: 1.
406
407    Returns:
408        Tensor, the value is the gradient of input for DepthwiseConv2dNative.
409    """
410
411    @prim_attr_register
412    def __init__(self,
413                 channel_multiplier,
414                 kernel_size,
415                 pad_mode="valid",
416                 pad=0,
417                 pad_list=(0, 0, 0, 0),
418                 mode=3,
419                 stride=1,
420                 dilation=1,
421                 group=1):
422        """Initialize Convolution"""
423        self.init_prim_io_names(inputs=['input_size', 'filter', 'dout'], outputs=['output'])
424        self.channel_multiplier = channel_multiplier
425        self.kernel_size = kernel_size
426        self.mode = mode
427        self.pad_mode = pad_mode
428        if isinstance(pad, int):
429            pad = (pad,) * 4
430        else:
431            validator.check_equal_int(len(pad), 4, 'pad size', self.name)
432        self.add_prim_attr("pad", pad)
433        self.pad_list = pad_list
434        self.stride = stride
435        self.dilation = dilation
436        self.group = group
437        self.add_prim_attr('data_format', "NCHW")
438
439    def __infer__(self, x_size, w, dout):
440        args = {'w': w['dtype'], 'dout': dout['dtype']}
441        validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
442        x_size_v = x_size['value']
443        out = {
444            'value': None,
445            'shape': x_size_v,
446            'dtype': dout['dtype'],
447        }
448        return out
449
450
451class DropoutGrad(Primitive):
452    """
453    The gradient of Dropout. During training, randomly zeroes some of the elements
454    of the input tensor with probability.
455
456    Args:
457        keep_prob (float): The keep rate, between 0 and 1, e.g. keep_prob = 0.9,
458          means dropping out 10% of input units. Default: 0.5.
459
460    Inputs:
461        - **shape** (tuple[int]) - The shape of target mask.
462
463    Outputs:
464        Tensor, the value of generated mask for input shape.
465
466    Examples:
467        >>> dropout_grad = ops.DropoutGrad(keep_prob=0.5)
468        >>> in = Tensor((20, 16, 50, 50))
469        >>> out = dropout_grad(in)
470    """
471
472    @prim_attr_register
473    def __init__(self, keep_prob=0.5):
474        self.keep_prob = validator.check_float_range(keep_prob, 0, 1, validator.INC_RIGHT, "keep_prob", self.name)
475
476
477class FlattenGrad(PrimitiveWithInfer):
478    """Performs gradients of Flatten."""
479
480    @prim_attr_register
481    def __init__(self):
482        self.init_prim_io_names(inputs=['x', 'shape'], outputs=['output'])
483
484    def __infer__(self, *args):
485        out = {
486            'value': None,
487            'shape': args[1]['value'],
488            'dtype': args[0]['dtype'],
489        }
490        return out
491
492
493class InstanceNormGrad(PrimitiveWithInfer):
494    """Gradients of InstanceNorm operation."""
495
496    @prim_attr_register
497    def __init__(self, epsilon=0.0, momentum=0.1):
498        self.init_prim_io_names(inputs=['dy', 'x', 'gamma', 'save_mean', 'save_variance'],
499                                outputs=['dx', 'bn_gamma', 'bn_beta'])
500
501
502class InstanceNormV2Grad(Primitive):
503    """Gradients of InstanceNormV2 operation."""
504
505    @prim_attr_register
506    def __init__(self, is_training=True, epsilon=1e-5):
507        self.init_prim_io_names(inputs=['dy', 'x', 'gamma', 'mean', 'variance', 'save_mean', 'save_variance'],
508                                outputs=['pd_x', 'pd_gamma', 'pd_beta'])
509        validator.check_is_float(epsilon, 'epsilon', self.name)
510        validator.check_float_range(epsilon, 0, 1, validator.INC_RIGHT, 'epsilon', self.name)
511        validator.check_bool(is_training, "is_training", self.name)
512
513
514class EinsumGrad(PrimitiveWithInfer):
515    """Gradients of Einsum."""
516
517    @prim_attr_register
518    def __init__(self, equation):
519        pass
520
521    def infer_shape(self, x_shapes, dout_shape):
522        out_shape = ()
523        for dim in x_shapes:
524            out_shape += (dim,)
525        return out_shape
526
527    def infer_dtype(self, x_types, dout_shape):
528        out_type = ()
529        for cur_type in x_types:
530            out_type += (cur_type,)
531        return out_type
532
533
534class UniqueGrad(Primitive):
535    """Gradients of Unique operation."""
536
537    @prim_attr_register
538    def __init__(self):
539        self.init_prim_io_names(inputs=['dy', 'y'], outputs=['dx'])
540
541
542class BNTrainingReduceGrad(Primitive):
543    """Gradients of FusedBatchNorm operation."""
544
545    @prim_attr_register
546    def __init__(self, epsilon=0.0001, data_format='NCHW'):
547        self.data_format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name)
548        _inputs = ['grads', 'x', 'diff_scale', 'diff_offset', 'scale', 'batch_mean', 'batch_variance']
549        self.init_prim_io_names(inputs=_inputs, outputs=['y'])
550
551
552class BNTrainingUpdateGrad(Primitive):
553    """Gradients of FusedBatchNorm operation."""
554
555    @prim_attr_register
556    def __init__(self, epsilon=0.0001, data_format='NCHW'):
557        self.data_format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name)
558        self.init_prim_io_names(inputs=['grads', 'x', 'batch_mean', 'batch_variance'],
559                                outputs=['diff_scale', 'diff_offset'])
560
561
562class NeighborExchangeV2Grad(PrimitiveWithInfer):
563    """"Gradients of NeighborExchangeV2 operation."""
564
565    @prim_attr_register
566    def __init__(self, send_rank_ids, send_lens, recv_rank_ids, recv_lens, data_format,
567                 group=GlobalComm.WORLD_COMM_GROUP):
568        self.init_prim_io_names(inputs=['dy'], outputs=['dx'])
569        self.send_rank_ids = send_rank_ids
570        self.recv_rank_ids = recv_rank_ids
571        self.send_lens = send_lens
572        self.recv_lens = recv_lens
573        self.format = validator.check_string(data_format, ['NCHW'], 'format', self.name)
574        self.add_prim_attr('no_elimilate', True)
575
576    def __infer__(self, dy):
577        dy_shape = dy['shape']
578        validator.check(f'dy_shape.size()', len(dy_shape), f'4', 4, validator.EQ, self.name)
579        if self.send_rank_ids[5] != -1 or self.send_rank_ids[6] != -1 or self.send_rank_ids[7] != -1:
580            dy_shape[3] -= self.send_lens[2]
581
582        if self.send_rank_ids[1] != -1 or self.send_rank_ids[2] != -1 or self.send_rank_ids[3] != -1:
583            dy_shape[3] -= self.send_lens[3]
584
585        if self.send_rank_ids[0] != -1 or self.send_rank_ids[1] != -1 or self.send_rank_ids[7] != -1:
586            dy_shape[2] -= self.send_lens[0]
587
588        if self.send_rank_ids[3] != -1 or self.send_rank_ids[4] != -1 or self.send_rank_ids[5] != -1:
589            dy_shape[2] -= self.send_lens[1]
590
591        return {'shape': dy_shape,
592                'dtype': dy['dtype'],
593                'value': None}
594
595
596class _PoolGrad(PrimitiveWithInfer):
597    """Gradients of the max/avg pool operation."""
598
599    @prim_attr_register
600    def __init__(self, kernel_size, strides, pad_mode="VALID", data_format="NCHW"):
601        self.init_prim_io_names(inputs=['x_origin', 'out_origin', 'grad'], outputs=['output'])
602
603        validator.check_value_type('kernel_size', kernel_size, [int, tuple], self.name)
604        validator.check_value_type('strides', strides, [int, tuple], self.name)
605        validator.check_value_type('pad_mode', pad_mode, [str], self.name)
606        self.pad_mode = validator.check_string(pad_mode.upper(), ['VALID', 'SAME'], 'pad_mode', self.name)
607        self.add_prim_attr("pad_mode", self.pad_mode)
608        self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name)
609        if context.get_context("device_target") != "GPU" and self.format == "NHWC":
610            raise ValueError("NHWC format only support in GPU target.")
611        self.is_maxpoolgradwithargmax = (self.name == "MaxPoolGradWithArgmax")
612        if not self.is_maxpoolgradwithargmax:
613            self.add_prim_attr('data_format', self.format)
614
615        def _grad_check_int_or_tuple(arg_name, arg_val, is_argmax):
616            validator.check_value_type(arg_name, arg_val, (int, tuple), self.name)
617            error_msg = ValueError(f"For '{self.name}' the '{arg_name}' must be an positive int number "
618                                   f"or a tuple of two or four positive int numbers, but got {arg_val}")
619            if isinstance(arg_val, int):
620                ret = (1, arg_val, arg_val, 1) if is_argmax else (1, 1, arg_val, arg_val)
621            elif len(arg_val) == 2:
622                ret = (1, arg_val[0], arg_val[1], 1) if is_argmax else (1, 1, arg_val[0], arg_val[1])
623            elif len(arg_val) == 4:
624                ret = arg_val
625            else:
626                raise error_msg
627            # whether all elements of tuple are positive integers
628            for item in ret:
629                if not isinstance(item, int) or item <= 0:
630                    raise error_msg
631            return ret
632
633        kernel_size = _grad_check_int_or_tuple("kernel_size", kernel_size, self.is_maxpoolgradwithargmax)
634        strides = _grad_check_int_or_tuple("strides", strides, self.is_maxpoolgradwithargmax)
635        if self.format == "NCHW":
636            self.kernel_size = kernel_size
637            self.strides = strides
638        else:
639            self.kernel_size = [kernel_size[0], kernel_size[2], kernel_size[3], kernel_size[1]]
640            self.strides = [strides[0], strides[2], strides[3], strides[1]]
641        self.add_prim_attr("kernel_size", self.kernel_size)
642        self.add_prim_attr("strides", self.strides)
643
644
645class AvgPoolGradVm(_PoolGrad):
646    """Gradients of the avg pool operation for vm."""
647
648    @prim_attr_register
649    def __init__(self, kernel_size=1, strides=1, pad_mode="VALID"):
650        super(AvgPoolGradVm, self).__init__(kernel_size, strides, pad_mode)
651        self.init_prim_io_names(inputs=['x_origin', 'grad', 'mean_matrix', 'kernel_matrix'], outputs=['output'])
652
653    def __infer__(self, origin_input, dout, mean_matrix, kernel_matrix):
654        out = {
655            'value': None,
656            'shape': tuple(origin_input['value']),
657            'dtype': dout['dtype'],
658        }
659
660        return out
661
662
663class AvgPoolGradGe(_PoolGrad):
664    """Gradients of the avg pool operation for ge."""
665
666    @prim_attr_register
667    def __init__(self, kernel_size=1, strides=1, pad_mode="VALID", data_format="NCHW"):
668        super(AvgPoolGradGe, self).__init__(kernel_size, strides, pad_mode, data_format)
669
670    def __infer__(self, origin_input, dout):
671        out = {
672            'value': None,
673            'shape': tuple(origin_input['value']),
674            'dtype': dout['dtype'],
675        }
676
677        return out
678
679
680class AvgPoolGradV1(Primitive):
681    """Gradients of the AvgPoolV1 operation."""
682
683    @prim_attr_register
684    def __init__(self, kernel_size=1, strides=1, pad_mode="VALID", data_format="NCHW"):
685        validator.check_value_type('kernel_size', kernel_size, [int, tuple], self.name)
686        validator.check_value_type('strides', strides, [int, tuple], self.name)
687        self.pad_mode = validator.check_string(
688            pad_mode.upper(), ['VALID', 'SAME'], 'pad_mode', self.name)
689        self.add_prim_attr("pad_mode", self.pad_mode)
690        self.format = validator.check_string(
691            data_format, ['NCHW', 'NHWC'], 'format', self.name)
692        self.add_prim_attr('data_format', self.format)
693
694        def _avgpoolgrad_check_int_or_tuple(argname, argval):
695            validator.check_value_type(argname, argval, (int, tuple), self.name)
696            errormsg = ValueError(f"For '{self.name}' the '{argname}' should be an positive int number "
697                                  f"or a tuple of two or four positive int numbers, but got {argval}")
698            if isinstance(argval, int):
699                ret = (1, 1, argval, argval)
700            elif len(argval) == 2:
701                ret = (1, 1, argval[0], argval[1])
702            elif len(argval) == 4:
703                ret = argval
704            else:
705                raise errormsg
706            # whether all elements of tuple are positive integers?
707            for it in ret:
708                if not isinstance(it, int) or it <= 0:
709                    raise errormsg
710            return ret
711
712        self.kernel_size = _avgpoolgrad_check_int_or_tuple(
713            "kernel_size", kernel_size)
714        self.strides = _avgpoolgrad_check_int_or_tuple("strides", strides)
715
716        self.kernel_size_adapt = self.kernel_size if self.format == "NCHW" else (
717            self.kernel_size[0], self.kernel_size[2], self.kernel_size[3], self.kernel_size[1])
718        self.strides_adapt = self.strides if self.format == "NCHW" else (
719            self.strides[0], self.strides[2], self.strides[3], self.strides[1])
720
721        # If length of some attrs is 4 we regard it as legal, either by using the op directly,
722        # or passed from an instance of forward op AvgPoolV1.
723        if len(self.kernel_size) == 4:
724            self.kernel_size_adapt = self.kernel_size
725        if len(self.strides) == 4:
726            self.strides_adapt = self.strides
727
728        self.add_prim_attr("kernel_size", self.kernel_size_adapt)
729        self.add_prim_attr("strides", self.strides_adapt)
730
731
732class AdaptiveAvgPool2DGrad(Primitive):
733    """Gradients of the adaptive avg pool 2D operation."""
734
735    @prim_attr_register
736    def __init__(self):
737        """Initialize AdaptiveAvgPool2DGrad"""
738        self.init_prim_io_names(inputs=['input_grad', 'orig_input_shape'], outputs=['output_grad'])
739
740
741class AdaptiveAvgPool3DGrad(Primitive):
742    """Performs grad of AdaptiveAvgPool3D operation."""
743    @prim_attr_register
744    def __init__(self):
745        self.init_prim_io_names(inputs=['y_grad', 'orig_input_shape'], outputs=['x_grad'])
746
747
748class AvgPool3DGrad(Primitive):
749    """Gradients of the avg pool3d operation."""
750
751    @prim_attr_register
752    def __init__(self, kernel_size=1, strides=1, pads=0, ceil_mode=False,
753                 count_include_pad=True, divisor_override=0, data_format="NCDHW", pad_mode="pad"):
754        self.init_prim_io_names(inputs=['origin_input_shape', 'grads'], outputs=['output'])
755        self.kernel_size = _check_3d_int_or_tuple('kernel_size', kernel_size, self.name, allow_five=True, ret_five=True)
756        self.add_prim_attr('kernel_size', self.kernel_size)
757        self.strides = _check_3d_int_or_tuple('strides', strides, self.name, allow_five=True, ret_five=True)
758        self.add_prim_attr('strides', self.strides)
759        validator.check_value_type('pad_mode', pad_mode, [str], self.name)
760        self.pad_mode = validator.check_string(pad_mode.upper(), ['VALID', 'SAME', 'PAD'], 'pad_mode', self.name)
761        validator.check_value_type('pads', pads, (int, tuple), self.name)
762        if isinstance(pads, int):
763            pads = (pads,) * 6
764        validator.check_equal_int(len(pads), 6, 'pad size', self.name)
765        for item in pads:
766            validator.check_non_negative_int(item, 'pad item', self.name)
767        self.add_prim_attr('pad_list', pads)
768        self.ceil_mode = validator.check_value_type('ceil_mode', ceil_mode, bool, self.name)
769        self.count_include_pad = validator.check_value_type('count_include_pad', count_include_pad, bool, self.name)
770        self.divisor_override = validator.check_value_type('divisor_override', divisor_override, int, self.name)
771        self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.name)
772
773
774class AdaptiveMaxPool2DGrad(Primitive):
775    """Gradients of the adaptive max pool 2D operation."""
776    @prim_attr_register
777    def __init__(self):
778        """Initialize AdaptiveMaxPool2DGrad"""
779        self.init_prim_io_names(inputs=['y_grad', 'x', 'argmax'], outputs=['x_grad'])
780
781
782class MaxPoolGrad(_PoolGrad):
783    """Performs gradients of the max pool operation."""
784
785    @prim_attr_register
786    def __init__(self, kernel_size=1, strides=1, pad_mode="VALID", data_format="NCHW"):
787        super(MaxPoolGrad, self).__init__(kernel_size, strides, pad_mode, data_format)
788
789    def infer_shape(self, x1_shape, x2_shape, grad_shape):
790        return x1_shape
791
792    def infer_dtype(self, x1_dtype, x2_dtype, grad_dtype):
793        return x1_dtype
794
795
796class MaxPoolGradV1(Primitive):
797    """Performs gradients of the MaxPoolV1 operation."""
798
799    @prim_attr_register
800    def __init__(self, kernel_size=1, strides=1, pad_mode="VALID", data_format="NCHW"):
801        self.init_prim_io_names(
802            inputs=['x_origin', 'out_origin', 'grad'], outputs=['output'])
803
804        validator.check_value_type('kernel_size', kernel_size, [int, tuple], self.name)
805        validator.check_value_type('strides', strides, [int, tuple], self.name)
806        self.pad_mode = validator.check_string(
807            pad_mode.upper(), ['VALID', 'SAME'], 'pad_mode', self.name)
808        self.add_prim_attr("pad_mode", self.pad_mode)
809        self.format = validator.check_string(
810            data_format, ['NCHW', 'NHWC'], 'format', self.name)
811        self.add_prim_attr('data_format', self.format)
812
813        def _grad_check_int_or_tuple(arg_name, arg_val):
814            validator.check_value_type(
815                arg_name, arg_val, (int, tuple), self.name)
816            error_msg = ValueError(f"For '{self.name}' the '{arg_name}' should be an positive int number "
817                                   f"or a tuple of two or four positive int numbers, but got {arg_val}")
818            if isinstance(arg_val, int):
819                ret = (1, 1, arg_val, arg_val)
820            elif len(arg_val) == 2:
821                ret = (1, 1, arg_val[0], arg_val[1])
822            elif len(arg_val) == 4:
823                ret = arg_val
824            else:
825                raise error_msg
826            # whether all elements of tuple are positive integers
827            for item in ret:
828                if not isinstance(item, int) or item <= 0:
829                    raise error_msg
830            return ret
831
832        self.kernel_size = _grad_check_int_or_tuple("kernel_size", kernel_size)
833        self.strides = _grad_check_int_or_tuple("strides", strides)
834
835        kernel_size_adapted = self.kernel_size if self.format == 'NCHW' else (
836            self.kernel_size[0], self.kernel_size[2], self.kernel_size[3], self.kernel_size[1])
837        strides_adapted = self.strides if self.format == 'NCHW' else (
838            self.strides[0], self.strides[2], self.strides[3], self.strides[1])
839
840        if len(kernel_size) == 4:
841            kernel_size_adapted = kernel_size
842        if len(strides) == 4:
843            strides_adapted = strides
844
845        self.add_prim_attr("kernel_size", kernel_size_adapted)
846        self.add_prim_attr("strides", strides_adapted)
847
848
849class MaxPoolGradGrad(_PoolGrad):
850    r"""
851    Performs gradients of the MaxPoolGrad operation.
852
853    Args:
854        kernel_size (Union[int, tuple[int]]): The size of kernel used to take the maximum value,
855            is an int number that represents height and width are both kernel_size, or a tuple
856            of two int numbers that represent height and width respectively. Default: 1.
857        strides (Union[int, tuple[int]]): The distance of kernel moving, an int number that represents
858            the height and width of movement are both strides, or a tuple of two int numbers that
859            represent height and width of movement respectively. Default: 1.
860        pad_mode (str): The optional value for pad mode, is "same" or "valid", not case sensitive.
861            Default: "valid".
862
863            - same: Adopts the way of completion. The height and width of the output will be the same as
864              the input. The total number of padding will be calculated in horizontal and vertical
865              directions and evenly distributed to top and bottom, left and right if possible.
866              Otherwise, the last extra padding will be done from the bottom and the right side.
867
868            - valid: Adopts the way of discarding. The possible largest height and width of output
869              will be returned without padding. Extra pixels will be discarded.
870
871    Inputs:
872        - **origin_input** (Tensor) - Tensor with data format "NCHW".
873          For Ascend, data type must be float16. For CPU and GPU, data type support float16 and float32.
874        - **origin_output** (Tensor) - Data type same as `origin_input`.
875        - **grad** (Tensor) - Data type and shape same as `origin_input`.
876
877    Outputs:
878        Tensor, with data type same as `origin_input`. Shape same as `origin_output`.
879
880    Raises:
881        TypeError: If kernel_size is neither int nor a tuple of 2/4 int numbers.
882        TypeError: If strides is neither int nor a tuple of 2/4 int numbers.
883        TypeError: If pad_mode is not string.
884        ValueError: If pad_mode is neither "same" nor "valid"(not case sensitive).
885        TypeError: For Ascend, input data type is not float16. For CPU or GPU, input data type is neither
886        float16 nor float32.
887        ValueError: If the rank of `origin_input`, `origin_output` or `grad` is not equal to 4.
888        ValueError: If data types of all inputs are not equal.
889        ValueError: If the shapes of `origin_input` and `grad` are not equal.
890
891    Supported Platforms:
892        ``Ascend`` ``GPU`` ``CPU``
893    """
894
895    @prim_attr_register
896    def __init__(self, kernel_size=1, strides=1, pad_mode="VALID"):
897        super(MaxPoolGradGrad, self).__init__(kernel_size, strides, pad_mode)
898
899    def infer_shape(self, x1_shape, x2_shape, grad_shape):
900        return x2_shape
901
902    def infer_dtype(self, x1_dtype, x2_dtype, grad_dtype):
903        args = {'x1_dtype': x1_dtype, 'x2_dtype': x2_dtype, 'grad_dtype': grad_dtype}
904        validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16], self.name)
905        return x2_dtype
906
907
908def _get_max_pool3d_grad_pads_by_pad_mode(input_shape, kernel_size, strides, pad_mode):
909    """
910    helper for get max pool3d grad pads by pad_mode
911    """
912
913    def get_pad(origin_shape, ksize, stride):
914        tail = origin_shape % stride
915        pad = (ksize - tail) if tail > 0 else (ksize - stride)
916        pad = max(pad, 0)
917        pad1 = int(pad / 2)
918        pad2 = int(pad / 2) + pad % 2
919        return pad1, pad2
920
921    _, _, d, h, w = input_shape
922    _, _, kd, kh, kw = kernel_size
923    _, _, strd, strh, strw = strides
924
925    pads = (0, 0, 0, 0, 0, 0)
926    if pad_mode == 'SAME':
927        pads_d = get_pad(d, kd, strd)
928        pads_h = get_pad(h, kh, strh)
929        pads_w = get_pad(w, kw, strw)
930        pads = pads_d + pads_h + pads_w
931    return pads
932
933
934class MaxPool3DGrad(Primitive):
935    """Gradients of the max pool3d operation."""
936
937    @prim_attr_register
938    def __init__(self, kernel_size=(1, 1, 1, 1, 1), strides=(1, 1, 1, 1, 1),
939                 pad_mode='VALID', pad_list=0, data_format="NCDHW"):
940        self.init_prim_io_names(inputs=['x_origin', 'out_origin', 'grad'], outputs=['output'])
941        validator.check_value_type('kernel_size', kernel_size, [int, tuple], self.name)
942        validator.check_value_type('strides', strides, [int, tuple], self.name)
943        validator.check_value_type('pad_mode', pad_mode, [str], self.name)
944        self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.name)
945        if pad_mode.upper() == 'PAD':
946            pad_mode = 'CALCULATED'
947        self.pad_mode = validator.check_string(pad_mode.upper(), ['VALID', 'SAME', 'CALCULATED'], 'pad_mode', self.name)
948        self.kernel_size = _check_3d_int_or_tuple("kernel_size", kernel_size, self.name,
949                                                  allow_five=True, ret_five=True)
950        self.add_prim_attr("kernel_size", self.kernel_size)
951        self.strides = _check_3d_int_or_tuple("strides", strides, self.name, allow_five=True, ret_five=True)
952        self.add_prim_attr("strides", self.strides)
953        validator.check_value_type('pad_list', pad_list, (int, tuple), self.name)
954        self.pad_list = pad_list
955        if isinstance(self.pad_list, int):
956            self.pad_list = (self.pad_list,) * 6
957        if len(self.pad_list) == 3:
958            self.pad_list = (pad_list[0], pad_list[0], pad_list[1], pad_list[1], pad_list[2], pad_list[3])
959        if len(self.pad_list) != 3 and len(self.pad_list) != 6:
960            raise ValueError(f"For `maxpool3d` attr 'pad_list' must be an positive int number or a tuple of "
961                             f"three or six positive int numbers, but got `{len(self.pad_list)}` numbers.")
962        if self.pad_mode != 'CALCULATED' and self.pad_list != (0, 0, 0, 0, 0, 0):
963            raise ValueError(f"For '{self.name}', when pad_list is not 0, pad_mode must be set as 'pad'.")
964        if self.pad_mode == 'CALCULATED':
965            for item in self.pad_list:
966                validator.check_non_negative_int(item, 'pad_list item', self.name)
967        self.add_prim_attr("pad_list", self.pad_list)
968
969
970class MaxPool3DGradGrad(PrimitiveWithInfer):
971    r"""Gradients of the max pool3d grad operation.
972
973    Args:
974        kernel_size (Union[int, tuple[int]]): The size of kernel used to take the maximum value,
975            is an int number that represents depth, height and width are both kernel_size, or a tuple
976            of two int numbers that represent depth, height and width respectively. Default: 1.
977        strides (Union[int, tuple[int]]): The distance of kernel moving, an int number that represents
978            the depth, height and width of movement are both strides, or a tuple of two int numbers that
979            represent depth, height and width of movement respectively. Default: 1.
980        pad_mode (str): The optional value for pad mode, is "same" or "valid", not case sensitive.
981            Default: "valid".
982
983            - same: Adopts the way of completion. The depth, height and width of the output will be the
984              same as the input. The total number of padding will be calculated in depth, horizontal and
985              vertical directions and evenly distributed to front and back, top and bottom, left and
986              right if possible. Otherwise, the last extra padding will be done from the back, the bottom
987              and the right side.
988
989            - valid: Adopts the way of discarding. The possible largest height and width of output
990              will be returned without padding. Extra pixels will be discarded.
991
992    Inputs:
993        - **origin_input** (Tensor) - Tensor with data format "NCDHW".
994          For Ascend, data type must be float16. For CPU and GPU, data type support float16 and float32.
995        - **origin_output** (Tensor) - Data type same as `origin_input`.
996        - **grad** (Tensor) - Data type and shape same as `origin_input`.
997
998    Outputs:
999        Tensor, with data type same as `origin_input`. Shape same as `origin_output`.
1000
1001    Raises:
1002        TypeError: If kernel_size is neither int nor a tuple of 3/5 int numbers.
1003        TypeError: If strides is neither int nor a tuple of 3/5 int numbers.
1004        TypeError: If pad_mode is not string.
1005        ValueError: If pad_mode is neither "same" nor "valid"(not case sensitive).
1006        TypeError: For Ascend, input data type is not float16. For CPU or GPU, input data type is neither
1007        float16 nor float32.
1008        ValueError: If the rank of `origin_input`, `origin_output` or `grad` is not equal to 5.
1009        ValueError: If data types of all inputs are not equal.
1010        ValueError: If the shapes of `origin_input` and `grad` are not equal.
1011
1012    Supported Platforms:
1013        ``Ascend`` ``GPU`` ``CPU``
1014    """
1015
1016    @prim_attr_register
1017    def __init__(self, kernel_size=(1, 1, 1, 1, 1), strides=(1, 1, 1, 1, 1), pad_mode='VALID', data_format="NCDHW"):
1018        validator.check_value_type('kernel_size', kernel_size, [int, tuple], self.name)
1019        validator.check_value_type('strides', strides, [int, tuple], self.name)
1020        validator.check_value_type('pad_mode', pad_mode, [str], self.name)
1021        self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.name)
1022        self.pad_mode = validator.check_string(pad_mode.upper(), ['VALID', 'SAME'], 'pad_mode', self.name)
1023        self.kernel_size = _check_3d_int_or_tuple("kernel_size", kernel_size, self.name,
1024                                                  allow_five=True, ret_five=True)
1025        self.add_prim_attr("kernel_size", self.kernel_size)
1026        self.strides = _check_3d_int_or_tuple("strides", strides, self.name, allow_five=True, ret_five=True)
1027        self.add_prim_attr("strides", self.strides)
1028
1029    def infer_shape(self, x_shape, y_shape, grad_shape):
1030        validator.check_equal_int(len(x_shape), 5, "x rank", self.name)
1031        validator.check('x_shape', x_shape, 'grad_shape', grad_shape, prim_name=self.name)
1032        pad_list = _get_max_pool3d_grad_pads_by_pad_mode(x_shape, self.kernel_size, self.strides, self.pad_mode)
1033        for pad in pad_list:
1034            validator.check_non_negative_int(pad, 'element of pad_list', self.name)
1035        self.add_prim_attr("pad_list", pad_list)
1036        return y_shape
1037
1038    def infer_dtype(self, x_dtype, y_dtype, grad_dtype):
1039        args = {'x_dtype': x_dtype, 'y_dtype': y_dtype}
1040        validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name)
1041        validator.check_tensor_dtype_valid('grad_dtype', grad_dtype, [mstype.float16, mstype.float32], self.name)
1042        return x_dtype
1043
1044
1045class MaxPoolGradWithArgmax(Primitive):
1046    """Computes the gradients of MaxPoolWithArgmax."""
1047    @prim_attr_register
1048    def __init__(self, kernel_size=1, strides=1, pad_mode="VALID", data_format="NCHW"):
1049        self.init_prim_io_names(inputs=['x_origin', 'out_origin', 'grad'], outputs=['output'])
1050        validator.check_value_type('kernel_size', kernel_size, [int, tuple], self.name)
1051        validator.check_value_type('strides', strides, [int, tuple], self.name)
1052        validator.check_value_type('pad_mode', pad_mode, [str], self.name)
1053        self.pad_mode = validator.check_string(pad_mode.upper(), ['VALID', 'SAME'], 'pad_mode', self.name)
1054        self.add_prim_attr("pad_mode", self.pad_mode)
1055        self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name)
1056        if context.get_context("device_target") != "GPU" and self.format == "NHWC":
1057            raise ValueError("NHWC format only support in GPU target.")
1058        self.is_maxpoolgradwithargmax = (self.name == "MaxPoolGradWithArgmax")
1059        if not self.is_maxpoolgradwithargmax:
1060            self.add_prim_attr('data_format', self.format)
1061
1062        def _grad_check_int_or_tuple(arg_name, arg_val):
1063            validator.check_value_type(arg_name, arg_val, (int, tuple), self.name)
1064            error_msg = ValueError(f"For '{self.name}' the '{arg_name}' must be an positive int number "
1065                                   f"or a tuple of two or four positive int numbers, but got {arg_val}")
1066            if isinstance(arg_val, int):
1067                ret = (1, arg_val, arg_val, 1)
1068            elif len(arg_val) == 2:
1069                ret = (1, arg_val[0], arg_val[1], 1)
1070            elif len(arg_val) == 4:
1071                ret = arg_val
1072            else:
1073                raise error_msg
1074            # whether all elements of tuple are positive integers
1075            for item in ret:
1076                if not isinstance(item, int) or item <= 0:
1077                    raise error_msg
1078            return ret
1079
1080        kernel_size = _grad_check_int_or_tuple("kernel_size", kernel_size)
1081        self.kernel_size = kernel_size
1082        self.add_prim_attr("kernel_size", self.kernel_size)
1083
1084        strides = _grad_check_int_or_tuple("strides", strides)
1085        self.strides = strides
1086        self.add_prim_attr("strides", self.strides)
1087
1088
1089class MaxPoolGradWithArgmaxV2(Primitive):
1090    """Gradients of the MaxPoolWithArgmaxV2 operation."""
1091
1092    @prim_attr_register
1093    def __init__(self, kernel_size, strides=None, pads=0, dilation=(1, 1), ceil_mode=False, argmax_type=mstype.int64):
1094        self.init_prim_io_names(inputs=['x', 'grad', 'argmax'], outputs=['y'])
1095        self.kernel_size = _check_positive_int_or_tuple("kernel_size", kernel_size, self.name, allow_four=True,
1096                                                        ret_four=True)
1097        self.add_prim_attr('kernel_size', self.kernel_size)
1098        if strides is None:
1099            strides = kernel_size
1100        self.strides = _check_positive_int_or_tuple("strides", strides, self.name, allow_four=True, ret_four=True)
1101        self.add_prim_attr('strides', self.strides)
1102        self.pads = _check_positive_int_or_tuple("pads", pads, self.name, allow_four=True, ret_four=True,
1103                                                 strict_positive=False)
1104        self.add_prim_attr('pads', self.pads)
1105        validator.check_value_type('ceil_mode', ceil_mode, bool, self.name)
1106        self.add_prim_attr('ceil_mode', self.ceil_mode)
1107        self.dilation = _check_positive_int_or_tuple("dilation", dilation, self.name, allow_four=True, ret_four=True)
1108        self.add_prim_attr('dilation', self.dilation)
1109        self.add_prim_attr('argmax_type', self.argmax_type)
1110
1111
1112class MaxPool3DGradWithArgmax(Primitive):
1113    """Gradients of the maxpool3Dwithargmax operation."""
1114
1115    @prim_attr_register
1116    def __init__(self, ksize, strides, pads, dilation=(1, 1, 1), ceil_mode=False, data_format="NCDHW"):
1117        self.init_prim_io_names(inputs=['x', 'grads', 'argmax'], outputs=['y'])
1118        validator.check_value_type('ceil_mode', ceil_mode, bool, self.name)
1119        validator.check_value_type('data_format', data_format, str, self.name)
1120        self.data_format = validator.check_string(data_format, ['NCDHW'], 'data_format', self.name)
1121        self.ksize = _check_3d_int_or_tuple("ksize", ksize, self.name, ret_five=False)
1122        self.add_prim_attr('ksize', self.ksize)
1123        self.strides = _check_3d_int_or_tuple("strides", strides, self.name, ret_five=False)
1124        self.add_prim_attr('strides', self.strides)
1125        self.pads = _check_3d_int_or_tuple("pads", pads, self.name, greater_zero=False, ret_five=False)
1126        self.add_prim_attr('pads', self.pads)
1127        self.dilation = _check_3d_int_or_tuple("dilation", dilation, self.name, allow_five=True, ret_five=False)
1128        self.add_prim_attr('dilation', self.dilation)
1129
1130
1131class MaxPoolGradGradWithArgmax(_PoolGrad):
1132    r"""
1133    Computes the gradients of MaxPoolGradWithArgmax.
1134
1135    Args:
1136        kernel_size (Union[int, tuple[int]]): The size of kernel used to take the maximum value,
1137            is an int number that represents height and width are both kernel_size, or a tuple
1138            of two int numbers that represent height and width respectively. Default: 1.
1139        strides (Union[int, tuple[int]]): The distance of kernel moving, an int number that represents
1140            the height and width of movement are both strides, or a tuple of two int numbers that
1141            represent height and width of movement respectively. Default: 1.
1142        pad_mode (str): The optional value for pad mode, is "same" or "valid", not case sensitive.
1143            Default: "valid".
1144
1145            - same: Adopts the way of completion. The height and width of the output will be the same as
1146              the input. The total number of padding will be calculated in horizontal and vertical
1147              directions and evenly distributed to top and bottom, left and right if possible.
1148              Otherwise, the last extra padding will be done from the bottom and the right side.
1149
1150            - valid: Adopts the way of discarding. The possible largest height and width of output
1151              will be returned without padding. Extra pixels will be discarded.
1152
1153    Inputs:
1154        - **x** (Tensor) - Tensor with data format "NCHW".
1155          For Ascend, data type must be float16. For CPU and GPU, data type support float16 and float32.
1156        - **grad** (Tensor) - Data type and shape same as `x`.
1157        - **argmax** (Tensor) - Data type must be int32 or int64.
1158
1159    Outputs:
1160        Tensor, with data type same as `x`. Shape same as `argmax`.
1161
1162    Raises:
1163        TypeError: If kernel_size is neither int nor a tuple of 2/4 int numbers.
1164        TypeError: If strides is neither int nor a tuple of 2/4 int numbers.
1165        TypeError: If pad_mode is not string.
1166        ValueError: If pad_mode is neither "same" nor "valid"(not case sensitive).
1167        TypeError: For Ascend, the data types of `x` and `grad` are not float16.
1168        For CPU or GPU, the data types of `x` and `grad` are neither float16 nor float32.
1169        TypeError: The data type of `argmax` is neither int32 nor int64.
1170        ValueError: If the rank of `x`, `grad` or `argmax` is not equal to 4.
1171        ValueError: If the shapes of `x` and `grad` are not equal.
1172
1173    Supported Platforms:
1174        ``Ascend`` ``GPU`` ``CPU``
1175    """
1176
1177    @prim_attr_register
1178    def __init__(self, kernel_size=1, strides=1, pad_mode="VALID"):
1179        self.init_prim_io_names(inputs=['x', 'grad', 'argmax'], outputs=['output'])
1180        super(MaxPoolGradGradWithArgmax, self).__init__(kernel_size, strides, pad_mode)
1181
1182    def infer_shape(self, x_shape, grad_shape, argmax_shape):
1183        if not grad_shape:
1184            raise TypeError("The dout of MaxPoolGradGradWithArgmax must be a Tensor.")
1185        return x_shape
1186
1187    def infer_dtype(self, x_dtype, grad_dtype, argmax_dtype):
1188        args = {'x_dtype': x_dtype, 'grad_dtype': grad_dtype}
1189        validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16], self.name)
1190        return grad_dtype
1191
1192
1193class MinimumGradGrad(Primitive):
1194    """Grad for minimum_grad."""
1195    @prim_attr_register
1196    def __init__(self):
1197        """Initialize MinimumGradGrad"""
1198        super().__init__("MinimumGradGrad")
1199        self.init_prim_io_names(inputs=['x1', 'x2', 'grad_y1', 'grad_y2'],
1200                                outputs=['sopd_x1', 'sopd_x2', 'sopd_grads'])
1201
1202
1203class L2NormalizeGrad(Primitive):
1204    r"""
1205    Gradients of L2 normalize.
1206
1207    Args:
1208        axis (Union[list(int), tuple(int), int]): The begin axis for the input to apply L2 normalize. Default: 0.
1209        epsilon (float): A small value added for numerical stability. Default: 1e-4.
1210
1211    Inputs:
1212        - **input_x** (Tensor) - Must be the input `weight` of forward operator L2Normalize.
1213        - **out** (Tensor) - Must be the output of forward operator L2Normalize.
1214        - **dout** (Tensor) - The backprop of the next layer.
1215
1216    Outputs:
1217        Tensor, gradients of L2Normalize `input_x`.
1218    """
1219
1220    @prim_attr_register
1221    def __init__(self, axis=0, epsilon=1e-4):
1222        axis = [axis] if isinstance(axis, int) else axis
1223        validator.check_value_type('axis', axis, [list, tuple], self.name)
1224        validator.check_value_type('epsilon', epsilon, [int, float], self.name)
1225        self.add_prim_attr('axis', axis)
1226        self.init_attrs['axis'] = axis
1227        if len(axis) != 1:
1228            raise TypeError("The length of axis must be 1, later will support multiple axis!")
1229
1230
1231class LSTMGradData(Primitive):
1232    """Computes the data gradients of LSTM."""
1233
1234    @prim_attr_register
1235    def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout):
1236        self.input_size = validator.check_positive_int(input_size, 'input_size', self.name)
1237        self.hidden_size = validator.check_positive_int(hidden_size, 'hidden_size', self.name)
1238        self.num_layers = validator.check_positive_int(num_layers, 'num_layers', self.name)
1239        self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name)
1240        self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name)
1241        self.dropout = validator.check_value_type("dropout", dropout, [float], self.name)
1242        self.dropout = validator.check_float_range(dropout, 0, 1, validator.INC_BOTH, 'dropout', self.name)
1243
1244        if bidirectional:
1245            self.num_directions = 2
1246        else:
1247            self.num_directions = 1
1248
1249
1250class LSTMGradWeight(Primitive):
1251    """Computes the weight gradients of LSTM."""
1252
1253    @prim_attr_register
1254    def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout):
1255        self.input_size = validator.check_positive_int(input_size, 'input_size', self.name)
1256        self.hidden_size = validator.check_positive_int(hidden_size, 'hidden_size', self.name)
1257        self.num_layers = validator.check_positive_int(num_layers, 'num_layers', self.name)
1258        self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name)
1259        self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name)
1260        self.dropout = validator.check_value_type("dropout", dropout, [float], self.name)
1261        self.dropout = validator.check_float_range(dropout, 0, 1, validator.INC_BOTH, 'dropout', self.name)
1262
1263        if bidirectional:
1264            self.num_directions = 2
1265        else:
1266            self.num_directions = 1
1267
1268
1269class LSTMGrad(Primitive):
1270    """Computes the data and weight gradients of LSTM."""
1271
1272    @prim_attr_register
1273    def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout, proj_size=0):
1274        self.input_size = validator.check_positive_int(input_size, 'input_size', self.name)
1275        self.hidden_size = validator.check_positive_int(hidden_size, 'hidden_size', self.name)
1276        self.proj_size = validator.check_int_range(proj_size, 0, hidden_size, validator.INC_LEFT,
1277                                                   'proj_size', self.name)
1278        self.num_layers = validator.check_positive_int(num_layers, 'num_layers', self.name)
1279        self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name)
1280        self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name)
1281        self.dropout = validator.check_value_type("dropout", dropout, [float], self.name)
1282        self.dropout = validator.check_float_range(dropout, 0, 1, validator.INC_BOTH, 'dropout', self.name)
1283        self.proj_size = validator.check_int_range(proj_size, 0, hidden_size, Rel.INC_LEFT,
1284                                                   'proj_size', self.name)
1285
1286
1287        if bidirectional:
1288            self.num_directions = 2
1289        else:
1290            self.num_directions = 1
1291
1292    def infer_shape(self, x_shape, hx_shape, cx_shape, w_shape, y_shape, hy_shape, cy_shape, dy_shape, dhy_shape,
1293                    dcy_shape, reserve_shape):
1294        # dhy and dcy should be same shape
1295        validator.check_equal_int(len(dhy_shape), 3, "h_shape", self.name)
1296        validator.check_equal_int(len(dhy_shape), len(dcy_shape), "h_shape", self.name)
1297        if self.proj_size == 0:
1298            validator.check_equal_int(dhy_shape[0], dcy_shape[0], "h_shape[0]", self.name)
1299            validator.check_equal_int(dhy_shape[1], dcy_shape[1], "h_shape[1]", self.name)
1300            validator.check_equal_int(dhy_shape[2], dcy_shape[2], "h_shape[2]", self.name)
1301
1302        real_hidden_size = self.proj_size if self.proj_size > 0 else self.hidden_size
1303        validator.check_int(dhy_shape[0], self.num_layers * self.num_directions, Rel.EQ, "h_shape[0]", self.name)
1304        validator.check_equal_int(dhy_shape[2], real_hidden_size, "h_shape[2]", self.name)
1305
1306        validator.check_equal_int(len(dy_shape), 3, "dy_shape", self.name)
1307        validator.check_equal_int(dy_shape[1], dhy_shape[1], "dy[1]", self.name)
1308        validator.check_int(dy_shape[2], real_hidden_size * self.num_directions, Rel.EQ, "dy[2]", self.name)
1309
1310        dx_shape = (y_shape[0], y_shape[1], self.input_size)
1311        dhx_shape = dhy_shape
1312        dcx_shape = dcy_shape
1313        weight_size = 0
1314        gate_size = 4 * self.hidden_size
1315        for layer in range(self.num_layers):
1316            for _ in range(self.num_directions):
1317                input_layer_size = self.input_size if layer == 0 else self.hidden_size * self.num_directions
1318                weight_size += gate_size * input_layer_size
1319                weight_size += gate_size * real_hidden_size
1320                if self.proj_size > 0:
1321                    weight_size += self.proj_size * self.hidden_size
1322                if self.has_bias:
1323                    weight_size += gate_size
1324
1325        return (dx_shape, dhx_shape, dcx_shape, (weight_size, 1, 1))
1326
1327    def infer_dtype(self, x_dtype, hx_dtype, cx_dtype, w_dtype, y_dtype, hy_dtype, cy_dtype, dy_dtype, dhy_dtype,
1328                    dcy_dtype, reserve_dtype):
1329        return (dy_dtype, dy_dtype, dy_dtype, hx_dtype)
1330
1331class DynamicRNNGrad(Primitive):
1332    """Computes the input gradients of DynamicRNN."""
1333
1334    @prim_attr_register
1335    def __init__(self,
1336                 cell_type='LSTM',
1337                 direction='UNIDIRECTIONAL',
1338                 cell_depth=1,
1339                 use_peephole=False,
1340                 keep_prob=1.0,
1341                 cell_clip=-1.0,
1342                 num_proj=0,
1343                 time_major=True,
1344                 forget_bias=0.0):
1345        self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name)
1346
1347
1348class GruGradData(PrimitiveWithInfer):
1349    """Computes the data gradients of GRU."""
1350
1351    @prim_attr_register
1352    def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout):
1353        self.input_size = validator.check_positive_int(input_size, 'input_size', self.name)
1354        self.hidden_size = validator.check_positive_int(hidden_size, 'hidden_size', self.name)
1355        self.num_layers = validator.check_positive_int(num_layers, 'num_layers', self.name)
1356        self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name)
1357        self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name)
1358        self.dropout = validator.check_value_type("dropout", dropout, [float], self.name)
1359        self.dropout = validator.check_float_range(dropout, 0, 1, validator.INC_BOTH, 'dropout', self.name)
1360
1361        if bidirectional:
1362            self.num_directions = 2
1363        else:
1364            self.num_directions = 1
1365
1366    def infer_shape(self, y_shape, dy_shape, dhy_shape, w_shape,
1367                    hx_shape, reserve_shape, state_shape):
1368        # dhy and dcy should be same shape
1369        validator.check_equal_int(len(dhy_shape), 3, "h_shape", self.name)
1370
1371        validator.check_int(dhy_shape[0], self.num_layers * self.num_directions, validator.EQ, "h_shape[0]", self.name)
1372        validator.check_equal_int(dhy_shape[2], self.hidden_size, "h_shape[2]", self.name)
1373
1374        validator.check_equal_int(len(dy_shape), 3, "dy_shape", self.name)
1375        validator.check_equal_int(dy_shape[1], dhy_shape[1], "dy[1]", self.name)
1376        validator.check_int(dy_shape[2], self.hidden_size * self.num_directions, validator.EQ, "dy[2]", self.name)
1377
1378        dx_shape = (y_shape[0], y_shape[1], self.input_size)
1379        dhx_shape = dhy_shape
1380
1381        return (dx_shape, dhx_shape)
1382
1383    def infer_dtype(self, y_dtype, dy_dtype, dhy_dtype, w_dtype,
1384                    hx_dtype, reserve_dtype, state_dtype):
1385        args = {"dy": dy_dtype, "dhy": dhy_dtype}
1386        validator.check_tensors_dtypes_same_and_valid(args, (mstype.float32, mstype.float16), self.name)
1387        return (dy_dtype, dy_dtype)
1388
1389
1390class GruGradWeight(PrimitiveWithInfer):
1391    """Computes the weight gradients of GRU."""
1392
1393    @prim_attr_register
1394    def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout):
1395        self.input_size = validator.check_positive_int(input_size, 'input_size', self.name)
1396        self.hidden_size = validator.check_positive_int(hidden_size, 'hidden_size', self.name)
1397        self.num_layers = validator.check_positive_int(num_layers, 'num_layers', self.name)
1398        self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name)
1399        self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name)
1400        self.dropout = validator.check_value_type("dropout", dropout, [float], self.name)
1401        self.dropout = validator.check_float_range(dropout, 0, 1, validator.INC_BOTH, 'dropout', self.name)
1402
1403        if bidirectional:
1404            self.num_directions = 2
1405        else:
1406            self.num_directions = 1
1407
1408    def infer_shape(self, x_shape, hx_shape, y_shape, reserve_shape, state_shape):
1409        weight_size = 0
1410        gate_size = 3 * self.hidden_size
1411        for layer in range(self.num_layers):
1412            for _ in range(self.num_directions):
1413                input_layer_size = self.input_size if layer == 0 else self.hidden_size * self.num_directions
1414                weight_size += gate_size * input_layer_size
1415                weight_size += gate_size * self.hidden_size
1416                if self.has_bias:
1417                    weight_size += 2 * gate_size
1418
1419        return (weight_size, 1, 1)
1420
1421    def infer_dtype(self, x_dtype, hx_dtype, y_dtype, reserve_dtype, state_dtype):
1422        return hx_dtype
1423
1424
1425class GRUV2Grad(Primitive):
1426    """Computes the grad gradients of GRU."""
1427
1428    @prim_attr_register
1429    def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout):
1430        self.input_size = validator.check_positive_int(input_size, 'input_size', self.name)
1431        self.hidden_size = validator.check_positive_int(hidden_size, 'hidden_size', self.name)
1432        self.num_layers = validator.check_positive_int(num_layers, 'num_layers', self.name)
1433        self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name)
1434        self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name)
1435        self.dropout = validator.check_value_type("dropout", dropout, [float], self.name)
1436        self.dropout = validator.check_float_range(dropout, 0, 1, validator.INC_BOTH, 'dropout', self.name)
1437
1438        if bidirectional:
1439            self.num_directions = 2
1440        else:
1441            self.num_directions = 1
1442
1443
1444class DynamicGRUV2Grad(Primitive):
1445    r"""
1446    Computes the input gradients of DynamicGRUV2.
1447
1448    Args:
1449        direction (str): A string identifying the direction in the op. Default: 'UNIDIRECTIONAL'.
1450            Only 'UNIDIRECTIONAL' is currently supported.
1451        cell_depth (int): An integer identifying the cell depth in the op. Default: 1.
1452        keep_prob (float): A float identifying the keep prob in the op. Default: 1.0.
1453        cell_clip (float): A float identifying the cell clip in the op. Default: -1.0.
1454        num_proj (int): An integer identifying the num proj in the op. Default: 0.
1455        time_major (bool): A bool identifying the time major in the op. Default: ``True``.
1456        gate_order (str): An string identifying the gate order in weight and bias. Default: 'rzh.
1457            'zrh' is another option.
1458        reset_after (bool): An bool identifying whether to apply reset gate after matrix multiplication.
1459            Default: ``True``.
1460
1461    Inputs:
1462        - **x** (Tensor) - Current words. Tensor of shape :math:`(num\_step, batch\_size, input\_size)`.
1463          The data type must be float16 or float32.
1464        - **weight_input** (Tensor) - Weight. Tensor of shape :math:`(input\_size, 3 x hidden\_size)`.
1465          The data type must be float16 or float32.
1466        - **weight_hidden** (Tensor) - Bias. Tensor of shape :math:`(hidden\_size, 3 x hidden\_size)`.
1467          The data type must be float16 or float32.
1468        - **y** (Tensor) - A Tensor of shape :math:
1469          if num_proj > 0 `(num_step, batch_size, min(hidden_size, num_proj)`,
1470          if num_proj == 0 `(num_step, batch_size, hidden_size)`.
1471          The data type must be float16 or float32.
1472        - **init_h** (Tensor) - Hidden state of initial time.
1473          Tensor of shape :math:`(batch\_size, hidden\_size)`.
1474          The data type must be float16 or float32.
1475        - **h** (Tensor) - A Tensor of shape :math:`(num\_step, batch\_size, hidden\_size)`.
1476          The data type must be float16 or float32.
1477        - **dy** (Tensor) - Gradient of `y`, has the same shape and data type as `y`.
1478        - **dh** (Tensor) - Gradient of `h`, has the same shape and data type as `init_h`.
1479        - **update** (Tensor) - A Tensor of shape :math:`(num\_step, batch\_size, hidden\_size)`.
1480          The data type must be float16 or float32.
1481        - **reset** (Tensor) - A Tensor of shape :math:`(num\_step, batch\_size, hidden\_size)`.
1482          The data type must be float16 or float32.
1483        - **new** (Tensor) - A Tensor of shape :math:`(num\_step, batch\_size, hidden\_size)`.
1484          The data type must be float16 or float32.
1485        - **hidden_new** (Tensor) - A Tensor of shape :math:`(num\_step, batch\_size, hidden\_size)`.
1486          The data type must be float16 or float32.
1487        - **seq_length** (Tensor) - The length of each batch. Tensor of shape :math:`(batch\_size)`.
1488          Only `None` is currently supported.
1489        - **mask** (Tensor) - A 4-D Tensor. The data type must be float16 or float32.
1490
1491    Outputs:
1492        - **dw_input** (Tensor) - A Tensor has the same shape as `weight_input`.
1493          Has the same type with input `x`.
1494        - **dw_hidden** (Tensor) - A Tensor has the same shape as `weight_hidden`.
1495          Has the same type with input `x`.
1496        - **db_input** (Tensor) - A Tensor of shape :math:`(3 x hidden\_size)`.
1497          Has the same type with input `init\_h`.
1498        - **db_hidden** (Tensor) - A Tensor of shape :math:`(3 x hidden\_size)`.
1499          Has the same type with input `init\_h`.
1500        - **dx** (Tensor) - A Tensor of shape :math:`(num\_step, batch\_size, hidden\_size)`.
1501          Has the same type with input `x`.
1502        - **dh_prev** (Tensor) - A Tensor of shape :math:`(batch\_size, hidden\_size)`.
1503          Has the same type with input `init\_h`.
1504    """
1505
1506    @prim_attr_register
1507    def __init__(self,
1508                 direction='UNIDIRECTIONAL',
1509                 cell_depth=1,
1510                 keep_prob=1.0,
1511                 cell_clip=-1.0,
1512                 num_proj=0,
1513                 time_major=True,
1514                 gate_order="rzh",
1515                 reset_after=True):
1516        self.cell_depth = validator.check_value_type("cell_depth", cell_depth, [int], self.name)
1517        self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name)
1518        self.cell_clip = validator.check_value_type("cell_clip", cell_clip, [float], self.name)
1519        self.num_proj = validator.check_non_negative_int(num_proj, "num_proj", self.name)
1520        self.time_major = validator.check_value_type("time_major", time_major, [bool], self.name)
1521        self.direction = validator.check_string(direction, ['UNIDIRECTIONAL'], "direction", self.name)
1522        self.gate_order = validator.check_string(gate_order, ['zrh', 'rzh'], "gate_order", self.name)
1523        self.reset_after = validator.check_value_type("reset_after", reset_after, [bool], self.name)
1524        self.init_prim_io_names(inputs=[
1525            "x", "weight_input", "weight_hidden", "y", "init_h", "h", "dy",
1526            "dh", "update", "reset", "new", "hidden_new", "seq_length", "mask"
1527        ],
1528                                outputs=[
1529                                    "dw_input", "dw_hidden", "db_input",
1530                                    "db_hidden", "dx", "dh_prev"
1531                                ])
1532
1533
1534class RandomGammaGrad(Primitive):
1535    r"""
1536    Computes the derivative of a random sample of Gamma with respect to alpha.:
1537
1538    Inputs:
1539        - **alpha** (Tensor) - α is the shape parameter of RandomGamma distribution.
1540        It must be greater than 0. Must be one of the following types: float32, float64.
1541        - **sample** (Tensor) - The sample of random gamma tensor. Must be one of the
1542        following types: float32, float64.
1543
1544    Outputs:
1545        The dtype is the same type as alpha.
1546        The output shape is derived from the input through broadcasting.
1547
1548    Raises:
1549        TypeError: If data type of `alpha` and `sample` is not float32 or float64.
1550        TypeError: If data type of `alpha` and `sample` is not same.
1551        ValueError: If the shape last dim of `sample` and `alpha` is not equal.
1552
1553    Supported Platforms:
1554        ``GPU``
1555
1556    Examples:
1557        >>> alpha = Tensor(np.array([1., 0.6, 3., 26.]), mstype.float32)
1558        >>> sample = Tensor(np.array([6., 7, 11., 0.5]), mstype.float32)
1559        >>> randomgammagrad = ops.RandomGammaGrad()
1560        >>> output = randomgammagrad(alpha, sample)
1561        >>> print(output)
1562        [2.5142431 3.4334087 1.8847835 0.07780622]
1563    """
1564
1565    @prim_attr_register
1566    def __init__(self):
1567        """Initialize RandomGammaGrad"""
1568        self.init_prim_io_names(inputs=['alpha', 'sample'], outputs=['output'])
1569        self.add_prim_attr("side_effect_hidden", True)
1570
1571
1572class ROIAlignGrad(Primitive):
1573    """
1574    ROIAlignGrad operator.
1575
1576    Args:
1577       pooled_height (int): The output feature height.
1578       pooled_width (int): The output feature width.
1579       spatial_scale (float): The feature stride.
1580       sample_num (int): Number of sampling points. Default: 2.
1581    """
1582
1583    @prim_attr_register
1584    def __init__(self, pooled_height, pooled_width, spatial_scale, sample_num=2):
1585        """Initialize ROIAlignGrad"""
1586        self.init_prim_io_names(inputs=["dy", "rois", "xdiff_shape"], outputs=["dx"])
1587        validator.check_value_type("pooled_height", pooled_height, [int], self.name)
1588        validator.check_value_type("pooled_width", pooled_width, [int], self.name)
1589        validator.check_value_type("spatial_scale", spatial_scale, [float], self.name)
1590        validator.check_value_type("sample_num", sample_num, [int], self.name)
1591        self.pooled_height = pooled_height
1592        self.pooled_width = pooled_width
1593        self.spatial_scale = spatial_scale
1594        self.sample_num = sample_num
1595
1596
1597class PsROIPoolingGrad(PrimitiveWithInfer):
1598    """
1599    PsROIPoolingGrad operator.
1600    """
1601
1602    @prim_attr_register
1603    def __init__(self, batch_size, channels, height, width, num_rois,
1604                 pooled_height, pooled_width, spatial_scale, out_dim):
1605        """Initialize PsROIPoolingGrad"""
1606        validator.check_value_type("batch_size", batch_size, [int], self.name)
1607        validator.check_value_type("channels", channels, [int], self.name)
1608        validator.check_value_type("height", height, [int], self.name)
1609        validator.check_value_type("width", width, [int], self.name)
1610        validator.check_value_type("num_rois", num_rois, [int], self.name)
1611        validator.check_value_type("pooled_height", pooled_height, [int], self.name)
1612        validator.check_value_type("pooled_width", pooled_width, [int], self.name)
1613        validator.check_value_type("spatial_scale", spatial_scale, [float], self.name)
1614        validator.check_value_type("out_dim", out_dim, [int], self.name)
1615        self.batch_size = batch_size
1616        self.channels = channels
1617        self.height = height
1618        self.width = width
1619        self.num_rois = num_rois
1620        self.pooled_height = pooled_height
1621        self.pooled_width = pooled_width
1622        self.spatial_scale = spatial_scale
1623        self.out_dim = out_dim
1624
1625    def infer_shape(self, ydiff_shape, rois_shape, mapping_channel_shape):
1626        return [self.batch_size, self.channels, self.height, self.width]
1627
1628    def infer_dtype(self, ydiff_type, rois_type, mapping_channel_type):
1629        return ydiff_type
1630
1631
1632class _ActivationGrad(PrimitiveWithInfer):
1633    """_ActivationGrad base class."""
1634
1635    @prim_attr_register
1636    def __init__(self):
1637        self.init_prim_io_names(inputs=['y_grad', 'x'], outputs=['output'])
1638
1639    def infer_shape(self, y_grad_shape, x_shape):
1640        return x_shape
1641
1642    def infer_dtype(self, y_grad_dtype, x_dtype):
1643        valid_dtypes = (mstype.float16, mstype.float32)
1644        validator.check_tensor_dtype_valid("y_grad", y_grad_dtype, valid_dtypes, self.name)
1645        validator.check_tensor_dtype_valid("x", x_dtype, valid_dtypes, self.name)
1646        return x_dtype
1647
1648
1649class SigmoidCrossEntropyWithLogitsGrad(Primitive):
1650    """Computes the gradients of `SigmoidCrossEntropyWithLogits`."""
1651
1652    @prim_attr_register
1653    def __init__(self):
1654        """Initialize SigmoidCrossEntropyWithLogitsGrad"""
1655        self.init_prim_io_names(inputs=['x', 'y', 'dout'], outputs=['x_grad'])
1656
1657
1658class SliceGrad(PrimitiveWithInfer):
1659    """Reverse of slice."""
1660
1661    @prim_attr_register
1662    def __init__(self):
1663        """Initialize SliceGrad"""
1664        self.init_prim_io_names(inputs=['dy', 'x', 'begin', 'size'], outputs=['dx'])
1665
1666    def __infer__(self, dy, x, begin, size):
1667        dy_shape, x_shape, size_value, begin_v = dy['shape'], x['shape'], size['value'], begin['value']
1668        dy_shape_len = len(dy_shape)
1669        if size_value is not None and not is_shape_unknown(x_shape) and not is_shape_unknown(dy_shape):
1670            size_value = list(size_value)
1671            for i in range(dy_shape_len):
1672                if size_value[i] == -1:
1673                    size_value[i] = x_shape[i] - begin_v[i]
1674                validator.check(f'dy_shape[{i}]', dy_shape[i], f'x_shape[{i}]', x_shape[i], validator.LE, self.name)
1675                validator.check(f'dy_shape[{i}]', dy_shape[i], f'size_shape[{i}]',
1676                                size_value[i], validator.EQ, self.name)
1677
1678        return {'shape': x_shape,
1679                'dtype': x['dtype'],
1680                'value': None}
1681
1682
1683class SmoothL1LossGrad(Primitive):
1684    """Computes gradient for prediction on SmoothL1Loss."""
1685
1686    @prim_attr_register
1687    def __init__(self, beta=1.0, reduction='none'):
1688        self.add_prim_attr('sigma', self.beta)
1689        self.reduction = validator.check_string(
1690            reduction, ['none', 'sum', 'mean'], 'reduction', self.name)
1691
1692
1693class SoftMarginLossGrad(Primitive):
1694    """Computes gradient for prediction on SoftMarginLoss."""
1695
1696    @prim_attr_register
1697    def __init__(self, reduction="mean"):
1698        self.init_prim_io_names(inputs=['predict', 'label', "dout"], outputs=['gradient'])
1699        self.reduction = validator.check_string(reduction, ['none', 'sum', 'mean'], 'reduction', self.name)
1700
1701
1702class StridedSliceV2Grad(Primitive):
1703    """
1704    Performs grad of StridedSliceV2 operation.
1705
1706    Inputs:
1707        - **shapex** (Tensor) - StridedSliceV2 shape of input
1708        - **begin** (tuple[int]) - A tuple which represents the location where to start. Only
1709          constant value is allowed.
1710        - **end** (tuple[int]) - A tuple or which represents the maximum location where to end.
1711          Only constant value is allowed.
1712        - **strides** (tuple[int]) - A tuple which represents the stride is continuously added
1713          before reaching the maximum location. Only constant value is allowed.
1714        - **dy** (Tensor) - The output of StridedSliceV2
1715
1716    Outputs:
1717        Tensor, the shape same as the input of StridedSliceV2
1718    """
1719
1720    @prim_attr_register
1721    def __init__(self,
1722                 begin_mask=0,
1723                 end_mask=0,
1724                 ellipsis_mask=0,
1725                 new_axis_mask=0,
1726                 shrink_axis_mask=0):
1727        """Initialize StridedSliceV2Grad"""
1728        self.init_prim_io_names(inputs=['shapex', 'begin', 'end', 'strides', 'dy'], outputs=['output'])
1729
1730
1731class StridedSliceGrad(Primitive):
1732    """
1733    Performs grad of StridedSlice operation.
1734
1735    Args:
1736        begin_mask (int): Start indexing the slice. Default: 0.
1737        end_mask (int): End indexing the slice. Default: 0.
1738        ellipsis_mask (int): An int32 mask. Default: 0.
1739        new_axis_mask (int): An int32 mask. Default: 0.
1740        shrink_axis_mask (int): An int32 mask. Default: 0.
1741
1742    Returns:
1743        Tensor, has the same shape of input.
1744    """
1745
1746    @prim_attr_register
1747    def __init__(self,
1748                 begin_mask=0,
1749                 end_mask=0,
1750                 ellipsis_mask=0,
1751                 new_axis_mask=0,
1752                 shrink_axis_mask=0):
1753        """Initialize StridedSliceGrad"""
1754        validator.check_value_type('begin_mask', begin_mask, [int], self.name)
1755        validator.check_value_type('end_mask', end_mask, [int], self.name)
1756        validator.check_value_type('ellipsis_mask', ellipsis_mask, [int], self.name)
1757        validator.check_value_type('new_axis_mask', new_axis_mask, [int], self.name)
1758        validator.check_value_type('shrink_axis_mask', shrink_axis_mask, [int], self.name)
1759        self.init_prim_io_names(inputs=['dy', 'shapex', 'begin', 'end', 'strides'], outputs=['output'])
1760
1761
1762class SoftplusGrad(Primitive):
1763    """Computes gradient for the Softplus activation."""
1764
1765    @prim_attr_register
1766    def __init__(self):
1767        self.init_prim_io_names(inputs=['gradients', 'features'], outputs=['backprops'])
1768
1769
1770class TanhGrad(Primitive):
1771    """Computes gradient of hyperbolic tangent of input element-wise."""
1772
1773    @prim_attr_register
1774    def __init__(self):
1775        """Initialize TanhGrad"""
1776        self.init_prim_io_names(inputs=['y', 'dy'], outputs=['z'])
1777
1778
1779class MirrorPadGrad(Primitive):
1780    """Gradients of MirrorPad operation."""
1781
1782    @prim_attr_register
1783    def __init__(self, mode="REFLECT"):
1784        """Initialize MirrorPad"""
1785        self.init_prim_io_names(inputs=['dy', 'paddings'], outputs=['output'])
1786        validator.check_string(mode, ['REFLECT', 'SYMMETRIC'], 'mode', self.name)
1787        self.mode = mode
1788
1789
1790class PadV3Grad(Primitive):
1791    """Gradients of PadV3 operation."""
1792
1793    @prim_attr_register
1794    def __init__(self, mode='reflect', paddings_contiguous=True):
1795        """Initialize Padv3Grad"""
1796        self.add_prim_attr("cust_aicpu", self.name)
1797        self.init_prim_io_names(inputs=['x', 'paddings'], outputs=['y'])
1798        validator.check_string(mode, ['reflect', 'edge', 'circular'], 'mode', self.name)
1799        validator.check_bool(paddings_contiguous, "paddings_contiguous", self.name)
1800        self.mode = mode
1801        self.paddings_contiguous = paddings_contiguous
1802
1803
1804class EmbeddingLookupCommGrad(PrimitiveWithInfer):
1805    """
1806    Performs the gradient for the communication part of EmbeddingLookup operator.
1807
1808    This works ONLY when 'reduce_scatter_flag' is True in 'EmbeddingLookup'. Roughly speaking,
1809    this primitive is implemented by StridedSlice --> _HostAllGather --> Concat. This primitive runs on host.
1810    """
1811
1812    @prim_attr_register
1813    def __init__(self):
1814        self.init_prim_io_names(inputs=['dy', 'split_num'], outputs=['output'])
1815        self.set_device('CPU')
1816        self.tuple_setitem = Primitive('tuple_setitem')
1817
1818    def __infer__(self, dy, split_num):
1819        """
1820        This primitive is implemented by three steps:
1821            1) Splits the 'dy' along dimension 0 into 'split_num' parts.
1822            2) For each part, perform _HostAllGather((0, 1, 2, 3, 4, 5, 6, 7)) on the host.
1823            3) After _HostAllGather, there are still 'split_num' parts in each process. Then, perform Concat on them
1824              along dimension 0.
1825
1826        The output shape of this primitive: shape(output)[0] == shape(dy)[0] * 8
1827        """
1828        dy_shape = tuple(dy['shape'])
1829        split_num_value = split_num['value']
1830        validator.check_value_type("split_num_value", split_num_value, [int], self.name)
1831        dy_shape_all = self.tuple_setitem(dy_shape, 0, dy_shape[0] * 8)
1832        return {'shape': dy_shape_all,
1833                'dtype': dy['dtype'],
1834                'value': None}
1835
1836
1837class RefToEmbed(Primitive):
1838    r"""
1839    Make a key from Ref.
1840
1841    The Key is a symbolic_key, is a embedding on Parameter, which is used as a key of the variable in env_type,
1842    and get items by operation `EnvironGet` with the symbolic_key instance. The `Parameter` is a ref.
1843
1844    Inputs:
1845        - **input** (Ref) - Target ref, ref is short for reference. The value of a Parameter is a ref.
1846
1847    Outputs:
1848        symbolic_key, made from the Ref.
1849
1850    Examples:
1851        >>> class Net(nn.Cell):
1852        >>>     def __init__(self):
1853        >>>         super(Net, self).__init__()
1854        >>>         self.weight = mindspore.Parameter(1.0, name='weight')
1855        >>>
1856        >>>     def construct(self):
1857        >>>         key = RefToEmbed()(self.weight)
1858        >>>         return key, self.weight
1859    """
1860    __mindspore_signature__ = (
1861        sig.make_sig('variable', sig.sig_rw.RW_REF),
1862    )
1863
1864    @prim_attr_register
1865    def __init__(self):
1866        pass
1867
1868
1869class BasicLSTMCellCStateGrad(PrimitiveWithInfer):
1870    """Computes the state gradients of BasicLSTMCell."""
1871
1872    @prim_attr_register
1873    def __init__(self, forget_bias, activation):
1874        self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name)
1875        self.activation = validator.check_string(activation, ['tanh'], "activation", self.name)
1876
1877    def infer_shape(self, c_shape, dht_shape, dct_shape, it_shape, jt_shape, ft_shape, ot_shape, tanhct_shape):
1878        # dhy and dcy should be same shape
1879        validator.check_equal_int(len(c_shape), 2, "c rank", self.name)
1880        validator.check("dht rank", len(dht_shape), "c rank", len(c_shape), validator.EQ, self.name)
1881        validator.check("dct rank", len(dct_shape), "c rank", len(c_shape), validator.EQ, self.name)
1882        validator.check("it rank", len(it_shape), "c rank", len(c_shape), validator.EQ, self.name)
1883        validator.check("jt rank", len(jt_shape), "c rank", len(c_shape), validator.EQ, self.name)
1884        validator.check("ft rank", len(ft_shape), "c rank", len(c_shape), validator.EQ, self.name)
1885        validator.check("ot rank", len(ot_shape), "c rank", len(c_shape), validator.EQ, self.name)
1886        validator.check("tanhct rank", len(tanhct_shape), "c rank", len(c_shape), validator.EQ, self.name)
1887        validator.check("dht shape", dht_shape, "c shape", c_shape, validator.EQ, self.name)
1888        validator.check("dct shape", dct_shape, "c shape", c_shape, validator.EQ, self.name)
1889        validator.check("it shape", it_shape, "c shape", c_shape, validator.EQ, self.name)
1890        validator.check("jt shape", jt_shape, "c shape", c_shape, validator.EQ, self.name)
1891        validator.check("ft shape", ft_shape, "c shape", c_shape, validator.EQ, self.name)
1892        validator.check("ot shape", ot_shape, "c shape", c_shape, validator.EQ, self.name)
1893        validator.check("tanhct shape", tanhct_shape, "c shape", c_shape, validator.EQ, self.name)
1894
1895        dgate_shape = (c_shape[0], 4 * c_shape[1])
1896        dct_1_shape = c_shape
1897
1898        return (dgate_shape, dct_1_shape)
1899
1900    def infer_dtype(self, c_dtype, dht_dtype, dct_dtype, it_dtype, jt_dtype, ft_dtype, ot_dtype, tanhct_dtype):
1901        validator.check_subclass("c", c_dtype, [mstype.tensor_type], self.name)
1902        validator.check_subclass("dht", dht_dtype, [mstype.tensor_type], self.name)
1903        validator.check_subclass("dct", dct_dtype, [mstype.tensor_type], self.name)
1904        validator.check_subclass("it", it_dtype, [mstype.tensor_type], self.name)
1905        validator.check_subclass("jt", jt_dtype, [mstype.tensor_type], self.name)
1906        validator.check_subclass("ft", ft_dtype, [mstype.tensor_type], self.name)
1907        validator.check_subclass("ot", ot_dtype, [mstype.tensor_type], self.name)
1908        validator.check_subclass("tanhct", tanhct_dtype, [mstype.tensor_type], self.name)
1909        validator.check_type_name("c", c_dtype, [mstype.float16, mstype.float32], self.name)
1910        validator.check_type_name("dht", dht_dtype, [mstype.float16, mstype.float32], self.name)
1911        validator.check_type_name("dct", dct_dtype, [mstype.float16, mstype.float32], self.name)
1912        validator.check_type_name("it", it_dtype, [mstype.float16, mstype.float32], self.name)
1913        validator.check_type_name("jt", jt_dtype, [mstype.float16, mstype.float32], self.name)
1914        validator.check_type_name("ft", ft_dtype, [mstype.float16, mstype.float32], self.name)
1915        validator.check_type_name("ot", ot_dtype, [mstype.float16, mstype.float32], self.name)
1916        validator.check_type_name("tanhct", tanhct_dtype, [mstype.float16, mstype.float32], self.name)
1917        return (c_dtype, c_dtype)
1918
1919
1920class BasicLSTMCellWeightGrad(PrimitiveWithInfer):
1921    """Computes the weight gradients of BasicLSTM."""
1922
1923    @prim_attr_register
1924    def __init__(self):
1925        pass
1926
1927    def infer_shape(self, x_shape, h_shape, dgate_shape):
1928        validator.check_equal_int(len(x_shape), 2, "x rank", self.name)
1929        validator.check("h rank", len(h_shape), " x rank", len(x_shape), validator.EQ, self.name)
1930        validator.check("dgate rank", len(dgate_shape), "x rank", len(x_shape), validator.EQ, self.name)
1931        validator.check("h_shape[0]", h_shape[0], "x_shape[0]", x_shape[0], validator.EQ, self.name)
1932        validator.check("dgate_shape[0]", dgate_shape[0], "h_shape[0]", h_shape[0], validator.EQ, self.name)
1933        validator.check("dgate_shape[1]", dgate_shape[1], "4*h_shape[1]", 4 * h_shape[1], validator.EQ, self.name)
1934        input_size = x_shape[1]
1935        hidden_size = h_shape[1]
1936        dw_shape = (input_size + hidden_size, 4 * hidden_size)
1937        db_shape = (4 * hidden_size,)
1938        return (dw_shape, db_shape)
1939
1940    def infer_dtype(self, x_dtype, h_dtype, dgate_dtype):
1941        validator.check_subclass("x", x_dtype, mstype.tensor_type, self.name)
1942        validator.check_subclass("h", h_dtype, mstype.tensor_type, self.name)
1943        validator.check_subclass("dgate", dgate_dtype, mstype.tensor_type, self.name)
1944        validator.check_type_name("x", x_dtype, [mstype.float16, mstype.float32], self.name)
1945        validator.check_type_name("h", h_dtype, [mstype.float16, mstype.float32], self.name)
1946        validator.check_type_name("dgate", dgate_dtype, [mstype.float16, mstype.float32], self.name)
1947        return (x_dtype, x_dtype)
1948
1949
1950class BasicLSTMCellInputGrad(PrimitiveWithInfer):
1951    """Computes the input gradients of BasicLSTM."""
1952
1953    @prim_attr_register
1954    def __init__(self, keep_prob):
1955        self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name)
1956        self.keep_prob = validator.check_float_range(keep_prob, 0.0, 1.0, validator.INC_BOTH, "keep_prob", self.name)
1957
1958    def infer_shape(self, dgate_shape, w_shape):
1959        validator.check_equal_int(len(dgate_shape), 2, "dgate rank", self.name)
1960        validator.check_equal_int(len(w_shape), 2, "w rank", self.name)
1961        validator.check("dgate_shape[1]", dgate_shape[1], "w_shape[1]", w_shape[1], validator.EQ, self.name)
1962        batch_size = dgate_shape[0]
1963        hidden_size = dgate_shape[1] // 4
1964        input_size = w_shape[0] - hidden_size
1965        dxt_shape = (batch_size, input_size)
1966        dht_shape = (batch_size, hidden_size)
1967        return (dxt_shape, dht_shape)
1968
1969    def infer_dtype(self, dgate_dtype, w_dtype):
1970        validator.check_subclass("dgate", dgate_dtype, mstype.tensor_type, self.name)
1971        validator.check_subclass("w", w_dtype, mstype.tensor_type, self.name)
1972        validator.check_type_name("dgate", dgate_dtype, [mstype.float16, mstype.float32], self.name)
1973        validator.check_type_name("w", w_dtype, [mstype.float16, mstype.float32], self.name)
1974        return (dgate_dtype, dgate_dtype)
1975
1976
1977class InvGrad(Primitive):
1978    """Computes gradients for inv operation."""
1979
1980    @prim_attr_register
1981    def __init__(self):
1982        self.init_prim_io_names(inputs=['x', 'grad'], outputs=['y'])
1983
1984
1985class LRNGrad(Primitive):
1986    """Computes gradients for LRN operation."""
1987
1988    @prim_attr_register
1989    def __init__(self, depth_radius=5, bias=1.0, alpha=1.0, beta=0.5):
1990        self.init_prim_io_names(inputs=['grads', 'x', 'y'], outputs=['z'])
1991        validator.check_value_type("depth_radius", depth_radius, [int], self.name)
1992        validator.check_value_type("bias", bias, [float], self.name)
1993        validator.check_value_type("alpha", alpha, [float], self.name)
1994        validator.check_value_type("beta", beta, [float], self.name)
1995
1996
1997class MvlgammaGrad(Primitive):
1998    r"""
1999    Computes gradients for Mvlgamma.
2000
2001    The following tex shows the mathematical calculation process of Mvlgamma:
2002
2003    .. math::
2004
2005        \log (\Gamma_{p}(a))=C+\sum_{i=1}^{p} \log (\Gamma(a-\frac{i-1}{2}))
2006
2007    where :math:`C = \log(\pi) \times \frac{p(p-1)}{4}` and :math:`\Gamma(\cdot)` is the Gamma function.
2008
2009    Args:
2010        p(int): The number of dimensions. And the value of `p` must be greater than or equal to 1.
2011
2012    Inputs:
2013        - **y_grad** (Tensor) - The input gradient.
2014        - **x** (Tensor) - The input of Mvlgamma with data type of float32 or float64.
2015
2016    Outputs:
2017        Tensor, has the same shape and type as `x`.
2018
2019    Raises:
2020        TypeError: If dtype of `y_grad or `x` is neither float32 nor float64.
2021        TypeError: If `p` is not an int.
2022        ValueError: If p is not greater than or equal to 1.
2023        ValueError: If all elements of `x` are not greater than (p-1)/2.
2024
2025    Supported Platforms:
2026        ``Ascend`` ``CPU``
2027    """
2028
2029    @prim_attr_register
2030    def __init__(self, p):
2031        self.init_prim_io_names(inputs=['y_grad', 'x'], outputs=['x_grad'])
2032        self.p = validator.check_value_type('p', p, [int], self.name)
2033
2034
2035class MaskedSelectGrad(PrimitiveWithInfer):
2036    """Computes gradient for MaskedSelect."""
2037
2038    @prim_attr_register
2039    def __init__(self):
2040        pass
2041
2042    def infer_shape(self, x, mask, grad):
2043        return x
2044
2045    def infer_dtype(self, x, mask, grad):
2046        return x
2047
2048
2049class SoftShrinkGrad(Primitive):
2050    r"""
2051          Gradients for SoftShrink operation.
2052
2053          Args:
2054              lambd – The \lambdaλ (must be no less than zero) value for the Softshrink formulation. Default: 0.5.
2055
2056          Inputs:
2057              - **input_grad** (Tensor) - The input gradient.
2058              - **input_x** (Tensor) - The input of SoftShrink with data type of float16 or float32.
2059                Any number of additional dimensions.
2060
2061          Outputs:
2062              output - Tensor, has the same shape and data type as input_x.
2063
2064          Raises:
2065              TypeError: If lambd is not a float.
2066              TypeError: If dtype of input_x is neither float16 nor float32.
2067              ValueError: If lambd is less than to 0.
2068
2069          Supported Platforms:
2070              ``Ascend``
2071      """
2072
2073    @prim_attr_register
2074    def __init__(self, lambd=0.5):
2075        self.init_prim_io_names(inputs=['input_grad', 'input_x'], outputs=['output'])
2076        validator.check_value_type("lambd", lambd, [float], self.name)
2077        validator.check_number("lambd", lambd, 0, validator.GE, self.name)
2078
2079
2080class CdistGrad(Primitive):
2081    """Computes gradient for Cdist."""
2082
2083    @prim_attr_register
2084    def __init__(self, p=2.0):
2085        validator.check_value_type("p", p, [float], self.name)
2086        self.init_prim_io_names(inputs=['grad', 'input_x', 'input_y', 'cdist'], outputs=['output'])
2087
2088
2089class PdistGrad(Primitive):
2090    """Computes gradient for Pdist operation.
2091
2092    Args:
2093        p (float): the p value for the Pdist formulation. Default: 2.0.
2094
2095    Inputs:
2096        - **y_grad** (Tensor) - The gradients of loss to output of Pdist function.
2097        - **x** (Tensor) - Input tensor of shape :math:`(N, M)`.
2098        Must be the input `x` of the forward operator Pdist.
2099        - **y** (Tensor) - Input tensor of shape :math:`(N*(N-1)/2)`.
2100        Must be the output `y` of the forward operator Pdist.
2101
2102    Outputs:
2103        Tensor, with the same shape and dtype as `x`.
2104
2105    Raises:
2106        TypeError: If one of `y_grad`, `x` and `y` is not a Tensor.
2107        TypeError: If dtype of `y_grad`, `x` and `y` are not all float16, float32 or float64.
2108        TypeError: If `p` is not a float.
2109        ValueError: If `p` is a negative float.
2110        ValueError: If shape of `y_grad` is not same as `y`.
2111        ValueError: If dimension of `x` is not 2.
2112
2113    Supported Platforms:
2114        ``Ascend`` ``GPU`` ``CPU``
2115    """
2116
2117    @prim_attr_register
2118    def __init__(self, p=2.0):
2119        validator.check_value_type("p", p, [float], self.name)
2120        if p < 0:
2121            raise ValueError('Pdist p must be a non-negative value, but got `{p}`.')
2122        self.init_prim_io_names(inputs=['y_grad', 'x', 'y'], outputs=['x_grad'])
2123
2124
2125class MultilabelMarginLossGrad(Primitive):
2126    """
2127    Compute the gradients of MultilabelMarginLoss operation.
2128
2129    Args:
2130        reduction (str, optional): Apply specific reduction method to the output: ``'none'`` , ``'mean'`` ,
2131            ``'sum'`` . Default: ``'mean'`` .
2132
2133            - ``'none'``: no reduction will be applied.
2134            - ``'mean'``: compute and return the mean of elements in the output.
2135            - ``'sum'``: the output elements will be summed.
2136
2137    Inputs:
2138        - **y_grad** (Tensor) - The gradients of loss to output of MultilabelMarginLoss function, with
2139          the same shape and data type as forward output `y`.
2140        - **x** (Tensor) - Predict data. Tensor of shape :math:`(C)` or :math:`(N, C)`, where :math:`N`
2141          is the batch size and :math:`C` is the number of classes. Data type must be float16 or float32.
2142        - **target** (Tensor) - Ground truth data, with the same shape as `x`, data type must be int32 and
2143          label targets padded by -1.
2144        - **is_target** (Tensor) - Forward output tensor for backward input, with the same shape and
2145          data type as `target`.
2146
2147    Outputs:
2148        The shape of output :math:`(C)` or :math:`(N, C)`, with the same shape and data type as `x`.
2149
2150    Raises:
2151        TypeError: If `x` or `target` or `y_grad` is not a Tensor.
2152        TypeError: If dtype of `x` is neither float16 nor float32.
2153        TypeError: If dtype of `target` is not int32.
2154        TypeError: If dtype of `y_grad` is not the same as `x`.
2155        ValueError: If length of shape of `x` is neither 1 nor 2.
2156        ValueError: If shape of `x` is not the same as `target`.
2157        ValueError: If `reduction` is not one of ``'none'``, ``'mean'``, ``'sum'``.
2158        ValueError: If shape of `y_grad` is not the same as forward output `y`.
2159
2160    Supported Platforms:
2161        ``Ascend``
2162    """
2163
2164    @prim_attr_register
2165    def __init__(self, reduction="mean"):
2166        """Initialize MultilabelMarginLossGrad"""
2167        self.reduction = validator.check_string(reduction, ['none', 'sum', 'mean'], 'reduction', self.name)
2168        self.init_prim_io_names(inputs=['y_grad', 'x', 'target', 'is_target'], outputs=['x_grad'])
2169
2170
2171class Dilation2DBackpropInput(Primitive):
2172    """
2173    Computes the gradient of morphological 2-D dilation with respect to the input.
2174
2175    .. warning::
2176        This operator is an experimental operator, which has some accuracy problems for some inputs.
2177
2178    Args:
2179        stride (Union[int, tuple[int]]): The distance of filter moving, an int number that represents
2180            the height and width of movement are both strides, a tuple of two int numbers that
2181            represent height and width of movement respectively, or a tuple of four int numbers which
2182            should be :math:`(1, 1, H_{stride}, W_{stride})`.
2183        dilation (Union[int, tuple[int]]): The input stride for atrous morphological dilation.The data
2184            type is int or a tuple of 2 or 4 integers. Its value must be greater or equal to 1 and bounded
2185            by the height and width of the input `x`.
2186        pad_mode (str): Specifies padding mode. The optional values are "same", "valid".
2187            Default: "same". Both upper and lower case are supported.
2188        data_format (str): The format of input and output data. Only NCHW format is supported at present.
2189            Default:'NCHW'
2190
2191    Inputs:
2192        - **x** (Tensor) - Input data. A four dimension tensor with float16 or float32 data type. The shape must be
2193          :math:`(N, C_{in}, H_{in}, W_{in})`.
2194        - **filter** (Tensor) - A three dimension tensor with the same type as input. The shape must be
2195          :math:`(C_{in}, H_{filter}, W_{filter})`.
2196        - **out_backprop** (Tensor) - The gradients with respect to the output of the convolution.
2197          A four dimension tensor with float16 or float32 data type. The shape must be
2198          :math:`(N, C_{in}, H_{out}, W_{out})`.
2199
2200    outputs:
2201        Tensor, the gradients with respect to the input of convolution. It has the same shape and type as the input `x`.
2202
2203    Raises:
2204        TypeError: If type of `x` or `filter` is not the tpye in [uint8, uint16, uint32, uint64, int8, int16,
2205                                  int32, int64, float16, float32, float64].
2206        TypeError: If type of `out_backprop` is not the tpye in [uint8, uint16, uint32, uint64, int8, int16,
2207                                  int32, int64, float16, float32, float64].
2208        TypeError: If `stride` or `dilation` is not an int number or a tuple of two or four int numbers.
2209        ValueError: If the length of `stride` or `dilation` is neither two nor four when they are tuples.
2210        ValueError: If `stride` or `dilation` is not (1, 1, height, width) when it is a tuple of four int numbers.
2211        ValueError: If `stride` is not in the range of [1, 255].
2212        ValueError: If `dilation` is less than 1.
2213        ValueError: If `pad_mode` is not a str of 'same', 'valid', 'SAME' or 'VALID'.
2214        ValueError: If `data_format` is not the str of 'NCHW'.
2215
2216    Supported Platforms:
2217        ``Ascend`` ``GPU`` ``CPU``
2218
2219    Examples:
2220        (pad_mode="SAME", data_format="NCHW")
2221        >>> out_backprop = Tensor(np.ones([1, 3, 4, 4]), mstype.float32)
2222        >>> filter = Tensor(np.ones([3 , 2 , 2]), mstype.float32)
2223        >>> x = Tensor(np.ones([1, 3, 4, 4]), mstype.float32)
2224        >>> dilation_backprop_input = G.Dilation2DBackpropInput(stride=1, dilation=1)
2225        >>> output = dilation_backprop_input(x, filter, out_backprop)
2226        >>> print(output)
2227        [[[[1. 1. 1. 1.]
2228           [1. 1. 1. 1.]
2229           [1. 1. 1. 1.]
2230           [1. 1. 1. 1.]]
2231          [[1. 1. 1. 1.]
2232           [1. 1. 1. 1.]
2233           [1. 1. 1. 1.]
2234           [1. 1. 1. 1.]]
2235          [[1. 1. 1. 1.]
2236           [1. 1. 1. 1.]
2237           [1. 1. 1. 1.]
2238           [1. 1. 1. 1.]]]]
2239    """
2240
2241    @prim_attr_register
2242    def __init__(self, stride, dilation, pad_mode="SAME", data_format="NCHW"):
2243        """Initialize Dilation2DBackpropInput"""
2244
2245        def _check_format_stride_or_dilation(arg_name, arg_value, prim_name, data_format):
2246            validator.check_value_type(arg_name, arg_value, (int, tuple), prim_name)
2247            if isinstance(arg_value, int):
2248                ret_value = (1, arg_value, arg_value, 1) if data_format == "NHWC" else (1, 1, arg_value, arg_value)
2249            elif len(arg_value) == 2:
2250                ret_value = (1, arg_value[0], arg_value[1], 1) if data_format == "NHWC" else \
2251                    (1, 1, arg_value[0], arg_value[1])
2252            elif len(arg_value) == 4:
2253                if data_format == "NHWC" and (arg_value[0] != 1 or arg_value[3] != 1):
2254                    raise ValueError(f"For '{prim_name}' attr '{arg_name}' should be "
2255                                     f"[1, {arg_name}_height, {arg_name}_weigth, 1] when data_format is 'NHWC', "
2256                                     f"but got {arg_value}")
2257                if data_format == "NCHW" and (arg_value[0] != 1 or arg_value[1] != 1):
2258                    raise ValueError(
2259                        f"For '{prim_name}' attr '{arg_name}' should be [1, 1, {arg_name}_height, {arg_name}_weigth]"
2260                        f"when data_format is 'NCHW', but got {arg_value}")
2261                ret_value = arg_value
2262            else:
2263                raise ValueError(
2264                    f"For '{prim_name}' attr '{arg_name}' should be an positive int number or a tuple of two "
2265                    f"or four positive int numbers, but got {arg_value}")
2266            for item in ret_value:
2267                if isinstance(item, int) and not isinstance(item, bool) and item > 0:
2268                    continue
2269                raise ValueError(
2270                    f"For '{prim_name}' attr '{arg_name}' should be an positive int number or a tuple of two "
2271                    f"or four positive int numbers, but got {arg_value}")
2272            return ret_value
2273
2274        if data_format == 'NHWC':
2275            raise ValueError(f"For '{self.name}', NHWC format is not supported at present.")
2276        self.data_format = validator.check_string(self.data_format, ['NCHW', 'NHWC'], 'data_format', self.name)
2277        self.add_prim_attr("data_format", self.data_format)
2278        self.pad_mode = validator.check_string(self.pad_mode, ["SAME", "VALID", 'same', "valid"], "pad_mode", self.name)
2279        self.add_prim_attr("pad_mode", self.pad_mode.upper())
2280        self.stride = _check_format_stride_or_dilation("stride", stride, self.name, self.data_format)
2281        self.add_prim_attr("stride", self.stride)
2282        self.dilation = _check_format_stride_or_dilation("dilation", dilation, self.name, self.data_format)
2283        self.add_prim_attr("dilation", self.dilation)
2284
2285
2286class Dilation2DBackpropFilter(Primitive):
2287    """
2288    Computes the gradient of morphological 2-D dilation with respect to the filter.
2289
2290    .. warning::
2291        This operator is an experimental operator, which has some accuracy problems for some inputs.
2292
2293    Args:
2294        stride (Union[int, tuple[int]]): The distance of kernel moving, an int number that represents
2295            the height and width of movement are both strides, a tuple of two int numbers that
2296            represent height and width of movement respectively, or a tuple of four int numbers which
2297            should be :math:`(1, 1, H_{stride}, W_{stride})`.
2298        dilation (Union(int, tuple[int])): The data type is int or a tuple of 2 integers or a tuple of 4 integers.
2299            Specifies the dilation rate to use for dilated convolution.
2300            If set to be :math:`k > 1`, there will be :math:`k - 1` pixels skipped for each sampling location.
2301            Its value must be greater or equal to 1 and bounded by the height and width of the input `x`.
2302        pad_mode (str): Specifies padding mode. The optional values are "same", "valid".
2303            Default: "same". Both upper and lower case are supported.
2304        data_format (str): The format of input and output data. Only NCHW format is supported at present.
2305            Default:'NCHW'
2306
2307    Inputs:
2308        - **x** (Tensor) - Input data. A four dimension tensor with float16 or float32 data type. The shape must be
2309          :math:`(N, C_{in}, H_{in}, W_{in})`.
2310        - **filter** (Tensor) - A three dimension tensor with the same type as input. The shape must be
2311          :math:`(C_{in}, H_{filter}, W_{filter})`.
2312        - **out_backprop** (Tensor) - The gradients with respect to the output of the convolution.
2313          A four dimension tensor with float16 or float32 data type. The shape must be
2314          :math:`(N, C_{in}, H_{out}, W_{out})`.
2315
2316    outputs:
2317        Tensor, the gradients with respect to the input of convolution. It has the same shape and type as the input `x`.
2318
2319    Raises:
2320        TypeError: If type of `x` or `filter` is not the tpye in [uint8, uint16, uint32, uint64, int8, int16,
2321                                  int32, int64, float16, float32, float64].
2322        TypeError: If type of `out_backprop` is not the tpye in [uint8, uint16, uint32, uint64, int8, int16,
2323                                  int32, int64, float16, float32, float64].
2324        TypeError: If `stride` or `dilation` is not an int number or a tuple of two or four int numbers.
2325        ValueError: If the length of `stride` or `dilation` is neither two nor four when they are tuples.
2326        ValueError: If `stride` or `dilation` is not (1, 1, height, width) when it is a tuple of four int numbers.
2327        ValueError: If `stride` is not in the range of [1, 255].
2328        ValueError: If `dilation` is less than 1.
2329        ValueError: If `pad_mode` is not a str of 'same', 'valid', 'SAME' or 'VALID'.
2330        ValueError: If `data_format` is not the str of 'NCHW'.
2331
2332
2333    Supported Platforms:
2334        ``Ascend`` ``GPU`` ``CPU``
2335
2336    Examples:
2337        (pad_mode="SAME", data_format="NCHW")
2338        >>> x = Tensor(np.ones([2, 3, 4, 4]), mstype.float32)
2339        >>> filter = Tensor(np.ones([3,2,2]), mstype.float32)
2340        >>> out_backprop = Tensor(np.ones([2,3,2,2]), mstype.float32)
2341        >>> dilation_backprop_filter = G.Dilation2DBackpropFilter(stride=2, dilation=1)
2342        >>> output = dilation_backprop_filter(x, filter, out_backprop)
2343        >>> print(output)
2344        [[[8. 8. 8.]
2345          [0. 0. 0.]]
2346         [[0. 0. 0.]
2347          [0. 0. 0.]]]
2348    """
2349
2350    @prim_attr_register
2351    def __init__(self, stride, dilation, pad_mode="SAME", data_format="NCHW"):
2352        """Initialize Dilation2DBackpropFilter"""
2353
2354        def _check_format_stride_or_dilation(arg_name, arg_value, prim_name, data_format):
2355            validator.check_value_type(arg_name, arg_value, (int, tuple), prim_name)
2356            if isinstance(arg_value, int):
2357                ret_value = (1, arg_value, arg_value, 1) if data_format == "NHWC" else (1, 1, arg_value, arg_value)
2358            elif len(arg_value) == 2:
2359                ret_value = (1, arg_value[0], arg_value[1], 1) if data_format == "NHWC" else \
2360                    (1, 1, arg_value[0], arg_value[1])
2361            elif len(arg_value) == 4:
2362                if data_format == "NHWC" and (arg_value[0] != 1 or arg_value[3] != 1):
2363                    raise ValueError(
2364                        f"For '{prim_name}' attr '{arg_name}' should be [1, {arg_name}_height, {arg_name}_weigth, 1]"
2365                        f"when data_format is 'NHWC', but got {arg_value}")
2366                if data_format == "NCHW" and (arg_value[0] != 1 or arg_value[1] != 1):
2367                    raise ValueError(
2368                        f"For '{prim_name}' attr '{arg_name}' should be [1, 1, {arg_name}_height, {arg_name}_weigth]"
2369                        f"when data_format is 'NCHW', but got {arg_value}")
2370                ret_value = arg_value
2371            else:
2372                raise ValueError(
2373                    f"For '{prim_name}' attr '{arg_name}' should be an positive int number or a tuple of two "
2374                    f"or four positive int numbers, but got {arg_value}")
2375            for item in ret_value:
2376                if isinstance(item, int) and not isinstance(item, bool) and item > 0:
2377                    continue
2378                raise ValueError(
2379                    f"For '{prim_name}' attr '{arg_name}' should be an positive int number or a tuple of two "
2380                    f"or four positive int numbers, but got {arg_value}")
2381            return ret_value
2382
2383        if data_format == 'NHWC':
2384            raise ValueError(f"For '{self.name}', NHWC format is not supported at present.")
2385        self.data_format = validator.check_string(self.data_format, ['NCHW', 'NHWC'], 'data_format', self.name)
2386        self.add_prim_attr("data_format", self.data_format)
2387        self.pad_mode = validator.check_string(self.pad_mode, ["SAME", "VALID", 'same', "valid"], "pad_mode", self.name)
2388        self.add_prim_attr("pad_mode", self.pad_mode.upper())
2389        self.stride = _check_format_stride_or_dilation("stride", stride, self.name, self.data_format)
2390        def is_in_range(x):
2391            return 1 <= x <= 255
2392        if not is_in_range(self.stride[2]) or not is_in_range(self.stride[3]):
2393            raise ValueError(f"For '{self.name}', size of stride is not supported, "
2394                             f'stride should be in the range of [1, 255], '
2395                             f'but got stride_h: `{self.stride[2]}`, stride_w: `{self.stride[3]}`.')
2396        self.add_prim_attr("stride", self.stride)
2397        self.dilation = _check_format_stride_or_dilation("dilation", dilation, self.name, self.data_format)
2398        self.add_prim_attr("dilation", self.dilation)
2399
2400
2401class ParallelResizeBilinearGrad(PrimitiveWithInfer):
2402    """ParallelResizeBilinearGrad ops"""
2403
2404    @prim_attr_register
2405    def __init__(self, ori_image_size, src_start_w, dst_start_w, align_corners):
2406        """Initialize ParallelResizeBilinearGrad."""
2407        self.init_prim_io_names(inputs=["grad", "x", "size"], outputs=['y'])
2408        validator.check_value_type("ori_image_size", ori_image_size, [tuple, list], self.name)
2409        validator.check_value_type("src_start_w", src_start_w, [int], self.name)
2410        validator.check_value_type("dst_start_w", dst_start_w, [int], self.name)
2411        validator.check_value_type("align_corners", align_corners, [bool], self.name)
2412        self.ori_image_size = list(ori_image_size)
2413        self.src_start_w = src_start_w
2414        self.dst_start_w = dst_start_w
2415        self.align_corners = align_corners
2416        self.half_pixel_centers = False
2417        self.add_prim_attr('ori_image_size', self.ori_image_size)
2418        self.add_prim_attr('src_start_w', self.src_start_w)
2419        self.add_prim_attr('dst_start_w', self.dst_start_w)
2420        self.add_prim_attr('align_corners', self.align_corners)
2421        self.add_prim_attr('half_pixel_centers', self.half_pixel_centers)
2422
2423    def __infer__(self, grad, x, size):
2424        size_val = size['value']
2425        grad_shape = grad['shape']
2426        grad_dtype = grad['dtype']
2427        x_shape = x['shape']
2428        x_dtype = x['dtype']
2429        validator.check_tensor_dtype_valid("grad_dtype", grad_dtype, [mstype.float16, mstype.float32], self.name)
2430        validator.check_tensor_dtype_valid("x_dtype", x_dtype, [mstype.float16, mstype.float32], self.name)
2431        if size_val is None:
2432            raise ValueError("size must be const input")
2433        output_shape = [grad_shape[0], grad_shape[1], x_shape[2], x_shape[3]]
2434
2435        return {'shape': output_shape,
2436                'dtype': x_dtype,
2437                'value': None}
2438
2439
2440class MultiMarginLossGrad(Primitive):
2441    """
2442    Compute the gradients of MultiMarginLoss operation
2443
2444    Args:
2445        p (int): Optional. The norm degree for pairwise distance.Should be 1 or 2. Default: 1.
2446        margin (float): Optional. A parameter to change pairwise distance. Default: 1.0.
2447        reduction (str, optional): Apply specific reduction method to the output: ``'none'`` , ``'mean'`` ,
2448            ``'sum'`` . Default: ``'mean'`` .
2449
2450            - ``'none'``: no reduction will be applied.
2451            - ``'mean'``: compute and return the weighted mean of elements in the output.
2452            - ``'sum'``: the output elements will be summed.
2453
2454    Inputs:
2455        - **y_grad** (Tensor) - If it's not a scalar, the shape of 'y_grad' :math:`(N, C)`.
2456          Data type only support float32 or float16,float64.
2457        - **x** (Tensor) - Input x, with shape :math:`(N, C)`. Data type only support float32, float16 or float64.
2458        - **target** (Tensor) - Ground truth labels, with shape :math:`(N,)`. Data type only support int64. The
2459          value of target should be non-negative, less than C.
2460        - **weight** (Tensor, optional) - The rescaling weight to each class with shape :math:`(C,)`. Data type only
2461          support float32, float16 or float64. Default: ``None``.
2462
2463    Outputs:
2464        The shape of output :math:`(N, C)`. Data type only support float32 or float16, float64.
2465        Has the same data type with 'x'.
2466
2467    Raises:
2468        TypeError: If dtype of `p` and `target` is not int.
2469        TypeError: If dtype of `margin` is not float.
2470        TypeError: If dtype of `reduction` is not str.
2471        TypeError: If dtype of `x` is not float16, float or float64.
2472        TypeError: If dtype of `weight` and `x` is not the same.
2473        ValueError: If 'p' is not 1 or 2.
2474        ValueError: If 'reduction' is not one of {'none','sum','mean'}.
2475        ValueError: If shape[0] of `x` is not equal to shape[0] of `target`.
2476        ValueError: If shape[1] of `x` is not equal to shape[0] of `weight`.
2477        ValueError: IF rank of `weight` is not 1.
2478        ValueError: If rank of `x` is not 2 or rank of 'target' is not 1.
2479
2480    Supported Platforms:
2481        ``Ascend``  ``CPU``
2482    """
2483    __mindspore_signature__ = (
2484        sig.make_sig('y_grad'),
2485        sig.make_sig('x'),
2486        sig.make_sig('target'),
2487        sig.make_sig('weight', default=None)
2488    )
2489
2490    @prim_attr_register
2491    def __init__(self, p=1, margin=1.0, reduction="mean"):
2492        """Initialize MultiMarginLossGrad"""
2493        self.p = validator.check_value_type('p', p, [int], self.name)
2494        validator.check_int(p, {1, 2}, validator.IN, 'p', self.name)
2495        self.margin = validator.check_value_type('margin', margin, [float], self.name)
2496        self.reduction = validator.check_string(reduction, ['none', 'sum', 'mean'], 'reduction', self.name)
2497        self.init_prim_io_names(inputs=['y_grad', 'x', 'target', 'weight'], outputs=['x_grad'])
2498
2499    def __call__(self, y_grad, x, target, weight=None):
2500        return super().__call__(y_grad, x, target, weight)
2501
2502
2503class SparseSegmentMeanGrad(Primitive):
2504    """
2505    Compute gradients for SparseSegmentMeanGrad operation.
2506
2507    Inputs:
2508        - **x** (Tensor) - A Tensor of the first input of SparseSegmentMeanGrad.
2509        - **indices** (Tensor) - Indices is a 1-D tensor with indices into `x`. Must be one of the following
2510          types: int32, int64. Has same rank as `segment_ids`. The shape should be :math:`(N,)`.
2511        - **segment_ids** (Tensor) - Segment_ids is a 1-D tensor with indices into the output `y`. Must be one of the
2512          following types: int32, int64. Values should be sorted and can be repeated. The shape should be :math:`(N,)`.
2513        - **output_dim0** (Tensor) - Output_dim0 is a 0-D tensor. Dimension 0 of `x` passed to SparseSegmentMean op.
2514
2515    Outputs:
2516        A Tensor. Has the same type as `x` .
2517        Has same shape as `x`, except for dimension 0 which is the value of `output_dim0`.
2518
2519    Raises:
2520        TypeError: If `x` or `indices` or `segment_ids` is not a tensor.
2521        TypeError: If the dtype of `x` is not any of the following data types: {float32, float64}.
2522        TypeError: If the dtype of `indices` is not int32.
2523        TypeError: If the dtype of `segment_ids` is not int32.
2524        TypeError: If the dtype of `output_dim0` is not int32.
2525        ValueError: If dimension size of `x` is less than 1.
2526        ValueError: If rank of `indices` or `segment_ids` is not 1.
2527        ValueError: If dimension size of `output_dim0` is not 0.
2528        ValueError: If the first dimension of `indices` is not equal to the first dimension of `segment_ids`.
2529        ValueError: If `segment_ids` is not sorted.
2530        ValueError: If `indices` is out of range of `output_dim0`.
2531
2532    Supported Platforms:
2533        ``Ascend`` ``GPU`` ``CPU``
2534    """
2535
2536    @prim_attr_register
2537    def __init__(self):
2538        """Initialize SparseSegmentMeanGrad"""
2539        self.init_prim_io_names(inputs=['x', 'indices', 'segment_ids', 'output_dim0'], outputs=['y'])
2540
2541
2542class FractionalMaxPoolGrad(Primitive):
2543    """Computes gradients for FractionalMaxPool operation."""
2544
2545    @prim_attr_register
2546    def __init__(self, overlapping=False):
2547        self.init_prim_io_names(inputs=["orig_input", "orig_output", "out_backprop",
2548                                        "row_pooling_sequence", "col_pooling_sequence"],
2549                                outputs=["y"])
2550        validator.check_value_type("overlapping", overlapping, [bool], self.name)
2551
2552
2553class FractionalMaxPool3DGradWithFixedKsize(Primitive):
2554    """Computes gradients for FractionalMaxPool3DWithFixedKsize operation."""
2555
2556    @prim_attr_register
2557    def __init__(self, data_format="NCDHW"):
2558        self.init_prim_io_names(inputs=["origin_input", "out_backprop", "argmax"], outputs=["y"])
2559        self.data_format = validator.check_string(data_format, ['NCDHW', "NDHWC"], 'data_format', self.name)
2560
2561
2562class MaxUnpool2DGrad(Primitive):
2563    r"""
2564    Gradients for MaxUnpool2D operation.
2565    """
2566
2567    @prim_attr_register
2568    def __init__(self, ksize, strides=0, pads=0, output_shape=(), data_format="NCHW"):
2569        """Initialize MaxUnpool2DGrad."""
2570        self.init_prim_io_names(inputs=['x', 'grads', 'argmax'], outputs=['y'])
2571        validator.check_value_type("ksize", ksize, [int, tuple], self.name)
2572        validator.check_value_type("strides", strides, [int, tuple], self.name)
2573        validator.check_value_type("pads", pads, [int, tuple], self.name)
2574        validator.check_value_type("output_shape", output_shape, [tuple], self.name)
2575        validator.check_string(data_format, ['NCHW', 'NHWC'], 'data_format', self.name)
2576        validator.check_int(len(ksize), 4, validator.EQ, "ksize rank", self.name)
2577        validator.check_int(len(strides), 4, validator.EQ, "strides rank", self.name)
2578        validator.check_int(len(pads), 4, validator.EQ, "pads rank", self.name)
2579
2580
2581class MaxUnpool3DGrad(Primitive):
2582    r"""
2583    Gradients for MaxUnpool3D operation.
2584    """
2585
2586    @prim_attr_register
2587    def __init__(self, ksize, strides=0, pads=0, output_shape=(), data_format="NCDHW"):
2588        """Initialize MaxUnpool3DGrad."""
2589        self.init_prim_io_names(inputs=['x', 'grads', 'argmax'], outputs=['y'])
2590        validator.check_value_type("ksize", ksize, [int, tuple], self.name)
2591        validator.check_value_type("strides", strides, [int, tuple], self.name)
2592        validator.check_value_type("pads", pads, [int, tuple], self.name)
2593        validator.check_value_type("output_shape", output_shape, [tuple], self.name)
2594        validator.check_string(data_format, ['NCDHW', 'NDHWC'], 'data_format', self.name)
2595        validator.check_int(len(ksize), 5, validator.EQ, "ksize rank", self.name)
2596        validator.check_int(len(strides), 5, validator.EQ, "strides rank", self.name)
2597        validator.check_int(len(pads), 5, validator.EQ, "pads rank", self.name)
2598
2599
2600class FractionalAvgPoolGrad(Primitive):
2601    """Computes gradients for FractionalAvgPool operation."""
2602
2603    @prim_attr_register
2604    def __init__(self, overlapping=False):
2605        self.add_prim_attr("max_length", 1000000)
2606        self.init_prim_io_names(inputs=["orig_input_tensor_shape", "out_backprop", "row_pooling_sequence",
2607                                        "col_pooling_sequence"],
2608                                outputs=["y"])
2609        validator.check_value_type("overlapping", overlapping, [bool], self.name)
2610
2611
2612class PSROIPoolingGrad(Primitive):
2613    """Computes gradients for PSROIPooling operation."""
2614
2615    @prim_attr_register
2616    def __init__(self, input_size, spatial_scale, group_size, output_dim):
2617        """Initialize PSROIPoolingGrad."""
2618        self.init_prim_io_names(inputs=["x", "rois"], outputs=['y'])
2619        validator.check_value_type("input_size", input_size, [int, tuple], self.name)
2620        validator.check_positive_float(spatial_scale, "spatial_scale", self.name)
2621        validator.check_positive_int(group_size, "group_size", self.name)
2622        validator.check_positive_int(output_dim, "output_dim", self.name)
2623
2624        if isinstance(input_size, int):
2625            self.input_size = [input_size, input_size]
2626        else:
2627            self.input_size = list(input_size)
2628
2629        validator.check_positive_int_sequence(self.input_size, "input_size", self.name)
2630        self.spatial_scale = spatial_scale
2631        self.group_size = group_size
2632        self.output_dim = output_dim
2633
2634        self.add_prim_attr('input_size', self.input_size)
2635        self.add_prim_attr('spatial_scale', self.spatial_scale)
2636        self.add_prim_attr('group_size', self.group_size)
2637        self.add_prim_attr('output_dim', self.output_dim)
2638
2639
2640class AdaptiveMaxPool3DGrad(Primitive):
2641    """Computes gradients for AdaptiveMaxPool3D operation."""
2642
2643    @prim_attr_register
2644    def __init__(self):
2645        """Initialize AdaptiveMaxPool3DGrad"""
2646        self.init_prim_io_names(inputs=['input_grad', 'x', 'argmax'], outputs=['output_grad'])
2647
2648
2649class TraceGrad(Primitive):
2650    """
2651    Computes grad for Trace operation.
2652
2653    Inputs:
2654        - **y_grad** (Tensor) - the grad of trace to output of Trace function.
2655          Currently grad data type support float16, float32, int8, int16, int32, int64,
2656          uint8, uint16, uint32, uint64, float64.
2657        - **x_shape** (Tensor) - the shape of trace to output of Trace function.
2658          Currently shape data type support int32, int64.
2659
2660    Outputs:
2661        x_grad - Tensor, with the same data type as 'y_grad' and shape is x_shape.
2662
2663    Raises:
2664        TypeError: If `x_shape` is not a Tensor.
2665        TypeError: If the dtype of `x_shape` is neither int32 nor int64.
2666        ValueError: If `x_shape` is not a 1D Tensor.
2667        ValueError: If length of shape of `x_shape` is not equal to 2.
2668
2669    Support Platforms:
2670        ``Ascend`` ``GPU`` ``CPU``
2671    """
2672
2673    @prim_attr_register
2674    def __init__(self):
2675        self.init_prim_io_names(inputs=['y_grad', 'x_shape'], outputs=['x_grad'])
2676
2677
2678class IgammaGradA(Primitive):
2679    r"""
2680    Computes the gradient of igamma(a, x) wrt a.
2681
2682    Inputs:
2683        - **a** (Tensor) - The input tensor. With float32 or float64 data type.
2684        - **x** (Tensor) - The input tensor. With float32 data or float64 type. `x` should have
2685          the same dtype with `a`.
2686
2687    Outputs:
2688        Tensor, has the same dtype as `a` and `x`.
2689
2690    Raises:
2691        TypeError: If a or grad is not a Tensor.
2692        TypeError: If dtype of input x and a is not float32 nor float64.
2693        TypeError: If x has different dtype with a.
2694        ValueError: If `a` could not be broadcast to a tensor with shape of `x`.
2695
2696    Supported Platforms:
2697        ``Ascend`` ``GPU`` ``CPU``
2698
2699    Examples:
2700        >>> a = Tensor(np.array([2.0, 4.0, 6.0, 8.0]).astype(np.float32))
2701        >>> x = Tensor(np.array([2.0, 3.0, 4.0, 5.0]).astype(np.float32))
2702        >>> igammagrada = G.IgammaGradA()
2703        >>> output = igammagrada(a, x)
2704        >>> print (output)
2705        [-0.2940046  -0.20153049 -0.13028376 -0.08352186]
2706    """
2707
2708    @prim_attr_register
2709    def __init__(self):
2710        """Initialize IgammaGradA"""
2711        self.init_prim_io_names(inputs=['a', 'x'], outputs=['z'])
2712
2713
2714class DeformableOffsetsGrad(Primitive):
2715    r"""
2716    Computes gradients of DeformableOffsets operation.
2717    Args:
2718        strides (tuple[int, int ,int ,int]): A tuple of 4 integers. The stride of sliding windows for height
2719            and width for H/W dimension.
2720        pads (tuple[int, int ,int ,int]): A tuple of 4 integers.Padding added to H/W dimension of the input.The number
2721            of pixels to add to each (top, bottom, left,right) side of the input
2722        kernel_size (tuple[int, int]): Kernel size, a tuple of 2 integers.
2723        dilations (tuple[int, int, int, int]): A tuple of 4 integers. The dilation factor for each dimension of
2724            input. Default:(1, 1, 1, 1)
2725        data_format (str): An optional string from:"NCHW", "NHWC".Specify the data format of the input x. Default:
2726            "NCHW".
2727        deformable_groups (int): Specify the C-axis grouping number of input x. Default: 1.
2728        modulated (bool): Specify version of DeformableOffsetsGrad, true means v2, false means v1. Default: ``True``.
2729
2730    Inputs:
2731        - **grad** (Tensor) - The input grad tensor. With float16 or float32 data type.
2732        - **x** (Tensor) - The input `x` of DeformableOffsets with data type of float16 or float32.
2733        - **offsets** (Tensor) - The input 'offsets' of DeformableOffsets with data type of float16 or float32.
2734
2735    Outputs:
2736        - **grad_x** (Tensor) - The output grad of input `x`. With same dtype and shape of input `x`.
2737        - ""grad_offsets** (Tensor) - The output grad of input `offsets`. With same dtype and shape of input `offsets`.
2738
2739    Supported Platforms:
2740        ``Ascend````GPU````CPU``
2741    """
2742
2743    @prim_attr_register
2744    def __init__(self,
2745                 strides,
2746                 pads,
2747                 kernel_size,
2748                 dilations=(1, 1, 1, 1),
2749                 data_format="NCHW",
2750                 deformable_groups=1,
2751                 modulated=True):
2752        """Initialize DeformableOffsetsGrad"""
2753        self.init_prim_io_names(inputs=['out_backprop', 'input', 'offsets'], outputs=['out_grad'])
2754
2755        self.strides = _check_positive_int_or_tuple('strides', strides, self.name, allow_four=True, ret_four=True)
2756        self.add_prim_attr('strides', self.strides)
2757
2758        self.pads = pads
2759        self.add_prim_attr('pads', self.pads)
2760
2761        self.kernel_size = _check_positive_int_or_tuple('kernel_size', kernel_size, self.name, allow_four=True,
2762                                                        ret_four=False)
2763        self.add_prim_attr('ksize', self.kernel_size)
2764
2765        self.dilations = _check_positive_int_or_tuple('dilations', dilations, self.name, allow_four=True, ret_four=True)
2766        self.add_prim_attr('dilations', dilations)
2767
2768        self.data_format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name)
2769        self.add_prim_attr('data_format', self.data_format)
2770
2771        self.deformable_groups = validator.check_positive_int(deformable_groups, 'deformable_groups', self.name)
2772        self.add_prim_attr('deformable_groups', self.deformable_groups)
2773
2774        self.modulated = validator.check_bool(modulated, 'modulated', self.name)
2775        self.add_prim_attr('modulated', self.modulated)
2776
2777
2778class MedianGrad(Primitive):
2779    """
2780    Computes gradient for Median operation.
2781
2782    .. warning::
2783        When attr `global_median` is True, the value of Median's second output Tensor `indices` value is meaningless.
2784
2785    Args:
2786        global_median (bool): Whether the output tensor is the global median of all input tensor elements
2787            or not in Median operation.
2788        axis (int): The dimension need to reduce in Median operation.
2789        keep_dims (bool): Whether the output tensor need to retain `axis` dimension or not in Median operation.
2790
2791    Inputs:
2792        - **y_grad** (Tensor) - The gradients of loss to output of Median function.
2793        - **x** (Tensor) - The first input is a tensor whose data type is number.
2794          The dtype is one of the following: int16, int32, int64, float32, double.
2795        - **y** (Tensor) - The first output of Median function, which datatype is same as `x`.
2796        - **indices** (Tensor) - The second output of Median function, which datatype is int64.
2797
2798    Outputs:
2799        x_grad - Tensor, has the same shape as the `x`, dtype is double only when dtype of `x` is double.
2800        Otherwise, dtype of `x_grad` is float32.
2801
2802    Raises:
2803        TypeError: If dtype of `y_grad` is not the same as `x`.
2804        ValueError: If shape of `y_grad` is not the same as `y`.
2805
2806    Supported Platforms:
2807        ``Ascend`` ``CPU``
2808    """
2809
2810    @prim_attr_register
2811    def __init__(self, global_median=False, axis=0, keep_dims=False):
2812        validator.check_value_type("global_median", global_median, [bool], self.name)
2813        self.global_median = global_median
2814        if global_median is False:
2815            validator.check_value_type("axis", axis, [int], self.name)
2816            validator.check_value_type("keep_dims", keep_dims, [bool], self.name)
2817        self.init_prim_io_names(inputs=['y_grad', 'x', 'y', 'indices'], outputs=['x_grad'])
2818
2819
2820class SparseSegmentSumGrad(Primitive):
2821    """
2822    Computes gradients for SparseSegmentSumGrad operation.
2823
2824    Inputs:
2825        - **grad** (Tensor) - A tensor.
2826        - **indices** (Tensor) - Indices is a 1-D tensor. Must be one of the following types: int32, int64.
2827          Has same rank as segment_ids. The shape should be :math:`(N,)`.
2828        - **segment_ids** (Tensor) - Segment_ids is a 1-D tensor. Must be one of the following types: int32, int64.
2829          Values should be sorted and can be repeated. The shape should be :math:`(N,)`.
2830        - **output_dim0** (Tensor) - Output_dim0 is a 0-D tensor. Dimension 0 of `x` passed to SparseSegmentSum op.
2831
2832    Outputs:
2833        A Tensor. Has the same type as `grad` .
2834        Has same shape as `grad`, except for dimension 0 which is the value of `output_dim0`.
2835
2836    Raises:
2837        TypeError: If `grad` or `indices` or `segment_ids` or `output_dim0` is not a tensor.
2838        TypeError: If the dtype of `grad` is not any of the following data types: {float16, float32, float64}.
2839        TypeError: If the dtype of `indices` and `segment_ids` and `output_dim0` is not int32 or int64.
2840        ValueError: If dimension size of `grad` less than 1.
2841        ValueError: If rank of `indices` or `segment_ids` is not 1.
2842        ValueError: If dimension size of `output_dim0` is not 0.
2843        ValueError: If shape[0] of `indices` is not corresponding to shape[0] of `segment_ids`.
2844        ValueError: If `segment_ids` is not sorted.
2845        ValueError: If the last number of `segment_ids` is out of range of grad's first shape.
2846        ValueError: If `indices` is bigger than or equal to `output_dim0`.
2847
2848    Supported Platforms:
2849        ``GPU``
2850    """
2851    __mindspore_signature__ = (
2852        sig.make_sig('grad', dtype=sig.sig_dtype.T1),
2853        sig.make_sig('indices', dtype=sig.sig_dtype.T),
2854        sig.make_sig('segment_ids', dtype=sig.sig_dtype.T),
2855        sig.make_sig('output_dim0', dtype=sig.sig_dtype.T)
2856    )
2857
2858    @prim_attr_register
2859    def __init__(self):
2860        """Initialize SparseSegmentSumGrad"""
2861        self.init_prim_io_names(inputs=['grad', 'indices', 'segment_ids', 'output_dim0'], outputs=['y'])
2862
2863
2864class SparseSegmentSqrtNGrad(Primitive):
2865    """
2866    Computes gradients for SparseSegmentSqrtNGrad operation.
2867
2868    Inputs:
2869        - **x** (Tensor) - A tensor. It's rank must be more than or equal to one.
2870        - **indices** (Tensor) - Indices is a 1-D tensor with indices into `x`. Must be one of the following
2871          types: int32, int64. Has same rank as segment_ids. The shape should be :math:`(N,)`.
2872        - **segment_ids** (Tensor) - Segment_ids is a 1-D tensor with indices into the output `y`. Must be one
2873          of the following types: int32, int64. Values should be sorted and can be repeated. The shape should
2874          be :math:`(N,)`.
2875        - **output_dim0** (Tensor) - Output_dim0 is a 0-D tensor. Dimension 0 of `x` passed to SparseSegmentSqrtN op.
2876
2877    Outputs:
2878        A Tensor. Has the same type as `x` .
2879        Has same shape as `x`, except for dimension 0 which is the value of `output_dim0`.
2880
2881    Raises:
2882        TypeError: If `x` or `indices` or `segment_ids` or `output_dim0` is not a tensor.
2883        TypeError: If the dtype of `x` is not any of the following data types: {float16, float32, float64}.
2884        TypeError: If the dtype of `indices` is not int32.
2885        TypeError: If the dtype of `segment_ids` is not int32.
2886        TypeError: If the dtype of `output_dim0` is not int32.
2887        ValueError: If dimension size of `x` is less than 1.
2888        ValueError: If rank of `indices` or `segment_ids` is not 1.
2889        ValueError: If dimension size of `output_dim0` is not 0.
2890        ValueError: If shape[0] of `indices` is not corresponding to shape[0] of `segment_ids`.
2891        ValueError: If `segment_ids` is not sorted.
2892        ValueError: If the last number of `segment_ids` is out of range of x's first shape.
2893        ValueError: If `indices` is bigger than or equal to `output_dim0`.
2894
2895    Supported Platforms:
2896        ``Ascend`` ``GPU`` ``CPU``
2897    """
2898
2899    @prim_attr_register
2900    def __init__(self):
2901        """Initialize SparseSegmentSqrtNGrad"""
2902        self.init_prim_io_names(inputs=['x', 'indices', 'segment_ids', 'output_dim0'], outputs=['y'])
2903
2904
2905class SparseSliceGrad(Primitive):
2906    r"""
2907    Computes gradients for SparseSlice operation.
2908
2909    Inputs:
2910        - **backprop_val_grad** (Tensor) - A 1D Tensor.
2911          The shape should be :math:`(N,)`.
2912        - **indices** (Tensor) - A 2D Tensor (N x R matrix) of type int64. The indices of the SparseTensor.
2913          Support int64, each element value should be a non-negative int number. This tensor should be sorted.
2914          The shape is :math:`(N, R)`.
2915        - **start** (Tensor) - A 1D Tensor of type int64, represents the start of the indices.
2916          The shape should be :math:`(R,)`.
2917        - **new_indices** (Tensor) - A 2D Tensor (N x C matrix) of type int64. The indices of the SparseTensor.
2918          Support int64, each element value should be a non-negative int number. This tensor should be sorted.
2919          The shape is :math:`(N, C)`.
2920
2921    Outputs:
2922        - *y_grad_val: A Tensor. Has the same type as `backprop_val_grad`.
2923          Has the same number as `indices`.
2924
2925    Raises:
2926        TypeError: If the dtype of `indices`, `start`, `new_indices` are not int64.
2927        ValueError: If `indices`, `new_indices` are not 2-D tensor.
2928        ValueError: If `backprop_val_grad`, `start` is not a 1-D tensor.
2929        ValueError: If the number of `backprop_val_grad` is not corresponding to the number of `new_indices`.
2930        ValueError: If the shape of `indices[1]` is not corresponding to `start[1]`.
2931        ValueError: If the shape of `indices[1]` is not corresponding to `new_indices[1]`.
2932        RuntimeError: If the `backprop_val_grad` is not all backpropagated, because `indices` or `new_indices`
2933        is not sorted.
2934
2935    Supported Platforms:
2936        ``Ascend`` ``GPU`` ``CPU``
2937    Examples:
2938        >>> backprop_val_grad = Tensor(np.array([1, 2, 3, 4]).astype(np.int64))
2939        >>> indices = Tensor(np.array([[0, 0], [0, 2], [1, 2], [1, 3], [2, 3], [2, 4]]).astype(np.int64))
2940        >>> start = Tensor(np.array([0, 0]).astype(np.int64))
2941        >>> new_indices = Tensor(np.array([[0, 2], [1, 2], [1, 3], [2, 4]]).astype(np.int64))
2942        >>> grad = SparseSliceGrad()
2943        >>> output = grad(backprop_val_grad, indices, start, new_indices)
2944        >>> print(output)
2945        [0 1 2 3 0 4]
2946    """
2947
2948    @prim_attr_register
2949    def __init__(self):
2950        """Initialize SparseSliceGrad."""
2951        self.init_prim_io_names(inputs=['backprop_val_grad', 'indices', 'start', 'new_indices'], outputs=['y_grad'])
2952
2953
2954class FractionalMaxPoolGradWithFixedKsize(Primitive):
2955    """
2956    Computes the gradients of FractionalMaxPoolWithFixedKsize.
2957
2958    Args:
2959        data_format (str): The optional value for data format, is 'NCHW'. Default: "NCHW".
2960
2961    Inputs:
2962        - **origin_input** (Tensor) - Tensor with data format "NCHW", data type must be int32 or int64.
2963        - **out_backprop** (Tensor) - The gradients with respect to the output of FractionalMaxPoolWithFixedKsize
2964        function. Tensor with data format "NCHW", whose data type is float16, float32, float64, int32 or int64.
2965        - **argmax** (Tensor) - The second output of FractionalMaxPoolWithFixedKsize function, whose data
2966        type is int64.
2967
2968    Outputs:
2969        - **y** (Tensor) - Tensor, with the same shape as `origin_input`, and the same data type as
2970        the input `out_backprop`.
2971
2972    Raises:
2973        TypeError: If data type of `out_backprop` is not one of the following: float16, float32, float64, int32, int64.
2974        TypeError: If data type of `argmax` is not int64.
2975        ValueError: If the shape of `out_backprop` and `argmax` is not equal.
2976        ValueError: If the first dimension size of `origin_input` and `out_backprop` is not equal.
2977        ValueError: If the second dimension size of `origin_input` and `out_backprop` is not equal.
2978
2979    Supported Platforms:
2980        ``Ascend`` ``GPU`` ``CPU``
2981    """
2982
2983    @prim_attr_register
2984    def __init__(self, data_format="NCHW"):
2985        self.data_format = validator.check_string(data_format, ['NCHW'], 'data_format', self.name)
2986        self.add_prim_attr("data_format", self.data_format)
2987        self.init_prim_io_names(inputs=['origin_input', 'out_backprop', 'argmax'], outputs=['y'])
2988
2989
2990class AffineGridGrad(Primitive):
2991    r"""
2992    Computes gradients for AffineGrid operation.
2993
2994    Args:
2995        align_corners (bool): if True, consider -1 and 1 to refer to the centers
2996            of the corner pixels rather than the image corners. Default: ``False``.
2997
2998    Inputs:
2999        - **y_grad** (Tensor) - Data type must be float16 or float32.
3000        - **x_size** (tuple) - Data type must be int32 or int64.
3001
3002    Outputs:
3003        Tensor, with data type same as `y_grad`.
3004
3005    Supported Platforms:
3006        ``CPU``
3007
3008    Examples:
3009        >>> import mindspore.ops.operations._grad_ops as _grad_ops
3010        >>> affinegridgrad = _grad_ops.AffineGridGrad()
3011        >>> y_grad = Tensor(np.ones([1, 2, 2, 2]), mindspore.float32)
3012        >>> x_size = (1, 2, 2, 2)
3013        >>> x_grad = affinegridgrad(y_grad, x_size)
3014        >>> print(x_grad)
3015        [[[0. 0. 4.]
3016          [0. 0. 4.]]]
3017    """
3018
3019    @prim_attr_register
3020    def __init__(self, align_corners=False):
3021        """Initialize AffineGridGrad."""
3022        validator.check_value_type("align_corners", align_corners, [bool], self.name)
3023        self.init_prim_io_names(inputs=['y_grad', 'x_size'], outputs=['x_grad'])
3024
3025
3026
3027class GluGrad(Primitive):
3028    """
3029    Computes grad for Glu operation.
3030    """
3031
3032    @prim_attr_register
3033    def __init__(self, axis):
3034        self.add_prim_attr("cust_aicpu", self.name)
3035        self.init_prim_io_names(inputs=["grads", "x"], outputs=["y"])
3036        validator.check_value_type("axis", axis, [int], self.name)
3037
3038
3039class MapTensorGetGrad(Primitive):
3040    """
3041    Computes gradients for MapTensorGet operation.
3042
3043    Inputs:
3044        - **map_tensor** (MapTensor) - The input `map_tensor` of the forward operator MapTensorGet.
3045        - **key_tensor** (Tensor) - The input `key_tensor` of the forward operator MapTensorGet.
3046        - **default_value** (Scalar) - The input `default_value` of the forward operator MapTensorGet.
3047        - **grad** (Tensor) - The grad value according the forward operator MapTensorGet.
3048
3049    Outputs:
3050        - **output** (MapTensor) -  MapTensor with grad values.
3051    """
3052    @prim_attr_register
3053    def __init__(self):
3054        """Initialize MapTensorGetGrad"""
3055        self.init_prim_io_names(inputs=['map_tensor', 'key_tensor', 'default_value', 'grad'], outputs=['output'])
3056        self.add_prim_attr('side_effect_mem', True)
3057
3058
3059class ResizeV2Grad(Primitive):
3060    r"""
3061    Calculates the gradient of ResizeV2 operation.
3062
3063    Supported Platforms:
3064        ``CPU``
3065    """
3066
3067    @prim_attr_register
3068    def __init__(self, coordinate_transformation_mode="half_pixel", mode="nearest"):
3069        """Initialize ResizeV2Grad."""
3070        self.init_prim_io_names(inputs=["grads", "roi", "scales", "original_size"], outputs=["y"])
3071        self.add_prim_attr("nearest_mode", "floor")
3072        self.add_prim_attr("cubic_coeff_a", -0.75)
3073        validator.check_value_type(
3074            "coordinate_transformation_mode", coordinate_transformation_mode, [str], self.name)
3075        validator.check_string(coordinate_transformation_mode,
3076                               ["align_corners", "half_pixel"], "coordinate_transformation_mode", self.name)
3077        validator.check_value_type("mode", mode, [str], self.name)
3078        validator.check_string(mode, ["nearest", "linear", "cubic"], "mode", self.name)
3079
3080
3081class WKVGrad(Primitive):
3082    r"""
3083    Calculates the gradient of WKV operation.
3084
3085    Supported Platforms:
3086        ``Ascend``
3087    """
3088
3089    @prim_attr_register
3090    def __init__(self):
3091        """Initialize WKVGrad."""
3092        self.init_prim_io_names(inputs=["time_first", "time_decay", "key", "value", "gy"],
3093                                outputs=["gw", "gu", "gk", "gv"])
3094