• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020-2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16"""Operators for gradients."""
17import math
18from functools import partial
19from mindspore._checkparam import _check_3d_int_or_tuple
20from .. import signature as sig
21from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register
22from ..._checkparam import Validator as validator, Rel
23from .._utils import get_concat_offset
24from ...common import dtype as mstype
25from ... import context
26
27
28class AbsGrad(PrimitiveWithInfer):
29    """Computes gradients for abs operation."""
30
31    @prim_attr_register
32    def __init__(self):
33        """Initialize AbsGrad"""
34
35    def infer_shape(self, y, dy):
36        return y
37
38    def infer_dtype(self, y, dy):
39        return y
40
41
42class ACosGrad(PrimitiveWithInfer):
43    """
44    Computes ACosGrad of input element-wise.
45
46    Returns:
47        Tensor, has the same type as input.
48    """
49
50    @prim_attr_register
51    def __init__(self):
52        """Initialize ACosGrad"""
53
54    def infer_shape(self, x, dout):
55        validator.check("x shape", x, "dout shape", dout, Rel.EQ, self.name)
56        return x
57
58    def infer_dtype(self, x, dout):
59        args = {"x": x, "dout": dout}
60        validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
61        return x
62
63
64class AcoshGrad(PrimitiveWithInfer):
65    """Performs grad of Acosh operation."""
66
67    @prim_attr_register
68    def __init__(self):
69        """Initialize AcoshGrad"""
70
71    def infer_shape(self, x, dout):
72        validator.check("x shape", x, "dout shape", dout, Rel.EQ, self.name)
73        return x
74
75    def infer_dtype(self, x, dout):
76        args = {"x": x, "dout": dout}
77        validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
78        return x
79
80
81class AsinGrad(PrimitiveWithInfer):
82    """
83    Computes AsinGrad of input element-wise.
84
85    Returns:
86        Tensor, has the same type as input.
87    """
88
89    @prim_attr_register
90    def __init__(self):
91        """Initialize AsinGrad"""
92
93    def infer_shape(self, x, dout):
94        validator.check("x shape", x, "dout shape", dout, Rel.EQ, self.name)
95        return x
96
97    def infer_dtype(self, x, dout):
98        args = {"x": x, "dout": dout}
99        validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
100        return x
101
102
103class AsinhGrad(PrimitiveWithInfer):
104    """Performs grad of Asinh operation."""
105
106    @prim_attr_register
107    def __init__(self):
108        """Initialize AsinhGrad"""
109
110    def infer_shape(self, x, dout):
111        validator.check("x shape", x, "dout shape", dout, Rel.EQ, self.name)
112        return x
113
114    def infer_dtype(self, x, dout):
115        args = {"x": x, "dout": dout}
116        validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
117        return x
118
119
120class ReciprocalGrad(PrimitiveWithInfer):
121    """Performs grad of Reciprocal operation."""
122
123    @prim_attr_register
124    def __init__(self):
125        """Initialize ReciprocalGrad"""
126
127    def infer_shape(self, x_shape, dout_shape):
128        validator.check("x shape", x_shape, "dout shape", dout_shape, Rel.EQ, self.name)
129        return x_shape
130
131    def infer_dtype(self, x_dtype, dout_dtype):
132        args = {"x": x_dtype, "dout": dout_dtype}
133        validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name)
134        return x_dtype
135
136
137class RsqrtGrad(PrimitiveWithInfer):
138    """Performs grad of Rsqrt operation."""
139
140    @prim_attr_register
141    def __init__(self):
142        """Initialize RsqrtGrad"""
143
144    def infer_shape(self, x_shape, dout_shape):
145        validator.check("x shape", x_shape, "dout shape", dout_shape, Rel.EQ, self.name)
146        return x_shape
147
148    def infer_dtype(self, x_dtype, dout_dtype):
149        args = {"x": x_dtype, "dout": dout_dtype}
150        validator.check_tensors_dtypes_same_and_valid(args,
151                                                      [mstype.float16, mstype.float32, mstype.int32, mstype.int8],
152                                                      self.name)
153        return x_dtype
154
155
156class SoftmaxGrad(ReciprocalGrad):
157    """Performs grad of Softmax operation."""
158
159
160class SqrtGrad(PrimitiveWithInfer):
161    """Performs grad of Sqrt operation."""
162
163    @prim_attr_register
164    def __init__(self):
165        """Initialize SqrtGrad"""
166
167    def infer_shape(self, x_shape, dout_shape):
168        validator.check("x shape", x_shape, "dout shape", dout_shape, Rel.EQ, self.name)
169        return x_shape
170
171    def infer_dtype(self, x_dtype, dout_dtype):
172        args = {"x": x_dtype, "dout": dout_dtype}
173        valid_types = [mstype.float16, mstype.float32, mstype.float64]
174        validator.check_tensors_dtypes_same_and_valid(args, valid_types, self.name)
175        return x_dtype
176
177
178class BatchNormGrad(PrimitiveWithInfer):
179    """Performs grad of BatchNorm operation."""
180
181    @prim_attr_register
182    def __init__(self, is_training=False, epsilon=1e-5, data_format='NCHW'):
183        self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name)
184        self.epsilon = validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name)
185        self.data_format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name)
186
187    def infer_shape(self, y_backprop_shape, x_shape, scale_shape, save_mean_shape, save_variance_shape, reserve):
188        validator.check("BatchNorm y_backprop_shape", y_backprop_shape, "BatchNorm x_shape", x_shape)
189        return (x_shape, scale_shape, scale_shape)
190
191    def infer_dtype(self, y_backprop_type, x_type, scale_type, save_mean_shape, save_variance_shape, reserve):
192        return (x_type, scale_type, scale_type)
193
194
195class SyncBatchNormGrad(PrimitiveWithInfer):
196    """Performs grad of SyncBatchNorm operation."""
197
198    @prim_attr_register
199    def __init__(self, epsilon=1e-5, group="group0", device_num=2):
200        validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name)
201        if not isinstance(group, str):
202            raise TypeError("The group attr of SyncBatchNormGrad should be str.")
203        validator.check_int(device_num, 2, Rel.GE, "device_num", self.name)
204
205    def infer_shape(self, y_backprop_shape, x_shape, scale_shape, save_mean_shape, save_variance_shape):
206        validator.check("BatchNorm y_backprop_shape", y_backprop_shape, "BatchNorm x_shape", x_shape)
207        return (x_shape, scale_shape, scale_shape)
208
209    def infer_dtype(self, y_backprop_type, x_type, scale_type, save_mean_shape, save_variance_shape):
210        return (x_type, scale_type, scale_type)
211
212
213class BiasAddGrad(Primitive):
214    """Computes gradients of BiasAdd."""
215
216    @prim_attr_register
217    def __init__(self, data_format="NCHW"):
218        self.init_prim_io_names(inputs=['dout'], outputs=['output'])
219        self.format = validator.check_string(data_format, ['NCHW', 'NHWC', 'NCDHW'], 'format', self.name)
220        if context.get_context("device_target") != "GPU" and self.format == "NHWC":
221            raise ValueError("NHWC format only support in GPU target.")
222        if self.format == "NCDHW":
223            self.format = "NCHW"
224        self.add_prim_attr('data_format', self.format)
225
226
227class KLDivLossGrad(PrimitiveWithInfer):
228    """Computes gradients for `KLDivLoss` operation."""
229
230    @prim_attr_register
231    def __init__(self, reduction='mean'):
232        self.reduction = validator.check_string(reduction, ['none', 'mean', 'sum'], 'reduction', self.name)
233
234    def infer_shape(self, x_shape, y_shape, doutput_shape):
235        validator.check('x_shape', x_shape, 'y_shape', y_shape, Rel.EQ, self.name)
236        return x_shape, y_shape
237
238    def infer_dtype(self, x_type, y_type, doutput_type):
239        args = {'x_type': x_type, 'y_type': y_type, 'doutput_type': doutput_type}
240        validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
241        return x_type, y_type
242
243
244class BinaryCrossEntropyGrad(PrimitiveWithInfer):
245    """Computes gradients for `BinaryCrossEntropy` operation."""
246
247    @prim_attr_register
248    def __init__(self, reduction='mean'):
249        self.reduction = validator.check_string(reduction, ['none', 'mean', 'sum'], 'reduction', self.name)
250
251    def infer_shape(self, x_shape, y_shape, doutput_shape, weight_shape):
252        validator.check('x_shape', x_shape, 'y_shape', y_shape, Rel.EQ, self.name)
253        if weight_shape:
254            validator.check('y_shape', y_shape, 'weight_shape', weight_shape, Rel.EQ, self.name)
255        return x_shape
256
257    def infer_dtype(self, x_type, y_type, doutput_type, weight_type):
258        args = {'x_type': x_type, 'y_type': y_type, 'doutput_type': doutput_type}
259        validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
260        if weight_type:
261            validator.check('x_type', x_type, 'weight_type', weight_type, Rel.EQ, TypeError)
262        return x_type
263
264
265class ConcatOffset(PrimitiveWithInfer):
266    """primitive for computing Concat's gradient."""
267
268    @prim_attr_register
269    def __init__(self, N=2, axis=0):
270        """Initialize ConcatOffset"""
271
272    def __infer__(self, input_x):
273        axis = self.axis
274        x_shp = input_x['shape']
275        x_type = input_x['dtype']
276        offset, _, axis = get_concat_offset(x_shp, x_type, axis, self.name)
277        self.add_prim_attr('T', x_type[0].element_type())
278        offset_values = []
279        for i in range(len(x_shp)):
280            values = []
281            for j in range(len(x_shp[0])):
282                value = 0
283                if j == axis:
284                    value = offset[i]
285                values.append(value)
286            offset_values.append(tuple(values))
287        out = {'shape': None,
288               'dtype': None,
289               'value': tuple(offset_values)}
290        return out
291
292
293class Conv3DBackpropFilter(PrimitiveWithInfer):
294    """
295    Computes the gradients of convolution 3D with respect to the filter.
296
297    Args:
298        out_channel (int): The dimension of the output.
299        kernel_size (Union[int, tuple[int]]): The kernel size of the 3D convolution.
300        mode (int): Modes for different convolutions. Not currently used.
301        pad_mode (str): Modes to fill padding. It could be "valid", "same", or "pad". Default: "valid".
302        pad (Union(int, tuple[int])): The pad value to be filled. Default: 0. If `pad` is an integer, the paddings of
303                    head, tail, top, bottom, left and right are the same, equal to pad. If `pad` is a tuple of four
304                    integers, the padding of head, tail, top, bottom, left and right equal to pad[0], pad[1], pad[2],
305                    pad[3], pad[4] and pad[5] correspondingly.
306        stride (Union(int, tuple[int])): The stride to be applied to the convolution filter. Default: 1.
307        dilation (Union(int, tuple[int])): Specifies the space to use between kernel elements. Default: 1.
308        group (int): Splits input into groups. Default: 1.
309        data_format (str): The optional value for data format. Currently only support 'NCDHW'.
310
311    Inputs:
312        - **x** (Tensor) - The input of the convolution, then the shape is :math:`(C_{out}, C_{in}, D_{in}, K_1, K_2)`.
313          Currently dout data type only support float16 and float32.
314        - **dout** (Tensor) - The gradients w.r.t the output of the convolution. The shape conforms to the default
315          data_format :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})`. Currently dout data type only support float16
316          and float32.
317        - **w_size** (tuple(int)) - A tuple describes the shape of the weight which conforms to the format
318          :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})`.
319
320    Outputs:
321        Tensor, the gradients w.r.t the weight of convolution 3D. It has the same shape as the weight.
322
323    Supported Platforms:
324        ``Ascend``
325
326    Examples:
327        >>> x = Tensor(np.ones([16, 32, 13, 37, 33]), mindspore.float16)
328        >>> dout = Tensor(np.ones([16, 32, 10, 32, 32]), mindspore.float16)
329        >>> w = Tensor(np.ones([32, 32, 4, 6, 2]), mindspore.float16)
330        >>> conv3d_backprop_input = P.Conv3DBackpropInput(out_channel=4, kernel_size=(4, 6, 2))
331        >>> output = conv3d_backprop_input(x, dout, F.shape(w))
332        >>> print(output.shape)
333        (32, 32, 4, 6, 2)
334    """
335
336    @prim_attr_register
337    def __init__(self,
338                 out_channel,
339                 kernel_size,
340                 mode=1,
341                 pad_mode="valid",
342                 pad=0,
343                 stride=(1, 1, 1, 1, 1),
344                 dilation=(1, 1, 1, 1, 1),
345                 group=1,
346                 data_format="NCDHW"):
347        """Initialize Convolution"""
348        self.init_prim_io_names(inputs=['x', 'out_backprop', 'filter_size'], outputs=['y'])
349        self.out_channel = validator.check_positive_int(out_channel, 'out_channel', self.name)
350        self.kernel_size = _check_3d_int_or_tuple('kernel_size', kernel_size, self.name)
351        self.stride = _check_3d_int_or_tuple('stride', stride, self.name, allow_five=True, ret_five=True)
352        self.add_prim_attr('strides', self.stride)
353        self.dilation = _check_3d_int_or_tuple('dilation', dilation, self.name, allow_five=True, ret_five=True)
354        self.add_prim_attr('dilations', self.dilation)
355        validator.check_value_type('pad', pad, (int, tuple), self.name)
356        if isinstance(pad, int):
357            pad = (pad,) * 6
358        validator.check_equal_int(len(pad), 6, 'pad size', self.name)
359        self.add_prim_attr('pad', self.pad)
360        self.pad_list = pad
361        self.add_prim_attr('pad_list', self.pad_list)
362
363        self.pad_mode = validator.check_string(pad_mode.lower(), ['valid', 'same', 'pad'], 'pad_mode', self.name)
364        if self.pad_mode != 'pad' and self.pad_list != (0, 0, 0, 0, 0, 0):
365            raise ValueError(f"For '{self.name}', when pad is not 0, pad_mode should be set as 'pad'.")
366        if self.pad_mode == 'pad':
367            for item in pad:
368                validator.check_non_negative_int(item, 'pad item', self.name)
369        self.add_prim_attr('pad_mode', self.pad_mode)
370
371        self.mode = validator.check_equal_int(mode, 1, 'mode', self.name)
372        self.add_prim_attr('mode', self.mode)
373        self.group = validator.check_positive_int(group, 'group', self.name)
374        self.add_prim_attr('groups', self.group)
375        self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.name)
376        self.add_prim_attr('data_format', self.format)
377
378    def __infer__(self, x, doutput, w_size):
379        w_size_v = w_size['value']
380        validator.check_value_type('w_size', w_size_v, [tuple], self.name)
381        for i, dim_len in enumerate(w_size_v):
382            validator.check_value_type("w_size[%d]" % i, dim_len, [int], self.name)
383        args = {"x": x['dtype'], "doutput": doutput['dtype']}
384        valid_dtypes = [mstype.float16, mstype.float32]
385        validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
386
387        validator.check("filter's batch", w_size_v[0], "dout's channel", doutput['shape'][1], Rel.EQ, self.name)
388        validator.check("filter's channel", w_size_v[1], "input_size's channel", x['shape'][1], Rel.EQ, self.name)
389        validator.check("input_size's batch", x['shape'][0], "dout's batch", doutput['shape'][0], Rel.EQ, self.name)
390
391        # infer shape
392        x_shape = x['shape']
393        dout_shape = doutput['shape']
394        kernel_d = self.kernel_size[0]
395        kernel_h = self.kernel_size[1]
396        kernel_w = self.kernel_size[2]
397        stride_d = self.stride[2]
398        stride_h = self.stride[3]
399        stride_w = self.stride[4]
400        dilation_d = self.dilation[2]
401        dilation_h = self.dilation[3]
402        dilation_w = self.dilation[4]
403        # The pad_mode is valid by default. If pad_mode is not valid or same, then pad.
404        if self.pad_mode == "valid":
405            self.pad_list = (0, 0, 0, 0, 0, 0)
406        if self.pad_mode == "same":
407            pad_needed_d = max(0, (dout_shape[2] - 1) * stride_d + dilation_d * (kernel_d - 1) + 1 - x_shape[2])
408            pad_head = math.floor(pad_needed_d / 2)
409            pad_tail = pad_needed_d - pad_head
410
411            pad_needed_h = max(0, (dout_shape[3] - 1) * stride_h + dilation_h * (kernel_h - 1) + 1 - x_shape[3])
412            pad_top = math.floor(pad_needed_h / 2)
413            pad_bottom = pad_needed_h - pad_top
414
415            pad_needed_w = max(0, (dout_shape[4] - 1) * stride_w + dilation_w * (kernel_w - 1) + 1 - x_shape[4])
416            pad_left = math.floor(pad_needed_w / 2)
417            pad_right = pad_needed_w - pad_left
418            self.pad_list = (pad_head, pad_tail, pad_top, pad_bottom, pad_left, pad_right)
419
420        self.add_prim_attr('pad_list', self.pad_list)
421        out = {
422            'value': None,
423            'shape': w_size_v,
424            'dtype': mstype.float32,
425        }
426        return out
427
428
429class Conv2DBackpropFilter(Primitive):
430    """
431    Computes the gradients of convolution with respect to the filter.
432
433    Args:
434        out_channel (int): The dimensionality of the output space.
435        kernel_size (Union[int, tuple[int]]): The size of the convolution window.
436        pad_mode (str): Modes to fill padding. It could be "valid", "same", or "pad". Default: "valid".
437        pad (Union(int, tuple[int])): The pad value to be filled. Default: 0. If `pad` is an integer, the paddings of
438                    top, bottom, left and right are the same, equal to pad. If `pad` is a tuple of four integers, the
439                    padding of top, bottom, left and right equal to pad[0], pad[1], pad[2], and pad[3] correspondingly.
440        pad_list (tuple): The pad list like (top, bottom, left, right). Default: (0, 0, 0, 0).
441        mode (int): Modes for different convolutions. 0 Math convolutiuon, 1 cross-correlation convolution ,
442                    2 deconvolution, 3 depthwise convolution. Default: 1.
443        stride (tuple): The stride to be applied to the convolution filter. Default: (1, 1).
444        dilation (tuple): Specifies the dilation rate to be used for the dilated convolution. Default: (1, 1, 1, 1).
445        group (int): Splits input into groups. Default: 1.
446        data_format (str) - The format of input and output data. It should be 'NHWC' or 'NCHW',\
447            default is 'NCHW'.
448
449    Returns:
450        Tensor, the gradients of convolution.
451    """
452
453    @prim_attr_register
454    def __init__(self,
455                 out_channel,
456                 kernel_size,
457                 pad_mode="valid",
458                 pad=0,
459                 pad_list=(0, 0, 0, 0),
460                 mode=1,
461                 stride=(1, 1),
462                 dilation=(1, 1, 1, 1),
463                 group=1,
464                 data_format="NCHW"):
465        """Initialize Convolution"""
466        self.init_prim_io_names(inputs=['out_backprop', 'input', 'filter_sizes'], outputs=['output'])
467        self.out_channel = out_channel
468        self.kernel_size = kernel_size
469        self.mode = mode
470        pad_mode = pad_mode.upper()
471        self.add_prim_attr('pad_mode', pad_mode)
472        if isinstance(pad, int):
473            pad = (pad,) * 4
474        else:
475            validator.check_equal_int(len(pad), 4, 'pad size', self.name)
476        self.add_prim_attr("pad", pad)
477        if isinstance(stride, tuple) and len(stride) == 4:
478            self.stride = (stride[2], stride[3])
479            self.add_prim_attr('stride', self.stride)
480        self.dilation = dilation
481        self.group = group
482        self.add_prim_attr('groups', group)
483        self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name)
484        if context.get_context("device_target") != "GPU" and self.format == "NHWC":
485            raise ValueError("NHWC format only support in GPU target.")
486        self.add_prim_attr('data_format', self.format)
487
488
489class DepthwiseConv2dNativeBackpropFilter(PrimitiveWithInfer):
490    """
491    Returns the gradient of filter for DepthwiseConv2dNative.
492
493    Applies depthwise conv2d for the input, which will generate more channels with channel_multiplier.
494
495    Refer to class DepthwiseConv2dNative for more details.
496
497    Args:
498        channel_multiplier (int): The multiplier for the original output conv.
499        kernel_size (int or tuple): The size of the conv kernel.
500        mode (int): Modes for different convolutions. 0 Math convolutiuon, 1 cross-correlation convolution,
501                       2 deconvolution,3 depthwise convolution. Default: 3.
502        pad_mode (str): The mode to fill padding which can be: "valid", "same" or "pad". Default: "valid".
503        pad (Union(int, tuple[int])): The pad value to be filled. Default: 0. If `pad` is an integer, the paddings of
504                    top, bottom, left and right are the same, equal to pad. If `pad` is a tuple of four integers, the
505                    padding of top, bottom, left and right equal to pad[0], pad[1], pad[2], and pad[3] correspondingly.
506        pad_list (tuple): The pad list like (top, bottom, left, right). Default: (0, 0, 0, 0).
507        stride (int): The stride to be applied to the convolution filter. Default: 1.
508        dilation (int): Specifies the space to use between kernel elements. Default: 1.
509        group (int): Splits input into groups. Default: 1.
510
511    Returns:
512        Tensor, the value is the gradient of filter for DepthwiseConv2dNative.
513    """
514
515    @prim_attr_register
516    def __init__(self,
517                 channel_multiplier,
518                 kernel_size,
519                 pad_mode="valid",
520                 pad=0,
521                 pad_list=(0, 0, 0, 0),
522                 mode=3,
523                 stride=1,
524                 dilation=1,
525                 group=1):
526        """Initialize Convolution"""
527        self.init_prim_io_names(inputs=['input', 'filter_size', 'dout'], outputs=['output'])
528        self.channel_multiplier = channel_multiplier
529        self.kernel_size = kernel_size
530        self.mode = mode
531        self.pad_mode = pad_mode
532        if isinstance(pad, int):
533            pad = (pad,) * 4
534        else:
535            validator.check_equal_int(len(pad), 4, 'pad size', self.name)
536        self.add_prim_attr("pad", pad)
537        self.pad_list = pad_list
538        self.stride = stride
539        self.dilation = dilation
540        self.group = group
541        self.add_prim_attr('data_format', "NCHW")
542
543    def __call__(self, x, w_size, dout):
544        raise NotImplementedError
545
546    def __infer__(self, x, w_size, dout):
547        w_size_v = w_size['value']
548        args = {'x': x['dtype'], 'dout': dout['dtype']}
549        validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
550        out = {
551            'value': None,
552            'shape': w_size_v,
553            'dtype': dout['dtype'],
554        }
555        return out
556
557
558class DepthwiseConv2dNativeBackpropInput(PrimitiveWithInfer):
559    """
560    Returns the gradient of input for DepthwiseConv2dNative.
561
562    Applies depthwise conv2d for the input, which will generate more channels with channel_multiplier.
563
564    Args:
565        channel_multiplier (int): The multiplier for the original output conv.
566        kernel_size (int or tuple): The size of the conv kernel.
567        mode (int): Modes for different convolutions. 0 Math convolutiuon, 1 cross-correlation convolution ,
568                    2 deconvolution,3 depthwise convolution. Default: 3.
569        pad_mode (str):  Modes to fill padding. It could be "valid", "same", or "pad". Default: "valid".
570        pad (Union(int, tuple[int])): The pad value to be filled. Default: 0. If `pad` is an integer, the paddings of
571                    top, bottom, left and right are the same, equal to pad. If `pad` is a tuple of four integers, the
572                    padding of top, bottom, left and right equal to pad[0], pad[1], pad[2], and pad[3] correspondingly.
573        pad_list (tuple): The pad list like (top, bottom, left, right). Default: (0, 0, 0, 0).
574        stride (int): The stride to be applied to the convolution filter. Default: 1.
575        dilation (int): Specifies the space to use between kernel elements. Default: 1.
576        group (int): Splits input into groups. Default: 1.
577
578    Returns:
579        Tensor, the value is the gradient of input for DepthwiseConv2dNative.
580    """
581
582    @prim_attr_register
583    def __init__(self,
584                 channel_multiplier,
585                 kernel_size,
586                 pad_mode="valid",
587                 pad=0,
588                 pad_list=(0, 0, 0, 0),
589                 mode=3,
590                 stride=1,
591                 dilation=1,
592                 group=1):
593        """Initialize Convolution"""
594        self.init_prim_io_names(inputs=['input_size', 'filter', 'dout'], outputs=['output'])
595        self.channel_multiplier = channel_multiplier
596        self.kernel_size = kernel_size
597        self.mode = mode
598        self.pad_mode = pad_mode
599        if isinstance(pad, int):
600            pad = (pad,) * 4
601        else:
602            validator.check_equal_int(len(pad), 4, 'pad size', self.name)
603        self.add_prim_attr("pad", pad)
604        self.pad_list = pad_list
605        self.stride = stride
606        self.dilation = dilation
607        self.group = group
608        self.add_prim_attr('data_format', "NCHW")
609
610    def __call__(self, x_size, w, dout):
611        raise NotImplementedError
612
613    def __infer__(self, x_size, w, dout):
614        args = {'w': w['dtype'], 'dout': dout['dtype']}
615        validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
616        x_size_v = x_size['value']
617        out = {
618            'value': None,
619            'shape': x_size_v,
620            'dtype': dout['dtype'],
621        }
622        return out
623
624
625class DropoutGrad(Primitive):
626    """
627    The gradient of Dropout. During training, randomly zeroes some of the elements
628    of the input tensor with probability.
629
630    Args:
631        keep_prob (float): The keep rate, between 0 and 1, e.g. keep_prob = 0.9,
632          means dropping out 10% of input units. Default: 0.5.
633
634    Inputs:
635        - **shape** (tuple[int]) - The shape of target mask.
636
637    Outputs:
638        Tensor, the value of generated mask for input shape.
639
640    Examples:
641        >>> dropout_grad = ops.DropoutGrad(keep_prob=0.5)
642        >>> in = Tensor((20, 16, 50, 50))
643        >>> out = dropout_grad(in)
644    """
645
646    @prim_attr_register
647    def __init__(self, keep_prob=0.5):
648        self.keep_prob = validator.check_float_range(keep_prob, 0, 1, Rel.INC_RIGHT, "keep_prob", self.name)
649
650
651class FlattenGrad(PrimitiveWithInfer):
652    """Performs gradients of Flatten."""
653
654    @prim_attr_register
655    def __init__(self):
656        self.init_prim_io_names(inputs=['x', 'shape'], outputs=['output'])
657
658    def __infer__(self, *args):
659        out = {
660            'value': None,
661            'shape': args[1]['value'],
662            'dtype': args[0]['dtype'],
663        }
664        return out
665
666
667class InstanceNormGrad(PrimitiveWithInfer):
668    """Gradients of InstanceNorm operation."""
669
670    @prim_attr_register
671    def __init__(self, epsilon=0.0, momentum=0.1):
672        self.init_prim_io_names(inputs=['dy', 'x', 'gamma', 'save_mean', 'save_variance'],
673                                outputs=['dx', 'bn_gamma', 'bn_beta'])
674
675    def infer_shape(self, y_backprop_shape, x_shape, gamma_shape, save_mean_shape, save_variance_shape):
676        return (x_shape, gamma_shape, gamma_shape)
677
678    def infer_dtype(self, y_backprop_type, x_type, gamma_type, save_mean_type, save_variance_type):
679        return (x_type, gamma_type, gamma_type)
680
681
682class UniqueGrad(Primitive):
683    """Gradients of Unique operation."""
684
685    @prim_attr_register
686    def __init__(self):
687        self.init_prim_io_names(inputs=['dy', 'y'], outputs=['dx'])
688
689    def __call__(self, dy, x, scale, save_mean, save_inv_variance):
690        raise NotImplementedError
691
692
693class BNTrainingReduceGrad(PrimitiveWithInfer):
694    """Gradients of FusedBatchNorm operation."""
695
696    @prim_attr_register
697    def __init__(self, epsilon=0.0001):
698        _inputs = ['grads', 'x', 'diff_scale', 'diff_offset', 'scale', 'batch_mean', 'batch_variance']
699        self.init_prim_io_names(inputs=_inputs, outputs=['y'])
700
701    def infer_shape(self, grads, x, diff_scale, diff_offset, scale, batch_mean, batch_variance):
702        return grads
703
704    def infer_dtype(self, grads, x, diff_scale, diff_offset, scale, batch_mean, batch_variance):
705        return grads
706
707
708class BNTrainingUpdateGrad(PrimitiveWithInfer):
709    """Gradients of FusedBatchNorm operation."""
710
711    @prim_attr_register
712    def __init__(self, epsilon=0.0001):
713        self.init_prim_io_names(inputs=['grads', 'x', 'batch_mean', 'batch_variance'],
714                                outputs=['diff_scale', 'diff_offset'])
715
716    def infer_shape(self, grads, x, batch_mean, batch_variance):
717        return (batch_mean, batch_variance)
718
719    def infer_dtype(self, grads, x, batch_mean, batch_variance):
720        return (batch_mean, batch_variance)
721
722
723class GeLUGrad(PrimitiveWithInfer):
724    """Gradients of GeLU operation."""
725
726    @prim_attr_register
727    def __init__(self):
728        """Initialize GeLUGrad"""
729
730    def infer_shape(self, y_backprop_shape, x_shape, y_shape):
731        return x_shape
732
733    def infer_dtype(self, y_backprop_dtype, x_dtype, y_dtype):
734        tuple(map(partial(validator.check_tensor_dtype_valid,
735                          valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name),
736                  ("y_backprop", "x", "y"),
737                  (y_backprop_dtype, x_dtype, y_dtype)))
738        return x_dtype
739
740
741class FastGeLUGrad(PrimitiveWithInfer):
742    """Gradients of FastGeLU operation."""
743
744    @prim_attr_register
745    def __init__(self):
746        """init FastGeLUGrad"""
747
748    def infer_shape(self, y_backprop_shape, x_shape):
749        return x_shape
750
751    def infer_dtype(self, y_backprop_dtype, x_dtype):
752        tuple(map(partial(validator.check_tensor_dtype_valid,
753                          valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name),
754                  ("y_backprop", "x"),
755                  (y_backprop_dtype, x_dtype)))
756        return x_dtype
757
758
759class _PoolGrad(PrimitiveWithInfer):
760    """Gradients of the max/avg pool operation."""
761
762    @prim_attr_register
763    def __init__(self, kernel_size, strides, pad_mode="VALID", data_format="NCHW"):
764        self.init_prim_io_names(inputs=['x_origin', 'out_origin', 'grad'], outputs=['output'])
765
766        validator.check_value_type('kernel_size', kernel_size, [int, tuple], self.name)
767        validator.check_value_type('strides', strides, [int, tuple], self.name)
768        self.pad_mode = validator.check_string(pad_mode.upper(), ['VALID', 'SAME'], 'pad_mode', self.name)
769        self.add_prim_attr("pad_mode", self.pad_mode)
770        self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name)
771        if context.get_context("device_target") != "GPU" and self.format == "NHWC":
772            raise ValueError("NHWC format only support in GPU target.")
773        self.is_maxpoolgradwithargmax = (self.name == "MaxPoolGradWithArgmax")
774        if not self.is_maxpoolgradwithargmax:
775            self.add_prim_attr('data_format', self.format)
776
777        def _grad_check_int_or_tuple(arg_name, arg_val, is_argmax):
778            validator.check_value_type(arg_name, arg_val, (int, tuple), self.name)
779            error_msg = ValueError(f"For '{self.name}' the '{arg_name}' should be an positive int number "
780                                   f"or a tuple of two or four positive int numbers, but got {arg_val}")
781            if isinstance(arg_val, int):
782                ret = (1, arg_val, arg_val, 1) if is_argmax else (1, 1, arg_val, arg_val)
783            elif len(arg_val) == 2:
784                ret = (1, arg_val[0], arg_val[1], 1) if is_argmax else (1, 1, arg_val[0], arg_val[1])
785            elif len(arg_val) == 4:
786                ret = arg_val
787            else:
788                raise error_msg
789            # whether all elements of tuple are positive integers
790            for item in ret:
791                if not isinstance(item, int) or item <= 0:
792                    raise error_msg
793            return ret
794
795        kernel_size = _grad_check_int_or_tuple("kernel_size", kernel_size, self.is_maxpoolgradwithargmax)
796        self.kernel_size = kernel_size if self.format == "NCHW" else [kernel_size[0], kernel_size[2],
797                                                                      kernel_size[3], kernel_size[1]]
798        self.add_prim_attr("kernel_size", self.kernel_size)
799
800        strides = _grad_check_int_or_tuple("strides", strides, self.is_maxpoolgradwithargmax)
801        self.strides = strides if self.format == "NCHW" else [strides[0], strides[2], strides[3], strides[1]]
802        self.add_prim_attr("strides", self.strides)
803
804
805class AvgPoolGradVm(_PoolGrad):
806    """Gradients of the avg pool operation for vm."""
807
808    @prim_attr_register
809    def __init__(self, kernel_size=1, strides=1, pad_mode="VALID"):
810        super(AvgPoolGradVm, self).__init__(kernel_size, strides, pad_mode)
811        self.init_prim_io_names(inputs=['x_origin', 'grad', 'mean_matrix', 'kernel_matrix'], outputs=['output'])
812
813    def __infer__(self, origin_input, dout, mean_matrix, kernel_matrix):
814        out = {
815            'value': None,
816            'shape': tuple(origin_input['value']),
817            'dtype': dout['dtype'],
818        }
819
820        return out
821
822
823class AvgPoolGrad(_PoolGrad):
824    """Gradients of the avg pool operation."""
825
826    @prim_attr_register
827    def __init__(self, kernel_size=1, strides=1, pad_mode="VALID", data_format="NCHW"):
828        super(AvgPoolGrad, self).__init__(kernel_size, strides, pad_mode, data_format)
829
830    def infer_shape(self, x1_shape, x2_shape, grad_shape):
831        return x1_shape
832
833    def infer_dtype(self, x1_dtype, x2_dtype, grad_dtype):
834        return x1_dtype
835
836
837class AdaptiveAvgPool2DGrad(PrimitiveWithInfer):
838    """Gradients of the adaptive avg pool 2D operation."""
839
840    @prim_attr_register
841    def __init__(self):
842        """Initialize AdaptiveAvgPool2DGrad"""
843
844    def infer_shape(self, x1_shape, grad_shape):
845        return x1_shape
846
847    def infer_dtype(self, x1_dtype, grad_dtype):
848        return x1_dtype
849
850
851class AvgPool3DGrad(Primitive):
852    """Gradients of the avg pool3d operation."""
853
854    @prim_attr_register
855    def __init__(self, kernel_size=1, strides=1, pads=0, ceil_mode=False,
856                 count_include_pad=True, divisor_override=0, data_format="NCDHW"):
857        self.init_prim_io_names(inputs=['origin_input_shape', 'grads'], outputs=['output'])
858        self.kernel_size = _check_3d_int_or_tuple('kernel_size', kernel_size, self.name)
859        self.add_prim_attr('kernel_size', self.kernel_size)
860        self.strides = _check_3d_int_or_tuple('strides', strides, self.name)
861        self.add_prim_attr('strides', self.strides)
862        validator.check_value_type('pads', pads, (int, tuple), self.name)
863        if isinstance(pads, int):
864            pads = (pads,) * 6
865        validator.check_equal_int(len(pads), 6, 'pad size', self.name)
866        for item in pads:
867            validator.check_non_negative_int(item, 'pad item', self.name)
868        self.add_prim_attr('pad_list', pads)
869        self.ceil_mode = validator.check_value_type('ceil_mode', ceil_mode, bool, self.name)
870        self.count_include_pad = validator.check_value_type('count_include_pad', count_include_pad, bool, self.name)
871        self.divisor_override = validator.check_value_type('divisor_override', divisor_override, int, self.name)
872        self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.name)
873
874
875class MaxPoolGrad(_PoolGrad):
876    """Performs gradients of the max pool operation."""
877
878    @prim_attr_register
879    def __init__(self, kernel_size=1, strides=1, pad_mode="VALID", data_format="NCHW"):
880        super(MaxPoolGrad, self).__init__(kernel_size, strides, pad_mode, data_format)
881
882    def infer_shape(self, x1_shape, x2_shape, grad_shape):
883        return x1_shape
884
885    def infer_dtype(self, x1_dtype, x2_dtype, grad_dtype):
886        return x1_dtype
887
888
889class MaxPoolGradGrad(_PoolGrad):
890    r"""
891    Performs gradients of the MaxPoolGrad operation.
892
893    Args:
894        kernel_size (Union[int, tuple[int]]): The size of kernel used to take the maximum value,
895            is an int number that represents height and width are both kernel_size, or a tuple
896            of two int numbers that represent height and width respectively. Default: 1.
897        strides (Union[int, tuple[int]]): The distance of kernel moving, an int number that represents
898            the height and width of movement are both strides, or a tuple of two int numbers that
899            represent height and width of movement respectively. Default: 1.
900        pad_mode (str): The optional value for pad mode, is "same" or "valid", not case sensitive.
901            Default: "valid".
902
903            - same: Adopts the way of completion. The height and width of the output will be the same as
904              the input. The total number of padding will be calculated in horizontal and vertical
905              directions and evenly distributed to top and bottom, left and right if possible.
906              Otherwise, the last extra padding will be done from the bottom and the right side.
907
908            - valid: Adopts the way of discarding. The possible largest height and width of output
909              will be returned without padding. Extra pixels will be discarded.
910
911    Inputs:
912        - **origin_input** (Tensor) - Tensor with data format "NCHW", data type must be float16.
913        - **origin_output** (Tensor) - Data type same as `origin_input`.
914        - **grad** (Tensor) - Data type same as `origin_input`.
915
916    Outputs:
917        Tensor, with data type same as `origin_input`.
918
919    """
920
921    @prim_attr_register
922    def __init__(self, kernel_size=1, strides=1, pad_mode="VALID"):
923        super(MaxPoolGradGrad, self).__init__(kernel_size, strides, pad_mode)
924
925    def infer_shape(self, x1_shape, x2_shape, grad_shape):
926        return x2_shape
927
928    def infer_dtype(self, x1_dtype, x2_dtype, grad_dtype):
929        args = {'x1_dtype': x1_dtype, 'x2_dtype': x2_dtype, 'grad_dtype': grad_dtype}
930        validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16], self.name)
931        return x2_dtype
932
933
934def _get_max_pool3d_grad_pads_by_pad_mode(input_shape, kernel_size, strides, pad_mode):
935    """
936    helper for get max pool3d grad pads by pad_mode
937    """
938    def get_pad(origin_shape, ksize, stride):
939        tail = origin_shape % stride
940        pad = (ksize - tail) if tail > 0 else (ksize - stride)
941        pad = max(pad, 0)
942        pad1 = int(pad / 2)
943        pad2 = int(pad / 2) + pad % 2
944        return pad1, pad2
945
946    _, _, d, h, w = input_shape
947    _, _, kd, kh, kw = kernel_size
948    _, _, strd, strh, strw = strides
949
950    pads = (0, 0, 0, 0, 0, 0)
951    if pad_mode == 'SAME':
952        pads_d = get_pad(d, kd, strd)
953        pads_h = get_pad(h, kh, strh)
954        pads_w = get_pad(w, kw, strw)
955        pads = pads_d + pads_h + pads_w
956    return pads
957
958
959class MaxPool3DGrad(PrimitiveWithInfer):
960    """Gradients of the max pool3d operation."""
961
962    @prim_attr_register
963    def __init__(self, kernel_size=(1, 1, 1, 1, 1), strides=(1, 1, 1, 1, 1),
964                 pad_mode='VALID', pad_list=0, data_format="NCDHW"):
965        validator.check_value_type('kernel_size', kernel_size, [int, tuple], self.name)
966        validator.check_value_type('strides', strides, [int, tuple], self.name)
967        validator.check_value_type('pad_mode', pad_mode, [str], self.name)
968        self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.name)
969        if pad_mode.upper() == 'PAD':
970            pad_mode = 'CALCULATED'
971        self.pad_mode = validator.check_string(pad_mode.upper(), ['VALID', 'SAME', 'CALCULATED'], 'pad_mode', self.name)
972        self.kernel_size = _check_3d_int_or_tuple("kernel_size", kernel_size, self.name,
973                                                  allow_five=True, ret_five=True)
974        self.add_prim_attr("kernel_size", self.kernel_size)
975        self.strides = _check_3d_int_or_tuple("strides", strides, self.name, allow_five=True, ret_five=True)
976        self.add_prim_attr("strides", self.strides)
977        validator.check_value_type('pad_list', pad_list, (int, tuple), self.name)
978        self.pad_list = pad_list
979        if isinstance(self.pad_list, int):
980            self.pad_list = (self.pad_list,) * 6
981        if len(self.pad_list) == 3:
982            self.pad_list = (pad_list[0], pad_list[0], pad_list[1], pad_list[1], pad_list[2], pad_list[3])
983        if len(self.pad_list) != 3 and len(self.pad_list) != 6:
984            raise ValueError(f"For `maxpool3d` attr 'pad_list' should be an positive int number or a tuple of "
985                             f"three or six positive int numbers, but got `{len(self.pad_list)}` numbers.")
986        if self.pad_mode != 'CALCULATED' and self.pad_list != (0, 0, 0, 0, 0, 0):
987            raise ValueError(f"For '{self.name}', when pad_list is not 0, pad_mode should be set as 'pad'.")
988        if self.pad_mode == 'CALCULATED':
989            for item in self.pad_list:
990                validator.check_non_negative_int(item, 'pad_list item', self.name)
991        self.add_prim_attr("pad_list", self.pad_list)
992
993    def infer_shape(self, x_shape, y_shape, grad_shape):
994        validator.check_equal_int(len(x_shape), 5, "x rank", self.name)
995        return x_shape
996
997    def infer_dtype(self, x_dtype, y_dtype, grad_dtype):
998        args = {'x_dtype': x_dtype, 'y_dtype': y_dtype, 'grad_dtype': grad_dtype}
999        validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name)
1000        return x_dtype
1001
1002
1003class MaxPool3DGradGrad(PrimitiveWithInfer):
1004    """Gradients of the max pool3d grad operation."""
1005
1006    @prim_attr_register
1007    def __init__(self, kernel_size=(1, 1, 1, 1, 1), strides=(1, 1, 1, 1, 1), pad_mode='VALID', data_format="NCDHW"):
1008        validator.check_value_type('kernel_size', kernel_size, [int, tuple], self.name)
1009        validator.check_value_type('strides', strides, [int, tuple], self.name)
1010        validator.check_value_type('pad_mode', pad_mode, [str], self.name)
1011        self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.name)
1012        self.pad_mode = validator.check_string(pad_mode.upper(), ['VALID', 'SAME'], 'pad_mode', self.name)
1013        self.kernel_size = _check_3d_int_or_tuple("kernel_size", kernel_size, self.name,
1014                                                  allow_five=True, ret_five=True)
1015        self.add_prim_attr("kernel_size", self.kernel_size)
1016        self.strides = _check_3d_int_or_tuple("strides", strides, self.name, allow_five=True, ret_five=True)
1017        self.add_prim_attr("strides", self.strides)
1018
1019    def infer_shape(self, x_shape, y_shape, grad_shape):
1020        validator.check_equal_int(len(x_shape), 5, "x rank", self.name)
1021        validator.check('x_shape', x_shape, 'grad_shape', grad_shape, prim_name=self.name)
1022        pad_list = _get_max_pool3d_grad_pads_by_pad_mode(x_shape, self.kernel_size, self.strides, self.pad_mode)
1023        for pad in pad_list:
1024            validator.check_non_negative_int(pad, 'element of pad_list', self.name)
1025        self.add_prim_attr("pad_list", pad_list)
1026        return y_shape
1027
1028    def infer_dtype(self, x_dtype, y_dtype, grad_dtype):
1029        args = {'x_dtype': x_dtype, 'y_dtype': y_dtype}
1030        validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name)
1031        validator.check_tensor_dtype_valid('grad_dtype', grad_dtype, [mstype.float16, mstype.float32], self.name)
1032        return x_dtype
1033
1034
1035class MaximumGrad(Primitive):
1036    """Grad for maximum."""
1037
1038    @prim_attr_register
1039    def __init__(self, grad_x=True, grad_y=True):
1040        """Initialize MaximumGrad"""
1041
1042    def __call__(self, x, y, dout):
1043        raise NotImplementedError
1044
1045
1046class MaxPoolGradWithArgmax(_PoolGrad):
1047    """Computes the gradients of MaxPoolWithArgmax."""
1048
1049    @prim_attr_register
1050    def __init__(self, kernel_size=1, strides=1, pad_mode="VALID"):
1051        self.init_prim_io_names(inputs=['x', 'grad', 'argmax'], outputs=['output'])
1052        super(MaxPoolGradWithArgmax, self).__init__(kernel_size, strides, pad_mode)
1053
1054    def infer_shape(self, x_shape, grad_shape, argmax_shape):
1055        if not grad_shape:
1056            raise TypeError("The dout of MaxPoolGradWithArgmax should be a Tensor.")
1057        return x_shape
1058
1059    def infer_dtype(self, x_dtype, grad_dtype, argmax_dtype):
1060        return grad_dtype
1061
1062
1063class MaxPoolGradGradWithArgmax(_PoolGrad):
1064    r"""
1065    Computes the gradients of MaxPoolGradWithArgmax.
1066
1067    Args:
1068        kernel_size (Union[int, tuple[int]]): The size of kernel used to take the maximum value,
1069            is an int number that represents height and width are both kernel_size, or a tuple
1070            of two int numbers that represent height and width respectively. Default: 1.
1071        strides (Union[int, tuple[int]]): The distance of kernel moving, an int number that represents
1072            the height and width of movement are both strides, or a tuple of two int numbers that
1073            represent height and width of movement respectively. Default: 1.
1074        pad_mode (str): The optional value for pad mode, is "same" or "valid", not case sensitive.
1075            Default: "valid".
1076
1077            - same: Adopts the way of completion. The height and width of the output will be the same as
1078              the input. The total number of padding will be calculated in horizontal and vertical
1079              directions and evenly distributed to top and bottom, left and right if possible.
1080              Otherwise, the last extra padding will be done from the bottom and the right side.
1081
1082            - valid: Adopts the way of discarding. The possible largest height and width of output
1083              will be returned without padding. Extra pixels will be discarded.
1084
1085    Inputs:
1086        - **x** (Tensor) - Tensor with data format "NCHW", data type must be float16.
1087        - **grad** (Tensor) - Data type same as `x`.
1088        - **argmax** (Tensor) - Data type must be uint16 or int64.
1089
1090    Outputs:
1091        Tensor, with data type same as `x`.
1092
1093    """
1094
1095    @prim_attr_register
1096    def __init__(self, kernel_size=1, strides=1, pad_mode="VALID"):
1097        self.init_prim_io_names(inputs=['x', 'grad', 'argmax'], outputs=['output'])
1098        super(MaxPoolGradGradWithArgmax, self).__init__(kernel_size, strides, pad_mode)
1099
1100    def infer_shape(self, x_shape, grad_shape, argmax_shape):
1101        if not grad_shape:
1102            raise TypeError("The dout of MaxPoolGradGradWithArgmax should be a Tensor.")
1103        return x_shape
1104
1105    def infer_dtype(self, x_dtype, grad_dtype, argmax_dtype):
1106        args = {'x_dtype': x_dtype, 'grad_dtype': grad_dtype}
1107        validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16], self.name)
1108        return grad_dtype
1109
1110
1111class MinimumGrad(Primitive):
1112    """Grad for minimum."""
1113
1114    @prim_attr_register
1115    def __init__(self, grad_x=True, grad_y=True):
1116        """Initialize MinimumGrad"""
1117
1118    def __call__(self, x, y, dout):
1119        raise NotImplementedError
1120
1121
1122class L2NormalizeGrad(PrimitiveWithInfer):
1123    r"""
1124    Gradients of L2 normalize.
1125
1126    Args:
1127        axis (Union[list(int), tuple(int), int]): The begin axis for the input to apply L2 normalize. Default: 0.
1128        epsilon (float): A small value added for numerical stability. Default: 1e-4.
1129
1130    Inputs:
1131        - **input_x** (Tensor) - Must be the input `weight` of forward operator L2Normalize.
1132        - **out** (Tensor) - Must be the output of forward operator L2Normalize.
1133        - **dout** (Tensor) - The backprop of the next layer.
1134
1135    Outputs:
1136        Tensor, gradients of L2Normalize `input_x`.
1137    """
1138
1139    @prim_attr_register
1140    def __init__(self, axis=0, epsilon=1e-4):
1141        axis = [axis] if isinstance(axis, int) else axis
1142        validator.check_value_type('axis', axis, [list, tuple], self.name)
1143        validator.check_value_type('epsilon', epsilon, [int, float], self.name)
1144        self.add_prim_attr('axis', axis)
1145        self.init_attrs['axis'] = axis
1146        if len(axis) != 1:
1147            raise TypeError("The length of axis must be 1, later will support multiple axis!")
1148
1149    def infer_shape(self, input_x, out, dout):
1150        validator.check('input_x shape', input_x, 'out shape', out, Rel.EQ, self.name)
1151        validator.check('input_x shape', input_x, 'dout shape', dout, Rel.EQ, self.name)
1152        return input_x
1153
1154    def infer_dtype(self, input_x, out, dout):
1155        args = {'input_x': input_x, 'out': out, 'dout': dout}
1156        validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
1157        return input_x
1158
1159
1160class LayerNormGrad(Primitive):
1161    """
1162    Applies the layer Normalization to the input array.
1163
1164    This operator will calculate the input gradients of layernorm.
1165
1166    Args:
1167        begin_norm_axis (int): The begin axis for the input to apply layernorm. Default: 1.
1168        begin_params_axis (int): The begin axis for the parameter input to apply layernorm. Default: 1.
1169
1170    Returns:
1171        tuple[int], tuple of 3 values (the gradients of layernorm input,  gamma, beta).
1172    """
1173
1174    @prim_attr_register
1175    def __init__(self, begin_norm_axis=1, begin_params_axis=1):
1176        """init"""
1177        self.begin_norm_axis = validator.check_value_type('begin_norm_axis', begin_norm_axis, [int], self.name)
1178        self.begin_params_axis = validator.check_value_type('begin_params_axis', begin_params_axis, [int], self.name)
1179
1180    def __call__(self, x, dy, variance, mean, gamma):
1181        raise NotImplementedError
1182
1183
1184class LayerNormGradGrad(PrimitiveWithInfer):
1185    """
1186    Gets the gradient of LayerNormGrad operation.
1187
1188    Args:
1189        begin_norm_axis (int): The begin axis for the input to apply layernorm. Default: 1.
1190        begin_params_axis (int): The begin axis for the parameter input to apply layernorm. Default: 1.
1191
1192    Returns:
1193        tuple[int], tuple of 3 values (the gradients of layernormgrad input, dy, gamma).
1194    """
1195
1196    @prim_attr_register
1197    def __init__(self, begin_norm_axis=1, begin_params_axis=1):
1198        """init"""
1199        self.begin_norm_axis = validator.check_value_type('begin_norm_axis', begin_norm_axis, [int], self.name)
1200        self.begin_params_axis = validator.check_value_type('begin_params_axis', begin_params_axis, [int], self.name)
1201
1202    def __call__(self, x, dy, variance, mean, gamma, grad_dx, grad_dg, grad_db):
1203        raise NotImplementedError
1204
1205    def infer_shape(self, x, dy, variance, mean, gamma, grad_dx, grad_dg, grad_db):
1206        return x, dy, gamma
1207
1208    def infer_dtype(self, x, dy, variance, mean, gamma, grad_dx, grad_dg, grad_db):
1209        return x, dy, gamma
1210
1211
1212class LogSoftmaxGrad(PrimitiveWithInfer):
1213    """Computes gradient for the Log Softmax activation."""
1214
1215    @prim_attr_register
1216    def __init__(self, axis=-1):
1217        """Initialize LogSoftmaxGrad"""
1218        validator.check_value_type("axis", axis, [int], self.name)
1219
1220    def infer_shape(self, dout, logits):
1221        rank = len(logits)
1222        validator.check_int_range(self.axis, -rank - 1, rank, Rel.INC_BOTH, 'axis', self.name)
1223        return logits
1224
1225    def infer_dtype(self, dout, logits):
1226        validator.check_subclass("logits", logits, mstype.tensor, self.name)
1227        return logits
1228
1229
1230class LSTMGradData(PrimitiveWithInfer):
1231    """Computes the data gradients of LSTM."""
1232
1233    @prim_attr_register
1234    def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout):
1235        self.input_size = validator.check_positive_int(input_size, 'input_size', self.name)
1236        self.hidden_size = validator.check_positive_int(hidden_size, 'hidden_size', self.name)
1237        self.num_layers = validator.check_positive_int(num_layers, 'num_layers', self.name)
1238        self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name)
1239        self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name)
1240        self.dropout = validator.check_value_type("dropout", dropout, [float], self.name)
1241        self.dropout = validator.check_float_range(dropout, 0, 1, Rel.INC_BOTH, 'dropout', self.name)
1242
1243        if bidirectional:
1244            self.num_directions = 2
1245        else:
1246            self.num_directions = 1
1247
1248    def infer_shape(self, y_shape, dy_shape, dhy_shape, dcy_shape, w_shape,
1249                    hx_shape, cx_shape, reserve_shape, state_shape):
1250        # dhy and dcy should be same shape
1251        validator.check_equal_int(len(dhy_shape), 3, "h_shape", self.name)
1252        validator.check_equal_int(len(dhy_shape), len(dcy_shape), "h_shape", self.name)
1253        validator.check_equal_int(dhy_shape[0], dcy_shape[0], "h_shape[0]", self.name)
1254        validator.check_equal_int(dhy_shape[1], dcy_shape[1], "h_shape[1]", self.name)
1255        validator.check_equal_int(dhy_shape[2], dcy_shape[2], "h_shape[2]", self.name)
1256
1257        validator.check_int(dhy_shape[0], self.num_layers * self.num_directions, Rel.EQ, "h_shape[0]", self.name)
1258        validator.check_equal_int(dhy_shape[2], self.hidden_size, "h_shape[2]", self.name)
1259
1260        validator.check_equal_int(len(dy_shape), 3, "dy_shape", self.name)
1261        validator.check_equal_int(dy_shape[1], dhy_shape[1], "dy[1]", self.name)
1262        validator.check_int(dy_shape[2], self.hidden_size * self.num_directions, Rel.EQ, "dy[2]", self.name)
1263
1264        dx_shape = (y_shape[0], y_shape[1], self.input_size)
1265        dhx_shape = dhy_shape
1266        dcx_shape = dcy_shape
1267
1268        return (dx_shape, dhx_shape, dcx_shape)
1269
1270    def infer_dtype(self, y_dtype, dy_dtype, dhy_dtype, dcy_dtype, w_dtype,
1271                    hx_dtype, cx_dtype, reserve_dtype, state_dtype):
1272        args = {"dy": dy_dtype, "dhy": dhy_dtype, "dcy": dcy_dtype}
1273        validator.check_tensors_dtypes_same_and_valid(args, (mstype.float32, mstype.float16), self.name)
1274        return (dy_dtype, dy_dtype, dy_dtype)
1275
1276
1277class LSTMGradWeight(PrimitiveWithInfer):
1278    """Computes the weight gradients of LSTM."""
1279
1280    @prim_attr_register
1281    def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout):
1282        self.input_size = validator.check_positive_int(input_size, 'input_size', self.name)
1283        self.hidden_size = validator.check_positive_int(hidden_size, 'hidden_size', self.name)
1284        self.num_layers = validator.check_positive_int(num_layers, 'num_layers', self.name)
1285        self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name)
1286        self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name)
1287        self.dropout = validator.check_value_type("dropout", dropout, [float], self.name)
1288        self.dropout = validator.check_float_range(dropout, 0, 1, Rel.INC_BOTH, 'dropout', self.name)
1289
1290        if bidirectional:
1291            self.num_directions = 2
1292        else:
1293            self.num_directions = 1
1294
1295    def infer_shape(self, x_shape, hx_shape, y_shape, reserve_shape, state_shape):
1296        weight_size = 0
1297        gate_size = 4 * self.hidden_size
1298        for layer in range(self.num_layers):
1299            for _ in range(self.num_directions):
1300                input_layer_size = self.input_size if layer == 0 else self.hidden_size * self.num_directions
1301                weight_size += gate_size * input_layer_size
1302                weight_size += gate_size * self.hidden_size
1303                if self.has_bias:
1304                    weight_size += 2 * gate_size
1305
1306        return (weight_size, 1, 1)
1307
1308    def infer_dtype(self, x_dtype, hx_dtype, y_dtype, reserve_dtype, state_dtype):
1309        return hx_dtype
1310
1311
1312class LSTMGrad(PrimitiveWithInfer):
1313    """Computes the data and weight gradients of LSTM."""
1314
1315    @prim_attr_register
1316    def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout):
1317        self.input_size = validator.check_positive_int(input_size, 'input_size', self.name)
1318        self.hidden_size = validator.check_positive_int(hidden_size, 'hidden_size', self.name)
1319        self.num_layers = validator.check_positive_int(num_layers, 'num_layers', self.name)
1320        self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name)
1321        self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name)
1322        self.dropout = validator.check_value_type("dropout", dropout, [float], self.name)
1323        self.dropout = validator.check_float_range(dropout, 0, 1, Rel.INC_BOTH, 'dropout', self.name)
1324
1325        if bidirectional:
1326            self.num_directions = 2
1327        else:
1328            self.num_directions = 1
1329
1330    def infer_shape(self, x_shape, hx_shape, cx_shape, w_shape, y_shape, hy_shape, cy_shape, dy_shape, dhy_shape,
1331                    dcy_shape, reserve_shape):
1332        # dhy and dcy should be same shape
1333        validator.check_equal_int(len(dhy_shape), 3, "h_shape", self.name)
1334        validator.check_equal_int(len(dhy_shape), len(dcy_shape), "h_shape", self.name)
1335        validator.check_equal_int(dhy_shape[0], dcy_shape[0], "h_shape[0]", self.name)
1336        validator.check_equal_int(dhy_shape[1], dcy_shape[1], "h_shape[1]", self.name)
1337        validator.check_equal_int(dhy_shape[2], dcy_shape[2], "h_shape[2]", self.name)
1338
1339        validator.check_int(dhy_shape[0], self.num_layers * self.num_directions, Rel.EQ, "h_shape[0]", self.name)
1340        validator.check_equal_int(dhy_shape[2], self.hidden_size, "h_shape[2]", self.name)
1341
1342        validator.check_equal_int(len(dy_shape), 3, "dy_shape", self.name)
1343        validator.check_equal_int(dy_shape[1], dhy_shape[1], "dy[1]", self.name)
1344        validator.check_int(dy_shape[2], self.hidden_size * self.num_directions, Rel.EQ, "dy[2]", self.name)
1345
1346        dx_shape = (y_shape[0], y_shape[1], self.input_size)
1347        dhx_shape = dhy_shape
1348        dcx_shape = dcy_shape
1349        weight_size = 0
1350        gate_size = 4 * self.hidden_size
1351        for layer in range(self.num_layers):
1352            for _ in range(self.num_directions):
1353                input_layer_size = self.input_size if layer == 0 else self.hidden_size * self.num_directions
1354                weight_size += gate_size * input_layer_size
1355                weight_size += gate_size * self.hidden_size
1356                if self.has_bias:
1357                    weight_size += gate_size
1358
1359        return (dx_shape, dhx_shape, dcx_shape, (weight_size, 1, 1))
1360
1361    def infer_dtype(self, x_dtype, hx_dtype, cx_dtype, w_dtype, y_dtype, hy_dtype, cy_dtype, dy_dtype, dhy_dtype,
1362                    dcy_dtype, reserve_dtype):
1363        return (dy_dtype, dy_dtype, dy_dtype, hx_dtype)
1364
1365
1366class DynamicRNNGrad(PrimitiveWithInfer):
1367    """Computes the input gradients of DynamicRNN."""
1368
1369    @prim_attr_register
1370    def __init__(self,
1371                 cell_type='LSTM',
1372                 direction='UNIDIRECTIONAL',
1373                 cell_depth=1,
1374                 use_peephole=False,
1375                 keep_prob=1.0,
1376                 cell_clip=-1.0,
1377                 num_proj=0,
1378                 time_major=True,
1379                 forget_bias=0.0):
1380        self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name)
1381
1382    def infer_shape(self, x_shape, w_shape, b_shape, y_shape, init_h_shape, init_c_shape, h_shape,
1383                    c_shape, dy_shape, dh_shape, dc_shape, i_shape, j_shape, f_shape, o_shape, tanhc_shape):
1384        validator.check_equal_int(len(x_shape), 3, "x_shape", self.name)
1385        num_step, batch_size, input_size = x_shape
1386        hidden_size = w_shape[-1] // 4
1387        if w_shape[-1] % 4 != 0:
1388            raise ValueError(f"For {self.name}, w_shape[-1] should multiple of 4.")
1389        validator.check("w_shape[0]", w_shape[0], "input_size + hidden_size",
1390                        input_size + hidden_size, Rel.EQ, self.name)
1391        valid_shape = [num_step, batch_size, hidden_size]
1392        validator.check("b_shape[0]", b_shape[0], "w_shape[1]", w_shape[1], Rel.EQ, self.name)
1393        validator.check("y_shape", y_shape, "excepted shape", valid_shape, Rel.EQ, self.name)
1394        validator.check("h_shape", h_shape, "excepted shape", valid_shape, Rel.EQ, self.name)
1395        validator.check("c_shape", c_shape, "excepted shape", valid_shape, Rel.EQ, self.name)
1396        validator.check("i_shape", i_shape, "excepted shape", valid_shape, Rel.EQ, self.name)
1397        validator.check("j_shape", j_shape, "excepted shape", valid_shape, Rel.EQ, self.name)
1398        validator.check("f_shape", f_shape, "excepted shape", valid_shape, Rel.EQ, self.name)
1399        validator.check("o_shape", o_shape, "excepted shape", valid_shape, Rel.EQ, self.name)
1400        validator.check("tanhc_shape", tanhc_shape, "excepted shape", valid_shape, Rel.EQ, self.name)
1401        validator.check("dy_shape", dy_shape, "excepted shape", valid_shape, Rel.EQ, self.name)
1402        validator.check("dh_shape", dh_shape, "excepted shape", [batch_size, hidden_size], Rel.EQ, self.name)
1403        validator.check("dc_shape", dc_shape, "excepted shape", [batch_size, hidden_size], Rel.EQ, self.name)
1404
1405        return w_shape, (w_shape[1],), x_shape, dh_shape, dc_shape
1406
1407    def infer_dtype(self, x_dtype, w_dtype, b_dtype, y_dtype, init_h_dtype, init_c_dtype, h_dtype,
1408                    c_dtype, dy_dtype, dh_dtype, dc_dtype, i_dtype, j_dtype, f_dtype, o_dtype, tanhc_dtype):
1409        return x_dtype, x_dtype, x_dtype, x_dtype, x_dtype
1410
1411
1412class DynamicGRUV2Grad(PrimitiveWithInfer):
1413    r"""
1414    Computes the input gradients of DynamicGRUV2.
1415
1416    Args:
1417        direction (str): A string identifying the direction in the op. Default: 'UNIDIRECTIONAL'.
1418            Only 'UNIDIRECTIONAL' is currently supported.
1419        cell_depth (int): An integer identifying the cell depth in the op. Default: 1.
1420        keep_prob (float): A float identifying the keep prob in the op. Default: 1.0.
1421        cell_clip (float): A float identifying the cell clip in the op. Default: -1.0.
1422        num_proj (int): An integer identifying the num proj in the op. Default: 0.
1423        time_major (bool): A bool identifying the time major in the op. Default: True.
1424        gate_order (str): An string identifying the gate order in weight and bias. Default: 'rzh.
1425            'zrh' is another option.
1426        reset_after (bool): An bool identifying whether to apply reset gate after matrix multiplication. Default: True.
1427
1428    Inputs:
1429        - **x** (Tensor) - Current words. Tensor of shape :math:`(num_step, batch_size, input_size)`.
1430          The data type must be float16 or float32.
1431        - **weight_input** (Tensor) - Weight. Tensor of shape :math:`(input_size, 3 x hidden_size)`.
1432          The data type must be float16 or float32.
1433        - **weight_hidden** (Tensor) - Bias. Tensor of shape :math:`(hidden_size, 3 x hidden_size)`.
1434          The data type must be float16 or float32.
1435        - **y** (Tensor) - A Tensor of shape :math:
1436          if num_proj > 0 `(num_step, batch_size, min(hidden_size, num_proj)`,
1437          if num_proj == 0 `(num_step, batch_size, hidden_size)`.
1438          The data type must be float16 or float32.
1439        - **init_h** (Tensor) - Hidden state of initial time.
1440          Tensor of shape :math:`(batch_size, hidden_size)`.
1441          The data type must be float16 or float32.
1442        - **h** (Tensor) - A Tensor of shape :math:`(num_step, batch_size, hidden_size)`.
1443          The data type must be float16 or float32.
1444        - **dy** (Tensor) - Gradient of `y`, has the same shape and data type as `y`.
1445        - **dh** (Tensor) - Gradient of `h`, has the same shape and data type as `init_h`.
1446        - **update** (Tensor) - A Tensor of shape :math:`(num_step, batch_size, hidden_size)`.
1447          The data type must be float16 or float32.
1448        - **reset** (Tensor) - A Tensor of shape :math:`(num_step, batch_size, hidden_size)`.
1449          The data type must be float16 or float32.
1450        - **new** (Tensor) - A Tensor of shape :math:`(num_step, batch_size, hidden_size)`.
1451          The data type must be float16 or float32.
1452        - **hidden_new** (Tensor) - A Tensor of shape :math:`(num_step, batch_size, hidden_size)`.
1453          The data type must be float16 or float32.
1454        - **seq_length** (Tensor) - The length of each batch. Tensor of shape :math:`(batch_size)`.
1455          Only `None` is currently supported.
1456        - **mask** (Tensor) - A 4-D Tensor. The data type must be float16 or float32.
1457
1458    Outputs:
1459        - **dw_input** (Tensor) - A Tensor has the same shape as `weight_input`.
1460          Has the same type with input `x`.
1461        - **dw_hidden** (Tensor) - A Tensor has the same shape as `weight_hidden`.
1462          Has the same type with input `x`.
1463        - **db_input** (Tensor) - A Tensor of shape :math:`(3 x hidden_size)`.
1464          Has the same type with input `x`.
1465        - **db_hidden** (Tensor) - A Tensor of shape :math:`(3 x hidden_size)`.
1466          Has the same type with input `x`.
1467        - **dx** (Tensor) - A Tensor of shape :math:`(num_step, batch_size, hidden_size)`.
1468          Has the same type with input `x`.
1469        - **dh_prev** (Tensor) - A Tensor of shape :math:`(batch_size, hidden_size)`.
1470          Has the same type with input `x`.
1471    """
1472
1473    @prim_attr_register
1474    def __init__(self,
1475                 direction='UNIDIRECTIONAL',
1476                 cell_depth=1,
1477                 keep_prob=1.0,
1478                 cell_clip=-1.0,
1479                 num_proj=0,
1480                 time_major=True,
1481                 gate_order="rzh",
1482                 reset_after=True):
1483        self.cell_depth = validator.check_value_type("cell_depth", cell_depth, [int], self.name)
1484        self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name)
1485        self.cell_clip = validator.check_value_type("cell_clip", cell_clip, [float], self.name)
1486        self.num_proj = validator.check_non_negative_int(num_proj, "num_proj", self.name)
1487        self.time_major = validator.check_value_type("time_major", time_major, [bool], self.name)
1488        self.direction = validator.check_string(direction, ['UNIDIRECTIONAL'], "direction", self.name)
1489        self.gate_order = validator.check_string(gate_order, ['zrh', 'rzh'], "gate_order", self.name)
1490        self.reset_after = validator.check_value_type("reset_after", reset_after, [bool], self.name)
1491
1492    def infer_shape(self, x_shape, winput_shape, whidden_shape, y_shape, init_h_shape, h_shape,
1493                    dy_shape, dh_shape, update_shape, reset_shape, new_shape, hnew_shape, seq_shape, mask_shape):
1494        validator.check_int(len(x_shape), 3, Rel.EQ, "x shape", self.name)
1495        validator.check_int(len(winput_shape), 2, Rel.EQ, "weight input shape rank", self.name)
1496        validator.check_int(len(whidden_shape), 2, Rel.EQ, "weight hidden shape rank", self.name)
1497        validator.check_int(len(y_shape), 3, Rel.EQ, "y shape rank", self.name)
1498        num_step, batch_size, input_size = x_shape
1499        hidden_size = whidden_shape[0]
1500        validator.check("weight_hidden_shape[-1]", whidden_shape[-1], "3 * hidden_size",
1501                        3 * hidden_size, Rel.EQ, self.name)
1502        validator.check("weight_input_shape", winput_shape, "excepted shape",
1503                        [input_size, 3 * hidden_size], Rel.EQ, self.name)
1504        if self.num_proj > 0:
1505            valid_y_shape = [num_step, batch_size, min(hidden_size, self.num_proj)]
1506        else:
1507            valid_y_shape = [num_step, batch_size, hidden_size]
1508        validator.check("y_shape", y_shape, "excepted shape", valid_y_shape, Rel.EQ, self.name)
1509
1510        validator.check("init_h_shape", init_h_shape, "excepted shape",
1511                        [batch_size, hidden_size], Rel.EQ, self.name)
1512        valid_shape = [num_step, batch_size, hidden_size]
1513        validator.check("h_shape", h_shape, "excepted shape", valid_shape, Rel.EQ, self.name)
1514        validator.check("dy_shape", dy_shape, "excepted shape", valid_shape, Rel.EQ, self.name)
1515        validator.check("dh_shape", dh_shape, "excepted shape",
1516                        [batch_size, hidden_size], Rel.EQ, self.name)
1517        validator.check("update_shape", update_shape, "excepted shape", valid_shape, Rel.EQ, self.name)
1518        validator.check("reset_shape", reset_shape, "excepted shape", valid_shape, Rel.EQ, self.name)
1519        validator.check("new_shape", new_shape, "excepted shape", valid_shape, Rel.EQ, self.name)
1520        validator.check("hnew_shape", hnew_shape, "excepted shape", valid_shape, Rel.EQ, self.name)
1521        if seq_shape is not None:
1522            validator.check("seq_shape", seq_shape, "batch_size", batch_size, Rel.EQ, self.name)
1523
1524        dx_shape = (num_step, batch_size, input_size)
1525        dh_shape = (batch_size, hidden_size)
1526        dwinput_shape = (input_size, 3 * hidden_size)
1527        dwhidden_shape = (hidden_size, 3 * hidden_size)
1528        db_shape = (3 * hidden_size,)
1529        return dwinput_shape, dwhidden_shape, db_shape, db_shape, dx_shape, dh_shape
1530
1531    def infer_dtype(self, x_dtype, winput_dtype, whidden_dtype, y_dtype, init_h_dtype, h_dtype,
1532                    dy_dtype, dh_dtype, update_dtype, reset_dtype, new_dtype, hnew_dtype, seq_dtype, mask_dtype):
1533        valid_types = (mstype.float16, mstype.float32)
1534        args = {"y_dtype": y_dtype, "h_dtype": h_dtype, "dy_dtype": dy_dtype,
1535                "dh_dtype": dh_dtype, "update_dtype": update_dtype, "reset_dtype": reset_dtype,
1536                "new_dtype": new_dtype, "hnew_dtype": hnew_dtype}
1537        validator.check_tensor_dtype_valid("x_dtype", x_dtype, valid_types, self.name)
1538        validator.check_tensor_dtype_valid("winput_dtype", winput_dtype, valid_types, self.name)
1539        validator.check_tensor_dtype_valid("whidden_dtype", whidden_dtype, valid_types, self.name)
1540        validator.check_tensor_dtype_valid("init_h_dtype", init_h_dtype, valid_types, self.name)
1541        validator.check_tensors_dtypes_same_and_valid(args, valid_types, self.name)
1542        if seq_dtype is not None:
1543            validator.check_tensor_dtype_valid("seq_dtype", seq_dtype, valid_types, self.name)
1544        if mask_dtype is not None:
1545            validator.check_tensor_dtype_valid("mask_dtype", mask_dtype, valid_types, self.name)
1546        return x_dtype, x_dtype, x_dtype, x_dtype, x_dtype, x_dtype
1547
1548
1549class PReLUGrad(PrimitiveWithInfer):
1550    r"""
1551    Gradients of PReLU operation.
1552
1553    Note:
1554        1-dimensional input_x is not supported.
1555
1556    Inputs:
1557        - **y_backprop** (Tensor) - Representing the backprop of the next layer.
1558        - **input_x** (Tensor) - Must be the input `input_x` of forward operator PRelu.
1559        - **weight** (Tensor) - Float Tensor, w > 0, must be the input `weight` of forward operator PRelu.
1560
1561    Outputs:
1562        Tensor, with the same type as `input_x`.
1563    """
1564
1565    @prim_attr_register
1566    def __init__(self):
1567        pass
1568
1569    def infer_shape(self, y_backprop_shape, a_shape, w_shape):
1570        return y_backprop_shape, w_shape
1571
1572    def infer_dtype(self, y_backprop_dtype, a_dtype, w_dtype):
1573        tuple(map(partial(validator.check_tensor_dtype_valid,
1574                          valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name),
1575                  ('y_backprop', "input_x", "weight"),
1576                  (y_backprop_dtype, a_dtype, w_dtype)))
1577        return y_backprop_dtype, w_dtype
1578
1579
1580class ReluGrad(Primitive):
1581    """Performs grad of Relu operation."""
1582
1583    @prim_attr_register
1584    def __init__(self):
1585        """Initialize ReluGrad"""
1586        self.init_prim_io_names(inputs=['y_backprop', 'x'], outputs=['output'])
1587
1588    def __call__(self, y_backprop, x):
1589        raise NotImplementedError
1590
1591
1592class ReLU6Grad(PrimitiveWithInfer):
1593    """Performs grad of ReLU6 operation."""
1594
1595    @prim_attr_register
1596    def __init__(self):
1597        self.init_prim_io_names(inputs=['y_grad', 'x'], outputs=['output'])
1598
1599    def __call__(self, y_grad, x):
1600        raise NotImplementedError
1601
1602    def infer_shape(self, y_grad_shape, x_shape):
1603        return x_shape
1604
1605    def infer_dtype(self, y_grad_dtype, x_dtype):
1606        valid_dtypes = (mstype.float16, mstype.float32)
1607        validator.check_tensor_dtype_valid("y_grad", y_grad_dtype, valid_dtypes, self.name)
1608        validator.check_tensor_dtype_valid("x", x_dtype, valid_dtypes, self.name)
1609        return x_dtype
1610
1611
1612class ReluGradV2(Primitive):
1613    """Performs grad of ReLUV2 operation."""
1614
1615    @prim_attr_register
1616    def __init__(self):
1617        self.init_prim_io_names(inputs=['gradients', 'mask'], outputs=['output'])
1618
1619    def __call__(self, gradients, mask):
1620        raise NotImplementedError
1621
1622
1623class EluGrad(PrimitiveWithInfer):
1624    """Performs grad of Elu operation."""
1625
1626    @prim_attr_register
1627    def __init__(self):
1628        """Initialize EluGrad"""
1629
1630    def infer_shape(self, y_grad_shape, x_shape):
1631        return x_shape
1632
1633    def infer_dtype(self, y_grad_dtype, x_dtype):
1634        args = {'y_grad': y_grad_dtype, 'x': x_dtype}
1635        validator.check_tensors_dtypes_same_and_valid(args, mstype.float_type, self.name)
1636        return x_dtype
1637
1638
1639class GatherDGrad(PrimitiveWithInfer):
1640    """Performs grad of GatherD operation."""
1641
1642    @prim_attr_register
1643    def __init__(self, dim=0, shape=None):
1644        """Initialize GatherDGrad"""
1645        validator.check_is_int(dim, int)
1646        self.add_prim_attr("dim", dim)
1647        self.dim = dim
1648        self.out_shape = shape
1649        self.init_prim_io_names(inputs=['index', 'grad'], outputs=['output'])
1650
1651    def infer_shape(self, index_shape, grad_shape):
1652        return self.out_shape
1653
1654    def infer_dtype(self, index_dtype, grad_dtype):
1655        return grad_dtype
1656
1657
1658class ResizeBilinearGrad(PrimitiveWithInfer):
1659    """Performs grad of ResizeBilinear operation."""
1660
1661    @prim_attr_register
1662    def __init__(self, align_corners=False):
1663        """init"""
1664
1665    def infer_shape(self, dout_shape, orig_shape):
1666        return orig_shape
1667
1668    def infer_dtype(self, dout_dtype, orig_type):
1669        return orig_type
1670
1671
1672class ResizeNearestNeighborGrad(PrimitiveWithInfer):
1673    """
1674    Compute gradient of `ResizeNearestNeighbor` operator.
1675
1676    Note:
1677        The shape of input parameter `size` must be (height, width).
1678
1679    Args:
1680        align_corners (bool): Whether the centers of the 4 corner pixels of the input
1681            and output tensors are aligned. Default: False.
1682    """
1683
1684    @prim_attr_register
1685    def __init__(self, align_corners=False):
1686        """Initialize ResizeNearestNeighborGrad"""
1687        self.init_prim_io_names(inputs=['grads', 'size'], outputs=['y'])
1688
1689    def __infer__(self, grads, size):
1690        shp = (grads['shape'][0],) + (grads['shape'][1],) + size['value']
1691        return {'shape': shp,
1692                'dtype': grads['dtype'],
1693                'value': None}
1694
1695
1696class ROIAlignGrad(PrimitiveWithInfer):
1697    """
1698    ROIAlignGrad operator.
1699
1700    Args:
1701       xdiff_shape (tuple): The diff shape.
1702       pooled_height (int): The output feature height.
1703       pooled_width (int): The output feature width.
1704       spatial_scale (float): The feature stride.
1705       sample_num (int): Number of sampling points. Default: 2.
1706    """
1707
1708    @prim_attr_register
1709    def __init__(self, xdiff_shape, pooled_height, pooled_width, spatial_scale, sample_num=2):
1710        """Initialize ROIAlignGrad"""
1711        validator.check_value_type("pooled_height", pooled_height, [int], self.name)
1712        validator.check_value_type("pooled_width", pooled_width, [int], self.name)
1713        validator.check_value_type("spatial_scale", spatial_scale, [float], self.name)
1714        validator.check_value_type("sample_num", sample_num, [int], self.name)
1715        validator.check_value_type("xdiff_shape", xdiff_shape, [tuple], self.name)
1716        self.xdiff_shape = xdiff_shape
1717        self.pooled_height = pooled_height
1718        self.pooled_width = pooled_width
1719        self.spatial_scale = spatial_scale
1720        self.sample_num = sample_num
1721
1722    def infer_shape(self, ydiff_shape, rois_shape):
1723        return self.xdiff_shape
1724
1725    def infer_dtype(self, ydiff_type, rois_type):
1726        return ydiff_type
1727
1728
1729class SigmoidGrad(PrimitiveWithInfer):
1730    """Gets the gradient of Sigmoid operation."""
1731
1732    @prim_attr_register
1733    def __init__(self):
1734        pass
1735
1736    def infer_shape(self, out, dout):
1737        return out
1738
1739    def infer_dtype(self, out, dout):
1740        args = {'out': out, 'dout': dout}
1741        validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
1742        return out
1743
1744
1745class _ActivationGrad(PrimitiveWithInfer):
1746    """_ActivationGrad base class."""
1747
1748    @prim_attr_register
1749    def __init__(self):
1750        self.init_prim_io_names(inputs=['y_grad', 'x'], outputs=['output'])
1751
1752    def infer_shape(self, y_grad_shape, x_shape):
1753        return x_shape
1754
1755    def infer_dtype(self, y_grad_dtype, x_dtype):
1756        valid_dtypes = (mstype.float16, mstype.float32)
1757        validator.check_tensor_dtype_valid("y_grad", y_grad_dtype, valid_dtypes, self.name)
1758        validator.check_tensor_dtype_valid("x", x_dtype, valid_dtypes, self.name)
1759        return x_dtype
1760
1761
1762class HSwishGrad(_ActivationGrad):
1763    """Gets the gradient of HSwish operation."""
1764
1765
1766class HSigmoidGrad(_ActivationGrad):
1767    """Gets the gradient of HSigmoid operation."""
1768
1769
1770class SigmoidCrossEntropyWithLogitsGrad(PrimitiveWithInfer):
1771    """Computes the gradients of `SigmoidCrossEntropyWithLogits`."""
1772
1773    @prim_attr_register
1774    def __init__(self):
1775        """Initialize SigmoidCrossEntropyWithLogitsGrad"""
1776        self.init_prim_io_names(inputs=['x', 'y', 'dout'], outputs=['x_grad'])
1777
1778    def infer_shape(self, x_shape, y_shape, dout_shape):
1779        validator.check("x_shape", x_shape, "y_shape", y_shape, Rel.EQ, self.name)
1780        validator.check("x_shape", x_shape, "dout_shape", dout_shape, Rel.EQ, self.name)
1781        return x_shape
1782
1783    def infer_dtype(self, x_dtype, y_dtype, dout_dtype):
1784        args = {"x_dtype": x_dtype, "y_dtype": y_dtype, 'dout_dtype': dout_dtype}
1785        validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
1786        return dout_dtype
1787
1788
1789class SliceGrad(PrimitiveWithInfer):
1790    """Reverse of slice."""
1791
1792    @prim_attr_register
1793    def __init__(self):
1794        """Initialize SliceGrad"""
1795        self.init_prim_io_names(inputs=['dy', 'x', 'begin', 'size'], outputs=['dx'])
1796
1797    def __infer__(self, dy, x, begin, size):
1798        dy_shape, x_shape, size_value = dy['shape'], x['shape'], size['value']
1799        dy_shape_len = len(dy_shape)
1800        for i in range(dy_shape_len):
1801            validator.check(f'dy_shape[{i}]', dy_shape[i], f'x_shape[{i}]', x_shape[i], Rel.LE, self.name)
1802            validator.check(f'dy_shape[{i}]', dy_shape[i], f'size_shape[{i}]', size_value[i], Rel.EQ, self.name)
1803        return {'shape': x_shape,
1804                'dtype': x['dtype'],
1805                'value': None}
1806
1807
1808class NLLLossGrad(PrimitiveWithInfer):
1809    """Computes the gradients of `NLLLoss`."""
1810
1811    @prim_attr_register
1812    def __init__(self, reduction="mean"):
1813        """Initialize NLLLoss"""
1814        self.init_prim_io_names(inputs=['x', 'target', "weight"], outputs=['loss'])
1815        self.reduction = validator.check_string(reduction, ['none', 'sum', 'mean'], 'reduction', self.name)
1816        self.add_prim_attr('reduction', self.reduction)
1817
1818    def infer_shape(self, x_shape, y_grad_shape, t_shape, w_shape, tw_shape):
1819        validator.check_int(len(x_shape), [1, 2], Rel.IN, "x rank", self.name)
1820        validator.check_int(len(t_shape), 1, Rel.EQ, "target rank", self.name)
1821        validator.check_int(len(w_shape), 1, Rel.EQ, "weight rank", self.name)
1822        validator.check(f"input_shape[0]", x_shape[0], "target_shape", t_shape[0], Rel.EQ, self.name)
1823        if len(x_shape) == 1:
1824            validator.check(f"input_shape[0]", x_shape[0], "weight_shape", w_shape[0], Rel.EQ, self.name)
1825        else:
1826            validator.check(f"input_shape[1]", x_shape[1], "weight_shape", w_shape[0], Rel.EQ, self.name)
1827        return x_shape
1828
1829    def infer_dtype(self, x_dtype, y_grad_dtype, t_dtype, w_dtype, tw_dtype):
1830        valid_dtypes = (mstype.float16, mstype.float32)
1831        validator.check_tensor_dtype_valid("x_dtype", x_dtype, valid_dtypes, self.name)
1832        validator.check_tensor_dtype_valid("y_grad_dtype", y_grad_dtype, valid_dtypes, self.name)
1833        validator.check_tensor_dtype_valid("t_dtype", t_dtype, mstype.int32, self.name)
1834        validator.check_tensor_dtype_valid("w_dtype", w_dtype, valid_dtypes, self.name)
1835        validator.check_tensor_dtype_valid("tw_dtype", tw_dtype, valid_dtypes, self.name)
1836        validator.check('tw_shape_dtype', tw_dtype, 'w_shape_dtype', w_dtype, Rel.EQ, self.name)
1837        return x_dtype
1838
1839
1840class SmoothL1LossGrad(PrimitiveWithInfer):
1841    """Computes gradient for prediction on SmoothL1Loss."""
1842
1843    @prim_attr_register
1844    def __init__(self, beta=1.0):
1845        pass
1846
1847    def infer_shape(self, prediction, target, dloss):
1848        validator.check('prediction shape', prediction, 'target shape', target, Rel.EQ, self.name)
1849        validator.check('prediction shape', prediction, 'dloss shape', dloss, Rel.EQ, self.name)
1850        return prediction
1851
1852    def infer_dtype(self, prediction, target, dloss):
1853        args = {"prediction": prediction, "target": target, 'dloss': dloss}
1854        validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
1855        return dloss
1856
1857
1858class SoftMarginLossGrad(Primitive):
1859    """Computes gradient for prediction on SoftMarginLoss."""
1860
1861    @prim_attr_register
1862    def __init__(self, reduction="mean"):
1863        self.init_prim_io_names(inputs=['predict', 'label', "dout"], outputs=['gradient'])
1864        self.reduction = validator.check_string(reduction, ['none', 'sum', 'mean'], 'reduction', self.name)
1865
1866
1867class StridedSliceGrad(PrimitiveWithInfer):
1868    """
1869    Performs grad of StridedSlice operation.
1870
1871    Args:
1872        begin_mask (int): Start indexing the slice. Default: 0.
1873        end_mask (int): End indexing the slice. Default: 0.
1874        ellipsis_mask (int): An int32 mask. Default: 0.
1875        new_axis_mask (int): An int32 mask. Default: 0.
1876        shrink_axis_mask (int): An int32 mask. Default: 0.
1877
1878    Returns:
1879        Tensor, has the same shape of input.
1880    """
1881
1882    @prim_attr_register
1883    def __init__(self,
1884                 begin_mask=0,
1885                 end_mask=0,
1886                 ellipsis_mask=0,
1887                 new_axis_mask=0,
1888                 shrink_axis_mask=0):
1889        """Initialize StridedSliceGrad"""
1890        validator.check_value_type('begin_mask', begin_mask, [int], self.name)
1891        validator.check_value_type('end_mask', end_mask, [int], self.name)
1892        validator.check_value_type('ellipsis_mask', ellipsis_mask, [int], self.name)
1893        validator.check_value_type('new_axis_mask', new_axis_mask, [int], self.name)
1894        validator.check_value_type('shrink_axis_mask', shrink_axis_mask, [int], self.name)
1895        self.init_prim_io_names(inputs=['dy', 'shapex', 'begin', 'end', 'strides'], outputs=['output'])
1896
1897    def __infer__(self, dy, shapex, begin, end, strides):
1898        validator.check_tensor_dtype_valid("dy", dy['dtype'], mstype.number_type + (mstype.bool_,), self.name)
1899
1900        for idx, item in enumerate(shapex['value']):
1901            validator.check_value_type("shapex[%d]" % idx, item, [int], self.name)
1902        for idx, item in enumerate(begin['value']):
1903            validator.check_value_type("begin[%d]" % idx, item, [int], self.name)
1904        for idx, item in enumerate(end['value']):
1905            validator.check_value_type("end[%d]" % idx, item, [int], self.name)
1906        for idx, item in enumerate(strides['value']):
1907            validator.check_value_type("strides[%d]" % idx, item, [int], self.name)
1908
1909        return {'shape': shapex['value'],
1910                'dtype': dy['dtype'],
1911                'value': None}
1912
1913
1914class SoftplusGrad(PrimitiveWithInfer):
1915    """Computes gradient for the Softplus activation."""
1916
1917    @prim_attr_register
1918    def __init__(self):
1919        self.init_prim_io_names(inputs=['dout', 'x'], outputs=['output'])
1920
1921    def infer_shape(self, dout_shape, x_shape):
1922        validator.check("x_shape", x_shape, "dout_shape", dout_shape, Rel.EQ, self.name)
1923        return x_shape
1924
1925    def infer_dtype(self, dout_dtype, x_dtype):
1926        args = {"x_dtype": x_dtype, "dout_dtype": dout_dtype}
1927        validator.check_tensors_dtypes_same_and_valid(args, mstype.float_type, self.name)
1928        return x_dtype
1929
1930
1931class TanhGrad(PrimitiveWithInfer):
1932    """Computes gradient of hyperbolic tangent of input element-wise."""
1933
1934    @prim_attr_register
1935    def __init__(self):
1936        pass
1937
1938    def infer_shape(self, out, dout):
1939        return out
1940
1941    def infer_dtype(self, out, dout):
1942        args = {"out": out, "dout": dout}
1943        validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
1944        return out
1945
1946
1947class MirrorPadGrad(PrimitiveWithInfer):
1948    """Gradients of MirrorPad operation."""
1949
1950    @prim_attr_register
1951    def __init__(self, mode="REFLECT"):
1952        """Initialize MirrorPad"""
1953        validator.check_string(mode, ['REFLECT', 'SYMMETRIC'], 'mode', self.name)
1954        self.mode = mode
1955
1956    def __infer__(self, dout, paddings):
1957        validator.check_subclass("dout", dout['dtype'], mstype.tensor, self.name)
1958        validator.check_subclass("paddings", paddings['dtype'], mstype.tensor, self.name)
1959        validator.check("paddings rank", len(paddings['shape']), "expected", 2, Rel.EQ, self.name)
1960        validator.check("paddings dim_1", paddings['shape'][1], "expected", 2, Rel.EQ, self.name)
1961
1962        if paddings['value'] is None:
1963            raise ValueError(f"For {self.name}, paddings must be const.")
1964        paddings_value = paddings['value'].asnumpy()
1965        y_shape = ()
1966        dout_shape = dout['shape']
1967        for i, val in enumerate(dout_shape):
1968            y_shape += (val - paddings_value[i][0] - paddings_value[i][1],)
1969        return {'shape': y_shape,
1970                'dtype': dout['dtype'],
1971                'value': None}
1972
1973
1974class EmbeddingLookupCommGrad(PrimitiveWithInfer):
1975    """
1976    Performs the gradient for the communication part of EmbeddingLookup operator.
1977
1978    This works ONLY when 'reduce_scatter_flag' is True in 'EmbeddingLookup'. Roughly speaking,
1979    this primitive is implemented by StridedSlice --> _HostAllGather --> Concat. This primitive runs on host.
1980    """
1981
1982    @prim_attr_register
1983    def __init__(self):
1984        self.init_prim_io_names(inputs=['dy', 'split_num'], outputs=['output'])
1985        self.add_prim_attr('primitive_target', 'CPU')
1986        self.tuple_setitem = Primitive('tuple_setitem')
1987
1988    def __infer__(self, dy, split_num):
1989        """
1990        This primitive is implemented by three steps:
1991            1) Splits the 'dy' along dimension 0 into 'split_num' parts.
1992            2) For each part, perform _HostAllGather((0, 1, 2, 3, 4, 5, 6, 7)) on the host.
1993            3) After _HostAllGather, there are still 'split_num' parts in each process. Then, perform Concat on them
1994              along dimension 0.
1995
1996        The output shape of this primitive: shape(output)[0] == shape(dy)[0] * 8
1997        """
1998        dy_shape = tuple(dy['shape'])
1999        split_num_value = split_num['value']
2000        validator.check_value_type("split_num_value", split_num_value, [int], self.name)
2001        dy_shape_all = self.tuple_setitem(dy_shape, 0, dy_shape[0] * 8)
2002        return {'shape': dy_shape_all,
2003                'dtype': dy['dtype'],
2004                'value': None}
2005
2006
2007class RefToEmbed(Primitive):
2008    r"""
2009    Make a key from Ref.
2010
2011    The Key is a symbolic_key, is a embedding on Parameter, which is used as a key of the variable in env_type,
2012    and get items by operation `env_get_item` with the symbolic_key instance. The `Parameter` is a ref.
2013
2014    Inputs:
2015        - **input** (Ref) - Target ref, ref is short for reference. The value of a Parameter is a ref.
2016
2017    Outputs:
2018        symbolic_key, made from the Ref.
2019
2020    Examples:
2021        >>> class Net(nn.Cell):
2022        >>>     def __init__(self):
2023        >>>         super(Net, self).__init__()
2024        >>>         self.weight = mindspore.Parameter(1.0, name='weight')
2025        >>>
2026        >>>     def construct(self):
2027        >>>         key = RefToEmbed()(self.weight)
2028        >>>         return key, self.weight
2029    """
2030    __mindspore_signature__ = (
2031        sig.make_sig('variable', sig.sig_rw.RW_REF),
2032    )
2033
2034    @prim_attr_register
2035    def __init__(self):
2036        pass
2037
2038
2039class AtanGrad(PrimitiveWithInfer):
2040    """
2041    Computes AtanGrad of input element-wise.
2042
2043    Returns:
2044        Tensor, has the same type as input.
2045    """
2046
2047    @prim_attr_register
2048    def __init__(self):
2049        """Initialize AtanGrad"""
2050
2051    def infer_shape(self, x, dout):
2052        validator.check("x shape", x, "dout shape", dout, Rel.EQ, self.name)
2053        return x
2054
2055    def infer_dtype(self, x, dout):
2056        args = {"x": x, "dout": dout}
2057        validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
2058        return x
2059
2060
2061class BasicLSTMCellCStateGrad(PrimitiveWithInfer):
2062    """Computes the state gradients of BasicLSTMCell."""
2063
2064    @prim_attr_register
2065    def __init__(self, forget_bias, activation):
2066        self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name)
2067        self.activation = validator.check_string(activation, ['tanh'], "activation", self.name)
2068
2069    def infer_shape(self, c_shape, dht_shape, dct_shape, it_shape, jt_shape, ft_shape, ot_shape, tanhct_shape):
2070        # dhy and dcy should be same shape
2071        validator.check_equal_int(len(c_shape), 2, "c rank", self.name)
2072        validator.check("dht rank", len(dht_shape), "c rank", len(c_shape), Rel.EQ, self.name)
2073        validator.check("dct rank", len(dct_shape), "c rank", len(c_shape), Rel.EQ, self.name)
2074        validator.check("it rank", len(it_shape), "c rank", len(c_shape), Rel.EQ, self.name)
2075        validator.check("jt rank", len(jt_shape), "c rank", len(c_shape), Rel.EQ, self.name)
2076        validator.check("ft rank", len(ft_shape), "c rank", len(c_shape), Rel.EQ, self.name)
2077        validator.check("ot rank", len(ot_shape), "c rank", len(c_shape), Rel.EQ, self.name)
2078        validator.check("tanhct rank", len(tanhct_shape), "c rank", len(c_shape), Rel.EQ, self.name)
2079        validator.check("dht shape", dht_shape, "c shape", c_shape, Rel.EQ, self.name)
2080        validator.check("dct shape", dct_shape, "c shape", c_shape, Rel.EQ, self.name)
2081        validator.check("it shape", it_shape, "c shape", c_shape, Rel.EQ, self.name)
2082        validator.check("jt shape", jt_shape, "c shape", c_shape, Rel.EQ, self.name)
2083        validator.check("ft shape", ft_shape, "c shape", c_shape, Rel.EQ, self.name)
2084        validator.check("ot shape", ot_shape, "c shape", c_shape, Rel.EQ, self.name)
2085        validator.check("tanhct shape", tanhct_shape, "c shape", c_shape, Rel.EQ, self.name)
2086
2087        dgate_shape = (c_shape[0], 4 * c_shape[1])
2088        dct_1_shape = c_shape
2089
2090        return (dgate_shape, dct_1_shape)
2091
2092    def infer_dtype(self, c_dtype, dht_dtype, dct_dtype, it_dtype, jt_dtype, ft_dtype, ot_dtype, tanhct_dtype):
2093        validator.check_subclass("c", c_dtype, [mstype.tensor], self.name)
2094        validator.check_subclass("dht", dht_dtype, [mstype.tensor], self.name)
2095        validator.check_subclass("dct", dct_dtype, [mstype.tensor], self.name)
2096        validator.check_subclass("it", it_dtype, [mstype.tensor], self.name)
2097        validator.check_subclass("jt", jt_dtype, [mstype.tensor], self.name)
2098        validator.check_subclass("ft", ft_dtype, [mstype.tensor], self.name)
2099        validator.check_subclass("ot", ot_dtype, [mstype.tensor], self.name)
2100        validator.check_subclass("tanhct", tanhct_dtype, [mstype.tensor], self.name)
2101        validator.check_type_name("c", c_dtype, [mstype.float16, mstype.float32], self.name)
2102        validator.check_type_name("dht", dht_dtype, [mstype.float16, mstype.float32], self.name)
2103        validator.check_type_name("dct", dct_dtype, [mstype.float16, mstype.float32], self.name)
2104        validator.check_type_name("it", it_dtype, [mstype.float16, mstype.float32], self.name)
2105        validator.check_type_name("jt", jt_dtype, [mstype.float16, mstype.float32], self.name)
2106        validator.check_type_name("ft", ft_dtype, [mstype.float16, mstype.float32], self.name)
2107        validator.check_type_name("ot", ot_dtype, [mstype.float16, mstype.float32], self.name)
2108        validator.check_type_name("tanhct", tanhct_dtype, [mstype.float16, mstype.float32], self.name)
2109        return (c_dtype, c_dtype)
2110
2111
2112class BasicLSTMCellWeightGrad(PrimitiveWithInfer):
2113    """Computes the weight gradients of BasicLSTM."""
2114    @prim_attr_register
2115    def __init__(self):
2116        pass
2117
2118    def infer_shape(self, x_shape, h_shape, dgate_shape):
2119        validator.check_equal_int(len(x_shape), 2, "x rank", self.name)
2120        validator.check("h rank", len(h_shape), " x rank", len(x_shape), Rel.EQ, self.name)
2121        validator.check("dgate rank", len(dgate_shape), "x rank", len(x_shape), Rel.EQ, self.name)
2122        validator.check("h_shape[0]", h_shape[0], "x_shape[0]", x_shape[0], Rel.EQ, self.name)
2123        validator.check("dgate_shape[0]", dgate_shape[0], "h_shape[0]", h_shape[0], Rel.EQ, self.name)
2124        validator.check("dgate_shape[1]", dgate_shape[1], "4*h_shape[1]", 4 * h_shape[1], Rel.EQ, self.name)
2125        input_size = x_shape[1]
2126        hidden_size = h_shape[1]
2127        dw_shape = (input_size + hidden_size, 4 * hidden_size)
2128        db_shape = (4 * hidden_size,)
2129        return (dw_shape, db_shape)
2130
2131    def infer_dtype(self, x_dtype, h_dtype, dgate_dtype):
2132        validator.check_subclass("x", x_dtype, mstype.tensor, self.name)
2133        validator.check_subclass("h", h_dtype, mstype.tensor, self.name)
2134        validator.check_subclass("dgate", dgate_dtype, mstype.tensor, self.name)
2135        validator.check_type_name("x", x_dtype, [mstype.float16, mstype.float32], self.name)
2136        validator.check_type_name("h", h_dtype, [mstype.float16, mstype.float32], self.name)
2137        validator.check_type_name("dgate", dgate_dtype, [mstype.float16, mstype.float32], self.name)
2138        return (x_dtype, x_dtype)
2139
2140
2141class BasicLSTMCellInputGrad(PrimitiveWithInfer):
2142    """Computes the input gradients of BasicLSTM."""
2143
2144    @prim_attr_register
2145    def __init__(self, keep_prob):
2146        self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name)
2147        self.keep_prob = validator.check_float_range(keep_prob, 0.0, 1.0, Rel.INC_BOTH, "keep_prob", self.name)
2148
2149    def infer_shape(self, dgate_shape, w_shape):
2150        validator.check_equal_int(len(dgate_shape), 2, "dgate rank", self.name)
2151        validator.check_equal_int(len(w_shape), 2, "w rank", self.name)
2152        validator.check("dgate_shape[1]", dgate_shape[1], "w_shape[1]", w_shape[1], Rel.EQ, self.name)
2153        batch_size = dgate_shape[0]
2154        hidden_size = dgate_shape[1] // 4
2155        input_size = w_shape[0] - hidden_size
2156        dxt_shape = (batch_size, input_size)
2157        dht_shape = (batch_size, hidden_size)
2158        return (dxt_shape, dht_shape)
2159
2160    def infer_dtype(self, dgate_dtype, w_dtype):
2161        validator.check_subclass("dgate", dgate_dtype, mstype.tensor, self.name)
2162        validator.check_subclass("w", w_dtype, mstype.tensor, self.name)
2163        validator.check_type_name("dgate", dgate_dtype, [mstype.float16, mstype.float32], self.name)
2164        validator.check_type_name("w", w_dtype, [mstype.float16, mstype.float32], self.name)
2165        return (dgate_dtype, dgate_dtype)
2166
2167
2168class InvGrad(PrimitiveWithInfer):
2169    """Computes gradients for inv operation."""
2170
2171    @prim_attr_register
2172    def __init__(self):
2173        pass
2174
2175    def infer_shape(self, x, grad):
2176        validator.check("x_shape", x, "grad_shape", grad, Rel.EQ, self.name)
2177        return x
2178
2179    def infer_dtype(self, x, grad):
2180        validator.check_type_name("dgate", x, [mstype.float16, mstype.float32, mstype.int32, mstype.int8], self.name)
2181        validator.check_type_name("grad", grad, [mstype.float16, mstype.float32, mstype.int32, mstype.int8], self.name)
2182        return x
2183
2184
2185class LRNGrad(PrimitiveWithInfer):
2186    """Computes gradients for LRN operation."""
2187
2188    @prim_attr_register
2189    def __init__(self, depth_radius=5, bias=1.0, alpha=1.0, beta=0.5):
2190        self.init_prim_io_names(inputs=['grads', 'x', 'y'], outputs=['z'])
2191        validator.check_value_type("depth_radius", depth_radius, [int], self.name)
2192        validator.check_value_type("bias", bias, [float], self.name)
2193        validator.check_value_type("alpha", alpha, [float], self.name)
2194        validator.check_value_type("beta", beta, [float], self.name)
2195
2196    def infer_dtype(self, grads, x, y):
2197        args = {"grads": grads, "x": x, "y": y}
2198        validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32,), self.name)
2199        return x
2200
2201    def infer_shape(self, grads, x, y):
2202        return x
2203
2204
2205class MaskedSelectGrad(PrimitiveWithInfer):
2206    """Computes gradient for MaskedSelect."""
2207
2208    @prim_attr_register
2209    def __init__(self):
2210        pass
2211
2212    def infer_shape(self, x, mask, grad):
2213        return x
2214
2215    def infer_dtype(self, x, mask, grad):
2216        return x
2217
2218
2219class SoftShrinkGrad(Primitive):
2220    r"""
2221          Gradients for SoftShrink operation.
2222
2223          Args:
2224              lambd – The \lambdaλ (must be no less than zero) value for the Softshrink formulation. Default: 0.5.
2225
2226          Inputs:
2227              - **input_grad** (Tensor) - The input gradient.
2228              - **input_x** (Tensor) - The input of SoftShrink with data type of float16 or float32.
2229                Any number of additional dimensions.
2230
2231          Outputs:
2232              output - Tensor, has the same shape and data type as input_x.
2233
2234          Raises:
2235              TypeError: If lambd is not a float.
2236              TypeError: If dtype of input_x is neither float16 nor float32.
2237              ValueError: If lambd is less than to 0.
2238
2239          Supported Platforms:
2240              ``Ascend``
2241      """
2242
2243    @prim_attr_register
2244    def __init__(self, lambd=0.5):
2245        self.init_prim_io_names(inputs=['input_grad', 'input_x'], outputs=['output'])
2246        validator.check_value_type("lambd", lambd, [float], self.name)
2247        validator.check_number("lambd", lambd, 0, Rel.GE, self.name)
2248
2249
2250class CdistGrad(Primitive):
2251    """Computes gradient for Cdist."""
2252
2253    @prim_attr_register
2254    def __init__(self, p=2.0):
2255        validator.check_value_type("p", p, [float], self.name)
2256        self.init_prim_io_names(inputs=['grad', 'input_x', 'input_y', 'cdist'], outputs=['output'])
2257
2258
2259class HShrinkGrad(Primitive):
2260    """
2261    Computes gradients for HShrinkGrad operation.
2262
2263    Args:
2264        Lambd (float): the λ value for the Hardshrink formulation. Default: 0.5
2265
2266    Inputs:
2267        - **Gradients** (Tensor) - the gradients of loss to output of HShrink function.
2268          Currently gradients data type only support float16 and float32.
2269        - **Features** (Tensor) - Must be the input `input_x` of the forward operator HSHrink.
2270          Currently features data type only support float16 and float32.
2271
2272    Outputs:
2273        backprops - Tensor, with the same shape and data type as `features`.
2274
2275    Rasise:
2276        ValueError: If `lambd` is not a float.
2277        ValueError: If shape of `gradients` is not the same as `features`.
2278        TypeError: If dtype of `gradients` is not the same as `features`.
2279        TypeError: If dtype of `gradients` or `features` is neither float16 nor float32.
2280
2281    Supported Platforms:
2282        ``Ascend``
2283    """
2284
2285    @prim_attr_register
2286    def __init__(self, lambd=0.5):
2287        validator.check_value_type("lambd", lambd, [float], self.name)
2288        if lambd < 0.0:
2289            lambd = 0.0
2290            self.add_prim_attr('lambd', lambd)
2291