• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2023 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Operators for nn."""
16from __future__ import absolute_import
17from __future__ import division
18
19import numbers
20import math
21import numpy as np
22from mindspore.ops import signature as sig
23from mindspore.ops.primitive import Primitive, prim_attr_register, prim_arg_register, PrimitiveWithInfer
24from mindspore.ops._primitive_cache import _get_cache_prim
25from mindspore.ops.auto_generate import gen_arg_handler as handler
26from mindspore.common import Tensor, CSRTensor, COOTensor
27from mindspore.common._stub_tensor import _convert_stub
28from mindspore._c_expression import typing
29from mindspore._c_expression import Tensor as Tensor_
30from mindspore._c_expression import pyboost_cast, pyboost_tile, pyboost_zeros, pyboost_ones
31from mindspore.common import dtype as mstype
32from mindspore.common._utils import is_shape_unknown
33from mindspore import _checkparam as validator
34from mindspore.ops.operations.manually_defined._inner import ScalarCast
35from mindspore.ops_generate.gen_ops_inner_prim import DtypeToEnum
36from mindspore.common.initializer import Zero
37from mindspore.common.parameter import Parameter
38from mindspore.ops.auto_generate.gen_ops_prim import FlashAttentionScore
39
40
41dtype_to_type_id = DtypeToEnum()
42
43
44dtype_to_type_id = DtypeToEnum()
45
46
47class ScalarDiv(Primitive):
48    r"""
49    Computes the quotient of dividing the first input scalar by the second input scalar element-wise.
50
51    .. math::
52
53        out_{i} = \frac{x_i}{y_i}
54
55    .. note::
56        The inputs can be constant/variable value. Usage is the same as '/' in Python.
57        This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
58
59    Inputs:
60        - **x** (Scalar) - A constant or variable scalar.
61        - **y** (Scalar) - A constant or variable scalar.
62
63    Outputs:
64        Scalar, the type of scalar is float.
65
66    Raises:
67        TypeError: If `x` and `y` are not scalar.
68        ValueError: If `y` is 0.
69
70    Supported Platforms:
71        ``Ascend`` ``GPU`` ``CPU``
72    """
73    @prim_attr_register
74    def __init__(self):
75        """Initialize ScalarDiv"""
76
77    def __call__(self, x, y):
78        if y == 0:
79            raise ValueError('The divisor could not be zero. But the divisor is zero now.')
80        return x / y
81
82
83class ScalarFloorDiv(Primitive):
84    r"""
85    Computes the quotient of dividing the first input scalar by the second input scalar element-wise.
86
87    .. math::
88
89        out_{i} = \frac{x_i}{y_i}
90
91    .. note::
92        The inputs can be constant/variable value. Usage is the same as '//' in Python.
93        This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
94
95    Inputs:
96        - **x** (Scalar) - A constant or variable scalar.
97        - **y** (Scalar) - A constant or variable scalar.
98
99    Outputs:
100        Scalar, the type of scalar is float.
101
102    Raises:
103        TypeError: If `x` and `y` are not scalar.
104        ValueError: If `y` is 0.
105
106    Supported Platforms:
107        ``Ascend`` ``GPU`` ``CPU``
108    """
109    @prim_attr_register
110    def __init__(self):
111        """Initialize ScalarFloorDiv"""
112        self.init_prim_io_names(inputs=['x', 'y'], outputs=['output'])
113
114    def __call__(self, x, y):
115        if y == 0:
116            raise ValueError('The divisor could not be zero. But the divisor is zero now.')
117        return x // y
118
119
120class ScalarAdd(Primitive):
121    r"""
122    Adds two input scalar.
123
124    .. note::
125        The inputs can be constant/variable value. Usage is the same as '+' in Python.
126        This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
127
128    Inputs:
129        - **x** (Scalar) - A constant or variable scalar.
130        - **y** (Scalar) - A constant or variable scalar.
131
132    Outputs:
133        Scalar, and the data type is the one with higher precision or higher digits among the two inputs.
134
135    Raises:
136        TypeError: If `x` and `y` are not scalar.
137
138    Supported Platforms:
139        ``Ascend`` ``GPU`` ``CPU``
140    """
141    @prim_attr_register
142    def __init__(self):
143        """Initialize ScalarAdd"""
144
145    def __call__(self, x, y):
146        return x + y
147
148
149class ScalarPow(Primitive):
150    r"""
151    Pow two input scalar.
152
153    .. note::
154        The inputs can be constant/variable value. Usage is the same as '+' in Python.
155        This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
156
157    Inputs:
158        - **x** (Scalar) - A constant or variable scalar.
159        - **y** (Scalar) - A constant or variable scalar.
160
161    Outputs:
162        Scalar, and the data type is the one with higher precision or higher digits among the two inputs.
163
164    Raises:
165        TypeError: If `x` and `y` are not scalar.
166
167    Supported Platforms:
168        ``Ascend`` ``GPU`` ``CPU``
169    """
170    @prim_attr_register
171    def __init__(self):
172        """Initialize ScalarPow"""
173
174    def __call__(self, x, y):
175        return pow(x, y)
176
177
178class ScalarLog(Primitive):
179    r"""
180    Log input scalar.
181
182    .. note::
183        The inputs can be constant/variable value. Usage is the same as '+' in Python.
184        This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
185
186    Inputs:
187        - **x** (Scalar) - A constant or variable scalar.
188
189    Outputs:
190        Scalar, and the data type is the one with higher precision or higher digits among the two inputs.
191
192    Raises:
193        TypeError: If `x` and `y` are not scalar.
194
195    Supported Platforms:
196        ``Ascend`` ``GPU`` ``CPU``
197    """
198    @prim_attr_register
199    def __init__(self):
200        """Initialize ScalarAdd"""
201
202    def __call__(self, x):
203        return math.log(x)
204
205
206class ScalarUadd(Primitive):
207    r"""
208    UAdds input scalar.
209
210    .. note::
211        The inputs can be constant/variable value. Usage is the same as '+' in Python.
212        This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
213
214    Inputs:
215        - **x** (Scalar) - A constant or variable scalar.
216
217    Outputs:
218        Scalar, and the data type is the one with higher precision or higher digits among the two inputs.
219
220    Raises:
221        TypeError: If `x` and `y` are not scalar.
222
223    Supported Platforms:
224        ``Ascend`` ``GPU`` ``CPU``
225    """
226    @prim_attr_register
227    def __init__(self):
228        """Initialize ScalarAdd"""
229
230    def __call__(self, x):
231        return x
232
233
234class ScalarUsub(Primitive):
235    r"""
236    usub input scalar.
237
238    .. note::
239        The inputs can be constant/variable value. Usage is the same as '+' in Python.
240        This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
241
242    Inputs:
243        - **x** (Scalar) - A constant or variable scalar.
244        - **y** (Scalar) - A constant or variable scalar.
245
246    Outputs:
247        Scalar, and the data type is the one with higher precision or higher digits among the two inputs.
248
249    Raises:
250        TypeError: If `x` and `y` are not scalar.
251
252    Supported Platforms:
253        ``Ascend`` ``GPU`` ``CPU``
254    """
255    @prim_attr_register
256    def __init__(self):
257        """Initialize ScalarUsub"""
258
259    def __call__(self, x):
260        return -x
261
262
263class ScalarSub(Primitive):
264    r"""
265    Subtracts the second input Scalar from the first input Scalar.
266
267    .. note::
268        The inputs can be constant/variable value. Usage is the same as '-' in Python.
269        This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
270
271    Inputs:
272        - **x** (Scalar) - A constant or variable scalar.
273        - **y** (Scalar) - A constant or variable scalar.
274
275    Outputs:
276        Scalar, and the data type is the one with higher precision or higher digits among the two inputs.
277
278    Raises:
279        TypeError: If `x` and `y` are not scalar.
280
281    Supported Platforms:
282        ``Ascend`` ``GPU`` ``CPU``
283    """
284    @prim_attr_register
285    def __init__(self):
286        """Initialize ScalarSub"""
287
288    def __call__(self, x, y):
289        return x - y
290
291
292class ScalarMul(Primitive):
293    r"""
294    Muls two input scalar.
295
296    .. note::
297        The inputs can be constant/variable value. Usage is the same as '+' in Python.
298        This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
299
300    Inputs:
301        - **x** (Scalar) - A constant or variable scalar.
302        - **y** (Scalar) - A constant or variable scalar.
303
304    Outputs:
305        Scalar, and the data type is the one with higher precision or higher digits among the two inputs.
306
307    Raises:
308        TypeError: If `x` and `y` are not scalar.
309
310    Supported Platforms:
311        ``Ascend`` ``GPU`` ``CPU``
312    """
313    @prim_attr_register
314    def __init__(self):
315        """Initialize ScalarMul"""
316
317    def __call__(self, x, y):
318        return x * y
319
320
321class ScalarEq(Primitive):
322    r"""
323    Computes the equivalence between two Scalars.
324
325    .. note::
326        The inputs can be constant/variable value. Usage is the same as '==' in Python.
327        This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
328
329    Inputs:
330        - **x** (Scalar) - A constant or variable scalar.
331        - **y** (Scalar) - A constant or variable scalar.
332
333    Outputs:
334        Scalar, the type of scalar is bool.
335
336    Raises:
337        TypeError: If `x` and `y` are not scalar.
338
339    Supported Platforms:
340        ``Ascend`` ``GPU`` ``CPU``
341    """
342    @prim_attr_register
343    def __init__(self):
344        """Initialize ScalarEq"""
345
346    def __call__(self, x, y):
347        return x == y
348
349
350class ScalarGt(Primitive):
351    r"""
352    Compare the value of the input scalars :math:`x,y`, and the output result is a bool value.
353
354    .. note::
355        The inputs can be constant/variable value. Usage is the same as '>' in Python.
356        This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
357
358    Inputs:
359        - **x** (Scalar) - A constant or variable scalar.
360        - **y** (Scalar) - A constant or variable scalar.
361
362    Outputs:
363        Scalar, the type of scalar is bool.
364
365    Raises:
366        TypeError: If `x` and `y` are not scalar.
367
368    Supported Platforms:
369        ``Ascend`` ``GPU`` ``CPU``
370    """
371    @prim_attr_register
372    def __init__(self):
373        """Initialize scalar_gt"""
374
375    def __call__(self, x, y):
376        return x > y
377
378
379class ScalarLt(Primitive):
380    r"""
381    Computes the boolean value of :math:`x < y`.
382
383    .. note::
384        The inputs can be constant/variable value. Usage is the same as '<' in Python.
385        This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
386
387    Inputs:
388        - **x** (Scalar) - A constant or variable scalar.
389        - **y** (Scalar) - A constant or variable scalar.
390
391    Outputs:
392        Scalar, the type of scalar is bool.
393
394    Raises:
395        TypeError: If `x` and `y` are not scalar.
396
397    Supported Platforms:
398        ``Ascend`` ``GPU`` ``CPU``
399    """
400    @prim_attr_register
401    def __init__(self):
402        """Initialize scalar_lt"""
403
404    def __call__(self, x, y):
405        return x < y
406
407
408class ScalarGe(Primitive):
409    r"""
410    Compare the value of the input scalars :math:`x,y`, and the output result is a bool value.
411
412    .. note::
413        The inputs can be constant/variable value. Usage is the same as '>=' in Python.
414        This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
415
416    Inputs:
417        - **x** (Scalar) - A constant or variable scalar.
418        - **y** (Scalar) - A constant or variable scalar.
419
420    Outputs:
421        Scalar, the type of scalar is bool.
422
423    Raises:
424        TypeError: If `x` and `y` are not scalar.
425
426    Supported Platforms:
427        ``Ascend`` ``GPU`` ``CPU``
428    """
429    @prim_attr_register
430    def __init__(self):
431        """Initialize scalar_ge"""
432
433    def __call__(self, x, y):
434        return x >= y
435
436
437class ScalarLe(Primitive):
438    r"""
439    Compare the value of the input scalars :math:`x,y`, and the output result is a bool value.
440
441    .. note::
442        The inputs can be constant/variable value. Usage is the same as '<=' in Python.
443        This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
444
445    Inputs:
446        - **x** (Scalar) - A constant or variable scalar.
447        - **y** (Scalar) - A constant or variable scalar.
448
449    Outputs:
450        Scalar, the type of scalar is bool.
451
452    Raises:
453        TypeError: If `x` and `y` are not scalar.
454
455    Supported Platforms:
456        ``Ascend`` ``GPU`` ``CPU``
457    """
458    @prim_attr_register
459    def __init__(self):
460        """Initialize scalar_le"""
461
462    def __call__(self, x, y):
463        return x <= y
464
465
466class ScalarMod(Primitive):
467    r"""
468    Computes the remainder of dividing the first input scalar by the second input scalar element-wise.
469
470    .. math::
471
472        out_{i} = x_{i} \text{ % } y_{i}
473
474    .. note::
475        The inputs can be constant/variable value. Usage is the same as '%' in Python.
476        This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
477
478    Inputs:
479        - **x** (Scalar) - A constant or variable scalar.
480        - **y** (Scalar) - A constant or variable scalar.
481
482    Outputs:
483        Scalar, the type is the one with higher precision or higher digits among the two inputs.
484
485    Raises:
486        TypeError: If `x` and `y` are not scalar.
487
488    Supported Platforms:
489        ``Ascend`` ``GPU`` ``CPU``
490    """
491    @prim_attr_register
492    def __init__(self):
493        """Initialize ScalarMod"""
494
495    def __call__(self, x, y):
496        if y == 0:
497            raise ValueError('Cannot perform modulo operation on zero.')
498        return x % y
499
500
501class ScalarBool(Primitive):
502    r"""
503    Computes the input scalar true or false.
504
505    .. note::
506        The inputs can be constant/variable value.
507        This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
508
509    Inputs:
510        - **x** (Scalar) - A constant or variable scalar.
511
512    Outputs:
513        Scalar, the type is bool.
514
515    Raises:
516        TypeError: If `x` are not scalar.
517
518    Supported Platforms:
519        ``Ascend`` ``GPU`` ``CPU``
520    """
521    @prim_attr_register
522    def __init__(self):
523        """Initialize ScalarBool"""
524
525    def __call__(self, x):
526        return bool(x)
527
528
529scalar_div = ScalarDiv()
530scalar_mod = ScalarMod()
531scalar_add = ScalarAdd()
532scalar_mul = ScalarMul()
533scalar_sub = ScalarSub()
534scalar_gt = ScalarGt()
535scalar_ge = ScalarGe()
536scalar_le = ScalarLe()
537scalar_lt = ScalarLt()
538scalar_eq = ScalarEq()
539scalar_bool = ScalarBool()
540scalar_floordiv = ScalarFloorDiv()
541scalar_log = ScalarLog()
542scalar_pow = ScalarPow()
543scalar_uadd = ScalarUadd()
544scalar_usub = ScalarUsub()
545
546
547class BatchNorm(Primitive):
548    r"""
549    Batch Normalization for input data and updated parameters.
550
551    Batch Normalization is widely used in convolutional neural networks. This operation
552    applies Batch Normalization over inputs to avoid internal covariate shift as described
553    in the paper `Batch Normalization: Accelerating Deep Network Training by Reducing Internal
554    Covariate Shift <https://arxiv.org/abs/1502.03167>`_. It rescales and recenters the
555    features using a mini-batch of data and the learned parameters can be described
556    in the following formula,
557
558    .. math::
559
560        y = \frac{x - mean}{\sqrt{variance + \epsilon}} * \gamma + \beta
561
562    where :math:`\gamma` is scale, :math:`\beta` is bias, :math:`\epsilon` is epsilon,
563    :math:`mean` is the mean of :math:`x`,
564    :math:`variance` is the variance of :math:`x`.
565
566    .. warning::
567        - If the operation is used for inference, and outputs "reserve_space_1" and "reserve_space_2" are available,
568          then "reserve_space_1" has the same value as "mean" and "reserve_space_2" has the same value as "variance".
569        - For Ascend 310, the result accuracy fails to reach 1‰ due to the square root instruction.
570
571    Args:
572        is_training (bool): If `is_training` is ``True`` , `mean` and `variance` are computed during training.
573            If `is_training` is ``False`` , they're loaded from checkpoint during inference. Default: ``False`` .
574        epsilon (float): A small value added for numerical stability. Default: ``1e-5``, value must be (0, 1] .
575        momentum (float): The hyper parameter to compute moving average for running_mean and running_var
576            (e.g. :math:`new\_running\_mean = (1 - momentum) * running\_mean + momentum * current\_mean`).
577            Momentum value must be [0, 1]. Default: ``0.1`` .
578        data_format (str): The optional value for data format, is ``'NHWC'`` or ``'NCHW'``, and the ``'NHWC'`` format
579            is only supported in GPU target. Default: ``"NCHW"`` .
580
581    Inputs:
582        If `is_training` is ``False`` , inputs are Tensors.
583
584        - **input_x** (Tensor) - Tensor of shape :math:`(N, C)`, with float16 or float32 data type.
585        - **scale** (Tensor) - Tensor of shape :math:`(C,)`, with float16 or float32 data type.
586        - **bias** (Tensor) - Tensor of shape :math:`(C,)`, has the same data type with `scale`.
587        - **mean** (Tensor) - Tensor of shape :math:`(C,)`, has the same data type with `scale`.
588        - **variance** (Tensor) - Tensor of shape :math:`(C,)`, has the same data type with `scale`.
589
590        If `is_training` is ``True`` , `scale`, `bias`, `mean` and `variance` are Parameters.
591
592        - **input_x** (Tensor) - Tensor of shape :math:`(N, C)`, with float16 or float32 data type.
593        - **scale** (Parameter) - Parameter of shape :math:`(C,)`, with float16 or float32 data type.
594        - **bias** (Parameter) - Parameter of shape :math:`(C,)`, has the same data type with `scale`.
595        - **mean** (Parameter) - Parameter of shape :math:`(C,)`, has the same data type with `scale`.
596        - **variance** (Parameter) - Parameter of shape :math:`(C,)`, has the same data type with `scale`.
597
598    Outputs:
599        Tuple of 5 Tensors, the normalized inputs and the updated parameters.
600
601        - **output_x** (Tensor) - The same type and shape as the input_x. The shape is :math:`(N, C)`.
602        - **batch_mean** (Tensor) - The mean calculated per-dimension over the mini-batches,
603          shape is :math:`(C,)`.
604        - **batch_variance** (Tensor) - The variance calculated per-dimension over the mini-batches,
605          shape is :math:`(C,)`.
606        - **reserve_space_1** (Tensor) - The mean that needs to be reused when calculating gradients,
607          one-dimensional Tensor. The shape is :math:`(C,)`.
608        - **reserve_space_2** (Tensor) - The variance that needs to be reused when calculating gradients,
609          one-dimensional Tensor. The shape is :math:`(C,)`.
610
611    Raises:
612        TypeError: If `is_training` is not a bool.
613        TypeError: If dtype of `epsilon` or `momentum` is not float.
614        TypeError: If `data_format` is not a str.
615        TypeError: If `input_x`, `scale`, `bias`, `mean` or `variance` is not a Tensor.
616        TypeError: If dtype of `input_x`, `scale` is neither float16 nor float32.
617
618    Supported Platforms:
619        ``Ascend`` ``GPU`` ``CPU``
620
621    Examples:
622        >>> import mindspore
623        >>> import numpy as np
624        >>> from mindspore import Tensor, ops
625        >>> input_x = Tensor(np.ones([2, 2]), mindspore.float32)
626        >>> scale = Tensor(np.ones([2]), mindspore.float32)
627        >>> bias = Tensor(np.ones([2]), mindspore.float32)
628        >>> mean = Tensor(np.ones([2]), mindspore.float32)
629        >>> variance = Tensor(np.ones([2]), mindspore.float32)
630        >>> batch_norm = ops.BatchNorm()
631        >>> output = batch_norm(input_x, scale, bias, mean, variance)
632        >>> print(output[0])
633        [[1. 1.]
634         [1. 1.]]
635    """
636    __mindspore_signature__ = (sig.make_sig('input_x', dtype=sig.sig_dtype.T1),
637                               sig.make_sig('scale',
638                                            sig.sig_rw.RW_WRITE,
639                                            dtype=sig.sig_dtype.T2),
640                               sig.make_sig('bias',
641                                            sig.sig_rw.RW_WRITE,
642                                            dtype=sig.sig_dtype.T2),
643                               sig.make_sig('mean',
644                                            sig.sig_rw.RW_WRITE,
645                                            dtype=sig.sig_dtype.T3),
646                               sig.make_sig('variance',
647                                            sig.sig_rw.RW_WRITE,
648                                            dtype=sig.sig_dtype.T3))
649
650    @prim_arg_register
651    def __init__(self,
652                 is_training=False,
653                 epsilon=1e-5,
654                 momentum=0.1,
655                 data_format="NCHW"):
656        """Initialize BatchNorm."""
657        if is_training is False:
658            self.set_signatures(tuple())
659        else:
660            self.add_prim_attr('side_effect_mem', True)
661        self.is_training = is_training
662        self.epsilon = epsilon
663        self.momentum = momentum
664        self.data_format = handler.str_to_enum("BatchNorm", "data_format", data_format)
665
666    def __call__(self, *args):
667        return super().__call__(*args, self.is_training, self.epsilon,
668                                self.momentum, self.data_format)
669
670
671def batch_norm_(input_x,
672                scale,
673                bias,
674                mean,
675                variance,
676                is_training=False,
677                epsilon=1e-5,
678                momentum=0.1,
679                data_format="NCHW"):
680    r"""
681    Batch Normalization for input data and updated parameters.
682
683    Batch Normalization is widely used in convolutional neural networks. This operation
684    applies Batch Normalization over inputs to avoid internal covariate shift as described
685    in the paper `Batch Normalization: Accelerating Deep Network Training by Reducing Internal
686    Covariate Shift <https://arxiv.org/abs/1502.03167>`_. It rescales and recenters the
687    features using a mini-batch of data and the learned parameters can be described
688    in the following formula,
689
690    .. math::
691
692        y = \frac{x - mean}{\sqrt{variance + \epsilon}} * \gamma + \beta
693
694    where :math:`\gamma` is scale, :math:`\beta` is bias, :math:`\epsilon` is epsilon,
695    :math:`mean` is the mean of :math:`x`,
696    :math:`variance` is the variance of :math:`x`.
697
698    .. warning::
699        - If the operation is used for inference, and outputs "reserve_space_1" and "reserve_space_2" are available,
700          then "reserve_space_1" has the same value as "mean" and "reserve_space_2" has the same value as "variance".
701        - For Atlas 200/300/500 inference product,
702          the result accuracy fails to reach 1‰ due to the square root instruction.
703
704    Note:
705        - If `training` is `False`, `weight`, `bias`, `running_mean` and `running_var` are tensors.
706        - If `training` is `True`, `weight`, `bias`, `running_mean` and `running_var` are Parameters.
707
708    Args:
709        input_x (tensor): tensor of shape :math:`(N, C)`, with float16 or float32 data type.
710        scale (Union[tensor, Parameter]): The shape :math:`(C,)`, has the same data type with `weight`.
711        bias (Union[tensor, Parameter]): The shape :math:`(C,)`, has the same data type with `weight`.
712        mean (Union[tensor, Parameter]): The shape :math:`(C,)`, with float16 or float32 data type.
713        variance (Union[tensor, Parameter]): The shape :math:`(C,)`, has the same data type with `weight`.
714        is_training (bool, optional): If `training` is `True`, `mean` and `variance` are computed during training.
715            If `training` is `False`, they're loaded from checkpoint during inference. Default: False.
716        epsilon (float): A small value added for numerical stability.
717            Default: ``1e-5``, value must be (0, 1] .
718        momentum (float): The hyper parameter to compute moving average for running_mean and running_var
719            (e.g. :math:`new\_running\_mean = (1 - momentum) * running\_mean + momentum * current\_mean`).
720            Momentum value must be [0, 1].
721            Default: ``0.1`` .
722        data_format (str): The optional value for data format, is ``'NHWC'`` or ``'NCHW'``,
723            and the ``'NHWC'`` format is only supported in GPU target.
724            Default: ``"NCHW"`` .
725
726    Returns:
727        output_x (Tensor): The same type and shape as the input_x. The shape is :math:`(N, C)`.
728        batch_mean (Tensor): Tensor of shape :math:`(C,)`.
729        batch_variance (Tensor): Tensor of shape :math:`(C,)`.
730        reserve_space_1 (Tensor): Tensor of shape :math:`(C,)`.
731        reserve_space_2 (Tensor): Tensor of shape :math:`(C,)`.
732
733    Raises:
734        TypeError: If `is_training` is not a bool.
735        TypeError: If dtype of `epsilon` or `momentum` is not float.
736        TypeError: If `data_format` is not a str.
737        TypeError: If `input_x`, `scale`, `bias`, `mean` or `variance` is not a Tensor.
738        TypeError: If dtype of `input_x`, `scale` is neither float16 nor float32.
739
740    Supported Platforms:
741        ``Ascend`` ``GPU`` ``CPU``
742
743    Examples:
744        >>> import mindspore
745        >>> import numpy as np
746        >>> from mindspore import Tensor, ops
747        >>> input_x = Tensor(np.ones([2, 2]), mindspore.float32)
748        >>> scale = Tensor(np.ones([2]), mindspore.float32)
749        >>> bias = Tensor(np.ones([2]), mindspore.float32)
750        >>> mean = Tensor(np.ones([2]), mindspore.float32)
751        >>> variance = Tensor(np.ones([2]), mindspore.float32)
752        >>> output = ops.batch_norm_(input_x, scale, bias, mean, variance, is_training, epsilon, momentum, data_format)
753        >>> print(output[0])
754        [[1. 1.]
755        [1. 1.]]
756    """
757    batch_norm_op = _get_cache_prim(BatchNorm)(is_training, epsilon, momentum,
758                                               data_format)
759    return batch_norm_op(input_x, scale, bias, mean, variance)
760
761
762class Rank(Primitive):
763    """
764    Returns the rank of a tensor.
765
766    Refer to :func:`mindspore.ops.rank` for more details.
767
768    Supported Platforms:
769        ``Ascend`` ``GPU`` ``CPU``
770
771    Examples:
772        >>> import mindspore
773        >>> import numpy as np
774        >>> from mindspore import Tensor, ops
775        >>> input_tensor = Tensor(np.array([[2, 2], [2, 2]]), mindspore.float32)
776        >>> rank = ops.Rank()
777        >>> output = rank(input_tensor)
778        >>> print(output)
779        2
780        >>> print(type(output))
781        <class 'int'>
782    """
783
784    @prim_attr_register
785    def __init__(self):
786        """Initialize Rank"""
787
788    def __call__(self, x):
789        if not isinstance(x, (Tensor, Tensor_)):
790            raise TypeError("the input x must be Tensor!")
791        return len(x.shape)
792
793
794def rank(input_x):
795    """
796    Returns the rank of a tensor.
797
798    Returns a 0-D int32 Tensor representing the rank of input; the rank of a tensor
799    is the number of indices required to uniquely select each element of the tensor.
800
801    Args:
802        input_x (Tensor): The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. The data type is Number.
803
804    Returns:
805        Tensor. 0-D int32 Tensor representing the rank of input, i.e., :math:`R`. The data type is an int.
806
807    Raises:
808        TypeError: If `input_x` is not a Tensor.
809
810    Supported Platforms:
811        ``Ascend`` ``GPU`` ``CPU``
812
813    Examples:
814        >>> import mindspore
815        >>> import numpy as np
816        >>> from mindspore import Tensor, ops
817        >>> input_tensor = Tensor(np.array([[2, 2], [2, 2]]), mindspore.float32)
818        >>> output = ops.rank(input_tensor)
819        >>> print(output)
820        2
821        >>> print(type(output))
822        <class 'int'>
823
824    """
825    rank_op = _get_cache_prim(Rank)()
826    return rank_op(input_x)
827
828
829class Shape(Primitive):
830    """
831    Returns the shape of the input tensor.
832
833    Refer to :func:`mindspore.ops.shape` for more details.
834
835    Inputs:
836        - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
837
838    Outputs:
839        tuple[int], the output tuple is constructed by multiple integers,
840        :math:`(x_1, x_2, ..., x_R)`.
841
842    Supported Platforms:
843        ``Ascend`` ``GPU`` ``CPU``
844
845    Examples:
846        >>> import mindspore
847        >>> import numpy as np
848        >>> from mindspore import Tensor, ops
849        >>> input_x = Tensor(np.ones(shape=[3, 2, 1]), mindspore.float32)
850        >>> shape = ops.Shape()
851        >>> output = shape(input_x)
852        >>> print(output)
853        (3, 2, 1)
854    """
855
856    @prim_attr_register
857    def __init__(self):
858        """Initialize Shape"""
859
860    def __call__(self, x):
861        if isinstance(x, (Tensor, COOTensor, CSRTensor, Tensor_)):
862            return x.shape
863        raise TypeError(f"For primitive[{self.name}], the input argument must be Tensor, but got {type(x)}.")
864
865
866def shape_(input_x):
867    """
868    Returns the shape of the input tensor.
869
870    Args:
871        input_x (Tensor): The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
872
873    Returns:
874        tuple[int], the output tuple is constructed by multiple integers,
875        :math:`(x_1, x_2, ..., x_R)`.
876
877    Raises:
878        TypeError: If `input_x` is not a Tensor.
879
880    Supported Platforms:
881        ``Ascend`` ``GPU`` ``CPU``
882
883    Examples:
884        >>> import mindspore
885        >>> import numpy as np
886        >>> from mindspore import Tensor, ops
887        >>> input_x = Tensor(np.ones(shape=[3, 2, 1]), mindspore.float32)
888        >>> output = ops.shape(input_x)
889        >>> print(output)
890        (3, 2, 1)
891    """
892    shape_op = _get_cache_prim(Shape)()
893    return shape_op(input_x)
894
895
896class ScalarToTensor(PrimitiveWithInfer):
897    """
898    Converts a scalar to a `Tensor`, and converts the data type to the specified type.
899
900    Refer to :func:`mindspore.ops.scalar_to_tensor` for more details.
901
902    Inputs:
903        - **input_x** (Union[int, float]) - The input is a scalar. Only constant value is allowed.
904        - **dtype** (mindspore.dtype) - The target data type. Default: ``mindspore.float32`` . Only
905          constant value is allowed.
906
907    Outputs:
908        Tensor. 0-D Tensor and the content is the input.
909
910    Supported Platforms:
911        ``Ascend`` ``GPU`` ``CPU``
912
913    Examples:
914        >>> import mindspore
915        >>> from mindspore import ops
916        >>> op = ops.ScalarToTensor()
917        >>> data = 1
918        >>> output = op(data, mindspore.float32)
919        >>> print(output)
920        1.0
921    """
922
923    @prim_attr_register
924    def __init__(self):
925        self.init_prim_io_names(inputs=['input_scalar', 'dtype'], outputs=['output_data'])
926
927    def __call__(self, x, dtype=mstype.float32):
928        validator.check_value_type("x", x, [bool, int, float], self.name)
929        validator.check_subclass("dtype", dtype, mstype.number, self.name)
930        data_type = mstype.dtype_to_nptype(dtype)
931        return Tensor(np.array(x, data_type), dtype=dtype)
932
933
934class Tile(Primitive):
935    r"""
936    Replicates an input tensor with given multiple times.
937
938    Refer to :func:`mindspore.ops.tile` for more details.
939
940    Inputs:
941        - **input** (Tensor) - The tensor whose elements need to be repeated. Set the shape of input tensor as
942          :math:`(x_1, x_2, ..., x_S)` .
943        - **dims** (tuple[int]) - The parameter that specifies the number of replications,
944          the parameter type is tuple, and the data type is int, i.e., :math:`(y_1, y_2, ..., y_S)`.
945          Only constant value is allowed.
946
947    Outputs:
948        Tensor, has the same data type as the `input`. Suppose the length of `dims` is `d`,
949        the dimension of `input` is `input.dim`, and the shape of `input` is :math:`(x_1, x_2, ..., x_S)`.
950
951        - If `input.dim = d`, then the shape of their corresponding positions can be multiplied, and
952          the shape of Outputs is :math:`(x_1*y_1, x_2*y_2, ..., x_S*y_S)`.
953        - If `input.dim < d`, prepend 1 to the shape of `input` until their lengths are consistent.
954          Such as set the shape of `input` as :math:`(1, ..., x_1, x_2, ..., x_S)`,
955          then the shape of their corresponding positions can be multiplied, and the shape of Outputs is
956          :math:`(1*y_1, ..., x_R*y_R, x_S*y_S)`.
957        - If `input.dim > d`, prepend 1 to `dims` until their lengths are consistent. Such as set the
958          `dims` as :math:`(1, ..., y_1, y_2, ..., y_S)`, then the shape of their corresponding positions
959          can be multiplied, and the shape of Outputs is :math:`(x_1*1, ..., x_R*y_R, x_S*y_S)`.
960
961    Raises:
962        TypeError: If `dims` is not a tuple or its elements are not all int.
963        ValueError: If the elements of `dims` are not all greater than or equal to 0.
964
965    Supported Platforms:
966        ``Ascend`` ``GPU`` ``CPU``
967
968    Examples:
969        >>> import mindspore
970        >>> import numpy as np
971        >>> from mindspore import Tensor, ops
972        >>> tile = ops.Tile()
973        >>> input = Tensor(np.array([[1, 2], [3, 4]]), mindspore.float32)
974        >>> dims = (2, 3)
975        >>> output = tile(input, dims)
976        >>> print(output)
977        [[1.  2.  1.  2.  1.  2.]
978         [3.  4.  3.  4.  3.  4.]
979         [1.  2.  1.  2.  1.  2.]
980         [3.  4.  3.  4.  3.  4.]]
981        >>> dims = (2, 3, 2)
982        >>> output = tile(input, dims)
983        >>> print(output)
984        [[[1. 2. 1. 2.]
985          [3. 4. 3. 4.]
986          [1. 2. 1. 2.]
987          [3. 4. 3. 4.]
988          [1. 2. 1. 2.]
989          [3. 4. 3. 4.]]
990         [[1. 2. 1. 2.]
991          [3. 4. 3. 4.]
992          [1. 2. 1. 2.]
993          [3. 4. 3. 4.]
994          [1. 2. 1. 2.]
995          [3. 4. 3. 4.]]]
996    """
997
998    @prim_attr_register
999    def __init__(self):
1000        """Initialize."""
1001
1002    def __call__(self, input, dims):
1003        return _convert_stub(pyboost_tile(self, [input, dims]))
1004
1005    # pylint: disable=missing-docstring
1006    def check_elim(self, *args):
1007        base_tensor, dims = args
1008        if not isinstance(base_tensor, Tensor):
1009            raise TypeError(f"For '{self.name}', the type of 'input' must be Tensor, "
1010                            f"but got {type(base_tensor).__name__}.")
1011        if not isinstance(dims, tuple):
1012            raise TypeError(f"For '{self.name}', the type of 'dims' must be tuple, "
1013                            f"but got {type(dims).__name__}.")
1014
1015        if all(v == 1 for v in dims) and len(base_tensor.shape) >= len(dims):
1016            from mindspore.ops.auto_generate.gen_ops_def import Identity
1017            ret = Identity()(base_tensor)
1018            return (True, ret)
1019        return (False, None)
1020
1021
1022def tile(input, dims):
1023    r"""
1024    Creates a new tensor by replicating `input` `dims` times. The i'th dimension of
1025    output tensor has `input.shape[i] * dims[i]` elements, and the values of `input`
1026    are replicated `dims[i]` times along the i'th dimension.
1027
1028    Args:
1029        input (Tensor): The tensor whose elements need to be repeated. Set the shape of input tensor as
1030            :math:`(x_1, x_2, ..., x_S)` .
1031
1032        dims (tuple[int]): The parameter that specifies the number of replications,
1033            the parameter type is tuple, and the data type is int, i.e., :math:`(y_1, y_2, ..., y_S)`.
1034            Only constant value is allowed.
1035
1036    Returns:
1037        Tensor, has the same data type as the `input`. Suppose the length of `dims` is `d`,
1038        the dimension of `input` is `input.dim`, and the shape of `input` is :math:`(x_1, x_2, ..., x_S)`.
1039
1040        - If `input.dim = d`, then the shape of their corresponding positions can be multiplied, and
1041          the shape of Outputs is :math:`(x_1*y_1, x_2*y_2, ..., x_S*y_S)`.
1042        - If `input.dim < d`, prepend 1 to the shape of `input` until their lengths are consistent.
1043          Such as set the shape of `input` as :math:`(1, ..., x_1, x_2, ..., x_S)`,
1044          then the shape of their corresponding positions can be multiplied, and the shape of Outputs is
1045          :math:`(1*y_1, ..., x_R*y_R, x_S*y_S)`.
1046        - If `input.dim > d`, prepend 1 to `dims` until their lengths are consistent. Such as set the
1047          `dims` as :math:`(1, ..., y_1, y_2, ..., y_S)`, then the shape of their corresponding positions
1048          can be multiplied, and the shape of Outputs is :math:`(x_1*1, ..., x_R*y_R, x_S*y_S)`.
1049
1050    Raises:
1051        TypeError: If `dims` is not a tuple or its elements are not all int.
1052        ValueError: If the elements of `dims` are not all greater than or equal to 0.
1053
1054    Supported Platforms:
1055        ``Ascend`` ``GPU`` ``CPU``
1056
1057    Examples:
1058        >>> import mindspore
1059        >>> import numpy as np
1060        >>> from mindspore import Tensor, ops
1061        >>> input = Tensor(np.array([[1, 2], [3, 4]]), mindspore.float32)
1062        >>> dims = (2, 3)
1063        >>> output = ops.tile(input, dims)
1064        >>> print(output)
1065        [[1.  2.  1.  2.  1.  2.]
1066         [3.  4.  3.  4.  3.  4.]
1067         [1.  2.  1.  2.  1.  2.]
1068         [3.  4.  3.  4.  3.  4.]]
1069        >>> dims = (2, 3, 2)
1070        >>> output = ops.tile(input, dims)
1071        >>> print(output)
1072        [[[1. 2. 1. 2.]
1073          [3. 4. 3. 4.]
1074          [1. 2. 1. 2.]
1075          [3. 4. 3. 4.]
1076          [1. 2. 1. 2.]
1077          [3. 4. 3. 4.]]
1078         [[1. 2. 1. 2.]
1079          [3. 4. 3. 4.]
1080          [1. 2. 1. 2.]
1081          [3. 4. 3. 4.]
1082          [1. 2. 1. 2.]
1083          [3. 4. 3. 4.]]]
1084    """
1085    tile_op = _get_cache_prim(Tile)()
1086    return tile_op(input, dims)
1087
1088
1089def scalar_cast(input_x, input_y):
1090    r"""
1091    The interface is deprecated from version 2.3 and will be removed in a future version,
1092    please use `int(x)` or `float(x)` instead.
1093
1094    Casts the input scalar to another type.
1095
1096    Args:
1097        input_x (scalar): The input scalar.
1098        input_y (mindspore.dtype): The type to be cast. Only constant value is allowed.
1099            The value should only be mindspore.int64, mindspore.float64, or mindspore.bool\_.
1100
1101    Returns:
1102        Scalar, the type is the same as the python type corresponding to `input_y`.
1103
1104    Raises:
1105        ValueError: if input_y's value is invalid.
1106
1107    Supported Platforms:
1108        Deprecated
1109
1110    Examples:
1111        >>> import mindspore
1112        >>> from mindspore import ops
1113        >>> output = ops.scalar_cast(255.0, mindspore.int64)
1114        >>> print(output)
1115        255
1116    """
1117    scalar_cast_op = _get_cache_prim(ScalarCast)()
1118    return scalar_cast_op(input_x, input_y)
1119
1120
1121class Cast(Primitive):
1122    """
1123    Returns a tensor with the new specified data type.
1124
1125    Note:
1126        When converting complex numbers to boolean type, the imaginary part of the complex number is not
1127        taken into account. As long as the real part is non-zero, it returns True; otherwise, it returns False.
1128
1129    Inputs:
1130        - **input_x** (Union[Tensor, Number]) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
1131          The tensor to be cast.
1132        - **type** (dtype.Number) - The valid data type of the output tensor. Only constant value is allowed.
1133
1134    Outputs:
1135        Tensor, the shape of tensor is the same as `input_x`, :math:`(x_1, x_2, ..., x_R)`.
1136
1137    Raises:
1138        TypeError: If `input_x` is neither Tensor nor Number.
1139        TypeError: If `type` is not a Number.
1140
1141    Supported Platforms:
1142        ``Ascend`` ``GPU`` ``CPU``
1143
1144    Examples:
1145        >>> import mindspore
1146        >>> import numpy as np
1147        >>> from mindspore import Tensor, ops
1148        >>> input_np = np.random.randn(2, 3, 4, 5).astype(np.float32)
1149        >>> input_x = Tensor(input_np)
1150        >>> type_dst = mindspore.int32
1151        >>> cast = ops.Cast()
1152        >>> output = cast(input_x, type_dst)
1153        >>> print(output.dtype)
1154        Int32
1155        >>> print(output.shape)
1156        (2, 3, 4, 5)
1157    """
1158
1159    @prim_attr_register
1160    def __init__(self):
1161        """Initialize Cast"""
1162        self.init_prim_io_names(inputs=['x', 'dst_type'], outputs=['output'])
1163
1164    def check_elim(self, x, dtype):
1165        if isinstance(x, (Tensor, numbers.Number, Parameter)):
1166            if isinstance(x, Parameter):
1167                data = x.data
1168                if data.dtype == dtype:
1169                    return (True, x)
1170            if isinstance(x, Tensor) and x.dtype == dtype:
1171                x = Tensor(x)
1172                x.set_cast_dtype()
1173                return (True, x)
1174            if isinstance(x, numbers.Number):
1175                return (True, Tensor(x, dtype=dtype))
1176        return (False, None)
1177
1178    def __call__(self, input_x, dtype):
1179        should_elim, output = self.check_elim(input_x, dtype)
1180        if should_elim:
1181            return output
1182        return _convert_stub(pyboost_cast(self, [input_x, dtype_to_type_id('Cast', 'dtype', dtype)]))
1183
1184
1185def to_sequence(val):
1186    """
1187    to_sequence
1188    """
1189    if isinstance(val, (tuple, list)):
1190        return val
1191    return (val,)
1192
1193
1194class EmbeddingTableExport(Primitive):
1195    """
1196    EmbeddingTableExport
1197    """
1198
1199    @prim_attr_register
1200    def __init__(self, embedding_dim, value_total_len, export_mode="all",
1201                 only_var_flag=False, file_type="bin", table_name=(),
1202                 filter_export_flag=False, steps_to_live_list=()):
1203        """Initialize EmbeddingTableExport"""
1204        self.add_prim_attr("_process_node_engine_id", "PS")
1205
1206
1207class EmbeddingTableImport(Primitive):
1208    """
1209    EmbeddingTableImport
1210    """
1211
1212    @prim_attr_register
1213    def __init__(self, embedding_dim, value_total_len,
1214                 only_var_flag=False, file_type="bin", table_name=()):
1215        """Initialize EmbeddingTableImport"""
1216        self.add_prim_attr("_process_node_engine_id", "PS")
1217
1218
1219class EmbeddingComputeVarImport(Primitive):
1220    """
1221    EmbeddingComputeVarImport
1222    """
1223
1224    @prim_attr_register
1225    def __init__(self, table_name=()):
1226        """Initialize EmbeddingComputeVarImport"""
1227        self.add_prim_attr("_process_node_engine_id", "PS")
1228
1229
1230class EmbeddingComputeVarExport(Primitive):
1231    """
1232    EmbeddingComputeVarExport
1233    """
1234
1235    @prim_attr_register
1236    def __init__(self, table_name=()):
1237        """Initialize EmbeddingComputeVarExport"""
1238        self.add_prim_attr("_process_node_engine_id", "PS")
1239
1240
1241class InitEmbeddingHashmap(Primitive):
1242    """
1243    InitEmbeddingHashmap
1244    """
1245    @prim_attr_register
1246    def __init__(self, value_total_len, embedding_dim, _table_id,
1247                 bucket_size=0, dtype=mstype.float32, initializer_mode="",
1248                 constant_valu=0., min=-2., max=2., mu=0., sigma=1., seed=0,
1249                 seed2=0, filter_mode="no_filter", optimizer_mode="",
1250                 optimizer_params=()):
1251        self.add_prim_attr("_process_node_engine_id", "PS")
1252
1253
1254def init_embedding_hashmap(table_id, value_total_len, embedding_dim, _table_id,
1255                           bucket_size=0, dtype=mstype.float32, initializer_mode='',
1256                           constant_value=0.0, min=-2.0, max=2.0, mu=0.0, sigma=1.0,
1257                           seed=0, seed2=0, filter_mode='no_filter',
1258                           optimizer_mode='', optimizer_params=()):
1259    """
1260    init_embedding_hashmap
1261    """
1262    op = _get_cache_prim(InitEmbeddingHashmap)(value_total_len, embedding_dim, _table_id,
1263                                               bucket_size, dtype, initializer_mode,
1264                                               constant_value, min, max, mu, sigma, seed,
1265                                               seed2, filter_mode, optimizer_mode, optimizer_params)
1266    return op(table_id)
1267
1268
1269class InitPartitionMap(Primitive):
1270    """
1271    InitPartitionMap
1272    """
1273    @prim_attr_register
1274    def __init__(self, _embedding_dim, _max_key_num,
1275                 _ps_num=1, partition_num=65537):
1276        self.add_prim_attr("_process_node_engine_id", "PS")
1277
1278
1279def init_partition_map(ps_num, ps_ids, _embedding_dim, _max_key_num,
1280                       _ps_num=1, partition_num=65537):
1281    """
1282    init_partition_map
1283    """
1284    op = _get_cache_prim(InitPartitionMap)(_embedding_dim, _max_key_num, _ps_num, partition_num)
1285    return op(ps_num, ps_ids)
1286
1287
1288class EmbeddingApplyAdam(Primitive):
1289    """
1290    EmbeddingApplyAdam
1291    """
1292    @prim_attr_register
1293    def __init__(self, embedding_dim, _max_key_num, mask_zero=(0,),
1294                 padding_key=(0,), padding_key_mask=(1,),
1295                 completion_key=(0,), completion_key_mask=(1,)):
1296        self.add_prim_attr("_process_node_engine_id", "PS")
1297
1298
1299class EmbeddingApplyAdamW(Primitive):
1300    """
1301    EmbeddingApplyAdam
1302    """
1303    @prim_attr_register
1304    def __init__(self, embedding_dim, _max_key_num, amsgrad=(0,),
1305                 maximize=(0,), mask_zero=(0,), padding_key=(0,),
1306                 padding_key_mask=(1,), completion_key=(0,), completion_key_mask=(1,)):
1307        self.add_prim_attr("_process_node_engine_id", "PS")
1308
1309
1310class EmbeddingApplyAdaGrad(Primitive):
1311    """
1312    EmbeddingApplyAdaGrad
1313    """
1314    @prim_attr_register
1315    def __init__(self, embedding_dim, _max_key_num, mask_zero=(0,),
1316                 padding_key=(0,), padding_key_mask=(1,),
1317                 completion_key=(0,), completion_key_mask=(1,)):
1318        self.add_prim_attr("_process_node_engine_id", "PS")
1319
1320
1321class EmbeddingApplyFtrl(Primitive):
1322    """
1323    EmbeddingApplyFtrl
1324    """
1325    @prim_attr_register
1326    def __init__(self, embedding_dim, _max_key_num, mask_zero=(0,),
1327                 padding_key=(0,), padding_key_mask=(1,),
1328                 completion_key=(0,), completion_key_mask=(1,)):
1329        self.add_prim_attr("_process_node_engine_id", "PS")
1330
1331
1332class EmbeddingTableFind(Primitive):
1333    """
1334    EmbeddingTableFind
1335    """
1336    @prim_attr_register
1337    def __init__(self, embedding_dim, _embedding_dim, _max_key_num,
1338                 _table_id, default_value=(-1.), _use_counter_filter=0):
1339        self.add_prim_attr("_process_node_engine_id", "PS")
1340        self.add_prim_attr("_execute_times", 2)
1341
1342
1343def embedding_table_find(table_id, keys, embedding_dim, _max_key_num,
1344                         _table_id, default_value=(-1.0,), _use_counter_filter=0):
1345    r"""
1346    embedding_table_find
1347    """
1348    _embedding_dim = embedding_dim if isinstance(embedding_dim, int) else embedding_dim[_table_id]
1349    op = _get_cache_prim(EmbeddingTableFind)(to_sequence(embedding_dim), _embedding_dim,
1350                                             _max_key_num, _table_id,
1351                                             to_sequence(default_value),
1352                                             _use_counter_filter)
1353    return op(table_id, keys)
1354
1355
1356class EmbeddingTableFindAndInit(Primitive):
1357    """
1358    EmbeddingTableFindAndInit
1359    """
1360    @prim_attr_register
1361    def __init__(self, embedding_dim, value_total_len, _embedding_dim, _table_id,
1362                 _max_key_num, initializer_mode=("random_uniform",),
1363                 constant_value=(0.,), min=(-2.,), max=(2.,), mu=(0.,),
1364                 sigma=(1.,), seed=(0,), seed2=(0,),
1365                 filter_mode=("no_filter",), filter_freq=(0,),
1366                 default_key_or_value=(0,), default_key=(0,),
1367                 default_value=(0.,), completion_key=(0,),
1368                 completion_key_mask=(1,), optimizer_mode=(),
1369                 optimizer_params=(), _use_counter_filter=0,
1370                 backward_mode="adam",
1371                 backward_int_params=((0,), (0,), (0,), (1,)),
1372                 backward_float_params=(0.9, 0.99, 0.001, 0.9, 0.999, 1e-08)):
1373        self.add_prim_attr("_process_node_engine_id", "PS")
1374        self.add_prim_attr("_execute_times", 2)
1375
1376
1377def embedding_table_find_and_init(table_id, keys, max_grad_norm, parameter, embedding_dim,
1378                                  value_total_len, _table_id, _max_key_num,
1379                                  initializer_mode=('random_uniform',), constant_value=(0.,),
1380                                  min=(-2.,), max=(2.,), mu=(0.,), sigma=(1.,), seed=(0,),
1381                                  seed2=(0,), filter_mode=("no_filter",),
1382                                  filter_freq=(0,), default_key_or_value=(0,),
1383                                  default_key=(0,), default_value=(0.,),
1384                                  completion_key=(0,), completion_key_mask=(1,),
1385                                  optimizer_mode=(), optimizer_params=(), _use_counter_filter=0,
1386                                  backward_mode="adam", backward_int_params=((0,), (0,), (0,), (1,)),
1387                                  backward_float_params=(0.9, 0.99, 0.001, 0.9, 0.999, 1e-08)):
1388    """
1389    embedding_table_find_and_init
1390
1391    backward_int_params (Union[tuple[tuple[int]], list[list[int]]]):
1392        - when the backward_mode is 'adam', 'ftrl' or 'adagrad',
1393          it means [[global_step], mask_zero, padding_key, padding_key_mask]
1394        - when the backward_mode is 'adamw', it means:
1395          [[global_step], amsgrad, maximize, mask_zero, padding_key, padding_key_mask]
1396    backward_float_params (Union[tuple[float], list[float]]):
1397        - when the backward_mode is 'adam', it means:
1398          [beta1_power, beta2_power, lr, beta1, beta2, epsilon]
1399        - when the backward_mode is 'ftrl', it means:
1400          [lr, lr_power, lambda1, lambda2]
1401        - when the backward_mode is 'adamw', it means:
1402          [beta1_power, beta2_power, lr, weight_decay, beta1, beta2, epsilon]
1403        - when the backward_mode is 'adagrad', it means [lr,]
1404    """
1405    _embedding_dim = embedding_dim if isinstance(embedding_dim, int) else embedding_dim[_table_id]
1406    op = _get_cache_prim(EmbeddingTableFindAndInit)(to_sequence(embedding_dim), to_sequence(value_total_len),
1407                                                    _embedding_dim, _table_id, _max_key_num,
1408                                                    to_sequence(initializer_mode),
1409                                                    to_sequence(constant_value), to_sequence(min),
1410                                                    to_sequence(max), to_sequence(mu),
1411                                                    to_sequence(sigma), to_sequence(seed),
1412                                                    to_sequence(seed2), to_sequence(filter_mode),
1413                                                    to_sequence(filter_freq), to_sequence(default_key_or_value),
1414                                                    to_sequence(default_key), to_sequence(default_value),
1415                                                    to_sequence(completion_key), to_sequence(completion_key_mask),
1416                                                    to_sequence(optimizer_mode), to_sequence(optimizer_params),
1417                                                    _use_counter_filter,
1418                                                    backward_mode, backward_int_params, backward_float_params)
1419    return op(table_id, keys, max_grad_norm, parameter)
1420
1421
1422class FakeRemoteLookupUniqued(Primitive):
1423
1424    """
1425    FakeRemoteLookupUniqued
1426    """
1427    @prim_attr_register
1428    def __init__(self, embedding_dim, value_total_len, _embedding_dim, _table_id,
1429                 _max_key_num, initializer_mode=('random_uniform',), constant_value=(0.,),
1430                 min=(-2.,), max=(2.,), mu=(0.,), sigma=(1.,), seed=(0,), seed2=(0,),
1431                 filter_mode=("no_filter",), filter_freq=(0,),
1432                 default_key_or_value=(0,), default_key=(0,), default_value=(0.,),
1433                 completion_key=(0,), completion_key_mask=(1,),
1434                 optimizer_mode=(), optimizer_params=(), _use_counter_filter=0,
1435                 backward_mode="adam", backward_int_params=((0,), (0,), (0,), (1,)),
1436                 backward_float_params=(0.9, 0.99, 0.001, 0.9, 0.999, 1e-08)):
1437        self.add_prim_attr("_process_node_engine_id", "PS")
1438        self.add_prim_attr("_execute_times", 2)
1439
1440
1441def fake_remote_lookup_uniqued(table_id, keys, actual_keys_num, unique_indices,
1442                               key_count, max_grad_norm, parameter,
1443                               embedding_dim, value_total_len, _table_id, _max_key_num,
1444                               initializer_mode=('random_uniform',), constant_value=(0.,),
1445                               min=(-2.,), max=(2.,), mu=(0.,), sigma=(1.,), seed=(0,),
1446                               seed2=(0,), filter_mode=("no_filter",),
1447                               filter_freq=(0,), default_key_or_value=(0,),
1448                               default_key=(0,), default_value=(0.,),
1449                               completion_key=(0,), completion_key_mask=(1,),
1450                               optimizer_mode=(), optimizer_params=(), _use_counter_filter=0,
1451                               backward_mode='adam', backward_int_params=((0,), (0,), (0,), (1,)),
1452                               backward_float_params=(0.9, 0.99, 0.001, 0.9, 0.999, 1e-08)):
1453    """
1454    fake_remote_lookup_uniqued
1455
1456    backward_mode (str): determine the optimizer used by backpropagation,
1457        valid values are ["adam", "adamw", "adagrad", "ftrl"]
1458    backward_int_params (Union[tuple[tuple[int]], list[list[int]]]):
1459        - when the backward_mode is 'adam', 'ftrl' or 'adagrad',
1460          it means [[global_step], mask_zero, padding_key, padding_key_mask]
1461        - when the backward_mode is 'adamw', it means:
1462          [[global_step], amsgrad, maximize, mask_zero, padding_key, padding_key_mask]
1463    backward_float_params (Union[tuple[float], list[float]]):
1464        - when the backward_mode is 'adam', it means:
1465          [beta1_power, beta2_power, lr, beta1, beta2, epsilon]
1466        - when the backward_mode is 'ftrl', it means:
1467          [lr, lr_power, lambda1, lambda2]
1468        - when the backward_mode is 'adamw', it means:
1469          [beta1_power, beta2_power, lr, weight_decay, beta1, beta2, epsilon]
1470        - when the backward_mode is 'adagrad', it means [lr,]
1471    """
1472    _embedding_dim = embedding_dim if isinstance(embedding_dim, int) else embedding_dim[_table_id]
1473    op = _get_cache_prim(FakeRemoteLookupUniqued)(to_sequence(embedding_dim), to_sequence(value_total_len),
1474                                                  _embedding_dim, _table_id, _max_key_num,
1475                                                  to_sequence(initializer_mode), to_sequence(constant_value),
1476                                                  to_sequence(min), to_sequence(max), to_sequence(mu),
1477                                                  to_sequence(sigma), to_sequence(seed), to_sequence(seed2),
1478                                                  to_sequence(filter_mode), to_sequence(filter_freq),
1479                                                  to_sequence(default_key_or_value), to_sequence(default_key),
1480                                                  to_sequence(default_value), to_sequence(completion_key),
1481                                                  to_sequence(completion_key_mask), to_sequence(optimizer_mode),
1482                                                  to_sequence(optimizer_params), _use_counter_filter,
1483                                                  backward_mode, backward_int_params, backward_float_params)
1484    return op(table_id, keys, actual_keys_num, unique_indices, key_count, max_grad_norm, parameter)
1485
1486
1487# Following is Python Infer Value.
1488# A valid infer value function should be:
1489#
1490# 1. named as infer_value_for_OpName
1491# 2. All inputs should pass without default value.
1492# 3. If not const input is given, return None. (for now)
1493
1494
1495def infer_value_for_Tile(input, dims):
1496    """Infer value for Tile op."""
1497    if input is None or dims is None or None in dims:
1498        return None
1499    return Tensor(np.tile(input.asnumpy(), dims))
1500
1501
1502def infer_value_for_Concat(tensors, axis):
1503    """Infer value for Concat op."""
1504    if not tensors or None in tensors or axis is None:
1505        return None
1506
1507    tensor_to_concat = [x.asnumpy() if x.dtype != mstype.bfloat16 else x.float().asnumpy() for x in tensors]
1508    return Tensor(np.concatenate(tensor_to_concat, axis), dtype=tensors[0].dtype)
1509
1510
1511def infer_value_for_ReduceSum(input_x, axis, keep_dims, skip_mode):
1512    """Infer value for ReduceSum op."""
1513    value = None
1514    if input_x is not None and axis is not None:
1515        value = input_x.asnumpy()
1516        if isinstance(axis, int):
1517            pass
1518        elif axis:
1519            axis = tuple(set(axis))
1520        elif axis in ((), []) and skip_mode:
1521            return input_x
1522        else:
1523            axis = tuple(range(len(value.shape)))
1524        value = np.sum(value, axis, keepdims=keep_dims)
1525        value = np.array(value)
1526        value = Tensor(value)
1527    return value
1528
1529
1530def _infer_value_for_Reduce(input_x, axis, keep_dims, prim_name):
1531    """Infer value for Common Reduce op."""
1532    value = None
1533    if input_x is not None and axis is not None:
1534        prim_map = {
1535            'ReduceMax': np.max,
1536            'ReduceMin': np.min,
1537            'ReduceProd': np.prod,
1538            'ReduceMean': np.mean,
1539            'ReduceAll': np.all,
1540            'ReduceAny': np.any,
1541        }
1542        np_reduce_func = prim_map.get(prim_name, None)
1543
1544        if np_reduce_func is not None:
1545            value = input_x.asnumpy()
1546            if isinstance(axis, int):
1547                pass
1548            elif axis:
1549                axis = tuple(set(axis))
1550            else:
1551                axis = tuple(range(len(value.shape)))
1552            value = np_reduce_func(value, axis, keepdims=keep_dims)
1553            value = np.array(value)
1554            value = Tensor(value)
1555    return value
1556
1557
1558def _infer_value_for_ReduceExtand(input_x, axis, keep_dims, dtype, prim_name):
1559    """Infer value for Common ReduceExtand op."""
1560    value = None
1561    if input_x is not None:
1562        prim_map = {
1563            'MeanExt': np.mean,
1564            'SumExt': np.sum,
1565            'ProdExt': np.prod,
1566        }
1567        np_reduce_extand_func = prim_map.get(prim_name, None)
1568
1569        if np_reduce_extand_func is not None:
1570            value = input_x.asnumpy()
1571            if isinstance(axis, int):
1572                pass
1573            elif axis:
1574                axis = tuple(set(axis))
1575            else:
1576                axis = tuple(range(len(value.shape)))
1577            if dtype is not None:
1578                np_dtype = mstype.dtype_to_nptype(typing.type_id_to_type(dtype))
1579                value = np_reduce_extand_func(value, axis, dtype=np_dtype, keepdims=keep_dims)
1580            else:
1581                value = np_reduce_extand_func(value, axis, keepdims=keep_dims)
1582
1583            value = np.array(value)
1584            value = Tensor(value)
1585    return value
1586
1587
1588def _infer_value_for_max_min(input_x, prim_name):
1589    """Infer value for Max/Min op."""
1590    value = None
1591    if input_x is not None:
1592        prim_map = {
1593            'Max': np.max,
1594            'Min': np.min,
1595        }
1596        np_reduce_func = prim_map.get(prim_name, None)
1597
1598        if np_reduce_func is not None:
1599            value = input_x.asnumpy()
1600            value = np_reduce_func(value, None, keepdims=False)
1601            value = np.array(value)
1602            value = Tensor(value)
1603    return value
1604
1605
1606def infer_value_for_Cast(x, dst_type_enum=None):
1607    """Infer value for Cast op."""
1608    if x is None or dst_type_enum is None:
1609        return None
1610    dst_type = typing.type_id_to_type(dst_type_enum)
1611    src_type = mstype.get_py_obj_dtype(x)
1612    validator.check_subclass("input_x", src_type, [mstype.tensor_type, mstype.number], "Cast")
1613    validator.check_subclass("type", dst_type, mstype.number, "Cast")
1614
1615    if isinstance(src_type, type(mstype.tensor_type)):
1616        src_type = src_type.element_type()
1617    if isinstance(dst_type, type(mstype.tensor_type)):
1618        dst_type = dst_type.element_type()
1619
1620    value = None
1621    np_dst_type = mstype.dtype_to_nptype(dst_type)
1622    if isinstance(x, (int, float)):
1623        value = Tensor(np.array(x).astype(np_dst_type), dtype=dst_type)
1624    else:
1625        value = Tensor_(x.asnumpy().astype(np_dst_type), dtype=dst_type)
1626    return value
1627
1628
1629def infer_value_for_ReduceMax(input_x, axis, keep_dims):
1630    """Infer value for ReduceMax op."""
1631    return _infer_value_for_Reduce(input_x, axis, keep_dims, 'ReduceMax')
1632
1633
1634def infer_value_for_Max(input_x):
1635    """Infer value for Max op."""
1636    return _infer_value_for_max_min(input_x, 'Max')
1637
1638
1639def infer_value_for_ReduceMin(input_x, axis, keep_dims):
1640    """Infer value for ReduceMin op."""
1641    return _infer_value_for_Reduce(input_x, axis, keep_dims, 'ReduceMin')
1642
1643
1644def infer_value_for_Min(input_x):
1645    """Infer value for Max op."""
1646    return _infer_value_for_max_min(input_x, 'Min')
1647
1648
1649def infer_value_for_ReduceProd(input_x, axis, keep_dims):
1650    """Infer value for ReduceProd op."""
1651    return _infer_value_for_Reduce(input_x, axis, keep_dims, 'ReduceProd')
1652
1653
1654def infer_value_for_ReduceMean(input_x, axis, keep_dims):
1655    """Infer value for ReduceMean op."""
1656    return _infer_value_for_Reduce(input_x, axis, keep_dims, 'ReduceMean')
1657
1658
1659def infer_value_for_ReduceAll(input_x, axis, keep_dims):
1660    """Infer value for ReduceAll op."""
1661    return _infer_value_for_Reduce(input_x, axis, keep_dims, 'ReduceAll')
1662
1663
1664def infer_value_for_ReduceAny(input_x, axis, keep_dims):
1665    """Infer value for ReduceAny op."""
1666    return _infer_value_for_Reduce(input_x, axis, keep_dims, 'ReduceAny')
1667
1668
1669def infer_value_for_MeanExt(input_x, axis, keep_dims, dtype):
1670    """Infer value for MeanExt op."""
1671    return _infer_value_for_ReduceExtand(input_x, axis, keep_dims, dtype, 'MeanExt')
1672
1673
1674def infer_value_for_SumExt(input_x, axis, keep_dims, dtype):
1675    """Infer value for SumExt op."""
1676    return _infer_value_for_ReduceExtand(input_x, axis, keep_dims, dtype, 'SumExt')
1677
1678
1679def infer_value_for_ProdExt(input_x, axis, keep_dims, dtype):
1680    """Infer value for ProdExt op."""
1681    return _infer_value_for_ReduceExtand(input_x, axis, keep_dims, dtype, 'ProdExt')
1682
1683
1684def infer_value_for_Diag(input_x):
1685    """Infer value for Diag op."""
1686    if input_x is None:
1687        return None
1688    # do constant-folding only when x rank is 1
1689    if len(input_x.shape) != 1:
1690        return None
1691    ret = np.diag(input_x.asnumpy())
1692    return Tensor(ret)
1693
1694
1695def infer_value_for_BroadcastTo(x, shape):
1696    """Infer value for BroadcastTo op."""
1697    def none_in_tuple_or_list(x):
1698        return isinstance(x, (tuple, list)) and None in x
1699    if shape is None or none_in_tuple_or_list(shape) or x is None:
1700        return None
1701
1702    if isinstance(shape, (Tensor, Tensor_)):
1703        validator.check_tensor_dtype_valid("shape", mstype.TensorType(shape.dtype),
1704                                           [mstype.int32, mstype.int64], "BroadcastTo")
1705        shape = shape.asnumpy().tolist()
1706    else:
1707        validator.check_value_type("shape", shape, [tuple], "BroadcastTo")
1708        shape = list(shape)
1709
1710    np_data = np.broadcast_to(x.asnumpy(), shape)
1711    if 0 in shape:
1712        init_func = Zero()
1713        init_func.__enable_zero_dim__ = True
1714        out = Tensor(shape=shape, dtype=x.dtype, init=init_func)
1715        return out
1716    return Tensor(np_data)
1717
1718
1719def infer_value_for_Reshape(x, shape):
1720    """Infer value for Reshape op."""
1721    def none_in_tuple_or_list(x):
1722        return isinstance(x, (tuple, list)) and None in x
1723    # for shape is not constant
1724    if shape is None or none_in_tuple_or_list(shape) or x is None:
1725        return None
1726
1727    if isinstance(shape, (Tensor, Tensor_)):
1728        validator.check_tensor_dtype_valid("shape", mstype.TensorType(shape.dtype),
1729                                           [mstype.int32, mstype.int64], "Reshape")
1730        shape = shape.asnumpy().tolist()
1731    else:
1732        validator.check_value_type("shape", shape, [tuple], "Reshape")
1733        shape = list(shape)
1734
1735    neg_index = -1
1736    dim_prod = 1
1737    for i, shp_i in enumerate(shape):
1738        validator.check_value_type("shape[%d]" % i, shp_i, [int], "Reshape")
1739        if shp_i == -1:
1740            if neg_index != -1:
1741                raise ValueError(f"For 'Reshape', there can be at most one '-1' in 'input_shape', "
1742                                 f"but got {shape}.")
1743            neg_index = i
1744        else:
1745            dim_prod *= shp_i
1746    out = None
1747    if not is_shape_unknown(x.shape):
1748        x_shp = x.shape
1749        if dim_prod < 0:
1750            raise ValueError(f"For 'Reshape', the shape of 'input_x' is {x_shp}, "
1751                             f"the value of 'input_shape' is {shape}. "
1752                             f"The product of 'input_shape' should > 0, but got {dim_prod}.")
1753        arr_prod = np.prod(x_shp)
1754        if neg_index != -1:
1755            shape[neg_index] = int(arr_prod // dim_prod)
1756            dim_prod *= shape[neg_index]
1757        if dim_prod != arr_prod:
1758            raise ValueError(f"For 'Reshape', the product of the 'input_x' shape "
1759                             f"should be equal to product of 'input_shape', but got product of the"
1760                             f" shape of 'input_x': {arr_prod}, product of 'input_shape': {dim_prod}.")
1761        if 0 in shape:
1762            init_func = Zero()
1763            init_func.__enable_zero_dim__ = True
1764            out = Tensor(shape=shape, dtype=x.dtype, init=init_func)
1765        else:
1766            out = Tensor(x.asnumpy().reshape(shape))
1767    return out
1768
1769
1770class Ones(Primitive):
1771    r"""
1772    Creates a tensor filled with value ones.
1773
1774    Refer to :func:`mindspore.ops.ones` for more details.
1775
1776    .. warning::
1777        For argument `size`, Tensor type input will be deprecated in the future version.
1778
1779    Inputs:
1780        - **shape** (Union[tuple[int], List[int], int, Tensor]) - The specified shape of output tensor.
1781        - **type** (:class:`mindspore.dtype`) - The specified type of output tensor.
1782
1783    Outputs:
1784        Tensor, whose dtype and size are defined by input.
1785
1786    Raises:
1787        TypeError: If `shape` is neither an int nor an tuple/list/Tensor of int.
1788
1789    Supported Platforms:
1790        ``Ascend`` ``GPU`` ``CPU``
1791
1792    Examples:
1793        >>> import mindspore
1794        >>> from mindspore import ops
1795        >>> ones = ops.Ones()
1796        >>> output = ones((2, 2), mindspore.float32)
1797        >>> print(output)
1798        [[1. 1.]
1799         [1. 1.]]
1800        >>> output = ones((3, 3), mindspore.float32)
1801        >>> print(output)
1802        [[1. 1. 1.]
1803         [1. 1. 1.]
1804         [1. 1. 1.]]
1805    """
1806
1807    __mindspore_signature__ = (
1808        sig.make_sig('size'),
1809        sig.make_sig('type', default=None),
1810    )
1811
1812    @prim_arg_register
1813    def __init__(self):
1814        pass
1815
1816    def __call__(self, size, type=None):
1817        return _convert_stub(pyboost_ones(self, [size, type if type is None \
1818            else handler.dtype_to_type_id('Ones', 'type', type)]))
1819
1820
1821class Zeros(Primitive):
1822    r"""
1823    Zeros will be deprecated in the future. Please use class `mindspore.ops.zeros` instead.
1824
1825    Creates a tensor filled with value zeros.
1826
1827    Creates a tensor with shape described by the first argument and
1828    fills it with value zeros in type of the second argument.
1829
1830    .. warning::
1831        For argument `size`, Tensor type input will be deprecated in the future version.
1832
1833    Inputs:
1834        - **shape** (tuple[int], List[int], int, Tensor) - The specified shape of output tensor.
1835        - **type** (mindspore.dtype) - The specified type of output tensor.
1836
1837    Outputs:
1838        Tensor, whose dtype and size are defined by input.
1839
1840    Raises:
1841        TypeError: If `shape` is neither an int nor an tuple/list/Tensor of int.
1842
1843    Supported Platforms:
1844        ``Ascend`` ``GPU`` ``CPU``
1845
1846    Examples:
1847        >>> import mindspore
1848        >>> from mindspore import ops
1849        >>> zeros = ops.Zeros()
1850        >>> output = zeros((2, 2), mindspore.float32)
1851        >>> print(output)
1852        [[0. 0.]
1853         [0. 0.]]
1854
1855    """
1856
1857    __mindspore_signature__ = (
1858        sig.make_sig('size'),
1859        sig.make_sig('type', default=None),
1860    )
1861
1862    @prim_arg_register
1863    def __init__(self):
1864        pass
1865
1866    def __call__(self, size, type=None):
1867        return _convert_stub(pyboost_zeros(self, [size, type if type is None else \
1868            handler.dtype_to_type_id('Zeros', 'type', type)]))
1869
1870
1871def flash_attention_score(query, key, value, head_num, real_shift=None, drop_mask=None, padding_mask=None,
1872                          attn_mask=None, prefix=None, actual_seq_qlen=None, actual_seq_kvlen=None, keep_prob=1.0,
1873                          scalar_value=1.0, pre_tokens=2147483647, next_tokens=2147483647, inner_precise=0,
1874                          input_layout='BSH', sparse_mode=0):
1875    r"""
1876    The interface is not open to the public, just for internal use,
1877
1878    .. math::
1879        \begin{array}{ll} \\
1880            y = Dropout(Softmax(Mask(scale_value \mul (real_shift + query * key), attn_mask), -1), keep\_prob) \\
1881            \mul value \\
1882        \end{array}
1883
1884    B -- Batch size. Value range 1 to 2k.
1885    S1 -- Sequence length of query. Value range 1 to 512k.
1886    S2 -- Sequence length of key and value. Value range 1 to 512k.
1887    N1 -- Num heads of query. Value range 1 to 256.
1888    N2 -- Num heads of key and value, and N2 must be a factor of N1.
1889    D -- Head size. The value ranges is a multiple of 16, with the max value of 512.
1890    H1 -- Hidden size of query, which equals to N1 * D.
1891    H2 -- Hidden size of key and value, which equals to N2 * D.
1892
1893    .. warning::
1894        This is an experimental API that is subject to change or deletion. Only support on Atlas training series.
1895
1896    Args:
1897        query (Tensor[float16, bfloat16]): The query tensor. Input tensor of shape :math:`(B, S1, H1)`,
1898            `(B, N1, S1, D)`, `(S1, B, H1)`, `(B, S1, N1, D)` or `(T1, N1, D)`.
1899        key (Tensor[float16, bfloat16]): The key tensor. Input tensor of shape :math:`(B, S2, H2)`,
1900            `(B, N2, S2, D)`, `(S2, B, H2)`, `(B, S2, N2, D)` or `(T2, N2, D)`.
1901        value (Tensor[float16, bfloat16]): The value tensor. Input tensor of shape :math:`(B, S2, H2)`,
1902            `(B, N2, S2, D)`, `(S2, B, H2)`, `(B, S2, N2, D)` or `(T2, N2, D)`. The key and value have the same shape.
1903        head_num (int): The head num of query, equal to N1.
1904        real_shift (Union[Tensor[float16, bfloat16], None]): Also known as pse. The position embedding code. If S
1905            is greater than 1024 and the mask of the lower triangle is used, enter only the inverse 1024 lines of
1906            the lower triangle for memory optimization. Input tensor of shape :math:`(B, N1, S1, S2)`,
1907            `(1, N1, S1, S2)`, `(B, N1, 1024, S2)`, `(1, N1, 1024, S2)`.
1908
1909            - ALiBi scenario: real_shift must meet the ALiBi rule, and sparse_mode is 2 or 3 for the lower triangle.
1910              In this scenario, real_shift is `(B, N1, 1024, S2)`, `(1, N1, 1024, S2)`.
1911            - Non-ALiBi scenario: real_shift is `(B, N1, S1, S2)`, `(1, N1, S1, S2)`.
1912
1913            The shape of `real_shift` should be `(B, N1, 1024, S2)` and `(1, N1, 1024, S2)` when input_layout is
1914            `TND`.
1915        drop_mask (Union[Tensor[uint8], None]): The dropout mask tensor. Input tensor of shape :math:
1916            `(B, N1, S1, S2 // 8) or None`. S2 is a multiple of 8 when not None.
1917        padding_mask (None): Reserved parameter. Not implemented yet.
1918        attn_mask (Union[Tensor[uint8], Tensor[bool], None]): The attention mask tensor. For each element, 0
1919            indicates retention and 1 indicates discard. Input tensor of shape :math:`(B, N1, S1, S2)`,
1920            `(B, 1, S1, S2)`, `(S1, S2)` or `(2048, 2048)`. In compression scenario, sparse_mode is 2, 3, or 4,
1921            attn_mask must be `(2048, 2048)`. When sparse_mode is 5, attn_mask must be `(B, N1, S1, S2)`,
1922            `(B, 1, S1, S2)`. When sparse_mode is 0 and 1, attn_mask should be `(B, N1, S1, S2)`, `(B, 1, S1, S2)`,
1923            `(S1, S2)`.
1924        prefix (Union[List[int64], Tuple[int64] None]): N value of each Batch in the prefix sparse calculation
1925            scenario. Input tensor of shape :math:`(B,)`. B max value 32. Not none only when sparse_mode is 5.
1926            If S1 > S2, N ranges from 0 to S2. If S1 <= S2, N ranges from S2 - S1 to S2.
1927        actual_seq_qlen (Union[List[int64], Tuple[int64], None]): Size of query corresponding to each batch, array
1928            with increasing values and the last value equal to T1.
1929        actual_seq_kvlen (Union[List[int64], Tuple[int64], None]): Size of key and value corresponding to each batch,
1930            array with increasing values and the last value equal to T2.
1931        keep_prob (float): The keep probability of dropout. Value range is (0.0, 1.0]. Default: 1.0. when keep_prob
1932            is 1.0, drop_mask should be none.
1933        scale_value (float): The scale factor of score. Generally, the value is 1.0 / (D ** 0.5). Default: 1.0.
1934        pre_tokens (int): Parameter for sparse computation, represents how many tokens are counted forward.
1935            When sparse_mode is set to 1, 2, 3, or 5, this parameter does not take effect. Default: 2147483647.
1936        next_tokens (int): Parameter for sparse computation, represents how many tokens are counted backward.
1937            When sparse_mode is set to 1, 2, 3, or 5, this parameter does not take effect. Default: 2147483647.
1938            The value of pre_tokens corresponds to S1, and the value of next_tokens corresponds to S2. They define the
1939            valid area on the attn_mask matrix. It must ensure that the band is not empty.
1940            The following values are not allowed:
1941
1942            - pre_tokens < 0 and next_tokens < 0.
1943            - (pre_tokens < 0 and next_tokens >= 0) and (next_tokens < abs(pre_tokens) or abs(pre_tokens) >= S2).
1944            - (pre_tokens >= 0 and next_tokens < 0) and (abs(next_tokens) > pre_tokens or abs(next_tokens) >= S1).
1945
1946        inner_precise (int): The parameter is reserved and not implemented yet. Default: 0.
1947        input_layout (str): Specifies the layout of input `query`, key and value. The value can be "BSH", "BNSD",
1948            "SBH", "BSND" or "TND". "TND" is an experimental format. Default: "BSH".
1949            When input_layout is "TND", the following restrictions must be met.
1950            There are two lists that represent the length of the input sequence: list_seq_q and list_seq_k. Each
1951            value in the list indicates the length of the sequence in the batch. For example, list_seq_q = [4, 2, 6],
1952            list_seq_k = [10, 3, 9]. The element of list indicate S. T1 is sum(list_seq_q) = 12, T2 is
1953            sum(list_seq_k) = 22.
1954            max_seqlen_q = max(list_seq_q), max_seqlen_k = max(list_seq_k).
1955            qk_pointer = sum(list_seq_q * list_seq_k), which is the sum of the element multiplication.
1956
1957            - The lengths of two lists are the same, and size of list is batch. batch is less than or equal to 1024.
1958            - When input_layout is "TND", actual_seq_qlen and actual_seq_kvlen must be not none.
1959              Otherwise, they are none.
1960            - The actual_seq_qlen and actual_seq_kvlen are the cumulative sum of sequence of key/value, so they must
1961              be non-decreasing.
1962            - If real_shift is not none, list_seq_q and list_seq_k must be same. The maximum value of list_seq_q and
1963              list_seq_k is greater than 1024. Real_shift should be `(B, N1, 1024, S2)` and `(1, N1, 1024, S2)`, and
1964              S2 is equal to max_seqlen_k.
1965            - Attn mask must be a lower trianglar matrix, so sparse_mode should be 2 or 3. The shape of attn_mask
1966              should be `(2048, 2048)`.
1967            - The shape of drop_mask is (qk_pointer * N1 // 8,).
1968            - Prefix is none.
1969            - Next_tokens is 0, and pre_tokens is not less than max_seqlen_q.
1970            - When sparse_mode is 3, S1 of each batch should be less than or equal to S2.
1971            - 0 should not exist in list_seq_k.
1972
1973        sparse_mode (int): Indicates sparse mode. Default 0.
1974
1975            - 0: Indicates the defaultMask mode. If attn_mask is not passed, the mask operation is not performed,
1976              and preTokens and nextTokens(internally assigned as INT_MAX) are ignored. If passed in, the full
1977              attn_mask matrix (S1 * S2) needs to be passed in, indicating that the part between preTokens and
1978              nextTokens needs to be calculated.
1979            - 1: Represents allMask, that is, passing in the complete attn_mask matrix.
1980            - 2: Representing the leftUpCausal mode corresponds to the lower triangle scenario divided by the left
1981              vertex, and the optimized attn_mask matrix (2048*2048) is required.
1982            - 3: Representing the rightDownCausal model corresponds to the lower triangle scene divided by the lower
1983              right vertex, and the optimized attn_mask matrix (2048*2048) is required.
1984            - 4: Represents the band scenario, that is, the part between counting preTokens and nextTokens, and the
1985              optimized attn_mask matrix (2048*2048) is required.
1986            - 5: Represents the prefix scenario, that is, on the basis of rightDownCasual, a matrix with length S1 and
1987              width N is added to the left side. The value of N is obtained by the new input prefix, and the N value
1988              of each Batch axis is different, not implemented yet.
1989            - 6: Represents the global scenario, not implemented yet.
1990            - 7: Represents the dilated scenario, not implemented yet.
1991            - 8: Represents the block_local scenario, not implemented yet.
1992
1993    Returns:
1994        attention_out (Tensor[float16, bfloat16]), The output of attention, its shape, and data type are the same
1995        as the query.
1996
1997    Supported Platforms:
1998        ``Ascend``
1999
2000    Examples:
2001        >>> import mindspore
2002        >>> import mindspore.common.dtype as mstype
2003        >>> import numpy as np
2004        >>> from mindspore import ops, Tensor
2005        >>> query = Tensor(np.ones([2, 4, 64]), dtype=mstype.float16)
2006        >>> key = Tensor(np.ones([2, 4, 64]), dtype=mstype.float16)
2007        >>> value = Tensor(np.ones([2, 4, 64]), dtype=mstype.float16)
2008        >>> head_num = 4
2009        >>> output = ops.flash_attention_score(query, key, value, head_num)
2010        >>> print(output.shape)
2011        (2, 4, 64)
2012    """
2013    rank_op = _get_cache_prim(FlashAttentionScore)(head_num, keep_prob, scalar_value, pre_tokens, next_tokens,
2014                                                   inner_precise, input_layout, sparse_mode)
2015    return rank_op(query, key, value, real_shift, drop_mask, padding_mask, attn_mask, prefix, actual_seq_qlen,
2016                   actual_seq_kvlen)[3]
2017