• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020-2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16"""Operators for nn."""
17
18import math
19import operator
20from functools import reduce, partial
21import numpy as np
22from mindspore import log as logger
23from mindspore._checkparam import _check_3d_int_or_tuple
24from ... import context
25from .. import signature as sig
26from ..._checkparam import Validator as validator
27from ..._checkparam import Rel
28from ...common import dtype as mstype
29from ...common._decorator import deprecated
30from ..primitive import Primitive, PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register
31
32
33def _check_positive_int_or_tuple(arg_name, arg_value, prim_name, allow_four=False, ret_four=False):
34    """
35    Checks whether an argument is a positive int or tuple with 2 or 4(when allow_four is True) positive int elements.
36    """
37
38    def _raise_message():
39        raise ValueError(f"For '{prim_name}' attr '{arg_name}' should be an positive int number or a tuple of two "
40                         f"{'or four ' if allow_four else ''}positive int numbers, but got {arg_value}")
41
42    def _get_return_value():
43        if isinstance(arg_value, int):
44            ret = (1, 1, arg_value, arg_value) if ret_four else (arg_value, arg_value)
45        elif len(arg_value) == 2:
46            ret = (1, 1, arg_value[0], arg_value[1]) if ret_four else arg_value
47        elif len(arg_value) == 4:
48            if not allow_four:
49                _raise_message()
50            ret = arg_value if ret_four else (arg_value[2], arg_value[3])
51        else:
52            _raise_message()
53        return ret
54
55    validator.check_value_type(arg_name, arg_value, (int, tuple), prim_name)
56    ret_value = _get_return_value()
57    for item in ret_value:
58        if isinstance(item, int) and not isinstance(item, bool) and item > 0:
59            continue
60        _raise_message()
61    return ret_value
62
63
64def _check_shape(arg_name, arg_value, prim_name):
65    """
66    Checks whether an shape dims is a positive int elements.
67    """
68
69    def _raise_message():
70        raise ValueError(f"For '{prim_name}' attr '{arg_name}' dims elements should be positive int numbers, "
71                         f"but got {arg_value}")
72
73    validator.check_value_type(arg_name, arg_value, (list, tuple), prim_name)
74    for item in arg_value:
75        if isinstance(item, int) and item > 0:
76            continue
77        _raise_message()
78    return arg_value
79
80
81def _update_attr_by_format(arg_value, arg_format):
82    """
83    If the format is NHWC, should modify the strides or dilation shape.
84    """
85    ret = arg_value
86    if len(arg_value) == 4 and arg_format == "NHWC":
87        ret = arg_value[1:] + (1,)
88
89    return ret
90
91
92class Flatten(PrimitiveWithInfer):
93    r"""
94    Flattens a tensor without changing its batch size on the 0-th axis.
95
96    Inputs:
97        - **input_x** (Tensor) - Tensor of shape :math:`(N, \ldots)` to be flattened, where :math:`N` is batch size.
98
99    Outputs:
100        Tensor, the shape of the output tensor is :math:`(N, X)`, where :math:`X` is
101        the product of the remaining dimension.
102
103    Raises:
104        TypeError: If `input_x` is not a Tensor.
105        ValueError: If length of shape of `input_x` is less than 1.
106
107    Supported Platforms:
108        ``Ascend`` ``GPU`` ``CPU``
109
110    Examples:
111        >>> input_x = Tensor(np.ones(shape=[1, 2, 3, 4]), mindspore.float32)
112        >>> flatten = ops.Flatten()
113        >>> output = flatten(input_x)
114        >>> print(output.shape)
115        (1, 24)
116    """
117
118    @prim_attr_register
119    def __init__(self):
120        pass
121
122    def infer_shape(self, input_x):
123        validator.check_int(len(input_x), 1, Rel.GE, 'input_x rank', self.name)
124        prod = 1 if len(input_x) == 1 else reduce(operator.mul, input_x[1:])
125        return input_x[0], prod
126
127    def infer_dtype(self, input_x):
128        validator.check_subclass("input_x", input_x, mstype.tensor, self.name)
129        return input_x
130
131
132class AdaptiveAvgPool2D(PrimitiveWithInfer):
133    r"""
134    AdaptiveAvgPool2D operation.
135
136    This operator applies a 2D adaptive average pooling to an input signal composed of multiple input planes.
137    That is, for any input size, the size of the specified output is H x W.
138    The number of output features is equal to the number of input planes.
139
140    Args:
141        output_size (Union[int, tuple]): The target output size is H x W.
142            ouput_size can be a tuple, or a single H for H x H, and H and W can be int or None
143            which means the output size is the same as the input.
144
145    Inputs:
146        - **input_x** (Tensor) - The input of AdaptiveAvgPool2D, which is a 3D or 4D tensor,
147          with float16, float32, float64 data type.
148
149    Outputs:
150        Tensor, with the same type as the `input_x`.
151
152        Shape of the output is `input_x_shape[:len(input_x_shape) - len(out_shape)] + out_shape`.
153
154        If `output_size` contains `None`:
155
156        - `out_shape = input_x_shape[-2] + output_size[1]`: If `output_size` is `(None, w)`
157        - `out_shape = output_size[0] + input_x_shape[-1]`: If `output_size` is `(h, None)`
158        - `out_shape = input_x_shape[-2:]: If output_size` is `(None, None)`
159
160        If `output_size` does not contain `None`:
161
162        - `out_shape = (h, h)`: If `output_size` is `h`
163        - `out_shape = (h, w)`: If `output_size` is `(h, w)`
164
165    Raises:
166        ValueError: If `output_size` is a tuple and if `output_size` length is not 2.
167        TypeError: If `input_x` is not a tensor.
168        TypeError: If dtype of `input_x` is not float16, float32, float64.
169        ValueError: If `input_x` dimension is less than or equal to output_size dimension.
170
171    Supported Platforms:
172        ``GPU``
173
174    Examples:
175        >>> # case 1: output_size=(None, 2)
176        >>> input_x = Tensor(np.array([[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]],
177        ...                            [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]],
178        ...                            [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]]), mindspore.float32)
179        >>> adaptive_avg_pool_2d = ops.AdaptiveAvgPool2D((None, 2))
180        >>> output = adaptive_avg_pool_2d(input_x)
181        >>> print(output)
182        [[[1.5 2.5]
183          [4.5 5.5]
184          [7.5 8.5]]
185         [[1.5 2.5]
186          [4.5 5.5]
187          [7.5 8.5]]
188         [[1.5 2.5]
189          [4.5 5.5]
190          [7.5 8.5]]]
191        >>> # case 2: output_size=2
192        >>> adaptive_avg_pool_2d = ops.AdaptiveAvgPool2D(2)
193        >>> output = adaptive_avg_pool_2d(input_x)
194        >>> print(output)
195        [[[3. 4.]
196          [6. 7.]]
197         [[3. 4.]
198          [6. 7.]]
199         [[3. 4.]
200          [6. 7.]]]
201        >>> # case 3: output_size=(1, 2)
202        >>> adaptive_avg_pool_2d = ops.AdaptiveAvgPool2D((1, 2))
203        >>> output = adaptive_avg_pool_2d(input_x)
204        >>> print(output)
205        [[[4.5 5.5]]
206         [[4.5 5.5]]
207         [[4.5 5.5]]]
208    """
209
210    @prim_attr_register
211    def __init__(self, output_size):
212        """Initialize AdaptiveAvgPool2D."""
213        validator.check_value_type("output_size", output_size, [int, tuple], self.name)
214        if isinstance(output_size, tuple):
215            validator.check_int(len(output_size), 2, Rel.EQ, 'length of output_size', self.name)
216        self.output_size = (output_size, output_size) if isinstance(self.output_size, int) else output_size
217
218    def infer_shape(self, x_shape):
219        if len(x_shape) <= len(self.output_size):
220            raise ValueError("input_x {} dimension should be larger than output_size {} "
221                             "dimension".format(x_shape, self.output_size))
222        validator.check_int(len(x_shape), 5, Rel.LT, 'input_x_dimensions', self.name)
223        for input_x_dimension in x_shape:
224            validator.check_int(input_x_dimension, 0, Rel.GT, 'input_x dimension', self.name)
225        zipped = zip(self.output_size, x_shape[-len(self.output_size):])
226        out_size = [i if i is not None else j for i, j in zipped]
227        for item in out_size:
228            validator.check_value_type("item of output_size", item, [int], self.name)
229        self.add_prim_attr('output_size', out_size)
230        output_shape = x_shape[:len(x_shape) - len(out_size)] + out_size
231        return output_shape
232
233    def infer_dtype(self, x_dtype):
234        validator.check_tensor_dtype_valid("x_dtype", x_dtype, [mstype.float16, mstype.float32, mstype.float64],
235                                           self.name)
236        return x_dtype
237
238
239class Softmax(Primitive):
240    r"""
241    Softmax operation.
242
243    Applies the Softmax operation to the input tensor on the specified axis.
244    Supposes a slice in the given aixs :math:`x`, then for each element :math:`x_i`,
245    the Softmax function is shown as follows:
246
247    .. math::
248        \text{output}(x_i) = \frac{exp(x_i)}{\sum_{j = 0}^{N-1}\exp(x_j)},
249
250    where :math:`N` is the length of the tensor.
251
252    Args:
253        axis (Union[int, tuple]): The axis to perform the Softmax operation. Default: -1.
254
255    Inputs:
256        - **logits** (Tensor) - Tensor of shape :math:`(N, *)`, where :math:`*` means, any number of
257          additional dimensions, with float16 or float32 data type.
258
259    Outputs:
260        Tensor, with the same type and shape as the logits.
261
262    Raises:
263        TypeError: If `axis` is neither an int nor a tuple.
264        TypeError: If dtype of `logits` is neither float16 nor float32.
265        ValueError: If `axis` is a tuple whose length is less than 1.
266        ValueError: If `axis` is a tuple whose elements are not all in range [-len(logits.shape), len(logits.shape)).
267
268    Supported Platforms:
269        ``Ascend`` ``GPU`` ``CPU``
270
271    Examples:
272        >>> logits = Tensor(np.array([1, 2, 3, 4, 5]), mindspore.float32)
273        >>> softmax = ops.Softmax()
274        >>> output = softmax(logits)
275        >>> print(output)
276        [0.01165623 0.03168492 0.08612854 0.23412167 0.6364086 ]
277    """
278
279    @prim_attr_register
280    def __init__(self, axis=-1):
281        """Initialize Softmax."""
282        self.init_prim_io_names(inputs=['x'], outputs=['output'])
283        validator.check_value_type("axis", axis, [int, tuple], self.name)
284        if isinstance(axis, int):
285            self.add_prim_attr('axis', (axis,))
286        for item in self.axis:
287            validator.check_value_type("item of axis", item, [int], self.name)
288
289
290class LogSoftmax(Primitive):
291    r"""
292    Log Softmax activation function.
293
294    Applies the Log Softmax function to the input tensor on the specified axis.
295    Supposes a slice in the given aixs, :math:`x` for each element :math:`x_i`,
296    the Log Softmax function is shown as follows:
297
298    .. math::
299        \text{output}(x_i) = \log \left(\frac{\exp(x_i)} {\sum_{j = 0}^{N-1}\exp(x_j)}\right),
300
301    where :math:`N` is the length of the Tensor.
302
303    Args:
304        axis (int): The axis to perform the Log softmax operation. Default: -1.
305
306    Inputs:
307        - **logits** (Tensor) - Tensor of shape :math:`(N, *)`, where :math:`*` means, any number of
308          additional dimensions, with float16 or float32 data type.
309
310    Outputs:
311        Tensor, with the same type and shape as the logits.
312
313    Raises:
314        TypeError: If `axis` is not an int.
315        TypeError: If dtype of `logits` is neither float16 nor float32.
316        ValueError: If `axis` is not in range [-len(logits.shape), len(logits.shape)).
317
318    Supported Platforms:
319        ``Ascend`` ``GPU`` ``CPU``
320
321    Examples:
322        >>> logits = Tensor(np.array([1, 2, 3, 4, 5]), mindspore.float32)
323        >>> log_softmax = ops.LogSoftmax()
324        >>> output = log_softmax(logits)
325        >>> print(output)
326        [-4.4519143 -3.4519143 -2.4519143 -1.4519144 -0.4519144]
327    """
328
329    @prim_attr_register
330    def __init__(self, axis=-1):
331        """Initialize LogSoftmax."""
332        validator.check_value_type("axis", axis, [int], self.name)
333
334
335class Softplus(Primitive):
336    r"""
337    Softplus activation function.
338
339    Softplus is a smooth approximation to the ReLU function.
340    It can be used to constrain the output of a machine to always be positive.
341    The function is shown as follows:
342
343    .. math::
344
345        \text{output} = \log(1 + \exp(\text{x})),
346
347    Inputs:
348        - **input_x** (Tensor) - Tensor of shape :math:`(N, *)`, where :math:`*` means, any number of
349          additional dimensions, with float16 or float32 data type.
350
351    Outputs:
352        Tensor, with the same type and shape as the `input_x`.
353
354    Raises:
355        TypeError: If `input_x` is not a Tensor.
356        TypeError: If the dtype of `input_x` is neither float16 nor float32.
357
358    Supported Platforms:
359        ``Ascend``  ``GPU`` ``CPU``
360
361    Examples:
362        >>> input_x = Tensor(np.array([1, 2, 3, 4, 5]), mindspore.float32)
363        >>> softplus = ops.Softplus()
364        >>> output = softplus(input_x)
365        >>> print(output)
366        [1.3132615 2.126928  3.0485873 4.01815   5.0067153]
367    """
368
369    @prim_attr_register
370    def __init__(self):
371        """Initialize Softplus"""
372        self.init_prim_io_names(inputs=['x'], outputs=['output'])
373
374
375class Softsign(PrimitiveWithInfer):
376    r"""
377    Softsign activation function.
378
379    The function is shown as follows:
380
381    .. math::
382
383        \text{SoftSign}(x) = \frac{x}{ 1 + |x|}
384
385    Inputs:
386        - **input_x** (Tensor) - Tensor of shape :math:`(N, *)`, where :math:`*` means, any number of
387          additional dimensions, with float16 or float32 data type.
388
389    Outputs:
390        Tensor, with the same type and shape as the `input_x`.
391
392    Raises:
393        TypeError: If `input_x` is not a Tensor.
394        TypeError: If dtype of `input_x` is neither float16 nor float32.
395
396    Supported Platforms:
397        ``Ascend``
398
399    Examples:
400        >>> input_x = Tensor(np.array([0, -1, 2, 30, -30]), mindspore.float32)
401        >>> softsign = ops.Softsign()
402        >>> output = softsign(input_x)
403        >>> print(output)
404        [ 0.        -0.5         0.6666667  0.9677419 -0.9677419]
405    """
406
407    @prim_attr_register
408    def __init__(self):
409        """Initialize Softsign"""
410        self.init_prim_io_names(inputs=['x'], outputs=['output'])
411
412    def infer_shape(self, input_x):
413        return input_x
414
415    def infer_dtype(self, input_x):
416        validator.check_tensor_dtype_valid('input_x', input_x, [mstype.float16, mstype.float32], self.name)
417        return input_x
418
419
420class ReLU(Primitive):
421    r"""
422    Computes ReLU (Rectified Linear Unit) of input tensors element-wise.
423
424    It returns :math:`\max(x,\  0)` element-wise.
425
426    Note:
427        In general, this operator is more commonly used. The difference from `ReLuV2` is that the operator will
428        output one more Mask.
429
430    Inputs:
431        - **input_x** (Tensor) - Tensor of shape :math:`(N, *)`, where :math:`*` means, any number of
432          additional dimensions, with number data type.
433
434    Outputs:
435        Tensor, with the same type and shape as the `input_x`.
436
437    Raises:
438        TypeError: If dtype of `input_x` is not number.
439        TypeError: If `input_x` is not a Tensor.
440
441    Supported Platforms:
442        ``Ascend`` ``GPU`` ``CPU``
443
444    Examples:
445        >>> input_x = Tensor(np.array([[-1.0, 4.0, -8.0], [2.0, -5.0, 9.0]]), mindspore.float32)
446        >>> relu = ops.ReLU()
447        >>> output = relu(input_x)
448        >>> print(output)
449        [[0. 4. 0.]
450         [2. 0. 9.]]
451    """
452
453    @prim_attr_register
454    def __init__(self):
455        """Initialize ReLU"""
456        self.init_prim_io_names(inputs=['x'], outputs=['output'])
457
458
459class Mish(PrimitiveWithInfer):
460    r"""
461    Computes MISH(A Self Regularized Non-Monotonic Neural Activation Function) of input tensors element-wise.
462
463    The function is shown as follows:
464
465    .. math::
466
467        \text{output} = x * \tan(\log(1 + \exp(\text{x})))
468
469    See more details in `A Self Regularized Non-Monotonic Neural Activation Function
470    <https://arxiv.org/abs/1908.08681>`_.
471
472    Inputs:
473        - **x** (Tensor) - Tensor of shape :math:`(N, *)`, where :math:`*` means, any number of
474          additional dimensions, with float16 or float32 data type.
475
476    Outputs:
477        Tensor, with the same type and shape as the `x`.
478
479    Supported Platforms:
480        ``Ascend``
481
482    Raises:
483        TypeError: If dtype of `x` is neither float16 nor float32.
484
485    Examples:
486        >>> x = Tensor(np.array([[-1.0, 4.0, -8.0], [2.0, -5.0, 9.0]]), mindspore.float32)
487        >>> mish = ops.Mish()
488        >>> output = mish(x)
489        >>> print(output)
490        [[-0.30273438  3.9974136 -0.015625]
491         [ 1.9439697  -0.02929688 8.999999]]
492    """
493
494    @prim_attr_register
495    def __init__(self):
496        """Initialize Mish"""
497        self.init_prim_io_names(inputs=['x'], outputs=['output'])
498
499    def infer_shape(self, x_shape):
500        return x_shape
501
502    def infer_dtype(self, x_dtype):
503        validator.check_tensor_dtype_valid('x', x_dtype, [mstype.float16, mstype.float32], self.name)
504        return x_dtype
505
506
507class SeLU(PrimitiveWithInfer):
508    r"""
509    Computes SeLU (scaled exponential Linear Unit) of input tensors element-wise.
510
511    The activation function is defined as:
512
513    .. math::
514        E_{i} =
515        scale *
516        \begin{cases}
517        x_{i}, &\text{if } x_{i} \geq 0; \cr
518        \text{alpha} * (\exp(x_i) - 1), &\text{otherwise.}
519        \end{cases}
520
521    where :math:`alpha` and :math:`scale` are pre-defined constants(:math:`alpha=1.67326324`
522    and :math:`scale=1.05070098`).
523
524    See more details in `Self-Normalizing Neural Networks <https://arxiv.org/abs/1706.02515>`_.
525
526    Inputs:
527        - **input_x** (Tensor) - Tensor of shape :math:`(N, *)`, where :math:`*` means, any number of
528          additional dimensions, with float16 or float32 data type.
529
530    Outputs:
531        Tensor, with the same type and shape as the `input_x`.
532
533    Supported Platforms:
534        ``Ascend``
535
536    Raises:
537        TypeError: If dtype of `input_x` is neither float16 nor float32.
538
539    Examples:
540        >>> input_x = Tensor(np.array([[-1.0, 4.0, -8.0], [2.0, -5.0, 9.0]]), mindspore.float32)
541        >>> selu = ops.SeLU()
542        >>> output = selu(input_x)
543        >>> print(output)
544        [[-1.1113307 4.202804 -1.7575096]
545        [ 2.101402 -1.7462534 9.456309 ]]
546    """
547
548    @prim_attr_register
549    def __init__(self):
550        """Initialize SeLU"""
551        self.init_prim_io_names(inputs=['x'], outputs=['output'])
552
553    def infer_shape(self, x_shape):
554        return x_shape
555
556    def infer_dtype(self, x_dtype):
557        valid_dtypes = [mstype.float16, mstype.float32]
558        validator.check_tensor_dtype_valid('x', x_dtype, valid_dtypes, self.name)
559        return x_dtype
560
561
562class ReLU6(PrimitiveWithCheck):
563    r"""
564    Computes ReLU (Rectified Linear Unit) upper bounded by 6 of input tensors element-wise.
565
566    .. math::
567
568        \text{ReLU6}(x) = \min(\max(0,x), 6)
569
570    It returns :math:`\min(\max(0,x), 6)` element-wise.
571
572    Inputs:
573        - **input_x** (Tensor) - Tensor of shape :math:`(N, *)`, where :math:`*` means, any number of
574          additional dimensions, with float16 or float32 data type.
575
576    Outputs:
577        Tensor, with the same type and shape as the `input_x`.
578
579    Raises:
580        TypeError: If dtype of `input_x` is neither float16 nor float32.
581        TypeError: If `input_x` is not a Tensor.
582
583    Supported Platforms:
584        ``Ascend`` ``GPU`` ``CPU``
585
586    Examples:
587        >>> input_x = Tensor(np.array([[-1.0, 4.0, -8.0], [2.0, -5.0, 9.0]]), mindspore.float32)
588        >>> relu6 = ops.ReLU6()
589        >>> result = relu6(input_x)
590        >>> print(result)
591        [[0. 4. 0.]
592         [2. 0. 6.]]
593    """
594
595    @prim_attr_register
596    def __init__(self):
597        """Initialize ReLU6"""
598        self.init_prim_io_names(inputs=['x'], outputs=['output'])
599
600    def check_shape(self, input_x):
601        pass
602
603    def check_dtype(self, input_x):
604        validator.check_tensor_dtype_valid('input_x', input_x, (mstype.float16, mstype.float32), self.name)
605
606
607class ReLUV2(Primitive):
608    r"""
609    Computes ReLU (Rectified Linear Unit) of input tensors element-wise.
610
611    It returns :math:`\max(x,\  0)` element-wise.
612
613    Note:
614        The difference from `ReLu` is that the operator will output one more Mask,
615        and the kernel of the operator is different from `ReLu`.
616
617    Inputs:
618        - **input_x** (Tensor) - The input tensor must be a 4-D tensor.
619
620    Outputs:
621        - **output** (Tensor) - Has the same type and shape as the `input_x`.
622        - **mask** (Tensor) - A tensor whose data type must be uint8.
623
624    Raises:
625        TypeError: If `input_x` is not a Tensor.
626        ValueError: If shape of `input_x` is not 4-D.
627
628    Supported Platforms:
629        ``Ascend``
630
631    Examples:
632        >>> input_x = Tensor(np.array([[[[1, -2], [-3, 4]], [[-5, 6], [7, -8]]]]), mindspore.float32)
633        >>> relu_v2 = ops.ReLUV2()
634        >>> output, mask= relu_v2(input_x)
635        >>> print(output)
636        [[[[1. 0.]
637           [0. 4.]]
638          [[0. 6.]
639           [7. 0.]]]]
640        >>> print(mask)
641        [[[[[1 0]
642            [2 0]]
643           [[2 0]
644            [1 0]]]]]
645    """
646
647    @prim_attr_register
648    def __init__(self):
649        """Initialize ReLUV2"""
650        self.init_prim_io_names(inputs=['x'], outputs=['output', 'mask'])
651
652
653class Elu(PrimitiveWithInfer):
654    r"""
655    Computes exponential linear:
656
657    .. math::
658
659        \text{ELU}(x)= \left\{
660        \begin{array}{align}
661            \alpha(e^{x}  - 1) & \text{if } x \le 0\\
662            x & \text{if } x \gt 0\\
663        \end{array}\right.
664
665    The data type of input tensor must be float.
666
667    Args:
668        alpha (float): The coefficient of negative factor whose type is float,
669            only support '1.0' currently. Default: 1.0.
670
671    Inputs:
672        - **input_x** (Tensor) - Tensor of shape :math:`(N, *)`, where :math:`*` means, any number of
673          additional dimensions, with float16 or float32 data type.
674
675    Outputs:
676        Tensor, has the same shape and data type as `input_x`.
677
678    Raises:
679        TypeError: If `alpha` is not a float.
680        TypeError: If dtype of `input_x` is neither float16 nor float32.
681        ValueError: If `alpha` is not equal to 1.0.
682
683    Supported Platforms:
684        ``Ascend`` ``GPU`` ``CPU``
685
686    Examples:
687        >>> input_x = Tensor(np.array([[-1.0, 4.0, -8.0], [2.0, -5.0, 9.0]]), mindspore.float32)
688        >>> elu = ops.Elu()
689        >>> output = elu(input_x)
690        >>> print(output)
691        [[-0.63212055  4.         -0.99966455]
692         [ 2.         -0.99326205  9.        ]]
693    """
694
695    @prim_attr_register
696    def __init__(self, alpha=1.0):
697        """Initialize Elu"""
698        validator.check_value_type("alpha", alpha, [float], self.name)
699        validator.check_number("alpha", alpha, 1.0, Rel.EQ, self.name)
700
701    def infer_shape(self, input_x):
702        return input_x
703
704    def infer_dtype(self, input_x):
705        validator.check_tensor_dtype_valid('input_x', input_x, mstype.float_type, self.name)
706        return input_x
707
708
709class HSwish(PrimitiveWithInfer):
710    r"""
711    Hard swish activation function.
712
713    Applies hswish-type activation element-wise. The input is a Tensor with any valid shape.
714
715    Hard swish is defined as:
716
717    .. math::
718
719        \text{hswish}(x_{i}) = x_{i} * \frac{ReLU6(x_{i} + 3)}{6},
720
721    where :math:`x_i` is an element of the input Tensor.
722
723    Inputs:
724        - **input_x** (Tensor) - Tensor of shape :math:`(N, *)`, where :math:`*` means, any number of
725          additional dimensions, with float16 or float32 data type.
726
727    Outputs:
728        Tensor, with the same type and shape as the `input_x`.
729
730    Raises:
731        TypeError: If `input_x` is not a Tensor.
732        TypeError: If dtype of `input_x` is neither float16 nor float32.
733
734    Supported Platforms:
735        ``GPU`` ``CPU``
736
737    Examples:
738        >>> hswish = ops.HSwish()
739        >>> input_x = Tensor(np.array([-1, -2, 0, 2, 1]), mindspore.float16)
740        >>> result = hswish(input_x)
741        >>> print(result)
742        [-0.3333  -0.3333  0  1.666  0.6665]
743    """
744
745    @prim_attr_register
746    def __init__(self):
747        """Initialize HSwish."""
748        self.init_prim_io_names(inputs=['x'], outputs=['output'])
749
750    def infer_shape(self, xshape):
751        return xshape
752
753    def infer_dtype(self, x_dtype):
754        validator.check_tensor_dtype_valid("x", x_dtype, (mstype.float16, mstype.float32), self.name)
755        return x_dtype
756
757
758class Sigmoid(PrimitiveWithInfer):
759    r"""
760    Sigmoid activation function.
761
762    Computes Sigmoid of input element-wise. The Sigmoid function is defined as:
763
764    .. math::
765
766        \text{sigmoid}(x_i) = \frac{1}{1 + \exp(-x_i)},
767
768    where :math:`x_i` is an element of the input Tensor.
769
770    Inputs:
771        - **input_x** (Tensor) - Tensor of shape :math:`(N, *)`, where :math:`*` means, any number of
772          additional dimensions, with float16 or float32 data type.
773
774    Outputs:
775        Tensor, with the same type and shape as the input_x.
776
777    Raises:
778        TypeError: If dtype of `input_x` is neither float16 nor float32.
779        TypeError: If `input_x` is not a Tensor.
780
781    Supported Platforms:
782        ``Ascend`` ``GPU`` ``CPU``
783
784    Examples:
785        >>> input_x = Tensor(np.array([1, 2, 3, 4, 5]), mindspore.float32)
786        >>> sigmoid = ops.Sigmoid()
787        >>> output = sigmoid(input_x)
788        >>> print(output)
789        [0.7310586  0.880797   0.95257413 0.98201376 0.9933072 ]
790    """
791
792    @prim_attr_register
793    def __init__(self):
794        """Initialize Sigmoid."""
795        self.init_prim_io_names(inputs=['x'], outputs=['output'])
796
797    def infer_shape(self, input_x):
798        return input_x
799
800    def infer_dtype(self, input_x):
801        validator.check_tensor_dtype_valid("input_x", input_x, (mstype.float16, mstype.float32), self.name)
802        return input_x
803
804
805class HSigmoid(Primitive):
806    r"""
807    Hard sigmoid activation function.
808
809    Applies hard sigmoid activation element-wise. The input is a Tensor with any valid shape.
810
811    Hard sigmoid is defined as:
812
813    .. math::
814
815        \text{hsigmoid}(x_{i}) = max(0, min(1, \frac{x_{i} + 3}{6})),
816
817    where :math:`x_i` is an element of the input Tensor.
818
819    Inputs:
820        - **input_x** (Tensor) - Tensor of shape :math:`(N, *)`, where :math:`*` means, any number of
821          additional dimensions.
822
823    Outputs:
824        Tensor, with the same type and shape as the `input_x`.
825
826    Raises:
827        TypeError: If `input_x` is not a Tensor.
828
829    Supported Platforms:
830        ``Ascend`` ``GPU`` ``CPU``
831
832    Examples:
833        >>> hsigmoid = ops.HSigmoid()
834        >>> input_x = Tensor(np.array([-1, -2, 0, 2, 1]), mindspore.float16)
835        >>> result = hsigmoid(input_x)
836        >>> print(result)
837        [0.3333 0.1666 0.5    0.8335 0.6665]
838    """
839    @prim_attr_register
840    def __init__(self):
841        """Initialize HSigmoid."""
842        self.init_prim_io_names(inputs=['input_x'], outputs=['output'])
843
844
845class Tanh(PrimitiveWithInfer):
846    r"""
847    Tanh activation function.
848
849    Computes hyperbolic tangent of input element-wise. The Tanh function is defined as:
850
851    .. math::
852
853        tanh(x_i) = \frac{\exp(x_i) - \exp(-x_i)}{\exp(x_i) + \exp(-x_i)} = \frac{\exp(2x_i) - 1}{\exp(2x_i) + 1},
854
855    where :math:`x_i` is an element of the input Tensor.
856
857    Inputs:
858        - **input_x** (Tensor) - Tensor of shape :math:`(N, *)`, where :math:`*` means, any number of
859          additional dimensions, with float16 or float32 data type.
860
861    Outputs:
862        Tensor, with the same type and shape as the `input_x`.
863
864    Raises:
865        TypeError: If dtype of `input_x` is neither float16 nor float32.
866        TypeError: If `input_x` is not a Tensor.
867
868    Supported Platforms:
869        ``Ascend`` ``GPU``  ``CPU``
870
871    Examples:
872        >>> input_x = Tensor(np.array([1, 2, 3, 4, 5]), mindspore.float32)
873        >>> tanh = ops.Tanh()
874        >>> output = tanh(input_x)
875        >>> print(output)
876        [0.7615941 0.9640276 0.9950547 0.9993293 0.9999092]
877    """
878
879    @prim_attr_register
880    def __init__(self):
881        pass
882
883    def infer_shape(self, input_x):
884        return input_x
885
886    def infer_dtype(self, input_x):
887        validator.check_tensor_dtype_valid("input_x", input_x, mstype.float_type, self.name)
888        return input_x
889
890
891class FusedBatchNorm(Primitive):
892    r"""
893    The FusedBatchNorm interface is deprecated, please use the BatchNorm interface.
894    """
895
896    def __init__(self, mode=0, epsilon=1e-5, momentum=0.1):
897        raise TypeError("The FusedBatchNorm interface is deprecated, please use the BatchNorm interface.")
898
899
900class FusedBatchNormEx(PrimitiveWithCheck):
901    r"""
902    The FusedBatchNormEx interface is deprecated, please use the BatchNorm interface.
903    """
904
905    def __init__(self, mode=0, epsilon=1e-5, momentum=0.1, data_format="NCHW"):
906        raise TypeError("FusedBatchnormEx interface is deprecated, please use BatchNorm interface.")
907
908
909class InstanceNorm(PrimitiveWithInfer):
910    r"""
911    Instance Normalization over a 4D input.
912
913    This operator applies Instance Normalization over a 4D input (a mini-batch of 2D inputs with
914    additional channel dimension) as described in the paper `Instance Normalization: The Missing Ingredient for
915    Fast Stylization <https://arxiv.org/abs/1607.08022>`_. It rescales and recenters the feature using a mini-batch
916    of data and the learned parameters which can be described in the following formula.
917
918    .. math::
919
920        y = \frac{x - mean}{\sqrt{variance + \epsilon}} * \gamma + \beta
921
922    where :math:`\gamma` is scale, :math:`\beta` is bias, :math:`\epsilon` is epsilon.
923
924    Args:
925        epsilon (float): A small value added for numerical stability. Default: 1e-5.
926        momentum (float): The hyper parameter to compute moving average for running_mean and running_var
927            (e.g. :math:`new\_running\_mean = momentum * running\_mean + (1 - momentum) * current\_mean`).
928            Momentum value must be [0, 1]. Default: 0.1.
929
930    Inputs:
931        - **input_x** (Tensor) - The input of InstanceNorm, Tensor of shape :math:`(N, C)`,
932          data type: float16 or float32.
933        - **gamma** (Parameter) - Scale, Tensor of shape :math:`(C,)`,
934          data type: float32.
935        - **beta** (Parameter) - Bias, Tensor of shape :math:`(C,)`,
936          data type: float32.
937        - **mean** (Parameter) - Mean value, Tensor of shape :math:`(C,)`, data type: float32.
938        - **variance** (Parameter) - Variance value, Tensor of shape :math:`(C,)`, data type: float32.
939
940    Outputs:
941        Tuple of 3 Tensors, the normalized input, the updated parameters.
942
943        - **output_x** (Tensor) - The output of InstanceNorm, same type and shape as the `input_x`.
944        - **updated_moving_mean** (Tensor) - Updated mean value, Tensor of shape :math:`(NC,)`, data type: float32.
945        - **updated_moving_variance** (Tensor) - Updated variance value, Tensor of shape :math:`(NC,)`,
946          data type: float32.
947
948    Supported Platforms:
949        ``GPU``
950
951    Raises:
952        TypeError: If `epsilon` or `momentum` is not a float.
953        TypeError: If dtype of `input_x` is neither float16 nor float32.
954        TypeError: If dtype of `gamma`, `beta` or `mean` is not float32.
955        ValueError: If `epsilon` is not in the range of [0, 1).
956        ValueError: If `momentum` is not in the range of [0, 1].
957
958    Examples:
959        >>> class InstanceNormNet(nn.Cell):
960        >>>     def __init__(self):
961        >>>         super(InstanceNormNet, self).__init__()
962        >>>         self.instance_norm = ops.InstanceNorm()
963        >>>         self.gamma = Parameter(Tensor(np.ones([64]), mindspore.float32), name="gamma")
964        >>>         self.beta = Parameter(Tensor(np.ones([64]), mindspore.float32), name="beta")
965        >>>         self.mean = Parameter(Tensor(np.ones([64]), mindspore.float32), name="mean")
966        >>>         self.variance = Parameter(Tensor(np.ones([64]), mindspore.float32), name="variance")
967        >>>
968        >>>     def construct(self, input_x):
969        >>>         out = self.instance_norm(input_x, self.gamma, self.beta, self.mean, self.variance)
970        >>>         return out
971        >>>
972        >>> input_x = Tensor(np.ones([128, 64, 32, 64]), mindspore.float32)
973        >>> net = InstanceNormNet()
974        >>> output = net(input_x)
975        >>> result = output[0].shape
976        >>> print(result)
977        (128, 64, 32, 64)
978    """
979    __mindspore_signature__ = (
980        sig.make_sig('input_x', dtype=sig.sig_dtype.T2),
981        sig.make_sig('gamma', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
982        sig.make_sig('beta', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
983        sig.make_sig('mean', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
984        sig.make_sig('variance', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
985    )
986
987    @prim_attr_register
988    def __init__(self, epsilon=1e-5, momentum=0.1):
989        """Initialize InstanceNorm."""
990        self.init_prim_io_names(inputs=['x', 'gamma', 'beta', 'mean', 'variance'],
991                                outputs=['y', 'save_mean', 'save_variance'])
992        self.epsilon = validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name)
993        self.momentum = validator.check_float_range(momentum, 0, 1, Rel.INC_BOTH, 'momentum', self.name)
994        self._update_parameter = True
995
996    def infer_shape(self, input_x, gamma, beta, mean, variance):
997        input_shape_norm = input_x
998        validator.check_equal_int(len(gamma), 1, "gamma rank", self.name)
999        validator.check("gamma shape", gamma, "beta shape", beta, Rel.EQ, self.name)
1000        validator.check("gamma shape[0]", gamma[0], "input channel", input_shape_norm[1], Rel.EQ, self.name)
1001        validator.check_equal_int(len(mean), 1, "mean rank", self.name)
1002
1003        validator.check("mean shape", mean, "variance shape", variance, Rel.EQ, self.name)
1004        validator.check("mean shape", mean, "gamma shape", gamma, Rel.EQ, self.name)
1005        save_mean_shape = gamma
1006        save_mean_shape[0] = save_mean_shape[0] * input_shape_norm[0]
1007        return input_x, save_mean_shape, save_mean_shape
1008
1009    def infer_dtype(self, input_x, gamma, beta, mean, variance):
1010        validator.check_tensor_dtype_valid("input_x", input_x, [mstype.float16, mstype.float32], self.name)
1011        args = {"gamma": gamma, "beta": beta}
1012        validator.check_tensors_dtypes_same_and_valid(args, [mstype.float32], self.name)
1013        args_moving = {"mean": mean, "variance": variance}
1014        valid_dtypes = [mstype.tensor_type(mstype.float32)]
1015        validator.check_types_same_and_valid(args_moving, valid_dtypes, self.name)
1016        return input_x, gamma, gamma
1017
1018
1019class BNTrainingReduce(PrimitiveWithInfer):
1020    """
1021    For the BatchNorm operation this operator updates the moving averages for training and is used in conjunction with
1022    BNTrainingUpdate.
1023
1024    Inputs:
1025        - **x** (Tensor) - A 4-D Tensor with float16 or float32 data type. Tensor of shape :math:`(N, C, A, B)`.
1026
1027    Outputs:
1028        - **sum** (Tensor) - A 1-D Tensor with float32 data type. Tensor of shape :math:`(C,)`.
1029        - **square_sum** (Tensor) - A 1-D Tensor with float16 or float32 data type. Tensor of shape :math:`(C,)`.
1030
1031    Raises:
1032        TypeError: If `x` is not a Tensor.
1033        TypeError: If dtype of `x` is neither float16 nor float32.
1034
1035    Supported Platforms:
1036        ``Ascend``
1037
1038    Examples:
1039        >>> x = Tensor(np.ones([128, 3, 32, 3]), mindspore.float32)
1040        >>> bn_training_reduce = ops.BNTrainingReduce()
1041        >>> output = bn_training_reduce(x)
1042        >>> print(output)
1043        (Tensor(shape=[3], dtype=Float32, value=
1044        [ 1.22880000e+04, 1.22880000e+04, 1.22880000e+04]), Tensor(shape=[3], dtype=Float32, value=
1045        [ 1.22880000e+04, 1.22880000e+04, 1.22880000e+04]))
1046    """
1047
1048    @prim_attr_register
1049    def __init__(self):
1050        """Initialize BNTrainingReduce."""
1051        self.init_prim_io_names(inputs=['x'], outputs=['sum', 'square_sum'])
1052
1053    def infer_shape(self, x_shape):
1054        validator.check_equal_int(len(x_shape), 4, "x rank", self.name)
1055        return [x_shape[1]], [x_shape[1]]
1056
1057    def infer_dtype(self, x_type):
1058        validator.check_tensor_dtype_valid("x", x_type, [mstype.float16, mstype.float32], self.name)
1059        return x_type, x_type
1060
1061
1062class BNTrainingUpdate(PrimitiveWithInfer):
1063    """
1064    For the BatchNorm operation, this operator updates the moving averages for training and is used in conjunction with
1065    BNTrainingReduce. Where the moving averages is a method of analyzing data points by creating a series of averages
1066    of different subsets of the entire data set.
1067
1068    .. warning::
1069        For Ascend 310, the result accuracy fails to reach 1‰ due to the square root instruction.
1070
1071    Args:
1072        isRef (bool): If a ref. Default: True. Ref indicates whether to enable the output multiplexing input address.
1073        epsilon (float): A small value added to variance avoid dividing by zero. Default: 1e-5.
1074        factor (float): A weight for updating the mean and variance. Default: 0.1.
1075
1076    Inputs:
1077        - **input_x** (Tensor) - A 4-D Tensor with float16 or float32 data type. Tensor of shape :math:`(N, C, A, B)`.
1078        - **sum** (Tensor) - A 1-D Tensor with float16 or float32 data type for the output of operator BNTrainingReduce.
1079          Tensor of shape :math:`(C,)`.
1080        - **square_sum** (Tensor) - A 1-D Tensor with float16 or float32 data type for the output of operator
1081          BNTrainingReduce. Tensor of shape :math:`(C,)`.
1082        - **scale** (Tensor) - A 1-D Tensor with float16 or float32, for the scaling factor.
1083          Tensor of shape :math:`(C,)`.
1084        - **offset** (Tensor) - A 1-D Tensor with float16 or float32, for the scaling offset.
1085          Tensor of shape :math:`(C,)`.
1086        - **mean** (Tensor) - A 1-D Tensor with float16 or float32, for the scaling mean. Tensor of shape :math:`(C,)`.
1087        - **variance** (Tensor) - A 1-D Tensor with float16 or float32, for the update variance.
1088          Tensor of shape :math:`(C,)`.
1089
1090    Outputs:
1091        - **y** (Tensor) - Tensor, has the same shape and data type as `input_x`.
1092        - **mean** (Tensor) - Tensor for the updated mean, with float32 data type.
1093          Has the same shape as `variance`.
1094        - **variance** (Tensor) - Tensor for the updated variance, with float32 data type.
1095          Has the same shape as `variance`.
1096        - **batch_mean** (Tensor) - Tensor for the mean of `input_x`, with float32 data type.
1097          Has the same shape as `variance`.
1098        - **batch_variance** (Tensor) - Tensor for the mean of `variance`, with float32 data type.
1099          Has the same shape as `variance`.
1100
1101    Raises:
1102        TypeError: If `isRef` is not a bool.
1103        TypeError: If dtype of `epsilon` or `factor` is not float.
1104        TypeError: If `input_x`, `sum`, `square_sum`, `scale`, `offset`, `mean` or `variance` is not a Tensor.
1105        TypeError: If dtype of `input_x`, `sum`, `square_sum`, `scale`, `offset`, `mean` or `variance` is neither
1106                   float16 nor float32.
1107
1108    Supported Platforms:
1109        ``Ascend``
1110
1111    Examples:
1112        >>> input_x = Tensor(np.ones([1, 2, 2, 2]), mindspore.float32)
1113        >>> sum_val = Tensor(np.ones([2]), mindspore.float32)
1114        >>> square_sum = Tensor(np.ones([2]), mindspore.float32)
1115        >>> scale = Tensor(np.ones([2]), mindspore.float32)
1116        >>> offset = Tensor(np.ones([2]), mindspore.float32)
1117        >>> mean = Tensor(np.ones([2]), mindspore.float32)
1118        >>> variance = Tensor(np.ones([2]), mindspore.float32)
1119        >>> bn_training_update = ops.BNTrainingUpdate()
1120        >>> output = bn_training_update(input_x, sum_val, square_sum, scale, offset, mean, variance)
1121        >>> print(output)
1122        (Tensor(shape=[1, 2, 2, 2], dtype=Float32, value=
1123        [[[[ 2.73200464e+00,  2.73200464e+00],
1124           [ 2.73200464e+00,  2.73200464e+00]],
1125          [[ 2.73200464e+00,  2.73200464e+00],
1126           [ 2.73200464e+00,  2.73200464e+00]]]]), Tensor(shape=[2], dtype=Float32, value= [9.24999952e-01,
1127        9.24999952e-01]), Tensor(shape=[2], dtype=Float32, value= [ 9.24999952e-01, 9.24999952e-01]),
1128        Tensor(shape=[2], dtype=Float32, value= [ 2.50000000e-01, 2.50000000e-01]), Tensor(shape=[2], dtype=Float32,
1129        value= [ 1.87500000e-01, 1.87500000e-01]))
1130    """
1131
1132    @prim_attr_register
1133    def __init__(self, isRef=True, epsilon=1e-5, factor=0.1):
1134        """Initialize BNTrainingUpdate."""
1135        self.init_prim_io_names(inputs=['x', 'sum', 'square_sum', 'scale', 'b', 'mean', 'variance'],
1136                                outputs=['y', 'running_mean', 'running_variance', 'save_mean', 'save_inv_variance'])
1137        validator.check_value_type("isRef", isRef, [bool], self.name)
1138        validator.check_value_type("epsilon", epsilon, [float], self.name)
1139        validator.check_value_type("factor", factor, [float], self.name)
1140        self.epsilon = validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', 'BNTrainingUpdate')
1141        self.factor = validator.check_float_range(factor, 0, 1, Rel.INC_BOTH, 'factor', 'BNTrainingUpdate')
1142
1143    def infer_shape(self, x, sum, square_sum, scale, b, mean, variance):
1144        validator.check_equal_int(len(x), 4, "x rank", self.name)
1145        validator.check_equal_int(len(sum), 1, "sum rank", self.name)
1146        validator.check_equal_int(len(square_sum), 1, "square_sum rank", self.name)
1147        validator.check_equal_int(len(scale), 1, "scale rank", self.name)
1148        validator.check_equal_int(len(b), 1, "b rank", self.name)
1149        validator.check_equal_int(len(mean), 1, "mean rank", self.name)
1150        validator.check_equal_int(len(variance), 1, "variance rank", self.name)
1151        validator.check("sum shape", sum[0], "x_shape[1]", x[1], Rel.EQ, self.name)
1152        validator.check("square_sum shape", square_sum, "sum", sum, Rel.EQ, self.name)
1153        validator.check("scale shape", scale[0], "x_shape[1]", x[1], Rel.EQ, self.name)
1154        validator.check("offset shape", b[0], "x_shape[1]", x[1], Rel.EQ, self.name)
1155        validator.check("mean shape", mean[0], "x_shape[1]", x[1], Rel.EQ, self.name)
1156        validator.check("variance shape", variance[0], "x_shape[1]", x[1], Rel.EQ, self.name)
1157        return x, variance, variance, variance, variance
1158
1159    def infer_dtype(self, x, sum, square_sum, scale, b, mean, variance):
1160        tuple(map(partial(validator.check_tensor_dtype_valid,
1161                          valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name),
1162                  ("x", "sum", "square_sum", "scale", "b", "mean", "variance"),
1163                  (x, sum, square_sum, scale, b, mean, variance)))
1164        return x, variance, variance, variance, variance
1165
1166
1167class BatchNorm(PrimitiveWithInfer):
1168    r"""
1169    Batch Normalization for input data and updated parameters.
1170
1171    Batch Normalization is widely used in convolutional neural networks. This operation
1172    applies Batch Normalization over inputs to avoid internal covariate shift as described
1173    in the paper `Batch Normalization: Accelerating Deep Network Training by Reducing Internal
1174    Covariate Shift <https://arxiv.org/abs/1502.03167>`_. It rescales and recenters the
1175    features using a mini-batch of data and the learned parameters can be described
1176    in the following formula,
1177
1178    .. math::
1179
1180        y = \frac{x - mean}{\sqrt{variance + \epsilon}} * \gamma + \beta
1181
1182    where :math:`\gamma` is scale, :math:`\beta` is bias, :math:`\epsilon` is epsilon, :math:`mean` is the mean of x,
1183    :math:`variance` is the variance of x.
1184
1185    .. warning::
1186        - If the operation is used for inference, and outputs "reserve_space_1" and "reserve_space_2" are available,
1187          then "reserve_space_1" has the same value as "mean" and "reserve_space_2" has the same value as "variance".
1188        - For Ascend 310, the result accuracy fails to reach 1‰ due to the square root instruction.
1189
1190    Args:
1191        is_training (bool): If `is_training` is True, `mean` and `variance` are computed during training.
1192            If `is_training` is False, they're loaded from checkpoint during inference. Default: False.
1193        epsilon (float): A small value added for numerical stability. Default: 1e-5.
1194        momentum (float): The hyper parameter to compute moving average for running_mean and running_var
1195            (e.g. :math:`new\_running\_mean = (1 - momentum) * running\_mean + momentum * current\_mean`).
1196            Momentum value must be [0, 1]. Default: 0.1.
1197        data_format (str): The optional value for data format, is 'NHWC' or 'NCHW'.
1198            Default: "NCHW".
1199
1200    Inputs:
1201        If `is_training` is False, inputs are Tensors.
1202
1203        - **input_x** (Tensor) - Tensor of shape :math:`(N, C)`, with float16 or float32 data type.
1204        - **scale** (Tensor) - Tensor of shape :math:`(C,)`, with float16 or float32 data type.
1205        - **bias** (Tensor) - Tensor of shape :math:`(C,)`, has the same data type with `scale`.
1206        - **mean** (Tensor) - Tensor of shape :math:`(C,)`, has the same data type with `scale`.
1207        - **variance** (Tensor) - Tensor of shape :math:`(C,)`, has the same data type with `scale`.
1208
1209        If `is_training` is True, `scale`, `bias`, `mean` and `variance` are Parameters.
1210
1211        - **input_x** (Tensor) - Tensor of shape :math:`(N, C)`, with float16 or float32 data type.
1212        - **scale** (Parameter) - Parameter of shape :math:`(C,)`, with float16 or float32 data type.
1213        - **bias** (Parameter) - Parameter of shape :math:`(C,)`, has the same data type with `scale`.
1214        - **mean** (Parameter) - Parameter of shape :math:`(C,)`, has the same data type with `scale`.
1215        - **variance** (Parameter) - Parameter of shape :math:`(C,)`, has the same data type with `scale`.
1216
1217    Outputs:
1218        Tuple of 5 Tensors, the normalized inputs and the updated parameters.
1219
1220        - **output_x** (Tensor) - The same type and shape as the input_x. The shape is :math:`(N, C)`.
1221        - **updated_scale** (Tensor) - Tensor of shape :math:`(C,)`.
1222        - **updated_bias** (Tensor) - Tensor of shape :math:`(C,)`.
1223        - **reserve_space_1** (Tensor) - Tensor of shape :math:`(C,)`.
1224        - **reserve_space_2** (Tensor) - Tensor of shape :math:`(C,)`.
1225
1226    Raises:
1227        TypeError: If `is_training` is not a bool.
1228        TypeError: If dtype of `epsilon` or `momentum` is not float.
1229        TypeError: If `data_format` is not a str.
1230        TypeError: If `input_x`, `scale`, `bias`, `mean` or `variance` is not a Tensor.
1231        TypeError: If dtype of `input_x`, `scale` is neither float16 nor float32.
1232
1233    Supported Platforms:
1234        ``Ascend`` ``CPU`` ``GPU``
1235
1236    Examples:
1237        >>> input_x = Tensor(np.ones([2, 2]), mindspore.float32)
1238        >>> scale = Tensor(np.ones([2]), mindspore.float32)
1239        >>> bias = Tensor(np.ones([2]), mindspore.float32)
1240        >>> mean = Tensor(np.ones([2]), mindspore.float32)
1241        >>> variance = Tensor(np.ones([2]), mindspore.float32)
1242        >>> batch_norm = ops.BatchNorm()
1243        >>> output = batch_norm(input_x, scale, bias, mean, variance)
1244        >>> print(output[0])
1245        [[1. 1.]
1246         [1. 1.]]
1247    """
1248
1249    __mindspore_signature__ = (
1250        sig.make_sig('input_x', dtype=sig.sig_dtype.T1),
1251        sig.make_sig('scale', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T2),
1252        sig.make_sig('bias', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T2),
1253        sig.make_sig('mean', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T3),
1254        sig.make_sig('variance', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T3)
1255    )
1256
1257    @prim_attr_register
1258    def __init__(self, is_training=False, epsilon=1e-5, momentum=0.1, data_format="NCHW"):
1259        """Initialize BatchNorm."""
1260        if is_training is False:
1261            self.set_signatures(tuple())
1262        validator.check_value_type('is_training', is_training, (bool,), self.name)
1263        validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name)
1264        validator.check_float_range(momentum, 0, 1, Rel.INC_BOTH, 'momentum', self.name)
1265        self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name)
1266        if context.get_context("device_target") != "GPU" and self.format == "NHWC":
1267            raise ValueError(f"For '{self.name}', the \"NHWC\" format only support in GPU target, "
1268                             f"but got the format is {self.format} and "
1269                             f"the platform is {context.get_context('device_target')}.")
1270        self.add_prim_attr('data_format', self.format)
1271        self.init_prim_io_names(inputs=['x', 'scale', 'offset', 'mean', 'variance'],
1272                                outputs=['y', 'batch_mean', 'batch_variance', 'reserve_space_1', 'reserve_space_2'])
1273
1274    def infer_shape(self, input_x, scale, bias, mean, variance):
1275        input_x_channel = input_x[-1] if self.format == "NHWC" else input_x[1]
1276        validator.check_equal_int(len(scale), 1, "scale rank", self.name)
1277        validator.check("scale shape", scale, "bias shape", bias, Rel.EQ, self.name)
1278        validator.check("scale shape[0]", scale[0], "input_x channel", input_x_channel, Rel.EQ, self.name)
1279        if not self.is_training:
1280            validator.check_equal_int(len(mean), 1, "mean rank", self.name)
1281            validator.check("mean shape", mean, "variance shape", variance, Rel.EQ, self.name)
1282            validator.check("mean shape", mean, "scale shape", scale, Rel.EQ, self.name)
1283        return input_x, scale, scale, scale, scale
1284
1285    def infer_dtype(self, input_x, scale, bias, mean, variance):
1286        validator.check_tensor_dtype_valid("input_x", input_x, [mstype.float16, mstype.float32], self.name)
1287        args = {"scale": scale, "bias": bias, "mean": mean, "variance": variance}
1288        validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name)
1289        return input_x, mstype.float32, mstype.float32, mstype.float32, mstype.float32
1290
1291
1292class Conv2D(Primitive):
1293    r"""
1294    2D convolution layer.
1295
1296    Applies a 2D convolution over an input tensor which is typically of shape :math:`(N, C_{in}, H_{in}, W_{in})`,
1297    where :math:`N` is batch size, :math:`C` is channel number, :math:`H` is height, :math:`W` is width, :math:`X_i` is
1298    the :math:`i^{th}` input value and :math:`b_i` indicates the deviation value of the :math:`i^{th}` input value.
1299    For each batch of shape :math:`(C_{in}, H_{in}, W_{in})`, the formula is defined as:
1300
1301    .. math::
1302
1303        out_j = \sum_{i=0}^{C_{in} - 1} ccor(W_{ij}, X_i) + b_j,
1304
1305    where :math:`ccor` is the cross correlation operator, :math:`C_{in}` is the input channel number, :math:`j` ranges
1306    from :math:`0` to :math:`C_{out} - 1`, :math:`W_{ij}` corresponds to the :math:`i`-th channel of the :math:`j`-th
1307    filter and :math:`out_{j}` corresponds to the :math:`j`-th channel of the output. :math:`W_{ij}` is a slice
1308    of kernel and it has shape :math:`(\text{kernel_size[0]}, \text{kernel_size[1]})`,
1309    where :math:`\text{kernel_size[0]}` and :math:`\text{kernel_size[1]}` are the height and width of the
1310    convolution kernel. The full kernel has shape
1311    :math:`(C_{out}, C_{in} // \text{group}, \text{kernel_size[0]}, \text{kernel_size[1]})`,
1312    where group is the group number to split the input in the channel dimension.
1313
1314    If the 'pad_mode' is set to be "valid", the output height and width will be
1315    :math:`\left \lfloor{1 + \frac{H_{in} + \text{padding[0]} + \text{padding[1]} - \text{kernel_size[0]} -
1316    (\text{kernel_size[0]} - 1) \times (\text{dilation[0]} - 1) }{\text{stride[0]}}} \right \rfloor` and
1317    :math:`\left \lfloor{1 + \frac{W_{in} + \text{padding[2]} + \text{padding[3]} - \text{kernel_size[1]} -
1318    (\text{kernel_size[1]} - 1) \times (\text{dilation[1]} - 1) }{\text{stride[1]}}} \right \rfloor` respectively.
1319    Where :math:`dialtion` is Spacing between kernel elements, :math:`stride` is The step length of each step,
1320    :math:`padding` is zero-padding added to both sides of the input.
1321
1322
1323    The first introduction can be found in paper `Gradient Based Learning Applied to Document Recognition
1324    <http://vision.stanford.edu/cs598_spring07/papers/Lecun98.pdf>`_. More detailed introduction can be found here:
1325    http://cs231n.github.io/convolutional-networks/.
1326
1327    Args:
1328        out_channel (int): The number of output channel :math:`C_{out}`.
1329        kernel_size (Union[int, tuple[int]]): The data type is int or a tuple of 2 integers. Specifies the height
1330            and width of the 2D convolution window. Single int means the value is for both the height and the width of
1331            the kernel. A tuple of 2 ints means the first value is for the height and the other is for the
1332            width of the kernel.
1333        mode (int): Modes for different convolutions. 0 Math convolutiuon, 1 cross-correlation convolution ,
1334                       2 deconvolution, 3 depthwise convolution. Default: 1.
1335        pad_mode (str): Specifies padding mode. The optional values are
1336            "same", "valid", "pad". Default: "valid".
1337
1338            - same: Adopts the way of completion. The height and width of the output will be the same as
1339              the input `x`. The total number of padding will be calculated in horizontal and vertical
1340              directions and evenly distributed to top and bottom, left and right if possible. Otherwise, the
1341              last extra padding will be done from the bottom and the right side. If this mode is set, `pad`
1342              must be 0.
1343
1344            - valid: Adopts the way of discarding. The possible largest height and width of output will be returned
1345              without padding. Extra pixels will be discarded. If this mode is set, `pad` must be 0.
1346
1347            - pad: Implicit paddings on both sides of the input `x`. The number of `pad` will be padded to the input
1348              Tensor borders. `pad` must be greater than or equal to 0.
1349        pad (Union(int, tuple[int])): Implicit paddings on both sides of the input `x`. If `pad` is one integer,
1350                    the paddings of top, bottom, left and right are the same, equal to pad. If `pad` is a tuple
1351                    with four integers, the paddings of top, bottom, left and right will be equal to pad[0],
1352                    pad[1], pad[2], and pad[3] accordingly. Default: 0.
1353        stride (Union(int, tuple[int])): The distance of kernel moving, an int number that represents
1354            the height and width of movement are both strides, or a tuple of two int numbers that
1355            represent height and width of movement respectively. Default: 1.
1356        dilation (Union(int, tuple[int])): The data type is int or a tuple of 2 integers. Specifies the dilation rate
1357                                      to use for dilated convolution. If set to be :math:`k > 1`, there will
1358                                      be :math:`k - 1` pixels skipped for each sampling location. Its value must
1359                                      be greater or equal to 1 and bounded by the height and width of the
1360                                      input `x`. Default: 1.
1361        group (int): Splits input into groups. Default: 1.
1362        data_format (str): The optional value for data format, is 'NHWC' or 'NCHW'. Default: "NCHW".
1363
1364    Inputs:
1365        - **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
1366        - **weight** (Tensor) - Set size of kernel is :math:`(\text{kernel_size[0]}, \text{kernel_size[1]})`,
1367          then the shape is :math:`(C_{out}, C_{in}, \text{kernel_size[0]}, \text{kernel_size[1]})`.
1368
1369    Outputs:
1370        Tensor, the value that applied 2D convolution. The shape is :math:`(N, C_{out}, H_{out}, W_{out})`.
1371
1372    Raises:
1373        TypeError: If `kernel_size`, `stride`, `pad` or `dilation` is neither an int nor a tuple.
1374        TypeError: If `out_channel` or `group` is not an int.
1375        ValueError: If `kernel_size`, `stride` or `dilation` is less than 1.
1376        ValueError: If `pad_mode` is not one of 'same', 'valid', 'pad'.
1377        ValueError: If `pad` is a tuple whose length is not equal to 4.
1378        ValueError: If `pad_mode` it not equal to 'pad' and `pad` is not equal to (0, 0, 0, 0).
1379        ValueError: If `data_format` is neither 'NCHW' not 'NHWC'.
1380
1381    Supported Platforms:
1382        ``Ascend`` ``GPU`` ``CPU``
1383
1384    Examples:
1385        >>> x = Tensor(np.ones([10, 32, 32, 32]), mindspore.float32)
1386        >>> weight = Tensor(np.ones([32, 32, 3, 3]), mindspore.float32)
1387        >>> conv2d = ops.Conv2D(out_channel=32, kernel_size=3)
1388        >>> output = conv2d(x, weight)
1389        >>> print(output.shape)
1390        (10, 32, 30, 30)
1391    """
1392
1393    @prim_attr_register
1394    def __init__(self,
1395                 out_channel,
1396                 kernel_size,
1397                 mode=1,
1398                 pad_mode="valid",
1399                 pad=0,
1400                 stride=1,
1401                 dilation=1,
1402                 group=1,
1403                 data_format="NCHW"):
1404        """Initialize Conv2D"""
1405        self.init_prim_io_names(inputs=['x', 'w'], outputs=['output'])
1406        self.kernel_size = _check_positive_int_or_tuple('kernel_size', kernel_size, self.name)
1407        self.stride = _check_positive_int_or_tuple('stride', stride, self.name, allow_four=True, ret_four=True)
1408        self.add_prim_attr('stride', self.stride)
1409        self.dilation = _check_positive_int_or_tuple('dilation', dilation, self.name, allow_four=True, ret_four=True)
1410        self.add_prim_attr('dilation', self.dilation)
1411        validator.check_value_type('pad', pad, (int, tuple), self.name)
1412        if isinstance(pad, int):
1413            pad = (pad,) * 4
1414        else:
1415            validator.check_equal_int(len(pad), 4, 'pad size', self.name)
1416        self.add_prim_attr("pad", pad)
1417        self.padding = pad
1418        self.pad_mode = validator.check_string(pad_mode, ['valid', 'same', 'pad'], 'pad_mode', self.name)
1419
1420        if pad_mode != 'pad' and pad != (0, 0, 0, 0):
1421            raise ValueError(f"For '{self.name}', the 'pad' must be zero when 'pad_mode' is not \"pad\", "
1422                             f"but got 'pad' and 'pad_mode' is {pad_mode}.")
1423        if self.pad_mode == 'pad':
1424            for item in pad:
1425                validator.check_non_negative_int(item, 'pad item', self.name)
1426
1427        self.mode = validator.check_equal_int(mode, 1, 'mode', self.name)
1428        self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name)
1429        if context.get_context("device_target") != "GPU" and self.format == "NHWC":
1430            raise ValueError(f"For '{self.name}', the \"NHWC\" format only support in GPU target, "
1431                             f"but got the format is {self.format} "
1432                             f"and platform is {context.get_context('device_target')}.")
1433        self.add_prim_attr('data_format', self.format)
1434        self.out_channel = validator.check_positive_int(out_channel, 'out_channel', self.name)
1435        self.group = validator.check_positive_int(group, 'group', self.name)
1436        self.add_prim_attr('groups', self.group)
1437
1438
1439class DepthwiseConv2dNative(PrimitiveWithInfer):
1440    r"""
1441    Returns the depth-wise convolution value for the input.
1442
1443    Applies depthwise conv2d for the input, which will generate more channels with channel_multiplier.
1444    Given an input tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})` where :math:`N` is the batch size,
1445    :math:`C` is the channels, :math:`H` is height, :math:`W` is width and a filter tensor with kernel size
1446    :math:`(\text{kernel_size[0]}, \text{kernel_size[1]})`, where :math:`\text{kernel_size[0]}` indicates the
1447    kernel_size of height, :math:`\text{kernel_size[1]}` indicates the kernel_size of width, containing
1448    :math:`C_{in} * \text{channel_multiplier}` convolutional filters of depth 1;
1449    it applies different filters to each input channel (channel_multiplier channels
1450    for each input channel has the default value 1), then concatenates the results together. The output has
1451    :math:`C_{in} * \text{channel_multiplier}` channels.
1452
1453    Args:
1454        channel_multiplier (int): The multiplier for the original output convolution. Its value must be greater than 0.
1455        kernel_size (Union[int, tuple[int]]): The data type is int or a tuple of 2 integers. Specifies the height
1456            and width of the 2D convolution window. Single int means the value is for both the height and the width of
1457            the kernel. A tuple of 2 ints means the first value is for the height and the other is for the
1458            width of the kernel.
1459        mode (int): Modes for different convolutions. 0 Math convolution, 1 cross-correlation convolution ,
1460                       2 deconvolution, 3 depthwise convolution. Default: 3.
1461        pad_mode (str): Specifies padding mode. The optional values are
1462            "same", "valid", "pad". Default: "valid".
1463
1464            - same: Adopts the way of completion. The height and width of the output will be the same as
1465              the input `x`. The total number of padding will be calculated in horizontal and vertical
1466              directions and evenly distributed to top and bottom, left and right if possible. Otherwise, the
1467              last extra padding will be done from the bottom and the right side. If this mode is set, `pad`
1468              must be 0.
1469
1470            - valid: Adopts the way of discarding. The possible largest height and width of output will be returned
1471              without padding. Extra pixels will be discarded. If this mode is set, `pad` must be 0.
1472
1473            - pad: Implicit paddings on both sides of the input `x`. The number of `pad` will be padded to the input
1474              Tensor borders. `pad` must be greater than or equal to 0.
1475        pad (Union[int, tuple[int]]): Implicit paddings on both sides of the input `x`. If `pad` is one integer,
1476                    the paddings of top, bottom, left and right are the same, equal to pad. If `pad` is a tuple
1477                    with four integers, the paddings of top, bottom, left and right will be equal to pad[0],
1478                    pad[1], pad[2], and pad[3] accordingly. Default: 0.
1479        stride (Union(int, tuple[int])): The distance of kernel moving, an int number that represents
1480            the height and width of movement are both strides, or a tuple of two int numbers that
1481            represent height and width of movement respectively. Default: 1.
1482        dilation (Union(int, tuple[int])): The data type is int or a tuple of 2 integers. Specifies the dilation rate
1483                                      to use for dilated convolution. If set to be :math:`k > 1`, there will
1484                                      be :math:`k - 1` pixels skipped for each sampling location. Its value must
1485                                      be greater or equal to 1 and bounded by the height and width of the
1486                                      input `x`. Default: 1.
1487        group (int): Splits input into groups. Default: 1.
1488
1489    Inputs:
1490        - **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
1491        - **weight** (Tensor) - Set the size of kernel as :math:`(\text{kernel_size[0]}, \text{kernel_size[1]})`,
1492          then the shape is :math:`(K, C_{in}, \text{kernel_size[0]}, \text{kernel_size[1]})`, `K` must be 1.
1493
1494    Outputs:
1495        Tensor of shape :math:`(N, C_{in} * \text{channel_multiplier}, H_{out}, W_{out})`.
1496
1497    Raises:
1498        TypeError: If `kernel_size`, `stride`, `pad` or `dilation` is neither an int nor a tuple.
1499        TypeError: If `channel_multiplier` or `group` is not an int.
1500        ValueError: If `stride` or `dilation` is less than 1.
1501        ValueError: If `pad_mode` is not one of the following:'same', 'valid' or 'pad'.
1502        ValueError: If `pad_mode` it not equal to 'pad' and `pad` is not equal to (0, 0, 0, 0).
1503
1504    Supported Platforms:
1505        ``Ascend``
1506
1507    Examples:
1508        >>> x = Tensor(np.ones([10, 32, 32, 32]), mindspore.float32)
1509        >>> weight = Tensor(np.ones([1, 32, 3, 3]), mindspore.float32)
1510        >>> depthwise_conv2d = ops.DepthwiseConv2dNative(channel_multiplier=3, kernel_size=(3, 3))
1511        >>> output = depthwise_conv2d(x, weight)
1512        >>> print(output.shape)
1513        (10, 96, 30, 30)
1514    """
1515
1516    @prim_attr_register
1517    def __init__(self,
1518                 channel_multiplier,
1519                 kernel_size,
1520                 mode=3,
1521                 pad_mode="valid",
1522                 pad=0,
1523                 stride=1,
1524                 dilation=1,
1525                 group=1):
1526        """Initialize DepthwiseConv2dNative"""
1527        logger.warning("WARN_DEPRECATED: The usage of DepthwiseConv2dNative is deprecated."
1528                       " Please use nn.Conv2D.")
1529        self.init_prim_io_names(inputs=['x', 'w'], outputs=['output'])
1530        self.kernel_size = _check_positive_int_or_tuple('kernel_size', kernel_size, self.name)
1531        self.stride = _check_positive_int_or_tuple('stride', stride, self.name)
1532        if self.stride[0] != self.stride[1]:
1533            raise ValueError("The height and width of stride should be equal,"
1534                             f"but got height:{self.stride[0]},  width:{self.stride[1]}")
1535        self.add_prim_attr('stride', (1, 1, self.stride[0], self.stride[1]))
1536
1537        self.dilation = _check_positive_int_or_tuple('dilation', dilation, self.name)
1538        if self.dilation[0] != self.dilation[1]:
1539            raise ValueError("The height and width of dilation should be equal,"
1540                             f"but got height:{self.dilation[0]},  width:{self.dilation[1]}")
1541        self.add_prim_attr('dilation', (1, 1, self.dilation[0], self.dilation[1]))
1542        validator.check_value_type('pad', pad, (int, tuple), self.name)
1543        if isinstance(pad, int):
1544            pad = (pad,) * 4
1545        else:
1546            validator.check_equal_int(len(pad), 4, 'pad size', self.name)
1547        self.add_prim_attr("pad", pad)
1548        self.padding = pad
1549        self.pad_mode = validator.check_string(pad_mode, ['valid', 'same', 'pad'], 'pad_mode', self.name)
1550        if pad_mode != 'pad' and pad != (0, 0, 0, 0):
1551            raise ValueError(f"For '{self.name}', the 'pad' must be zero when 'pad_mode' is not \"pad\", "
1552                             f"but got 'pad' is {pad} and 'pad_mode' is {pad_mode}")
1553        if self.pad_mode == 'pad':
1554            for item in pad:
1555                validator.check_non_negative_int(item, 'pad item', self.name)
1556        self.mode = validator.check_equal_int(mode, 3, "mode", self.name)
1557        self.add_prim_attr('data_format', "NCHW")
1558        self.channel_multiplier = validator.check_positive_int(channel_multiplier, "channel_multiplier", self.name)
1559        self.group = validator.check_positive_int(group, "group", self.name)
1560        self.add_prim_attr('offset_a', 0)
1561
1562    def infer_shape(self, x_shape, w_shape, b_shape=None):
1563        validator.check_equal_int(len(w_shape), 4, "weight rank", self.name)
1564        validator.check_equal_int(len(x_shape), 4, "x rank", self.name)
1565        validator.check("x_shape[1]", x_shape[1], "w_shape[1]", w_shape[1], Rel.EQ, self.name)
1566        validator.check('kernel_size', self.kernel_size, 'w_shape[2:4]', tuple(w_shape[2:4]), Rel.EQ, self.name)
1567
1568        kernel_size_n, _, kernel_size_h, kernel_size_w = w_shape
1569        _, _, stride_h, stride_w = self.stride
1570        _, _, dilation_h, dilation_w = self.dilation
1571        if kernel_size_n != 1:
1572            raise ValueError(f"For '{self.name}', the batch of 'weight' should be 1, but got {kernel_size_n}")
1573        if self.pad_mode == "valid":
1574            h_out = math.ceil((x_shape[2] - dilation_h * (kernel_size_h - 1)) / stride_h)
1575            w_out = math.ceil((x_shape[3] - dilation_w * (kernel_size_w - 1)) / stride_w)
1576            pad_top, pad_bottom, pad_left, pad_right = 0, 0, 0, 0
1577        elif self.pad_mode == "same":
1578            h_out = math.ceil(x_shape[2] / stride_h)
1579            w_out = math.ceil(x_shape[3] / stride_w)
1580
1581            pad_needed_h = max(0, (h_out - 1) * stride_h + dilation_h * (kernel_size_h - 1) + 1 - x_shape[2])
1582            pad_top = math.floor(pad_needed_h / 2)
1583            pad_bottom = pad_needed_h - pad_top
1584
1585            pad_needed_w = max(0, (w_out - 1) * stride_w + dilation_w * (kernel_size_w - 1) + 1 - x_shape[3])
1586            pad_left = math.floor(pad_needed_w / 2)
1587            pad_right = pad_needed_w - pad_left
1588        elif self.pad_mode == 'pad':
1589            pad_top, pad_bottom, pad_left, pad_right = self.padding
1590
1591            h_out = 1 + (x_shape[2] + pad_top + pad_bottom - kernel_size_h - (kernel_size_h - 1) * (dilation_h - 1)) \
1592                    / stride_h
1593            w_out = 1 + (x_shape[3] + pad_left + pad_right - kernel_size_w - (kernel_size_w - 1) * (dilation_w - 1)) \
1594                    / stride_w
1595            h_out = math.floor(h_out)
1596            w_out = math.floor(w_out)
1597
1598        self.pad_list = (pad_top, pad_bottom, pad_left, pad_right)
1599        self.add_prim_attr('pad_list', self.pad_list)
1600
1601        out_channel = self.channel_multiplier * x_shape[1]
1602        out_shape = [x_shape[0], out_channel, h_out, w_out]
1603        return out_shape
1604
1605    def infer_dtype(self, x_dtype, w_dtype, b_dtype=None):
1606        args = {'x': x_dtype, 'w': w_dtype}
1607        validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
1608        if x_dtype.element_type() == mstype.int8:
1609            return mstype.tensor_type(mstype.int32)
1610        return x_dtype
1611
1612
1613class _Pool(PrimitiveWithInfer):
1614    r"""
1615    Performs max/avg pooling operation.
1616
1617    Args:
1618        kernel_size (Union[int, tuple[int]]): The size of the kernel, that must be a tuple
1619           of two `int` for height and width. Default: 1.
1620        strides (Union[int, tuple[int]]): The stride of the window, that must be
1621            a tuple of two `int` for height and width. Default: 1.
1622        pad_mode (str): The optional value for pad mode, is "same" or "valid", not case sensitive.
1623            Default: "valid".
1624        data_format (str): The optional value for data format, is 'NHWC' or 'NCHW'.
1625            Default: "NCHW".
1626    """
1627
1628    @prim_attr_register
1629    def __init__(self, kernel_size=1, strides=1, pad_mode="valid", data_format="NCHW"):
1630        """Initialize _Pool."""
1631        self.init_prim_io_names(inputs=['x'], outputs=['output'])
1632        validator.check_value_type('kernel_size', kernel_size, [int, tuple], self.name)
1633        validator.check_value_type('strides', strides, [int, tuple], self.name)
1634        self.pad_mode = validator.check_string(pad_mode.upper(), ['VALID', 'SAME'], 'pad_mode', self.name)
1635        self.add_prim_attr("pad_mode", self.pad_mode)
1636        self.is_maxpoolwithargmax = (self.name == "MaxPoolWithArgmax")
1637        self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name)
1638        if context.get_context("device_target") != "GPU" and self.format == "NHWC":
1639            raise ValueError(f"For '{self.name}', the \"NHWC\" format only support in GPU target, "
1640                             f"but got the format is {self.format} and "
1641                             f"the platform is {context.get_context('device_target')}.")
1642        if not self.is_maxpoolwithargmax:
1643            self.add_prim_attr('data_format', self.format)
1644
1645        self.kernel_size = _check_positive_int_or_tuple(
1646            "kernel_size", kernel_size, self.name, allow_four=False, ret_four=True)
1647        if self.is_maxpoolwithargmax:
1648            self.kernel_size = (1, self.kernel_size[-2], self.kernel_size[-1], 1)
1649        self.add_prim_attr("kernel_size", self.kernel_size)
1650
1651        self.strides = _check_positive_int_or_tuple("strides", strides, self.name, allow_four=False, ret_four=True)
1652        if self.is_maxpoolwithargmax:
1653            self.strides = (1, self.strides[-2], self.strides[-1], 1)
1654        self.add_prim_attr("strides", self.strides)
1655
1656    def infer_shape(self, x_shape):
1657        x_shape_norm = x_shape if self.format == "NCHW" else [x_shape[0], x_shape[3], x_shape[1], x_shape[2]]
1658        validator.check_equal_int(len(x_shape_norm), 4, "x rank", self.name)
1659        batch, channel, input_h, input_w = x_shape_norm
1660        if self.is_maxpoolwithargmax:
1661            _, kernel_h, kernel_w, _ = self.kernel_size
1662            _, stride_h, stride_w, _ = self.strides
1663        else:
1664            _, _, kernel_h, kernel_w = self.kernel_size
1665            _, _, stride_h, stride_w = self.strides
1666
1667        if self.pad_mode == "VALID":
1668            out_h = math.ceil((input_h - (kernel_h - 1)) / stride_h)
1669            out_w = math.ceil((input_w - (kernel_w - 1)) / stride_w)
1670        elif self.pad_mode == "SAME":
1671            out_h = math.ceil(input_h / stride_h)
1672            out_w = math.ceil(input_w / stride_w)
1673        out_shape = [batch, channel, out_h, out_w] if self.format == "NCHW" else [batch, out_h, out_w, channel]
1674
1675        for shape_value in out_shape:
1676            if shape_value <= 0:
1677                raise ValueError(f"For '{self.name}', the each element of the output shape must be larger than 0, "
1678                                 f"but got output shape: {out_shape}. The input shape: {x_shape}, "
1679                                 f"kernel size: {self.kernel_size}, strides: {self.strides}."
1680                                 f"Please check the official api documents for "
1681                                 f"more information about the output.")
1682        return out_shape
1683
1684    def infer_dtype(self, x_dtype):
1685        validator.check_subclass("input", x_dtype, mstype.tensor, self.name)
1686        return x_dtype
1687
1688
1689class MaxPool(_Pool):
1690    r"""
1691    Max pooling operation.
1692
1693    Applies a 2D max pooling over an input Tensor which can be regarded as a composition of 2D planes.
1694
1695    Typically the input is of shape :math:`(N_{in}, C_{in}, H_{in}, W_{in})`, MaxPool outputs
1696    regional maximum in the :math:`(H_{in}, W_{in})`-dimension. Given kernel size
1697    :math:`ks = (h_{ker}, w_{ker})` and stride :math:`s = (s_0, s_1)`, the operation is as follows.
1698
1699    .. math::
1700        \text{output}(N_i, C_j, h, w) = \max_{m=0, \ldots, h_{ker}-1} \max_{n=0, \ldots, w_{ker}-1}
1701        \text{input}(N_i, C_j, s_0 \times h + m, s_1 \times w + n)
1702
1703    Args:
1704        kernel_size (Union[int, tuple[int]]): The size of kernel used to take the maximum value,
1705            is an int number that represents height and width of the kernel, or a tuple
1706            of two int numbers that represent height and width respectively. Default: 1.
1707        strides (Union[int, tuple[int]]): The distance of kernel moving, an int number that represents
1708            the height and width of movement are both strides, or a tuple of two int numbers that
1709            represent height and width of movement respectively. Default: 1.
1710        pad_mode (str): The optional value for pad mode, is "same" or "valid", not case sensitive.
1711            Default: "valid".
1712
1713            - same: Adopts the way of completion. The height and width of the output will be the same as
1714              the input. The total number of padding will be calculated in horizontal and vertical
1715              directions and evenly distributed to top and bottom, left and right if possible.
1716              Otherwise, the last extra padding will be done from the bottom and the right side.
1717
1718            - valid: Adopts the way of discarding. The possible largest height and width of output
1719              will be returned without padding. Extra pixels will be discarded.
1720        data_format (str) : The optional value for data format, is 'NHWC' or 'NCHW'.
1721            Default: 'NCHW'.
1722
1723    Inputs:
1724        - **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
1725
1726    Outputs:
1727        Tensor, with shape :math:`(N, C_{out}, H_{out}, W_{out})`.
1728
1729    Raises:
1730        TypeError: If `kernel_size` or `strides` is neither int nor tuple.
1731        ValueError: If `pad_mode` is neither 'valid' nor 'same' with not case sensitive.
1732        ValueError: If `data_format` is neither 'NCHW' nor 'NHWC'.
1733        ValueError: If `kernel_size` or `strides` is less than 1.
1734        ValueError: If length of shape of `input` is not equal to 4.
1735
1736    Supported Platforms:
1737        ``Ascend`` ``GPU`` ``CPU``
1738
1739    Examples:
1740        >>> x = Tensor(np.arange(1 * 3 * 3 * 4).reshape((1, 3, 3, 4)), mindspore.float32)
1741        >>> maxpool_op = ops.MaxPool(pad_mode="VALID", kernel_size=2, strides=1)
1742        >>> output = maxpool_op(x)
1743        >>> print(output)
1744        [[[[ 5.  6.  7.]
1745           [ 9. 10. 11.]]
1746          [[17. 18. 19.]
1747           [21. 22. 23.]]
1748          [[29. 30. 31.]
1749           [33. 34. 35.]]]]
1750    """
1751
1752    @prim_attr_register
1753    def __init__(self, kernel_size=1, strides=1, pad_mode="valid", data_format="NCHW"):
1754        """Initialize MaxPool."""
1755        super(MaxPool, self).__init__(kernel_size, strides, pad_mode, data_format)
1756
1757
1758class MaxPoolWithArgmax(_Pool):
1759    r"""
1760    Performs max pooling on the input Tensor and returns both max values and indices.
1761
1762    Typically the input is of shape :math:`(N_{in}, C_{in}, H_{in}, W_{in})`, MaxPool outputs
1763    regional maximum in the :math:`(H_{in}, W_{in})`-dimension. Given kernel size
1764    :math:`ks = (h_{ker}, w_{ker})` and stride :math:`s = (s_0, s_1)`, the operation is as follows.
1765
1766    .. math::
1767        \text{output}(N_i, C_j, h, w) = \max_{m=0, \ldots, h_{ker}-1} \max_{n=0, \ldots, w_{ker}-1}
1768        \text{input}(N_i, C_j, s_0 \times h + m, s_1 \times w + n)
1769
1770    Args:
1771        kernel_size (Union[int, tuple[int]]): The size of kernel used to take the maximum value and arg
1772            value, is an int number that represents height and width of the kernel, or a tuple of
1773            two int numbers that represent height and width respectively. Default: 1.
1774        strides (Union[int, tuple[int]]): The distance of kernel moving, an int number that represents
1775            the height and width of movement are both strides, or a tuple of two int numbers that
1776            represent height and width of movement respectively. Default: 1.
1777        pad_mode (str): The optional value for pad mode, is "same" or "valid", not case sensitive.
1778            Default: "valid".
1779
1780            - same: Adopts the way of completion. The height and width of the output will be the same as
1781              the input. The total number of padding will be calculated in horizontal and vertical
1782              directions and evenly distributed to top and bottom, left and right if possible.
1783              Otherwise, the last extra padding will be done from the bottom and the right side.
1784
1785            - valid: Adopts the way of discarding. The possible largest height and width of output
1786              will be returned without padding. Extra pixels will be discarded.
1787
1788    Inputs:
1789        - **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
1790          Data type must be float16 or float32.
1791
1792    Outputs:
1793        Tuple of 2 Tensors, representing the maxpool result and where the max values are generated.
1794
1795        - **output** (Tensor) -  Maxpooling result, with shape :math:`(N, C_{out}, H_{out}, W_{out})`.
1796          It has the same data type as `x`.
1797        - **mask** (Tensor) -  Max values' index represented by the mask. Data type is int32.
1798
1799    Raises:
1800        TypeError: If the data type of `x` is neither float16 nor float32.
1801        TypeError: If `kernel_size` or `strides` is neither an int nor a tuple.
1802        TypeError: If `x` is not a Tensor.
1803
1804    Supported Platforms:
1805        ``Ascend`` ``GPU``
1806
1807    Examples:
1808        >>> x = Tensor(np.arange(1 * 3 * 3 * 4).reshape((1, 3, 3, 4)), mindspore.float32)
1809        >>> maxpool_arg_op = ops.MaxPoolWithArgmax(pad_mode="VALID", kernel_size=2, strides=1)
1810        >>> output_tensor, argmax = maxpool_arg_op(x)
1811        >>> print(output_tensor)
1812        [[[[ 5.  6.  7.]
1813           [ 9. 10. 11.]]
1814          [[17. 18. 19.]
1815           [21. 22. 23.]]
1816          [[29. 30. 31.]
1817           [33. 34. 35.]]]]
1818    """
1819
1820    @prim_attr_register
1821    def __init__(self, kernel_size=1, strides=1, pad_mode="valid", data_format="NCHW"):
1822        """Initialize MaxPoolWithArgmax."""
1823        super(MaxPoolWithArgmax, self).__init__(kernel_size, strides, pad_mode, data_format)
1824
1825    def infer_shape(self, x_shape):
1826        out_shape = _Pool.infer_shape(self, x_shape)
1827        return out_shape, out_shape
1828
1829    def infer_dtype(self, x_dtype):
1830        validator.check_tensor_dtype_valid("x", x_dtype, (mstype.float16, mstype.float32), self.name)
1831        argmax_dtype = mstype.int32
1832        return x_dtype, argmax_dtype
1833
1834
1835class MaxPool3D(PrimitiveWithInfer):
1836    r"""
1837    3D max pooling operation.
1838
1839    Applies a 3D max pooling over an input Tensor which can be regarded as a composition of 3D planes.
1840
1841    Typically the input is of shape :math:`(N_{in}, C_{in}, D_{in}, H_{in}, W_{in})`, MaxPool outputs
1842    regional maximum in the :math:`(D_{in}, H_{in}, W_{in})`-dimension. Given kernel size
1843    :math:`ks = (d_{ker}, h_{ker}, w_{ker})` and stride :math:`s = (s_0, s_1, s_2)`, the operation is as follows.
1844
1845    .. math::
1846        \text{output}(N_i, C_j, d, h, w) =
1847        \max_{l=0, \ldots, d_{ker}-1} \max_{m=0, \ldots, h_{ker}-1} \max_{n=0, \ldots, w_{ker}-1}
1848        \text{input}(N_i, C_j, s_0 \times d + l, s_1 \times h + m, s_2 \times w + n)
1849
1850    Args:
1851        kernel_size (Union[int, tuple[int]]): The size of kernel used to take the maximum value,
1852            is an int number that represents depth, height and width of the kernel, or a tuple
1853            of three int numbers that represent depth, height and width respectively. Default: 1.
1854        strides (Union[int, tuple[int]]): The distance of kernel moving, an int number that represents
1855            the depth, height and width of movement are both strides, or a tuple of three int numbers that
1856            represent depth, height and width of movement respectively. Default: 1.
1857        pad_mode (str): The optional value for pad mode, is "same" or "valid", not case sensitive.
1858            Default: "valid".
1859
1860            - same: Adopts the way of completion. The height and width of the output will be the same as
1861              the input. The total number of padding will be calculated in horizontal and vertical
1862              directions and evenly distributed to top and bottom, left and right if possible.
1863              Otherwise, the last extra padding will be done from the bottom and the right side.
1864
1865            - valid: Adopts the way of discarding. The possible largest height and width of output
1866              will be returned without padding. Extra pixels will be discarded.
1867
1868            - pad: Implicit paddings on both sides of the input in depth, height, width. The number of "pad" will
1869              be padded to the input Tensor borders. "pad" must be greater than or equal to 0.
1870
1871        pad_list (Union(int, tuple[int])): The pad value to be filled. Default: 0. If `pad` is an integer, the paddings
1872            of head, tail, top, bottom, left and right are the same, equal to pad. If `pad` is a tuple of six
1873            integers, the padding of head, tail, top, bottom, left and right equal to pad[0], pad[1], pad[2],
1874            pad[3], pad[4] and pad[5] correspondingly.
1875        ceil_mode (bool): Whether to use ceil instead of floor to calculate output shape. Only effective in "pad" mode.
1876            When "pad_mode" is "pad" and "ceil_mode" is "None", "ceil_mode" will be set as "False". Default: None.
1877        data_format (str) : The optional value for data format. Currently only support 'NCDHW'. Default: 'NCDHW'.
1878
1879    Inputs:
1880        - **x** (Tensor) - Tensor of shape :math:`(N, C, D_{in}, H_{in}, W_{in})`.
1881          Data type must be float16 or float32.
1882
1883    Outputs:
1884        Tensor, with shape :math:`(N, C, D_{out}, H_{out}, W_{out})`. Has the data type with `x`.
1885
1886    Raises:
1887        TypeError: If `kernel_size` or `strides` is neither an int not a tuple.
1888        TypeError: If `pad_mode` or `data_format` is not a string.
1889        ValueError: If numbers in `kernel_size` or `strides` are not positive.
1890        ValueError: If `pad_mode` is not one of 'same', 'valid', 'pad'.
1891        ValueError: If `pad_mode` is 'same' or 'valid', 'ceil_mode' is not None.
1892        ValueError: If `kernel_size` or `strides` is a tuple whose length is not equal to 3.
1893        ValueError: If `data_format` is not 'NCDHW'.
1894
1895    Supported Platforms:
1896        ``Ascend`` ``GPU``
1897
1898    Examples:
1899        >>> x = Tensor(np.arange(1 * 2 * 2 * 2 * 3).reshape((1, 2, 2, 2, 3)), mindspore.float32)
1900        >>> max_pool3d = ops.MaxPool3D(kernel_size=2, strides=1, pad_mode="valid")
1901        >>> output = max_pool3d(x)
1902        >>> print(output)
1903        [[[[[10. 11.]]]
1904          [[[22. 23.]]]]]
1905    """
1906
1907    @prim_attr_register
1908    def __init__(self, kernel_size=1, strides=1, pad_mode="VALID", pad_list=0, ceil_mode=None, data_format="NCDHW"):
1909        """Initialize MaxPool3D."""
1910        self.init_prim_io_names(inputs=['x'], outputs=['output'])
1911        validator.check_value_type('kernel_size', kernel_size, [int, tuple], self.name)
1912        validator.check_value_type('strides', strides, [int, tuple], self.name)
1913        validator.check_value_type('pad_mode', pad_mode, [str], self.name)
1914        self.pad_mode = validator.check_string(pad_mode.upper(), ['VALID', 'SAME', 'PAD'], 'pad_mode', self.name)
1915        if pad_mode.upper() == "PAD":
1916            self.pad_mode = "CALCULATED"
1917        self.add_prim_attr("pad_mode", self.pad_mode)
1918        self.data_format = validator.check_string(data_format, ['NCDHW'], 'data_format', self.name)
1919        self.kernel_size = _check_3d_int_or_tuple("kernel_size", kernel_size, self.name,
1920                                                  allow_five=False, ret_five=True)
1921        self.add_prim_attr("kernel_size", self.kernel_size)
1922        self.strides = _check_3d_int_or_tuple("strides", strides, self.name, allow_five=False, ret_five=True)
1923        self.add_prim_attr("strides", self.strides)
1924        if ceil_mode is None:
1925            self.ceil_mode = not self.pad_mode == "CALCULATED"
1926        else:
1927            self.ceil_mode = validator.check_value_type('ceil_mode', ceil_mode, [bool], self.name)
1928            if self.pad_mode != "CALCULATED":
1929                raise ValueError("When pad_mode is same or valid, ceil_mode only support 'None'.")
1930        self.add_prim_attr("ceil_mode", int(self.ceil_mode))
1931
1932        validator.check_value_type('pad_list', pad_list, (int, tuple), self.name)
1933        self.pad_list = pad_list
1934        if isinstance(self.pad_list, int):
1935            self.pad_list = (self.pad_list,) * 6
1936        if len(self.pad_list) == 3:
1937            self.pad_list = (pad_list[0], pad_list[0], pad_list[1], pad_list[1], pad_list[2], pad_list[2])
1938        if len(self.pad_list) != 3 and len(self.pad_list) != 6:
1939            raise ValueError(f"For '{self.name}', attr 'pad_list' should be an positive int number or a tuple of "
1940                             f"three or six positive int numbers, but got {len(self.pad_list)} numbers.")
1941        if self.pad_mode != 'CALCULATED' and self.pad_list != (0, 0, 0, 0, 0, 0):
1942            raise ValueError(f"For '{self.name}', the 'pad_list' must be zero when 'pad_mode' is not \"pad\", "
1943                             f"but got 'pad_list' is {self.pad_list} and 'pad_mode' is {pad_mode}.")
1944        if self.pad_mode == 'CALCULATED':
1945            for item in self.pad_list:
1946                validator.check_non_negative_int(item, 'pad_list item', self.name)
1947        self.add_prim_attr("pad_list", self.pad_list)
1948
1949    def infer_shape(self, x_shape):
1950        validator.check_equal_int(len(x_shape), 5, "x rank", self.name)
1951        batch, channel, input_d, input_h, input_w = x_shape
1952        self.add_prim_attr("x_shape", x_shape)
1953        _, _, kernel_d, kernel_h, kernel_w = self.kernel_size
1954        _, _, stride_d, stride_h, stride_w = self.strides
1955
1956        if self.pad_mode == "VALID":
1957            out_d = math.ceil((input_d - (kernel_d - 1)) / stride_d)
1958            out_h = math.ceil((input_h - (kernel_h - 1)) / stride_h)
1959            out_w = math.ceil((input_w - (kernel_w - 1)) / stride_w)
1960        elif self.pad_mode == "SAME":
1961            out_d = math.ceil(input_d / stride_d)
1962            out_h = math.ceil(input_h / stride_h)
1963            out_w = math.ceil(input_w / stride_w)
1964        else:
1965            out_d = ((input_d + self.pad_list[0] + self.pad_list[1] -
1966                      (kernel_d - 1) - 1) / stride_d) + 1
1967            out_h = ((input_h + self.pad_list[2] + self.pad_list[3] -
1968                      (kernel_h - 1) - 1) / stride_h) + 1
1969            out_w = ((input_w + self.pad_list[4] + self.pad_list[5] -
1970                      (kernel_w - 1) - 1) / stride_w) + 1
1971            if self.ceil_mode:
1972                out_d = math.ceil(out_d)
1973                out_h = math.ceil(out_h)
1974                out_w = math.ceil(out_w)
1975            else:
1976                out_d = math.floor(out_d)
1977                out_h = math.floor(out_h)
1978                out_w = math.floor(out_w)
1979        out_shape = [batch, channel, out_d, out_h, out_w]
1980
1981        _check_shape('output', out_shape, self.name)
1982        return out_shape
1983
1984    def infer_dtype(self, x_dtype):
1985        validator.check_tensor_dtype_valid("x", x_dtype, [mstype.float16, mstype.float32], self.name)
1986        return x_dtype
1987
1988
1989class AvgPool(_Pool):
1990    r"""
1991    Average pooling operation.
1992
1993    Applies a 2D average pooling over an input Tensor which can be regarded as a composition of 2D input planes.
1994    Typically the input is of shape :math:`(N_{in}, C_{in}, H_{in}, W_{in})`, AvgPool outputs
1995    regional average in the :math:`(H_{in}, W_{in})`-dimension. Given kernel size
1996    :math:`ks = (h_{ker}, w_{ker})` and stride :math:`s = (s_0, s_1)`, the operation is as follows.
1997
1998    .. math::
1999        \text{output}(N_i, C_j, h, w) = \frac{1}{h_{ker} * w_{ker}} \sum_{m=0}^{h_{ker}-1} \sum_{n=0}^{w_{ker}-1}
2000        \text{input}(N_i, C_j, s_0 \times h + m, s_1 \times w + n)
2001
2002    .. warning::
2003        - Only single input and single output are supported.
2004        - Global pooling is supported.
2005        - The height of "kernel_size" and the weight of "kernel_size" are positive integers within the range [1, 255].
2006          ksize_h * ksize_w < 256.
2007        - Due to instruction restrictions, the values of "strides_h" and "strides_w" are
2008          positive integers within the range [1, 63].
2009
2010    Args:
2011        kernel_size (Union[int, tuple[int]]): The size of kernel used to take the average value,
2012            is an int number that represents height and width of the kernel, or a tuple
2013            of two int numbers that represent height and width respectively. Default: 1.
2014        strides (Union[int, tuple[int]]): The distance of kernel moving, an int number that represents
2015            the height and width of movement are both strides, or a tuple of two int numbers that
2016            represent height and width of movement respectively. Default: 1.
2017        pad_mode (str): The optional value for pad mode, is "same" or "valid", not case sensitive.
2018            Default: "valid".
2019
2020            - same: Adopts the way of completion. The height and width of the output will be the same as
2021              the input. The total number of padding will be calculated in horizontal and vertical
2022              directions and evenly distributed to top and bottom, left and right if possible.
2023              Otherwise, the last extra padding will be done from the bottom and the right side.
2024
2025            - valid: Adopts the way of discarding. The possible largest height and width of output
2026              will be returned without padding. Extra pixels will be discarded.
2027        data_format (str): The format of input and output data. It should be 'NHWC' or 'NCHW'.
2028            Default: 'NCHW'.
2029
2030    Inputs:
2031        - **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
2032
2033    Outputs:
2034        Tensor, with shape :math:`(N, C_{out}, H_{out}, W_{out})`.
2035
2036    Raises:
2037        TypeError: If `kernel_size` or `strides` is neither int nor tuple.
2038        ValueError: If `pad_mode` is neither 'valid' nor 'same' with not case sensitive.
2039        ValueError: If `data_format` is neither 'NCHW' nor 'NHWC'.
2040        ValueError: If `kernel_size` or `strides` is less than 1.
2041        ValueError: If length of shape of `x` is not equal to 4.
2042
2043    Supported Platforms:
2044        ``Ascend`` ``GPU`` ``CPU``
2045
2046    Examples:
2047        >>> class Net(nn.Cell):
2048        ...     def __init__(self):
2049        ...         super(Net, self).__init__()
2050        ...         self.avgpool_op = ops.AvgPool(pad_mode="VALID", kernel_size=2, strides=1)
2051        ...
2052        ...     def construct(self, x):
2053        ...         result = self.avgpool_op(x)
2054        ...         return result
2055        ...
2056        >>> x = Tensor(np.arange(1 * 3 * 3 * 4).reshape(1, 3, 3, 4), mindspore.float32)
2057        >>> net = Net()
2058        >>> output = net(x)
2059        >>> print(output)
2060        [[[[ 2.5   3.5   4.5]
2061           [ 6.5   7.5   8.5]]
2062          [[14.5  15.5  16.5]
2063           [18.5  19.5  20.5]]
2064          [[26.5  27.5  28.5]
2065           [30.5  31.5  32.5]]]]
2066    """
2067
2068    @prim_attr_register
2069    def __init__(self, kernel_size=1, strides=1, pad_mode="valid", data_format="NCHW"):
2070        """Initialize AvgPool."""
2071        super(AvgPool, self).__init__(kernel_size, strides, pad_mode, data_format)
2072
2073
2074class Conv2DBackpropInput(Primitive):
2075    r"""
2076    Computes the gradients of convolution with respect to the input.
2077
2078    Args:
2079        out_channel (int): The number of output channel :math:`C_{out}`.
2080        kernel_size (Union[int, tuple[int]]): The data type is int or a tuple of 2 integers. Specifies the height
2081            and width of the 2D convolution window. Single int means the value is for both the height and the width of
2082            the kernel. A tuple of 2 ints means the first value is for the height and the other is for the
2083            width of the kernel.
2084        pad_mode (str): Modes to fill padding. It could be "valid", "same", or "pad". Default: "valid".
2085        pad (Union[int, tuple[int]]): The pad value to be filled. Default: 0. If `pad` is an integer, the paddings of
2086                    top, bottom, left and right are the same, equal to pad. If `pad` is a tuple of four integers, the
2087                    padding of top, bottom, left and right equal to pad[0], pad[1], pad[2], and pad[3] correspondingly.
2088        mode (int): Modes for different convolutions. 0 Math convolutiuon, 1 cross-correlation convolution ,
2089                       2 deconvolution, 3 depthwise convolution. Default: 1.
2090        stride (Union[int. tuple[int]]): The distance of kernel moving, an int number that represents
2091            the height and width of movement are both strides, or a tuple of two int numbers that
2092            represent height and width of movement respectively. Default: 1.
2093        dilation (Union[int. tuple[int]]): Specifies the dilation rate to be used for the dilated convolution.
2094            Default: 1.
2095        group (int): Splits input into groups. Default: 1.
2096        data_format (str): The format of input and output data. It should be 'NHWC' or 'NCHW'.
2097            Default: 'NCHW'.
2098
2099    Inputs:
2100        - **dout** (Tensor) - The gradients write respect to the output of the convolution. The shape conforms
2101          to the default data_format :math:`(N, C_{out}, H_{out}, W_{out})`.
2102        - **weight** (Tensor) - Set size of kernel is :math:`(\text{ks_w}, \text{ks_h})`, where :math:`\text{ks_w}`
2103          and :math:`\text{ks_h}` are the height and width of the convolution kernel, then the shape is
2104          :math:`(C_{out}, C_{in}, \text{ks_w}, \text{ks_h})`.
2105        - **input_size** (Tensor) - A tuple describes the shape of the input which conforms to the format
2106          :math:`(N, C_{in}, H_{in}, W_{in})`.
2107
2108    Outputs:
2109        Tensor, the gradients with respect to the input of convolution. It has the same shape as the input.
2110
2111    Raises:
2112        TypeError: If `kernel_size`, `stride`, `pad` or `dilation` is neither an int nor a tuple.
2113        TypeError: If `out_channel` or `group` is not an int.
2114        ValueError: If `kernel_size`, `stride` or `dilation` is less than 1.
2115        ValueError: If `pad_mode` is not one of 'same', 'valid', 'pad'.
2116        ValueError: If `padding` is a tuple whose length is not equal to 4.
2117        ValueError: If `pad_mode` it not equal to 'pad' and `pad` is not equal to (0, 0, 0, 0).
2118        ValueError: If `data_format` is neither 'NCHW' not 'NHWC'.
2119
2120    Supported Platforms:
2121        ``Ascend`` ``GPU`` ``CPU``
2122
2123    Examples:
2124        >>> import numpy as np
2125        >>> import mindspore
2126        >>> from mindspore import Tensor
2127        >>> import mindspore.ops as ops
2128        >>> dout = Tensor(np.ones([10, 32, 30, 30]), mindspore.float32)
2129        >>> weight = Tensor(np.ones([32, 32, 3, 3]), mindspore.float32)
2130        >>> input_x = Tensor(np.ones([10, 32, 32, 32]))
2131        >>> conv2d_backprop_input = ops.Conv2DBackpropInput(out_channel=32, kernel_size=3)
2132        >>> output = conv2d_backprop_input(dout, weight, ops.shape(input_x))
2133        >>> print(output.shape)
2134        (10, 32, 32, 32)
2135    """
2136    __mindspore_signature__ = (
2137        sig.make_sig('out_backprop', dtype=sig.sig_dtype.T),
2138        sig.make_sig('filter', dtype=sig.sig_dtype.T1),
2139        sig.make_sig('input_sizes', dtype=sig.sig_dtype.T2)
2140    )
2141
2142    @prim_attr_register
2143    def __init__(self,
2144                 out_channel,
2145                 kernel_size,
2146                 pad_mode="valid",
2147                 pad=0,
2148                 pad_list=None,
2149                 mode=1,
2150                 stride=1,
2151                 dilation=1,
2152                 group=1,
2153                 data_format="NCHW"):
2154        """Initialize Conv2DBackpropInput"""
2155        self.init_prim_io_names(inputs=['out_backprop', 'filter', 'input_sizes'], outputs=['output'])
2156        self.out_channel = validator.check_positive_int(out_channel, 'out_channel', self.name)
2157        self.kernel_size = _check_positive_int_or_tuple('kernel_size', kernel_size, self.name)
2158        self.add_prim_attr('kernel_size', self.kernel_size)
2159        self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name)
2160        if context.get_context("device_target") != "GPU" and self.format == "NHWC":
2161            raise ValueError(f"For '{self.name}', the \"NHWC\" format only support in GPU target, "
2162                             f"but got the format is {self.format} and "
2163                             f"the platform is {context.get_context('device_target')}.")
2164        self.add_prim_attr('data_format', self.format)
2165        self.stride = _check_positive_int_or_tuple('stride', stride, self.name, allow_four=True, ret_four=True)
2166        self.stride = _update_attr_by_format(self.stride, self.format)
2167        self.add_prim_attr('stride', self.stride)
2168        self.dilation = _check_positive_int_or_tuple('dilation', dilation, self.name, allow_four=True, ret_four=True)
2169        self.dilation = _update_attr_by_format(self.dilation, self.format)
2170        self.add_prim_attr('dilation', self.dilation)
2171        validator.check_value_type('pad', pad, (int, tuple), self.name)
2172        if isinstance(pad, int):
2173            pad = (pad,) * 4
2174        else:
2175            validator.check_equal_int(len(pad), 4, 'pad size', self.name)
2176        self.add_prim_attr("pad", pad)
2177        self.padding = pad
2178        self.pad_mode = validator.check_string(pad_mode, ['valid', 'same', 'pad'], 'pad_mode', self.name)
2179        if pad_mode != 'pad' and pad != (0, 0, 0, 0):
2180            raise ValueError(f"For '{self.name}', the 'pad' must be zero when 'pad_mode' is not \"pad\", "
2181                             f"but got 'pad' is {pad} and 'pad_mode' is {pad_mode}.")
2182        if self.pad_mode == 'pad':
2183            for item in pad:
2184                validator.check_non_negative_int(item, 'pad item', self.name)
2185
2186        pad_mode = pad_mode.upper()
2187        self.add_prim_attr('pad_mode', pad_mode)
2188        self.mode = validator.check_equal_int(mode, 1, 'mode', self.name)
2189        self.group = validator.check_positive_int(group, 'group', self.name)
2190        self.add_prim_attr('groups', self.group)
2191        if pad_list:
2192            for x in pad_list:
2193                validator.check_non_negative_int(x, 'element of pad_list', self.name)
2194            self.pad_list = pad_list
2195
2196
2197class Conv2DTranspose(Conv2DBackpropInput):
2198    """
2199    Compute a 2D transposed convolution, which is also known as a deconvolution
2200    (although it is not an actual deconvolution).
2201
2202    Args:
2203        out_channel (int): The dimensionality of the output space.
2204        kernel_size (Union[int, tuple[int]]): The size of the convolution window.
2205        pad_mode (str): Modes to fill padding. It could be "valid", "same", or "pad". Default: "valid".
2206        pad (Union[int, tuple[int]]): The pad value to be filled. Default: 0. If `pad` is an integer, the paddings of
2207                    top, bottom, left and right are the same, equal to pad. If `pad` is a tuple of four integers, the
2208                    padding of top, bottom, left and right equal to pad[0], pad[1], pad[2], and pad[3] correspondingly.
2209        mode (int): Modes for different convolutions. 0 Math convolutiuon, 1 cross-correlation convolution ,
2210                       2 deconvolution, 3 depthwise convolution. Default: 1.
2211        stride (Union[int. tuple[int]]): The stride to be applied to the convolution filter. Default: 1.
2212        dilation (Union[int. tuple[int]]): Specifies the dilation rate to be used for the dilated convolution.
2213            Default: 1.
2214        group (int): Splits input into groups. Default: 1.
2215        data_format (str) - The format of input and output data. It should be 'NHWC' or 'NCHW',\
2216            default is 'NCHW'.
2217
2218    Inputs:
2219        - **dout** (Tensor) - the gradients w.r.t the output of the convolution. The shape conforms to the default
2220          data_format :math:`(N, C_{out}, H_{out}, W_{out})`.
2221        - **weight** (Tensor) - Set size of kernel is :math:`(K_1, K_2)`, then the shape is
2222          :math:`(C_{out}, C_{in}, K_1, K_2)`.
2223        - **input_size** (Tensor) - A tuple describes the shape of the input which conforms to the format
2224          :math:`(N, C_{in}, H_{in}, W_{in})`.
2225
2226    Outputs:
2227        Tensor, the gradients w.r.t the input of convolution. It has the same shape as the input.
2228
2229    Raises:
2230        TypeError: If `kernel_size`, `stride`, `pad` or `dilation` is neither an int nor a tuple.
2231        TypeError: If `out_channel` or `group` is not an int.
2232        ValueError: If `kernel_size`, `stride` or `dilation` is less than 1.
2233        ValueError: If `pad_mode` is not one of 'same', 'valid', 'pad'.
2234        ValueError: If `padding` is a tuple whose length is not equal to 4.
2235        ValueError: If `pad_mode` it not equal to 'pad' and `pad` is not equal to (0, 0, 0, 0).
2236        ValueError: If `data_format` is neither 'NCHW' not 'NHWC'.
2237
2238    Supported Platforms:
2239        ``Ascend`` ``GPU`` ``CPU``
2240
2241    Examples:
2242        >>> dout = Tensor(np.ones([10, 32, 30, 30]), mindspore.float32)
2243        >>> weight = Tensor(np.ones([32, 32, 3, 3]), mindspore.float32)
2244        >>> x = Tensor(np.ones([10, 32, 32, 32]))
2245        >>> conv2d_transpose_input = ops.Conv2DTranspose(out_channel=32, kernel_size=3)
2246        >>> output = conv2d_transpose_input(dout, weight, ops.shape(x))
2247        >>> print(output.shape)
2248        (10, 32, 32, 32)
2249    """
2250
2251    @prim_attr_register
2252    def __init__(self, out_channel, kernel_size, pad_mode="valid", pad=0,
2253                 pad_list=None, mode=1, stride=1, dilation=1, group=1, data_format="NCHW"):
2254        """Initialize Conv2DTranspose."""
2255        super(Conv2DTranspose, self).__init__(out_channel, kernel_size, pad_mode, pad,
2256                                              pad_list, mode, stride, dilation, group, data_format)
2257
2258
2259class BiasAdd(Primitive):
2260    r"""
2261    Returns sum of input and bias tensor.
2262
2263    Adds the 1-D bias tensor to the input tensor, and broadcasts the shape on all axis
2264    except for the channel axis.
2265
2266    Args:
2267        data_format (str): The format of input and output data. It should be 'NHWC', 'NCHW' or 'NCDHW'.
2268            Default is 'NCHW'.
2269
2270    Inputs:
2271        - **input_x** (Tensor) - The input tensor. The shape can be 2-5 dimensions.
2272          The data type should be float16 or float32.
2273        - **bias** (Tensor) - The bias tensor, with shape :math:`(C)`. The shape of
2274          `bias` must be the same as `input_x`'s channel dimension. The data type should be float16 or float32.
2275
2276    Outputs:
2277        Tensor, with the same shape and data type as `input_x`.
2278
2279    Raises:
2280        TypeError: If `data_format` is not a str.
2281        TypeError: If `input_x` or `bias` is not a Tensor.
2282        TypeError: If dtype of `input_x` or `bias` is neither float16 nor float32.
2283
2284    Supported Platforms:
2285        ``Ascend`` ``GPU`` ``CPU``
2286
2287    Examples:
2288        >>> input_x = Tensor(np.arange(6).reshape((2, 3)), mindspore.float32)
2289        >>> bias = Tensor(np.random.random(3).reshape((3,)), mindspore.float32)
2290        >>> bias_add = ops.BiasAdd()
2291        >>> output = bias_add(input_x, bias)
2292        >>> print(output.shape)
2293        (2, 3)
2294    """
2295
2296    @prim_attr_register
2297    def __init__(self, data_format="NCHW"):
2298        """Initialize BiasAdd."""
2299        self.init_prim_io_names(inputs=['x', 'b'], outputs=['output'])
2300        self.format = validator.check_string(data_format, ['NCHW', 'NHWC', 'NCDHW'], 'format', self.name)
2301        if context.get_context("device_target") != "GPU" and self.format == "NHWC":
2302            raise ValueError(f"For '{self.name}', the \"NHWC\" format only support in GPU target, "
2303                             f"but got the format is {self.format} and "
2304                             f"the platform is {context.get_context('device_target')}.")
2305        self.add_prim_attr('data_format', self.format)
2306
2307
2308class TopK(PrimitiveWithInfer):
2309    """
2310    Finds values and indices of the `k` largest entries along the last dimension.
2311
2312    .. warning::
2313        - If sorted set to 'False', it will use aicpu operator, performance may be reduced.
2314
2315    If the `input_x` is a one-dimensional Tensor, finds the `k` largest entries in the Tensor,
2316    and outputs its value and index as a Tensor. Therefore, values[`k`] is the `k` largest item in `input_x`,
2317    and its index is indices [`k`].
2318
2319    For a multi-dimensional matrix,
2320    calculates the first `k` entries in each row (corresponding vector along the last dimension), therefore:
2321
2322    .. math::
2323
2324        values.shape = indices.shape = input.shape[:-1] + [k].
2325
2326    If the two compared elements are the same, the one with the smaller index value is returned first.
2327
2328    Args:
2329        sorted (bool): If true, the obtained elements will
2330            be sorted by the values in descending order. Default: True.
2331
2332    Inputs:
2333        - **input_x** (Tensor) - Input to be computed, data type must be float16, float32 or int32.
2334        - **k** (int) - The number of top elements to be computed along the last dimension, constant input is needed.
2335
2336    Outputs:
2337        Tuple of 2 tensors, the values and the indices.
2338
2339        - **values** (Tensor) - The `k` largest elements in each slice of the last dimensional.
2340        - **indices** (Tensor) - The indices of values within the last dimension of input.
2341
2342    Raises:
2343        TypeError: If `sorted` is not a bool.
2344        TypeError: If `input_x` is not a Tensor.
2345        TypeError: If `k` is not an int.
2346        TypeError: If dtype of `input_x` is not one of the following: float16, float32 or int32.
2347
2348    Supported Platforms:
2349        ``Ascend`` ``GPU`` ``CPU``
2350
2351    Examples:
2352        >>> topk = ops.TopK(sorted=True)
2353        >>> input_x = Tensor([1, 2, 3, 4, 5], mindspore.float16)
2354        >>> k = 3
2355        >>> values, indices = topk(input_x, k)
2356        >>> print((values, indices))
2357        (Tensor(shape=[3], dtype=Float16, value= [ 5.0000e+00,  4.0000e+00,  3.0000e+00]), Tensor(shape=[3],
2358          dtype=Int32, value= [4, 3, 2]))
2359    """
2360
2361    @prim_attr_register
2362    def __init__(self, sorted=True):
2363        """Initialize TopK."""
2364        self.sorted = validator.check_value_type("sorted", sorted, [bool], self.name)
2365        self.add_prim_attr("sorted", self.sorted)
2366        self.init_prim_io_names(inputs=['input', 'k'],
2367                                outputs=['values', 'indices'])
2368
2369    def __infer__(self, input_x, k):
2370        x_dtype = input_x['dtype']
2371        valid_dtypes = (mstype.int32, mstype.float16, mstype.float32)
2372        validator.check_tensor_dtype_valid('x', x_dtype, valid_dtypes, self.name)
2373        k_v = k['value']
2374        validator.check_value_type('k', k_v, (int,), self.name)
2375        x_shape = list(input_x['shape'])
2376        ndim = len(x_shape) - 1
2377        x_shape[ndim] = k_v
2378        return {'shape': (x_shape, x_shape),
2379                'dtype': (x_dtype, mstype.int32),
2380                'value': None}
2381
2382
2383class NLLLoss(PrimitiveWithInfer):
2384    r"""
2385    Gets the negative log likelihood loss between logits and labels.
2386
2387    The nll loss with reduction=none can be described as:
2388
2389    .. math::
2390
2391        \ell(x, t)=L=\left\{l_{1}, \ldots, l_{N}\right\}^{\top},
2392        \quad l_{n}=-w_{t_{n}} x_{n, t_{n}},
2393        \quad w_{c}=\text { weight }[c] \cdot 1
2394
2395    where :math:`x` is the logits, :math:`t` is the labels, :math:`w` is the weight,
2396    N is the batch size, :math:`c` belonging [0, C-1] is class index, where :math:`C` is the number of classes.
2397
2398    If reduction is not 'none' (default 'mean'), then
2399
2400    .. math::
2401
2402        \ell(x, t)=\left\{\begin{array}{ll}
2403        \sum_{n=1}^{N} \frac{1}{\sum_{n=1}^{N} w_{t n}} l_{n}, & \text { if reduction }=\text { 'mean'; } \\
2404        \sum_{n=1}^{N} l_{n}, & \text { if reduction }=\text { 'sum' }
2405        \end{array}\right.
2406
2407    Args:
2408        reduction (str): Apply specific reduction method to the output: 'none', 'mean', 'sum', Default: "mean".
2409
2410    Inputs:
2411        - **logits** (Tensor) - Input logits, with shape :math:`(N, C)`. Data type only support float32 or float16.
2412        - **labels** (Tensor) - Ground truth labels, with shape :math:`(N,)`. Data type only support int32.
2413        - **weight** (Tensor) - The rescaling weight to each class, with shape :math:`(C,)` and data type only
2414          support float32 or float16.
2415
2416    Outputs:
2417        Tuple of 2 tensors composed with `loss` and `total_weight`.
2418
2419        - **loss** (Tensor) - When `reduction` is 'none' and `logits` is 2D tensor, the `loss` shape is :math:`(N,)`.
2420          Otherwise, the `loss` is a scalar. The data type is same with `input's`.
2421        - **total_weight** (Tensor) - The `total_weight` is a scalar. The data type is same with `weight's`.
2422
2423    Raises:
2424        TypeError: If dtype of `logits` or `weight` is neither float16 nor float32, `labels` is not int32.
2425        ValueError: If `logits` is not a one or two dimension tensor, `labels` and `weight` not a one dimension tensor.
2426                    When `logits` is a two dimension tensor, the first dimension of `logits` is not equal to `labels`,
2427                    and second dimension of `logits` is not equal to `weight`.
2428                    When `logits` is a one dimension tensor, the dimensions of `logits`, `labels`
2429                    and `weight` should be equal to each other.
2430
2431    Supported Platforms:
2432        ``Ascend`` ``GPU``
2433
2434    Examples:
2435        >>> logits = Tensor(np.array([[0.5488135, 0.71518934],
2436        ...                           [0.60276335, 0.5448832],
2437        ...                           [0.4236548, 0.6458941]]).astype(np.float32))
2438        >>> labels = Tensor(np.array([0, 0, 0]).astype(np.int32))
2439        >>> weight = Tensor(np.array([0.3834415, 0.79172504]).astype(np.float32))
2440        >>> nll_loss = ops.NLLLoss(reduction="mean")
2441        >>> loss, weight = nll_loss(logits, labels, weight)
2442        >>> print(loss)
2443        -0.52507716
2444        >>> print(weight)
2445        1.1503246
2446    """
2447
2448    @prim_attr_register
2449    def __init__(self, reduction="mean"):
2450        """Initialize NLLLoss"""
2451        self.init_prim_io_names(inputs=['x', 'target', "weight"], outputs=['loss'])
2452        self.reduction = validator.check_string(reduction, ['none', 'sum', 'mean'], 'reduction', self.name)
2453        self.add_prim_attr('reduction', self.reduction)
2454
2455    def infer_shape(self, x_shape, t_shape, w_shape):
2456        validator.check_int(len(x_shape), [1, 2], Rel.IN, "x rank", self.name)
2457        validator.check_int(len(t_shape), 1, Rel.EQ, "target rank", self.name)
2458        validator.check_int(len(w_shape), 1, Rel.EQ, "weight rank", self.name)
2459        validator.check(f"input_shape[0]", x_shape[0], "target_shape", t_shape[0], Rel.EQ, self.name)
2460        if len(x_shape) == 1:
2461            validator.check(f"input_shape[0]", x_shape[0], "weight_shape", w_shape[0], Rel.EQ, self.name)
2462        else:
2463            validator.check(f"input_shape[1]", x_shape[1], "weight_shape", w_shape[0], Rel.EQ, self.name)
2464        if self.reduction == "none":
2465            return t_shape, ()
2466        return (), ()
2467
2468    def infer_dtype(self, x_dtype, t_dtype, w_dtype):
2469        valid_dtypes = (mstype.float16, mstype.float32)
2470        validator.check_tensor_dtype_valid("x_dtype", x_dtype, valid_dtypes, self.name)
2471        validator.check_tensor_dtype_valid("t_dtype", t_dtype, mstype.int32, self.name)
2472        validator.check_tensor_dtype_valid("w_dtype", w_dtype, valid_dtypes, self.name)
2473        return x_dtype, w_dtype
2474
2475
2476class SoftmaxCrossEntropyWithLogits(PrimitiveWithInfer):
2477    r"""
2478    Gets the softmax cross-entropy value between logits and labels with one-hot encoding.
2479
2480    The updating formulas of SoftmaxCrossEntropyWithLogits algorithm are as follows,
2481
2482    .. math::
2483        \begin{array}{ll} \\
2484            p_{ij} = softmax(X_{ij}) = \frac{\exp(x_i)}{\sum_{j = 0}^{N-1}\exp(x_j)} \\
2485            loss_{ij} = -\sum_j{Y_{ij} * ln(p_{ij})}
2486        \end{array}
2487
2488    where :math:`X` represents `logits`.
2489    :math:`Y` represents `label`.
2490    :math:`loss` represents `output`.
2491
2492    Inputs:
2493        - **logits** (Tensor) - Input logits, with shape :math:`(N, C)`. Data type must be float16 or float32.
2494        - **labels** (Tensor) - Ground truth labels, with shape :math:`(N, C)`, has the same data type with `logits`.
2495
2496    Outputs:
2497        Tuple of 2 tensors(loss, dlogits), the `loss` shape is :math:`(N,)`,
2498        and the `dlogits` with the same shape as `logits`.
2499
2500    Raises:
2501        TypeError: If dtype of `logits` or `labels` is neither float16 nor float32.
2502        TypeError: If `logits` or `labels` is not a Tensor.
2503        ValueError: If shape of `logits` is not the same as `labels`.
2504
2505    Supported Platforms:
2506        ``Ascend`` ``GPU`` ``CPU``
2507
2508    Examples:
2509        >>> logits = Tensor([[2, 4, 1, 4, 5], [2, 1, 2, 4, 3]], mindspore.float32)
2510        >>> labels = Tensor([[0, 0, 0, 0, 1], [0, 0, 0, 1, 0]], mindspore.float32)
2511        >>> softmax_cross = ops.SoftmaxCrossEntropyWithLogits()
2512        >>> loss, dlogits = softmax_cross(logits, labels)
2513        >>> print(loss)
2514        [0.5899297  0.52374405]
2515        >>> print(dlogits)
2516        [[ 0.02760027  0.20393994  0.01015357  0.20393994 -0.44563377]
2517         [ 0.08015892  0.02948882  0.08015892 -0.4077012   0.21789455]]
2518    """
2519
2520    @prim_attr_register
2521    def __init__(self):
2522        pass
2523
2524    def infer_shape(self, logits_shape, labels_shape):
2525        validator.check("logits_shape", logits_shape, "labels_shape", labels_shape, Rel.EQ, self.name)
2526        loss_shape = [logits_shape[0]]
2527        dlogits_shape = logits_shape
2528        return loss_shape, dlogits_shape
2529
2530    def infer_dtype(self, logits_type, labels_type):
2531        args = {"logits": logits_type, "labels": labels_type}
2532        validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
2533        return logits_type, logits_type
2534
2535
2536class SparseSoftmaxCrossEntropyWithLogits(PrimitiveWithInfer):
2537    r"""
2538    Computes the softmax cross-entropy value between logits and sparse encoding labels.
2539
2540    Sets input logits as `X`, input label as `Y`, output as `loss`. Then,
2541
2542    .. math::
2543        \begin{array}{ll} \\
2544            p_{ij} = softmax(X_{ij}) = \frac{\exp(x_i)}{\sum_{j = 0}^{N-1}\exp(x_j)} \\
2545            loss_{ij} = \begin{cases} -ln(p_{ij}), &j = y_i \cr -ln(1 - p_{ij}), & j \neq y_i \end{cases} \\
2546            loss = \sum_{ij} loss_{ij}
2547        \end{array}
2548
2549    Args:
2550        is_grad (bool): If true, this operation returns the computed gradient. Default: False.
2551
2552    Inputs:
2553        - **logits** (Tensor) - Input logits, with shape :math:`(N, C)`. Data type must be float16 or float32.
2554        - **labels** (Tensor) - Ground truth labels, with shape :math:`(N)`.
2555          Data type must be int32 or int64.
2556
2557    Outputs:
2558        Tensor, if `is_grad` is False, the output tensor is the value of loss which is a scalar tensor;
2559        if `is_grad` is True, the output tensor is the gradient of input with the same shape as `logits`.
2560
2561    Raises:
2562        TypeError: If `is_grad` is not a bool.
2563        TypeError: If dtype of `logits` is neither float16 nor float32.
2564        TypeError: If dtype of `labels` is neither int32 nor int64.
2565        ValueError: If logits.shape[0] != labels.shape[0].
2566
2567    Supported Platforms:
2568        ``GPU`` ``CPU``
2569
2570    Examples:
2571        >>> logits = Tensor([[2, 3, 1, 4, 5], [2, 1, 2, 4, 3]], mindspore.float32)
2572        >>> labels = Tensor([0, 1], mindspore.int32)
2573        >>> sparse_softmax_cross = ops.SparseSoftmaxCrossEntropyWithLogits()
2574        >>> loss = sparse_softmax_cross(logits, labels)
2575        >>> print(loss)
2576        3.4878292
2577        >>> sparse_softmax_cross_grad = ops.SparseSoftmaxCrossEntropyWithLogits(is_grad=True)
2578        >>> loss_grad = sparse_softmax_cross_grad(logits, labels)
2579        >>> print(loss_grad)
2580        [[-0.48415753  0.04306427  0.00582811  0.11706084  0.3182043 ]
2581         [ 0.04007946 -0.4852556   0.04007946  0.2961494   0.10894729]]
2582    """
2583
2584    @prim_attr_register
2585    def __init__(self, is_grad=False):
2586        """Initialize SparseSoftmaxCrossEntropyWithLogits."""
2587        validator.check_value_type('is_grad', is_grad, [bool], self.name)
2588        self.init_prim_io_names(inputs=['features', 'labels'], outputs=['output'])
2589        self.is_grad = is_grad
2590        self.add_prim_attr('sens', 1.0)
2591
2592    def infer_shape(self, logits_shape, labels_shape):
2593        validator.check("logits_shape[0]", logits_shape[0], "labels_shape[0]", labels_shape[0], Rel.EQ, self.name)
2594        loss_shape = []
2595        if self.is_grad:
2596            return logits_shape
2597        return loss_shape
2598
2599    def infer_dtype(self, logits_type, labels_type):
2600        validator.check_tensor_dtype_valid("logits", logits_type, (mstype.float16, mstype.float32),
2601                                           self.name)
2602        validator.check_tensor_dtype_valid("labels", labels_type, (mstype.int32, mstype.int64), self.name)
2603        return logits_type
2604
2605
2606class ApplyMomentum(PrimitiveWithInfer):
2607    """
2608    Optimizer that implements the Momentum algorithm.
2609
2610    Refer to the paper `On the importance of initialization and momentum in deep
2611    learning <https://dl.acm.org/doi/10.5555/3042817.3043064>`_  for more details.
2612
2613    Refer to :class:`mindspore.nn.Momentum` for more details about the formula and usage.
2614
2615    Inputs of `variable`, `accumulation` and `gradient` comply with the implicit type conversion rules
2616    to make the data types consistent.
2617    If they have different data types, lower priority data type will be converted to
2618    relatively highest priority data type.
2619    Data type conversion of Parameter is not supported. RuntimeError exception will be thrown.
2620
2621    Args:
2622        use_locking (bool): Whether to enable a lock to protect the variable and accumulation tensors
2623                            from being updated. Default: False.
2624        use_nesterov (bool): Enable Nesterov momentum. Default: False.
2625        gradient_scale (float): The scale of the gradient. Default: 1.0.
2626
2627    Inputs:
2628        - **variable** (Parameter) - Weights to be updated. data type must be float.
2629        - **accumulation** (Parameter) - Accumulated gradient value by moment weight.
2630          Has the same data type with `variable`.
2631        - **learning_rate** (Union[Number, Tensor]) - The learning rate value, must be a float number or
2632          a scalar tensor with float data type.
2633        - **gradient** (Tensor) - Gradient, has the same data type as `variable`.
2634        - **momentum** (Union[Number, Tensor]) - Momentum, must be a float number or
2635          a scalar tensor with float data type.
2636
2637    Outputs:
2638        Tensor, parameters to be updated.
2639
2640    Raises:
2641        TypeError: If the `use_locking` or `use_nesterov` is not a bool or `gradient_scale` is not a float.
2642
2643    Supported Platforms:
2644        ``Ascend`` ``GPU`` ``CPU``
2645
2646    Examples:
2647        Please refer to the usage in :class:`mindspore.nn.Momentum`.
2648    """
2649    __mindspore_signature__ = (
2650        sig.make_sig('variable', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
2651        sig.make_sig('accumulation', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
2652        sig.make_sig('learning_rate', dtype=sig.sig_dtype.T1),
2653        sig.make_sig('gradient', dtype=sig.sig_dtype.T),
2654        sig.make_sig('momentum', dtype=sig.sig_dtype.T2)
2655    )
2656
2657    @prim_attr_register
2658    def __init__(self, use_nesterov=False, use_locking=False, gradient_scale=1.0):
2659        """Initialize ApplyMomentum."""
2660        self.use_nesterov = validator.check_bool(use_nesterov, "use_nesterov", self.name)
2661        self.use_locking = validator.check_bool(use_locking, "use_locking", self.name)
2662        validator.check_value_type('gradient_scale', gradient_scale, [float], self.name)
2663        self.init_prim_io_names(inputs=['variable', 'accumulation', 'learning_rate', 'gradient', 'momentum'],
2664                                outputs=['output'])
2665        self.add_prim_attr('side_effect_mem', True)
2666
2667    def infer_shape(self, v_shape, a_shape, l_shape, g_shape, m_shape):
2668        return v_shape
2669
2670    def infer_dtype(self, v_dtype, a_dtype, l_dtype, g_dtype, m_dtype):
2671        valid_dtypes = [mstype.float16, mstype.float32, mstype.float64]
2672        if v_dtype != mstype.type_refkey and a_dtype != mstype.type_refkey:
2673            validator.check_tensor_dtype_valid("v", v_dtype, valid_dtypes, self.name)
2674            validator.check_tensor_dtype_valid("a", a_dtype, valid_dtypes, self.name)
2675        validator.check_scalar_or_tensor_types_same({"l_dtype": l_dtype}, valid_dtypes, self.name)
2676        validator.check_scalar_or_tensor_types_same({"g_dtype": g_dtype}, valid_dtypes, self.name)
2677        validator.check_scalar_or_tensor_types_same({"m_dtype": m_dtype}, valid_dtypes, self.name)
2678        return v_dtype
2679
2680
2681class SmoothL1Loss(PrimitiveWithInfer):
2682    r"""
2683    Computes smooth L1 loss, a robust L1 loss.
2684
2685    SmoothL1Loss is a Loss similar to MSELoss but less sensitive to outliers as described in the
2686    `Fast R-CNN <https://arxiv.org/abs/1504.08083>`_ by Ross Girshick.
2687
2688    Given two input :math:`x,\  y` of length :math:`N`, the unreduced SmoothL1Loss can be described
2689    as follows:
2690
2691    .. math::
2692        L_{i} =
2693        \begin{cases}
2694        \frac{0.5 (x_i - y_i)^{2}}{\text{beta}}, & \text{if } |x_i - y_i| < \text{beta} \\
2695        |x_i - y_i| - 0.5 \text{beta}, & \text{otherwise. }
2696        \end{cases}
2697
2698    Here :math:`\text{beta}` controls the point where the loss function changes from quadratic to linear.
2699    Its default value is 1.0. :math:`N` is the batch size. This function returns an
2700    unreduced loss Tensor.
2701
2702    .. warning::
2703        This operator does not perform the "reduce" operation on the loss value.
2704        Call other reduce operators to perform "reduce" operation on the loss if required.
2705
2706    Args:
2707        beta (float): A parameter used to control the point where the function will change from
2708            quadratic to linear. Default: 1.0.
2709
2710    Inputs:
2711        - **logits** (Tensor) - Tensor of shape :math:`(N, *)` where :math:`*` means, any number of
2712          additional dimensions. Data type must be float16 or float32.
2713        - **labels** (Tensor) - Ground truth data, tensor of shape :math:`(N, *)`,
2714          same shape and dtype as the `logits`.
2715
2716    Outputs:
2717        Tensor, loss float tensor, same shape and dtype as the `logits`.
2718
2719    Raises:
2720        TypeError: If `beta` is not a float.
2721        TypeError: If dtype of `logits` or `labels` is neither float16 not float32.
2722        ValueError: If `beta` is less than or equal to 0.
2723        ValueError: If shape of `logits` is not the same as `labels`.
2724
2725    Supported Platforms:
2726        ``Ascend`` ``GPU`` ``CPU``
2727
2728    Examples:
2729        >>> loss = ops.SmoothL1Loss()
2730        >>> logits = Tensor(np.array([1, 2, 3]), mindspore.float32)
2731        >>> labels = Tensor(np.array([1, 2, 2]), mindspore.float32)
2732        >>> output = loss(logits, labels)
2733        >>> print(output)
2734        [0.  0.  0.5]
2735    """
2736
2737    @prim_attr_register
2738    def __init__(self, beta=1.0):
2739        """Initialize SmoothL1Loss."""
2740        validator.check_value_type('beta', beta, [float], self.name)
2741        validator.check('beta', beta, '', 0, Rel.GT, self.name)
2742        self.init_prim_io_names(inputs=['prediction', 'target'], outputs=['output'])
2743
2744    def infer_shape(self, prediction, target):
2745        validator.check('prediction shape', prediction, 'target shape', target, Rel.EQ, self.name)
2746        return prediction
2747
2748    def infer_dtype(self, prediction, target):
2749        args = {"prediction": prediction, "target": target}
2750        validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
2751        return prediction
2752
2753
2754class SoftMarginLoss(Primitive):
2755    r"""
2756    SoftMarginLoss operation.
2757
2758    Creates a criterion that optimizes a two-class classification
2759    logistic loss between input tensor :math:`x` and target tensor :math:`y`
2760    (containing 1 or -1).
2761
2762    .. math::
2763        \text{loss}(x, y) = \sum_i \frac{\log(1 + \exp(-y[i]*x[i]))}{\text{x.nelement}()}
2764
2765    Args:
2766        reduction (str): Apply specific reduction method to the output: 'none', 'mean', 'sum'. Default: "mean".
2767
2768    Inputs:
2769        - **logits** (Tensor) - Predict data. Data type must be float16 or float32.
2770        - **labels** (Tensor) - Ground truth data, with the same type and shape as `logits`.
2771
2772    Outputs:
2773        Tensor or Scalar, if `reduction` is "none", its shape is the same as `logits`.
2774        Otherwise, a scalar value will be returned.
2775
2776    Raises:
2777        TypeError: If `logits` or `labels` is not a Tensor.
2778        TypeError: If dtype of `logits` or `labels` is neither float16 nor float32.
2779        ValueError: If shape of `logits` is not the same as `labels`.
2780        ValueError: If `reduction` is not one of 'none', 'mean', 'sum'.
2781
2782    Supported Platforms:
2783        ``Ascend``
2784
2785    Examples:
2786        >>> loss = ops.SoftMarginLoss()
2787        >>> logits = Tensor(np.array([[0.3, 0.7], [0.5, 0.5]]), mindspore.float32)
2788        >>> labels = Tensor(np.array([[-1, 1], [1, -1]]), mindspore.float32)
2789        >>> output = loss(logits, labels)
2790        >>> print(output)
2791        0.6764238
2792    """
2793
2794    @prim_attr_register
2795    def __init__(self, reduction="mean"):
2796        """Initialize SoftMarginLoss"""
2797        self.init_prim_io_names(inputs=['predict', 'label'], outputs=['loss'])
2798        self.reduction = validator.check_string(reduction, ['none', 'sum', 'mean'], 'reduction', self.name)
2799
2800
2801class L2Loss(PrimitiveWithInfer):
2802    """
2803    Calculates half of the L2 norm of a tensor without using the `sqrt`.
2804
2805    Set `input_x` as x and output as loss.
2806
2807    .. math::
2808        loss = sum(x ** 2) / 2
2809
2810    Inputs:
2811        - **input_x** (Tensor) - A input Tensor. Data type must be float16 or float32.
2812
2813    Outputs:
2814        Tensor, has the same dtype as `input_x`. The output tensor is the value of loss which is a scalar tensor.
2815
2816    Raises:
2817        TypeError: If `input_x` not a Tensor.
2818        TypeError: If dtype of `input_x` is neither float16 nor float32.
2819
2820    Supported Platforms:
2821        ``Ascend`` ``GPU`` ``CPU``
2822
2823    Examples
2824        >>> input_x = Tensor(np.array([1, 2, 3]), mindspore.float16)
2825        >>> l2_loss = ops.L2Loss()
2826        >>> output = l2_loss(input_x)
2827        >>> print(output)
2828        7.0
2829    """
2830
2831    @prim_attr_register
2832    def __init__(self):
2833        """Initialize L2Loss"""
2834
2835    def infer_shape(self, input_x):
2836        loss_shape = []
2837        return loss_shape
2838
2839    def infer_dtype(self, x_type):
2840        valid_dtypes = [mstype.float16, mstype.float32]
2841        validator.check_tensor_dtype_valid('x_type', x_type, valid_dtypes, self.name)
2842        return x_type
2843
2844
2845class DataFormatDimMap(PrimitiveWithInfer):
2846    """
2847    Returns the dimension index in the destination data format given in the source data format.
2848
2849    Args:
2850        src_format (str): An optional value for source data format. The format can be 'NHWC' and 'NCHW'.
2851            Default: 'NHWC'.
2852        dst_format (str): An optional value for destination data format. The format can be 'NHWC' and 'NCHW'.
2853            Default: 'NCHW'.
2854
2855    Inputs:
2856        - **input_x** (Tensor) - A Tensor with each element as a dimension index in source data format.
2857          The suggested values is in the range [-4, 4). Only supports int32.
2858
2859    Outputs:
2860        Tensor, Return the dimension index in the given target data format,
2861        has the same data type and shape as the `input_x`.
2862
2863    Raises:
2864        TypeError: If `src_format` or `dst_format` is not a str.
2865        TypeError: If `input_x` is not a Tensor whose dtype is not int32.
2866
2867    Supported Platforms:
2868        ``Ascend``
2869
2870    Examples:
2871        >>> input_x = Tensor([0, 1, 2, 3], mindspore.int32)
2872        >>> dfdm = ops.DataFormatDimMap()
2873        >>> output = dfdm(input_x)
2874        >>> print(output)
2875        [0 3 1 2]
2876    """
2877
2878    @prim_attr_register
2879    def __init__(self, src_format='NHWC', dst_format='NCHW'):
2880        """Initialize DataFormatDimMap."""
2881        valid_values = ['NHWC', 'NCHW']
2882        self.src_format = validator.check_string(src_format, valid_values, "src_format", self.name)
2883        self.dst_format = validator.check_string(dst_format, valid_values, "dst_format", self.name)
2884        self.init_prim_io_names(inputs=['input_x'], outputs=['output'])
2885
2886    def infer_shape(self, x_shape):
2887        return x_shape
2888
2889    def infer_dtype(self, x_dtype):
2890        valid_dtypes = [mstype.int32]
2891        validator.check_tensor_dtype_valid("x", x_dtype, valid_dtypes, self.name)
2892        return x_dtype
2893
2894
2895class RNNTLoss(PrimitiveWithInfer):
2896    """
2897    Computes the RNNTLoss and its gradient with respect to the softmax outputs.
2898
2899    Args:
2900        blank_label (int): blank label. Default: 0.
2901
2902    Inputs:
2903        - **acts** (Tensor) - Tensor of shape :math:`(B, T, U, V)`. Data type must be float16 or float32.
2904        - **labels** (Tensor) - Tensor of shape :math:`(B, U-1)`. Data type is int32.
2905        - **input_lengths** (Tensor) - Tensor of shape :math:`(B,)`. Data type is int32.
2906        - **label_lengths** (Tensor) - Tensor of shape :math:`(B,)`. Data type is int32.
2907
2908    Outputs:
2909        - **costs** (Tensor) - Tensor of shape :math:`(B,)`. Data type is int32.
2910        - **grads** (Tensor) - Has the same shape and dtype as `acts`.
2911
2912    Raises:
2913        TypeError: If `acts`, `labels`, `input_lengths` or `label_lengths` is not a Tensor.
2914        TypeError: If dtype of `acts` is neither float16 nor float32.
2915        TypeError: If dtype of `labels`, `input_lengths` or `label_lengths` is not int32.
2916
2917    Supported Platforms:
2918        ``Ascend``
2919
2920    Examples:
2921        >>> B, T, U, V = 1, 2, 3, 5
2922        >>> blank = 0
2923        >>> acts = np.random.random((B, T, U, V)).astype(np.float32)
2924        >>> labels = np.array([[1, 2]]).astype(np.int32)
2925        >>> input_length = np.array([T] * B).astype(np.int32)
2926        >>> label_length = np.array([len(l) for l in labels]).astype(np.int32)
2927        >>> rnnt_loss = ops.RNNTLoss(blank_label=0)
2928        >>> costs, grads = rnnt_loss(Tensor(acts), Tensor(labels), Tensor(input_length), Tensor(label_length))
2929        >>> print(costs.shape)
2930        (1,)
2931        >>> print(grads.shape)
2932        (1, 2, 3, 5)
2933    """
2934
2935    @prim_attr_register
2936    def __init__(self, blank_label=0):
2937        """Initialize RNNTLoss."""
2938        validator.check_value_type('blank_label', blank_label, [int], self.name)
2939        self.init_prim_io_names(inputs=['acts', 'labels', 'input_length', 'label_length'],
2940                                outputs=['costs', 'grads'])
2941
2942    def infer_shape(self, acts_shape, labels_shape, input_length_shape, label_length_shape):
2943        validator.check_equal_int(len(acts_shape), 4, 'acts_rank', self.name)
2944        validator.check_equal_int(len(labels_shape), 2, 'labels_rank', self.name)
2945        validator.check_equal_int(len(input_length_shape), 1, 'input_length_rank', self.name)
2946        validator.check_equal_int(len(label_length_shape), 1, 'label_length_rank', self.name)
2947        validator.check('labels shape[0]', labels_shape[0], 'acts shape[0]', acts_shape[0], Rel.EQ, self.name)
2948        validator.check('labels shape[1]', labels_shape[1], 'acts shape[2]-1', acts_shape[2] - 1, Rel.EQ, self.name)
2949        validator.check('input_length size', input_length_shape[0], 'acts shape[0]', acts_shape[0], Rel.EQ, self.name)
2950        validator.check('label_length size', label_length_shape[0], 'acts shape[0]', acts_shape[0], Rel.EQ, self.name)
2951        costs_shape = (acts_shape[0],)
2952        return costs_shape, acts_shape
2953
2954    def infer_dtype(self, acts_type, labels_type, input_length_type, label_length_type):
2955        validator.check_tensor_dtype_valid("acts_type", acts_type, [mstype.float32, mstype.float16], self.name)
2956        tuple(map(partial(validator.check_tensor_dtype_valid,
2957                          valid_dtypes=(mstype.int32,), prim_name=self.name),
2958                  ("labels", "input_length", "label_length"),
2959                  (labels_type, input_length_type, label_length_type)))
2960        return acts_type, acts_type
2961
2962
2963class SGD(PrimitiveWithCheck):
2964    """
2965    Computes the stochastic gradient descent. Momentum is optional.
2966
2967    Nesterov momentum is based on the formula from paper `On the importance of
2968    initialization and momentum in deep learning <http://proceedings.mlr.press/v28/sutskever13.html>`_.
2969
2970    Note:
2971        For more details, please refer to :class:`nn.SGD`.
2972
2973    Args:
2974        dampening (float): The dampening for momentum. Default: 0.0.
2975        weight_decay (float): Weight decay (L2 penalty). Default: 0.0.
2976        nesterov (bool): Enable Nesterov momentum. Default: False.
2977
2978    Inputs:
2979        - **parameters** (Tensor) - Parameters to be updated. With float16 or float32 data type.
2980        - **gradient** (Tensor) - Gradient, with float16 or float32 data type.
2981        - **learning_rate** (Tensor) - Learning rate, a scalar tensor with float16 or float32 data type.
2982          e.g. Tensor(0.1, mindspore.float32)
2983        - **accum** (Tensor) - Accum(velocity) to be updated. With float16 or float32 data type.
2984        - **momentum** (Tensor) - Momentum, a scalar tensor with float16 or float32 data type.
2985          e.g. Tensor(0.1, mindspore.float32).
2986        - **stat** (Tensor) - States to be updated with the same shape as gradient, with float16 or float32 data type.
2987
2988    Outputs:
2989        Tensor, parameters to be updated.
2990
2991    Raises:
2992        TypeError: If `dampening` or `weight_decay` is not a float.
2993        TypeError: If `nesterov` is not a bool.
2994        TypeError: If `parameters`, `gradient`, `learning_rate`, `accum`, `momentum` or `stat` is not a Tensor.
2995        TypeError: If dtype of `parameters`, `gradient`, `learning_rate`, `accum`, `momentum` or `stat` is neither
2996                   float16 nor float32.
2997
2998    Supported Platforms:
2999        ``Ascend`` ``GPU`` ``CPU``
3000
3001    Examples:
3002        >>> sgd = ops.SGD()
3003        >>> parameters = Tensor(np.array([2, -0.5, 1.7, 4]), mindspore.float32)
3004        >>> gradient = Tensor(np.array([1, -1, 0.5, 2]), mindspore.float32)
3005        >>> learning_rate = Tensor(0.01, mindspore.float32)
3006        >>> accum = Tensor(np.array([0.1, 0.3, -0.2, -0.1]), mindspore.float32)
3007        >>> momentum = Tensor(0.1, mindspore.float32)
3008        >>> stat = Tensor(np.array([1.5, -0.3, 0.2, -0.7]), mindspore.float32)
3009        >>> output = sgd(parameters, gradient, learning_rate, accum, momentum, stat)
3010        >>> print(output)
3011        (Tensor(shape=[4], dtype=Float32,
3012         value= [ 1.98989999e+00, -4.90300000e-01,  1.69520009e+00,  3.98009992e+00]),)
3013    """
3014
3015    @prim_attr_register
3016    def __init__(self, dampening=0.0, weight_decay=0.0, nesterov=False):
3017        """Initialize SGD."""
3018        validator.check_value_type("nesterov", nesterov, [bool], self.name)
3019        if nesterov and dampening != 0:
3020            raise ValueError(f"For '{self.name}', the 'dampening' must be 0 when 'nesterov' is True, "
3021                             f"but got 'dampening' is {dampening} and 'nesterov' is {nesterov}.")
3022        self.init_prim_io_names(inputs=['parameters', 'gradient', 'learning_rate', 'accum', 'momentum', 'stat'],
3023                                outputs=['output'])
3024        self.add_prim_attr('side_effect_mem', True)
3025
3026    def check_shape(self, parameters_shape, gradient_shape, learning_rate_shape,
3027                    accum_shape, momentum_shape, stat_shape):
3028        validator.check_positive_int(len(parameters_shape), "parameters rank", self.name)
3029        validator.check_int(len(gradient_shape), 0, Rel.GE, f'gradient rank', self.name)
3030        validator.check_int(len(learning_rate_shape), 0, Rel.GE, f'learning rate rank', self.name)
3031        validator.check_positive_int(len(accum_shape), "accumulation rank", self.name)
3032        validator.check_int(len(momentum_shape), 0, Rel.GE, f'momentum rank', self.name)
3033        validator.check_int(len(stat_shape), 0, Rel.GE, f'stat rank', self.name)
3034        validator.check("gradient shape", gradient_shape, "stat shape", stat_shape, Rel.EQ, self.name)
3035
3036    def check_dtype(self, parameters_dtype, gradient_dtype, learning_rate_dtype,
3037                    accum_dtype, momentum_dtype, stat_dtype):
3038        tuple(map(partial(validator.check_tensor_dtype_valid,
3039                          valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name),
3040                  ("parameters", "gradient", "learning_rate", "accum", "momentum", "stat"),
3041                  (parameters_dtype, gradient_dtype, learning_rate_dtype, accum_dtype, momentum_dtype, stat_dtype)))
3042
3043
3044class ApplyRMSProp(PrimitiveWithInfer):
3045    r"""
3046    Optimizer that implements the Root Mean Square prop(RMSProp) algorithm.
3047    Please refer to the usage in source code of :class:`nn.RMSProp`.
3048
3049    The updating formulas of ApplyRMSProp algorithm are as follows,
3050
3051    .. math::
3052        \begin{array}{ll} \\
3053            s_{t+1} = \rho s_{t} + (1 - \rho)(\nabla Q_{i}(w))^2 \\
3054            m_{t+1} = \beta m_{t} + \frac{\eta} {\sqrt{s_{t+1} + \epsilon}} \nabla Q_{i}(w) \\
3055            w = w - m_{t+1}
3056        \end{array}
3057
3058    where :math:`w` represents `var`, which will be updated.
3059    :math:`s_{t+1}` represents `mean_square`, :math:`s_{t}` is the last momentent of :math:`s_{t+1}`,
3060    :math:`m_{t+1}` represents `moment`, :math:`m_{t}` is the last momentent of :math:`m_{t+1}`.
3061    :math:`\rho` represents `decay`. :math:`\beta` is the momentum term, represents `momentum`.
3062    :math:`\epsilon` is a smoothing term to avoid division by zero, represents `epsilon`.
3063    :math:`\eta` represents `learning_rate`. :math:`\nabla Q_{i}(w)` represents `grad`.
3064
3065    .. warning::
3066        Note that in dense implementation of this algorithm, "mean_square" and "moment" will update even if "grad" is 0,
3067        but in this sparse implementation, "mean_square" and "moment" will not update
3068        in iterations during which "grad" is 0.
3069
3070    Args:
3071        use_locking (bool): Whether to enable a lock to protect the variable and accumlation tensors
3072                            from being updated. Default: False.
3073
3074    Inputs:
3075        - **var** (Tensor) - Weights to be update.
3076        - **mean_square** (Tensor) - Mean square gradients, must have the same type as `var`.
3077        - **moment** (Tensor) - Delta of `var`, must have the same type as `var`.
3078        - **learning_rate** (Union[Number, Tensor]) - Learning rate. Must be a float number or
3079          a scalar tensor with float16 or float32 data type.
3080        - **grad** (Tensor) - Gradient, must have the same type as `var`.
3081        - **decay** (float) - Decay rate. Only constant value is allowed.
3082        - **momentum** (float) - Momentum. Only constant value is allowed.
3083        - **epsilon** (float) - Ridge term. Only constant value is allowed.
3084
3085    Outputs:
3086        Tensor, parameters to be update.
3087
3088    Raises:
3089        TypeError: If `use_locking` is not a bool.
3090        TypeError: If `var`, `mean_square`, `moment` or `decay` is not a Tensor.
3091        TypeError: If `learning_rate` is neither a Number nor a Tensor.
3092        TypeError: If dtype of `decay`, `momentum` or `epsilon` is not float.
3093        TypeError: If dtype of `learning_rate` is neither float16 nor float32.
3094        ValueError: If `decay`, `momentum` or `epsilon` is not a constant value.
3095
3096    Supported Platforms:
3097        ``Ascend`` ``GPU`` ``CPU``
3098
3099    Examples:
3100        >>> class Net(nn.Cell):
3101        ...     def __init__(self):
3102        ...         super(Net, self).__init__()
3103        ...         self.apply_rms_prop = ops.ApplyRMSProp()
3104        ...         self.var = Parameter(Tensor(np.ones([2, 2]).astype(np.float32)), name="var")
3105        ...
3106        ...     def construct(self, mean_square, moment, grad, decay, momentum, epsilon, lr):
3107        ...         out = self.apply_rms_prop(self.var, mean_square, moment, lr, grad, decay, momentum, epsilon)
3108        ...         return out
3109        ...
3110        >>> net = Net()
3111        >>> mean_square = Tensor(np.ones([2, 2]).astype(np.float32))
3112        >>> moment = Tensor(np.ones([2, 2]).astype(np.float32))
3113        >>> grad = Tensor(np.ones([2, 2]).astype(np.float32))
3114        >>> output = net(mean_square, moment, grad, 0.0, 1e-10, 0.001, 0.01)
3115        >>> print(net.var.asnumpy())
3116        [[0.990005  0.990005]
3117         [0.990005  0.990005]]
3118    """
3119
3120    @prim_attr_register
3121    def __init__(self, use_locking=False):
3122        """Initialize ApplyRMSProp."""
3123        self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name)
3124        self.init_prim_io_names(inputs=['var', 'mean_square', 'moment', 'learning_rate', 'grad',
3125                                        'rho', 'momentum', 'epsilon'], outputs=['output'])
3126        self.add_prim_attr('side_effect_mem', True)
3127
3128    def infer_shape(self, var_shape, mean_square_shape, moment_shape, learning_rate_shape, grad_shape, decay_shape,
3129                    momentum_shape, epsilon_shape):
3130        validator.check("var_shape", var_shape, "mean_square_shape", mean_square_shape, Rel.EQ, self.name)
3131        validator.check("var_shape", var_shape, "moment_shape", moment_shape, Rel.EQ, self.name)
3132        validator.check("var_shape", var_shape, "grad_shape", grad_shape, Rel.EQ, self.name)
3133        return var_shape
3134
3135    def infer_dtype(self, var_dtype, mean_square_dtype, moment_dtype, learning_rate_dtype, grad_dtype, decay_dtype,
3136                    momentum_dtype, epsilon_dtype):
3137        args = {"var": var_dtype, "mean_square": mean_square_dtype, "moment": moment_dtype, "grad": grad_dtype}
3138        validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
3139
3140        valid_dtypes = [mstype.float16, mstype.float32]
3141        args_decay = {"decay": decay_dtype, 'momentum': momentum_dtype, "epsilon": epsilon_dtype}
3142        validator.check_types_same_and_valid(args_decay, valid_dtypes, self.name)
3143        args_lr = {"learning_rate": learning_rate_dtype, "decay": decay_dtype}
3144        validator.check_scalar_or_tensor_types_same(args_lr, valid_dtypes, self.name, allow_mix=True)
3145        return var_dtype
3146
3147    def infer_value(self, var, mean_square, moment, learning_rate, grad, decay, momentum, epsilon):
3148        if decay is None or momentum is None or epsilon is None:
3149            raise ValueError(f"For '{self.name}', 'decay', 'momentum' and 'epsilon' can not be None, "
3150                             f"but got 'decay': {decay}, 'momentum': {momentum} and 'epsilon':{epsilon}.")
3151
3152
3153class ApplyCenteredRMSProp(PrimitiveWithInfer):
3154    r"""
3155    Optimizer that implements the centered RMSProp algorithm.
3156    Please refer to the usage in source code of :class:`nn.RMSProp`.
3157
3158    The updating formulas of ApplyCenteredRMSProp algorithm are as follows,
3159
3160    .. math::
3161        \begin{array}{ll} \\
3162            g_{t+1} = \rho g_{t} + (1 - \rho)\nabla Q_{i}(w) \\
3163            s_{t+1} = \rho s_{t} + (1 - \rho)(\nabla Q_{i}(w))^2 \\
3164            m_{t+1} = \beta m_{t} + \frac{\eta} {\sqrt{s_{t+1} - g_{t+1}^2 + \epsilon}} \nabla Q_{i}(w) \\
3165            w = w - m_{t+1}
3166        \end{array}
3167
3168    where :math:`w` represents `var`, which will be updated.
3169    :math:`g_{t+1}` represents `mean_gradient`, :math:`g_{t}` is the last momentent of :math:`g_{t+1}`.
3170    :math:`s_{t+1}` represents `mean_square`, :math:`s_{t}` is the last momentent of :math:`s_{t+1}`,
3171    :math:`m_{t+1}` represents `moment`, :math:`m_{t}` is the last momentent of :math:`m_{t+1}`.
3172    :math:`\rho` represents `decay`. :math:`\beta` is the momentum term, represents `momentum`.
3173    :math:`\epsilon` is a smoothing term to avoid division by zero, represents `epsilon`.
3174    :math:`\eta` represents `learning_rate`. :math:`\nabla Q_{i}(w)` represents `grad`.
3175
3176    Note:
3177        The difference between `ApplyCenteredRMSProp` and `ApplyRMSProp` is that the fromer
3178        uses the centered RMSProp algorithm, and the centered RRMSProp algorithm uses an estimate of the centered second
3179        moment(i.e., the variance) for normalization, as opposed to regular RMSProp, which uses the (uncentered)
3180        second moment. This often helps with training, but is slightly more exapnsive interms of computation and memory.
3181
3182    .. warning::
3183        In dense implementation of this algorithm, `mean_gradient`, `mean_square`, and `moment` will update
3184        even if the `grad` is zero. But in this sparse implementation, `mean_gradient`, `mean_square`, and `moment`
3185        will not update in iterations during which the `grad` is zero.
3186
3187    Args:
3188        use_locking (bool): Whether to enable a lock to protect the variable and accumlation tensors
3189                            from being updated. Default: False.
3190
3191    Inputs:
3192        - **var** (Tensor) - Weights to be update.
3193        - **mean_gradient** (Tensor) - Mean gradients, must have the same type as `var`.
3194        - **mean_square** (Tensor) - Mean square gradients, must have the same type as `var`.
3195        - **moment** (Tensor) - Delta of `var`, must have the same type as `var`.
3196        - **grad** (Tensor) - Gradient, must have the same type as `var`.
3197        - **learning_rate** (Union[Number, Tensor]) - Learning rate. Must be a float number or
3198          a scalar tensor with float16 or float32 data type.
3199        - **decay** (float) - Decay rate.
3200        - **momentum** (float) - Momentum.
3201        - **epsilon** (float) - Ridge term.
3202
3203    Outputs:
3204        Tensor, parameters to be update.
3205
3206    Raises:
3207        TypeError: If `use_locking` is not a bool.
3208        TypeError: If `var`, `mean_gradient`, `mean_square`, `moment` or `grad` is not a Tensor.
3209        TypeError: If `learing_rate` is neither a Number nor a Tensor.
3210        TypeError: If dtype of `learing_rate` is neither float16 nor float32.
3211        TypeError: If `decay`, `momentum` or `epsilon` is not a float.
3212
3213    Supported Platforms:
3214        ``Ascend`` ``GPU`` ``CPU``
3215
3216    Examples:
3217        >>> class Net(nn.Cell):
3218        ...     def __init__(self):
3219        ...         super(Net, self).__init__()
3220        ...         self.apply_centerd_rms_prop = ops.ApplyCenteredRMSProp()
3221        ...         self.var = Parameter(Tensor(np.ones([2, 2]).astype(np.float32)), name="var")
3222        ...
3223        ...     def construct(self, mean_grad, mean_square, moment, grad, decay, momentum, epsilon, lr):
3224        ...         out = self.apply_centerd_rms_prop(self.var, mean_grad, mean_square, moment, grad,
3225        ...                                           lr, decay, momentum, epsilon)
3226        ...         return out
3227        ...
3228        >>> net = Net()
3229        >>> mean_grad = Tensor(np.ones([2, 2]).astype(np.float32))
3230        >>> mean_square = Tensor(np.ones([2, 2]).astype(np.float32))
3231        >>> moment = Tensor(np.ones([2, 2]).astype(np.float32))
3232        >>> grad = Tensor(np.ones([2, 2]).astype(np.float32))
3233        >>> output = net(mean_grad, mean_square, moment, grad, 0.0, 1e-10, 0.001, 0.01)
3234        >>> print(net.var.asnumpy())
3235        [[0.68377227  0.68377227]
3236         [0.68377227  0.68377227]]
3237    """
3238
3239    @prim_attr_register
3240    def __init__(self, use_locking=False):
3241        """Initialize ApplyCenteredRMSProp."""
3242        self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name)
3243        self.add_prim_attr('side_effect_mem', True)
3244
3245    def infer_shape(self, var_shape, mean_gradient_shape, mean_square_shape, moment_shape, grad_shape,
3246                    learning_rate_shape, decay_shape, momentum_shape, epsilon_shape):
3247        validator.check("var_shape", var_shape, "mean_gradient_shape", mean_gradient_shape, Rel.EQ, self.name)
3248        validator.check("var_shape", var_shape, "mean_square_shape", mean_square_shape, Rel.EQ, self.name)
3249        validator.check("var_shape", var_shape, "moment_shape", moment_shape, Rel.EQ, self.name)
3250        validator.check("var_shape", var_shape, "grad_shape", grad_shape, Rel.EQ, self.name)
3251        return var_shape
3252
3253    def infer_dtype(self, var_dtype, mean_gradient_dtype, mean_square_dtype, moment_dtype, grad_dtype,
3254                    learning_rate_dtype, rho_dtype, momentum_dtype, epsilon_dtype):
3255        args = {"var": var_dtype, "mean_gradient": mean_gradient_dtype,
3256                "mean_square": mean_square_dtype, "moment": moment_dtype, "grad": grad_dtype}
3257        validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
3258
3259        valid_dtypes = [mstype.float16, mstype.float32]
3260        args_rho = {"rho": rho_dtype, 'momentum': momentum_dtype, "epsilon": epsilon_dtype}
3261        validator.check_types_same_and_valid(args_rho, valid_dtypes, self.name)
3262        args_lr = {"learning_rate": learning_rate_dtype, "rho": rho_dtype}
3263        validator.check_scalar_or_tensor_types_same(args_lr, valid_dtypes, self.name, allow_mix=True)
3264        return var_dtype
3265
3266
3267class LayerNorm(Primitive):
3268    r"""
3269    Applies the Layer Normalization to the input tensor.
3270
3271    This operator will normalize the input tensor on given axis. LayerNorm is described in the paper
3272    `Layer Normalization <https://arxiv.org/abs/1607.06450>`_.
3273
3274    .. math::
3275        y = \frac{x - mean}{\sqrt{variance + \epsilon}} * \gamma + \beta
3276
3277    where :math:`\gamma` is scale, :math:`\beta` is bias, :math:`\epsilon` is epsilon.
3278
3279    Args:
3280        begin_norm_axis (int): The begin axis of the `input_x` to apply LayerNorm,
3281            the value must be in [-1, rank(input)). Default: 1.
3282        begin_params_axis (int): The begin axis of the parameter input (`gamma`, `beta`) to
3283            apply LayerNorm, the value must be in [-1, rank(input)). Default: 1.
3284        epsilon (float): A value added to the denominator for numerical stability. Default: 1e-7.
3285
3286    Inputs:
3287        - **input_x** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
3288          The input of LayerNorm.
3289        - **gamma** (Tensor) - Tensor of shape :math:`(P_0, \ldots, P_\text{begin_params_axis})`.
3290          The learnable parameter `gamma` as the scale on norm.
3291        - **beta** (Tensor) - Tensor of shape :math:`(P_0, \ldots, P_\text{begin_params_axis})`.
3292          The learnable parameter `beta` as the scale on norm.
3293
3294    Outputs:
3295        tuple[Tensor], tuple of 3 tensors, the normalized input and the updated parameters.
3296
3297        - **output_x** (Tensor) - The normalized input, has the same type and shape as the `input_x`.
3298          The shape is :math:`(N, C)`.
3299        - **mean** (Tensor) - Tensor of shape :math:`(C,)`.
3300        - **variance** (Tensor) - Tensor of shape :math:`(C,)`.
3301
3302    Raises:
3303        TypeError: If `begin_norm_axis` or `begin_params_axis` is not an int.
3304        TypeError: If `epsilon` is not a float.
3305        TypeError: If `input_x`, `gamma` or `beta` is not a Tensor.
3306
3307    Supported Platforms:
3308        ``Ascend`` ``GPU`` ``CPU``
3309
3310    Examples:
3311        >>> input_x = Tensor(np.array([[1, 2, 3], [1, 2, 3]]), mindspore.float32)
3312        >>> gamma = Tensor(np.ones([3]), mindspore.float32)
3313        >>> beta = Tensor(np.ones([3]), mindspore.float32)
3314        >>> layer_norm = ops.LayerNorm()
3315        >>> output, mean, variance = layer_norm(input_x, gamma, beta)
3316        >>> print(output)
3317        [[-0.2247448  1.         2.2247448]
3318         [-0.2247448  1.         2.2247448]]
3319        >>> print(mean)
3320        [[2.]
3321         [2.]]
3322        >>> print(variance)
3323        [[0.6666667]
3324         [0.6666667]]
3325    """
3326
3327    @prim_attr_register
3328    def __init__(self, begin_norm_axis=1, begin_params_axis=1, epsilon=1e-7):
3329        """Initialize LayerNorm."""
3330        validator.check_value_type('begin_norm_axis', begin_norm_axis, [int], self.name)
3331        validator.check_value_type('begin_params_axis', begin_params_axis, [int], self.name)
3332        validator.check_value_type('epsilon', epsilon, [float], self.name)
3333
3334
3335class L2Normalize(PrimitiveWithInfer):
3336    r"""
3337    L2 Normalization Operator.
3338
3339    This operator will normalize the input using the given axis. The function is shown as follows:
3340
3341    .. math::
3342        \text{output} = \frac{x}{\sqrt{\text{max}(\text{sum} (\text{x}^2), \epsilon)}},
3343
3344    where :math:`\epsilon` is epsilon.
3345
3346    Args:
3347        axis (Union[list(int), tuple(int), int]): The starting axis for the input to apply the L2 Normalization.
3348                                                  Default: 0.
3349        epsilon (float): A small value added for numerical stability. Default: 1e-4.
3350
3351    Inputs:
3352        - **x** (Tensor) - Input to compute the normalization. Tensor of shape :math:`(N, \ldots)`.
3353          Data type must be float16 or float32.
3354
3355    Outputs:
3356        Tensor, with the same type and shape as the `x`.
3357
3358    Raises:
3359        TypeError: If `axis` is not one of the following: list, tuple or int.
3360        TypeError: If `epsilon` is not a float.
3361        TypeError: If `x` is not a Tensor.
3362        TypeError: If dtype of `x` is neither float16 nor float32.
3363
3364    Supported Platforms:
3365        ``Ascend`` ``GPU`` ``CPU``
3366
3367    Examples:
3368        >>> l2_normalize = ops.L2Normalize()
3369        >>> x = Tensor(np.random.randint(-256, 256, (2, 3, 4)), mindspore.float32)
3370        >>> output = l2_normalize(x)
3371        >>> print(output.shape)
3372        (2, 3, 4)
3373    """
3374
3375    @prim_attr_register
3376    def __init__(self, axis=0, epsilon=1e-4):
3377        """Initialize L2Normalize."""
3378        axis = [axis] if isinstance(axis, int) else axis
3379        validator.check_value_type('axis', axis, [list, tuple], self.name)
3380        validator.check_value_type('epsilon', epsilon, [int, float], self.name)
3381        self.add_prim_attr('axis', axis)
3382        self.init_attrs['axis'] = axis
3383        if len(axis) != 1:
3384            raise TypeError(f"For '{self.name}', the length of 'axis' must be 1, but got {len(axis)}, "
3385                            f"later will support multiple axis!")
3386        self.axis = axis
3387
3388    def infer_shape(self, input_x):
3389        dim = len(input_x)
3390        validator.check_int_range(self.axis[0], -dim, dim, Rel.INC_LEFT, 'axis value', self.name)
3391        return input_x
3392
3393    def infer_dtype(self, input_x):
3394        validator.check_tensor_dtype_valid("input_x", input_x, [mstype.float16, mstype.float32], self.name)
3395        return input_x
3396
3397
3398class DropoutGenMask(Primitive):
3399    """
3400    Generates the mask value for the input shape.
3401
3402    Dropout means that neural network units are temporarily dropped from the network according to a certain probability
3403    during the deep learning network training. Generally, The effect of Dropout is the same as that of DropoutGenMask
3404    and DropoutDoMask. The DropoutGenMask generates a mask shape that is specified based on the input. Next,
3405    The DropoutDoMask is a mask generated using DropoutGenMask.
3406    The input tensor is randomly set to zero based on the probability p.
3407
3408
3409    Args:
3410        Seed0 (int): Seed0 value for random generating. Default: 0.
3411        Seed1 (int): Seed1 value for random generating. Default: 0.
3412
3413    Inputs:
3414        - **shape** (tuple[int]) - The shape of target mask.
3415        - **keep_prob** (Tensor) - The keep rate, greater than 0 and less equal than 1, e.g. keep_prob = 0.9,
3416          means dropping out 10% of input units.
3417
3418    Outputs:
3419        Tensor, the value of generated mask for Inputs `shape`.
3420
3421    Raises:
3422        TypeError: If neither `seed0` nor `seed1` is an int.
3423        TypeError: If `shape` is not a tuple.
3424        TypeError: If `keep_prob` is not a Tensor.
3425
3426    Supported Platforms:
3427        ``Ascend``
3428
3429    Examples:
3430        >>> dropout_gen_mask = ops.DropoutGenMask()
3431        >>> shape = (2, 4, 5)
3432        >>> keep_prob = Tensor(0.5, mindspore.float32)
3433        >>> output = dropout_gen_mask(shape, keep_prob)
3434        >>> print(output.shape)
3435        (16,)
3436    """
3437
3438    @prim_attr_register
3439    def __init__(self, Seed0=0, Seed1=0):
3440        """Initialize DropoutGenMask."""
3441        self.init_prim_io_names(inputs=['shape', 'keep_prob'], outputs=['output'])
3442        validator.check_value_type("Seed0", Seed0, [int], self.name)
3443        validator.check_value_type("Seed1", Seed1, [int], self.name)
3444        self.add_prim_attr("_random_effect", True)
3445
3446
3447class DropoutDoMask(Primitive):
3448    r"""
3449    Applies dropout mask on the input tensor.
3450
3451    Take the mask output of DropoutGenMask as input, and apply dropout on the input.
3452
3453    Dropout means that neural network units are temporarily dropped from the network according to a certain probability
3454    during the deep learning network training. Generally, The effect of Dropout is the same as that of DropoutGenMask
3455    and DropoutDoMask. The DropoutGenMask generates a mask shape that is specified based on the input. Next,
3456    The DropoutDoMask is a mask generated using DropoutGenMask.
3457    The input tensor is randomly set to zero based on the probability p.
3458
3459    Inputs:
3460        - **input_x** (Tensor) - The input tensor. Tensor of shape :math:`(N, \ldots)`.
3461          The data type should be float32, float16 or int32
3462        - **mask** (Tensor) - The mask to be applied on `input_x`, which is the output of `DropoutGenMask`. And the
3463          shape of `input_x` must be the same as the value of `DropoutGenMask`'s input `shape`. If input wrong `mask`,
3464          the output of `DropoutDoMask` are unpredictable.
3465        - **keep_prob** (Union[Tensor, float]) - The keep rate, greater than 0 and less equal than 1, e.g. keep_prob =
3466          0.9, means dropping out 10% of input units. The value of `keep_prob` is the same as the input `keep_prob` of
3467          the operator `DropoutGenMask`.
3468
3469    Outputs:
3470        Tensor, the value that applied dropout on, as the same data type and shape as `input_x`.
3471
3472    Raises:
3473        TypeError: If `input_x`, `mask` or `keep_prob` is not a Tensor.
3474        TypeError: If `keep_prob` is not a float.
3475        ValueError: If value of `keep_prob` is not same as `DropoutGenMaks`.
3476
3477    Supported Platforms:
3478        ``Ascend``
3479
3480    Examples:
3481        >>> input_x = Tensor(np.ones([2, 2, 3]), mindspore.float32)
3482        >>> shape = (2, 2, 3)
3483        >>> keep_prob = Tensor(0.5, mindspore.float32)
3484        >>> dropout_gen_mask = ops.DropoutGenMask()
3485        >>> dropout_do_mask = ops.DropoutDoMask()
3486        >>> mask = dropout_gen_mask(shape, keep_prob)
3487        >>> output = dropout_do_mask(input_x, mask, keep_prob)
3488        >>> print(output.shape)
3489        (2, 2, 3)
3490    """
3491
3492    @prim_attr_register
3493    def __init__(self):
3494        pass
3495
3496
3497class ResizeBilinear(PrimitiveWithInfer):
3498    r"""
3499    Resizes an image to a certain size using the bilinear interpolation.
3500
3501    The resizing only affects the lower two dimensions which represent the height and width. The input images
3502    can be represented by different data types, but the data types of output images are always float32.
3503
3504    Args:
3505        size (Union[tuple[int], list[int]]): A tuple or list of 2 int elements :math:`(new\_height, new\_width)`,
3506            the new size of the images.
3507        align_corners (bool): If true, rescale input by :math:`(new\_height - 1) / (height - 1)`,
3508                       which exactly aligns the 4 corners of images and resized images. If false,
3509                       rescale by :math:`new\_height / height`. Default: False.
3510
3511    Inputs:
3512        - **x** (Tensor) - Image to be resized. Input images must be a 4-D tensor with shape
3513          :math:`(batch, channels, height, width)`, with data type of float32 or float16.
3514
3515    Outputs:
3516        Tensor, resized image. 4-D with shape :math:`(batch, channels, new\_height, new\_width)`,
3517        with the same data type as input `x`.
3518
3519    Raises:
3520        TypeError: If `size` is neither a tuple nor list.
3521        TypeError: If `align_corners` is not a bool.
3522        TypeError: If dtype of `x` is neither float16 nor float32.
3523        TypeError: If `x` is not a Tensor.
3524        ValueError: If length of shape of `x` is not equal to 4.
3525
3526    Supported Platforms:
3527        ``Ascend`` ``CPU`` ``GPU``
3528
3529    Examples:
3530        >>> x = Tensor([[[[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]]]], mindspore.float32)
3531        >>> resize_bilinear = ops.ResizeBilinear((5, 5))
3532        >>> output = resize_bilinear(x)
3533        >>> print(output)
3534        [[[[1. 2. 3. 4. 5.]
3535           [1. 2. 3. 4. 5.]
3536           [1. 2. 3. 4. 5.]
3537           [1. 2. 3. 4. 5.]
3538           [1. 2. 3. 4. 5.]]]]
3539    """
3540
3541    @prim_attr_register
3542    def __init__(self, size, align_corners=False):
3543        """Initialize ResizeBilinear."""
3544        validator.check_value_type("size", size, [tuple, list], self.name)
3545        validator.check_equal_int(len(size), 2, "size len", self.name)
3546        for item in size:
3547            validator.check_positive_int(item, 'size item', self.name)
3548            validator.check_value_type("size item", item, int, self.name)
3549        validator.check_value_type("align_corners", align_corners, [bool], self.name)
3550        for i, value in enumerate(size):
3551            validator.check_positive_int(value, f'{i}th value of size', self.name)
3552
3553    def infer_shape(self, input_shape):
3554        validator.check("input shape rank", len(input_shape), "", 4, Rel.EQ, self.name)
3555        input_shape = list(input_shape)
3556        batch, channel, _, _ = input_shape
3557        out_shape = [batch, channel]
3558        for i in self.size:
3559            out_shape.append(int(i))
3560        return out_shape
3561
3562    def infer_dtype(self, input_dtype):
3563        validator.check_tensor_dtype_valid('input_dtype', input_dtype, [mstype.float16, mstype.float32],
3564                                           self.name)
3565        return input_dtype
3566
3567
3568class OneHot(Primitive):
3569    r"""
3570    Computes a one-hot tensor.
3571
3572    Makes a new tensor, whose locations represented by indices in `indices` take value `on_value`, while all
3573    other locations take value `off_value`.
3574
3575    Note:
3576        If the input indices is rank `N`, the output will have rank `N+1`. The new axis is created at dimension `axis`.
3577
3578    Args:
3579        axis (int): Position to insert the value. e.g. If shape of `indices` is :math:`(N, C)`, and `axis` is -1,
3580            the output shape will be :math:`(N, C, D)`, If `axis` is 0, the output shape will be :math:`(D, N, C)`.
3581            Default: -1.
3582
3583    Inputs:
3584        - **indices** (Tensor) - A tensor of indices. Tensor of shape :math:`(X_0, \ldots, X_n)`.
3585          Data type must be int32 or int64.
3586        - **depth** (int) - A scalar defining the depth of the one hot dimension.
3587        - **on_value** (Tensor) - A value to fill in output when `indices[j] = i`.
3588          With data type of float16 or float32.
3589        - **off_value** (Tensor) - A value to fill in output when `indices[j] != i`.
3590          Has the same data type as `on_value`.
3591
3592    Outputs:
3593        Tensor, one-hot tensor. Tensor of shape :math:`(X_0, \ldots, X_{axis}, \text{depth} ,X_{axis+1}, \ldots, X_n)`.
3594
3595    Raises:
3596        TypeError: If `axis` or `depth` is not an int.
3597        TypeError: If dtype of `indices` is neither int32 nor int64.
3598        TypeError: If `indices`, `on_value` or `off_value` is not a Tensor.
3599        ValueError: If `axis` is not in range [-1, len(indices_shape)].
3600        ValueError: If `depth` is less than 0.
3601
3602    Supported Platforms:
3603        ``Ascend`` ``GPU`` ``CPU``
3604
3605    Examples:
3606        >>> indices = Tensor(np.array([0, 1, 2]), mindspore.int32)
3607        >>> depth, on_value, off_value = 3, Tensor(1.0, mindspore.float32), Tensor(0.0, mindspore.float32)
3608        >>> onehot = ops.OneHot()
3609        >>> output = onehot(indices, depth, on_value, off_value)
3610        >>> print(output)
3611        [[1. 0. 0.]
3612         [0. 1. 0.]
3613         [0. 0. 1.]]
3614    """
3615
3616    @prim_attr_register
3617    def __init__(self, axis=-1):
3618        """Initialize OneHot."""
3619        self.init_prim_io_names(inputs=['indices', 'depth', 'on_value', 'off_value'], outputs=['output'])
3620        validator.check_value_type("axis", axis, [int], self.name)
3621
3622
3623class Gelu(PrimitiveWithInfer):
3624    """
3625    Same as operator GeLU. Gelu will be deprecated in the future.
3626    Please use GeLU instead.
3627    """
3628
3629    @deprecated("1.1", "GeLU", True)
3630    @prim_attr_register
3631    def __init__(self):
3632        """Initialize Gelu"""
3633        self.init_prim_io_names(inputs=['x'], outputs=['output'])
3634
3635    def infer_shape(self, input_x):
3636        return input_x
3637
3638    def infer_dtype(self, input_x):
3639        validator.check_tensor_dtype_valid("input_x", input_x, (mstype.float16, mstype.float32), self.name)
3640        return input_x
3641
3642
3643class GeLU(Primitive):
3644    r"""
3645    Gaussian Error Linear Units activation function.
3646
3647    GeLU is described in the paper `Gaussian Error Linear Units (GELUs) <https://arxiv.org/abs/1606.08415>`_.
3648    And also please refer to `BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding
3649    <https://arxiv.org/abs/1810.04805>`_.
3650
3651    GeLU is defined as follows:
3652
3653    .. math::
3654        \text{output} = 0.5 * x * (1 + tanh(x / \sqrt{2})),
3655
3656    where :math:`tanh` is the hyperbolic tangent.
3657
3658    Inputs:
3659        - **x** (Tensor) - Input to compute the GeLU with data type of float16 or float32.
3660
3661    Outputs:
3662        Tensor, with the same type and shape as `x`.
3663
3664    Raises:
3665        TypeError: If `x` is not a Tensor.
3666        TypeError: If dtype of `x` is neither float16 nor float32.
3667
3668    Supported Platforms:
3669        ``Ascend`` ``GPU`` ``CPU``
3670
3671    Examples:
3672        >>> x = Tensor(np.array([1.0, 2.0, 3.0]), mindspore.float32)
3673        >>> gelu = ops.GeLU()
3674        >>> result = gelu(x)
3675        >>> print(result)
3676        [0.841192  1.9545976  2.9963627]
3677    """
3678
3679    @prim_attr_register
3680    def __init__(self):
3681        """Initialize GeLU"""
3682        self.init_prim_io_names(inputs=['x'], outputs=['output'])
3683
3684
3685class FastGelu(PrimitiveWithInfer):
3686    """
3687    Same as operator FastGeLU. FastGelu will be deprecated in the future.
3688    Please use FastGeLU instead.
3689    """
3690
3691    @deprecated("1.1", "FastGeLU", True)
3692    @prim_attr_register
3693    def __init__(self):
3694        """Initialize FastGelu."""
3695        self.init_prim_io_names(inputs=['x'], outputs=['output'])
3696
3697    def infer_shape(self, input_x):
3698        return input_x
3699
3700    def infer_dtype(self, input_x):
3701        validator.check_tensor_dtype_valid("input_x", input_x, (mstype.float16, mstype.float32), self.name)
3702        return input_x
3703
3704
3705class FastGeLU(PrimitiveWithInfer):
3706    r"""
3707    Fast Gaussian Error Linear Units activation function.
3708
3709    FastGeLU is defined as follows:
3710
3711    .. math::
3712        \text{output} = \frac {x} {1 + \exp(-1.702 * \left| x \right|)} * \exp(0.851 * (x - \left| x \right|)),
3713
3714    where :math:`x` is the element of the input.
3715
3716    Inputs:
3717        - **x** (Tensor) - Input to compute the FastGeLU with data type of float16 or float32.
3718
3719    Outputs:
3720        Tensor, with the same type and shape as `x`.
3721
3722    Raises:
3723        TypeError: If dtype of `x` is neither float16 nor float32.
3724
3725    Supported Platforms:
3726        ``Ascend``
3727
3728    Examples:
3729        >>> x = Tensor(np.array([[-1.0, 4.0, -8.0], [2.0, -5.0, 9.0]]), mindspore.float32)
3730        >>> fast_gelu = ops.FastGeLU()
3731        >>> output = fast_gelu(x)
3732        >>> print(output)
3733        [[-1.5418735e-01  3.9921875e+00 -9.7473649e-06]
3734         [ 1.9375000e+00 -1.0052517e-03  8.9824219e+00]]
3735    """
3736
3737    @prim_attr_register
3738    def __init__(self):
3739        """Initialize FastGeLU."""
3740        self.init_prim_io_names(inputs=['x'], outputs=['output'])
3741
3742    def infer_shape(self, input_x):
3743        return input_x
3744
3745    def infer_dtype(self, input_x):
3746        validator.check_tensor_dtype_valid("input_x", input_x, (mstype.float16, mstype.float32), self.name)
3747        return input_x
3748
3749
3750class GetNext(Primitive):
3751    """
3752    Returns the next element in the dataset queue.
3753
3754    Note:
3755        The GetNext operation needs to be associated with network and it also depends on the init_dataset interface,
3756        it can't be used directly as a single operation.
3757        For details, please refer to `connect_network_with_dataset` source code.
3758
3759    Args:
3760        types (list[:class:`mindspore.dtype`]): The type of the outputs.
3761        shapes (list[tuple[int]]): The dimensionality of the outputs.
3762        output_num (int): The output number, length of `types` and `shapes`.
3763        shared_name (str): The queue name of `init_dataset` interface.
3764
3765    Inputs:
3766        No inputs.
3767
3768    Outputs:
3769        tuple[Tensor], the output of Dataset. The shape is described in `shapes`
3770        and the type is described in `types`.
3771
3772    Supported Platforms:
3773        ``Ascend`` ``GPU``
3774
3775    Examples:
3776        >>> train_dataset = create_custom_dataset()
3777        >>> dataset_helper = mindspore.DatasetHelper(train_dataset, dataset_sink_mode=True)
3778        >>> dataset = dataset_helper.iter.dataset
3779        >>> dataset_types, dataset_shapes = dataset_helper.types_shapes()
3780        >>> queue_name = dataset.__transfer_dataset__.queue_name
3781        >>> get_next = ops.GetNext(dataset_types, dataset_shapes, len(dataset_types), queue_name)
3782        >>> data, label = get_next()
3783        >>> relu = ops.ReLU()
3784        >>> result = relu(data).asnumpy()
3785        >>> print(result.shape)
3786        (32, 1, 32, 32)
3787    """
3788
3789    @prim_attr_register
3790    def __init__(self, types, shapes, output_num, shared_name):
3791        """Initialize GetNext."""
3792        validator.check_value_type("types", types, [list, tuple], self.name)
3793        validator.check_value_type("shapes", shapes, [list, tuple], self.name)
3794        validator.check("types length", len(types), "shapes length", len(shapes), Rel.EQ, self.name)
3795        validator.check_value_type("output_num", output_num, [int], self.name)
3796
3797
3798class PReLU(PrimitiveWithInfer):
3799    r"""
3800    Parametric Rectified Linear Unit activation function.
3801
3802    PReLU is described in the paper `Delving Deep into Rectifiers: Surpassing Human-Level Performance on
3803    ImageNet Classification <https://arxiv.org/abs/1502.01852>`_. Defined as follows:
3804
3805    .. math::
3806        prelu(x_i)= \max(0, x_i) + \min(0, w * x_i),
3807
3808    where :math:`x_i` is an element of an channel of the input, `w` is the weight of the channel.
3809
3810    Note:
3811        0-D or 1-D input_x is not supported on Ascend.
3812
3813    Inputs:
3814        - **x** (Tensor) - The first input tensor, representing the output of the preview layer.
3815          With data type of float16 or float32.
3816          The shape is :math:`(N, C, *)` where :math:`*` means, any number of additional dimensions.
3817        - **weight** (Tensor) -  The second input tensor. The data type is float16 or float32.
3818          There are only two shapes are legitimate, 1 or the number of channels of the `input_x`.
3819          Channel dim is the 2nd dim of input. When input is 0-D or 1-D tensor, the number of channels is 1.
3820
3821    Outputs:
3822        Tensor, with the same type as `x`.
3823
3824    For detailed information, please refer to :class:`nn.PReLU`.
3825
3826    Raises:
3827        TypeError: If dtype of `x` or `weight` is neither float16 nor float32.
3828        TypeError: If the `x` or the `weight` is not a Tensor.
3829        ValueError: If the `x` is a 0-D or 1-D Tensor on Ascned.
3830        ValueError: If the `weight` is not a 1-D Tensor.
3831
3832    Supported Platforms:
3833        ``Ascend`` ``GPU``
3834
3835    Examples:
3836        >>> class Net(nn.Cell):
3837        ...     def __init__(self):
3838        ...         super(Net, self).__init__()
3839        ...         self.prelu = ops.PReLU()
3840        ...     def construct(self, x, weight):
3841        ...         result = self.prelu(x, weight)
3842        ...         return result
3843        ...
3844        >>> x = Tensor(np.arange(-6, 6).reshape((2, 3, 2)), mindspore.float32)
3845        >>> weight = Tensor(np.array([0.1, 0.6, -0.3]), mindspore.float32)
3846        >>> net = Net()
3847        >>> output = net(x, weight)
3848        >>> print(output)
3849        [[[-0.60 -0.50]
3850          [-2.40 -1.80]
3851          [ 0.60  0.30]]
3852         [[ 0.00  1.00]
3853          [ 2.00  3.00]
3854          [ 4.0   5.00]]]
3855    """
3856
3857    @prim_attr_register
3858    def __init__(self):
3859        pass
3860
3861    def infer_shape(self, input_x_shape, weight_shape):
3862        input_x_dim = len(input_x_shape)
3863        if input_x_dim in (0, 1):
3864            if context.get_context("device_target") == "Ascend":
3865                raise ValueError(f"For '{self.name}', the dimension of 'x' can not be 0-D or 1-D when the platform is "
3866                                 f"\"Ascend\", but got dimension of 'x' is {input_x_dim}.")
3867            channel_num = 1
3868        else:
3869            channel_num = input_x_shape[1]
3870
3871        weight_dim = len(weight_shape)
3872        if weight_dim != 1:
3873            raise ValueError(f"For '{self.name}', the dimension of 'weight' should be 1, while got {weight_dim}.")
3874        if weight_shape[0] != 1 and weight_shape[0] != channel_num:
3875            raise ValueError(f"For '{self.name}', the first dimension of 'weight' should be (1,) or "
3876                             f"it should be equal to number of channels: {channel_num}, but got {weight_shape}")
3877        return input_x_shape
3878
3879    def infer_dtype(self, input_x_dtype, weight_dtype):
3880        valid_dtypes = (mstype.float16, mstype.float32)
3881        args = {"input_x": input_x_dtype, "weight": weight_dtype}
3882        if context.get_context("device_target") == "GPU":
3883            validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
3884        else:
3885            validator.check_tensor_dtype_valid("input_x", input_x_dtype, valid_dtypes, self.name)
3886            validator.check_tensor_dtype_valid("weight", weight_dtype, valid_dtypes, self.name)
3887        return input_x_dtype
3888
3889
3890class LSTM(PrimitiveWithInfer):
3891    """
3892    Performs the Long Short-Term Memory (LSTM) on the input.
3893
3894    For detailed information, please refer to :class:`nn.LSTM`.
3895
3896    Args:
3897        input_size (int): Number of features of input.
3898        hidden_size (int):  Number of features of hidden layer.
3899        num_layers (int): Number of layers of stacked LSTM.
3900        has_bias (bool): Whether the cell has bias `b_ih` and `b_hh`.
3901        bidirectional (bool): Specifies whether it is a bidirectional LSTM.
3902        dropout (float): If not 0, append `Dropout` layer on the outputs of each
3903            LSTM layer except the last layer. The range of dropout is [0.0, 1.0].
3904
3905    Inputs:
3906        - **input** (Tensor) - Tensor of shape (seq_len, batch_size, `input_size`) or
3907          (batch_size, seq_len, `input_size`).
3908        - **h** (tuple) - Tensor of shape (num_directions * `num_layers`, batch_size, `hidden_size`).
3909        - **c** (tuple) - Tensor of shape (num_directions * `num_layers`, batch_size, `hidden_size`).
3910
3911    Outputs:
3912        Tuple, a tuple contains (`output`, `h_n`, `c_n`, `reserve`, `state`).
3913
3914        - **output** (Tensor) - Tensor of shape (seq_len, batch_size, num_directions * `hidden_size`).
3915        - **h_n** (Tensor) - Tensor of shape (num_directions * `num_layers`, batch_size, `hidden_size`).
3916        - **c_n** (Tensor) - Tensor of shape (num_directions * `num_layers`, batch_size, `hidden_size`).
3917        - **reserve** (Tensor) - Tensor of shape (r, 1).
3918        - **state** (Tensor) - Random number generator state and its shape is (s, 1).
3919
3920    Raises:
3921        TypeError: If `input_size`, `hidden_size` or `num_layers` is not an int.
3922        TypeError: If `has_bias` or `bidirectional` is not a bool.
3923        TypeError: If `dropout` is not a float.
3924        ValueError: If `dropout` is not in range [0.0, 1.0].
3925
3926    Supported Platforms:
3927        ``GPU`` ``CPU``
3928
3929    Examples:
3930        >>> input_size = 10
3931        >>> hidden_size = 2
3932        >>> num_layers = 1
3933        >>> seq_len = 5
3934        >>> batch_size = 2
3935        >>>
3936        >>> net = ops.LSTM(input_size, hidden_size, num_layers, True, False, 0.0)
3937        >>> input_tensor = Tensor(np.ones([seq_len, batch_size, input_size]).astype(np.float32))
3938        >>> h0 = Tensor(np.ones([num_layers, batch_size, hidden_size]).astype(np.float32))
3939        >>> c0 = Tensor(np.ones([num_layers, batch_size, hidden_size]).astype(np.float32))
3940        >>> w = Tensor(np.ones([112, 1, 1]).astype(np.float32))
3941        >>> output, hn, cn, _, _ = net(input_tensor, h0, c0, w)
3942        >>> print(output)
3943        [[[0.9640267  0.9640267 ]
3944          [0.9640267  0.9640267 ]]
3945         [[0.9950539  0.9950539 ]
3946          [0.9950539  0.9950539 ]]
3947         [[0.99932843 0.99932843]
3948          [0.99932843 0.99932843]]
3949         [[0.9999084  0.9999084 ]
3950          [0.9999084  0.9999084 ]]
3951         [[0.9999869  0.9999869 ]
3952          [0.9999869  0.9999869 ]]]
3953    """
3954
3955    @prim_attr_register
3956    def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout):
3957        """Initialize LSTM."""
3958        self.input_size = validator.check_positive_int(input_size, "input_size", self.name)
3959        self.hidden_size = validator.check_positive_int(hidden_size, "hidden_size", self.name)
3960        self.num_layers = validator.check_positive_int(num_layers, "num_layers", self.name)
3961        self.has_bias = validator.check_value_type("has_bias", has_bias, (bool,), self.name)
3962        self.bidirectional = validator.check_value_type("bidirectional", bidirectional, (bool,), self.name)
3963        self.dropout = validator.check_value_type("dropout", dropout, [float], self.name)
3964        self.dropout = validator.check_float_range(dropout, 0, 1, Rel.INC_BOTH, 'dropout', self.name)
3965
3966        if bidirectional:
3967            self.num_directions = 2
3968        else:
3969            self.num_directions = 1
3970
3971    def infer_shape(self, x_shape, h_shape, c_shape, w_shape):
3972        validator.check_equal_int(len(x_shape), 3, "x rank", self.name)
3973        validator.check_equal_int(x_shape[2], self.input_size, "x[2]", self.name)
3974
3975        # h and c should be same shape
3976        validator.check_equal_int(len(h_shape), 3, "h rank", self.name)
3977        validator.check("h_shape", h_shape, "c_shape", c_shape, Rel.EQ, self.name)
3978
3979        validator.check_int(h_shape[0], self.num_layers * self.num_directions, Rel.EQ, "h[0]", self.name)
3980        validator.check_equal_int(h_shape[1], x_shape[1], "h[1]", self.name)
3981        validator.check_int(h_shape[2], self.hidden_size, Rel.EQ, "h[2]", self.name)
3982
3983        y_shape = (x_shape[0], x_shape[1], self.hidden_size * self.num_directions)
3984
3985        # set arbitrary shape for reserved space
3986        reserved_shape = (1, 1)
3987        state_shape = (1, 1)
3988        return y_shape, h_shape, c_shape, reserved_shape, state_shape
3989
3990    def infer_dtype(self, x_dtype, h_dtype, c_dtype, w_dtype):
3991        args = {'x': x_dtype, 'h': h_dtype, 'c': c_dtype, 'w': w_dtype}
3992        validator.check_tensors_dtypes_same_and_valid(args, (mstype.float32, mstype.float16), self.name)
3993        return x_dtype, x_dtype, x_dtype, x_dtype, x_dtype
3994
3995
3996class SigmoidCrossEntropyWithLogits(PrimitiveWithInfer):
3997    r"""
3998    Uses the given logits to compute sigmoid cross entropy between the logits and the label.
3999
4000    Measures the distribution error in discrete classification tasks where each class is independent
4001    and not mutually exclusive using cross entropy loss.
4002
4003    Sets input logits as :math:`X`, input label as :math:`Y`, output as :math:`loss`. Then,
4004
4005    .. math::
4006
4007        \begin{array}{ll} \\
4008            p_{ij} = sigmoid(X_{ij}) = \frac{1}{1 + e^{-X_{ij}}} \\
4009            loss_{ij} = -[Y_{ij} * ln(p_{ij}) + (1 - Y_{ij})ln(1 - p_{ij})]
4010        \end{array}
4011
4012    Inputs:
4013        - **logits** (Tensor) - Input logits. Tensor of shape :math:`(N, *)` where :math:`*` means, any number
4014          of additional dimensions.
4015        - **label** (Tensor) - Ground truth label. With the same shape and type as `logits`.
4016
4017    Outputs:
4018        Tensor, with the same shape and type as input `logits`.
4019
4020    Raises:
4021        TypeError: If `logits` or `label` is not a Tensor.
4022
4023    Supported Platforms:
4024        ``Ascend`` ``GPU`` ``CPU``
4025
4026    Examples:
4027        >>> logits = Tensor(np.array([[-0.8, 1.2, 0.7], [-0.1, -0.4, 0.7]]).astype(np.float32))
4028        >>> labels = Tensor(np.array([[0.3, 0.8, 1.2], [-0.6, 0.1, 2.2]]).astype(np.float32))
4029        >>> sigmoid = ops.SigmoidCrossEntropyWithLogits()
4030        >>> output = sigmoid(logits, labels)
4031        >>> print(output)
4032        [[ 0.6111007   0.5032824   0.26318604]
4033         [ 0.58439666  0.5530153  -0.4368139 ]]
4034    """
4035
4036    @prim_attr_register
4037    def __init__(self):
4038        """Initialize SigmoidCrossEntropyWithLogits"""
4039        self.init_prim_io_names(inputs=['predict', 'target'], outputs=['loss'])
4040
4041    def infer_shape(self, x_shape, y_shape):
4042        validator.check("x_shape", x_shape, "y_shape", y_shape, Rel.EQ, self.name)
4043        return x_shape
4044
4045    def infer_dtype(self, x_dtype, y_dtype):
4046        args = {"x_dtype": x_dtype, "y_dtype": y_dtype}
4047        validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
4048        return x_dtype
4049
4050
4051class BCEWithLogitsLoss(PrimitiveWithInfer):
4052    r"""
4053    Adds sigmoid activation function to input `logits`, and uses the given logits to compute binary cross entropy
4054    between the logits and the label.
4055
4056    Sets input logits as :math:`X`, input label as :math:`Y`, input weight as :math:`W`, output as :math:`L`. Then,
4057
4058    .. math::
4059
4060        \begin{array}{ll} \\
4061            p_{ij} = sigmoid(X_{ij}) = \frac{1}{1 + e^{-X_{ij}}} \\
4062            L_{ij} = -[Y_{ij} * log(p_{ij}) + (1 - Y_{ij})log(1 - p_{ij})]
4063        \end{array}
4064
4065    :math:`i` indicates the :math:`i^{th}` sample, :math:`j` indicates the category. Then,
4066
4067    .. math::
4068        \ell(x, y) = \begin{cases}
4069        L, & \text{if reduction} = \text{'none';}\\
4070        \operatorname{mean}(L), & \text{if reduction} = \text{'mean';}\\
4071        \operatorname{sum}(L),  & \text{if reduction} = \text{'sum'.}
4072        \end{cases}
4073
4074    :math:`\ell` indicates the method of calculating the loss. There are three methods:
4075    the first method is to provide the loss value directly,
4076    the second method is to calculate the average value of all losses,
4077    and the third method is to calculate the sum of all losses.
4078
4079    This operator will multiply the output by the corresponding weight.
4080    The tensor weight assigns different weights to each piece of data in the batch,
4081    and the tensor pos_weight adds corresponding weights to the positive examples of each category.
4082
4083    In addition, it can trade off recall and precision by adding weights to positive examples.
4084    In the case of multi-label classification the loss can be described as:
4085
4086    .. math::
4087        \begin{array}{ll} \\
4088            p_{ij,c} = sigmoid(X_{ij,c}) = \frac{1}{1 + e^{-X_{ij,c}}} \\
4089            L_{ij,c} = -[P_{c}Y_{ij,c} * log(p_{ij,c}) + (1 - Y_{ij,c})log(1 - p_{ij,c})]
4090        \end{array}
4091
4092    where c is the class number (c>1 for multi-label binary classification, c=1 for single-label binary classification),
4093    n is the number of the sample in the batch and :math:`p_c` is the weight of the positive answer for the class c.
4094    :math:`p_c>1` increases the recall, :math:`p_c<1` increases the precision.
4095
4096    Args:
4097        reduction (str): Type of reduction to be applied to loss. The optional values are 'mean', 'sum', and 'none',
4098             not case sensitive. If 'none', do not perform reduction. Default:'mean'.
4099
4100    Inputs:
4101        - **logits** (Tensor) - Input logits. Data type must be float16 or float32.
4102          Tensor of shape :math:`(N, *)` where :math:`*` means, any number of additional dimensions.
4103        - **label** (Tensor) - Ground truth label, has the same shape as `logits`.
4104          Data type must be float16 or float32.
4105        - **weight** (Tensor) - A rescaling weight applied to the loss of each batch element. It can be
4106          broadcast to a tensor with shape of `logits`. Data type must be float16 or float32.
4107        - **pos_weight** (Tensor) - A weight of positive examples. Must be a vector with length equal to the
4108          number of classes. It can be broadcast to a tensor with shape of `logits`.
4109          Data type must be float16 or float32.
4110
4111    Outputs:
4112        Tensor or Scalar, if `reduction` is 'none', it's a tensor with the same shape and type as input `logits`.
4113        Otherwise, the output is a scalar.
4114
4115    Raises:
4116        TypeError: If data type of any input is neither float16 nor float32.
4117        ValueError: If `weight` or `pos_weight` can not be broadcast to a tensor with shape of `logits`.
4118        ValueError: If `reduction` is not one of 'none', 'mean', 'sum'.
4119
4120    Supported Platforms:
4121        ``Ascend`` ``GPU``
4122
4123    Examples:
4124        >>> logits = Tensor(np.array([[-0.8, 1.2, 0.7], [-0.1, -0.4, 0.7]]), mindspore.float32)
4125        >>> label = Tensor(np.array([[0.3, 0.8, 1.2], [-0.6, 0.1, 2.2]]), mindspore.float32)
4126        >>> weight = Tensor(np.array([1.0, 1.0, 1.0]), mindspore.float32)
4127        >>> pos_weight = Tensor(np.array([1.0, 1.0, 1.0]), mindspore.float32)
4128        >>> loss = ops.BCEWithLogitsLoss()
4129        >>> output = loss(logits, label, weight, pos_weight)
4130        >>> print(output)
4131        0.3463612
4132    """
4133
4134    @prim_attr_register
4135    def __init__(self, reduction='mean'):
4136        """Initialize BCEWithLogitsLoss"""
4137        self.reduction = validator.check_string(reduction, ['none', 'sum', 'mean'], 'reduction', self.name)
4138
4139    def infer_shape(self, logits, label, weight, pos_weight):
4140        validator.check('logits_shape', logits, 'label_shape', label, Rel.EQ, self.name)
4141        reversed_weight_shape = tuple(reversed(weight))
4142        reversed_label = tuple(reversed(logits))
4143        for i, v in enumerate(reversed_weight_shape):
4144            if v not in (reversed_label[i], 1):
4145                raise ValueError(f"For {self.name}, the shapes of 'logits' and 'weight' can not broadcast. "
4146                                 f"'logits': {tuple(logits)}, 'weight' shape {tuple(weight)}.")
4147
4148        reversed_pos_shape = tuple(reversed(pos_weight))
4149        reversed_label = tuple(reversed(logits))
4150        for i, v in enumerate(reversed_pos_shape):
4151            if v not in (reversed_label[i], 1):
4152                raise ValueError(f"For {self.name}, the shapes of 'logits' and 'pos_weight' can not broadcast. "
4153                                 f"'logits': {tuple(logits)}, 'pos_weight' shape {tuple(pos_weight)}.")
4154
4155        if self.reduction in ('mean', 'sum'):
4156            shape = []
4157        else:
4158            shape = logits
4159        return shape
4160
4161    def infer_dtype(self, logits, label, weight, pos_weight):
4162        validator.check_tensor_dtype_valid('logits dtype', logits, [mstype.float16, mstype.float32], self.name)
4163        validator.check_tensor_dtype_valid('label dtype', label, [mstype.float16, mstype.float32], self.name)
4164        validator.check_tensor_dtype_valid('weight dtype', weight, [mstype.float16, mstype.float32], self.name)
4165        validator.check_tensor_dtype_valid('pos_weight dtype', pos_weight, [mstype.float16, mstype.float32], self.name)
4166        return logits
4167
4168
4169class Pad(PrimitiveWithInfer):
4170    r"""
4171    Pads the input tensor according to the paddings.
4172    For example,
4173    to pad only the last dimension of the input tensor, then pad has the form (padding_left,padding_right);
4174    to pad the last 2 dimensions of the input tensor, then use
4175    (padding_left,padding_right, padding_top,padding_bottom);
4176    to pad the last 3 dimensions, use
4177    (padding_left,padding_right, padding_top,padding_bottom padding_front,padding_back).
4178
4179    .. math::
4180        \begin{aligned}
4181            &\text{ input_x_shape} = (N_{1},N_{2},...,N_{n}) \\
4182            &\begin{aligned}
4183                \text{output_shape = }(&N_{1}+paddings[0,0]+paddings[0,1], \\
4184                                 & N_{2}+paddings[1,0]+paddings[1,1], \\
4185                                 &... , \\
4186                                 & N_{n}+paddings[n-1,0]+paddings[n-1,1])
4187            \end{aligned}
4188        \end{aligned}
4189
4190    Args:
4191        paddings (tuple): The shape of parameter `paddings` is (N, 2). N is the rank of input data. All elements of
4192            paddings are int type. For the input in `D` th dimension, paddings[D, 0] indicates how many sizes to be
4193            extended ahead of the input tensor in the `D` th dimension, and paddings[D, 1] indicates how many sizes to
4194            be extended behind the input tensor in the `D` th dimension.
4195
4196    Inputs:
4197        - **input_x** (Tensor) - Tensor of shape :math:`(N, *)`, where :math:`*` means, any number of
4198          additional dimensions.
4199
4200    Outputs:
4201        Tensor, the tensor after padding.
4202
4203    Raises:
4204        TypeError: If `paddings` is not a tuple.
4205        TypeError: If `input_x` is not a Tensor.
4206        ValueError: If shape of `paddings` is not :math:`(N, 2)`.
4207        ValueError: If paddings.size is not equal to 2 * len(input_x).
4208
4209    Supported Platforms:
4210        ``Ascend`` ``GPU`` ``CPU``
4211
4212    Examples:
4213        >>> input_x = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32)
4214        >>> pad_op = ops.Pad(((1, 2), (2, 1)))
4215        >>> output = pad_op(input_x)
4216        >>> print(output)
4217        [[ 0.   0.   0.   0.   0.   0. ]
4218         [ 0.   0.  -0.1  0.3  3.6  0. ]
4219         [ 0.   0.   0.4  0.5 -3.2  0. ]
4220         [ 0.   0.   0.   0.   0.   0. ]
4221         [ 0.   0.   0.   0.   0.   0. ]]
4222    """
4223
4224    @prim_attr_register
4225    def __init__(self, paddings):
4226        """Initialize Pad"""
4227        self.init_prim_io_names(inputs=['x'], outputs=['y'])
4228        if not isinstance(paddings, tuple):
4229            raise TypeError(f"For '{self.name}', the type of 'paddings' must be tuple, "
4230                            f"but got {type(paddings)}.")
4231        for item in paddings:
4232            if len(item) != 2:
4233                raise ValueError(f"For '{self.name}', the shape of 'paddings' must be (n, 2), "
4234                                 f"but got {paddings}.")
4235        self.paddings = paddings
4236
4237    def infer_shape(self, x_shape):
4238        validator.check_int(len(self.paddings), len(x_shape), Rel.EQ, 'paddings.shape', self.name)
4239        paddings = np.array(self.paddings)
4240        if not np.all(paddings >= 0):
4241            raise ValueError(f"For '{self.name}', all elements of paddings must be >= 0.")
4242        y_shape = ()
4243        for i in range(int(paddings.size / 2)):
4244            y_shape += ((x_shape[i] + paddings[i, 0] + paddings[i, 1]),)
4245        return y_shape
4246
4247    def infer_dtype(self, x_dtype):
4248        validator.check_subclass("input_x", x_dtype, mstype.tensor, self.name)
4249        return x_dtype
4250
4251
4252class MirrorPad(PrimitiveWithInfer):
4253    """
4254    Pads the input tensor according to the paddings and mode.
4255
4256    Args:
4257        mode (str): Specifies the padding mode. The optional values are "REFLECT" and "SYMMETRIC".
4258            Default: "REFLECT".
4259
4260    Inputs:
4261        - **input_x** (Tensor) - Tensor of shape :math:`(N, *)`, where :math:`*` means, any number of
4262          additional dimensions.
4263        - **paddings** (Tensor) - The paddings tensor. The value of `paddings` is a matrix(list),
4264          and its shape is (N, 2). N is the rank of input data. All elements of paddings
4265          are int type. For the input in the `D` th dimension, paddings[D, 0] indicates how many sizes to be
4266          extended ahead of the input tensor in the `D` th dimension, and paddings[D, 1] indicates how many sizes to
4267          be extended behind the input tensor in the `D` th dimension.
4268
4269    Outputs:
4270        Tensor, the tensor after padding.
4271
4272        - If `mode` is "REFLECT", it uses a way of symmetrical copying through the axis of symmetry to fill in.
4273          If the `input_x` is [[1,2,3], [4,5,6], [7,8,9]] and `paddings` is [[1,1], [2,2]], then the
4274          Outputs is [[6,5,4,5,6,5,4], [3,2,1,2,3,2,1], [6,5,4,5,6,5,4], [9,8,7,8,9,8,7], [6,5,4,5,6,5,4]].
4275        - If `mode` is "SYMMETRIC", the filling method is similar to the "REFLECT". It is also copied
4276          according to the symmetry axis, except that it includes the symmetry axis. If the `input_x`
4277          is [[1,2,3], [4,5,6], [7,8,9]] and `paddings` is [[1,1], [2,2]], then the Outputs is
4278          [[2,1,1,2,3,3,2], [2,1,1,2,3,3,2], [5,4,4,5,6,6,5], [8,7,7,8,9,9,8], [8,7,7,8,9,9,8]].
4279
4280    Raises:
4281        TypeError: If `input_x` or `paddings` is not a Tensor.
4282        TypeError: If `mode` is not a str.
4283        ValueError: If paddings.size is not equal to 2 * len(input_x).
4284
4285    Supported Platforms:
4286        ``Ascend`` ``GPU`` ``CPU``
4287
4288    Examples:
4289        >>> # case1: mode="REFLECT"
4290        >>> class Net(nn.Cell):
4291        ...    def __init__(self, mode):
4292        ...        super(Net, self).__init__()
4293        ...        self.pad = ops.MirrorPad(mode=mode)
4294        ...        self.paddings = Tensor([[1, 1], [2, 2]])
4295        ...    def construct(self, input_x):
4296        ...        return self.pad(input_x, self.paddings)
4297        ...
4298        >>> input_x = Tensor([[1,2,3], [4,5,6], [7,8,9]])
4299        >>> pad = Net("REFLECT")
4300        >>> output = pad(input_x)
4301        >>> print(output)
4302        [[6 5 4 5 6 5 4]
4303         [3 2 1 2 3 2 1]
4304         [6 5 4 5 6 5 4]
4305         [9 8 7 8 9 8 7]
4306         [6 5 4 5 6 5 4]]
4307        >>> # case2: mode="SYMMETRIC"
4308        >>> pad = Net("SYMMETRIC")
4309        >>> output = pad(input_x)
4310        >>> print(output)
4311        [[2 1 1 2 3 3 2]
4312         [2 1 1 2 3 3 2]
4313         [5 4 4 5 6 6 5]
4314         [8 7 7 8 9 9 8]
4315         [8 7 7 8 9 9 8]]
4316    """
4317
4318    @prim_attr_register
4319    def __init__(self, mode='REFLECT'):
4320        """Initialize Pad"""
4321        validator.check_string(mode, ['REFLECT', 'SYMMETRIC'], 'mode', self.name)
4322        self.mode = mode
4323        self.set_const_input_indexes([1])
4324
4325    def __infer__(self, input_x, paddings):
4326        validator.check_subclass("input_x", input_x['dtype'], mstype.tensor, self.name)
4327        validator.check_subclass("paddings", paddings['dtype'], mstype.tensor, self.name)
4328        x_shape = list(input_x['shape'])
4329        paddings_value = paddings['value'].asnumpy()
4330        paddings_size = paddings_value.size
4331        validator.check_int(paddings_size, len(x_shape) * 2, Rel.EQ, 'paddings.shape', self.name)
4332        if not np.all(paddings_value >= 0):
4333            raise ValueError(f"For '{self.name}', all elements of 'paddings' must be >= 0.")
4334        adjust = 0
4335        if self.mode == 'SYMMETRIC':
4336            adjust = 1
4337        for i in range(0, int(paddings_size / 2)):
4338            if (paddings_value[i, 0] >= x_shape[i] + adjust) or (paddings_value[i, 1] >= x_shape[i] + adjust):
4339                msg = "x_shape[D] + 1" if adjust == 1 else "x_shape[D]"
4340                paddings_info_value = paddings['value']
4341                raise ValueError(f"For '{self.name}', both paddings[D, 0] and paddings[D, 1] must be less than {msg}, "
4342                                 f"but got paddings[{i}, 0]: {paddings_info_value[i, 0]}, "
4343                                 f"paddings[{i}, 1]: {paddings_info_value[i, 1]}, x_shape[{i}]: {x_shape[i]}.")
4344        y_shape = ()
4345        for i in range(0, int(paddings_size / 2)):
4346            y_shape += ((x_shape[i] + paddings_value[i, 0] + paddings_value[i, 1]),)
4347        return {'shape': y_shape,
4348                'dtype': input_x['dtype'],
4349                'value': None}
4350
4351
4352class ComputeAccidentalHits(PrimitiveWithCheck):
4353    r"""
4354    Compute accidental hits of sampled classes which match target classes.
4355
4356    When a target class matches the sample class, we call it "accidental hit".
4357    The result of calculating accidental hits contain three parts (index, id, weight),
4358    where index represents the row number in true_classes, and id represents the position in sampled_candidates,
4359    the weight is -FLOAT_MAX. FLOAT_MAX indicates the max value in the type of Float
4360
4361    Args:
4362        num_true (int): The number of target classes per training example. Default: 1.
4363
4364    Inputs:
4365        - **true_classes** (Tensor) - The target classes. With data type of int32 or int64
4366          and shape :math:`(batch\_size, num\_true)`.
4367        - **sampled_candidates** (Tensor) - The Candidate sampling results of operators, types of training samples,
4368          with data type of int32 or int64 and shape :math:`(num\_sampled, )`.
4369
4370    Outputs:
4371        Tuple of 3 Tensors.
4372
4373        - **indices** (Tensor) - A Tensor with shape :math:`(num\_accidental\_hits, )`,
4374          with the same type as `true_classes`.
4375        - **ids** (Tensor) - A Tensor with shape :math:`(num\_accidental\_hits, )`,
4376          with the same type as `true_classes`.
4377        - **weights** (Tensor) - A Tensor with shape :math:`(num\_accidental\_hits, )`, with the type float32.
4378
4379    Raises:
4380        TypeError: If dtype of `num_true` is not int.
4381        TypeError: If `true_classes` or `sampled_candidates` is not a Tensor.
4382        TypeError: If dtype of `true_classes` or `sampled_candidates` is neither int32 nor int64.
4383
4384    Supported Platforms:
4385        ``Ascend``
4386
4387    Examples:
4388        >>> true_classes = np.array([[1, 2], [0, 4], [3, 3]])
4389        >>> sampled_candidates = np.array([0, 1, 2, 3, 4])
4390        >>> sampler = ops.ComputeAccidentalHits(2)
4391        >>> indices, ids, weights = sampler(Tensor(true_classes), Tensor(sampled_candidates))
4392        >>> print(indices, ids, weights)
4393        [0 0 1 1 2 2]
4394        [1 2 0 4 3 3]
4395        [-3.4028235e+38 -3.4028235e+38 -3.4028235e+38 -3.4028235e+38 -3.4028235e+38 -3.4028235e+38]
4396
4397    """
4398
4399    @prim_attr_register
4400    def __init__(self, num_true=1):
4401        """Initialize ComputeAccidentalHits"""
4402        self.init_prim_io_names(inputs=['true_classes', 'sampled_candidates'],
4403                                outputs=['indices', 'ids', 'weights'])
4404        validator.check_value_type("num_true", num_true, [int], self.name)
4405        validator.check_number("num_true", num_true, 1, Rel.GE, self.name)
4406        self.num_true = num_true
4407
4408    def check_shape(self, true_classes_shape, sampled_candidates_shape):
4409        validator.check_int(len(true_classes_shape), 2, Rel.EQ, 'dim of true_classes', self.name)
4410        validator.check_int(len(sampled_candidates_shape), 1, Rel.EQ, 'dim of sampled_candidates', self.name)
4411        validator.check("true_classes shape[1]", true_classes_shape[1], "num_true", self.num_true, Rel.EQ, self.name)
4412
4413        indices_len = -1
4414        return (indices_len,), (indices_len,), (indices_len,)
4415
4416    def check_dtype(self, true_classes_type, sampled_candidates_type):
4417        validator.check_subclass("true_classes_type", true_classes_type, mstype.tensor, self.name)
4418        validator.check_subclass("sampled_candidates_type", sampled_candidates_type, mstype.tensor, self.name)
4419        valid_types = (mstype.int32, mstype.int64)
4420        validator.check_tensor_dtype_valid("true_classes_type", true_classes_type, valid_types, self.name)
4421        validator.check_tensor_dtype_valid("sampled_candidates_type", sampled_candidates_type, valid_types, self.name)
4422        weights_type = mstype.float32
4423        return true_classes_type, true_classes_type, weights_type
4424
4425
4426class ROIAlign(PrimitiveWithInfer):
4427    r"""
4428    Computes the Region of Interest (RoI) Align operator.
4429
4430    The operator computes the value of each sampling point by bilinear interpolation from the nearby grid points on the
4431    feature map. No quantization is performed on any coordinates involved in the RoI, its bins, or the sampling
4432    points. The details of (RoI) Align operator are described in `Mask R-CNN <https://arxiv.org/abs/1703.06870>`_.
4433
4434    Args:
4435        pooled_height (int): The output features height.
4436        pooled_width (int): The output features width.
4437        spatial_scale (float): A scaling factor that maps the raw image coordinates to the input
4438            feature map coordinates. Suppose the height of a RoI is `ori_h` in the raw image and `fea_h` in the
4439            input feature map, the `spatial_scale` must be `fea_h / ori_h`.
4440        sample_num (int): Number of sampling points. Default: 2.
4441        roi_end_mode (int): Number must be 0 or 1. Default: 1.
4442
4443    Inputs:
4444        - **features** (Tensor) - The input features, whose shape must be :math:`(N, C, H, W)`.
4445        - **rois** (Tensor) - The shape is :math:`(rois\_n, 5)`. With data type of float16 or float32.
4446          `rois_n` represents the number of RoI. The size of the second dimension must be `5` and the `5` colunms
4447          are :math:`(image\_index, top\_left\_x, top\_left\_y, bottom\_right\_x, bottom\_right\_y)`.
4448          `image_index` represents the index of image. `top_left_x` and `top_left_y` represent the `x, y`
4449          coordinates of the top left corner of corresponding RoI, respectively. `bottom_right_x` and `bottom_right_y`
4450          represent the `x, y` coordinates of the bottom right corner of corresponding RoI, respectively.
4451
4452    Outputs:
4453        Tensor, the shape is :math:`(rois\_n, C, pooled\_height, pooled\_width)`.
4454
4455    Raises:
4456        TypeError: If `pooled_height`, `pooled_width`, `sample_num` or `roi_end_mode` is not an int.
4457        TypeError: If `spatial_scale` is not a float.
4458        TypeError: If `features` or `rois` is not a Tensor.
4459
4460    Supported Platforms:
4461        ``Ascend`` ``GPU`` ``CPU``
4462
4463    Examples:
4464        >>> features = Tensor(np.array([[[[1., 2.], [3., 4.]]]]), mindspore.float32)
4465        >>> rois = Tensor(np.array([[0, 0.2, 0.3, 0.2, 0.3]]), mindspore.float32)
4466        >>> roi_align = ops.ROIAlign(2, 2, 0.5, 2)
4467        >>> output = roi_align(features, rois)
4468        >>> print(output)
4469        [[[[1.775 2.025]
4470           [2.275 2.525]]]]
4471    """
4472
4473    @prim_attr_register
4474    def __init__(self, pooled_height, pooled_width, spatial_scale, sample_num=2, roi_end_mode=1):
4475        """Initialize ROIAlign"""
4476        validator.check_value_type("pooled_height", pooled_height, [int], self.name)
4477        validator.check_value_type("pooled_width", pooled_width, [int], self.name)
4478        validator.check_value_type("spatial_scale", spatial_scale, [float], self.name)
4479        validator.check_value_type("sample_num", sample_num, [int], self.name)
4480        validator.check_value_type("roi_end_mode", roi_end_mode, [int], self.name)
4481        validator.check_int_range(roi_end_mode, 0, 1, Rel.INC_BOTH, "roi_end_mode", self.name)
4482        self.pooled_height = pooled_height
4483        self.pooled_width = pooled_width
4484        self.spatial_scale = spatial_scale
4485        self.sample_num = sample_num
4486        self.roi_end_mode = roi_end_mode
4487
4488    def infer_shape(self, inputs_shape, rois_shape):
4489        validator.check("input shape rank", len(inputs_shape), "", 4, Rel.LE, self.name)
4490        return [rois_shape[0], inputs_shape[1], self.pooled_height, self.pooled_width]
4491
4492    def infer_dtype(self, inputs_type, rois_type):
4493        valid_dtypes = (mstype.float16, mstype.float32)
4494        validator.check_tensor_dtype_valid("inputs_type", inputs_type, valid_dtypes, self.name)
4495        validator.check_tensor_dtype_valid("rois_type", rois_type, valid_dtypes, self.name)
4496        return inputs_type
4497
4498
4499class Adam(PrimitiveWithInfer):
4500    r"""
4501    Updates gradients by the Adaptive Moment Estimation (Adam) algorithm.
4502
4503    The Adam algorithm is proposed in `Adam: A Method for Stochastic Optimization <https://arxiv.org/abs/1412.6980>`_.
4504
4505    For more details, please refer to :class:`nn.Adam`.
4506
4507    The updating formulas are as follows,
4508
4509    .. math::
4510        \begin{array}{ll} \\
4511            m = \beta_1 * m + (1 - \beta_1) * g \\
4512            v = \beta_2 * v + (1 - \beta_2) * g * g \\
4513            l = \alpha * \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t} \\
4514            w = w - l * \frac{m}{\sqrt{v} + \epsilon}
4515        \end{array}
4516
4517    :math:`m` represents the 1st moment vector, :math:`v` represents the 2nd moment vector, :math:`g` represents
4518    `gradient`, :math:`l` represents scaling factor `lr`, :math:`\beta_1, \beta_2` represent `beta1` and `beta2`,
4519    :math:`t` represents updating step while :math:`beta_1^t(\beta_1^{t})` and :math:`beta_2^t(\beta_2^{t})`
4520    represent `beta1_power` and `beta2_power`, :math:`\alpha` represents `learning_rate`, :math:`w` represents `var`,
4521    :math:`\epsilon` represents
4522    `epsilon`.
4523
4524    Args:
4525        use_locking (bool): Whether to enable a lock to protect variable tensors from being updated.
4526            If true, updates of the var, m, and v tensors will be protected by a lock.
4527            If false, the result is unpredictable. Default: False.
4528        use_nesterov (bool): Whether to use Nesterov Accelerated Gradient (NAG) algorithm to update the gradients.
4529            If true, update the gradients using NAG.
4530            If false, update the gradients without using NAG. Default: False.
4531
4532    Inputs:
4533        - **var** (Tensor) - Weights to be updated. The shape is :math:`(N, *)` where :math:`*` means,
4534          any number of additional dimensions. The data type can be float16 or float32.
4535        - **m** (Tensor) - The 1st moment vector in the updating formula,
4536          the shape and data type value should be the same as `var`.
4537        - **v** (Tensor) - the 2nd moment vector in the updating formula,
4538          the shape and data type value should be the same as `var`. Mean square gradients with the same type as `var`.
4539        - **beta1_power** (float) - :math:`beta_1^t(\beta_1^{t})` in the updating formula,
4540          the data type value should be the same as `var`.
4541        - **beta2_power** (float) - :math:`beta_2^t(\beta_2^{t})` in the updating formula,
4542          the data type value should be the same as `var`.
4543        - **lr** (float) - :math:`l` in the updating formula. The paper suggested value is :math:`10^{-8}`,
4544          the data type value should be the same as `var`.
4545        - **beta1** (float) - The exponential decay rate for the 1st moment estimations,
4546          the data type value should be the same as `var`. The paper suggested value is :math:`0.9`
4547        - **beta2** (float) - The exponential decay rate for the 2nd moment estimations,
4548          the data type value should be the same as `var`. The paper suggested value is :math:`0.999`
4549        - **epsilon** (float) - Term added to the denominator to improve numerical stability.
4550        - **gradient** (Tensor) - Gradient, has the same shape and data type as `var`.
4551
4552    Outputs:
4553        Tuple of 3 Tensor, the updated parameters.
4554
4555        - **var** (Tensor) - The same shape and data type as Inputs `var`.
4556        - **m** (Tensor) - The same shape and data type as Inputs `m`.
4557        - **v** (Tensor) - The same shape and data type as Inputs `v`.
4558
4559    Raises:
4560        TypeError: If neither `use_locking` nor `use_nesterov` is a bool.
4561        TypeError: If `var`, `m` or `v` is not a Tensor.
4562        TypeError: If `beta1_power`, `beta2_power1`, `lr`, `beta1`, `beta2`, `epsilon` or `gradient` is not a Tensor.
4563
4564    Supported Platforms:
4565        ``Ascend`` ``GPU`` ``CPU``
4566
4567    Examples:
4568        >>> class Net(nn.Cell):
4569        ...     def __init__(self):
4570        ...         super(Net, self).__init__()
4571        ...         self.apply_adam = ops.Adam()
4572        ...         self.var = Parameter(Tensor(np.ones([2, 2]).astype(np.float32)), name="var")
4573        ...         self.m = Parameter(Tensor(np.ones([2, 2]).astype(np.float32)), name="m")
4574        ...         self.v = Parameter(Tensor(np.ones([2, 2]).astype(np.float32)), name="v")
4575        ...     def construct(self, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad):
4576        ...         out = self.apply_adam(self.var, self.m, self.v, beta1_power, beta2_power, lr, beta1, beta2,
4577        ...                               epsilon, grad)
4578        ...         return out
4579        ...
4580        >>> net = Net()
4581        >>> gradient = Tensor(np.ones([2, 2]).astype(np.float32))
4582        >>> output = net(0.9, 0.999, 0.001, 0.9, 0.999, 1e-8, gradient)
4583        >>> print(net.var.asnumpy())
4584        [[0.9996838 0.9996838]
4585         [0.9996838 0.9996838]]
4586    """
4587
4588    @prim_attr_register
4589    def __init__(self, use_locking=False, use_nesterov=False):
4590        """Initialize Adam."""
4591        validator.check_value_type("use_locking", use_locking, [bool], self.name)
4592        validator.check_value_type("use_nesterov", use_nesterov, [bool], self.name)
4593        self.add_prim_attr('side_effect_mem', True)
4594
4595    def infer_shape(self, var_shape, m_shape, v_shape, beta1_power_shape, beta2_power_shape, lr_shape,
4596                    beta1_shape, beta2_shape, epsilon_shape, grad_shape):
4597        validator.check("var_shape", var_shape, "m_shape", m_shape, Rel.EQ, self.name)
4598        validator.check("var_shape", var_shape, "v_shape", v_shape, Rel.EQ, self.name)
4599        validator.check("var_shape", var_shape, "grad_shape", grad_shape, Rel.EQ, self.name)
4600        return var_shape, m_shape, v_shape
4601
4602    def infer_dtype(self, var_dtype, m_dtype, v_dtype, beta1_power_dtype, beta2_power_dtype, lr_dtype,
4603                    beta1_dtype, beta2_dtype, epsilon_dtype, grad_dtype):
4604        args = {"var": var_dtype, "m": m_dtype, "v": v_dtype, "grad": grad_dtype}
4605        validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
4606
4607        args = {"beta1_power": beta1_power_dtype, "beta2_power": beta2_power_dtype, 'lr': lr_dtype,
4608                "beta1": beta1_dtype, "beta2": beta2_dtype, "epsilon": epsilon_dtype}
4609        validator.check_scalar_or_tensor_types_same(args, [mstype.float16, mstype.float32], self.name, True)
4610        return var_dtype, m_dtype, v_dtype
4611
4612
4613class AdamWeightDecay(PrimitiveWithInfer):
4614    r"""
4615    Updates gradients by the Adaptive Moment Estimation (AdamWeightDecay) algorithm with weight decay.
4616
4617    The Adam algorithm is proposed in `Adam: A Method for Stochastic Optimization <https://arxiv.org/abs/1412.6980>`_.
4618    The AdamWeightDecay variant was proposed in `Decoupled Weight Decay Regularization
4619    <https://arxiv.org/abs/1711.05101>`_.
4620
4621    The updating formulas are as follows,
4622
4623    .. math::
4624        \begin{array}{ll} \\
4625            m = \beta_1 * m + (1 - \beta_1) * g \\
4626            v = \beta_2 * v + (1 - \beta_2) * g * g \\
4627            update = \frac{m}{\sqrt{v} + eps} \\
4628            update =
4629            \begin{cases}
4630                update + weight\_decay * w
4631                    & \text{ if } weight\_decay > 0 \\
4632                update
4633                    & \text{ otherwise }
4634            \end{cases} \\
4635            w  = w - lr * update
4636        \end{array}
4637
4638    :math:`m` represents the 1st moment vector, :math:`v` represents the 2nd moment vector, :math:`g` represents
4639    `gradient`, :math:`\beta_1, \beta_2` represent `beta1` and `beta2`,
4640    :math:`lr` represents `learning_rate`, :math:`w` represents `var`, :math:`decay` represents `weight_decay`,
4641    :math:`\epsilon` represents `epsilon`.
4642
4643    Args:
4644        use_locking (bool): Whether to enable a lock to protect variable tensors from being updated.
4645            If true, updates of the var, m, and v tensors will be protected by a lock.
4646            If false, the result is unpredictable. Default: False.
4647
4648    Inputs:
4649        - **var** (Tensor) - Weights to be updated. The shape is :math:`(N, *)` where :math:`*` means,
4650          any number of additional dimensions. The data type can be float16 or float32.
4651        - **m** (Tensor) - The 1st moment vector in the updating formula,
4652          the shape and data type value should be the same as `var`.
4653        - **v** (Tensor) - the 2nd moment vector in the updating formula,
4654          the shape and data type value should be the same as `var`. Mean square gradients with the same type as `var`.
4655        - **lr** (float) - :math:`l` in the updating formula. The paper suggested value is :math:`10^{-8}`,
4656          the data type value should be the same as `var`.
4657        - **beta1** (float) - The exponential decay rate for the 1st moment estimations,
4658          the data type value should be the same as `var`. The paper suggested value is :math:`0.9`
4659        - **beta2** (float) - The exponential decay rate for the 2nd moment estimations,
4660          the data type value should be the same as `var`. The paper suggested value is :math:`0.999`
4661        - **epsilon** (float) - Term added to the denominator to improve numerical stability.
4662        - **decay** (float) - The weight decay value, must be a scalar tensor with float data type.
4663          Default: 0.0.
4664        - **gradient** (Tensor) - Gradient, has the same shape and data type as `var`.
4665    Outputs:
4666        Tuple of 3 Tensor, the updated parameters.
4667
4668        - **var** (Tensor) - The same shape and data type as `var`.
4669        - **m** (Tensor) - The same shape and data type as `m`.
4670        - **v** (Tensor) - The same shape and data type as `v`.
4671
4672    Supported Platforms:
4673        ``GPU`` ``CPU``
4674
4675    Examples:
4676        >>> import numpy as np
4677        >>> import mindspore.nn as nn
4678        >>> from mindspore import Tensor, Parameter, ops
4679        >>> class Net(nn.Cell):
4680        ...     def __init__(self):
4681        ...         super(Net, self).__init__()
4682        ...         self.adam_weight_decay = ops.AdamWeightDecay()
4683        ...         self.var = Parameter(Tensor(np.ones([2, 2]).astype(np.float32)), name="var")
4684        ...         self.m = Parameter(Tensor(np.ones([2, 2]).astype(np.float32)), name="m")
4685        ...         self.v = Parameter(Tensor(np.ones([2, 2]).astype(np.float32)), name="v")
4686        ...     def construct(self, lr, beta1, beta2, epsilon, decay, grad):
4687        ...         out = self.adam_weight_decay(self.var, self.m, self.v, lr, beta1, beta2,
4688        ...                               epsilon, decay, grad)
4689        ...         return out
4690        >>> net = Net()
4691        >>> gradient = Tensor(np.ones([2, 2]).astype(np.float32))
4692        >>> output = net(0.001, 0.9, 0.999, 1e-8, 0.0, gradient)
4693        >>> print(net.var.asnumpy())
4694    """
4695
4696    @prim_attr_register
4697    def __init__(self, use_locking=False):
4698        """Initialize AdamWeightDecay."""
4699        self.add_prim_attr('side_effect_mem', True)
4700        validator.check_value_type("use_locking", use_locking, [bool], self.name)
4701
4702    def infer_shape(self, var_shape, m_shape, v_shape, lr_shape, beta1_shape, beta2_shape,
4703                    epsilon_shape, decay_shape, grad_shape):
4704        validator.check("var_shape", var_shape, "m_shape", m_shape, Rel.EQ, self.name)
4705        validator.check("var_shape", var_shape, "v_shape", v_shape, Rel.EQ, self.name)
4706        validator.check("var_shape", var_shape, "grad_shape", grad_shape, Rel.EQ, self.name)
4707        return var_shape, m_shape, v_shape
4708
4709    def infer_dtype(self, var_dtype, m_dtype, v_dtype, lr_dtype, beta1_dtype, beta2_dtype,
4710                    epsilon_dtype, decay_dtype, grad_dtype):
4711        args = {"var": var_dtype, "m": m_dtype, "v": v_dtype, "grad": grad_dtype}
4712        validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
4713
4714        args = {"lr": lr_dtype, "beta1": beta1_dtype, "beta2": beta2_dtype, "epsilon": epsilon_dtype,
4715                "decay": decay_dtype}
4716        validator.check_scalar_or_tensor_types_same(args, [mstype.float32], self.name, True)
4717        return var_dtype, m_dtype, v_dtype
4718
4719
4720class AdamNoUpdateParam(PrimitiveWithInfer):
4721    r"""
4722    Updates gradients by Adaptive Moment Estimation (Adam) algorithm. This operator do not update the parameter, but
4723    calculate the value that should be added to the parameter instead.
4724
4725    The Adam algorithm is proposed in `Adam: A Method for Stochastic Optimization <https://arxiv.org/abs/1412.6980>`_.
4726
4727    The updating formulas are as follows,
4728
4729    .. math::
4730        \begin{array}{ll} \\
4731            m = \beta_1 * m + (1 - \beta_1) * g \\
4732            v = \beta_2 * v + (1 - \beta_2) * g * g \\
4733            l = \alpha * \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t} \\
4734            \Delta{w} = - l * \frac{m}{\sqrt{v} + \epsilon}
4735        \end{array}
4736
4737    :math:`m` represents the 1st moment vector, :math:`v` represents the 2nd moment vector, :math:`g` represents
4738    `gradient`, :math:`l` represents scaling factor `lr`, :math:`\beta_1, \beta_2` represent `beta1` and `beta2`,
4739    :math:`t` represents updating step while :math:`beta_1^t(\beta_1^{t})` and :math:`beta_2^t(\beta_2^{t})`
4740    represent `beta1_power` and `beta2_power`, :math:`\alpha` represents `learning_rate`,
4741    :math:`w` represents the parameter to be updated, :math:`\epsilon` represents `epsilon`.
4742
4743    Args:
4744        use_locking (bool): Whether to enable a lock to protect variable tensors from being updated.
4745            If true, updates of the var, m, and v tensors will be protected by a lock.
4746            If false, the result is unpredictable. Default: False.
4747        use_nesterov (bool): Whether to use Nesterov Accelerated Gradient (NAG) algorithm to update the gradients.
4748            If true, update the gradients using NAG.
4749            If false, update the gradients without using NAG. Default: False.
4750
4751    Inputs:
4752        - **m** (Tensor) - The 1st moment vector in the updating formula. The shape is :math:`(N, *)`
4753          where :math:`*` means, any number of additional dimensions. The data type must be float32.
4754        - **v** (Tensor) - the 2nd moment vector in the updating formula. The shape must be the same as `m`.
4755          The data type must be float32.
4756        - **beta1_power** (Tensor) - :math:`beta_1^t(\beta_1^{t})` in the updating formula.
4757          The shape is :math:`(1, )` and the data type must be float32.
4758        - **beta2_power** (Tensor) - :math:`beta_2^t(\beta_1^{t})` in the updating formula.
4759          The shape is :math:`(1, )` and the data type must be float32.
4760        - **lr** (Tensor) - :math:`l` in the updating formula.
4761          The shape is :math:`(1, )` and the data type must be float32.
4762        - **beta1** (Tensor) - The exponential decay rate for the 1st moment estimations.
4763          The shape is :math:`(1, )` and the data type must be float32.
4764        - **beta2** (Tensor) - The exponential decay rate for the 2nd moment estimations.
4765          The shape is :math:`(1, )` and the data type must be float32.
4766        - **epsilon** (Tensor) - Term added to the denominator to improve numerical stability.
4767          The shape is :math:`(1, )` and the data type must be float32.
4768        - **gradient** (Tensor) - Gradient, the shape must be the same as `m`, the data type must be float32.
4769
4770    Outputs:
4771        Tensor, whose shape and data type are the same with Inputs `gradient`, is a value that should be added to the
4772        parameter to be updated.
4773
4774    Raises:
4775        TypeError: If neither `use_locking` nor `use_nesterov` is a bool.
4776        TypeError: If `m`,  `v`, `beta1_power`, `beta2_power1`, `lr`, `beta1`, `beta2`, `epsilon` or `gradient`
4777                   is not a Tensor.
4778
4779    Supported Platforms:
4780        ``CPU``
4781
4782    Examples:
4783        >>> class Net(nn.Cell):
4784        ...     def __init__(self):
4785        ...         super(Net, self).__init__()
4786        ...         self.adam = ops.AdamNoUpdateParam()
4787        ...         self.m = Parameter(Tensor(np.array([[0.1, 0.1, 0.1], [0.2, 0.2, 0.2]]).astype(np.float32)),
4788        ...                            name="m")
4789        ...         self.v = Parameter(Tensor(np.array([[0.1, 0.1, 0.1], [0.2, 0.2, 0.2]]).astype(np.float32)),
4790        ...                            name="v")
4791        ...     def construct(self, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad):
4792        ...         out = self.adam(self.m, self.v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad)
4793        ...         return out
4794        >>> net = Net()
4795        >>> beta1_power = Tensor(0.9, ms.float32)
4796        >>> beta2_power = Tensor(0.999, ms.float32)
4797        >>> lr = Tensor(0.001, ms.float32)
4798        >>> beta1 = Tensor(0.9, ms.float32)
4799        >>> beta2 = Tensor(0.999, ms.float32)
4800        >>> epsilon = Tensor(1e-8, ms.float32)
4801        >>> gradient = Tensor(np.array([[0.1, 0.1, 0.1], [0.1, 0.1, 0.1]]).astype(np.float32))
4802        >>> result = net(beta1_power, beta2_power, lr, beta1, beta2, epsilon, gradient)
4803        >>> print(result)
4804        [[-0.00010004 -0.00010004 -0.00010004]
4805        [-0.00013441 -0.00013441 -0.00013441]]
4806
4807    """
4808
4809    @prim_attr_register
4810    def __init__(self, use_locking=False, use_nesterov=False):
4811        """Initialize AdamNoUpdateParam."""
4812        validator.check_value_type("use_locking", use_locking, [bool], self.name)
4813        validator.check_value_type("use_nesterov", use_nesterov, [bool], self.name)
4814
4815    def infer_shape(self, m_shape, v_shape, beta1_power_shape, beta2_power_shape, lr_shape,
4816                    beta1_shape, beta2_shape, epsilon_shape, grad_shape):
4817        validator.check("grad_shape", grad_shape, "m_shape", m_shape, Rel.EQ, self.name)
4818        validator.check("grad_shape", grad_shape, "v_shape", v_shape, Rel.EQ, self.name)
4819        return grad_shape
4820
4821    def infer_dtype(self, m_dtype, v_dtype, beta1_power_dtype, beta2_power_dtype, lr_dtype,
4822                    beta1_dtype, beta2_dtype, epsilon_dtype, grad_dtype):
4823        args = {"m": m_dtype, "v": v_dtype, "grad": grad_dtype,
4824                "beta1_power": beta1_power_dtype, "beta2_power": beta2_power_dtype, 'lr': lr_dtype,
4825                "beta1": beta1_dtype, "beta2": beta2_dtype, "epsilon": epsilon_dtype}
4826        validator.check_tensors_dtypes_same_and_valid(args, [mstype.float32], self.name)
4827        return grad_dtype
4828
4829
4830class FusedSparseAdam(PrimitiveWithInfer):
4831    r"""
4832    Merges the duplicate value of the gradient and then updates parameters by the Adaptive Moment Estimation (Adam)
4833    algorithm. This operator is used when the gradient is sparse.
4834
4835    The Adam algorithm is proposed in `Adam: A Method for Stochastic Optimization <https://arxiv.org/abs/1412.6980>`_.
4836
4837    The updating formulas are as follows,
4838
4839    .. math::
4840        \begin{array}{ll} \\
4841            m = \beta_1 * m + (1 - \beta_1) * g \\
4842            v = \beta_2 * v + (1 - \beta_2) * g * g \\
4843            l = \alpha * \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t} \\
4844            w = w - l * \frac{m}{\sqrt{v} + \epsilon}
4845        \end{array}
4846
4847    :math:`m` represents the 1st moment vector, :math:`v` represents the 2nd moment vector, :math:`g` represents
4848    `gradient`, :math:`l` represents scaling factor `lr`, :math:`\beta_1, \beta_2` represent `beta1` and `beta2`,
4849    :math:`t` represents updating step while :math:`beta_1^t` and :math:`beta_2^t` represent `beta1_power` and
4850    `beta2_power`, :math:`\alpha` represents `learning_rate`, :math:`w` represents `var`, :math:`\epsilon` represents
4851    `epsilon`.
4852
4853    All of inputs except `indices` comply with the implicit type conversion rules to make the data types consistent.
4854    If they have different data types, lower priority data type will be converted to
4855    relatively highest priority data type.
4856    RuntimeError exception will be thrown when the data type conversion of Parameter is required.
4857
4858    Args:
4859        use_locking (bool): Whether to enable a lock to protect variable tensors from being updated.
4860            If true, updates of the var, m, and v tensors will be protected by a lock.
4861            If false, the result is unpredictable. Default: False.
4862        use_nesterov (bool): Whether to use Nesterov Accelerated Gradient (NAG) algorithm to update the gradients.
4863            If true, update the gradients using NAG.
4864            If false, update the gradients without using NAG. Default: False.
4865
4866    Inputs:
4867        - **var** (Parameter) - Parameters to be updated with float32 data type. The shape is :math:`(N, *)`
4868          where :math:`*` means, any number of additional dimensions.
4869        - **m** (Parameter) - The 1st moment vector in the updating formula, has the same shape and data type as `var`.
4870        - **v** (Parameter) - The 2nd moment vector in the updating formula, has the same shape and data type as `var`.
4871          Mean square gradients, has the same type as `var` with float32 data type.
4872        - **beta1_power** (Tensor) - :math:`beta_1^t` in the updating formula with float32 data type.
4873          The shape is :math:`(1, )`.
4874        - **beta2_power** (Tensor) - :math:`beta_2^t` in the updating formula with float32 data type.
4875          The shape is :math:`(1, )`.
4876        - **lr** (Tensor) - :math:`l` in the updating formula. With float32 data type.
4877          The shape is :math:`(1, )`.
4878        - **beta1** (Tensor) - The exponential decay rate for the 1st moment estimations with float32 data type.
4879          The shape is :math:`(1, )`.
4880        - **beta2** (Tensor) - The exponential decay rate for the 2nd moment estimations with float32 data type.
4881          The shape is :math:`(1, )`.
4882        - **epsilon** (Tensor) - Term added to the denominator to improve numerical stability with float32 data type.
4883          The shape is :math:`(1, )`.
4884        - **gradient** (Tensor) - Gradient, has the same data type as `var` and
4885          gradient.shape[1:] = var.shape[1:] if var.shape > 1.
4886        - **indices** (Tensor) - Gradient indices with int32 data type and indices.shape[0] = gradient.shape[0].
4887
4888    Outputs:
4889        Tuple of 3 Tensors, this operator will update the input parameters directly, the outputs are useless.
4890
4891        - **var** (Tensor) - A Tensor with shape :math:`(1, )`.
4892        - **m** (Tensor) - A Tensor with shape :math:`(1, )`.
4893        - **v** (Tensor) - A Tensor with shape :math:`(1, )`.
4894
4895    Raises:
4896        TypeError: If neither `use_locking` nor `use_neserov` is a bool.
4897        TypeError: If dtype of `var`, `m`, `v`, `beta1_power`, `beta2_power`, `lr`, `beta1`, `beta2`, `epsilon`,
4898                   `gradient` or `indices` is not float32.
4899
4900    Supported Platforms:
4901        ``Ascend`` ``CPU``
4902
4903    Examples:
4904        >>> class Net(nn.Cell):
4905        ...     def __init__(self):
4906        ...         super(Net, self).__init__()
4907        ...         self.sparse_apply_adam = ops.FusedSparseAdam()
4908        ...         self.var = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="var")
4909        ...         self.m = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="m")
4910        ...         self.v = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="v")
4911        ...     def construct(self, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad, indices):
4912        ...         out = self.sparse_apply_adam(self.var, self.m, self.v, beta1_power, beta2_power, lr, beta1, beta2,
4913        ...                                      epsilon, grad, indices)
4914        ...         return out
4915        ...
4916        >>> net = Net()
4917        >>> beta1_power = Tensor(0.9, mindspore.float32)
4918        >>> beta2_power = Tensor(0.999, mindspore.float32)
4919        >>> lr = Tensor(0.001, mindspore.float32)
4920        >>> beta1 = Tensor(0.9, mindspore.float32)
4921        >>> beta2 = Tensor(0.999, mindspore.float32)
4922        >>> epsilon = Tensor(1e-8, mindspore.float32)
4923        >>> gradient = Tensor(np.array([[[0.1, 0.1]], [[0.1, 0.1]]]), mindspore.float32)
4924        >>> indices = Tensor([0, 1], mindspore.int32)
4925        >>> output = net(beta1_power, beta2_power, lr, beta1, beta2, epsilon, gradient, indices)
4926        >>> print(net.var.asnumpy())
4927        [[[0.9997121  0.9997121 ]]
4928         [[0.9997121  0.9997121 ]]
4929         [[0.99971527 0.99971527]]]
4930    """
4931    __mindspore_signature__ = (
4932        sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
4933        sig.make_sig('m', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
4934        sig.make_sig('v', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
4935        sig.make_sig('beta1_power', dtype=sig.sig_dtype.T),
4936        sig.make_sig('beta2_power', dtype=sig.sig_dtype.T),
4937        sig.make_sig('lr', dtype=sig.sig_dtype.T),
4938        sig.make_sig('beta1', dtype=sig.sig_dtype.T),
4939        sig.make_sig('beta2', dtype=sig.sig_dtype.T),
4940        sig.make_sig('epsilon', dtype=sig.sig_dtype.T),
4941        sig.make_sig('grad', dtype=sig.sig_dtype.T),
4942        sig.make_sig('indices', dtype=sig.sig_dtype.T1)
4943    )
4944
4945    @prim_attr_register
4946    def __init__(self, use_locking=False, use_nesterov=False):
4947        """Initialize FusedSparseAdam."""
4948        validator.check_value_type("use_locking", use_locking, [bool], self.name)
4949        validator.check_value_type("use_nesterov", use_nesterov, [bool], self.name)
4950        self.init_prim_io_names(inputs=['var', 'm', 'v', 'beta1_power', 'beta2_power', 'lr', 'beta1', 'beta2',
4951                                        'epsilon', 'grad', 'indices'],
4952                                outputs=['var', 'm', 'v'])
4953        self.add_prim_attr('side_effect_mem', True)
4954
4955    def infer_shape(self, var_shape, m_shape, v_shape, beta1_power_shape, beta2_power_shape, lr_shape,
4956                    beta1_shape, beta2_shape, epsilon_shape, grad_shape, indices_shape):
4957        validator.check("var_shape", var_shape, "m_shape", m_shape, Rel.EQ, self.name)
4958        validator.check("var_shape", var_shape, "v_shape", v_shape, Rel.EQ, self.name)
4959        validator.check_int(len(indices_shape), 1, Rel.EQ, "indices rank", self.name)
4960        validator.check('grad_shape[0]', grad_shape[0], 'indices_shape[0]', indices_shape[0], Rel.EQ, self.name)
4961        if len(var_shape) > 1 and grad_shape != indices_shape + var_shape[1:]:
4962            raise ValueError(f"For '{self.name}', the shape of updates should be [] or "
4963                             f"grad_shape = indices_shape + var_shape[1:], but got var_shape: {var_shape}, "
4964                             f"indices_shape: {indices_shape}, grad_shape: {grad_shape}.")
4965        return [1], [1], [1]
4966
4967    def infer_dtype(self, var_dtype, m_dtype, v_dtype, beta1_power_dtype, beta2_power_dtype, lr_dtype,
4968                    beta1_dtype, beta2_dtype, epsilon_dtype, grad_dtype, indices_dtype):
4969        args = {"var": var_dtype, "m": m_dtype, "v": v_dtype, "grad": grad_dtype}
4970        validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
4971
4972        args = {"beta1_power": beta1_power_dtype, "beta2_power": beta2_power_dtype, 'lr': lr_dtype,
4973                "beta1": beta1_dtype, "beta2": beta2_dtype, "epsilon": epsilon_dtype}
4974        validator.check_scalar_or_tensor_types_same(args, [mstype.float16, mstype.float32], self.name, True)
4975        validator.check_tensor_dtype_valid("indices_dtype", indices_dtype, [mstype.int32], self.name)
4976        return var_dtype, m_dtype, v_dtype
4977
4978
4979class FusedSparseLazyAdam(PrimitiveWithInfer):
4980    r"""
4981    Merges the duplicate value of the gradient and then updates parameters by the Adaptive Moment Estimation (LazyAdam)
4982    algorithm. This operator is used when the gradient is sparse. The behavior is not equivalent to the
4983    original Adam algorithm, as only the current indices parameters will be updated.
4984
4985    The Adam algorithm is proposed in `Adam: A Method for Stochastic Optimization <https://arxiv.org/abs/1412.6980>`_.
4986
4987    The updating formulas are as follows,
4988
4989    .. math::
4990        \begin{array}{ll} \\
4991            m = \beta_1 * m + (1 - \beta_1) * g \\
4992            v = \beta_2 * v + (1 - \beta_2) * g * g \\
4993            l = \alpha * \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t} \\
4994            w = w - l * \frac{m}{\sqrt{v} + \epsilon}
4995        \end{array}
4996
4997    :math:`m` represents the 1st moment vector, :math:`v` represents the 2nd moment vector, :math:`g` represents
4998    `gradient`, :math:`l` represents scaling factor `lr`, :math:`\beta_1, \beta_2` represent `beta1` and `beta2`,
4999    :math:`t` represents updating step while :math:`beta_1^t` and :math:`beta_2^t` represent `beta1_power` and
5000    `beta2_power`, :math:`\alpha` represents `learning_rate`, :math:`w` represents `var`, :math:`\epsilon` represents
5001    `epsilon`.
5002
5003    All of inputs except `indices` comply with the implicit type conversion rules to make the data types consistent.
5004    If they have different data types, lower priority data type will be converted to
5005    relatively highest priority data type.
5006    RuntimeError exception will be thrown when the data type conversion of Parameter is required.
5007
5008    Args:
5009        use_locking (bool): Whether to enable a lock to protect variable tensors from being updated.
5010            If true, updates of the var, m, and v tensors will be protected by a lock.
5011            If false, the result is unpredictable. Default: False.
5012        use_nesterov (bool): Whether to use Nesterov Accelerated Gradient (NAG) algorithm to update the gradients.
5013            If true, update the gradients using NAG.
5014            If false, update the gradients without using NAG. Default: False.
5015
5016    Inputs:
5017        - **var** (Parameter) - Parameters to be updated with float32 data type. The shape is :math:`(N, *)`
5018          where :math:`*` means, any number of additional dimensions.
5019        - **m** (Parameter) - The 1st moment vector in the updating formula, has the same shape and data type as `var`.
5020        - **v** (Parameter) - The 2nd moment vector in the updating formula, has the same shape and data type as `var`.
5021          Mean square gradients, has the same type as `var` with float32 data type.
5022        - **beta1_power** (Tensor) - :math:`beta_1^t` in the updating formula with float32 data type.
5023          The shape is :math:`(1, )`.
5024        - **beta2_power** (Tensor) - :math:`beta_2^t` in the updating formula with float32 data type.
5025          The shape is :math:`(1, )`.
5026        - **lr** (Tensor) - :math:`l` in the updating formula with float32 data type.
5027          The shape is :math:`(1, )`.
5028        - **beta1** (Tensor) - The exponential decay rate for the 1st moment estimations with float32 data type.
5029          The shape is :math:`(1, )`.
5030        - **beta2** (Tensor) - The exponential decay rate for the 2nd moment estimations with float32 data type.
5031          The shape is :math:`(1, )`.
5032        - **epsilon** (Tensor) - Term added to the denominator to improve numerical stability with float32 data type.
5033          The shape is :math:`(1, )`.
5034        - **gradient** (Tensor) - Gradient value with float32 data type and
5035          gradient.shape[1:] = var.shape[1:] if var.shape > 1.
5036        - **indices** (Tensor) - Gradient indices with int32 data type and indices.shape[0] = gradient.shape[0].
5037
5038    Outputs:
5039        Tuple of 3 Tensors, this operator will update the input parameters directly, the outputs are useless.
5040
5041        - **var** (Tensor) - A Tensor with shape :math:`(1, )`.
5042        - **m** (Tensor) - A Tensor with shape :math:`(1, )`.
5043        - **v** (Tensor) - A Tensor with shape :math:`(1, )`.
5044
5045    Raises:
5046        TypeError: If neither `use_locking` nor `use_nestrov` is a bool.
5047        TypeError: If dtype of `var`, `m`, `v`, `beta1_power`, `beta2_power`, `lr`, `beta1`, `beta2`, `epsilon` or
5048                   gradient is not float32.
5049        TypeError: If dtype of `indices` is not int32.
5050
5051    Supported Platforms:
5052        ``Ascend`` ``CPU``
5053
5054    Examples:
5055        >>> class Net(nn.Cell):
5056        ...     def __init__(self):
5057        ...         super(Net, self).__init__()
5058        ...         self.sparse_apply_lazyadam = ops.FusedSparseLazyAdam()
5059        ...         self.var = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="var")
5060        ...         self.m = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="m")
5061        ...         self.v = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="v")
5062        ...     def construct(self, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad, indices):
5063        ...         out = self.sparse_apply_lazyadam(self.var, self.m, self.v, beta1_power, beta2_power, lr, beta1,
5064        ...                                          beta2, epsilon, grad, indices)
5065        ...         return out
5066        ...
5067        >>> net = Net()
5068        >>> beta1_power = Tensor(0.9, mindspore.float32)
5069        >>> beta2_power = Tensor(0.999, mindspore.float32)
5070        >>> lr = Tensor(0.001, mindspore.float32)
5071        >>> beta1 = Tensor(0.9, mindspore.float32)
5072        >>> beta2 = Tensor(0.999, mindspore.float32)
5073        >>> epsilon = Tensor(1e-8, mindspore.float32)
5074        >>> gradient = Tensor(np.array([[[0.1, 0.1]], [[0.1, 0.1]]]), mindspore.float32)
5075        >>> indices = Tensor([0, 1], mindspore.int32)
5076        >>> output = net(beta1_power, beta2_power, lr, beta1, beta2, epsilon, gradient, indices)
5077        >>> print(net.var.asnumpy())
5078        [[[0.9997121  0.9997121 ]]
5079         [[0.9997121  0.9997121 ]]
5080         [[1.         1.        ]]]
5081    """
5082    __mindspore_signature__ = (
5083        sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
5084        sig.make_sig('m', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
5085        sig.make_sig('v', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
5086        sig.make_sig('beta1_power', dtype=sig.sig_dtype.T),
5087        sig.make_sig('beta2_power', dtype=sig.sig_dtype.T),
5088        sig.make_sig('lr', dtype=sig.sig_dtype.T),
5089        sig.make_sig('beta1', dtype=sig.sig_dtype.T),
5090        sig.make_sig('beta2', dtype=sig.sig_dtype.T),
5091        sig.make_sig('epsilon', dtype=sig.sig_dtype.T),
5092        sig.make_sig('grad', dtype=sig.sig_dtype.T),
5093        sig.make_sig('indices', dtype=sig.sig_dtype.T1)
5094    )
5095
5096    @prim_attr_register
5097    def __init__(self, use_locking=False, use_nesterov=False):
5098        """Initialize FusedSparseLazyAdam."""
5099        validator.check_value_type("use_locking", use_locking, [bool], self.name)
5100        validator.check_value_type("use_nesterov", use_nesterov, [bool], self.name)
5101        self.init_prim_io_names(inputs=['var', 'm', 'v', 'beta1_power', 'beta2_power', 'lr', 'beta1', 'beta2',
5102                                        'epsilon', 'grad', 'indices'],
5103                                outputs=['var', 'm', 'v'])
5104        self.add_prim_attr('side_effect_mem', True)
5105
5106    def infer_shape(self, var_shape, m_shape, v_shape, beta1_power_shape, beta2_power_shape, lr_shape,
5107                    beta1_shape, beta2_shape, epsilon_shape, grad_shape, indices_shape):
5108        validator.check("var_shape", var_shape, "m_shape", m_shape, Rel.EQ, self.name)
5109        validator.check("var_shape", var_shape, "v_shape", v_shape, Rel.EQ, self.name)
5110        validator.check_int(len(indices_shape), 1, Rel.EQ, "indices rank", self.name)
5111        validator.check('grad_shape[0]', grad_shape[0], 'indices_shape[0]', indices_shape[0], Rel.EQ, self.name)
5112        if len(var_shape) > 1 and grad_shape != indices_shape + var_shape[1:]:
5113            raise ValueError(f"For '{self.name}', the shape of updates should be [] or "
5114                             f"grad_shape = indices_shape + var_shape[1:], but got var_shape: {var_shape}, "
5115                             f"indices_shape: {indices_shape}, grad_shape: {grad_shape}.")
5116        return [1], [1], [1]
5117
5118    def infer_dtype(self, var_dtype, m_dtype, v_dtype, beta1_power_dtype, beta2_power_dtype, lr_dtype,
5119                    beta1_dtype, beta2_dtype, epsilon_dtype, grad_dtype, indices_dtype):
5120        args = {"var": var_dtype, "m": m_dtype, "v": v_dtype, "grad": grad_dtype}
5121        validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
5122
5123        args = {"beta1_power": beta1_power_dtype, "beta2_power": beta2_power_dtype, 'lr': lr_dtype,
5124                "beta1": beta1_dtype, "beta2": beta2_dtype, "epsilon": epsilon_dtype}
5125        validator.check_scalar_or_tensor_types_same(args, [mstype.float16, mstype.float32], self.name, True)
5126
5127        validator.check_tensor_dtype_valid("indices_dtype", indices_dtype, [mstype.int32], self.name)
5128        return var_dtype, m_dtype, v_dtype
5129
5130
5131class FusedSparseFtrl(PrimitiveWithInfer):
5132    """
5133    Merges the duplicate value of the gradient and then updates relevant entries according to the FTRL-proximal scheme.
5134
5135    All of inputs except `indices` comply with the implicit type conversion rules to make the data types consistent.
5136    If they have different data types, lower priority data type will be converted to
5137    relatively highest priority data type.
5138    RuntimeError exception will be thrown when the data type conversion of Parameter is required.
5139
5140    Args:
5141        lr (float): The learning rate value, must be positive.
5142        l1 (float): l1 regularization strength, must be greater than or equal to zero.
5143        l2 (float): l2 regularization strength, must be greater than or equal to zero.
5144        lr_power (float): Learning rate power controls how the learning rate decreases during training,
5145            must be less than or equal to zero. Use fixed learning rate if `lr_power` is zero.
5146        use_locking (bool): Use locks for updating operation if true . Default: False.
5147
5148    Inputs:
5149        - **var** (Parameter) - The variable to be updated. The data type must be float32. The shape is :math:`(N, *)`
5150          where :math:`*` means, any number of additional dimensions.
5151        - **accum** (Parameter) - The accumulation to be updated, must be same type and shape as `var`.
5152        - **linear** (Parameter) - the linear coefficient to be updated, must be same type and shape as `var`.
5153        - **grad** (Tensor) - A tensor of the same type as `var` and
5154          grad.shape[1:] = var.shape[1:] if var.shape > 1.
5155        - **indices** (Tensor) - A vector of indices into the first dimension of `var` and `accum`.
5156          The type must be int32 and indices.shape[0] = grad.shape[0].
5157
5158    Outputs:
5159        Tuple of 3 Tensor, this operator will update the input parameters directly, the outputs are useless.
5160
5161        - **var** (Tensor) - A Tensor with shape :math:`(1, )`.
5162        - **accum** (Tensor) - A Tensor with shape :math:`(1, )`.
5163        - **linear** (Tensor) - A Tensor with shape :math:`(1, )`.
5164
5165    Raises:
5166        TypeError: If `lr`, `l1`, `l2` or `lr_power` is not a float.
5167        ValueError: If shape of `lr_power` less than or equal to zero.
5168        TypeError: If dtype of `var` is not float32.
5169        TypeError: If dtype of `indices` is not int32.
5170        TypeError: If shape of `accum`, `linear` or `grad` is not same as `var`.
5171        TypeError: If shape of `indices` is not same as shape of first dimension of `grad`.
5172
5173    Supported Platforms:
5174        ``Ascend`` ``CPU``
5175
5176    Examples:
5177        >>> class SparseApplyFtrlNet(nn.Cell):
5178        ...     def __init__(self):
5179        ...         super(SparseApplyFtrlNet, self).__init__()
5180        ...         self.sparse_apply_ftrl = ops.FusedSparseFtrl(lr=0.01, l1=0.0, l2=0.0, lr_power=-0.5)
5181        ...         self.var = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="var")
5182        ...         self.accum = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="accum")
5183        ...         self.linear = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="linear")
5184        ...
5185        ...     def construct(self, grad, indices):
5186        ...         out = self.sparse_apply_ftrl(self.var, self.accum, self.linear, grad, indices)
5187        ...         return out
5188        ...
5189        >>> net = SparseApplyFtrlNet()
5190        >>> grad = Tensor(np.array([[[0.1, 0.1]], [[0.1, 0.1]]]).astype(np.float32))
5191        >>> indices = Tensor(np.array([0, 1]).astype(np.int32))
5192        >>> output = net(grad, indices)
5193        >>> print(net.var.asnumpy())
5194        [[[-0.00598256 -0.00598256]]
5195         [[-0.00598256 -0.00598256]]
5196         [[ 1.          1.        ]]]
5197    """
5198    __mindspore_signature__ = (
5199        sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
5200        sig.make_sig('accum', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
5201        sig.make_sig('linear', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
5202        sig.make_sig('grad', dtype=sig.sig_dtype.T),
5203        sig.make_sig('indices', dtype=sig.sig_dtype.T1)
5204    )
5205
5206    @prim_attr_register
5207    def __init__(self, lr, l1, l2, lr_power, use_locking=False):
5208        """Initialize FusedSparseFtrl."""
5209        self.init_prim_io_names(inputs=['var', 'accum', 'linear', 'grad', 'indices'],
5210                                outputs=['output'])
5211        self.add_prim_attr('side_effect_mem', True)
5212
5213        validator.check_value_type("lr", lr, [float], self.name)
5214        validator.check_value_type("l1", l1, [float], self.name)
5215        validator.check_value_type("l2", l2, [float], self.name)
5216        validator.check_value_type("lr_power", lr_power, [float], self.name)
5217        self.lr = validator.check_positive_float(lr, "lr", self.name)
5218        self.l1 = validator.check_non_negative_float(l1, "l1", self.name)
5219        self.l2 = validator.check_non_negative_float(l2, "l2", self.name)
5220        self.lr_power = validator.check_number("lr_power", lr_power, 0, Rel.LE, self.name)
5221        self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name)
5222
5223    def infer_shape(self, var_shape, accum_shape, linear_shape, grad_shape, indices_shape):
5224        validator.check('var shape', var_shape, 'accum shape', accum_shape, Rel.EQ, self.name)
5225        validator.check('var shape', var_shape, 'linear shape', linear_shape, Rel.EQ, self.name)
5226        if len(var_shape) > 1:
5227            validator.check('var_shape[1:]', var_shape[1:], 'grad_shape[1:]', grad_shape[1:], Rel.EQ, self.name)
5228        validator.check_int(len(indices_shape), 1, Rel.EQ, "indices rank", self.name)
5229        validator.check('grad_shape[0]', grad_shape[0], 'indices_shape[0]', indices_shape[0], Rel.EQ, self.name)
5230        return [1], [1], [1]
5231
5232    def infer_dtype(self, var_dtype, accum_dtype, linear_dtype, grad_dtype, indices_dtype):
5233        args = {"var_dtype": var_dtype, "accum_dtype": accum_dtype,
5234                "linear_dtype": linear_dtype, "grad_dtype": grad_dtype}
5235        validator.check_tensors_dtypes_same_and_valid(args, [mstype.float32], self.name)
5236        validator.check_tensor_dtype_valid("indices_dtype", indices_dtype, [mstype.int32], self.name)
5237        return var_dtype, accum_dtype, linear_dtype
5238
5239
5240class FusedSparseProximalAdagrad(PrimitiveWithInfer):
5241    r"""
5242    Merges the duplicate value of the gradient and then updates relevant entries according to the proximal adagrad
5243    algorithm.
5244
5245    .. math::
5246        \begin{array}{ll} \\
5247            accum += grad * grad \\
5248            \text{prox_v} = var - lr * grad * \frac{1}{\sqrt{accum}} \\
5249            var = \frac{sign(\text{prox_v})}{1 + lr * l2} * \max(\left| \text{prox_v} \right| - lr * l1, 0)
5250        \end{array}
5251
5252    All of inputs except `indices` comply with the implicit type conversion rules to make the data types consistent.
5253    If they have different data types, lower priority data type will be converted to
5254    relatively highest priority data type.
5255    RuntimeError exception will be thrown when the data type conversion of Parameter is required.
5256
5257    Args:
5258        use_locking (bool): If true, the variable and accumulation tensors will be protected from being updated.
5259            Default: False.
5260
5261    Inputs:
5262        - **var** (Parameter) - Variable tensor to be updated. The data type must be float32.
5263          The shape is :math:`(N, *)` where :math:`*` means, any number of additional dimensions.
5264        - **accum** (Parameter) - Variable tensor to be updated, has the same shape and data type as `var`.
5265        - **lr** (Tensor) - The learning rate value. The data type must be float32. The shape is :math:`(1, )`.
5266        - **l1** (Tensor) - l1 regularization strength. The data type must be float32. The shape is :math:`(1, )`.
5267        - **l2** (Tensor) - l2 regularization strength. The data type must be float32. The shape is :math:`(1, )`.
5268        - **grad** (Tensor) - A tensor of the same data type as `var` and
5269          grad.shape[1:] = var.shape[1:] if var.shape > 1.
5270        - **indices** (Tensor) - A vector of indices into the first dimension of `var` and `accum`.
5271          The type must be int32 and indices.shape[0] = grad.shape[0].
5272
5273    Outputs:
5274        Tuple of 2 Tensors, this operator will update the input parameters directly, the outputs are useless.
5275
5276        - **var** (Tensor) - A Tensor with shape :math:`(1, )`.
5277        - **accum** (Tensor) - A Tensor with shape :math:`(1, )`.
5278
5279    Raises:
5280        TypeError: If `use_locking` is not a bool.
5281        TypeError: If dtype of `var`, `accum`, `lr`, `l1`, `l2` or `grad` is not float32.
5282        TypeError: If dtype of `indices` is not int32.
5283
5284    Supported Platforms:
5285        ``Ascend`` ``CPU``
5286
5287    Examples:
5288        >>> class Net(nn.Cell):
5289        ...     def __init__(self):
5290        ...         super(Net, self).__init__()
5291        ...         self.sparse_apply_proximal_adagrad = ops.FusedSparseProximalAdagrad()
5292        ...         self.var = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="var")
5293        ...         self.accum = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="accum")
5294        ...         self.lr = Tensor(0.01, mindspore.float32)
5295        ...         self.l1 = Tensor(0.0, mindspore.float32)
5296        ...         self.l2 = Tensor(0.0, mindspore.float32)
5297        ...     def construct(self, grad, indices):
5298        ...         out = self.sparse_apply_proximal_adagrad(self.var, self.accum, self.lr, self.l1,
5299        ...                                                  self.l2, grad, indices)
5300        ...         return out
5301        ...
5302        >>> net = Net()
5303        >>> grad = Tensor(np.array([[[0.1, 0.1]], [[0.1, 0.1]]]).astype(np.float32))
5304        >>> indices = Tensor(np.array([0, 1]).astype(np.int32))
5305        >>> output = net(grad, indices)
5306        >>> print(net.var.asnumpy())
5307        [[[0.99900496 0.99900496]]
5308         [[0.99900496 0.99900496]]
5309         [[1.         1.        ]]]
5310    """
5311    __mindspore_signature__ = (
5312        sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
5313        sig.make_sig('accum', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
5314        sig.make_sig('lr', dtype=sig.sig_dtype.T),
5315        sig.make_sig('l1', dtype=sig.sig_dtype.T),
5316        sig.make_sig('l2', dtype=sig.sig_dtype.T),
5317        sig.make_sig('grad', dtype=sig.sig_dtype.T),
5318        sig.make_sig('indices', dtype=sig.sig_dtype.T1)
5319    )
5320
5321    @prim_attr_register
5322    def __init__(self, use_locking=False):
5323        """Initialize FusedSparseProximalAdagrad"""
5324        self.init_prim_io_names(inputs=['var', 'accum', 'lr', 'l1', 'l2', 'grad', 'indices'],
5325                                outputs=['output'])
5326        self.add_prim_attr('side_effect_mem', True)
5327        self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name)
5328
5329    def infer_shape(self, var_shape, accum_shape, lr_shape, l1_shape, l2_shape,
5330                    grad_shape, indices_shape):
5331        validator.check_int(len(indices_shape), 1, Rel.EQ, "indices rank", self.name)
5332        return [1], [1]
5333
5334    def infer_dtype(self, var_dtype, accum_dtype, lr_dtype, l1_dtype, l2_dtype,
5335                    grad_dtype, indices_dtype):
5336        args = {'var': var_dtype, 'accum': accum_dtype, 'grad': grad_dtype}
5337        validator.check_tensors_dtypes_same_and_valid(args, [mstype.float32], self.name)
5338        validator.check_scalar_or_tensor_types_same({"lr": lr_dtype}, [mstype.float32], self.name)
5339        validator.check_scalar_or_tensor_types_same({"l1": l1_dtype}, [mstype.float32], self.name)
5340        validator.check_scalar_or_tensor_types_same({"l2": l2_dtype}, [mstype.float32], self.name)
5341        valid_dtypes = [mstype.int16, mstype.int32, mstype.int64,
5342                        mstype.uint16, mstype.uint32, mstype.uint64]
5343        validator.check_tensor_dtype_valid('indices', indices_dtype, valid_dtypes, self.name)
5344        return var_dtype, accum_dtype
5345
5346
5347class KLDivLoss(PrimitiveWithInfer):
5348    r"""
5349    Computes the Kullback-Leibler divergence between the logits and the labels.
5350
5351    The updating formulas of KLDivLoss algorithm are as follows,
5352
5353    .. math::
5354        L = \{l_1,\dots,l_N\}^\top, \quad
5355        l_n = y_n \cdot (\log y_n - x_n)
5356
5357    Then,
5358
5359    .. math::
5360        \ell(x, y) = \begin{cases}
5361        L, & \text{if reduction} = \text{'none';}\\
5362        \operatorname{mean}(L), & \text{if reduction} = \text{'mean';}\\
5363        \operatorname{sum}(L),  & \text{if reduction} = \text{'sum'.}
5364        \end{cases}
5365
5366    where :math:`x` represents `logits`.
5367    :math:`y` represents `labels`.
5368    :math:`\ell(x, y)` represents `output`.
5369
5370    Args:
5371        reduction (str): Specifies the reduction to be applied to the output.
5372            Its value must be one of 'none', 'mean', 'sum'. Default: 'mean'.
5373
5374    Inputs:
5375        - **logits** (Tensor) - The input Tensor. The data type must be float32.
5376        - **labels** (Tensor) - The label Tensor which has the same shape and data type as `logits`.
5377
5378    Outputs:
5379        Tensor or Scalar, if `reduction` is 'none', then output is a tensor and has the same shape as `logits`.
5380        Otherwise it is a scalar.
5381
5382    Raises:
5383        TypeError: If `reduction` is not a str.
5384        TypeError: If neither `logits` nor `labels` is a Tensor.
5385        TypeError: If dtype of `logits` or `labels` is not float32.
5386
5387    Supported Platforms:
5388        ``GPU``
5389
5390    Examples:
5391        >>> class Net(nn.Cell):
5392        ...     def __init__(self):
5393        ...         super(Net, self).__init__()
5394        ...         self.kldiv_loss = ops.KLDivLoss()
5395        ...     def construct(self, logits, labels):
5396        ...         result = self.kldiv_loss(logits, labels)
5397        ...         return result
5398        ...
5399        >>> net = Net()
5400        >>> logits = Tensor(np.array([0.2, 0.7, 0.1]), mindspore.float32)
5401        >>> labels = Tensor(np.array([0., 1., 0.]), mindspore.float32)
5402        >>> output = net(logits, labels)
5403        >>> print(output)
5404        -0.23333333
5405    """
5406
5407    @prim_attr_register
5408    def __init__(self, reduction='mean'):
5409        """Initialize KLDivLoss."""
5410        self.reduction = validator.check_string(reduction, ['none', 'mean', 'sum'], 'reduction', self.name)
5411
5412    def infer_shape(self, x_shape, y_shape):
5413        validator.check('x_shape', x_shape, 'y_shape', y_shape, Rel.EQ, self.name)
5414        if self.reduction in ('mean', 'sum'):
5415            shape = []
5416        else:
5417            shape = x_shape
5418        return shape
5419
5420    def infer_dtype(self, x_type, y_type):
5421        args = {'x': x_type, 'y': y_type}
5422        valid_dtypes = (mstype.float16, mstype.float32)
5423        validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
5424        return x_type
5425
5426
5427class BinaryCrossEntropy(PrimitiveWithInfer):
5428    r"""
5429    Computes the binary cross entropy between the logits and the labels.
5430
5431    Sets logits as :math:`x`, labels as :math:`y`, output as :math:`\ell(x, y)`.
5432    Let,
5433
5434    .. math::
5435        L = \{l_1,\dots,l_N\}^\top, \quad
5436        l_n = - w_n \left[ y_n \cdot \log x_n + (1 - y_n) \cdot \log (1 - x_n) \right]
5437
5438    In which, :math:`L` indicates the loss of all batch_sizes, :math:`l` indicates the loss of one batch_size,
5439    and n indicates one batch_size in the 1-N range. Then,
5440
5441    .. math::
5442        \ell(x, y) = \begin{cases}
5443        L, & \text{if reduction} = \text{'none';}\\
5444        \operatorname{mean}(L), & \text{if reduction} = \text{'mean';}\\
5445        \operatorname{sum}(L),  & \text{if reduction} = \text{'sum'.}
5446        \end{cases}
5447
5448    .. warning::
5449        - The value of "x" must range from 0 to 1.
5450        - The value of "y" must be "0" or "1".
5451
5452    Args:
5453        reduction (str): Specifies the reduction to be applied to the output.
5454            Its value must be one of 'none', 'mean', 'sum'. Default: 'mean'.
5455
5456    Inputs:
5457        - **logits** (Tensor) - The input Tensor. The data type must be float16 or float32,
5458          The shape is :math:`(N, *)` where :math:`*` means, any number of additional dimensions.
5459        - **labels** (Tensor) - The label Tensor which has same shape and data type as `logits`.
5460        - **weight** (Tensor, optional) - A rescaling weight applied to the loss of each batch element.
5461          And it must have same shape and data type as `logits`. Default: None.
5462
5463    Outputs:
5464        Tensor or Scalar, if `reduction` is 'none', then output is a tensor and has the same shape as `logits`.
5465        Otherwise, the output is a scalar.
5466
5467    Raises:
5468        TypeError: If dtype of `logits`, `labels` or `weight` (if given) is neither float16 not float32.
5469        ValueError: If `reduction` is not one of 'none', 'mean', 'sum'.
5470        ValueError: If shape of `labels` is not the same as `logits` or `weight` (if given).
5471        TypeError: If `logits`, `labels` or `weight` is not a Tensor.
5472
5473    Supported Platforms:
5474        ``Ascend`` ``GPU`` ``CPU``
5475
5476    Examples:
5477        >>> class Net(nn.Cell):
5478        ...     def __init__(self):
5479        ...         super(Net, self).__init__()
5480        ...         self.binary_cross_entropy = ops.BinaryCrossEntropy()
5481        ...     def construct(self, logits, labels, weight):
5482        ...         result = self.binary_cross_entropy(logits, labels, weight)
5483        ...         return result
5484        ...
5485        >>> net = Net()
5486        >>> logits = Tensor(np.array([0.2, 0.7, 0.1]), mindspore.float32)
5487        >>> labels = Tensor(np.array([0., 1., 0.]), mindspore.float32)
5488        >>> weight = Tensor(np.array([1, 2, 2]), mindspore.float32)
5489        >>> output = net(logits, labels, weight)
5490        >>> print(output)
5491        0.38240486
5492    """
5493
5494    @prim_attr_register
5495    def __init__(self, reduction='mean'):
5496        """Initialize BinaryCrossEntropy."""
5497        self.reduction = validator.check_string(reduction, ['none', 'mean', 'sum'], 'reduction', self.name)
5498
5499    def infer_shape(self, x_shape, y_shape, weight_shape):
5500        validator.check('x_shape', x_shape, 'y_shape', y_shape, Rel.EQ, self.name)
5501        if weight_shape:
5502            validator.check('y_shape', y_shape, 'weight_shape', weight_shape, Rel.EQ, self.name)
5503        if self.reduction in ('mean', 'sum'):
5504            shape = []
5505        else:
5506            shape = x_shape
5507        return shape
5508
5509    def infer_dtype(self, x_type, y_type, weight_type):
5510        args = {'x': x_type, 'y': y_type}
5511        valid_dtypes = (mstype.float16, mstype.float32)
5512        validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
5513        if weight_type:
5514            validator.check_tensors_dtypes_same_and_valid({'x': x_type, 'weight': weight_type}, valid_dtypes,
5515                                                          self.name)
5516        return x_type
5517
5518
5519class ApplyAdaMax(PrimitiveWithInfer):
5520    r"""
5521    Updates relevant entries according to the adamax scheme.
5522
5523    The updating formulas are as follows,
5524
5525    .. math::
5526        \begin{array}{ll} \\
5527            m_{t+1} = \beta_1 * m_{t} + (1 - \beta_1) * g \\
5528            v_{t+1} = \max(\beta_2 * v_{t}, \left| g \right|) \\
5529            var = var - \frac{l}{1 - \beta_1^{t+1}} * \frac{m_{t+1}}{v_{t+1} + \epsilon}
5530        \end{array}
5531
5532    :math:`t` represents updating step while :math:`m` represents the 1st moment vector, :math:`m_{t}`
5533    is the last momentent of :math:`m_{t+1}`, :math:`v` represents the 2nd moment vector, :math:`v_{t}`
5534    is the last momentent of :math:`v_{t+1}`, :math:`l` represents scaling factor `lr`,
5535    :math:`g` represents `grad`, :math:`\beta_1, \beta_2` represent `beta1` and `beta2`,
5536    :math:`beta_1^{t+1}` represents `beta1_power`, :math:`var` represents the variable to be updated,
5537    :math:`\epsilon` represents `epsilon`.
5538
5539    Inputs of `var`, `m`, `v` and `grad` comply with the implicit type conversion rules
5540    to make the data types consistent.
5541    If they have different data types, lower priority data type will be converted to
5542    relatively highest priority data type.
5543    RuntimeError exception will be thrown when the data type conversion of Parameter is required.
5544
5545    Inputs:
5546        - **var** (Parameter) - Variable to be updated. With float32 or float16 data type.
5547          The shape is :math:`(N, *)` where :math:`*` means, any number of additional dimensions.
5548        - **m** (Parameter) - The 1st moment vector in the updating formula, has the same shape and type as `var`.
5549          With float32 or float16 data type.
5550        - **v** (Parameter) - The 2nd moment vector in the updating formula. Mean square gradients
5551          with the same shape and type as `var`. With float32 or float16 data type.
5552        - **beta1_power** (Union[Number, Tensor]) - :math:`beta_1^t` in the updating formula, must be scalar.
5553          With float32 or float16 data type.
5554        - **lr** (Union[Number, Tensor]) - Learning rate, :math:`l` in the updating formula, must be scalar.
5555          With float32 or float16 data type.
5556        - **beta1** (Union[Number, Tensor]) - The exponential decay rate for the 1st moment estimations,
5557          must be scalar. With float32 or float16 data type.
5558        - **beta2** (Union[Number, Tensor]) - The exponential decay rate for the 2nd moment estimations,
5559          must be scalar. With float32 or float16 data type.
5560        - **epsilon** (Union[Number, Tensor]) - A small value added for numerical stability, must be scalar.
5561          With float32 or float16 data type.
5562        - **grad** (Tensor) - A tensor for gradient, has the same shape and type as `var`.
5563          With float32 or float16 data type.
5564
5565    Outputs:
5566        Tuple of 3 Tensor, the updated parameters.
5567
5568        - **var** (Tensor) - The same shape and data type as `var`.
5569        - **m** (Tensor) - The same shape and data type as `m`.
5570        - **v** (Tensor) - The same shape and data type as `v`.
5571
5572    Raises:
5573        TypeError: If dtype of `var`, `m`, `v`, `beta_power`, `lr`, `beta1`, `beta2`, `epsilon` or `grad` is neither
5574                   float16 nor float32.
5575        TypeError: If `beta_power`, `lr`, `beta1`, `beta2` or `epsilon` is neither a Number nor a Tensor.
5576        TypeError: If `grad` is not a Tensor.
5577
5578    Supported Platforms:
5579        ``Ascend``
5580
5581    Examples:
5582        >>> class Net(nn.Cell):
5583        ...     def __init__(self):
5584        ...         super(Net, self).__init__()
5585        ...         self.apply_ada_max = ops.ApplyAdaMax()
5586        ...         self.var = Parameter(Tensor(np.array([[0.6, 0.4],
5587        ...                                               [0.1, 0.5]]).astype(np.float32)), name="var")
5588        ...         self.m = Parameter(Tensor(np.array([[0.6, 0.5],
5589        ...                                             [0.2, 0.6]]).astype(np.float32)), name="m")
5590        ...         self.v = Parameter(Tensor(np.array([[0.9, 0.1],
5591        ...                                             [0.7, 0.8]]).astype(np.float32)), name="v")
5592        ...     def construct(self, beta1_power, lr, beta1, beta2, epsilon, grad):
5593        ...         out = self.apply_ada_max(self.var, self.m, self.v, beta1_power, lr, beta1, beta2, epsilon, grad)
5594        ...         return out
5595        ...
5596        >>> net = Net()
5597        >>> beta1_power =Tensor(0.9, mindspore.float32)
5598        >>> lr = Tensor(0.001, mindspore.float32)
5599        >>> beta1 = Tensor(0.9, mindspore.float32)
5600        >>> beta2 = Tensor(0.99, mindspore.float32)
5601        >>> epsilon = Tensor(1e-10, mindspore.float32)
5602        >>> grad = Tensor(np.array([[0.3, 0.7], [0.1, 0.8]]).astype(np.float32))
5603        >>> output = net(beta1_power, lr, beta1, beta2, epsilon, grad)
5604        >>> print(output)
5605        (Tensor(shape=[2, 2], dtype=Float32, value=
5606        [[ 5.93602717e-01,  3.92571449e-01],
5607         [ 9.72582996e-02,  4.92249995e-01]]), Tensor(shape=[2, 2], dtype=Float32, value=
5608        [[ 5.69999993e-01,  5.19999981e-01],
5609         [ 1.89999998e-01,  6.20000005e-01]]), Tensor(shape=[2, 2], dtype=Float32, value=
5610        [[ 8.90999973e-01,  6.99999988e-01],
5611         [ 6.93000019e-01,  8.00000012e-01]]))
5612    """
5613
5614    __mindspore_signature__ = (
5615        sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
5616        sig.make_sig('m', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
5617        sig.make_sig('v', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
5618        sig.make_sig('beta1_power', dtype=sig.sig_dtype.T1),
5619        sig.make_sig('lr', dtype=sig.sig_dtype.T2),
5620        sig.make_sig('beta1', dtype=sig.sig_dtype.T3),
5621        sig.make_sig('beta2', dtype=sig.sig_dtype.T4),
5622        sig.make_sig('epsilon', dtype=sig.sig_dtype.T5),
5623        sig.make_sig('grad', dtype=sig.sig_dtype.T)
5624    )
5625
5626    @prim_attr_register
5627    def __init__(self):
5628        """Initialize ApplyAdaMax"""
5629        self.add_prim_attr('side_effect_mem', True)
5630
5631    def infer_shape(self, var_shape, m_shape, v_shape, beta1_power_shape, lr_shape,
5632                    beta1_shape, beta2_shape, epsilon_shape, grad_shape):
5633        validator.check("m_shape", m_shape, "var_shape", var_shape, Rel.EQ, self.name)
5634        validator.check("v_shape", v_shape, "var_shape", var_shape, Rel.EQ, self.name)
5635        validator.check("grad_shape", grad_shape, "var_shape", var_shape, Rel.EQ, self.name)
5636        beta1_power_shp_len = len(beta1_power_shape)
5637        validator.check_int(beta1_power_shp_len, 1, Rel.LE, "beta1 power's rank", self.name)
5638        if beta1_power_shp_len == 1:
5639            validator.check_int(beta1_power_shape[0], 1, Rel.EQ, "beta1_power_shape[0]", self.name)
5640        lr_shp_len = len(lr_shape)
5641        validator.check_int(lr_shp_len, 1, Rel.LE, "lr's rank", self.name)
5642        if lr_shp_len == 1:
5643            validator.check_int(lr_shape[0], 1, Rel.EQ, "lr_shape[0]", self.name)
5644        beta1_shp_len = len(beta1_shape)
5645        validator.check_int(beta1_shp_len, 1, Rel.LE, "beta1's rank", self.name)
5646        if beta1_shp_len == 1:
5647            validator.check_int(beta1_shape[0], 1, Rel.EQ, "beta1_shape[0]", self.name)
5648        beta2_shp_len = len(beta2_shape)
5649        validator.check_int(beta2_shp_len, 1, Rel.LE, "beta2's rank", self.name)
5650        if beta2_shp_len == 1:
5651            validator.check_int(beta2_shape[0], 1, Rel.EQ, "beta2_shape[0]", self.name)
5652        epsilon_shp_len = len(epsilon_shape)
5653        validator.check_int(epsilon_shp_len, 1, Rel.LE, "epsilon's rank", self.name)
5654        if epsilon_shp_len == 1:
5655            validator.check_int(epsilon_shape[0], 1, Rel.EQ, "epsilon_shape[0]", self.name)
5656        return var_shape, m_shape, v_shape
5657
5658    def infer_dtype(self, var_dtype, m_dtype, v_dtype, beta1_power_dtype, lr_dtype,
5659                    beta1_dtype, beta2_dtype, epsilon_dtype, grad_dtype):
5660        valid_dtypes = [mstype.float16, mstype.float32]
5661        args = {"var": var_dtype, "m": m_dtype, "v": v_dtype, "grad": grad_dtype}
5662        validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
5663        validator.check_scalar_or_tensor_types_same({"beta1_power": beta1_power_dtype}, valid_dtypes, self.name)
5664        validator.check_scalar_or_tensor_types_same({"lr": lr_dtype}, valid_dtypes, self.name)
5665        validator.check_scalar_or_tensor_types_same({"beta1": beta1_dtype}, valid_dtypes, self.name)
5666        validator.check_scalar_or_tensor_types_same({"beta2": beta2_dtype}, valid_dtypes, self.name)
5667        validator.check_scalar_or_tensor_types_same({"epsilon": epsilon_dtype}, valid_dtypes, self.name)
5668        return var_dtype, m_dtype, v_dtype
5669
5670
5671class ApplyAdadelta(PrimitiveWithInfer):
5672    r"""
5673    Updates relevant entries according to the adadelta scheme.
5674
5675    .. math::
5676        \begin{array}{ll} \\
5677            accum = \rho * accum + (1 - \rho) * grad^2 \\
5678            \text{update} = \sqrt{\text{accum_update} + \epsilon} * \frac{grad}{\sqrt{accum + \epsilon}} \\
5679            \text{accum_update} = \rho * \text{accum_update} + (1 - \rho) * update^2 \\
5680            var -= lr * update
5681        \end{array}
5682
5683    where :math:`\rho` represents `rho`, :math:`\epsilon` represents `epsilon`.
5684
5685    Inputs of `var`, `accum`, `accum_update` and `grad` comply with the implicit type conversion rules
5686    to make the data types consistent.
5687    If they have different data types, lower priority data type will be converted to
5688    relatively highest priority data type.
5689    RuntimeError exception will be thrown when the data type conversion of Parameter is required.
5690
5691    Inputs:
5692        - **var** (Parameter) - Weights to be updated. With float32 or float16 data type.
5693          The shape is :math:`(N, *)` where :math:`*` means, any number of additional dimensions.
5694        - **accum** (Parameter) - Accumulation to be updated, has the same shape and data type as `var`.
5695        - **accum_update** (Parameter) - Accum_update to be updated, has the same shape and data type as `var`.
5696        - **lr** (Union[Number, Tensor]) - Learning rate, must be scalar. With float32 or float16 data type.
5697        - **rho** (Union[Number, Tensor]) - Decay rate, must be scalar. With float32 or float16 data type.
5698        - **epsilon** (Union[Number, Tensor]) - A small value added for numerical stability, must be scalar.
5699          With float32 or float16 data type.
5700        - **grad** (Tensor) - Gradients, has the same shape and data type as `var`.
5701
5702    Outputs:
5703        Tuple of 3 Tensor, the updated parameters.
5704
5705        - **var** (Tensor) - The same shape and data type as `var`.
5706        - **accum** (Tensor) - The same shape and data type as `accum`.
5707        - **accum_update** (Tensor) - The same shape and data type as `accum_update`.
5708
5709    Raises:
5710        TypeError: If dtype of `var`, `accum`, `accum_update`, `lr`, `rho`, `epsilon` or `grad` is neither float16 nor
5711                   float32.
5712        TypeError: If `accum_update`, `lr`, `rho` or `epsilon` is neither a Number nor a Tensor.
5713
5714    Supported Platforms:
5715        ``Ascend``
5716
5717    Examples:
5718        >>> class Net(nn.Cell):
5719        ...     def __init__(self):
5720        ...         super(Net, self).__init__()
5721        ...         self.apply_adadelta = ops.ApplyAdadelta()
5722        ...         self.var = Parameter(Tensor(np.array([[0.6, 0.4],
5723        ...                                               [0.1, 0.5]]).astype(np.float32)), name="var")
5724        ...         self.accum = Parameter(Tensor(np.array([[0.6, 0.5],
5725        ...                                                 [0.2, 0.6]]).astype(np.float32)), name="accum")
5726        ...         self.accum_update = Parameter(Tensor(np.array([[0.9, 0.1],
5727        ...                                                        [0.7, 0.8]]).astype(np.float32)),
5728        ...                                                             name="accum_update")
5729        ...     def construct(self, lr, rho, epsilon, grad):
5730        ...         out = self.apply_adadelta(self.var, self.accum, self.accum_update, lr, rho, epsilon, grad)
5731        ...         return out
5732        ...
5733        >>> net = Net()
5734        >>> lr = Tensor(0.001, mindspore.float32)
5735        >>> rho = Tensor(0.0, mindspore.float32)
5736        >>> epsilon = Tensor(1e-6, mindspore.float32)
5737        >>> grad = Tensor(np.array([[0.3, 0.7], [0.1, 0.8]]).astype(np.float32))
5738        >>> output = net(lr, rho, epsilon, grad)
5739        >>> print(output)
5740        (Tensor(shape=[2, 2], dtype=Float32, value=
5741        [[ 5.99051356e-01,  3.99683774e-01],
5742         [ 9.91633832e-02,  4.99105573e-01]]), Tensor(shape=[2, 2], dtype=Float32, value=
5743        [[ 9.00000036e-02,  4.89999980e-01],
5744         [ 1.00000007e-02,  6.40000045e-01]]), Tensor(shape=[2, 2], dtype=Float32, value=
5745        [[ 8.99990976e-01,  1.00000791e-01],
5746         [ 6.99930906e-01,  7.99999654e-01]]))
5747    """
5748
5749    __mindspore_signature__ = (
5750        sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
5751        sig.make_sig('accum', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
5752        sig.make_sig('accum_update', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
5753        sig.make_sig('lr', dtype=sig.sig_dtype.T1),
5754        sig.make_sig('rho', dtype=sig.sig_dtype.T2),
5755        sig.make_sig('epsilon', dtype=sig.sig_dtype.T3),
5756        sig.make_sig('grad', dtype=sig.sig_dtype.T)
5757    )
5758
5759    @prim_attr_register
5760    def __init__(self):
5761        """Initialize ApplyAdadelta"""
5762        self.add_prim_attr('side_effect_mem', True)
5763
5764    def infer_shape(self, var_shape, accum_shape, accum_update_shape, lr_shape, rho_shape,
5765                    epsilon_shape, grad_shape):
5766        validator.check("accum_shape", accum_shape, "var_shape", var_shape, Rel.EQ, self.name)
5767        validator.check("accum_update_shape", accum_update_shape, "var_shape", var_shape, Rel.EQ, self.name)
5768        validator.check("grad_shape", grad_shape, "var_shape", var_shape, Rel.EQ, self.name)
5769        lr_shp_len = len(lr_shape)
5770        validator.check_int(lr_shp_len, 1, Rel.LE, "lr's rank", self.name)
5771        if lr_shp_len == 1:
5772            validator.check_int(lr_shape[0], 1, Rel.EQ, "lr_shape[0]", self.name)
5773        rho_shp_len = len(rho_shape)
5774        validator.check_int(rho_shp_len, 1, Rel.LE, "rho's rank", self.name)
5775        if rho_shp_len == 1:
5776            validator.check_int(rho_shape[0], 1, Rel.EQ, "rho_shape[0]", self.name)
5777        epsilon_shp_len = len(epsilon_shape)
5778        validator.check_int(epsilon_shp_len, 1, Rel.LE, "lepsilon's rank", self.name)
5779        if epsilon_shp_len == 1:
5780            validator.check_int(epsilon_shape[0], 1, Rel.EQ, "epsilon_shape[0]", self.name)
5781        return var_shape, accum_shape, accum_update_shape
5782
5783    def infer_dtype(self, var_dtype, accum_dtype, accum_update_dtype, lr_dtype, rho_dtype,
5784                    epsilon_dtype, grad_dtype):
5785        valid_dtypes = [mstype.float16, mstype.float32]
5786        args = {"var": var_dtype, "accum": accum_dtype, "accum_update": accum_update_dtype, "grad": grad_dtype}
5787        validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
5788        validator.check_scalar_or_tensor_types_same({"lr": lr_dtype}, valid_dtypes, self.name)
5789        validator.check_scalar_or_tensor_types_same({"rho": rho_dtype}, valid_dtypes, self.name)
5790        validator.check_scalar_or_tensor_types_same({"epsilon": epsilon_dtype}, valid_dtypes, self.name)
5791        return var_dtype, accum_dtype, accum_update_dtype
5792
5793
5794class ApplyAdagrad(PrimitiveWithInfer):
5795    r"""
5796    Updates relevant entries according to the adagrad scheme.
5797
5798    .. math::
5799        \begin{array}{ll} \\
5800            accum += grad * grad \\
5801            var -= lr * grad * \frac{1}{\sqrt{accum}}
5802        \end{array}
5803
5804    Inputs of `var`, `accum` and `grad`  comply with the implicit type conversion rules
5805    to make the data types consistent.
5806    If they have different data types, lower priority data type will be converted to
5807    relatively highest priority data type.
5808    RuntimeError exception will be thrown when the data type conversion of Parameter is required.
5809
5810    Args:
5811        update_slots (bool): If `True`, `accum` will be updated. Default: True.
5812
5813    Inputs:
5814        - **var** (Parameter) - Variable to be updated. With float32 or float16 data type.
5815          The shape is :math:`(N, *)` where :math:`*` means, any number of additional dimensions.
5816        - **accum** (Parameter) - Accumulation to be updated. The shape and data type must be the same as `var`.
5817        - **lr** (Union[Number, Tensor]) - The learning rate value, must be scalar. With float32 or float16 data type.
5818        - **grad** (Tensor) - A tensor for gradient. The shape and data type must be the same as `var`.
5819
5820    Outputs:
5821        Tuple of 2 Tensors, the updated parameters.
5822
5823        - **var** (Tensor) - The same shape and data type as `var`.
5824        - **accum** (Tensor) - The same shape and data type as `accum`.
5825
5826    Raises:
5827        TypeError: If dtype of `var`, `accum`, `lr` or `grad` is neither float16 nor float32.
5828        TypeError: If `lr` is neither a Number nor a Tensor.
5829
5830    Supported Platforms:
5831        ``Ascend`` ``CPU`` ``GPU``
5832
5833    Examples:
5834        >>> class Net(nn.Cell):
5835        ...     def __init__(self):
5836        ...         super(Net, self).__init__()
5837        ...         self.apply_adagrad = ops.ApplyAdagrad()
5838        ...         self.var = Parameter(Tensor(np.array([[0.6, 0.4],
5839        ...                                               [0.1, 0.5]]).astype(np.float32)), name="var")
5840        ...         self.accum = Parameter(Tensor(np.array([[0.6, 0.5],
5841        ...                                                 [0.2, 0.6]]).astype(np.float32)), name="accum")
5842        ...     def construct(self, lr, grad):
5843        ...         out = self.apply_adagrad(self.var, self.accum, lr, grad)
5844        ...         return out
5845        ...
5846        >>> net = Net()
5847        >>> lr = Tensor(0.001, mindspore.float32)
5848        >>> grad = Tensor(np.array([[0.3, 0.7], [0.1, 0.8]]).astype(np.float32))
5849        >>> output = net(lr, grad)
5850        >>> print(output)
5851        (Tensor(shape=[2, 2], dtype=Float32, value=
5852        [[ 5.99638879e-01,  3.99296492e-01],
5853         [ 9.97817814e-02,  4.99281585e-01]]), Tensor(shape=[2, 2], dtype=Float32, value=
5854        [[ 6.90000057e-01,  9.90000010e-01],
5855         [ 2.10000008e-01,  1.24000001e+00]]))
5856    """
5857
5858    __mindspore_signature__ = (
5859        sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
5860        sig.make_sig('accum', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
5861        sig.make_sig('lr', dtype=sig.sig_dtype.T1),
5862        sig.make_sig('grad', dtype=sig.sig_dtype.T)
5863    )
5864
5865    @prim_attr_register
5866    def __init__(self, update_slots=True):
5867        """Initialize ApplyAdagrad."""
5868        validator.check_value_type("update_slots", update_slots, [bool], self.name)
5869        self.add_prim_attr('side_effect_mem', True)
5870
5871    def infer_shape(self, var_shape, accum_shape, lr_shape, grad_shape):
5872        validator.check('accum shape', accum_shape, 'var shape', var_shape, Rel.EQ, self.name)
5873        validator.check('grad shape', grad_shape, 'var shape', var_shape, Rel.EQ, self.name)
5874        lr_shp_len = len(lr_shape)
5875        validator.check_int(lr_shp_len, 1, Rel.LE, "lr's rank", self.name)
5876        if lr_shp_len == 1:
5877            validator.check_int(lr_shape[0], 1, Rel.EQ, "lr_shape[0]", self.name)
5878        return var_shape, accum_shape
5879
5880    def infer_dtype(self, var_dtype, accum_dtype, lr_dtype, grad_dtype):
5881        args = {'var': var_dtype, 'accum': accum_dtype, 'grad': grad_dtype}
5882        valid_dtypes = [mstype.float16, mstype.float32]
5883        validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
5884        validator.check_scalar_or_tensor_types_same({'lr': lr_dtype}, valid_dtypes, self.name)
5885        return var_dtype, accum_dtype
5886
5887
5888class ApplyAdagradV2(PrimitiveWithInfer):
5889    r"""
5890    Updates relevant entries according to the adagradv2 scheme.
5891
5892    .. math::
5893        \begin{array}{ll} \\
5894            accum += grad * grad \\
5895            var -= lr * grad * \frac{1}{\sqrt{accum} + \epsilon}
5896        \end{array}
5897
5898    where :math:`\epsilon` represents `epsilon`.
5899
5900    Inputs of `var`, `accum` and `grad` comply with the implicit type conversion rules
5901    to make the data types consistent.
5902    If they have different data types, lower priority data type will be converted to
5903    relatively highest priority data type.
5904    RuntimeError exception will be thrown when the data type conversion of Parameter is required.
5905
5906    Note:
5907        The difference is that `ApplyAdagradV2` has one more small constant value than `ApplyAdagrad`.
5908
5909    Args:
5910        epsilon (float): A small value added for numerical stability.
5911        update_slots (bool): If `True`, `accum` will be updated. Default: True.
5912
5913    Inputs:
5914        - **var** (Parameter) - Variable to be updated. With float16 or float32 data type.
5915          The shape is :math:`(N, *)` where :math:`*` means, any number of additional dimensions.
5916        - **accum** (Parameter) - Accumulation to be updated. The shape and data type must be the same as `var`.
5917        - **lr** (Union[Number, Tensor]) - The learning rate value, must be a float number or
5918          a scalar tensor with float16 or float32 data type.
5919        - **grad** (Tensor) - A tensor for gradient. The shape and data type must be the same as `var`.
5920
5921    Outputs:
5922        Tuple of 2 Tensors, the updated parameters.
5923
5924        - **var** (Tensor) - The same shape and data type as `var`.
5925        - **accum** (Tensor) - The same shape and data type as `m`.
5926
5927    Raises:
5928        TypeError: If dtype of `var`, `accum`, `lr` or `grad` is neither float16 nor float32.
5929        TypeError: If `lr` is neither a Number nor a Tensor.
5930
5931    Supported Platforms:
5932        ``Ascend``
5933
5934    Examples:
5935        >>> class Net(nn.Cell):
5936        ...     def __init__(self):
5937        ...         super(Net, self).__init__()
5938        ...         self.apply_adagrad_v2 = ops.ApplyAdagradV2(epsilon=1e-6)
5939        ...         self.var = Parameter(Tensor(np.array([[0.6, 0.4],
5940        ...                                               [0.1, 0.5]]).astype(np.float32)), name="var")
5941        ...         self.accum = Parameter(Tensor(np.array([[0.6, 0.5],
5942        ...                                                 [0.2, 0.6]]).astype(np.float32)), name="accum")
5943        ...     def construct(self, lr, grad):
5944        ...         out = self.apply_adagrad_v2(self.var, self.accum, lr, grad)
5945        ...         return out
5946        ...
5947        >>> net = Net()
5948        >>> lr = Tensor(0.001, mindspore.float32)
5949        >>> grad = Tensor(np.array([[0.3, 0.7], [0.1, 0.8]]).astype(np.float32))
5950        >>> output = net(lr, grad)
5951        >>> print(output)
5952        (Tensor(shape=[2, 2], dtype=Float32, value=
5953        [[ 5.99638879e-01,  3.99296492e-01],
5954         [ 9.97817814e-02,  4.99281585e-01]]), Tensor(shape=[2, 2], dtype=Float32, value=
5955        [[ 6.90000057e-01,  9.90000010e-01],
5956         [ 2.10000008e-01,  1.24000001e+00]]))
5957    """
5958
5959    __mindspore_signature__ = (
5960        sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
5961        sig.make_sig('accum', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
5962        sig.make_sig('lr', dtype=sig.sig_dtype.T1),
5963        sig.make_sig('grad', dtype=sig.sig_dtype.T)
5964    )
5965
5966    @prim_attr_register
5967    def __init__(self, epsilon, update_slots=True):
5968        """Initialize ApplyAdagradV2."""
5969        validator.check_value_type("epsilon", epsilon, [float], self.name)
5970        validator.check_value_type("update_slots", update_slots, [bool], self.name)
5971        self.add_prim_attr('side_effect_mem', True)
5972
5973    def infer_shape(self, var_shape, accum_shape, lr_shape, grad_shape):
5974        validator.check('var shape', var_shape, 'accum shape', accum_shape, Rel.EQ, self.name)
5975        validator.check('var shape', var_shape, 'grad shape', grad_shape, Rel.EQ, self.name)
5976        lr_shp_len = len(lr_shape)
5977        validator.check_int(lr_shp_len, 1, Rel.LE, "lr's rank", self.name)
5978        if lr_shp_len == 1:
5979            validator.check_int(lr_shape[0], 1, Rel.EQ, "lr_shape[0]", self.name)
5980        return var_shape, accum_shape
5981
5982    def infer_dtype(self, var_dtype, accum_dtype, lr_dtype, grad_dtype):
5983        args = {'var': var_dtype, 'accum': accum_dtype, 'grad': grad_dtype}
5984        validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name)
5985        validator.check_scalar_or_tensor_types_same({'lr': lr_dtype}, [mstype.float16, mstype.float32], self.name)
5986        return var_dtype, accum_dtype
5987
5988
5989class SparseApplyAdagrad(PrimitiveWithInfer):
5990    r"""
5991    Updates relevant entries according to the adagrad scheme.
5992
5993    .. math::
5994        \begin{array}{ll} \\
5995            accum += grad * grad \\
5996            var -= lr * grad * (1 / sqrt(accum))
5997        \end{array}
5998
5999    Inputs of `var`, `accum` and `grad` comply with the implicit type conversion rules
6000    to make the data types consistent.
6001    If they have different data types, lower priority data type will be converted to
6002    relatively highest priority data type.
6003    RuntimeError exception will be thrown when the data type conversion of Parameter is required.
6004
6005    Args:
6006        lr (float): Learning rate.
6007        update_slots (bool): If `True`, `accum` will be updated. Default: True.
6008        use_locking (bool): If true, the `var` and `accum` tensors will be protected from being updated.
6009            Default: False.
6010
6011    Inputs:
6012        - **var** (Parameter) - Variable to be updated. The data type must be float16 or float32.
6013          The shape is :math:`(N, *)` where :math:`*` means, any number of additional dimensions.
6014        - **accum** (Parameter) - Accumulation to be updated. The shape and data type must be the same as `var`.
6015        - **grad** (Tensor) - Gradients has the same data type as `var` and
6016          grad.shape[1:] = var.shape[1:] if var.shape > 1.
6017        - **indices** (Tensor) - A vector of indices into the first dimension of `var` and `accum`.
6018          The type must be int32 and indices.shape[0] = grad.shape[0].
6019
6020    Outputs:
6021        Tuple of 2 tensors, the updated parameters.
6022
6023        - **var** (Tensor) - The same shape and data type as `var`.
6024        - **accum** (Tensor) - The same shape and data type as `accum`.
6025
6026    Raises:
6027        TypeError: If `lr` is not a float.
6028        TypeError: If neither `update_slots` nor `use_locking` is a bool.
6029        TypeError: If dtype of `var`, `accum` or `grad` is neither float16 nor float32.
6030        TypeError: If dtype of `indices` is not int32.
6031
6032
6033    Supported Platforms:
6034        ``Ascend``
6035
6036    Examples:
6037        >>> class Net(nn.Cell):
6038        ...     def __init__(self):
6039        ...         super(Net, self).__init__()
6040        ...         self.sparse_apply_adagrad = ops.SparseApplyAdagrad(lr=1e-8)
6041        ...         self.var = Parameter(Tensor(np.array([[[0.2]]]).astype(np.float32)), name="var")
6042        ...         self.accum = Parameter(Tensor(np.array([[[0.1]]]).astype(np.float32)), name="accum")
6043        ...     def construct(self, grad, indices):
6044        ...         out = self.sparse_apply_adagrad(self.var, self.accum, grad, indices)
6045        ...         return out
6046        ...
6047        >>> net = Net()
6048        >>> grad = Tensor(np.array([[[0.7]]]).astype(np.float32))
6049        >>> indices = Tensor([0], mindspore.int32)
6050        >>> output = net(grad, indices)
6051        >>> print(output)
6052        (Tensor(shape=[1, 1, 1], dtype=Float32, value=
6053        [[[1.99999988e-01]]]), Tensor(shape=[1, 1, 1], dtype=Float32, value=
6054        [[[1.00000001e-01]]]))
6055    """
6056
6057    __mindspore_signature__ = (
6058        sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
6059        sig.make_sig('accum', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
6060        sig.make_sig('grad', dtype=sig.sig_dtype.T),
6061        sig.make_sig('indices', dtype=sig.sig_dtype.T1)
6062    )
6063
6064    @prim_attr_register
6065    def __init__(self, lr, update_slots=True, use_locking=False):
6066        """Initialize SparseApplyAdagrad."""
6067        validator.check_is_float(lr, "lr", self.name)
6068        validator.check_value_type("update_slots", update_slots, [bool], self.name)
6069        validator.check_value_type("use_locking", use_locking, [bool], self.name)
6070        self.add_prim_attr('side_effect_mem', True)
6071
6072    def infer_shape(self, var_shape, accum_shape, grad_shape, indices_shape):
6073        validator.check('var shape', var_shape, 'accum shape', accum_shape, Rel.EQ, self.name)
6074        validator.check('len of var shape', len(var_shape), 'len of grad shape', len(grad_shape), Rel.EQ, self.name)
6075        if len(var_shape) > 1:
6076            validator.check('var_shape[1:]', var_shape[1:], 'grad_shape[1:]', grad_shape[1:], Rel.EQ, self.name)
6077        validator.check_int(len(indices_shape), 1, Rel.EQ, "indices rank", self.name)
6078        validator.check('grad_shape[0]', grad_shape[0], 'indices_shape[0]', indices_shape[0], Rel.EQ, self.name)
6079        return var_shape, accum_shape
6080
6081    def infer_dtype(self, var_type, accum_type, grad_type, indices_type):
6082        args = {'var': var_type, 'accum': accum_type, 'grad': grad_type}
6083        validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name)
6084        validator.check_tensor_dtype_valid('indices', indices_type, [mstype.int32], self.name)
6085        return var_type, accum_type
6086
6087
6088class SparseApplyAdagradV2(PrimitiveWithInfer):
6089    r"""
6090    Updates relevant entries according to the adagrad scheme, one more epsilon attribute than SparseApplyAdagrad.
6091
6092    .. math::
6093        \begin{array}{ll} \\
6094            accum += grad * grad \\
6095            var -= lr * grad * \frac{1}{\sqrt{accum} + \epsilon}
6096        \end{array}
6097
6098    where :math:`\epsilon` represents `epsilon`.
6099
6100    Inputs of `var`, `accum` and `grad` comply with the implicit type conversion rules
6101    to make the data types consistent.
6102    If they have different data types, lower priority data type will be converted to
6103    relatively highest priority data type.
6104    RuntimeError exception will be thrown when the data type conversion of Parameter is required.
6105
6106    Args:
6107        lr (float): Learning rate.
6108        epsilon (float): A small value added for numerical stability.
6109        use_locking (bool): If `True`, the `var` and `accum` tensors will be protected from being updated.
6110            Default: False.
6111        update_slots (bool): If `True`, the computation logic will be different to `False`. Default: True.
6112
6113    Inputs:
6114        - **var** (Parameter) - Variable to be updated. The data type must be float16 or float32.
6115          The shape is :math:`(N, *)` where :math:`*` means, any number of additional dimensions.
6116        - **accum** (Parameter) - Accumulation to be updated. The shape and data type must be the same as `var`.
6117        - **grad** (Tensor) - Gradients has the same data type as `var` and
6118          grad.shape[1:] = var.shape[1:] if var.shape > 1.
6119        - **indices** (Tensor) - A vector of indices into the first dimension of `var` and `accum`.
6120          The type must be int32 and indices.shape[0] = grad.shape[0].
6121
6122    Outputs:
6123        Tuple of 2 tensors, the updated parameters.
6124
6125        - **var** (Tensor) - The same shape and data type as `var`.
6126        - **accum** (Tensor) - The same shape and data type as `accum`.
6127
6128    Raises:
6129        TypeError: If neither `lr` nor `epsilon` is a float.
6130        TypeError: If neither `update_slots` nor `use_locking` is a bool.
6131        TypeError: If dtype of `var`, `accum` or `grad` is neither float16 nor float32.
6132        TypeError: If dtype of `indices` is not int32.
6133
6134    Supported Platforms:
6135        ``Ascend``
6136
6137    Examples:
6138        >>> class Net(nn.Cell):
6139        ...     def __init__(self):
6140        ...         super(Net, self).__init__()
6141        ...         self.sparse_apply_adagrad_v2 = ops.SparseApplyAdagradV2(lr=1e-8, epsilon=1e-6)
6142        ...         self.var = Parameter(Tensor(np.array([[0.2]]).astype(np.float32)), name="var")
6143        ...         self.accum = Parameter(Tensor(np.array([[0.1]]).astype(np.float32)), name="accum")
6144        ...
6145        ...     def construct(self, grad, indices):
6146        ...         out = self.sparse_apply_adagrad_v2(self.var, self.accum, grad, indices)
6147        ...         return out
6148        ...
6149        >>> net = Net()
6150        >>> grad = Tensor(np.array([[0.7]]).astype(np.float32))
6151        >>> indices = Tensor(np.ones([1]), mindspore.int32)
6152        >>> output = net(grad, indices)
6153        >>> print(output)
6154        (Tensor(shape=[1, 1], dtype=Float32, value=
6155        [[ 2.00000003e-01]]), Tensor(shape=[1, 1], dtype=Float32, value=
6156        [[ 1.00000001e-01]]))
6157    """
6158
6159    __mindspore_signature__ = (
6160        sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
6161        sig.make_sig('accum', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
6162        sig.make_sig('grad', dtype=sig.sig_dtype.T),
6163        sig.make_sig('indices', dtype=sig.sig_dtype.T1)
6164    )
6165
6166    @prim_attr_register
6167    def __init__(self, lr, epsilon, use_locking=False, update_slots=True):
6168        """Initialize SparseApplyAdagradV2."""
6169        self.lr = validator.check_value_type("lr", lr, [float], self.name)
6170        self.epsilon = validator.check_value_type("epsilon", epsilon, [float], self.name)
6171        self.use_locking = validator.check_value_type("update_slots", update_slots, [bool], self.name)
6172        self.update_slots = validator.check_value_type("use_locking", use_locking, [bool], self.name)
6173        self.add_prim_attr('side_effect_mem', True)
6174
6175    def infer_shape(self, var_shape, accum_shape, grad_shape, indices_shape):
6176        validator.check('var shape', var_shape, 'accum shape', accum_shape, Rel.EQ, self.name)
6177        validator.check('len of var shape', len(var_shape), 'len of grad shape', len(grad_shape), Rel.EQ, self.name)
6178        if len(var_shape) > 1:
6179            validator.check('var_shape[1:]', var_shape[1:], 'grad_shape[1:]', grad_shape[1:], Rel.EQ, self.name)
6180        validator.check_int(len(indices_shape), 1, Rel.EQ, "indices rank", self.name)
6181        validator.check('grad_shape[0]', grad_shape[0], 'indices_shape[0]', indices_shape[0], Rel.EQ, self.name)
6182        return var_shape, accum_shape
6183
6184    def infer_dtype(self, var_type, accum_type, grad_type, indices_type):
6185        args = {'var': var_type, 'accum': accum_type, 'grad': grad_type}
6186        validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name)
6187        validator.check_tensor_dtype_valid('indices', indices_type, [mstype.int32], self.name)
6188        return var_type, accum_type
6189
6190
6191class ApplyProximalAdagrad(PrimitiveWithInfer):
6192    r"""
6193    Updates relevant entries according to the proximal adagrad algorithm.
6194
6195    .. math::
6196        \begin{array}{ll} \\
6197            accum += grad * grad \\
6198            \text{prox_v} = var - lr * grad * \frac{1}{\sqrt{accum}} \\
6199            var = \frac{sign(\text{prox_v})}{1 + lr * l2} * \max(\left| \text{prox_v} \right| - lr * l1, 0)
6200        \end{array}
6201
6202    Inputs of `var`, `accum` and `grad` comply with the implicit type conversion rules
6203    to make the data types consistent.
6204    If they have different data types, lower priority data type will be converted to
6205    relatively highest priority data type.
6206    RuntimeError exception will be thrown when the data type conversion of Parameter is required.
6207
6208    Args:
6209        use_locking (bool): If true, the var and accumulation tensors will be protected from being updated.
6210            Default: False.
6211
6212    Inputs:
6213        - **var** (Parameter) - Variable to be updated. The data type must be float16 or float32.
6214          The shape is :math:`(N, *)` where :math:`*` means, any number of additional dimensions.
6215        - **accum** (Parameter) - Accumulation to be updated. Must has the same shape and dtype as `var`.
6216        - **lr** (Union[Number, Tensor]) - The learning rate value, must be scalar. The data type must be
6217          float16 or float32.
6218        - **l1** (Union[Number, Tensor]) - l1 regularization strength, must be scalar. The data type must be
6219          float16 or float32.
6220        - **l2** (Union[Number, Tensor]) - l2 regularization strength, must be scalar. The data type must be
6221          float16 or float32.
6222        - **grad** (Tensor) - Gradient with the same shape and dtype as `var`.
6223
6224    Outputs:
6225        Tuple of 2 Tensors, the updated parameters.
6226
6227        - **var** (Tensor) - The same shape and data type as `var`.
6228        - **accum** (Tensor) - The same shape and data type as `accum`.
6229
6230    Raises:
6231        TypeError: If `use_blocking` is not a bool.
6232        TypeError: If dtype of `var`, `lr`, `l1` or `l2` is neither float16 nor float32.
6233        TypeError: If `lr`, `l1` or `l2` is neither a Number nor a Tensor.
6234        TypeError: If `grad` is not a Tensor.
6235
6236    Supported Platforms:
6237        ``Ascend``
6238
6239    Examples:
6240        >>> class Net(nn.Cell):
6241        ...     def __init__(self):
6242        ...         super(Net, self).__init__()
6243        ...         self.apply_proximal_adagrad = ops.ApplyProximalAdagrad()
6244        ...         self.var = Parameter(Tensor(np.array([[0.6, 0.4],
6245        ...                                               [0.1, 0.5]]).astype(np.float32)), name="var")
6246        ...         self.accum = Parameter(Tensor(np.array([[0.6, 0.5],
6247        ...                                                 [0.2, 0.6]]).astype(np.float32)), name="accum")
6248        ...         self.lr = 0.01
6249        ...         self.l1 = 0.0
6250        ...         self.l2 = 0.0
6251        ...     def construct(self, grad):
6252        ...         out = self.apply_proximal_adagrad(self.var, self.accum, self.lr, self.l1, self.l2, grad)
6253        ...         return out
6254        ...
6255        >>> net = Net()
6256        >>> grad = Tensor(np.array([[0.3, 0.7], [0.1, 0.8]]).astype(np.float32))
6257        >>> output = net(grad)
6258        >>> print(output)
6259        (Tensor(shape=[2, 2], dtype=Float32, value=
6260        [[ 5.96388459e-01,  3.92964751e-01],
6261         [ 9.78178233e-02,  4.92815793e-01]]), Tensor(shape=[2, 2], dtype=Float32, value=
6262        [[ 6.90000057e-01,  9.90000010e-01],
6263         [ 2.10000008e-01,  1.24000001e+00]]))
6264    """
6265
6266    __mindspore_signature__ = (
6267        sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
6268        sig.make_sig('accum', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
6269        sig.make_sig('lr', dtype=sig.sig_dtype.T1),
6270        sig.make_sig('l1', dtype=sig.sig_dtype.T2),
6271        sig.make_sig('l2', dtype=sig.sig_dtype.T3),
6272        sig.make_sig('grad', dtype=sig.sig_dtype.T)
6273    )
6274
6275    @prim_attr_register
6276    def __init__(self, use_locking=False):
6277        """Initialize ApplyProximalAdagrad."""
6278        self.init_prim_io_names(inputs=['var', 'accum', 'lr', 'l1', 'l2', 'grad'],
6279                                outputs=['var', 'accum'])
6280        self.add_prim_attr('side_effect_mem', True)
6281        self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name)
6282
6283    def infer_shape(self, var_shape, accum_shape, lr_shape, l1_shape, l2_shape, grad_shape):
6284        validator.check('accum shape', accum_shape, 'var shape', var_shape, Rel.EQ, self.name)
6285        validator.check('grad shape', grad_shape, 'var shape', var_shape, Rel.EQ, self.name)
6286        lr_shp_len = len(lr_shape)
6287        validator.check_int(lr_shp_len, 1, Rel.LE, "lr's rank", self.name)
6288        if lr_shp_len == 1:
6289            validator.check_int(lr_shape[0], 1, Rel.EQ, "lr_shape[0]", self.name)
6290        l1_shp_len = len(l1_shape)
6291        validator.check_int(l1_shp_len, 1, Rel.LE, "l1's rank", self.name)
6292        if l1_shp_len == 1:
6293            validator.check_int(l1_shape[0], 1, Rel.EQ, "l1_shape[0]", self.name)
6294        l2_shp_len = len(l2_shape)
6295        validator.check_int(l2_shp_len, 1, Rel.LE, "l2's rank", self.name)
6296        if l2_shp_len == 1:
6297            validator.check_int(l2_shape[0], 1, Rel.EQ, "l2_shape[0]", self.name)
6298        return var_shape, accum_shape
6299
6300    def infer_dtype(self, var_dtype, accum_dtype, lr_dtype, l1_dtype, l2_dtype, grad_dtype):
6301        valid_dtypes = [mstype.float16, mstype.float32]
6302        args = {'var': var_dtype, 'accum': accum_dtype, 'grad': grad_dtype}
6303        validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
6304        validator.check_scalar_or_tensor_types_same({"lr": lr_dtype}, valid_dtypes, self.name)
6305        validator.check_scalar_or_tensor_types_same({"l1": l1_dtype}, valid_dtypes, self.name)
6306        validator.check_scalar_or_tensor_types_same({"l2": l2_dtype}, valid_dtypes, self.name)
6307        return var_dtype, accum_dtype
6308
6309
6310class SparseApplyProximalAdagrad(PrimitiveWithCheck):
6311    r"""
6312    Updates relevant entries according to the proximal adagrad algorithm. Compared with ApplyProximalAdagrad,
6313    an additional index tensor is input.
6314
6315    .. math::
6316        \begin{array}{ll} \\
6317            accum += grad * grad \\
6318            \text{prox_v} = var - lr * grad * \frac{1}{\sqrt{accum}} \\
6319            var = \frac{sign(\text{prox_v})}{1 + lr * l2} * \max(\left| \text{prox_v} \right| - lr * l1, 0)
6320        \end{array}
6321
6322    Inputs of `var`, `accum` and `grad` comply with the implicit type conversion rules
6323    to make the data types consistent.
6324    If they have different data types, lower priority data type will be converted to
6325    relatively highest priority data type.
6326    RuntimeError exception will be thrown when the data type conversion of Parameter is required.
6327
6328    Args:
6329        use_locking (bool): If true, the `var` and `accum` tensors will be protected from being updated.
6330            Default: False.
6331
6332    Inputs:
6333        - **var** (Parameter) - Variable tensor to be updated. The data type must be float16 or float32.
6334          The shape is :math:`(N, *)` where :math:`*` means, any number of additional dimensions.
6335        - **accum** (Parameter) - Variable tensor to be updated, has the same shape and dtype as `var`.
6336        - **lr** (Union[Number, Tensor]) - The learning rate value, must be a float number or
6337          a scalar tensor with float16 or float32 data type.
6338        - **l1** (Union[Number, Tensor]) - l1 regularization strength, must be a float number or
6339          a scalar tensor with float16 or float32 data type.
6340        - **l2** (Union[Number, Tensor]) - l2 regularization strength, must be a float number or
6341          a scalar tensor with float16 or float32 data type..
6342        - **grad** (Tensor) - A tensor of the same type as `var` and
6343          grad.shape[1:] = var.shape[1:] if var.shape > 1.
6344        - **indices** (Tensor) - A tensor of indices in the first dimension of `var` and `accum`.
6345          If there are duplicates in `indices`, the behavior is undefined. Must be one of the
6346          following types: int32, int64 and indices.shape[0] = grad.shape[0].
6347
6348    Outputs:
6349        Tuple of 2 tensors, the updated parameters.
6350
6351        - **var** (Tensor) - The same shape and data type as `var`.
6352        - **accum** (Tensor) - The same shape and data type as `accum`.
6353
6354    Raises:
6355        TypeError: If `use_locking` is not a bool.
6356        TypeError: If dtype of `var`, `accum`, `lr`, `l1`, `l2`, `scalar` or `grad` is neither float16 nor float32.
6357        TypeError: If dtype of `indices` is neither int32 nor int64.
6358
6359    Supported Platforms:
6360        ``Ascend`` ``GPU``
6361
6362    Examples:
6363        >>> class Net(nn.Cell):
6364        ...     def __init__(self):
6365        ...         super(Net, self).__init__()
6366        ...         self.sparse_apply_proximal_adagrad = ops.SparseApplyProximalAdagrad()
6367        ...         self.var = Parameter(Tensor(np.array([[4.1, 7.2], [1.1, 3.0]], np.float32)), name="var")
6368        ...         self.accum = Parameter(Tensor(np.array([[0, 0], [0, 0]], np.float32)), name="accum")
6369        ...         self.lr = 1.0
6370        ...         self.l1 = 1.0
6371        ...         self.l2 = 0.0
6372        ...     def construct(self, grad, indices):
6373        ...         out = self.sparse_apply_proximal_adagrad(self.var, self.accum, self.lr, self.l1,
6374        ...                                                  self.l2, grad, indices)
6375        ...         return out
6376        ...
6377        >>> net = Net()
6378        >>> grad = Tensor(np.array([[1, 1], [1, 1]], np.float32))
6379        >>> indices = Tensor(np.array([0, 1], np.int32))
6380        >>> output = net(grad, indices)
6381        >>> print(output)
6382        (Tensor(shape=[2, 2], dtype=Float32, value=
6383        [[ 2.09999990e+00,  5.19999981e+00],
6384         [ 0.00000000e+00,  1.00000000e+00]]), Tensor(shape=[2, 2], dtype=Float32, value=
6385        [[ 1.00000000e+00,  1.00000000e+00],
6386         [ 1.00000000e+00,  1.00000000e+00]]))
6387    """
6388
6389    __mindspore_signature__ = (
6390        sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
6391        sig.make_sig('accum', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
6392        sig.make_sig('lr', dtype=sig.sig_dtype.T1),
6393        sig.make_sig('l1', dtype=sig.sig_dtype.T2),
6394        sig.make_sig('l2', dtype=sig.sig_dtype.T3),
6395        sig.make_sig('grad', dtype=sig.sig_dtype.T),
6396        sig.make_sig('indices', dtype=sig.sig_dtype.T4)
6397    )
6398
6399    @prim_attr_register
6400    def __init__(self, use_locking=False):
6401        """Initialize SparseApplyProximalAdagrad."""
6402        self.init_prim_io_names(inputs=['var', 'accum', 'lr', 'l1', 'l2', 'grad', 'indices'],
6403                                outputs=['var', 'accum'])
6404        self.add_prim_attr('side_effect_mem', True)
6405        self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name)
6406
6407    def check_shape(self, var_shape, accum_shape, lr_shape, l1_shape, l2_shape,
6408                    grad_shape, indices_shape):
6409        validator.check_int(len(indices_shape), 1, Rel.EQ, "indices rank", self.name)
6410
6411    def check_dtype(self, var_dtype, accum_dtype, lr_dtype, l1_dtype, l2_dtype,
6412                    grad_dtype, indices_dtype):
6413        args = {'var': var_dtype, 'accum': accum_dtype, 'grad': grad_dtype}
6414        validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name)
6415        validator.check_scalar_or_tensor_types_same({"lr": lr_dtype}, [mstype.float16, mstype.float32], self.name)
6416        validator.check_scalar_or_tensor_types_same({"l1": l1_dtype}, [mstype.float16, mstype.float32], self.name)
6417        validator.check_scalar_or_tensor_types_same({"l2": l2_dtype}, [mstype.float16, mstype.float32], self.name)
6418        valid_dtypes = [mstype.int32, mstype.int64]
6419        validator.check_tensor_dtype_valid('indices', indices_dtype, valid_dtypes, self.name)
6420
6421
6422class ApplyAddSign(PrimitiveWithInfer):
6423    r"""
6424    Updates relevant entries according to the AddSign algorithm.
6425
6426    .. math::
6427        \begin{array}{ll} \\
6428            m_{t+1} = \beta * m_{t} + (1 - \beta) * g \\
6429            \text{update} = (\alpha + \text{sign_decay} * sign(g) * sign(m)) * g \\
6430            var = var - lr_{t+1} * \text{update}
6431        \end{array}
6432
6433    :math:`t` represents updating step while :math:`m` represents the 1st moment vector, :math:`m_{t}`
6434    is the last momentent of :math:`m_{t+1}`, :math:`lr` represents scaling factor `lr`, :math:`g` represents `grad`,
6435    :math:`\alpha` represents `alpha`, :math:`\beta` represents `beta`.
6436
6437    Inputs of `var`, `accum` and `grad`  comply with the implicit type conversion rules
6438    to make the data types consistent.
6439    If they have different data types, lower priority data type will be converted to
6440    relatively highest priority data type.
6441    RuntimeError exception will be thrown when the data type conversion of Parameter is required.
6442
6443    Inputs:
6444        - **var** (Parameter) - Variable tensor to be updated. With float32 or float16 data type.
6445          The shape is :math:`(N, *)` where :math:`*` means, any number of additional dimensions.
6446        - **m** (Parameter) - Variable tensor to be updated, has the same shape and data type as `var`.
6447        - **lr** (Union[Number, Tensor]) - The learning rate value, must be a scalar.
6448          With float32 or float16 data type.
6449        - **alpha** (Union[Number, Tensor]) - Must be a scalar. With float32 or float16 data type.
6450        - **sign_decay** (Union[Number, Tensor]) - Must be a scalar. With float32 or float16 data type.
6451        - **beta** (Union[Number, Tensor]) - The exponential decay rate, must be a scalar.
6452          With float32 or float16 data type.
6453        - **grad** (Tensor) - A tensor of the same shape and data type as `var`, for the gradient.
6454
6455    Outputs:
6456        Tuple of 2 Tensors, the updated parameters.
6457
6458        - **var** (Tensor) - The same shape and data type as `var`.
6459        - **m** (Tensor) - The same shape and data type as `m`.
6460
6461    Raises:
6462        TypeError: If dtype of `var`, `lr`, `alpha`, `sign_decay` or `beta` is neither float16 nor float32.
6463        TypeError: If `lr`, `alpha` or `sign_decay` is neither a Number nor a Tensor.
6464        TypeError: If `grad` is not a Tensor.
6465
6466    Supported Platforms:
6467        ``Ascend``
6468
6469    Examples:
6470        >>> class Net(nn.Cell):
6471        ...     def __init__(self):
6472        ...         super(Net, self).__init__()
6473        ...         self.apply_add_sign = ops.ApplyAddSign()
6474        ...         self.var = Parameter(Tensor(np.array([[0.6, 0.4],
6475        ...                                               [0.1, 0.5]]).astype(np.float32)), name="var")
6476        ...         self.m = Parameter(Tensor(np.array([[0.6, 0.5],
6477        ...                                             [0.2, 0.6]]).astype(np.float32)), name="m")
6478        ...         self.lr = 0.001
6479        ...         self.alpha = 1.0
6480        ...         self.sign_decay = 0.99
6481        ...         self.beta = 0.9
6482        ...     def construct(self, grad):
6483        ...         out = self.apply_add_sign(self.var, self.m, self.lr, self.alpha, self.sign_decay, self.beta, grad)
6484        ...         return out
6485        ...
6486        >>> net = Net()
6487        >>> grad = Tensor(np.array([[0.3, 0.7], [0.1, 0.8]]).astype(np.float32))
6488        >>> output = net(grad)
6489        >>> print(output)
6490        (Tensor(shape=[2, 2], dtype=Float32, value=
6491        [[ 5.99403024e-01,  3.98607016e-01],
6492         [ 9.98010039e-02,  4.98407990e-01]]), Tensor(shape=[2, 2], dtype=Float32, value=
6493        [[ 5.70000052e-01,  5.19999981e-01],
6494         [ 1.89999998e-01,  6.20000064e-01]]))
6495    """
6496
6497    __mindspore_signature__ = (
6498        sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
6499        sig.make_sig('m', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
6500        sig.make_sig('lr', dtype=sig.sig_dtype.T1),
6501        sig.make_sig('alpha', dtype=sig.sig_dtype.T2),
6502        sig.make_sig('sign_decay', dtype=sig.sig_dtype.T3),
6503        sig.make_sig('beta', dtype=sig.sig_dtype.T3),
6504        sig.make_sig('grad', dtype=sig.sig_dtype.T)
6505    )
6506
6507    @prim_attr_register
6508    def __init__(self):
6509        """Initialize ApplyAddSign."""
6510        self.add_prim_attr('side_effect_mem', True)
6511
6512    def infer_shape(self, var_shape, m_shape, lr_shape, alpha_shape, sign_decay_shape,
6513                    beta_shape, grad_shape):
6514        validator.check('m_shape', m_shape, 'var_shape', var_shape, Rel.EQ, self.name)
6515        validator.check('grad_shape', grad_shape, 'var_shape', var_shape, Rel.EQ, self.name)
6516        lr_shape_len = len(lr_shape)
6517        validator.check_int(lr_shape_len, 1, Rel.LE, "lr's rank", self.name)
6518        if lr_shape_len == 1:
6519            validator.check_int(lr_shape[0], 1, Rel.EQ, "lr_shape[0]", self.name)
6520        alpha_shape_len = len(alpha_shape)
6521        validator.check_int(alpha_shape_len, 1, Rel.LE, "alpha's rank", self.name)
6522        if alpha_shape_len == 1:
6523            validator.check_int(alpha_shape[0], 1, Rel.EQ, "alpha_shape[0]", self.name)
6524        sign_decay_shape_len = len(sign_decay_shape)
6525        validator.check_int(sign_decay_shape_len, 1, Rel.LE, "sign_decay's rank", self.name)
6526        if sign_decay_shape_len == 1:
6527            validator.check_int(sign_decay_shape[0], 1, Rel.EQ, "sign_decay_shape[0]", self.name)
6528        beta_shape_len = len(beta_shape)
6529        validator.check_int(beta_shape_len, 1, Rel.LE, "beta's rank", self.name)
6530        if beta_shape_len == 1:
6531            validator.check_int(beta_shape[0], 1, Rel.EQ, "beta_shape[0]", self.name)
6532        return var_shape, m_shape
6533
6534    def infer_dtype(self, var_dtype, m_dtype, lr_dtype, alpha_dtype, sign_decay_dtype,
6535                    beta_dtype, grad_dtype):
6536        valid_dtypes = [mstype.float16, mstype.float32]
6537        args = {'var': var_dtype, 'm': m_dtype, 'grad': grad_dtype}
6538        validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
6539        validator.check_scalar_or_tensor_types_same({"lr": lr_dtype}, valid_dtypes, self.name)
6540        validator.check_scalar_or_tensor_types_same({"alpha": alpha_dtype}, valid_dtypes, self.name)
6541        validator.check_scalar_or_tensor_types_same({"sign_decay": sign_decay_dtype}, valid_dtypes, self.name)
6542        validator.check_scalar_or_tensor_types_same({"beta": beta_dtype}, valid_dtypes, self.name)
6543        return var_dtype, m_dtype
6544
6545
6546class ApplyPowerSign(PrimitiveWithInfer):
6547    r"""
6548    Updates relevant entries according to the AddSign algorithm.
6549
6550    .. math::
6551        \begin{array}{ll} \\
6552            m_{t+1} = \beta * m_{t} + (1 - \beta) * g \\
6553            \text{update} = \exp(\text{logbase} * \text{sign_decay} * sign(g) * sign(m)) * g \\
6554            var = var - lr_{t+1} * \text{update}
6555        \end{array}
6556
6557    :math:`t` represents updating step while :math:`m` represents the 1st moment vector, :math:`m_{t}`
6558    is the last momentent of :math:`m_{t+1}`, :math:`lr` represents scaling factor `lr`, :math:`g` represents `grad`,
6559    :math:`\beta` represents `beta`.
6560
6561    All of inputs comply with the implicit type conversion rules to make the data types consistent.
6562    If `lr`, `logbase`, `sign_decay` or `beta` is a number, the number is automatically converted to Tensor,
6563    and the data type is consistent with the Tensor data type involved in the operation.
6564    If inputs are tensors and have different data types, lower priority data type will be converted to
6565    relatively highest priority data type.
6566    RuntimeError exception will be thrown when the data type conversion of Parameter is required.
6567
6568    Inputs:
6569        - **var** (Parameter) - Variable tensor to be updated. With float32 or float16 data type.
6570          If data type of `var` is float16, all inputs must have the same data type as `var`.
6571          The shape is :math:`(N, *)` where :math:`*` means, any number of additional dimensions.
6572        - **m** (Parameter) - Variable tensor to be updated, has the same shape and data type as `var`.
6573        - **lr** (Union[Number, Tensor]) - The learning rate value, must be a scalar.
6574          With float32 or float16 data type.
6575        - **logbase** (Union[Number, Tensor]) - Must be a scalar. With float32 or float16 data type.
6576        - **sign_decay** (Union[Number, Tensor]) - Must be a scalar. With float32 or float16 data type.
6577        - **beta** (Union[Number, Tensor]) - The exponential decay rate, must be a scalar.
6578          With float32 or float16 data type.
6579        - **grad** (Tensor) - A tensor of the same shape and data type as `var`, for the gradient.
6580
6581    Outputs:
6582        Tuple of 2 Tensors, the updated parameters.
6583
6584        - **var** (Tensor) - The same shape and data type as `var`.
6585        - **m** (Tensor) - The same shape and data type as `m`.
6586
6587    Raises:
6588        TypeError: If dtype of `var`, `lr`, `logbase`, `sign_decay`, `beta` or `grad` is neither float16 nor float32.
6589        TypeError: If `lr`, `logbase`, `sign_decay` or `beta` is neither a Number nor a Tensor.
6590        TypeError: If `grad` is not a Tensor.
6591
6592    Supported Platforms:
6593        ``Ascend``
6594
6595    Examples:
6596        >>> class Net(nn.Cell):
6597        ...     def __init__(self):
6598        ...         super(Net, self).__init__()
6599        ...         self.apply_power_sign = ops.ApplyPowerSign()
6600        ...         self.var = Parameter(Tensor(np.array([[0.6, 0.4],
6601        ...                                               [0.1, 0.5]]).astype(np.float32)), name="var")
6602        ...         self.m = Parameter(Tensor(np.array([[0.6, 0.5],
6603        ...                                             [0.2, 0.6]]).astype(np.float32)), name="m")
6604        ...         self.lr = 0.001
6605        ...         self.logbase = np.e
6606        ...         self.sign_decay = 0.99
6607        ...         self.beta = 0.9
6608        ...     def construct(self, grad):
6609        ...         out = self.apply_power_sign(self.var, self.m, self.lr, self.logbase,
6610        ...                                        self.sign_decay, self.beta, grad)
6611        ...         return out
6612        ...
6613        >>> net = Net()
6614        >>> grad = Tensor(np.array([[0.3, 0.7], [0.1, 0.8]]).astype(np.float32))
6615        >>> output = net(grad)
6616        >>> print(output)
6617        (Tensor(shape=[2, 2], dtype=Float32, value=
6618        [[ 5.95575690e-01,  3.89676481e-01],
6619         [ 9.85252112e-02,  4.88201708e-01]]), Tensor(shape=[2, 2], dtype=Float32, value=
6620        [[ 5.70000052e-01,  5.19999981e-01],
6621         [ 1.89999998e-01,  6.20000064e-01]]))
6622    """
6623
6624    __mindspore_signature__ = (
6625        sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
6626        sig.make_sig('m', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
6627        sig.make_sig('lr', dtype=sig.sig_dtype.T),
6628        sig.make_sig('logbase', dtype=sig.sig_dtype.T),
6629        sig.make_sig('sign_decay', dtype=sig.sig_dtype.T),
6630        sig.make_sig('beta', dtype=sig.sig_dtype.T),
6631        sig.make_sig('grad', dtype=sig.sig_dtype.T)
6632    )
6633
6634    @prim_attr_register
6635    def __init__(self):
6636        """Initialize ApplyPowerSign."""
6637        self.add_prim_attr('side_effect_mem', True)
6638
6639    def infer_shape(self, var_shape, m_shape, lr_shape, logbase_shape, sign_decay_shape,
6640                    beta_shape, grad_shape):
6641        validator.check('m_shape', m_shape, 'var_shape', var_shape, Rel.EQ, self.name)
6642        validator.check('grad_shape', grad_shape, 'var_shape', var_shape, Rel.EQ, self.name)
6643        lr_shape_len = len(lr_shape)
6644        validator.check_int(lr_shape_len, 1, Rel.LE, "lr's rank", self.name)
6645        if lr_shape_len == 1:
6646            validator.check_int(lr_shape[0], 1, Rel.EQ, "lr_shape[0]", self.name)
6647        logbase_shape_len = len(logbase_shape)
6648        validator.check_int(logbase_shape_len, 1, Rel.LE, "logbase's rank", self.name)
6649        if logbase_shape_len == 1:
6650            validator.check_int(logbase_shape[0], 1, Rel.EQ, "logbase_shape[0]", self.name)
6651        sign_decay_shape_len = len(sign_decay_shape)
6652        validator.check_int(sign_decay_shape_len, 1, Rel.LE, "sign_decay's rank", self.name)
6653        if sign_decay_shape_len == 1:
6654            validator.check_int(sign_decay_shape[0], 1, Rel.EQ, "sign_decay_shape[0]", self.name)
6655        beta_shape_len = len(beta_shape)
6656        validator.check_int(beta_shape_len, 1, Rel.LE, "beta's rank", self.name)
6657        if beta_shape_len == 1:
6658            validator.check_int(beta_shape[0], 1, Rel.EQ, "beta_shape[0]", self.name)
6659        return var_shape, m_shape
6660
6661    def infer_dtype(self, var_dtype, m_dtype, lr_dtype, logbase_dtype, sign_decay_dtype,
6662                    beta_dtype, grad_dtype):
6663        valid_dtypes = [mstype.float16, mstype.float32]
6664        args = {'var': var_dtype, 'm': m_dtype, 'grad': grad_dtype}
6665        validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
6666        validator.check_scalar_or_tensor_types_same({"lr": lr_dtype}, valid_dtypes, self.name)
6667        validator.check_scalar_or_tensor_types_same({"logbase": logbase_dtype}, valid_dtypes, self.name)
6668        validator.check_scalar_or_tensor_types_same({"sign_decay": sign_decay_dtype}, valid_dtypes, self.name)
6669        validator.check_scalar_or_tensor_types_same({"beta": beta_dtype}, valid_dtypes, self.name)
6670        return var_dtype, m_dtype
6671
6672
6673class ApplyGradientDescent(PrimitiveWithInfer):
6674    r"""
6675    Updates relevant entries according to the following.
6676
6677    .. math::
6678        var = var - \alpha * \delta
6679
6680    where :math:`\alpha` represents `alpha`, :math:`\delta` represents `delta`.
6681
6682    Inputs of `var` and `delta` comply with the implicit type conversion rules to make the data types consistent.
6683    If they have different data types, lower priority data type will be converted to
6684    relatively highest priority data type.
6685    RuntimeError exception will be thrown when the data type conversion of Parameter is required.
6686
6687    Inputs:
6688        - **var** (Parameter) - Variable tensor to be updated. With float32 or float16 data type.
6689          The shape is :math:`(N, *)` where :math:`*` means, any number of additional dimensions.
6690        - **alpha** (Union[Number, Tensor]) - Scaling factor, must be a scalar. With float32 or float16 data type.
6691        - **delta** (Tensor) - A tensor for the change, has the same shape and data type as `var`.
6692
6693    Outputs:
6694        Tensor, represents the updated `var`.
6695
6696    Raises:
6697        TypeError: If dtype of `var` or `alpha` is neither float16 nor float32.
6698        TypeError: If `delta` is not a Tensor.
6699        TypeError: If `alpha` is neither a Number nor a Tensor.
6700
6701    Supported Platforms:
6702        ``Ascend``  ``GPU``
6703
6704    Examples:
6705        >>> class Net(nn.Cell):
6706        ...     def __init__(self):
6707        ...         super(Net, self).__init__()
6708        ...         self.apply_gradient_descent = ops.ApplyGradientDescent()
6709        ...         self.var = Parameter(Tensor(np.ones([2, 2]).astype(np.float32)), name="var")
6710        ...         self.alpha = 0.001
6711        ...     def construct(self, delta):
6712        ...         out = self.apply_gradient_descent(self.var, self.alpha, delta)
6713        ...         return out
6714        ...
6715        >>> net = Net()
6716        >>> delta = Tensor(np.array([[0.1, 0.1], [0.1, 0.1]]).astype(np.float32))
6717        >>> output = net(delta)
6718        >>> print(output)
6719        [[0.9999 0.9999]
6720         [0.9999 0.9999]]
6721    """
6722
6723    __mindspore_signature__ = (
6724        sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
6725        sig.make_sig('alpha', dtype=sig.sig_dtype.T1),
6726        sig.make_sig('delta', dtype=sig.sig_dtype.T)
6727    )
6728
6729    @prim_attr_register
6730    def __init__(self):
6731        """Initialize ApplyGradientDescent."""
6732        self.add_prim_attr('side_effect_mem', True)
6733
6734    def infer_shape(self, var_shape, alpha_shape, delta_shape):
6735        validator.check('delta shape', delta_shape, 'var shape', var_shape, Rel.EQ, self.name)
6736        alpha_shape_len = len(alpha_shape)
6737        validator.check_int(alpha_shape_len, 1, Rel.LE, "alpha's rank", self.name)
6738        if alpha_shape_len == 1:
6739            validator.check_int(alpha_shape[0], 1, Rel.EQ, "alpha_shape[0]", self.name)
6740        return var_shape
6741
6742    def infer_dtype(self, var_dtype, alpha_dtype, delta_dtype):
6743        valid_dtypes = [mstype.float16, mstype.float32]
6744        args = {'var': var_dtype, 'delta': delta_dtype}
6745        validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
6746        validator.check_scalar_or_tensor_types_same({"alpha": alpha_dtype}, valid_dtypes, self.name)
6747        return var_dtype
6748
6749
6750class ApplyProximalGradientDescent(PrimitiveWithInfer):
6751    r"""
6752    Updates relevant entries according to the FOBOS(Forward Backward Splitting) algorithm.
6753
6754    .. math::
6755        \begin{array}{ll} \\
6756            \text{prox_v} = var - \alpha * \delta \\
6757            var = \frac{sign(\text{prox_v})}{1 + \alpha * l2} * \max(\left| \text{prox_v} \right| - \alpha * l1, 0)
6758        \end{array}
6759
6760    where :math:`\alpha` represents `alpha`, :math:`\delta` represents `delta`.
6761
6762    Inputs of `var` and `delta` comply with the implicit type conversion rules to make the data types consistent.
6763    If they have different data types, lower priority data type will be converted to
6764    relatively highest priority data type.
6765    RuntimeError exception will be thrown when the data type conversion of Parameter is required.
6766
6767    Inputs:
6768        - **var** (Parameter) - Variable tensor to be updated. With float32 or float16 data type.
6769          The shape is :math:`(N, *)` where :math:`*` means, any number of additional dimensions.
6770        - **alpha** (Union[Number, Tensor]) - Scaling factor, must be a scalar. With float32 or float16 data type.
6771        - **l1** (Union[Number, Tensor]) - l1 regularization strength, must be scalar.
6772          With float32 or float16 data type.
6773        - **l2** (Union[Number, Tensor]) - l2 regularization strength, must be scalar.
6774          With float32 or float16 data type.
6775        - **delta** (Tensor) - A tensor for the change, has the same shape and data type as `var`.
6776
6777    Outputs:
6778        Tensor, represents the updated `var`.
6779
6780    Raises:
6781        TypeError: If dtype of `var`, `alpha`, `l1` or `l2` is neither float16 nor float32.
6782        TypeError: If `alpha`, `l1` or `l2` is neither a Number nor a Tensor.
6783        TypeError: If `delta` is not a Tensor.
6784
6785    Supported Platforms:
6786        ``Ascend``
6787
6788    Examples:
6789        >>> class Net(nn.Cell):
6790        ...     def __init__(self):
6791        ...         super(Net, self).__init__()
6792        ...         self.apply_proximal_gradient_descent = ops.ApplyProximalGradientDescent()
6793        ...         self.var = Parameter(Tensor(np.ones([2, 2]).astype(np.float32)), name="var")
6794        ...         self.alpha = 0.001
6795        ...         self.l1 = 0.1
6796        ...         self.l2 = 0.1
6797        ...     def construct(self, delta):
6798        ...         out = self.apply_proximal_gradient_descent(self.var, self.alpha, self.l1, self.l2, delta)
6799        ...         return out
6800        ...
6801        >>> net = Net()
6802        >>> delta = Tensor(np.array([[0.1, 0.1], [0.1, 0.1]]).astype(np.float32))
6803        >>> output = net(delta)
6804        >>> print(output)
6805        [[0.99969995 0.99969995]
6806         [0.99969995 0.99969995]]
6807    """
6808
6809    __mindspore_signature__ = (
6810        sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
6811        sig.make_sig('alpha', dtype=sig.sig_dtype.T1),
6812        sig.make_sig('l1', dtype=sig.sig_dtype.T2),
6813        sig.make_sig('l2', dtype=sig.sig_dtype.T3),
6814        sig.make_sig('delta', dtype=sig.sig_dtype.T)
6815    )
6816
6817    @prim_attr_register
6818    def __init__(self):
6819        """Initialize ApplyGradientDescent."""
6820        self.add_prim_attr('side_effect_mem', True)
6821
6822    def infer_shape(self, var_shape, alpha_shape, l1_shape, l2_shape, delta_shape):
6823        validator.check('delta shape', delta_shape, 'var shape', var_shape, Rel.EQ, self.name)
6824        alpha_shape_len = len(alpha_shape)
6825        validator.check_int(alpha_shape_len, 1, Rel.LE, "alpha's rank", self.name)
6826        if alpha_shape_len == 1:
6827            validator.check_int(alpha_shape[0], 1, Rel.EQ, "alpha_shape[0]", self.name)
6828        l1_shape_len = len(l1_shape)
6829        validator.check_int(l1_shape_len, 1, Rel.LE, "l1's rank", self.name)
6830        if l1_shape_len == 1:
6831            validator.check_int(l1_shape[0], 1, Rel.EQ, "l1_shape[0]", self.name)
6832        l2_shape_len = len(l2_shape)
6833        validator.check_int(l2_shape_len, 1, Rel.LE, "l2's rank", self.name)
6834        if l2_shape_len == 1:
6835            validator.check_int(l2_shape[0], 1, Rel.EQ, "l2_shape[0]", self.name)
6836        return var_shape
6837
6838    def infer_dtype(self, var_dtype, alpha_dtype, l1_dtype, l2_dtype, delta_dtype):
6839        valid_dtypes = [mstype.float16, mstype.float32]
6840        args = {'var': var_dtype, 'delta': delta_dtype}
6841        validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
6842        validator.check_scalar_or_tensor_types_same({"alpha": alpha_dtype}, valid_dtypes, self.name)
6843        validator.check_scalar_or_tensor_types_same({"l1": l1_dtype}, valid_dtypes, self.name)
6844        validator.check_scalar_or_tensor_types_same({"l2": l2_dtype}, valid_dtypes, self.name)
6845        return var_dtype
6846
6847
6848class LARSUpdate(PrimitiveWithInfer):
6849    """
6850    Conducts LARS (layer-wise adaptive rate scaling) update on the sum of squares of gradient.
6851
6852    For more details, please refer to :class:`nn.LARS`.
6853
6854    Args:
6855        epsilon (float): Term added to the denominator to improve numerical stability. Default: 1e-05.
6856        hyperpara (float): Trust coefficient for calculating the local learning rate. Default: 0.001.
6857        use_clip (bool): Whether to use clip operation for calculating the local learning rate. Default: False.
6858
6859    Inputs:
6860        - **weight** (Tensor) - A tensor, representing the weight.
6861          The shape is :math:`(N, *)` where :math:`*` means, any number of additional dimensions.
6862        - **gradient** (Tensor) - The gradient of weight, which has the same shape and dtype with weight.
6863        - **norm_weight** (Tensor) - A scalar tensor, representing the sum of squares of weight.
6864        - **norm_gradient** (Tensor) - A scalar tensor, representing the sum of squares of gradient.
6865        - **weight_decay** (Union[Number, Tensor]) - Weight decay. It must be a scalar tensor or number.
6866        - **learning_rate** (Union[Number, Tensor]) - Learning rate. It must be a scalar tensor or number.
6867
6868    Outputs:
6869        Tensor, represents the new gradient.
6870
6871    Raises:
6872        TypeError: If neither `epsilon` nor `hyperpara` is a float.
6873        TypeError: If `use_clip` is a bool.
6874        TypeError: If `weight`, `gradient`, `norm_weight` or `norm_gradient` is not a Tensor.
6875        TypeError: If `weight_decay` or `learning_rate` is neither a Number nor a Tensor.
6876        TypeError: If shape of `gradient` is not same as `weight`.
6877
6878    Supported Platforms:
6879        ``Ascend``
6880
6881    Examples:
6882        >>> class Net(nn.Cell):
6883        ...     def __init__(self):
6884        ...         super(Net, self).__init__()
6885        ...         self.lars = ops.LARSUpdate()
6886        ...         self.reduce = ops.ReduceSum()
6887        ...         self.square = ops.Square()
6888        ...     def construct(self, weight, gradient):
6889        ...         w_square_sum = self.reduce(self.square(weight))
6890        ...         grad_square_sum = self.reduce(self.square(gradient))
6891        ...         grad_t = self.lars(weight, gradient, w_square_sum, grad_square_sum, 0.0, 1.0)
6892        ...         return grad_t
6893        ...
6894        >>> weight = Tensor(np.array([[0.5, 0.8, 0.2], [0.6, 0.4, 0.2]]).astype(np.float32))
6895        >>> gradient = Tensor(np.array([[0.4, 0.4, 0.5], [0.2, 0.4, 0.3]]).astype(np.float32))
6896        >>> net = Net()
6897        >>> output = net(Tensor(weight), Tensor(gradient))
6898        >>> print(output)
6899        [[0.0005265  0.0005265 0.00065813]
6900         [0.00026325 0.0005265 0.00039488]]
6901    """
6902
6903    @prim_attr_register
6904    def __init__(self, epsilon=1e-05, hyperpara=0.001, use_clip=False):
6905        """Initialize LARSUpdate."""
6906        validator.check_value_type("epsilon", epsilon, [float], self.name)
6907        validator.check_value_type("hyperpara", hyperpara, [float], self.name)
6908        validator.check_value_type("use_clip", use_clip, [bool], self.name)
6909
6910    def infer_shape(self, weight_shape, gradient_shape, norm_weight_shape, norm_gradient_shape, weight_decay_shape,
6911                    learning_rate_shape):
6912        validator.check("weight shape", weight_shape, "gradient shape", gradient_shape, Rel.EQ, self.name)
6913        validator.check("norm weight shape", norm_weight_shape, "norm gradient shape", norm_gradient_shape, Rel.EQ,
6914                        self.name)
6915        shp_len = len(weight_decay_shape)
6916        validator.check_int(shp_len, 1, Rel.LE, "weight decay's rank", self.name)
6917        if shp_len == 1:
6918            validator.check_int(weight_decay_shape[0], 1, Rel.EQ, "weight_decay_shape[0]", self.name)
6919        shp_len = len(learning_rate_shape)
6920        validator.check_int(shp_len, 1, Rel.LE, "learning rate's rank", self.name)
6921        if shp_len == 1:
6922            validator.check_int(learning_rate_shape[0], 1, Rel.EQ, "learning_rate_shape[0]", self.name)
6923        return weight_shape
6924
6925    def infer_dtype(self, weight_dtype, gradient_dtype, norm_weight_dtype, norm_gradient_dtype,
6926                    weight_decay_dtype, learning_rate_dtype):
6927        args = {"Weight dtype": weight_dtype, "gradient dtype": gradient_dtype, "norm weight dtype": norm_weight_dtype,
6928                "norm gradient dtype": norm_gradient_dtype}
6929        validator.check_tensors_dtypes_same_and_valid(args,
6930                                                      [mstype.float16, mstype.float32, mstype.int16, mstype.int32],
6931                                                      self.name)
6932        validator.check_scalar_or_tensor_types_same({"weight_decay": weight_decay_dtype},
6933                                                    [mstype.float16, mstype.float32, mstype.float64], self.name)
6934        validator.check_scalar_or_tensor_types_same({"learning_rate": learning_rate_dtype},
6935                                                    [mstype.float16, mstype.float32, mstype.float64], self.name)
6936        return weight_dtype
6937
6938
6939class ApplyFtrl(PrimitiveWithInfer):
6940    """
6941    Updates relevant entries according to the FTRL scheme.
6942
6943    For more details, please refer to :class:`nn.FTRL`.
6944
6945    Args:
6946        use_locking (bool): Use locks for updating operation if true . Default: False.
6947
6948    Inputs:
6949        - **var** (Parameter) - The variable to be updated. The data type must be float16 or float32.
6950          The shape is :math:`(N, *)` where :math:`*` means, any number of additional dimensions.
6951        - **accum** (Parameter) - The accumulation to be updated, must be same shape and data type as `var`.
6952        - **linear** (Parameter) - The linear coefficient to be updated, must be same shape and data type as `var`.
6953        - **grad** (Tensor) - Gradient. The data type must be float16 or float32.
6954        - **lr** (Union[Number, Tensor]) - The learning rate value, must be positive. Default: 0.001.
6955          It must be a float number or a scalar tensor with float16 or float32 data type.
6956        - **l1** (Union[Number, Tensor]) - l1 regularization strength, must be greater than or equal to zero.
6957          Default: 0.0. It must be a float number or a scalar tensor with float16 or float32 data type.
6958        - **l2** (Union[Number, Tensor]) - l2 regularization strength, must be greater than or equal to zero.
6959          Default: 0.0. It must be a float number or a scalar tensor with float16 or float32 data type.
6960        - **lr_power** (Union[Number, Tensor]) - Learning rate power controls how the learning rate decreases
6961          during training, must be less than or equal to zero. Use fixed learning rate if lr_power is zero.
6962          Default: -0.5. It must be a float number or a scalar tensor with float16 or float32 data type.
6963
6964    Outputs:
6965        - **var** (Tensor) - Represents the updated `var`. As the input parameters has been updated in-place, this
6966          value is always zero when the platforms is GPU.
6967
6968    Raises:
6969        TypeError: If `use_locking` is not a bool.
6970        TypeError: If dtype of `var`, `grad`, `lr`, `l1`, `l2` or `lr_power` is neither float16 nor float32.
6971        TypeError: If `lr`, `l1`, `l2` or `lr_power` is neither a Number nor a Tensor.
6972        TypeError: If `grad` is not a Tensor.
6973
6974    Supported Platforms:
6975        ``Ascend`` ``GPU``
6976
6977    Examples:
6978        >>> class ApplyFtrlNet(nn.Cell):
6979        ...     def __init__(self):
6980        ...         super(ApplyFtrlNet, self).__init__()
6981        ...         self.apply_ftrl = ops.ApplyFtrl()
6982        ...         self.lr = 0.001
6983        ...         self.l1 = 0.0
6984        ...         self.l2 = 0.0
6985        ...         self.lr_power = -0.5
6986        ...         self.var = Parameter(Tensor(np.array([[0.6, 0.4],
6987        ...                                               [0.1, 0.5]]).astype(np.float32)), name="var")
6988        ...         self.accum = Parameter(Tensor(np.array([[0.6, 0.5],
6989        ...                                                 [0.2, 0.6]]).astype(np.float32)), name="accum")
6990        ...         self.linear = Parameter(Tensor(np.array([[0.9, 0.1],
6991        ...                                                  [0.7, 0.8]]).astype(np.float32)), name="linear")
6992        ...
6993        ...     def construct(self, grad):
6994        ...         out = self.apply_ftrl(self.var, self.accum, self.linear, grad, self.lr, self.l1, self.l2,
6995        ...                               self.lr_power)
6996        ...         return out
6997        ...
6998        >>> net = ApplyFtrlNet()
6999        >>> input_x = Tensor(np.array([[0.3, 0.7], [0.1, 0.8]]).astype(np.float32))
7000        >>> output = net(input_x)
7001        >>> print(net.var.asnumpy())
7002        [[ 0.0390525  0.11492836]
7003         [ 0.00066425 0.15075898]]
7004    """
7005
7006    @prim_attr_register
7007    def __init__(self, use_locking=False):
7008        """Initialize ApplyFtrl."""
7009        self.init_prim_io_names(inputs=['var', 'accum', 'linear', 'grad', 'lr', 'l1', 'l2', 'lr_power'],
7010                                outputs=['output'])
7011        self.add_prim_attr('side_effect_mem', True)
7012        self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name)
7013
7014    def infer_shape(self, var_shape, accum_shape, linear_shape, grad_shape, lr_shape, l1_shape, l2_shape,
7015                    lr_power_shape):
7016        validator.check('var shape', var_shape, 'accum shape', accum_shape, Rel.EQ, self.name)
7017        validator.check('var shape', var_shape, 'linear shape', linear_shape, Rel.EQ, self.name)
7018        return var_shape
7019
7020    def infer_dtype(self, var_type, accum_type, linear_type, grad_type, lr_type, l1_type, l2_type,
7021                    lr_power_type):
7022        valid_dtypes = [mstype.float16, mstype.float32]
7023        args = {'var': var_type, 'accum': accum_type, 'linear': linear_type, 'grad': grad_type}
7024        validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
7025
7026        validator.check_scalar_or_tensor_types_same({"lr": lr_type}, valid_dtypes, self.name)
7027        validator.check_scalar_or_tensor_types_same({"l1": l1_type}, valid_dtypes, self.name)
7028        validator.check_scalar_or_tensor_types_same({"l2": l2_type}, valid_dtypes, self.name)
7029        validator.check_scalar_or_tensor_types_same({"lr_power": lr_power_type}, valid_dtypes, self.name)
7030        return var_type
7031
7032
7033class SparseApplyFtrl(PrimitiveWithCheck):
7034    """
7035    Updates relevant entries according to the FTRL-proximal scheme.
7036
7037    For more details, please refer to :class:`nn.FTRL`.
7038
7039    All of inputs except `indices` comply with the implicit type conversion rules to make the data types consistent.
7040    If they have different data types, lower priority data type will be converted to
7041    relatively highest priority data type.
7042    RuntimeError exception will be thrown when the data type conversion of Parameter is required.
7043
7044    Args:
7045        lr (float): The learning rate value, must be positive.
7046        l1 (float): l1 regularization strength, must be greater than or equal to zero.
7047        l2 (float): l2 regularization strength, must be greater than or equal to zero.
7048        lr_power (float): Learning rate power controls how the learning rate decreases during training,
7049            must be less than or equal to zero. Use fixed learning rate if `lr_power` is zero.
7050        use_locking (bool): Use locks for updating operation if true . Default: False.
7051
7052    Inputs:
7053        - **var** (Parameter) - The variable to be updated. The data type must be float16 or float32.
7054          The shape is :math:`(N, *)` where :math:`*` means, any number of additional dimensions.
7055        - **accum** (Parameter) - The accumulation to be updated, must be same data type and shape as `var`.
7056        - **linear** (Parameter) - The linear coefficient to be updated, must be the same data type and shape as `var`.
7057        - **grad** (Tensor) - A tensor of the same type as `var` and grad.shape[1:] = var.shape[1:] if var.shape > 1.
7058        - **indices** (Tensor) - A tensor of indices in the first dimension of `var` and `accum`.
7059          If there are duplicates in `indices`, the behavior is undefined.
7060          The type must be int32 or int64 and indices.shape[0] = grad.shape[0].
7061
7062    Outputs:
7063        - **var** (Tensor) - Tensor, has the same shape and data type as `var`.
7064        - **accum** (Tensor) - Tensor, has the same shape and data type as `accum`.
7065        - **linear** (Tensor) - Tensor, has the same shape and data type as `linear`.
7066
7067    Raises:
7068        TypeError: If `lr`, `l1`, `l2` or `lr_power` is not a float.
7069        TypeError: If `use_locking` is not a bool.
7070        TypeError: If dtype of `var`, `accum`, `linear` or `grad` is neither float16 nor float32.
7071        TypeError: If dtype of `indices` is neither int32 nor int64.
7072
7073    Supported Platforms:
7074        ``Ascend`` ``GPU``
7075
7076    Examples:
7077        >>> class SparseApplyFtrlNet(nn.Cell):
7078        ...     def __init__(self):
7079        ...         super(SparseApplyFtrlNet, self).__init__()
7080        ...         self.sparse_apply_ftrl = ops.SparseApplyFtrl(lr=0.01, l1=0.0, l2=0.0, lr_power=-0.5)
7081        ...         self.var = Parameter(Tensor(np.array([[0.2]]).astype(np.float32)), name="var")
7082        ...         self.accum = Parameter(Tensor(np.array([[0.1]]).astype(np.float32)), name="accum")
7083        ...         self.linear = Parameter(Tensor(np.array([[0.6]]).astype(np.float32)), name="linear")
7084        ...
7085        ...     def construct(self, grad, indices):
7086        ...         out = self.sparse_apply_ftrl(self.var, self.accum, self.linear, grad, indices)
7087        ...         return out
7088        ...
7089        >>> net = SparseApplyFtrlNet()
7090        >>> grad = Tensor(np.array([[0.7]]).astype(np.float32))
7091        >>> indices = Tensor(np.ones([1]), mindspore.int32)
7092        >>> output = net(grad, indices)
7093        >>> print(output)
7094        (Tensor(shape=[1, 1], dtype=Float32, value=
7095        [[2.00000003e-01]]), Tensor(shape=[1, 1], dtype=Float32, value=
7096        [[1.00000001e-01]]), Tensor(shape=[1, 1], dtype=Float32, value=
7097        [[6.00000024e-01]]))
7098    """
7099
7100    __mindspore_signature__ = (
7101        sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
7102        sig.make_sig('accum', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
7103        sig.make_sig('linear', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
7104        sig.make_sig('grad', dtype=sig.sig_dtype.T),
7105        sig.make_sig('indices', dtype=sig.sig_dtype.T1)
7106    )
7107
7108    @prim_attr_register
7109    def __init__(self, lr, l1, l2, lr_power, use_locking=False):
7110        """Initialize SparseApplyFtrl."""
7111        validator.check_value_type("lr", lr, [float], self.name)
7112        validator.check_value_type("l1", l1, [float], self.name)
7113        validator.check_value_type("l2", l2, [float], self.name)
7114        validator.check_value_type("lr_power", lr_power, [float], self.name)
7115        self.lr = validator.check_positive_float(lr, "lr", self.name)
7116        self.l1 = validator.check_non_negative_float(l1, "l1", self.name)
7117        self.l2 = validator.check_non_negative_float(l2, "l2", self.name)
7118        self.lr_power = validator.check_number("lr_power", lr_power, 0, Rel.LE, self.name)
7119        self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name)
7120        self.init_prim_io_names(inputs=['var', 'accum', 'linear', 'grad', 'indices'],
7121                                outputs=['var', 'accum', 'linear'])
7122        self.add_prim_attr('side_effect_mem', True)
7123
7124    def check_shape(self, var_shape, accum_shape, linear_shape, grad_shape, indices_shape):
7125        validator.check('var shape', var_shape, 'accum shape', accum_shape, Rel.EQ, self.name)
7126        validator.check('var shape', var_shape, 'linear shape', linear_shape, Rel.EQ, self.name)
7127        if len(var_shape) > 1:
7128            validator.check('var_shape[1:]', var_shape[1:], 'grad_shape[1:]', grad_shape[1:], Rel.EQ, self.name)
7129        validator.check_int(len(indices_shape), 1, Rel.EQ, "indices rank", self.name)
7130        validator.check('grad_shape[0]', grad_shape[0], 'indices_shape[0]', indices_shape[0], Rel.EQ, self.name)
7131
7132    def check_dtype(self, var_dtype, accum_dtype, linear_dtype, grad_dtype, indices_dtype):
7133        args = {"var_dtype": var_dtype, "accum_dtype": accum_dtype,
7134                "linear_dtype": linear_dtype, "grad_dtype": grad_dtype}
7135        validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name)
7136        validator.check_tensor_dtype_valid("indices_dtype", indices_dtype, [mstype.int32, mstype.int64], self.name)
7137
7138
7139class SparseApplyFtrlV2(PrimitiveWithInfer):
7140    """
7141    Updates relevant entries according to the FTRL-proximal scheme. This class has one more attribute, named
7142    l2_shrinkage, than class SparseApplyFtrl.
7143
7144    All of inputs except `indices` comply with the implicit type conversion rules to make the data types consistent.
7145    If they have different data types, lower priority data type will be converted to
7146    relatively highest priority data type.
7147    RuntimeError exception will be thrown when the data type conversion of Parameter is required.
7148
7149    Args:
7150        lr (float): The learning rate value, must be positive.
7151        l1 (float): l1 regularization strength, must be greater than or equal to zero.
7152        l2 (float): l2 regularization strength, must be greater than or equal to zero.
7153        l2_shrinkage (float): L2 shrinkage regularization.
7154        lr_power (float): Learning rate power controls how the learning rate decreases during training,
7155            must be less than or equal to zero. Use fixed learning rate if `lr_power` is zero.
7156        use_locking (bool): If `True`, the var and accumulation tensors will be protected from being updated.
7157            Default: False.
7158
7159    Inputs:
7160        - **var** (Parameter) - The variable to be updated. The data type must be float16 or float32.
7161          The shape is :math:`(N, *)` where :math:`*` means, any number of additional dimensions.
7162        - **accum** (Parameter) - The accumulation to be updated, must be same data type and shape as `var`.
7163        - **linear** (Parameter) - the linear coefficient to be updated, must be same data type and shape as `var`.
7164        - **grad** (Tensor) - A tensor of the same type as `var` and grad.shape[1:] = var.shape[1:] if var.shape > 1.
7165        - **indices** (Tensor) - A vector of indices in the first dimension of `var` and `accum`.
7166          The type must be int32 and indices.shape[0] = grad.shape[0].
7167
7168    Outputs:
7169        Tuple of 3 Tensor, the updated parameters.
7170
7171        - **var** (Tensor) - Tensor, has the same shape and data type as `var`.
7172        - **accum** (Tensor) - Tensor, has the same shape and data type as `accum`.
7173        - **linear** (Tensor) - Tensor, has the same shape and data type as `linear`.
7174
7175    Raises:
7176        TypeError: If `lr`, `l1`, `l2`, `lr_power` or `use_locking` is not a float.
7177        TypeError: If `use_locking` is not a bool.
7178        TypeError: If dtype of `var`, `accum`, `linear` or `grad` is neither float16 nor float32.
7179        TypeError: If dtype of `indices` is not int32.
7180
7181    Supported Platforms:
7182        ``Ascend``
7183
7184    Examples:
7185        >>> class SparseApplyFtrlV2Net(nn.Cell):
7186        ...     def __init__(self):
7187        ...         super(SparseApplyFtrlV2Net, self).__init__()
7188        ...         self.sparse_apply_ftrl_v2 = ops.SparseApplyFtrlV2(lr=0.01, l1=0.0, l2=0.0,
7189        ...                                                         l2_shrinkage=0.0, lr_power=-0.5)
7190        ...         self.var = Parameter(Tensor(np.array([[0.2, 0.3]]).astype(np.float32)), name="var")
7191        ...         self.accum = Parameter(Tensor(np.array([[0.5, 0.9]]).astype(np.float32)), name="accum")
7192        ...         self.linear = Parameter(Tensor(np.array([[0.7, 0.5]]).astype(np.float32)), name="linear")
7193        ...
7194        ...     def construct(self, grad, indices):
7195        ...         out = self.sparse_apply_ftrl_v2(self.var, self.accum, self.linear, grad, indices)
7196        ...         return out
7197        ...
7198        >>> net = SparseApplyFtrlV2Net()
7199        >>> grad = Tensor(np.array([[0.8, 0.5]]).astype(np.float32))
7200        >>> indices = Tensor(np.ones([1]), mindspore.int32)
7201        >>> output = net(grad, indices)
7202        >>> print(output)
7203        (Tensor(shape=[1, 2], dtype=Float32, value=
7204        [[ 2.00000003e-01,  3.00000012e-01]]), Tensor(shape=[1, 2], dtype=Float32, value=
7205        [[ 5.00000000e-01,  8.99999976e-01]]), Tensor(shape=[1, 2], dtype=Float32, value=
7206        [[ 6.99999988e-01,  5.00000000e-01]]))
7207    """
7208
7209    __mindspore_signature__ = (
7210        sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
7211        sig.make_sig('accum', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
7212        sig.make_sig('linear', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
7213        sig.make_sig('grad', dtype=sig.sig_dtype.T),
7214        sig.make_sig('indices', dtype=sig.sig_dtype.T1)
7215    )
7216
7217    @prim_attr_register
7218    def __init__(self, lr, l1, l2, l2_shrinkage, lr_power, use_locking=False):
7219        """Initialize SparseApplyFtrlV2."""
7220        validator.check_value_type("lr", lr, [float], self.name)
7221        validator.check_value_type("l1", l1, [float], self.name)
7222        validator.check_value_type("l2", l2, [float], self.name)
7223        validator.check_value_type("lr_power", lr_power, [float], self.name)
7224        self.lr = validator.check_positive_float(lr, "lr", self.name)
7225        self.l1 = validator.check_non_negative_float(l1, "l1", self.name)
7226        self.l2 = validator.check_non_negative_float(l2, "l2", self.name)
7227        self.lr_power = validator.check_number("lr_power", lr_power, 0, Rel.LE, self.name)
7228        self.l2_shrinkage = validator.check_value_type("l2_shrinkage", l2_shrinkage, [float], self.name)
7229        self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name)
7230        self.add_prim_attr('side_effect_mem', True)
7231
7232    def infer_shape(self, var_shape, accum_shape, linear_shape, grad_shape, indices_shape):
7233        validator.check('var shape', var_shape, 'accum shape', accum_shape, Rel.EQ, self.name)
7234        validator.check('var shape', var_shape, 'linear shape', linear_shape, Rel.EQ, self.name)
7235        if len(var_shape) > 1:
7236            validator.check('var_shape[1:]', var_shape[1:], 'grad_shape[1:]', grad_shape[1:], Rel.EQ, self.name)
7237        validator.check_int(len(indices_shape), 1, Rel.EQ, "indices rank", self.name)
7238        validator.check('grad_shape[0]', grad_shape[0], 'indices_shape[0]', indices_shape[0], Rel.EQ, self.name)
7239        return var_shape, accum_shape, linear_shape
7240
7241    def infer_dtype(self, var_dtype, accum_dtype, linear_dtype, grad_dtype, indices_dtype):
7242        args = {"var_dtype": var_dtype, "accum_dtype": accum_dtype,
7243                "linear_dtype": linear_dtype, "grad_dtype": grad_dtype}
7244        validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name)
7245        validator.check_tensor_dtype_valid("indicese", indices_dtype, [mstype.int32], self.name)
7246        return var_dtype, accum_dtype, linear_dtype
7247
7248
7249class Dropout(PrimitiveWithCheck):
7250    """
7251    During training, randomly zeroes some of the elements of the input tensor
7252    with probability 1-`keep_prob` from a Bernoulli distribution.
7253
7254    Args:
7255        keep_prob (float): The keep rate, between 0 and 1, e.g. keep_prob = 0.9,
7256            means dropping out 10% of input units. Default: 0.5.
7257        Seed0 (int): Seed0 value for random generating. Default: 0.
7258        Seed1 (int): Seed1 value for random generating. Default: 0.
7259
7260    Inputs:
7261        - **x** (Tensor) - Tensor of shape :math:`(N, *)`, where :math:`*` means, any number of
7262          additional dimensions, with float16 or float32 data type.
7263
7264    Outputs:
7265        - **output** (Tensor) - With the same shape and data type as `x`.
7266        - **mask** (Tensor) - With the same shape as `x`.
7267
7268    Raises:
7269        TypeError: If `keep_prob` is not a float.
7270        TypeError: If `Seed0` or `Seed1` is not an int.
7271        TypeError: If dtype of `x` is neither float16 nor float32.
7272        TypeError: If `x` is not a Tensor.
7273
7274    Supported Platforms:
7275        ``Ascend`` ``GPU`` ``CPU``
7276
7277    Examples:
7278        >>> dropout = ops.Dropout(keep_prob=0.5)
7279        >>> x = Tensor(((20, 16), (50, 50)), mindspore.float32)
7280        >>> output, mask = dropout(x)
7281        >>> print(output.shape)
7282        (2, 2)
7283    """
7284
7285    @prim_attr_register
7286    def __init__(self, keep_prob=0.5, Seed0=0, Seed1=0):
7287        """Initialize Dropout."""
7288        self.seed0 = validator.check_value_type("Seed0", Seed0, [int], self.name)
7289        self.seed1 = validator.check_value_type("Seed1", Seed1, [int], self.name)
7290        self.keep_prob = validator.check_float_range(keep_prob, 0, 1, Rel.INC_RIGHT, "keep_prob", self.name)
7291
7292    def check_shape(self, x_shape):
7293        validator.check_int(len(x_shape), 1, Rel.GE, "x_shape", self.name)
7294
7295    def check_dtype(self, x_dtype):
7296        valid_dtypes = (mstype.float16, mstype.float32)
7297        validator.check_tensor_dtype_valid("x", x_dtype, valid_dtypes, self.name)
7298
7299
7300class Dropout2D(PrimitiveWithInfer):
7301    """
7302    During training, randomly zeroes some of the channels of the input tensor with probability 1-`keep_prob`
7303    from a Bernoulli distribution(For a 4-dimensional tensor with a shape of NCHW, the channel feature map refers
7304    to a 2-dimensional feature map with the shape of HW).
7305
7306    For example, the :math:`j_th` channel of the :math:`i_th` sample in the batched input is a 2D tensor input[i,j].
7307    Each channel will be zeroed out independently on every forward call with probability 1-`keep_prob` using samples
7308    from a Bernoulli distribution.
7309
7310    Dropout2D can improve the independence between channel feature maps.
7311
7312    Args:
7313        keep_prob (float): The keep probability of a channel, between 0 and 1, e.g. `keep_prob` = 0.8,
7314            means dropping out 20% of channels. Default: 0.5.
7315
7316    Inputs:
7317        - **x** (Tensor) - A 4-D tensor with shape :math:`(N, C, H, W)`. The data type should be int8, int16,
7318          int32, int64, float16 or float32.
7319
7320    Outputs:
7321        - **output** (Tensor) - With the same shape and data type as `x`.
7322        - **mask** (Tensor) - With the same shape as `x` and the data type is bool.
7323
7324    Raises:
7325        TypeError: If the data type of `keep_prob` is not float.
7326        ValueError: If `keep_prob` is out of the range [0.0, 1.0];
7327                    or if the dim of input is not 4-D.
7328
7329    Supported Platforms:
7330        ``Ascend``
7331
7332    Examples:
7333        >>> dropout = ops.Dropout2D(keep_prob=0.5)
7334        >>> x = Tensor(np.ones([2, 1, 2, 3]), mindspore.float32)
7335        >>> output, mask = dropout(x)
7336        >>> print(output.shape)
7337        (2, 1, 2, 3)
7338    """
7339
7340    @prim_attr_register
7341    def __init__(self, keep_prob=0.5):
7342        """Initialize Dropout2D."""
7343        self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name)
7344        self.keep_prob = validator.check_float_range(keep_prob, 0.0, 1.0, Rel.INC_BOTH, "keep_prob", self.name)
7345
7346    def infer_shape(self, x_shape):
7347        validator.check_int(len(x_shape), 4, Rel.EQ, "dim of input", self.name)
7348        return x_shape, x_shape
7349
7350    def infer_dtype(self, x_dtype):
7351        valid_dtypes = mstype.int_type + (mstype.float16, mstype.float32)
7352        validator.check_tensor_dtype_valid("x", x_dtype, valid_dtypes, self.name)
7353        mask_dtype = mstype.tensor_type(mstype.bool_)
7354        return x_dtype, mask_dtype
7355
7356
7357class Dropout3D(PrimitiveWithInfer):
7358    """
7359    During training, randomly zeroes some of the channels of the input tensor
7360    with probability 1-`keep_prob` from a Bernoulli distribution(For a 5-dimensional tensor with a shape of NCDHW,
7361    the channel feature map refers to a 3-dimensional feature map with a shape of DHW).
7362
7363    For example, the :math:`j_th` channel of the :math:`i_th` sample in the batched input is a 3D tensor input[i,j,k].
7364    Each channel will be zeroed out independently on every forward call with probability 1-`keep_prob`
7365    using samples from a Bernoulli distribution.
7366
7367    Dropout3D can improve the independence between channel feature maps.
7368
7369    Args:
7370        keep_prob (float): The keep probability of a channel, between 0 and 1, e.g. `keep_prob` = 0.8,
7371            means dropping out 20% of channels. Default: 0.5.
7372
7373    Inputs:
7374        - **x** (Tensor) - A 5-D tensor with shape :math:`(N, C, D, H, W)`. The data type should be int8, int16,
7375          int32, int64, float16 or float32.
7376
7377    Outputs:
7378        - **output** (Tensor) - With the same shape and data type as `x`.
7379        - **mask** (Tensor) - With the same shape as `x` and the data type is bool.
7380
7381    Raises:
7382        TypeError: If the data type of `keep_prob` is not float.
7383        ValueError: If `keep_prob` is out of the range [0.0, 1.0];
7384                    or if the dim of input is not 5-D.
7385
7386    Supported Platforms:
7387        ``Ascend`` ``GPU``
7388
7389    Examples:
7390        >>> dropout = ops.Dropout3D(keep_prob=0.5)
7391        >>> x = Tensor(np.ones([2, 1, 2, 1, 2]), mindspore.float32)
7392        >>> output, mask = dropout(x)
7393        >>> print(output.shape)
7394        (2, 1, 2, 1, 2)
7395    """
7396
7397    @prim_attr_register
7398    def __init__(self, keep_prob=0.5):
7399        """Initialize Dropout3D."""
7400        self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name)
7401        self.keep_prob = validator.check_float_range(keep_prob, 0.0, 1.0, Rel.INC_BOTH, "keep_prob", self.name)
7402
7403    def infer_shape(self, x_shape):
7404        validator.check_int(len(x_shape), 5, Rel.EQ, "dim of input", self.name)
7405        return x_shape, x_shape
7406
7407    def infer_dtype(self, x_dtype):
7408        valid_dtypes = mstype.int_type + (mstype.float16, mstype.float32)
7409        validator.check_tensor_dtype_valid("x", x_dtype, valid_dtypes, self.name)
7410        mask_dtype = mstype.tensor_type(mstype.bool_)
7411        return x_dtype, mask_dtype
7412
7413
7414class CTCLoss(Primitive):
7415    r"""
7416    Calculates the CTC (Connectionist Temporal Classification) loss and the gradient.
7417
7418    The CTC algorithm is proposed in `Connectionist Temporal Classification: Labeling Unsegmented Sequence Data with
7419    Recurrent Neural Networks <http://www.cs.toronto.edu/~graves/icml_2006.pdf>`_.
7420
7421    Args:
7422        preprocess_collapse_repeated (bool): If true, repeated labels will be collapsed prior to the CTC calculation.
7423                                             Default: False.
7424        ctc_merge_repeated (bool): If false, during CTC calculation, repeated non-blank labels will not be merged
7425                                   and these labels will be interpreted as individual ones. This is a simplfied
7426                                   version of CTC. Default: True.
7427        ignore_longer_outputs_than_inputs (bool): If true, sequences with longer outputs than inputs will be ignored.
7428                                                  Default: False.
7429
7430    Inputs:
7431        - **x** (Tensor) - The input Tensor must be a `3-D` tensor whose shape is
7432          :math:`(max\_time, batch\_size, num\_classes)`. `num_classes` must be `num_labels + 1` classes, `num_labels`
7433          indicates the number of actual labels. Blank labels are reserved. Default blank label is `num_classes - 1`.
7434          Data type must be float16, float32 or float64.
7435        - **labels_indices** (Tensor) - The indices of labels. `labels_indices[i, :] = [b, t]` means
7436          `labels_values[i]` stores the id for `(batch b, time t)`. The type must be int64 and rank must be 2.
7437        - **labels_values** (Tensor) - A `1-D` input tensor. The values are associated with the given batch and time.
7438          The type must be int32. `labels_values[i]` must in the range of `[0, num_classes)`.
7439        - **sequence_length** (Tensor) - A tensor containing sequence lengths with the shape of :math:`(batch\_size, )`.
7440          The type must be int32. Each value in the tensor must not be greater than `max_time`.
7441
7442    Outputs:
7443        - **loss** (Tensor) - A tensor containing log-probabilities, the shape is :math:`(batch\_size, )`.
7444          The tensor has the same data type as `x`.
7445        - **gradient** (Tensor) - The gradient of `loss`, has the same shape and data type as `x`.
7446
7447    Raises:
7448        TypeError: If `preprocess_collapse_repeated`, `ctc_merge_repeated` or `ignore_longer_outputs_than_inputs`
7449                   is not a bool.
7450        TypeError: If `x`, `labels_indices`, `labels_values` or `sequence_length` is not a Tensor.
7451        ValueError: If rank of `labels_indices` is not equal 2.
7452        TypeError: If dtype of `x` is not one of the following: float16, float32 or float64.
7453        TypeError: If dtype of `labels_indices` is not int64.
7454        TypeError: If dtype of `labels_values` or `sequence_length` is not int32.
7455
7456    Supported Platforms:
7457        ``Ascend`` ``GPU`` ``CPU``
7458
7459    Examples:
7460        >>> x = Tensor(np.array([[[0.3, 0.6, 0.6],
7461        ...                       [0.4, 0.3, 0.9]],
7462        ...
7463        ...                      [[0.9, 0.4, 0.2],
7464        ...                       [0.9, 0.9, 0.1]]]).astype(np.float32))
7465        >>> labels_indices = Tensor(np.array([[0, 0], [1, 0]]), mindspore.int64)
7466        >>> labels_values = Tensor(np.array([2, 2]), mindspore.int32)
7467        >>> sequence_length = Tensor(np.array([2, 2]), mindspore.int32)
7468        >>> ctc_loss = ops.CTCLoss()
7469        >>> loss, gradient = ctc_loss(x, labels_indices, labels_values, sequence_length)
7470        >>> print(loss)
7471        [ 0.79628  0.5995158 ]
7472        >>> print(gradient)
7473        [[[ 0.27029088  0.36485454  -0.6351454  ]
7474          [ 0.28140804  0.25462854  -0.5360366 ]]
7475         [[ 0.47548494  0.2883962    0.04510255 ]
7476          [ 0.4082751   0.4082751    0.02843709 ]]]
7477    """
7478
7479    @prim_attr_register
7480    def __init__(self, preprocess_collapse_repeated=False, ctc_merge_repeated=True,
7481                 ignore_longer_outputs_than_inputs=False):
7482        """Initialize CTCLoss."""
7483        self.init_prim_io_names(inputs=["inputs", "labels_indices", "labels_values", "sequence_length"],
7484                                outputs=["loss", "gradient"])
7485        validator.check_value_type("preprocess_collapse_repeated", preprocess_collapse_repeated, [bool], self.name)
7486        self.preprocess_collapse_repeated_ = preprocess_collapse_repeated
7487        self.ctc_merge_repeated_ = validator.check_value_type("ctc_merge_repeated", ctc_merge_repeated,
7488                                                              [bool], self.name)
7489        validator.check_value_type("ignore_longer_outputs_than_inputs",
7490                                   ignore_longer_outputs_than_inputs, [bool], self.name)
7491        self.ignore_longer_outputs_than_inputs_ = ignore_longer_outputs_than_inputs
7492
7493
7494class CTCGreedyDecoder(PrimitiveWithCheck):
7495    r"""
7496    Performs greedy decoding on the logits given in inputs.
7497
7498    Args:
7499        merge_repeated (bool): If true, merge repeated classes in output. Default: True.
7500
7501    Inputs:
7502        - **inputs** (Tensor) - The input Tensor must be a 3-D tensor whose shape is
7503          :math:`(max\_time, batch\_size, num\_classes)`. `num_classes` must be `num_labels + 1` classes,
7504          `num_labels` indicates the number of actual labels. Blank labels are reserved.
7505          Default blank label is `num_classes - 1`. Data type must be float32 or float64.
7506        - **sequence_length** (Tensor) - A tensor containing sequence lengths with the shape of :math:`(batch\_size, )`.
7507          The type must be int32. Each value in the tensor must be equal to or less than `max_time`.
7508
7509    Outputs:
7510        - **decoded_indices** (Tensor) - A tensor with shape of :math:`(total\_decoded\_outputs, 2)`.
7511          Data type is int64.
7512        - **decoded_values** (Tensor) - A tensor with shape of :math:`(total\_decoded\_outputs, )`,
7513          it stores the decoded classes. Data type is int64.
7514        - **decoded_shape** (Tensor) - A tensor with shape of :math:`(batch\_size, max\_decoded\_legth)`.
7515          Data type is int64.
7516        - **log_probability** (Tensor) - A tensor with shape of :math:`(batch\_size, 1)`,
7517          containing sequence log-probability, has the same type as `inputs`.
7518
7519    Raises:
7520        TypeError: If `merge_repeated` is not a bool.
7521        ValueError: If length of shape of `inputs` is not equal to 3.
7522        ValueError: If length of shape of `sequence_length` is not equal to 1.
7523
7524    Supported Platforms:
7525        ``Ascend``
7526
7527    Examples:
7528        >>> inputs = Tensor(np.array([[[0.6, 0.4, 0.2], [0.8, 0.6, 0.3]],
7529        ...                           [[0.0, 0.6, 0.0], [0.5, 0.4, 0.5]]]), mindspore.float32)
7530        >>> sequence_length = Tensor(np.array([2, 2]), mindspore.int32)
7531        >>> ctc_greedyDecoder = ops.CTCGreedyDecoder()
7532        >>> decoded_indices, decoded_values, decoded_shape, log_probability = ctc_greedyDecoder(inputs, sequence_length)
7533        >>> print(decoded_indices)
7534        [[0 0]
7535         [0 1]
7536         [1 0]]
7537        >>> print(decoded_values)
7538        [0 1 0]
7539        >>> print(decoded_shape)
7540        [2 2]
7541        >>> print(log_probability)
7542        [[-1.2]
7543         [-1.3]]
7544    """
7545
7546    @prim_attr_register
7547    def __init__(self, merge_repeated=True):
7548        """Initialize CTCGreedyDecoder."""
7549        self.merge_repeated = validator.check_value_type("merge_repeated", merge_repeated, [bool], self.name)
7550
7551    def check_shape(self, inputs_shape, sequence_length_shape):
7552        validator.check_int(len(inputs_shape), 3, Rel.EQ, "inputs rank", self.name)
7553        validator.check_int(len(sequence_length_shape), 1, Rel.EQ, "sequence_length rank", self.name)
7554        validator.check('inputs batch_size', inputs_shape[1], 'sequence_length batch_size',
7555                        sequence_length_shape[0], Rel.EQ, self.name)
7556        total_decoded_outputs = -1
7557        decoded_indices_shape = [total_decoded_outputs, 2]
7558        decoded_values = [total_decoded_outputs]
7559        decoded_shape = [2]
7560        log_probability_shape = [inputs_shape[1], 1]
7561        return decoded_indices_shape, decoded_values, decoded_shape, log_probability_shape
7562
7563    def check_dtype(self, inputs_dtype, sequence_length_dtype):
7564        validator.check_tensor_dtype_valid("inputs_dtype", inputs_dtype, [mstype.float32, mstype.double], self.name)
7565        validator.check_tensor_dtype_valid("sequence_length_dtype", sequence_length_dtype, [mstype.int32], self.name)
7566        decoded_type = mstype.tensor_type(mstype.int64)
7567        return decoded_type, decoded_type, decoded_type, inputs_dtype
7568
7569
7570class BasicLSTMCell(PrimitiveWithInfer):
7571    """
7572    It's similar to operator :class:`DynamicRNN`. BasicLSTMCell will be deprecated in the future.
7573    Please use DynamicRNN instead.
7574
7575    Supported Platforms:
7576        Deprecated
7577    """
7578
7579    @prim_attr_register
7580    def __init__(self, keep_prob=1.0, forget_bias=1.0, state_is_tuple=True, activation='tanh'):
7581        """Initialize BasicLSTMCell."""
7582        self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name)
7583        self.keep_prob = validator.check_float_range(keep_prob, 0.0, 1.0, Rel.INC_BOTH, "keep_prob", self.name)
7584        self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name)
7585        self.state_is_tuple = validator.check_value_type("state_is_tuple", state_is_tuple, [bool], self.name)
7586        self.activation = validator.check_string(activation, ['tanh'], "activation", self.name)
7587
7588    def infer_shape(self, x_shape, h_shape, c_shape, w_shape, b_shape):
7589        validator.check_int(len(x_shape), 2, Rel.EQ, "x rank", self.name)
7590        validator.check_int(len(h_shape), 2, Rel.EQ, "h rank", self.name)
7591        validator.check_int(len(c_shape), 2, Rel.EQ, "c rank", self.name)
7592        validator.check_int(len(w_shape), 2, Rel.EQ, "w rank", self.name)
7593        validator.check_int(len(b_shape), 1, Rel.EQ, "b rank", self.name)
7594        validator.check("x_shape[0]", x_shape[0], "h_shape[0]", h_shape[0], Rel.EQ, self.name)
7595        validator.check("c_shape[0]", c_shape[0], "h_shape[0]", h_shape[0], Rel.EQ, self.name)
7596        validator.check("c_shape[1]", c_shape[1], "h_shape[1]", h_shape[1], Rel.EQ, self.name)
7597        validator.check("w_shape[1]", w_shape[1], "4*h_shape[1]", 4 * h_shape[1], Rel.EQ, self.name)
7598        validator.check("w_shape[0]", w_shape[0], "x_shape[1]+h_shape[1]", x_shape[1] + h_shape[1], Rel.EQ, self.name)
7599        validator.check("b_shape[0]", b_shape[0], "4*h_shape[1]", 4 * h_shape[1], Rel.EQ, self.name)
7600        ct_shape = c_shape
7601        ht_shape = c_shape
7602        it_shape = c_shape
7603        jt_shape = c_shape
7604        ft_shape = c_shape
7605        ot_shape = c_shape
7606        tanhct_shape = c_shape
7607
7608        return ct_shape, ht_shape, it_shape, jt_shape, ft_shape, ot_shape, tanhct_shape
7609
7610    def infer_dtype(self, x_dtype, h_dtype, c_dtype, w_dtype, b_dtype):
7611        tuple(map(partial(validator.check_tensor_dtype_valid,
7612                          valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name),
7613                  ("x_dtype", "h_dtype", "w_dtype"),
7614                  (x_dtype, h_dtype, w_dtype)))
7615        args = {"c_dtype": c_dtype, "b_dtype": b_dtype}
7616        validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name)
7617        return c_dtype, mstype.float16, c_dtype, c_dtype, c_dtype, c_dtype, c_dtype
7618
7619
7620class DynamicRNN(PrimitiveWithInfer):
7621    r"""
7622    Applies a recurrent neural network to the input.
7623    Only long short-term memory (LSTM) currently supported.
7624
7625    .. math::
7626        \begin{array}{ll} \\
7627            i_{t+1} = \sigma(W_{ix} x_{t+1} + b_{ix} + W_{ih} h_{(t)} + b_{ih}) \\
7628            f_{t+1} = \sigma(W_{fx} x_{t+1} + b_{fx} + W_{fh} h_{(t)} + b_{fh}) \\
7629            \tilde{c}_{t+1} = \tanh(W_{cx} x_{t+1} + b_{cx} + W_{ch} h_{(t)} + b_{ch}) \\
7630            o_{t+1} = \sigma(W_{ox} x_{t+1} + b_{ox} + W_{oh} h_{(t)} + b_{oh}) \\
7631            c_{t+1} = f_{t+1} * c_{(t)} + i_t * \tilde{c}_{t+1} \\
7632            h_{t+1} = o_{t+1} * \tanh(c_{t+1}) \\
7633        \end{array}
7634
7635    where :math:`h_{t+1}` is the hidden state at time `t+1`, :math:`x_{t+1}` is the input
7636    at time `t+1`, :math:`h_{t}` is the hidden state of the layer
7637    at time `t` or the initial hidden state at time `0`,
7638    :math:`\sigma` is the sigmoid function, and :math:`*` is the Hadamard product. :math:`W, b`
7639    are learnable weights between the output and the input in the formula. For instance,
7640    :math:`W_{ix}, b_{ix}` are the weight and bias used to transform from input :math:`x` to :math:`i`.
7641
7642    Args:
7643        cell_type (str): A string identifying the cell type in the op. Default: 'LSTM'.
7644            Only 'LSTM' is currently supported.
7645        direction (str): A string identifying the direction in the op. Default: 'UNIDIRECTIONAL'.
7646            Only 'UNIDIRECTIONAL' is currently supported.
7647        cell_depth (int): An integer identifying the cell depth in the op. Default: 1.
7648        use_peephole (bool): A bool identifying if use peephole in the op. Default: False.
7649        keep_prob (float): A float identifying the keep prob in the op. Default: 1.0.
7650        cell_clip (float): A float identifying the cell clip in the op. Default: -1.0.
7651        num_proj (int): An integer identifying the num proj in the op. Default: 0.
7652        time_major (bool): A bool identifying the time major in the op. Default: True.
7653            Only `True` is currently supported.
7654        activation (str): A string identifying the type of activation function in the op. Default: 'tanh'.
7655            Only 'tanh' is currently supported.
7656        forget_bias (float): A float identifying the forget bias in the op. Default: 0.0.
7657        is_training (bool): A bool identifying is training in the op. Default: True.
7658
7659    Inputs:
7660        - **x** (Tensor) - Current words. Tensor of shape :math:`(num\_step, batch\_size, input\_size)`.
7661          The data type must be float16.
7662        - **w** (Tensor) - Weight. Tensor of shape :math:`(input\_size + hidden\_size, 4 x hidden\_size)`.
7663          The data type must be float16.
7664        - **b** (Tensor) - Bias. Tensor of shape :math`(4 x hidden\_size)`.
7665          The data type must be float16 or float32.
7666        - **seq_length** (Tensor) - The length of each batch. Tensor of shape :math:`(batch\_size, )`.
7667          Only `None` is currently supported.
7668        - **init_h** (Tensor) - Hidden state of initial time. Tensor of shape :math:`(1, batch\_size, hidden\_size)`.
7669          The data type must be float16.
7670        - **init_c** (Tensor) - Cell state of initial time. Tensor of shape :math:`(1, batch\_size, hidden\_size)`.
7671          The data type must be float16.
7672
7673    Outputs:
7674        - **y** (Tensor) - A Tensor of shape :math:`(num\_step, batch\_size, hidden\_size)`.
7675          Has the same type with input `b`.
7676        - **output_h** (Tensor) - A Tensor of shape :math:`(num\_step, batch\_size, hidden\_size)`.
7677          With data type of float16.
7678        - **output_c** (Tensor) - A Tensor of shape :math:`(num\_step, batch\_size, hidden\_size)`.
7679          Has the same type with input `b`.
7680        - **i** (Tensor) - A Tensor of shape :math:`(num\_step, batch\_size, hidden\_size)`.
7681          Has the same type with input `b`.
7682        - **j** (Tensor) - A Tensor of shape :math:`(num\_step, batch\_size, hidden\_size)`.
7683          Has the same type with input `b`.
7684        - **f** (Tensor) - A Tensor of shape :math:`(num\_step, batch\_size, hidden\_size)`.
7685          Has the same type with input `b`.
7686        - **o** (Tensor) - A Tensor of shape :math:`(num\_step, batch\_size, hidden\_size)`.
7687          Has the same type with input `b`.
7688        - **tanhct** (Tensor) - A Tensor of shape :math:`(num\_step, batch\_size, hidden\_size)`.
7689          Has the same type with input `b`.
7690
7691    Raises:
7692        TypeError: If `cell_type`, `direction` or `activation` is not a str.
7693        TypeError: If `cell_depth` or `num_proj` is not an int.
7694        TypeError: If `keep_prob`, `cell_clip` or `forget_bias` is not a float.
7695        TypeError: If `use_peehpole`, `time_major` or `is_training` is not a bool.
7696        TypeError: If `x`, `w`, `b`, `seq_length`, `init_h` or `init_c` is not a Tensor.
7697        TypeError: If dtype of `x`, `w`, `init_h` or `nit_c` is not float16.
7698        TypeError: If dtype of `b` is neither float16 nor float32.
7699
7700    Supported Platforms:
7701        ``Ascend``
7702
7703    Examples:
7704        >>> x = Tensor(np.random.rand(2, 16, 64).astype(np.float16))
7705        >>> w = Tensor(np.random.rand(96, 128).astype(np.float16))
7706        >>> b = Tensor(np.random.rand(128).astype(np.float16))
7707        >>> init_h = Tensor(np.random.rand(1, 16, 32).astype(np.float16))
7708        >>> init_c = Tensor(np.random.rand(1, 16, 32).astype(np.float16))
7709        >>> dynamic_rnn = ops.DynamicRNN()
7710        >>> output = dynamic_rnn(x, w, b, None, init_h, init_c)
7711        >>> print(output[0].shape)
7712        (2, 16, 32)
7713    """
7714
7715    @prim_attr_register
7716    def __init__(self,
7717                 cell_type='LSTM',
7718                 direction='UNIDIRECTIONAL',
7719                 cell_depth=1,
7720                 use_peephole=False,
7721                 keep_prob=1.0,
7722                 cell_clip=-1.0,
7723                 num_proj=0,
7724                 time_major=True,
7725                 activation='tanh',
7726                 forget_bias=0.0,
7727                 is_training=True):
7728        """Initialize DynamicRNN."""
7729        self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name)
7730        self.cell_depth = validator.check_value_type("cell_depth", cell_depth, [int], self.name)
7731        self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name)
7732        self.cell_clip = validator.check_value_type("cell_clip", cell_clip, [float], self.name)
7733        self.num_proj = validator.check_non_negative_int(num_proj, "num_proj", self.name)
7734        self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name)
7735        self.use_peephole = validator.check_value_type("use_peephole", use_peephole, [bool], self.name)
7736        self.time_major = validator.check_value_type("time_major", time_major, [bool], self.name)
7737        self.is_training = validator.check_value_type("is_training", is_training, [bool], self.name)
7738        validator.check_value_type("cell_type", cell_type, [str], self.name)
7739        self.cell_type = validator.check_string(cell_type, ['LSTM'], "cell_type", self.name)
7740        validator.check_value_type("direction", direction, [str], self.name)
7741        self.direction = validator.check_string(direction, ['UNIDIRECTIONAL'], "direction", self.name)
7742        validator.check_value_type("activation", activation, [str], self.name)
7743        self.activation = validator.check_string(activation, ['tanh'], "activation", self.name)
7744
7745    def infer_shape(self, x_shape, w_shape, b_shape, seq_shape, h_shape, c_shape):
7746        validator.check_int(len(x_shape), 3, Rel.EQ, "x_shape", self.name)
7747        validator.check_int(len(w_shape), 2, Rel.EQ, "w rank", self.name)
7748        validator.check_int(len(b_shape), 1, Rel.EQ, "b rank", self.name)
7749        validator.check_int(len(h_shape), 3, Rel.EQ, "h_shape", self.name)
7750        validator.check_int(len(c_shape), 3, Rel.EQ, "c_shape", self.name)
7751        if seq_shape is not None:
7752            raise ValueError(f"For '{self.name}', the dimension of 'seq_length' should be None, but got {seq_shape}.")
7753
7754        num_step, batch_size, input_size = x_shape
7755        hidden_size = w_shape[-1] // 4
7756
7757        validator.check("b_shape[-1]", b_shape[-1], "w_shape[-1]", w_shape[-1], Rel.EQ, self.name)
7758        if w_shape[-1] % 4 != 0:
7759            raise ValueError(f"For '{self.name}', the last dimension of 'w' should be a multiple of 4, "
7760                             f"but got {w_shape[-1]}.")
7761        validator.check("w_shape[0]", w_shape[0], "input_size + hidden_size",
7762                        input_size + hidden_size, Rel.EQ, self.name)
7763        validator.check("b_shape[0]", b_shape[0], "w_shape[1]", w_shape[1], Rel.EQ, self.name)
7764        validator.check_int(h_shape[0], 1, Rel.EQ, "h_shape[0]", self.name)
7765        validator.check("h_shape[1]", h_shape[1], "batch_size", batch_size, Rel.EQ, self.name)
7766        validator.check("h_shape[2]", h_shape[2], "hidden_size", hidden_size, Rel.EQ, self.name)
7767        validator.check("c_shape", c_shape, "h_shape", h_shape, Rel.EQ, self.name)
7768        self.placeholder_index = [3]
7769        self.add_prim_attr("placeholder_index", self.placeholder_index)
7770        self.add_prim_attr("input_size", input_size)
7771        self.add_prim_attr("hidden_size", hidden_size)
7772        y_shape = (num_step, batch_size, hidden_size)
7773        return y_shape, y_shape, y_shape, y_shape, y_shape, y_shape, y_shape, y_shape
7774
7775    def infer_dtype(self, x_dtype, w_dtype, b_dtype, seq_dtype, h_dtype, c_dtype):
7776        tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=[mstype.float16], prim_name=self.name),
7777                  ("x", "w", "h", "c"),
7778                  (x_dtype, w_dtype, h_dtype, c_dtype)))
7779        validator.check_tensor_dtype_valid("b", b_dtype, (mstype.float16, mstype.float32), self.name)
7780        return b_dtype, x_dtype, b_dtype, b_dtype, b_dtype, b_dtype, b_dtype, b_dtype
7781
7782
7783class DynamicGRUV2(PrimitiveWithInfer):
7784    r"""
7785    Applies a single-layer gated recurrent unit (GRU) to an input sequence.
7786
7787    .. math::
7788
7789        \begin{array}{ll}
7790            r_{t+1} = \sigma(W_{ir} x_{t+1} + b_{ir} + W_{hr} h_{(t)} + b_{hr}) \\
7791            z_{t+1} = \sigma(W_{iz} x_{t+1} + b_{iz} + W_{hz} h_{(t)} + b_{hz}) \\
7792            n_{t+1} = \tanh(W_{in} x_{t+1} + b_{in} + r_{t+1} * (W_{hn} h_{(t)}+ b_{hn})) \\
7793            h_{t+1} = (1 - z_{t+1}) * n_{t+1} + z_{t+1} * h_{(t)}
7794        \end{array}
7795
7796    where :math:`h_{t+1}` is the hidden state at time `t+1`, :math:`x_{t+1}` is the input
7797    at time `t+1`, :math:`h_{t}` is the hidden state of the layer
7798    at time `t` or the initial hidden state at time `0`, and :math:`r_{t+1}`,
7799    :math:`z_{t+1}`, :math:`n_{t+1}` are the reset, update, and new gates, respectively.
7800    :math:`W`, :math:`b` are the weight parameter and the deviation parameter respectively.
7801    :math:`\sigma` is the sigmoid function, and :math:`*` is the Hadamard product.
7802
7803    Args:
7804        direction (str): A string identifying the direction in the op. Default: 'UNIDIRECTIONAL'.
7805            Only 'UNIDIRECTIONAL' is currently supported.
7806        cell_depth (int): An integer identifying the cell depth in the op. Default: 1.
7807        keep_prob (float): A float identifying the keep prob in the op. Default: 1.0.
7808        cell_clip (float): A float identifying the cell clip in the op. Default: -1.0.
7809        num_proj (int): An integer identifying the num proj in the op. Default: 0.
7810        time_major (bool): A bool identifying the time major in the op. Default: True.
7811        activation (str) : A string identifying the type of activation function in the op. Default: 'tanh'.
7812            Only 'tanh' is currently supported.
7813        gate_order (str): A string identifying the gate order in weight and bias. Default: 'rzh.
7814            'zrh' is another option.
7815        reset_after (bool): A bool identifying whether to apply reset gate after matrix multiplication. Default: True.
7816        is_training (bool): A bool identifying is training in the op. Default: True.
7817
7818    Inputs:
7819        - **x** (Tensor) - Current words.
7820          Tensor of shape :math:`(\text{num_step}, \text{batch_size}, \text{input_size})`.
7821          The data type must be float16.
7822        - **weight_input** (Tensor) - Input-hidden weight.
7823          Tensor of shape :math:`(\text{input_size}, 3 \times \text{hidden_size})`.
7824          The data type must be float16.
7825        - **weight_hidden** (Tensor) - Hidden-hidden weight.
7826          Tensor of shape :math:`(\text{hidden_size}, 3 \times \text{hidden_size})`.
7827          The data type must be float16.
7828        - **init_h** (Tensor) - Hidden state of initial time.
7829          Tensor of shape :math:`(\text{batch_size}, \text{hidden_size})`.
7830          The data type must be float16 or float32.
7831        - **bias_input** (Tensor) - Input-hidden bias. Tensor of shape :math:`(3 \times \text{hidden_size})`, or None.
7832          Has the same data type with input `init_h`.
7833        - **bias_hidden** (Tensor) - Hidden-hidden bias. Tensor of shape :math:`(3 \times \text{hidden_size})`,
7834          or None. Has the same data type with input `init_h`.
7835        - **seq_length** (Tensor) - The length of each batch. Tensor of shape :math:`(\text{batch_size})`.
7836          Only `None` is currently supported.
7837
7838    Outputs:
7839        - **y** (Tensor) - A Tensor of shape:
7840
7841          - y_shape = :math:`(num\_step, batch\_size, min(hidden\_size, num\_proj))`: `If num_proj > 0`,
7842          - y_shape = :math:`(num\_step, batch\_size, hidden\_size)`: `If num_proj = 0`.
7843
7844          Has the same data type with input `bias_type`.
7845        - **output_h** (Tensor) - A Tensor of shape :math:`(\text{num_step}, \text{batch_size}, \text{hidden_size})`.
7846          Has the same data type with input `bias_type`.
7847        - **update** (Tensor) - A Tensor of shape :math:`(\text{num_step}, \text{batch_size}, \text{hidden_size})`.
7848          Has the same data type with input `bias_type`.
7849        - **reset** (Tensor) - A Tensor of shape :math:`(\text{num_step}, \text{batch_size}, \text{hidden_size})`.
7850          Has the same data type with input `bias_type`.
7851        - **new** (Tensor) - A Tensor of shape :math:`(\text{num_step}, \text{batch_size}, \text{hidden_size})`.
7852          Has the same data type with input `bias_type`.
7853        - **hidden_new** (Tensor) - A Tensor of shape :math:`(\text{num_step}, \text{batch_size}, \text{hidden_size})`.
7854          Has the same data type with input `bias_type`.
7855
7856        A note about the bias_type:
7857
7858        - If `bias_input` and `bias_hidden` both are `None`, `bias_type` is date type of `init_h`.
7859        - If `bias_input` is not `None`, `bias_type` is the date type of `bias_input`.
7860        - If `bias_input` is `None` and `bias_hidden` is not `None`, `bias_type` is the date type of `bias_hidden`.
7861
7862    Raises:
7863        TypeError: If `direction`, `activation` or `gate_order` is not a str.
7864        TypeError: If `cell_depth` or `num_proj` is not an int.
7865        TypeError: If `keep_prob` or `cell_clip` is not a float.
7866        TypeError: If `time_major`, `reset_after` or `is_training` is not a bool.
7867        TypeError: If `x`, `weight_input`, `weight_hidden`, `bias_input`, `bias_hidden`, `seq_length` or `ini_h` is not
7868                   a Tensor.
7869        TypeError: If dtype of `x`, `weight_input` or `weight_hidden` is not float16.
7870        TypeError: If dtype of `init_h` is neither float16 nor float32.
7871
7872    Supported Platforms:
7873        ``Ascend``
7874
7875    Examples:
7876        >>> x = Tensor(np.random.rand(2, 8, 64).astype(np.float16))
7877        >>> weight_i = Tensor(np.random.rand(64, 48).astype(np.float16))
7878        >>> weight_h = Tensor(np.random.rand(16, 48).astype(np.float16))
7879        >>> bias_i = Tensor(np.random.rand(48).astype(np.float16))
7880        >>> bias_h = Tensor(np.random.rand(48).astype(np.float16))
7881        >>> init_h = Tensor(np.random.rand(8, 16).astype(np.float16))
7882        >>> dynamic_gru_v2 = ops.DynamicGRUV2()
7883        >>> output = dynamic_gru_v2(x, weight_i, weight_h, bias_i, bias_h, None, init_h)
7884        >>> print(output[0].shape)
7885        (2, 8, 16)
7886    """
7887
7888    @prim_attr_register
7889    def __init__(self,
7890                 direction='UNIDIRECTIONAL',
7891                 cell_depth=1,
7892                 keep_prob=1.0,
7893                 cell_clip=-1.0,
7894                 num_proj=0,
7895                 time_major=True,
7896                 activation="tanh",
7897                 gate_order="rzh",
7898                 reset_after=True,
7899                 is_training=True):
7900        """Initialize DynamicGRUV2."""
7901        self.cell_depth = validator.check_value_type("cell_depth", cell_depth, [int], self.name)
7902        self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name)
7903        self.cell_clip = validator.check_value_type("cell_clip", cell_clip, [float], self.name)
7904        self.num_proj = validator.check_non_negative_int(num_proj, "num_proj", self.name)
7905        self.time_major = validator.check_value_type("time_major", time_major, [bool], self.name)
7906        self.is_training = validator.check_value_type("is_training", is_training, [bool], self.name)
7907        self.direction = validator.check_string(direction, ['UNIDIRECTIONAL'], "direction", self.name)
7908        self.activation = validator.check_string(activation, ['tanh'], "activation", self.name)
7909        self.gate_order = validator.check_string(gate_order, ['zrh', 'rzh'], "gate_order", self.name)
7910        self.reset_after = validator.check_value_type("reset_after", reset_after, [bool], self.name)
7911
7912    def infer_shape(self, x_shape, winput_shape, whidden_shape, binput_shape, bhidden_shape, seq_shape, h_shape):
7913        validator.check_int(len(x_shape), 3, Rel.EQ, "x shape", self.name)
7914        validator.check_int(len(winput_shape), 2, Rel.EQ, "weight input shape rank", self.name)
7915        validator.check_int(len(whidden_shape), 2, Rel.EQ, "weight hidden shape rank", self.name)
7916
7917        num_step, batch_size, input_size = x_shape
7918        hidden_size = winput_shape[-1] // 3
7919        if winput_shape[-1] % 3 != 0:
7920            raise ValueError(f"For '{self.name}', the last dimension of 'w' should be a multiple of 3, "
7921                             f"but got {winput_shape[-1]}.")
7922
7923        self.placeholder_index = [3, 4, 5]
7924        if binput_shape is not None:
7925            validator.check_int(len(binput_shape), 1, Rel.EQ, "bias input shape rank", self.name)
7926            validator.check("bias_input_shape", binput_shape, "3 * hidden_shape", [3 * hidden_size], Rel.EQ, self.name)
7927            self.placeholder_index.remove(3)
7928        if bhidden_shape is not None:
7929            validator.check_int(len(bhidden_shape), 1, Rel.EQ, "bias hidden shape rank", self.name)
7930            validator.check("bias_hidden_shape", bhidden_shape,
7931                            "3 * hidden_shape", [3 * hidden_size], Rel.EQ, self.name)
7932            self.placeholder_index.remove(4)
7933        if seq_shape is not None:
7934            raise ValueError(f"For '{self.name}', the dimension of 'seq_length' should be None, "
7935                             f"but got {seq_shape}.")
7936
7937        validator.check_int(len(h_shape), 2, Rel.EQ, "init_h shape rank", self.name)
7938        validator.check("init_h_shape[0]", h_shape[0], "batch_size", batch_size, Rel.EQ, self.name)
7939        validator.check("init_h_shape[1]", h_shape[1], "hidden_size", hidden_size, Rel.EQ, self.name)
7940        validator.check("weight_input_shape[-1]", winput_shape[-1], "weight_hidden_shape[-1]",
7941                        whidden_shape[-1], Rel.EQ, self.name)
7942        validator.check("weight_input_shape[0]", winput_shape[0], "input_size", input_size, Rel.EQ, self.name)
7943        validator.check("weight_hidden_shape[0]", whidden_shape[0], "hidden_size", hidden_size, Rel.EQ, self.name)
7944        if self.num_proj > 0:
7945            y_shape = (num_step, batch_size, min(hidden_size, self.num_proj))
7946        else:
7947            y_shape = (num_step, batch_size, hidden_size)
7948        out_shape = (num_step, batch_size, hidden_size)
7949        self.add_prim_attr("placeholder_index", self.placeholder_index)
7950        return y_shape, out_shape, out_shape, out_shape, out_shape, out_shape
7951
7952    def infer_dtype(self, x_dtype, winput_dtype, whidden_dtype, binput_dtype, bhidden_dtype, seq_dtype, h_dtype):
7953        validator.check_tensor_dtype_valid("x dtype", x_dtype, [mstype.float16], self.name)
7954        validator.check_tensor_dtype_valid("weight input dtype", winput_dtype, [mstype.float16], self.name)
7955        validator.check_tensor_dtype_valid("weight hidden dtype", whidden_dtype, [mstype.float16], self.name)
7956        valid_dtypes = [mstype.float16, mstype.float32]
7957        validator.check_tensor_dtype_valid("init_h dtype", h_dtype, valid_dtypes, self.name)
7958        b_dtype = h_dtype
7959        if binput_dtype is not None:
7960            args = {'init_h': h_dtype, 'bias_input': binput_dtype}
7961            validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
7962            b_dtype = binput_dtype
7963        if bhidden_dtype is not None:
7964            args = {'init_h': h_dtype, 'bias_hidden': bhidden_dtype}
7965            validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
7966            b_dtype = bhidden_dtype
7967
7968        return b_dtype, b_dtype, b_dtype, b_dtype, b_dtype, b_dtype
7969
7970
7971class InTopK(PrimitiveWithInfer):
7972    r"""
7973    Determines whether the targets are in the top `k` predictions.
7974
7975    Args:
7976        k (int): Specifies the number of top elements to be used for computing precision.
7977
7978    Inputs:
7979        - **x1** (Tensor) - A 2D Tensor defines the predictions of a batch of samples with float16 or float32
7980          data type.
7981        - **x2** (Tensor) - A 1D Tensor defines the labels of a batch of samples with int32 data type. The size of x2
7982          must be equal to x1's first dimension. The values of `x2` can not be negative and
7983          must be equal to or less than index of x1's second dimension.
7984
7985    Outputs:
7986        Tensor has 1 dimension of type bool and the same shape with `x2`. For labeling sample `i` in `x2`,
7987        if the label in the first `k` predictions for sample `i` is in `x1`, then the value is True, otherwise False.
7988
7989    Raises:
7990        TypeError: If `k` is not an int.
7991        TypeError: If `x1` or `x2` is not a Tensor.
7992        TypeError: If dtype of `x1` is neither float16 nor float32.
7993
7994    Supported Platforms:
7995        ``Ascend`` ``GPU``
7996
7997    Examples:
7998        >>> x1 = Tensor(np.array([[1, 8, 5, 2, 7], [4, 9, 1, 3, 5]]), mindspore.float32)
7999        >>> x2 = Tensor(np.array([1, 3]), mindspore.int32)
8000        >>> in_top_k = ops.InTopK(3)
8001        >>> output = in_top_k(x1, x2)
8002        >>> print(output)
8003        [ True  False]
8004    """
8005
8006    @prim_attr_register
8007    def __init__(self, k):
8008        """Initialize InTopK"""
8009        self.init_prim_io_names(inputs=['x1', 'x2', 'k'], outputs=['y'])
8010        validator.check_value_type("k", k, [int], self.name)
8011
8012    def infer_dtype(self, x1_dtype, x2_dtype):
8013        validator.check_tensor_dtype_valid("x1", x1_dtype, (mstype.float16, mstype.float32,), self.name)
8014        validator.check_tensor_dtype_valid("x2", x2_dtype, (mstype.int32,), self.name)
8015
8016        return mstype.tensor_type(mstype.bool_)
8017
8018    def infer_shape(self, x1_shape, x2_shape):
8019        validator.check("x1 shape", len(x1_shape), "", 2, Rel.EQ, self.name)
8020        validator.check("x2 shape", len(x2_shape), "", 1, Rel.EQ, self.name)
8021        validator.check("size of x2", x2_shape[0], "x1's first dimension", x1_shape[0], Rel.EQ, self.name)
8022        return x2_shape
8023
8024
8025class LRN(PrimitiveWithInfer):
8026    r"""
8027    Local Response Normalization.
8028
8029    .. math::
8030
8031        b_{c} = a_{c}\left(k + \frac{\alpha}{n}
8032        \sum_{c'=\max(0, c-n/2)}^{\min(N-1,c+n/2)}a_{c'}^2\right)^{-\beta}
8033
8034    where the :math:`a_{c}` indicates the represents the specific value of the pixel corresponding to c in feature map;
8035    where the :math:`n/2` indicate the `depth_radius`; where the :math:`k` indicate the `bias`;
8036    where the :math:`\alpha` indicate the`alpha`; where the :math:`\beta` indicate the `beta`.
8037
8038    Args:
8039        depth_radius (int): Half-width of the 1-D normalization window with the shape of 0-D. Default: 5.
8040        bias (float): An offset (usually positive to avoid dividing by 0). Default: 1.0.
8041        alpha (float): A scale factor, usually positive. Default: 1.0.
8042        beta (float): An exponent. Default: 0.5.
8043        norm_region (str): Specifies normalization region. Options: "ACROSS_CHANNELS". Default: "ACROSS_CHANNELS".
8044
8045    Inputs:
8046        - **x** (Tensor) - A 4D Tensor with float16 or float32 data type.
8047
8048    Outputs:
8049        Tensor, with the same shape and data type as `x`.
8050
8051    Raises:
8052        TypeError: If `depth_radius` is not an int.
8053        TypeError: If `bias`, `alpha` or `beta` is not a float.
8054        TypeError: If `norm_region` is not a str.
8055        TypeError: If `x` is not a Tensor.
8056
8057    Supported Platforms:
8058        ``Ascend`` ``GPU``
8059
8060    Examples:
8061        >>> x = Tensor(np.array([[[[0.1], [0.2]],
8062        ...                       [[0.3], [0.4]]]]), mindspore.float32)
8063        >>> lrn = ops.LRN()
8064        >>> output = lrn(x)
8065        >>> print(output)
8066        [[[[0.09534626]
8067           [0.1825742 ]]
8068          [[0.2860388 ]
8069           [0.3651484 ]]]]
8070    """
8071
8072    @prim_attr_register
8073    def __init__(self, depth_radius=5, bias=1.0, alpha=1.0, beta=0.5, norm_region="ACROSS_CHANNELS"):
8074        """Initialize LRN"""
8075        self.init_prim_io_names(inputs=['x'], outputs=['y'])
8076        validator.check_value_type("depth_radius", depth_radius, [int], self.name)
8077        validator.check_value_type("bias", bias, [float], self.name)
8078        validator.check_value_type("alpha", alpha, [float], self.name)
8079        validator.check_value_type("beta", beta, [float], self.name)
8080        validator.check_value_type("norm_region", norm_region, [str], self.name)
8081        validator.check_string(norm_region, ['ACROSS_CHANNELS'], 'norm_region', self.name)
8082        validator.check_non_negative_int(depth_radius, "depth_radius", self.name)
8083
8084    def infer_dtype(self, x_dtype):
8085        validator.check_tensor_dtype_valid("x", x_dtype, (mstype.float16, mstype.float32,), self.name)
8086        return x_dtype
8087
8088    def infer_shape(self, x_shape):
8089        validator.check_int(len(x_shape), 4, Rel.EQ, "x_shape", self.name)
8090        return x_shape
8091
8092
8093class AvgPool3D(Primitive):
8094    r"""
8095    3D Average pooling operation.
8096
8097    Applies a 3D average pooling over an input Tensor which can be regarded as a composition of 3D input planes.
8098    Typically the input is of shape :math:`(N, C, D_{in}, H_{in}, W_{in})`, AvgPool3D outputs
8099    regional average in the :math:`(D_{in}, H_{in}, W_{in})`-dimension. Given kernel size
8100    :math:`ks = (d_{ker}, h_{ker}, w_{ker})` and stride :math:`s = (s_0, s_1, s_2)`, the operation is as follows.
8101
8102    .. warning::
8103        "kernel_size" is in the range [1, 255]. "strides" is in the range [1, 63].
8104
8105    .. math::
8106        \text{output}(N_i, C_j, d, h, w) =
8107        \frac{1}{d_{ker} * h_{ker} * w_{ker}} \sum_{l=0}^{d_{ker}-1} \sum_{m=0}^{h_{ker}-1} \sum_{n=0}^{w_{ker}-1}
8108        \text{input}(N_i, C_j, s_0 \times d + l, s_1 \times h + m, s_2 \times w + n)
8109
8110    Args:
8111        kernel_size (Union[int, tuple[int]]): The size of kernel used to take the average value,
8112            is an int number that represents depth, height and width are both kernel_size, or a tuple
8113            of three int numbers that represent depth, height and width respectively. Default: 1.
8114        strides (Union[int, tuple[int]]): The distance of kernel moving, an int number that represents
8115            the depth, height and width of movement are both strides, or a tuple of three int numbers that
8116            represent depth, height and width of movement respectively. Default: 1.
8117        pad_mode (str): The optional value for pad mode, is "SAME", "VALID", "PAD", not case sensitive.
8118            Default: "VALID".
8119
8120            - same: Adopts the way of completion. The depth, height and width of the output will be the same as
8121              the input. The total number of padding will be calculated in depth, horizontal and vertical
8122              directions and evenly distributed to head and tail, top and bottom, left and right if possible.
8123              Otherwise, the last extra padding will be done from the tail, bottom and the right side.
8124              If this mode is set, `pad` must be 0.
8125
8126            - valid: Adopts the way of discarding. The possible largest depth, height and width of output
8127              will be returned without padding. Extra pixels will be discarded. If this mode is set, `pad`
8128              must be 0.
8129
8130            - pad: Implicit paddings on both sides of the input in depth, height, width. The number of `pad` will
8131              be padded to the input Tensor borders. `pad` must be greater than or equal to 0.
8132        pad (Union(int, tuple[int])): The pad value to be filled. Default: 0. If `pad` is an integer, the paddings of
8133                    head, tail, top, bottom, left and right are the same, equal to pad. If `pad` is a tuple of six
8134                    integers, the padding of head, tail, top, bottom, left and right equal to pad[0], pad[1], pad[2],
8135                    pad[3], pad[4] and pad[5] correspondingly.
8136        ceil_mode (bool): If True, ceil instead of floor to compute the output shape. Default: False.
8137        count_include_pad (bool): If True, averaging calculation will include the zero-padding. Default: True.
8138        divisor_override (int): If specified, it will be used as divisor in the averaging calculation,
8139            otherwise kernel_size will be used. Default: 0.
8140        data_format (str) : The optional value for data format. Currently only support 'NCDHW'. Default: 'NCDHW'.
8141
8142    Inputs:
8143        - **x** (Tensor) - Tensor of shape :math:`(N, C, D_{in}, H_{in}, W_{in})`.
8144          Currently support float16 and float32 data type.
8145
8146    Outputs:
8147        Tensor, with shape :math:`(N, C, D_{out}, H_{out}, W_{out})`. Has the same data type with `x`.
8148
8149    Raises:
8150        TypeError: If `kernel_size`, `strides` or `pad` is neither an int not a tuple.
8151        TypeError: If `ceil_mode` or `count_include_pad` is not a bool.
8152        TypeError: If `pad_mode` or `data_format` is not a string.
8153        TypeError: If `divisor_override` is not an int.
8154        ValueError: If numbers in `kernel_size` or `strides` are not positive.
8155        ValueError: If `kernel_size` or `strides` is a tuple whose length is not equal to 3.
8156        ValueError: If `pad_mode` is not one of 'same', 'valid' or 'pad'.
8157        ValueError: If `pad` is a tuple whose length is not equal to 6.
8158        ValueError: If element of `pad` is less than 0.
8159        ValueError: If `pad_mode` is not equal to 'pad' and `pad` is not equal to 0 or (0, 0, 0, 0, 0, 0).
8160        ValueError: If `data_format` is not 'NCDHW'.
8161
8162    Supported Platforms:
8163        ``Ascend``
8164
8165    Examples:
8166        >>> x = Tensor(np.arange(1 * 2 * 2 * 2 * 3).reshape((1, 2, 2, 2, 3)), mindspore.float16)
8167        >>> avg_pool3d = ops.AvgPool3D(kernel_size=2, strides=1, pad_mode="valid")
8168        >>> output = avg_pool3d(x)
8169        >>> print(output)
8170        [[[[[ 5.  6.]]]
8171          [[[17. 18.]]]]]
8172    """
8173
8174    @prim_attr_register
8175    def __init__(self, kernel_size=1, strides=1, pad_mode="valid", pad=0, ceil_mode=False,
8176                 count_include_pad=True, divisor_override=0, data_format="NCDHW"):
8177        """Initialize AvgPool3D"""
8178        self.init_prim_io_names(inputs=['input'], outputs=['output'])
8179        self.kernel_size = _check_3d_int_or_tuple('kernel_size', kernel_size, self.name)
8180        self.add_prim_attr('kernel_size', self.kernel_size)
8181        self.strides = _check_3d_int_or_tuple('strides', strides, self.name)
8182        validator.check_value_type('pad', pad, (int, tuple), self.name)
8183        self.add_prim_attr('strides', self.strides)
8184        if isinstance(pad, int):
8185            pad = (pad,) * 6
8186        if len(pad) != 6:
8187            raise ValueError(f"For '{self.name}', attr 'pad' should be an positive int number or a tuple of "
8188                             f"six positive int numbers, but got `{pad}`.")
8189        self.pad_list = pad
8190        self.add_prim_attr('pad_list', self.pad_list)
8191        validator.check_value_type('pad_mode', pad_mode, [str], self.name)
8192        self.pad_mode = validator.check_string(pad_mode.upper(), ['VALID', 'SAME', 'PAD'], 'pad_mode', self.name)
8193        self.add_prim_attr('pad_mode', self.pad_mode)
8194
8195        if self.pad_mode != 'PAD' and pad != (0, 0, 0, 0, 0, 0):
8196            raise ValueError(f"For '{self.name}', the 'pad' must be (0, 0, 0, 0, 0, 0) when 'pad_mode' is not \"pad\", "
8197                             f"but got 'pad' is {pad} and 'pad_mode' is {self.pad_mode}.")
8198        if self.pad_mode == 'PAD':
8199            for item in pad:
8200                validator.check_non_negative_int(item, 'pad or item of pad', self.name)
8201        self.ceil_mode = validator.check_value_type('ceil_mode', ceil_mode, bool, self.name)
8202        self.count_include_pad = validator.check_value_type('count_include_pad', count_include_pad, bool, self.name)
8203        self.divisor_override = validator.check_non_negative_int(divisor_override, 'divisor_override', self.name)
8204        self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.name)
8205
8206
8207class Conv3D(PrimitiveWithInfer):
8208    r"""
8209    3D convolution layer.
8210
8211    Applies a 3D convolution over an input tensor which is typically of shape
8212    :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})` and output shape
8213    :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})`. Where :math:`N` is batch size, :math:`C` is channel number,
8214    :math:`D` is depth, :math:`H` is height, :math:`W` is width.
8215    the formula is defined as:
8216
8217    .. math::
8218
8219        \operatorname{out}\left(N_{i}, C_{\text {out}_j}\right)=\operatorname{bias}\left(C_{\text {out}_j}\right)+
8220        \sum_{k=0}^{C_{in}-1} ccor(\text {weight}\left(C_{\text {out}_j}, k\right),
8221        \operatorname{input}\left(N_{i}, k\right))
8222
8223    where :math:`k` is kernel, :math:`ccor` is the cross-correlation operator.
8224
8225    If the 'pad_mode' is set to be "valid", the output depth, height and width will be
8226    :math:`\left \lfloor{1 + \frac{D_{in} + 2 \times \text{padding} - \text{ks_d} -
8227    (\text{ks_d} - 1) \times (\text{dilation} - 1) }{\text{stride}}} \right \rfloor` and
8228    :math:`\left \lfloor{1 + \frac{H_{in} + 2 \times \text{padding} - \text{ks_h} -
8229    (\text{ks_h} - 1) \times (\text{dilation} - 1) }{\text{stride}}} \right \rfloor` and
8230    :math:`\left \lfloor{1 + \frac{W_{in} + 2 \times \text{padding} - \text{ks_w} -
8231    (\text{ks_w} - 1) \times (\text{dilation} - 1) }{\text{stride}}} \right \rfloor` respectively. Where
8232    :math:`dialtion` is Spacing between kernel elements, :math:`stride` is The step length of each step,
8233    :math:`padding` is zero-padding added to both sides of the input.
8234
8235    Args:
8236        out_channel (int): The number of output channel :math:`C_{out}`.
8237        kernel_size (Union[int, tuple[int]]): The data type is int or a tuple of 3 integers. Specifies the depth, height
8238            and width of the 3D convolution window. Single int means the value is for the depth, height and the width
8239            of the kernel. A tuple of 3 ints means the first value is for the depth, height and the other is for the
8240            width of the kernel.
8241        mode (int): Modes for different convolutions. It is currently not used. Default: 1.
8242        stride (Union[int, tuple[int]]): The distance of kernel moving, an int number that represents
8243            the depth, height and width of movement are both strides, or a tuple of three int numbers that
8244            represent depth, height and width of movement respectively. Default: 1.
8245        pad_mode (str): Specifies padding mode. The optional values are
8246            "same", "valid", "pad". Default: "valid".
8247
8248            - same: Adopts the way of completion. The depth, height and width of the output will be the same as
8249              the input. The total number of padding will be calculated in depth, horizontal and vertical
8250              directions and evenly distributed to head and tail, top and bottom, left and right if possible.
8251              Otherwise, the last extra padding will be done from the tail, bottom and the right side.
8252              If this mode is set, `pad` must be 0.
8253
8254            - valid: Adopts the way of discarding. The possible largest depth, height and width of output
8255              will be returned without padding. Extra pixels will be discarded. If this mode is set, `pad`
8256              must be 0.
8257
8258            - pad: Implicit paddings on both sides of the input in depth, height, width. The number of `pad` will
8259              be padded to the input Tensor borders. `pad` must be greater than or equal to 0.
8260
8261        pad (Union(int, tuple[int])): The pad value to be filled. Default: 0. If `pad` is an integer, the paddings of
8262                    head, tail, top, bottom, left and right are the same, equal to pad. If `pad` is a tuple of six
8263                    integers, the padding of head, tail, top, bottom, left and right equal to pad[0], pad[1], pad[2],
8264                    pad[3], pad[4] and pad[5] correspondingly.
8265        dilation (Union[int, tuple[int]]): The data type is int or a tuple of 3 integers
8266                                      : math:`(dilation_d, dilation_h, dilation_w)`.
8267                                      Currently, dilation on depth only supports the case of 1.
8268                                      Specifies the dilation rate to use for dilated convolution.
8269                                      If set :math:`k > 1`, there will be :math:`k - 1` pixels skipped
8270                                      for each sampling location. Its value must be greater or equal to 1 and
8271                                      bounded by the height and width of the input. Default: 1.
8272        group (int): Splits filter into groups, `in_ channels` and `out_channels` must be
8273            divisible by the number of groups. Default: 1. Only 1 is currently supported.
8274        data_format (str): The optional value for data format. Currently only support "NCDHW".
8275
8276    Inputs:
8277        - **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})`.
8278          Currently input data type only support float16 and float32.
8279        - **weight** (Tensor) - Set size of kernel is :math:`(k_d, K_h, K_w)`, then the shape is
8280          :math:`(C_{out}, C_{in}//groups, k_d, K_h, K_w)`.
8281          Currently weight data type only support float16 and float32.
8282        - **bias** (Tensor) - Tensor of shape :math:`C_{in}`. Currently, only support none.
8283
8284    Outputs:
8285        Tensor, the value that applied 3D convolution. The shape is :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})`.
8286
8287    Raises:
8288        TypeError: If `out_channel` or `group` is not an int.
8289        TypeError: If `kernel_size`, `stride`, `pad` or `dilation` is neither an int nor a tuple.
8290        ValueError: If `out_channel`, `kernel_size`, `stride` or `dilation` is less than 1.
8291        ValueError: If `pad` is less than 0.
8292        ValueError: If `pad_mode` is not one of 'same', 'valid', 'pad'.
8293        ValueError: If `pad` is a tuple whose length is not equal to 6.
8294        ValueError: If `pad_mode` is not equal to 'pad' and `pad` is not equal to (0, 0, 0, 0, 0, 0).
8295        ValueError: If `data_format` is not 'NCDHW'.
8296
8297    Supported Platforms:
8298        ``Ascend`` ``GPU`` ``CPU``
8299
8300    Examples:
8301        >>> x = Tensor(np.ones([16, 3, 10, 32, 32]), mindspore.float16)
8302        >>> weight = Tensor(np.ones([32, 3, 4, 3, 3]), mindspore.float16)
8303        >>> conv3d = ops.Conv3D(out_channel=32, kernel_size=(4, 3, 3))
8304        >>> output = conv3d(x, weight)
8305        >>> print(output.shape)
8306        (16, 32, 7, 30, 30)
8307    """
8308
8309    @prim_attr_register
8310    def __init__(self,
8311                 out_channel,
8312                 kernel_size,
8313                 mode=1,
8314                 pad_mode="valid",
8315                 pad=0,
8316                 stride=1,
8317                 dilation=1,
8318                 group=1,
8319                 data_format="NCDHW"):
8320        """Initialize Conv3D"""
8321        self.init_prim_io_names(inputs=['x', 'w'], outputs=['output'])
8322        self.kernel_size = _check_3d_int_or_tuple('kernel_size', kernel_size, self.name)
8323        self.stride = _check_3d_int_or_tuple('stride', stride, self.name, allow_five=False, ret_five=True)
8324        self.add_prim_attr('strides', self.stride)
8325        self.dilation = _check_3d_int_or_tuple('dilation', dilation, self.name, allow_five=False,
8326                                               ret_five=True, third_one=True)
8327        self.add_prim_attr('dilations', self.dilation)
8328        validator.check_value_type('pad', pad, (int, tuple), self.name)
8329        if isinstance(pad, int):
8330            pad = (pad,) * 6
8331        if len(pad) != 6:
8332            raise ValueError(f"For '{self.name}', attr 'pad' should be an positive int number or a tuple of "
8333                             f"six positive int numbers, but got `{pad}`.")
8334        self.add_prim_attr("pad", pad)
8335        self.padding = pad
8336        validator.check_value_type('pad_mode', pad_mode, [str], self.name)
8337        self.pad_mode = validator.check_string(pad_mode.lower(), ['valid', 'same', 'pad'], 'pad_mode', self.name)
8338        self.add_prim_attr('pad_mode', self.pad_mode)
8339
8340        if self.pad_mode != 'pad' and pad != (0, 0, 0, 0, 0, 0):
8341            raise ValueError(f"For '{self.name}', the 'pad' must be (0, 0, 0, 0, 0, 0) when 'pad_mode' is not \"pad\", "
8342                             f"but got 'pad' is {pad} and 'pad_mode' is {self.pad_mode}.")
8343        if self.pad_mode == 'pad':
8344            for item in pad:
8345                validator.check_non_negative_int(item, 'pad item', self.name)
8346
8347        self.mode = validator.check_equal_int(mode, 1, 'mode', self.name)
8348        self.add_prim_attr('mode', self.mode)
8349        self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.name)
8350        self.add_prim_attr('data_format', self.format)
8351        self.out_channel = validator.check_positive_int(out_channel, 'out_channel', self.name)
8352        self.group = validator.check_equal_int(group, 1, 'group', self.name)
8353        self.add_prim_attr('groups', self.group)
8354        self.add_prim_attr('offset_x', 0)
8355
8356    def infer_shape(self, x_shape, w_shape, b_shape=None):
8357        validator.check_equal_int(len(w_shape), 5, "weight rank", self.name)
8358        validator.check_equal_int(len(x_shape), 5, "x rank", self.name)
8359        if b_shape is not None:
8360            raise ValueError(f"For '{self.name}', the 'bias' currently only support None, but got {b_shape}.")
8361        validator.check(f"x_shape[1] // group", x_shape[1] // self.group, "w_shape[1]", w_shape[1], Rel.EQ, self.name)
8362        validator.check('out_channel', self.out_channel, 'w_shape[0]', w_shape[0], Rel.EQ, self.name)
8363        validator.check('kernel_size', self.kernel_size, 'w_shape[1:4]', tuple(w_shape[2:]), Rel.EQ, self.name)
8364
8365        kernel_size_d = w_shape[2]
8366        kernel_size_h = w_shape[3]
8367        kernel_size_w = w_shape[4]
8368
8369        stride_d = self.stride[2]
8370        stride_h = self.stride[3]
8371        stride_w = self.stride[4]
8372
8373        dilation_d = self.dilation[2]
8374        dilation_h = self.dilation[3]
8375        dilation_w = self.dilation[4]
8376
8377        if self.pad_mode == "valid":
8378            d_out = math.ceil((x_shape[2] - dilation_d * (kernel_size_d - 1)) / stride_d)
8379            h_out = math.ceil((x_shape[3] - dilation_h * (kernel_size_h - 1)) / stride_h)
8380            w_out = math.ceil((x_shape[4] - dilation_w * (kernel_size_w - 1)) / stride_w)
8381            pad_head, pad_tail, pad_top, pad_bottom, pad_left, pad_right = 0, 0, 0, 0, 0, 0
8382
8383        elif self.pad_mode == "same":
8384            d_out = math.ceil(x_shape[2] / stride_d)
8385            h_out = math.ceil(x_shape[3] / stride_h)
8386            w_out = math.ceil(x_shape[4] / stride_w)
8387
8388            pad_needed_d = max(0, (d_out - 1) * stride_d + dilation_d * (kernel_size_d - 1) + 1 - x_shape[2])
8389            pad_head = math.floor(pad_needed_d / 2)
8390            pad_tail = pad_needed_d - pad_head
8391
8392            pad_needed_h = max(0, (h_out - 1) * stride_h + dilation_h * (kernel_size_h - 1) + 1 - x_shape[3])
8393            pad_top = math.floor(pad_needed_h / 2)
8394            pad_bottom = pad_needed_h - pad_top
8395
8396            pad_needed_w = max(0, (w_out - 1) * stride_w + dilation_w * (kernel_size_w - 1) + 1 - x_shape[4])
8397            pad_left = math.floor(pad_needed_w / 2)
8398            pad_right = pad_needed_w - pad_left
8399
8400        elif self.pad_mode == 'pad':
8401            pad_head, pad_tail, pad_top, pad_bottom, pad_left, pad_right = self.padding
8402            d_out = 1 + (x_shape[2] + pad_head + pad_tail - kernel_size_d - (kernel_size_d - 1)
8403                         * (dilation_d - 1)) / stride_d
8404            h_out = 1 + (x_shape[3] + pad_top + pad_bottom - kernel_size_h - (kernel_size_h - 1)
8405                         * (dilation_h - 1)) / stride_h
8406            w_out = 1 + (x_shape[4] + pad_left + pad_right - kernel_size_w - (kernel_size_w - 1)
8407                         * (dilation_w - 1)) / stride_w
8408            d_out = math.floor(d_out)
8409            h_out = math.floor(h_out)
8410            w_out = math.floor(w_out)
8411
8412        self.pad_list = [pad_head, pad_tail, pad_top, pad_bottom, pad_left, pad_right]
8413        filter_d = (self.kernel_size[0] - 1) * dilation_d + 1
8414        filter_h = (self.kernel_size[1] - 1) * dilation_h + 1
8415        filter_w = (self.kernel_size[2] - 1) * dilation_w + 1
8416        validator.check_int_range(self.pad_list[0], 0, filter_d, Rel.INC_LEFT,
8417                                  'pad_d belonging [0, filter_d)', self.name)
8418        validator.check_int_range(self.pad_list[1], 0, filter_d, Rel.INC_LEFT,
8419                                  'pad_d belonging [0, filter_d)', self.name)
8420        validator.check_int_range(self.pad_list[2], 0, filter_h, Rel.INC_LEFT,
8421                                  'pad_h belonging [0, filter_h)', self.name)
8422        validator.check_int_range(self.pad_list[3], 0, filter_h, Rel.INC_LEFT,
8423                                  'pad_h belonging [0, filter_h)', self.name)
8424        validator.check_int_range(self.pad_list[4], 0, filter_w, Rel.INC_LEFT,
8425                                  'pad_w belonging [0, filter_w)', self.name)
8426        validator.check_int_range(self.pad_list[5], 0, filter_w, Rel.INC_LEFT,
8427                                  'pad_w belonging [0, filter_w)', self.name)
8428        self.add_prim_attr('pad_list', (pad_head, pad_tail, pad_top, pad_bottom, pad_left, pad_right))
8429        out_channel = self.out_channel
8430        out_shape = [x_shape[0], out_channel, d_out, h_out, w_out]
8431        _check_shape('output', out_shape, self.name)
8432        return out_shape
8433
8434    def infer_dtype(self, x_dtype, w_dtype, b_dtype=None):
8435        args = {'x': x_dtype, 'w': w_dtype}
8436        valid_dtypes = [mstype.float16, mstype.float32]
8437        validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
8438        return x_dtype
8439
8440
8441class Conv3DBackpropInput(PrimitiveWithInfer):
8442    """
8443    Computes the gradients of convolution 3D with respect to the input.
8444
8445    Args:
8446        out_channel (int): The dimension of the output.
8447        kernel_size (Union[int, tuple[int]]): The kernel size of the 3D convolution.
8448        mode (int): Modes for different convolutions. Not currently used.
8449        pad_mode (str): Modes to fill padding. It could be "valid", "same", or "pad", not case sensitive.
8450            Default: "valid".
8451        pad (Union(int, tuple[int])): The pad value to be filled. Default: 0. If `pad` is an integer, the paddings of
8452                    head, tail, top, bottom, left and right are the same, equal to pad. If `pad` is a tuple of four
8453                    integers, the padding of head, tail, top, bottom, left and right equal to pad[0], pad[1], pad[2],
8454                    pad[3], pad[4] and pad[5] correspondingly.
8455        stride (Union(int, tuple[int])): The stride to be applied to the convolution filter. Default: 1.
8456        dilation (Union(int, tuple[int])): Specifies the space to use between kernel elements. Default: 1.
8457        group (int): Splits input into groups. Default: 1.
8458        data_format (str): The optional value for data format. Currently only support 'NCDHW'.
8459
8460    Inputs:
8461        - **weight** (Tensor) - Set size of kernel is :math:`(D_in, K_h, K_w)`, then the shape is
8462          :math:`(C_{out}, C_{in}, D_{in}, K_h, K_w)`. Currently weight data type only support float16 and float32.
8463        - **dout** (Tensor) - the gradients with respect to the output of the convolution.
8464          The shape conforms to the default.
8465          data_format :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})`. Currently dout data type only support float16
8466          and float32.
8467        - **input_size** (tuple(int)) - A tuple describes the shape of the input which conforms to the format
8468          :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})`.
8469
8470    Outputs:
8471        Tensor, the gradients with respect to the input of convolution 3D. It has the same shape as the input.
8472
8473    Raises:
8474        TypeError: If `out_channel` or `group` is not an int.
8475        TypeError: If `kernel_size`, `stride`, `pad` or `dilation` is neither an int not a tuple.
8476        ValueError: If `out_channel`, `kernel_size`, `stride` or `dilation` is less than 1.
8477        ValueError: If `pad` is less than 0.
8478        ValueError: If `pad_mode` is not one of 'same', 'valid', 'pad'.
8479        ValueError: If `pad` is a tuple whose length is not equal to 6.
8480        ValueError: If `pad_mode` is not equal to 'pad' and `pad` is not equal to (0, 0, 0, 0, 0, 0).
8481        ValueError: If `data_format` is not 'NCDHW'.
8482
8483    Supported Platforms:
8484        ``Ascend``
8485
8486    Examples:
8487        >>> import numpy as np
8488        >>> import mindspore
8489        >>> from mindspore import Tensor, ops
8490        >>> dout = Tensor(np.ones([16, 32, 10, 32, 32]), mindspore.float16)
8491        >>> weight = Tensor(np.ones([32, 32, 4, 6, 2]), mindspore.float16)
8492        >>> x = Tensor(np.ones([16, 32, 13, 37, 33]))
8493        >>> conv3d_backprop_input = ops.Conv3DBackpropInput(out_channel=4, kernel_size=(4, 6, 2))
8494        >>> output = conv3d_backprop_input(dout, weight, ops.shape(x))
8495        >>> print(output.shape)
8496        (16, 32, 13, 37, 33)
8497    """
8498
8499    @prim_attr_register
8500    def __init__(self,
8501                 out_channel,
8502                 kernel_size,
8503                 mode=1,
8504                 pad_mode="valid",
8505                 pad=0,
8506                 stride=1,
8507                 dilation=1,
8508                 group=1,
8509                 data_format="NCDHW"):
8510        """Initialize Conv3DBackpropInput"""
8511        self.init_prim_io_names(inputs=['filter', 'out_backprop', 'input_size'], outputs=['y'])
8512        self.out_channel = validator.check_positive_int(out_channel, 'out_channel', self.name)
8513        self.kernel_size = _check_3d_int_or_tuple('kernel_size', kernel_size, self.name)
8514        self.stride = _check_3d_int_or_tuple('stride', stride, self.name, allow_five=True, ret_five=True)
8515        self.add_prim_attr('strides', self.stride)
8516        self.dilation = _check_3d_int_or_tuple('dilation', dilation, self.name, allow_five=True, ret_five=True)
8517        self.add_prim_attr('dilations', self.dilation)
8518        validator.check_value_type('pad', pad, (int, tuple), self.name)
8519        if isinstance(pad, int):
8520            pad = (pad,) * 6
8521        validator.check_equal_int(len(pad), 6, 'pad size', self.name)
8522        self.add_prim_attr("pad", pad)
8523        self.pad_list = pad
8524
8525        self.pad_mode = validator.check_string(pad_mode.lower(), ['valid', 'same', 'pad'], 'pad_mode', self.name)
8526        if self.pad_mode != 'pad' and self.pad_list != (0, 0, 0, 0, 0, 0):
8527            raise ValueError(f"For '{self.name}', the 'pad' must be (0, 0, 0, 0, 0, 0) "
8528                             f"when 'pad_mode' is not \"pad\", "
8529                             f"but got 'pad' is {self.pad_list} and 'pad_mode' is {self.pad_mode}.")
8530        if self.pad_mode == 'pad':
8531            for item in pad:
8532                validator.check_non_negative_int(item, 'pad item', self.name)
8533        self.add_prim_attr('pad_mode', self.pad_mode)
8534
8535        self.mode = validator.check_equal_int(mode, 1, 'mode', self.name)
8536        self.add_prim_attr('mode', self.mode)
8537        self.group = validator.check_positive_int(group, 'group', self.name)
8538        self.add_prim_attr('groups', self.group)
8539        self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.name)
8540        self.add_prim_attr('data_format', self.format)
8541
8542    def __infer__(self, w, doutput, x_size):
8543        validator.check_equal_int(len(w['shape']), 5, 'The dimension of weight ', self.name)
8544        validator.check_equal_int(len(doutput['shape']), 5, 'The dimension of dout', self.name)
8545        x_size_v = x_size['value']
8546        validator.check_equal_int(len(x_size_v), 5, 'The dimension of input_size', self.name)
8547        validator.check_value_type('x_size', x_size_v, [tuple], self.name)
8548        for i, dim_len in enumerate(x_size_v):
8549            validator.check_value_type("x_size[%d]" % i, dim_len, [int], self.name)
8550        args = {'doutput': doutput['dtype'], 'w': w['dtype']}
8551        valid_dtypes = [mstype.float16, mstype.float32]
8552        validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
8553        validator.check("filter's batch", w['shape'][0], "dout's channel", doutput['shape'][1], Rel.EQ, self.name)
8554        validator.check("filter's channel", w['shape'][1], "input_size's channel", x_size_v[1], Rel.EQ, self.name)
8555        validator.check("input_size's batch", x_size_v[0], "dout's batch", doutput['shape'][0], Rel.EQ, self.name)
8556
8557        # infer shape
8558        dout_shape = doutput['shape']
8559        kernel_d = self.kernel_size[0]
8560        kernel_h = self.kernel_size[1]
8561        kernel_w = self.kernel_size[2]
8562        stride_d = self.stride[2]
8563        stride_h = self.stride[3]
8564        stride_w = self.stride[4]
8565        dilation_d = self.dilation[2]
8566        dilation_h = self.dilation[3]
8567        dilation_w = self.dilation[4]
8568        # The pad_mode is valid by default. If pad_mode is not valid or same, then pad.
8569        if self.pad_mode == "valid":
8570            self.pad_list = (0, 0, 0, 0, 0, 0)
8571        if self.pad_mode == "same":
8572            pad_needed_d = max(0, (dout_shape[2] - 1) * stride_d + dilation_d * (kernel_d - 1) + 1 - x_size_v[2])
8573            pad_head = math.floor(pad_needed_d / 2)
8574            pad_tail = pad_needed_d - pad_head
8575
8576            pad_needed_h = max(0, (dout_shape[3] - 1) * stride_h + dilation_h * (kernel_h - 1) + 1 - x_size_v[3])
8577            pad_top = math.floor(pad_needed_h / 2)
8578            pad_bottom = pad_needed_h - pad_top
8579
8580            pad_needed_w = max(0, (dout_shape[4] - 1) * stride_w + dilation_w * (kernel_w - 1) + 1 - x_size_v[4])
8581            pad_left = math.floor(pad_needed_w / 2)
8582            pad_right = pad_needed_w - pad_left
8583            self.pad_list = (pad_head, pad_tail, pad_top, pad_bottom, pad_left, pad_right)
8584
8585        self.add_prim_attr('pad_list', self.pad_list)
8586        out = {
8587            'value': None,
8588            'shape': x_size_v,
8589            'dtype': doutput['dtype'],
8590        }
8591        return out
8592
8593
8594def _deconv_output_length(input_length, kernel_size, stride_size, dilation_size):
8595    filter_size = kernel_size + (kernel_size - 1) * (dilation_size - 1)
8596    if filter_size - stride_size > 0:
8597        length = input_length * stride_size + filter_size - stride_size
8598    else:
8599        length = input_length * stride_size
8600    return length
8601
8602
8603class CTCLossV2(Primitive):
8604    """
8605    Calculates the CTC (Connectionist Temporal Classification) loss and the gradient.
8606
8607    The CTC algorithm is proposed in `Connectionist Temporal Classification: Labeling Unsegmented Sequence Data with
8608    Recurrent Neural Networks <http://www.cs.toronto.edu/~graves/icml_2006.pdf>`_.
8609
8610    Args:
8611        blank (int): The blank label. Default: 0.
8612        reduction (string): Apply specific reduction method to the output. Currently only support 'none',
8613            not case sensitive. Default: "none".
8614        zero_infinity (bool): Whether to set infinite loss and correlation gradient to zero. Default: False.
8615
8616    Inputs:
8617        - **log_probs** (Tensor) - A tensor of shape (T, N, C), where T is input length, N is batch size and C is number
8618            of classes (including blank).
8619        - **targets** (Tensor) - A tensor of shape (N, S), where S is max target length, means the target sequences.
8620        - **input_lengths** (Union(Tuple, Tensor)) - A tuple or Tensor of shape(N). It means the lengths of the input.
8621        - **target_lengths** (Union(Tuple, Tensor)) - A tuple or Tensor of shape(N). It means the lengths of the target.
8622
8623    Outputs:
8624        - **neg_log_likelihood** (Tensor) - A loss value which is differentiable with respect to each input node.
8625        - **log_alpha** (Tensor) - The probability of possible trace of input to target.
8626
8627    Raises:
8628        TypeError: If `zero_infinity` is not a bool, reduction is not string.
8629
8630    Supported Platforms:
8631
8632    """
8633
8634    @prim_attr_register
8635    def __init__(self, blank, reduction="none", zero_infinity=False):
8636        """Initialize CTCLossV2"""
8637        self.init_prim_io_names(inputs=["log_probs", "targets", "input_lengths", "target_lengths"],
8638                                outputs=["neg_log_likelihood", "log_alpha"])
8639        validator.check_value_type("blank", blank, [int], self.name)
8640        self.add_prim_attr("blank", blank)
8641        validator.check_value_type("reduction", reduction, [str], self.name)
8642        self.reduction = reduction.lower()
8643        validator.check_string(self.reduction, ['none'], 'reduction', self.name)
8644        self.add_prim_attr("reduction", self.reduction)
8645        validator.check_value_type("zero_infinity", zero_infinity, [bool], self.name)
8646        self.add_prim_attr("zero_infinity", zero_infinity)
8647
8648
8649class CTCLossV2Grad(Primitive):
8650    """
8651    Calculates the gradient of CTC (Connectionist Temporal Classification) loss.
8652
8653    The CTC algorithm is proposed in `Connectionist Temporal Classification: Labeling Unsegmented Sequence Data with
8654    Recurrent Neural Networks <http://www.cs.toronto.edu/~graves/icml_2006.pdf>`_.
8655
8656    Args:
8657        blank (int): The blank label. Default: 0.
8658        reduction (string): Apply specific reduction method to the output. Currently only support 'none'.
8659          Default: "none".
8660        zero_infinity (bool): Whether to set infinite loss and correlation gradient to zero. Default: False.
8661
8662    Inputs:
8663        - **grad_out** (Tenosr) - Gradient renewal codfficient, A tensor for shape (N), where N is batch size.
8664        - **log_probs** (Tensor) - A tensor of shape (T, N, C), where T is input length, N is batch size and C is number
8665            of classes (including blank).
8666        - **targets** (Tensor) - A tensor of shape (N, S), where S is max target length, means the target sequences.
8667        - **input_lengths** (Union(tuple, Tensor)) - A tuple or Tensor of shape(N). It means the lengths of the input.
8668        - **target_lengths** (Union(tuple, Tensor)) - A tuple or Tensor of shape(N). It means the lengths of the target.
8669        - **log_alpha** (Tensor) - The probability of possible trace of input to target.
8670        - **neg_log_likelihood** (Tensor) - A loss value which is differentiable with respect to each input node.
8671
8672    Outputs:
8673        - **grad** (Tensor) - The grad of Connectionist Temporal Classification Loss
8674
8675    Raises:
8676        TypeError: If `zero_infinity` is not a bool, reduction is not string.
8677
8678    Supported Platforms:
8679        ``Ascend``
8680    """
8681
8682    @prim_attr_register
8683    def __init__(self, blank, reduction="none", zero_infinity=False):
8684        """Initialize CTCLossV2Grad"""
8685        self.init_prim_io_names(inputs=["grad_out", "log_probs", "targets", "input_lengths", "target_lengths",
8686                                        "neg_log_likelihood", "log_alpha"],
8687                                outputs=["grad"])
8688        validator.check_value_type("blank", blank, [int], self.name)
8689        self.add_prim_attr("blank", blank)
8690        validator.check_value_type("reduction", reduction, [str], self.name)
8691        self.add_prim_attr("reduction", reduction)
8692        validator.check_value_type("zero_infinity", zero_infinity, [bool], self.name)
8693        self.add_prim_attr("zero_infinity", zero_infinity)
8694
8695
8696class Conv3DTranspose(PrimitiveWithInfer):
8697    r"""
8698    Computes a 3D transposed convolution, which is also known as a deconvolution
8699    (although it is not an actual deconvolution).
8700
8701    Input is typically of shape :math:`(N, C, D, H, W)`, where :math:`N` is batch size, :math:`C` is channel number,
8702    :math:`D` is depth, :math:`H` is height, :math:`W` is width.
8703
8704    If the 'pad_mode' is set to be "pad", the depth, height and width of output are defined as:
8705
8706    .. math::
8707        D_{out} = (D_{in} - 1) \times \text{stride}[0] - 2 \times \text{pad}[0] + \text{dilation}[0]
8708        \times (\text{kernel\_size}[0] - 1) + \text{output\_padding}[0] + 1
8709
8710        H_{out} = (H_{in} - 1) \times \text{stride}[1] - 2 \times \text{pad}[1] + \text{dilation}[1]
8711        \times (\text{kernel\_size}[1] - 1) + \text{output\_padding}[1] + 1
8712
8713        W_{out} = (W_{in} - 1) \times \text{stride}[2] - 2 \times \text{pad}[2] + \text{dilation}[2]
8714        \times (\text{kernel\_size}[2] - 1) + \text{output\_padding}[2] + 1
8715
8716    Args:
8717        in_channel (int): The channel of the input x.
8718        out_channel (int): The channel of the weight x.
8719        kernel_size (Union[int, tuple[int]]): The data type is int or a tuple of 3 integers.
8720            Specifies the depth, height and width of the 3D convolution window.
8721            Single int means the value is for the depth, height and the width of the kernel.
8722            A tuple of 3 ints means the first value is for the depth, second value is for height and the
8723            other is for the width of the kernel.
8724        mode (int): Modes for different convolutions. Default is 1. It is currently not used.
8725        pad_mode (str): Specifies padding mode. The optional values are
8726            "same", "valid", "pad", not case sensitive. Default: "valid".
8727
8728            - same: Adopts the way of completion. The depth, height and width of the output will be the same as
8729              the input. The total number of padding will be calculated in depth, horizontal and vertical
8730              directions and evenly distributed to head and tail, top and bottom, left and right if possible.
8731              Otherwise, the last extra padding will be done from the tail, bottom and the right side.
8732              If this mode is set, `pad` and `output_padding` must be 0.
8733
8734            - valid: Adopts the way of discarding. The possible largest depth, height and width of output
8735              will be returned without padding. Extra pixels will be discarded. If this mode is set, `pad`
8736              and `output_padding` must be 0.
8737
8738            - pad: Implicit paddings on both sides of the input in depth, height, width. The number of `pad` will
8739              be padded to the input Tensor borders. `pad` must be greater than or equal to 0.
8740
8741        pad (Union(int, tuple[int])): The pad value to be filled. Default: 0. If `pad` is an integer, the paddings of
8742             head, tail, top, bottom, left and right are the same, equal to pad. If `pad` is a tuple of six integers,
8743             the padding of head, tail, top, bottom, left and right equal to pad[0], pad[1], pad[2], pad[3], pad[4]
8744             and pad[5] correspondingly.
8745        stride (Union(int, tuple[int])): The distance of kernel moving, an int number that represents
8746            the depth, height and width of movement are both strides, or a tuple of three int numbers that
8747            represent depth, height and width of movement respectively. Default: 1.
8748        dilation (Union(int, tuple[int])): Specifies the space to use between kernel elements. Default: 1.
8749        group (int): Splits input into groups. Default: 1. Only 1 is currently supported.
8750        output_padding (Union(int, tuple[int])): Add extra size to each dimension of the output. Default: 0.
8751        data_format (str): The optional value for data format. Currently only 'NCDHW' is supported.
8752
8753    Inputs:
8754        - **dout** (Tensor) - The gradients with respect to the output of the convolution.
8755          The shape conforms to the default.
8756          data_format :math:`(N, C_{in}, D_{out}, H_{out}, W_{out})`. Currently dout data type only supports float16
8757          and float32.
8758        - **weight** (Tensor) - Set size of kernel is :math:`(K_d, K_h, K_w)`, then the shape is
8759          :math:`(C_{in}, C_{out}//group, K_d, K_h, K_w)`. Where :math:`group` is the Args parameter.
8760          Currently weight data type only supports float16 and float32.
8761        - **bias** (Tensor) - Tensor of shape :math:`C_{out}`. Currently, only support none.
8762
8763    Outputs:
8764        Tensor, the gradients with respect to the input of convolution 3D.
8765        Tensor of shape :math:`(N, C_{out}//group, D_{out}, H_{out}, W_{out})`,
8766        where :math:`group` is the Args parameter.
8767
8768    Supported Platforms:
8769        ``Ascend`` ``GPU``
8770
8771    Raises:
8772        TypeError: If `in_channel`, `out_channel` or `group` is not an int.
8773        TypeError: If `kernel_size`, `stride`, `pad` , `dilation` or `output_padding` is neither an int not a tuple.
8774        ValueError: If `in_channel`, `out_channel`, `kernel_size`, `stride` or `dilation` is less than 1.
8775        ValueError: If `pad` is less than 0.
8776        ValueError: If `pad_mode` is not one of 'same', 'valid', 'pad'.
8777        ValueError: If `pad` is a tuple whose length is not equal to 6.
8778        ValueError: If `pad_mode` is not equal to 'pad' and `pad` is not equal to (0, 0, 0, 0, 0, 0).
8779        ValueError: If `data_format` is not 'NCDHW'.
8780        TypeError: If dout and weight data type is not float16.
8781        ValueError: If bias is not none. The rank of dout and weight is not 5.
8782
8783    Examples:
8784        >>> dout = Tensor(np.ones([32, 16, 10, 32, 32]), mindspore.float16)
8785        >>> weight = Tensor(np.ones([16, 3, 4, 6, 2]), mindspore.float16)
8786        >>> conv3d_transpose = ops.Conv3DTranspose(in_channel=16, out_channel=3, kernel_size=(4, 6, 2))
8787        >>> output = conv3d_transpose(dout, weight)
8788        >>> print(output.shape)
8789        (32, 3, 13, 37, 33)
8790    """
8791
8792    @prim_attr_register
8793    def __init__(self,
8794                 in_channel,
8795                 out_channel,
8796                 kernel_size,
8797                 mode=1,
8798                 pad_mode='valid',
8799                 pad=0,
8800                 stride=1,
8801                 dilation=1,
8802                 group=1,
8803                 output_padding=0,
8804                 data_format="NCDHW"):
8805        """Initialize Conv3DTranspose"""
8806        self.init_prim_io_names(inputs=['x', 'filter'], outputs=['output'])
8807        self.in_channel = validator.check_positive_int(in_channel, 'in_channel', self.name)
8808        self.add_prim_attr('in_channel', self.in_channel)
8809        self.out_channel = validator.check_positive_int(out_channel, 'out_channel', self.name)
8810        self.add_prim_attr('out_channel', self.out_channel)
8811        self.kernel_size = _check_3d_int_or_tuple('kernel_size', kernel_size, self.name)
8812        self.stride = _check_3d_int_or_tuple('stride', stride, self.name, allow_five=False,
8813                                             ret_five=True)
8814        self.add_prim_attr('strides', self.stride)
8815        self.dilation = _check_3d_int_or_tuple('dilation', dilation, self.name, allow_five=False,
8816                                               ret_five=True, third_one=True)
8817        self.add_prim_attr('dilations', self.dilation)
8818        validator.check_value_type('pad', pad, (int, tuple), self.name)
8819        if isinstance(pad, int):
8820            pad = (pad,) * 6
8821        if len(pad) != 6:
8822            raise ValueError(f"For '{self.name}', attr 'pad' should be an positive int number or a tuple of "
8823                             f"six positive int numbers, but got `{pad}`.")
8824        self.pad_list = pad
8825        validator.check_value_type('pad_mode', pad_mode, [str], self.name)
8826        self.pad_mode = validator.check_string(pad_mode.lower(), ['valid', 'same', 'pad'], 'pad_mode', self.name)
8827        self.add_prim_attr('pad_mode', self.pad_mode)
8828
8829        if self.pad_mode != 'pad' and pad != (0, 0, 0, 0, 0, 0):
8830            raise ValueError(f"For '{self.name}', the 'pad' must be (0, 0, 0, 0, 0, 0) when 'pad_mode' is not \"pad\", "
8831                             f"but got 'pad' is {pad} and 'pad_mode' is {self.pad_mode}.")
8832
8833        if self.pad_mode == 'pad':
8834            for item in self.pad_list:
8835                validator.check_non_negative_int(item, 'pad item', self.name)
8836        self.mode = validator.check_equal_int(mode, 1, 'mode', self.name)
8837        self.add_prim_attr('mode', self.mode)
8838        self.group = validator.check_equal_int(group, 1, 'group', self.name)
8839        self.add_prim_attr('groups', self.group)
8840        self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.name)
8841        self.add_prim_attr('data_format', self.format)
8842
8843        self.output_padding = _check_3d_int_or_tuple('output_padding', output_padding, self.name,
8844                                                     allow_five=False, ret_five=True, greater_zero=False)
8845        output_padding = (self.output_padding[2], self.output_padding[3], self.output_padding[4])
8846        if self.pad_mode != 'pad' and output_padding != (0, 0, 0):
8847            raise ValueError(f"For '{self.name}', the 'output_padding' must be (0, 0, 0) "
8848                             f"when 'pad_mode' is not \"pad\", "
8849                             f"but got 'output_padding' is {output_padding} and 'pad_mode' is {self.pad_mode}.")
8850        validator.check_int_range(self.kernel_size[0] * self.kernel_size[1] * self.kernel_size[2], 1, 343, Rel.INC_BOTH,
8851                                  'The product of height, width and depth of kernel_size belonging [1, 343]', self.name)
8852        validator.check_int_range(self.stride[0] * self.stride[1] * self.stride[2], 1, 343, Rel.INC_BOTH,
8853                                  'The product of height, width and depth of stride belonging [1, 343]', self.name)
8854        validator.check_int_range(self.stride[1] * self.stride[2], 1, 256, Rel.INC_BOTH,
8855                                  'The product of height, width and depth of stride belonging [1, 256]', self.name)
8856        validator.check_int_range(self.output_padding[2], 0, max(self.dilation[2], self.stride[2]), Rel.INC_LEFT,
8857                                  'output_padding_d belonging [0, max(stride_d, dilation_d))', self.name)
8858        validator.check_int_range(self.output_padding[3], 0, max(self.dilation[3], self.stride[3]), Rel.INC_LEFT,
8859                                  'output_padding_h belonging [0, max(stride_h,dilation_h))', self.name)
8860        validator.check_int_range(self.output_padding[4], 0, max(self.dilation[4], self.stride[4]), Rel.INC_LEFT,
8861                                  'output_padding_w belonging [0, max(stride_w,dilation_w))', self.name)
8862
8863    def __infer__(self, x, w, b=None):
8864        args = {'x': x['dtype'], 'w': w['dtype']}
8865        if b is not None:
8866            raise ValueError(f"For '{self.name}', the 'bias' currently only support None, but got {b}.")
8867        valid_dtypes = [mstype.float16, mstype.float32]
8868        validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
8869
8870        # infer shape
8871        x_shape = x['shape']
8872        w_shape = w['shape']
8873        validator.check_equal_int(len(w_shape), 5, "weight rank", self.name)
8874        validator.check_equal_int(len(x_shape), 5, "x rank", self.name)
8875        validator.check("filter's batch", w_shape[0], "input x's channel",
8876                        x_shape[1], Rel.EQ, self.name)
8877
8878        kernel_d, kernel_h, kernel_w = self.kernel_size
8879        _, _, stride_d, stride_h, stride_w = self.stride
8880        _, _, dilation_d, dilation_h, dilation_w = self.dilation
8881
8882        if self.pad_mode == "valid":
8883            d_out = _deconv_output_length(x_shape[2], kernel_d, stride_d, dilation_d)
8884            h_out = _deconv_output_length(x_shape[3], kernel_h, stride_h, dilation_h)
8885            w_out = _deconv_output_length(x_shape[4], kernel_w, stride_w, dilation_w)
8886            self.pad_list = (0, 0, 0, 0, 0, 0)
8887            self.output_padding = (0, 0, 0, 0, 0)
8888
8889        elif self.pad_mode == "same":
8890            d_out = x_shape[2] * stride_d
8891            h_out = x_shape[3] * stride_h
8892            w_out = x_shape[4] * stride_w
8893
8894            pad_needed_d = max(0, (x_shape[2] - 1) * stride_d + dilation_d * (kernel_d - 1) + 1 - d_out)
8895            pad_head = math.floor(pad_needed_d / 2)
8896            pad_tail = pad_needed_d - pad_head
8897
8898            pad_needed_h = max(0, (x_shape[3] - 1) * stride_h + dilation_h * (kernel_h - 1) + 1 - h_out)
8899            pad_top = math.floor(pad_needed_h / 2)
8900            pad_bottom = pad_needed_h - pad_top
8901
8902            pad_needed_w = max(0, (x_shape[4] - 1) * stride_w + dilation_w * (kernel_w - 1) + 1 - w_out)
8903            pad_left = math.floor(pad_needed_w / 2)
8904            pad_right = pad_needed_w - pad_left
8905            self.pad_list = (pad_head, pad_tail, pad_top, pad_bottom, pad_left, pad_right)
8906            self.output_padding = (0, 0, 0, 0, 0)
8907
8908        elif self.pad_mode == 'pad':
8909            pad_head, pad_tail, pad_top, pad_bottom, pad_left, pad_right = self.pad_list
8910            d_out = (x_shape[2] - 1) * self.stride[2] - (pad_head + pad_tail) + self.dilation[2] * \
8911                    (self.kernel_size[0] - 1) + self.output_padding[2] + 1
8912            h_out = (x_shape[3] - 1) * self.stride[3] - (pad_top + pad_bottom) + self.dilation[3] * \
8913                    (self.kernel_size[1] - 1) + self.output_padding[3] + 1
8914            w_out = (x_shape[4] - 1) * self.stride[4] - (pad_left + pad_right) + self.dilation[4] * \
8915                    (self.kernel_size[2] - 1) + self.output_padding[4] + 1
8916
8917        self.add_prim_attr('pad_list', self.pad_list)
8918        self.add_prim_attr('output_padding', self.output_padding)
8919        output_shape = (x_shape[0], w_shape[1] * self.group, d_out, h_out, w_out)
8920        self.add_prim_attr('input_size', output_shape)
8921        out = {
8922            'value': None,
8923            'shape': output_shape,
8924            'dtype': x['dtype'],
8925        }
8926        return out
8927
8928
8929class SoftShrink(Primitive):
8930    r"""
8931    Applies the soft shrinkage function elementwise.
8932
8933    .. math::
8934        \text{SoftShrink}(x) =
8935        \begin{cases}
8936        x - \lambda, & \text{ if } x > \lambda \\
8937        x + \lambda, & \text{ if } x < -\lambda \\
8938        0, & \text{ otherwise }
8939        \end{cases}
8940
8941    Args:
8942        lambd: the :math:`\lambda` must be no less than zero value for the Softshrink formulation. Default: 0.5.
8943
8944    Inputs:
8945        - **input_x** (Tensor) - The input of SoftShrink with data type of float16 or float32.
8946          Any number of additional dimensions.
8947
8948    Outputs:
8949        Tensor, has the same shape and data type as `input_x`.
8950
8951    Raises:
8952        TypeError: If lambd is not a float.
8953        TypeError: If input_x is not a Tensor.
8954        TypeError: If dtype of input_x is neither float16 nor float32.
8955        ValueError: If lambd is less than 0.
8956
8957    Supported Platforms:
8958        ``Ascend``
8959
8960    Examples:
8961        >>> input_x = Tensor(np.array([[ 0.5297,  0.7871,  1.1754], [ 0.7836,  0.6218, -1.1542]]), mindspore.float16)
8962        >>> softshrink = ops.SoftShrink()
8963        >>> output = softshrink(input_x)
8964        >>> print(output)
8965        [[ 0.02979  0.287    0.676  ]
8966         [ 0.2837   0.1216  -0.6543 ]]
8967    """
8968
8969    @prim_attr_register
8970    def __init__(self, lambd=0.5):
8971        """Initialize SoftShrink"""
8972        validator.check_value_type("lambd", lambd, [float], self.name)
8973        validator.check_number("lambd", lambd, 0, Rel.GE, self.name)
8974
8975
8976class HShrink(Primitive):
8977    r"""
8978    Applies the hard shrinkage function element-wise, each element complies the follow function:
8979
8980    .. math::
8981        \text{HardShrink}(x) =
8982        \begin{cases}
8983        x, & \text{ if } x > \lambda \\
8984        x, & \text{ if } x < -\lambda \\
8985        0, & \text{ otherwise }
8986        \end{cases}
8987
8988    Args:
8989        lambd (float): The value for the HardShrink formulation. Default: 0.5
8990
8991    Inputs:
8992        - **input_x** (Tensor) - The input of HardShrink with data type of float16 or float32.
8993
8994    Outputs:
8995        Tensor, the same shape and data type as the input.
8996
8997    Supported Platforms:
8998        ``Ascend``
8999
9000    Raises:
9001        TypeError: If `lambd` is not a float.
9002        TypeError: If dtype of `input_x` is neither float16 nor float32.
9003
9004    Examples:
9005        >>> input_x = Tensor(np.array([[ 0.5,  1,  2.0],[0.0533,0.0776,-2.1233]]),mstype.float32)
9006        >>> hshrink = P.HShrink()
9007        >>> output = hshrink(input_x)
9008        >>> print(output)
9009        [[ 0.      1.      2.    ]
9010        [ 0.      0.     -2.1233]]
9011    """
9012
9013    @prim_attr_register
9014    def __init__(self, lambd=0.5):
9015        """Initialize HShrink"""
9016        validator.check_value_type('lambd', lambd, [float], self.name)
9017        if lambd < 0.0:
9018            lambd = 0.0
9019            self.add_prim_attr('lambd', lambd)
9020
9021
9022class ApplyAdagradDA(Primitive):
9023    r"""
9024    Update `var` according to the proximal adagrad scheme.
9025
9026    .. math::
9027        \begin{array}{ll} \\
9028            grad_accum += grad \\
9029            grad_squared_accum += grad * grad \\
9030            tmp_val=sign(grad_accum) * max\left \{|grad_accum|-l1*global_step, 0\right \}
9031                    if l1>0 else grad_accum \\
9032            x_value = -1 * lr * tmp_val \\
9033            y_value = l2 * global_step * lr + \sqrt{grad_squared_accum} \\
9034            var = x_value / y_value
9035        \end{array}
9036
9037    Inputs of `var`, `gradient_accumulator`, `gradient_squared_accumulator` and `grad`
9038    comply with the implicit type conversion rules to make the data types consistent.
9039    If they have different data types, lower priority data type will be converted to
9040    relatively highest priority data type.
9041    RuntimeError exception will be thrown when the data type conversion of Parameter is required.
9042
9043    Args:
9044        use_locking (bool): If `True`, updating of the `var` and `accum` tensors will be protected by a lock.
9045                            Otherwise the behavior is undefined, but may exhibit less contention. Default: False.
9046
9047    Inputs:
9048        - **var** (Parameter) - Variable to be updated. The data type must be float16 or float32.
9049          The shape is :math:`(N, *)` where :math:`*` means, any number of additional dimensions.
9050        - **gradient_accumulator** (Parameter) - The dict of mutable tensor gradient_accumulator. Must have the same
9051          shape and dtype as `var`.
9052        - **gradient_squared_accumulator** (Parameter) - The dict of mutable tensor gradient_squared_accumulator.
9053          Must have the same shape and dtype as `var`.
9054        - **grad** (Tensor) - A tensor for gradient. Must have the same shape and dtype as `var`.
9055        - **lr** ([Number, Tensor]) - Scaling factor. Must be a scalar. With float32 or float16 data type.
9056        - **l1** ([Number, Tensor]) -  L1 regularization. Must be a scalar. With float32 or float16 data type.
9057        - **l2** ([Number, Tensor]) -  L2 regularization. Must be a scalar. With float32 or float16 data type.
9058        - **global_step** ([Number, Tensor]) - Training step number. Must be a scalar. With int32 or int64 data type.
9059
9060    Outputs:
9061        Tuple of 3 Tensors, the updated parameters.
9062
9063        - **var** (Tensor) - The same shape and data type as `var`.
9064        - **gradient_accumulator** (Tensor) - The same shape and data type as `gradient_accumulator`.
9065        - **gradient_squared_accumulator** (Tensor) - The same shape and data type as `gradient_squared_accumulator`.
9066
9067    Raises:
9068        TypeError: If `var`, `gradient_accumulator`, `gradient_squared_accumulator` is not a Parameter.
9069        TypeError: If `grad` is not a Tensor.
9070        TypeError: If `lr`, `l1`, `l2` or `global_step` is neither a Number nor a Tensor.
9071        TypeError: If use_locking is not a bool.
9072        TypeError: If dtype of `var`, `gradient_accumulator`, `gradient_squared_accumulator`, `gradient_accumulator`,
9073                   `lr`, `l1`, `l2` is neither float16 nor float32.
9074        TypeError: If dtype of `gradient_accumulator`, `gradient_squared_accumulator`, `gradient_accumulator`
9075                     is not same as `var`.
9076        TypeError: If dtype of `global_step` is not int32 or int64.
9077        ValueError: If the shape size of `lr`, `l1`, `l2` and `global_step` is not 0.
9078
9079    Supported Platforms:
9080        ``Ascend``
9081
9082    Examples:
9083        >>> class ApplyAdagradDANet(nn.Cell):
9084        ...     def __init__(self, use_locking=False):
9085        ...         super(ApplyAdagradDANet, self).__init__()
9086        ...         self.apply_adagrad_d_a = P.ApplyAdagradDA(use_locking)
9087        ...         self.var = Parameter(Tensor(np.array([[0.6, 0.4], [0.1, 0.5]]).astype(np.float32)), name="var")
9088        ...         self.gradient_accumulator = Parameter(Tensor(np.array([[0.1, 0.3],
9089        ...                                                                [0.1, 0.5]]).astype(np.float32)),
9090        ...                                               name="gradient_accumulator")
9091        ...         self.gradient_squared_accumulator = Parameter(Tensor(np.array([[0.2, 0.1],
9092        ...                                                                        [0.1, 0.2]]).astype(np.float32)),
9093        ...                                                       name="gradient_squared_accumulator")
9094        ...         self.gradient_accumulator = Parameter(Tensor(np.array([[0.1, 0.3],
9095        ...                                                                [0.1, 0.5]]).astype(np.float32)),
9096        ...                                               name="gradient_accumulator")
9097        ...     def construct(self, grad, lr, l1, l2, global_step):
9098        ...         out = self.apply_adagrad_d_a(self.var, self.gradient_accumulator,
9099        ...                                      self.gradient_squared_accumulator, grad, lr, l1, l2, global_step)
9100        ...         return out
9101        ...
9102        >>> net = ApplyAdagradDANet()
9103        >>> grad = Tensor(np.array([[0.3, 0.4], [0.1, 0.2]]).astype(np.float32))
9104        >>> lr = Tensor(0.001, mstype.float32)
9105        >>> l1 = Tensor(0.001, mstype.float32)
9106        >>> l2 = Tensor(0.001, mstype.float32)
9107        >>> global_step = Tensor(2, mstype.int32)
9108        >>> output = net(grad, lr, l1, l2, global_step)
9109        >>> print(output)
9110        (Tensor(shape=[2, 2], dtype=Float32, value=
9111        [[-7.39064650e-04, -1.36888528e-03],
9112         [-5.96988888e-04, -1.42478070e-03]]), Tensor(shape=[2, 2], dtype=Float32, value=
9113        [[ 4.00000006e-01,  7.00000048e-01],
9114         [ 2.00000003e-01,  6.99999988e-01]]), Tensor(shape=[2, 2], dtype=Float32, value=
9115        [[ 2.90000021e-01,  2.60000020e-01],
9116         [ 1.09999999e-01,  2.40000010e-01]]))
9117    """
9118
9119    __mindspore_signature__ = (
9120        sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
9121        sig.make_sig('gradient_accumulator', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
9122        sig.make_sig('gradient_squared_accumulator', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
9123        sig.make_sig('grad', dtype=sig.sig_dtype.T),
9124        sig.make_sig('lr', dtype=sig.sig_dtype.T1),
9125        sig.make_sig('l1', dtype=sig.sig_dtype.T2),
9126        sig.make_sig('l2', dtype=sig.sig_dtype.T3),
9127        sig.make_sig('global_step', dtype=sig.sig_dtype.T4)
9128    )
9129
9130    @prim_attr_register
9131    def __init__(self, use_locking=False):
9132        """Initialize ApplyAdagradDA"""
9133        validator.check_value_type("use_locking", use_locking, [bool], self.name)
9134
9135
9136class SparseApplyRMSProp(Primitive):
9137    r"""
9138    Update relevant entries according to the rmsprop algorithm.
9139
9140    .. math::
9141        \begin{array}{ll} \\
9142            ms = rho * ms_{t-1} + (1 - rho) * grad * grad \\
9143            mom = momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon) \\
9144            var = var - mom
9145        \end{array}
9146
9147    Inputs of `var`, `ms`, `mom` and `grad` comply with the implicit type conversion rules
9148    to make the data types consistent.
9149    If they have different data types, lower priority data type will be converted to
9150    relatively highest priority data type.
9151    RuntimeError exception will be thrown when the data type conversion of Parameter is required.
9152
9153    Args:
9154        rho (float): Decay rate. The value should between 0 and 1, otherwise the behavior is undefined.
9155        momentum (float): Momentum. The value should be greater or equal to 0, otherwise the behavior is undefined.
9156        epsilon (float): A small value added for numerical stability. The value should be greater than 0,
9157                         otherwise the behavior is undefined.
9158        use_locking (bool): If `True`, updating of the var, ms, and mom tensors is protected by a lock;
9159                            otherwise the behavior is undefined, but may exhibit less contention. Default: False.
9160
9161    Inputs:
9162        - **var** (Parameter) - Variable to be updated. The data type must be float16 or float32.
9163          The shape is :math:`(N, *)` where :math:`*` means, any number of additional dimensions.
9164        - **ms** (Parameter) - The dict of mutable tensor ms. Must have the same shape and dtype as `var`.
9165        - **mom** (Parameter) - The dict of mutable tensor mom. Must have the same shape and dtype as `var`.
9166        - **lr** ([Number, Tensor]) - Learning rate. Must be a scalar. With float16 or float32 data type.
9167        - **grad** (Tensor) - A tensor for gradient. Must have the same shape and dtype as `var`.
9168        - **indices** (Tensor) - A tensor of indices in the first dimension of `var`, `ms` and `mom`.
9169          If there are duplicates in `indices`, the behavior is undefined. Must be one of the
9170          following types: int32, int64 and indices.shape[0] = var.shape[0].
9171
9172    Outputs:
9173        Tuple of 3 Tensors, the updated parameters.
9174
9175        - **var** (Tensor) -  The same shape and data type as `var`.
9176        - **ms** (Tensor) - The same shape and data type as `ms`.
9177        - **mom** (Tensor) - The same shape and data type as `mom`.
9178
9179    Raises:
9180        TypeError: If `var`, `ms` or `mom` is not a Parameter.
9181        TypeError: If `grad` or `indices` is not a Tensor.
9182        TypeError: If dtype of `var`, `ms`, `mom`, `lr`, `grad` is neither float16 nor float32.
9183        TypeError: If dtype of `indices` is neither int32 nor int64.
9184        TypeError: If `lr` is neither a Number or a Tensor.
9185        TypeError: If `use_locking` is not a bool.
9186        TypeError: If dtype of `epsilon`, `rho`, `momentum` is not a float.
9187        ValueError: If shape of `ms`, `mom`, `grad` is not same as `var`.
9188        ValueError: If the shape size of `lr` is not 0.
9189        ValueError: If shape of `indices` is not same as shape of first dimension of `var`.
9190        ValueError: If `epsilon` is less than or equal to 0.
9191        ValueError: If `momentum` is less than 0.
9192        ValueError: If `rho` is less than 0 or greater than 1.
9193        ValueError: If dimension of `var` is less than 1.
9194
9195    Supported Platforms:
9196        ``Ascend``
9197
9198    Examples:
9199        >>> class SparseApplyRMSPropNet(nn.Cell):
9200        ...     def __init__(self, rho, momentum, epsilon, use_locking=False):
9201        ...         super(SparseApplyRMSPropNet, self).__init__()
9202        ...         self.sparse_apply_r_m_s_prop = P.SparseApplyRMSProp(rho, momentum, epsilon, use_locking)
9203        ...         self.var = Parameter(Tensor(np.array([[0.6, 0.3], [0.1, 0.5]]).astype(np.float32)), name="var")
9204        ...         self.ms = Parameter(Tensor(np.array([[0.2, 0.4], [0.1, 0.3]]).astype(np.float32)), name="ms")
9205        ...         self.mom = Parameter(Tensor(np.array([[0.3, 0.1], [0.3, 0.6]]).astype(np.float32)), name="mom")
9206        ...     def construct(self, lr, grad, indices):
9207        ...         out = self.sparse_apply_r_m_s_prop(self.var, self.ms, self.mom, lr, grad, indices)
9208        ...         return out
9209        ...
9210        >>> rho = 0.2
9211        >>> momentum = 0.01
9212        >>> epsilon = 1e-6
9213        >>> net = SparseApplyRMSPropNet(rho, momentum, epsilon)
9214        >>> lr = 0.01
9215        >>> grad = Tensor(np.array([[0.3, 0.7], [0.1, 0.8]]).astype(np.float32))
9216        >>> indices = Tensor(np.array([0, 1], dtype=np.int32))
9217        >>> out = net(lr, grad, indices)
9218        >>> print(out)
9219        (Tensor(shape=[2, 2], dtype=Float32, value=
9220        [[ 5.88035822e-01,  2.88811117e-01],
9221         [ 9.10239667e-02,  4.83422279e-01]]), Tensor(shape=[2, 2], dtype=Float32, value=
9222        [[ 1.12000003e-01,  4.72000003e-01],
9223         [ 2.80000009e-02,  5.72000027e-01]]), Tensor(shape=[2, 2], dtype=Float32, value=
9224        [[ 1.19641740e-02,  1.11888833e-02],
9225         [ 8.97603668e-03,  1.65777095e-02]]))
9226    """
9227
9228    __mindspore_signature__ = (
9229        sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
9230        sig.make_sig('ms', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
9231        sig.make_sig('mom', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
9232        sig.make_sig('lr', dtype=sig.sig_dtype.T1),
9233        sig.make_sig('grad', dtype=sig.sig_dtype.T),
9234        sig.make_sig('indices', dtype=sig.sig_dtype.T2)
9235    )
9236
9237    @prim_attr_register
9238    def __init__(self, rho, momentum, epsilon, use_locking=False):
9239        """"Initialize SparseApplyRMSProp"""
9240        validator.check_value_type("rho", rho, [float], self.name)
9241        validator.check_value_type("momentum", momentum, [float], self.name)
9242        validator.check_value_type("epsilon", epsilon, [float], self.name)
9243        validator.check_value_type("use_locking", use_locking, [bool], self.name)
9244        self.epsilon = validator.check_number("epsilon", epsilon, 0.0, Rel.GT, self.name)
9245        self.momentum = validator.check_number("momentum", momentum, 0.0, Rel.GE, self.name)
9246        self.rho = validator.check_float_range(rho, 0.0, 1.0, Rel.INC_BOTH, "rho", self.name)
9247