• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020-2022 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16"""Inner operators."""
17from types import FunctionType, MethodType
18from collections.abc import Iterable
19import os
20import numpy as np
21
22from mindspore.common import Tensor
23from mindspore.common._stub_tensor import StubTensor
24from mindspore.ops import composite as C
25from mindspore.ops.operations.array_ops import Cast
26from mindspore.ops.operations._scalar_ops import bit_or, bit_and
27from mindspore.ops import signature as sig
28from mindspore.ops.operations.math_ops import _infer_shape_reduce
29from mindspore.ops.primitive import PrimitiveWithCheck, PrimitiveWithInfer, prim_attr_register, Primitive, \
30    _run_op, _check_contains_variable
31from mindspore._c_expression import Tensor as Tensor_
32from mindspore._c_expression import typing
33from mindspore import _checkparam as validator
34from mindspore.common import dtype as mstype
35from mindspore.common.parameter import Parameter
36from mindspore.communication.management import GlobalComm, get_rank, _get_group, get_group_size
37from mindspore.common.api import _pynative_executor
38from mindspore.common._register_for_adapter import ms_adapter_registry
39from mindspore import ops
40from ..auto_generate import TensorCopySlices, SiLU, Cummin, TopKRouter, ExtractImagePatches, DecoderKVCache, \
41    PromptKVCache, ApplyCamePart1, ApplyCamePart2, ApplyCamePart3, ApplyCamePart4
42
43# Bit operation
44bit_and = bit_and()
45bit_or = bit_or()
46bit_xor = Primitive("bit_xor")
47bit_left_shift = Primitive("bit_left_shift")
48bit_right_shift = Primitive("bit_right_shift")
49# String operation
50string_lt = Primitive("string_lt")
51string_gt = Primitive("string_gt")
52string_le = Primitive("string_le")
53string_ge = Primitive("string_ge")
54string_not = Primitive("string_not")
55string_in = Primitive("string_in")
56string_mul = Primitive("string_mul")
57string_getitem = Primitive("string_getitem")
58
59
60class Generator(Primitive):
61    r"""
62    Manage the state of random number generation.
63
64    Inputs:
65        - **cmd** (int) : operation to be executed.
66        - **inputs** (tuple[tensor]) : inputs for the operation.
67
68    Outputs:
69        - **seed** (Tensor): Seed for the random number generation algorithm.
70        - **offset** (Tensor): Offset of the random number sequence.
71        - **state** (Tensor): State tensor, can be used to restore current state.
72    """
73
74    @prim_attr_register
75    def __init__(self):
76        self.add_prim_attr("side_effect_mem", True)
77
78    def __call__(self, cmd, inputs):
79        if cmd == 0:  # step cmd
80            return inputs[0], inputs[1]
81        return super().__call__(cmd, inputs)
82
83
84class Quant(PrimitiveWithInfer):
85    r"""
86    Returns the quantized value of input_x.
87
88    If `sqrt_mode` is False:
89
90    .. math::
91        y = round(scale * x + offset)
92
93    If `sqrt_mode` is True:
94
95    .. math::
96        y = round(scale * x * scale + offset)
97
98    Note:
99        This operation only support Atlas 200/300/500 inference product.
100
101    Args:
102        scale (float) : Specifies the scaling ratio.
103        offset (float): Specifies the offset.
104        sqrt_mode (bool) : Specifies whether to perform square root on `scale`. Default: ``False``.
105        round_mode (str): Specifies the way to round. Must be one of ["Round", "Floor", "Ceil", "Trunc"].
106          Default: "Round".
107
108    Inputs:
109        - **input_x** (Tensor) : Input tensor. Its data type must be mindspore.float16 or mindspore.float32.
110
111    Outputs:
112        - Tensor: The quantized output tensor of type mindspore.int8.
113
114    Examples:
115        >>> input_x = Tensor([100.0, 150.0], mstype.float32)
116        >>> quant = ops.Quant(80.0, 0.0, False, "Round")
117        >>> y = quant(input_x)
118    """
119
120    @prim_attr_register
121    def __init__(self, scale, offset, sqrt_mode=False, round_mode="Round"):
122        self.scale = validator.check_value_type("scale", scale, [float], self.name)
123        self.offset = validator.check_value_type("offset", offset, [float], self.name)
124        self.sqrt_mode = validator.check_value_type("sqrt_mode", sqrt_mode, [bool], self.name)
125        self.round_mode = validator.check_string(round_mode, ["Round", "Floor", "Ceil", "Trunc"],
126                                                 "round_mode", self.name)
127        self.add_prim_attr("dst_type", mstype.int8)
128
129    def infer_shape(self, x_shape):
130        return x_shape
131
132    def infer_dtype(self, x_type):
133        validator.check_subclass("input_x", x_type, mstype.tensor_type, self.name)
134        validator.check_type_name("input_x", x_type, [mstype.float16, mstype.float32], self.name)
135        return self.get_attr_dict()['dst_type']
136
137
138class Lamb(PrimitiveWithInfer):
139    r"""
140    LAMB optimizer algorithm.
141
142    The Lamb optimizer is proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes
143    <https://arxiv.org/abs/1904.00962>`_.
144
145    Inputs:
146        - **var** (Tensor) - Weights to be updated. The shape is :math:`(N, *)` where :math:`*` means,
147          any number of additional dimensions. The data type can be float16 or float32.
148        - **m** (Tensor) - The 1st moment vector in the updating formula,
149          the shape and data type value should be the same as `var`.
150        - **v** (Tensor) - the 2nd moment vector in the updating formula,
151          the shape and data type value should be the same as `var`. Mean square gradients with the same type as `var`.
152        - **lr** (float) - :math:`l` in the updating formula. The paper suggested value is :math:`10^{-8}`,
153          the data type value should be the same as `var`.
154        - **beta1** (float) - The exponential decay rate for the 1st moment estimations,
155          the data type value should be the same as `var`. The paper suggested value is :math:`0.9`
156        - **beta2** (float) - The exponential decay rate for the 2nd moment estimations,
157          the data type value should be the same as `var`. The paper suggested value is :math:`0.999`
158        - **epsilon** (float) - Term added to the denominator to improve numerical stability.
159        - **decay** (float) - The weight decay value, must be a scalar tensor with float data type.
160          Default: 0.0.
161        - **global_step** (Tensor) - Tensor to record current global step.
162        - **gradient** (Tensor) - Gradient, has the same shape and data type as `var`.
163
164    Outputs:
165        Tensor, the updated parameters.
166
167        - **var** (Tensor) - The same shape and data type as `var`.
168
169    Supported Platforms:
170        ``Ascend````GPU``
171    """
172
173    @prim_attr_register
174    def __init__(self):
175        """Initialize Lamb."""
176        self.add_prim_attr('side_effect_mem', True)
177
178    def infer_shape(self, var_shape, m_shape, v_shape, lr_shape, beta1_shape, beta2_shape,
179                    epsilon_shape, decay_shape, global_step_shape, gradient_shape):
180        validator.check("var_shape", var_shape, "m_shape", m_shape, validator.EQ, self.name)
181        validator.check("var_shape", var_shape, "v_shape", v_shape, validator.EQ, self.name)
182        validator.check("var_shape", var_shape, "gradient_shape", gradient_shape, validator.EQ, self.name)
183        return var_shape
184
185    def infer_dtype(self, var_dtype, m_dtype, v_dtype, lr_dtype, beta1_dtype, beta2_dtype,
186                    epsilon_dtype, decay_dtype, global_step_dtype, gradient_dtype):
187        args = {"var": var_dtype, "m": m_dtype, "v": v_dtype, "grad": gradient_dtype}
188        validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name)
189
190        args = {"lr": lr_dtype, "decay": decay_dtype, "beta1": beta1_dtype, "beta2": beta2_dtype,
191                "epsilon": epsilon_dtype}
192        validator.check_scalar_or_tensor_types_same(args, [mstype.float32], self.name, True)
193        return var_dtype
194
195
196class Dequant(PrimitiveWithInfer):
197    r"""
198    Returns the dequantized value of input_x.
199    This operation will do ReLU to the dequantized value if `relu_flag` is True.
200
201    If `sqrt_mode` is False:
202
203    .. math::
204        y = x * deq\_scale
205
206    If `sqrt_mode` is True:
207
208    .. math::
209        y = x * deq\_scale * deq\_scale
210
211    Note:
212        This operation only support Atlas 200/300/500 inference product.
213
214    Args:
215        sqrt_mode (bool) : Specifies whether to perform square root on `scale`. Default: ``False``.
216        relu_flag (bool): Specifies whether to perform ReLU. Default: ``False``.
217
218    Inputs:
219        - **input_x** (Tensor) : Input tensor. Must be mindspore.int32.
220        - **deq_scale** (Tensor) : Specifies the scaling ratio.
221          Data type must be mindspore.float16 or mindspore.uint64
222
223    Outputs:
224        - Tensor: The quantized output tensor of type mindspore.float16.
225
226    Examples:
227        >>> input_x = Tensor([100.0, 150.0], mstype.float32)
228        >>> dequant = ops.Dequant(False, False)
229        >>> y = dequant(input_x)
230    """
231
232    @prim_attr_register
233    def __init__(self, sqrt_mode=False, relu_flag=False, dtype=mstype.float16):
234        self.sqrt_mode = validator.check_value_type("sqrt_mode", sqrt_mode, [bool], self.name)
235        self.relu_flag = validator.check_value_type("relu_flag", relu_flag, [bool], self.name)
236        self.dtype = dtype
237
238    def infer_shape(self, x_shape, deq_scale_shape):
239        return x_shape
240
241    def infer_dtype(self, x_type, deq_scale_type):
242        validator.check_subclass("x", x_type, mstype.tensor_type, self.name)
243        validator.check_type_name("x", x_type, [mstype.int32], self.name)
244        validator.check_type_name("deq_scale", deq_scale_type, [mstype.float16, mstype.uint64], self.name)
245        return mstype.float16
246
247
248class AntiQuant(Primitive):
249    r"""
250    Returns the antiquantized value of input_x.
251
252    If `sqrt_mode` is False:
253
254    .. math::
255        y = scale * (x + offset)
256
257    If `sqrt_mode` is True:
258
259    .. math::
260        y = scale * scale * (x + offset)
261
262    Note:
263        This operation only support Atlas 200/300/500 inference product.
264
265    Args:
266        scale (float) : Specifies the scaling ratio.
267        offset (float): Specifies the offset.
268        sqrt_mode (bool) : Specifies whether to perform square root on `scale`. Default: ``False``.
269
270    Inputs:
271        - **input_x** (Tensor) : Input tensor. Must be mindspore.int8.
272
273    Outputs:
274        - Tensor: The antiquantized output tensor of type mindspore.float32.
275
276    Examples:
277        >>> from mindspore.ops.operations._inner_ops import AntiQuant
278        >>> input_x = Tensor([50.0, 20.0], mstype.int8)
279        >>> antiquant = AntiQuant(2.0, 1.0, False)
280        >>> y = antiquant(input_x)
281        >>> print(y)
282        [102. 42.]
283    """
284
285    @prim_attr_register
286    def __init__(self, sqrt_mode=False, dtype=mstype.float16):
287        super().__init__("AntiQuant")
288        self.sqrt_mode = validator.check_value_type("sqrt_mode", sqrt_mode, [bool], self.name)
289        self.dtype = dtype
290
291        self.init_prim_io_names(inputs=['x', 'scale', 'offset'],
292                                outputs=['y'])
293
294
295class MatrixDiag(PrimitiveWithInfer):
296    """
297    Returns a batched diagonal tensor with a given batched diagonal values.
298
299    Inputs:
300        - **x** (Tensor) - A tensor which to be element-wise multi by `assist`. It can be one of the following data
301          types: float32, float16, int32, int8, and uint8.
302        - **assist** (Tensor) - A eye tensor of the same type as `x`. It's rank must be greater than or equal to 2 and
303          it's last dimension must be equal to the second to last dimension.
304
305    Outputs:
306        Tensor, has the same type and shape as input `assist`.
307
308    Examples:
309        >>> x = Tensor(np.array([1, -1]), mstype.float32)
310        >>> assist = Tensor(np.arange(-12, 0).reshape(3, 2, 2), mindspore.float32)
311        >>> matrix_diag = ops.MatrixDiag()
312        >>> result = matrix_diag(x, assist)
313        >>> print(result)
314        [[[-12.   11.]
315          [-10.    9.]]
316         [[ -8.    7.]
317          [ -6.    5.]]
318         [[ -4.    3.]
319          [ -2.    1.]]]
320    """
321
322    @prim_attr_register
323    def __init__(self):
324        """Initialize MatrixDiag"""
325
326    def infer_dtype(self, x_dtype, assist_dtype):
327        valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8]
328        args = {"x": x_dtype, "assist": assist_dtype}
329        validator.check_tensors_dtypes_same_and_valid(args, valid_type, self.name)
330        return x_dtype
331
332    def infer_shape(self, x_shape, assist_shape):
333        validator.check_int(len(assist_shape), 2, validator.GE, "assist rank", self.name)
334        validator.check('rank of x', len(x_shape) + 1,
335                        'rank of assist', len(assist_shape), validator.LE, self.name)
336        validator.check('assist\'s penultimate dimension', assist_shape[-2], 'assist\'s last dimension',
337                        assist_shape[-1], validator.EQ, self.name)
338
339        r_end_dim = -len(x_shape)
340        r_idx = -1
341        while r_idx >= r_end_dim:
342            if x_shape[r_idx] != 1:
343                validator.check("reverse x dim %d" % r_idx, x_shape[r_idx], "reverse assist dim %d" %
344                                assist_shape[r_idx - 1], assist_shape[r_idx - 1], validator.EQ, self.name)
345            r_idx = r_idx - 1
346
347        return assist_shape
348
349
350class MatrixDiagPart(PrimitiveWithInfer):
351    r"""
352    Returns the batched diagonal part of a batched tensor.
353
354    Inputs:
355        - **x** (Tensor) - The batched tensor. It can be one of the following data types:
356          float32, float16, int32, int8, uint8.
357        - **assist** (Tensor) - A eye tensor of the same type as `x`. With shape same as `x`.
358
359    Outputs:
360        Tensor, data type same as input `x`. The shape must be x.shape[:-2] + [min(x.shape[-2:])].
361
362    Examples:
363        >>> x = Tensor([[[-1, 0], [0, 1]], [[-1, 0], [0, 1]], [[-1, 0], [0, 1]]], mindspore.float32)
364        >>> assist = Tensor(np.arange(-12, 0).reshape(3, 2, 2), mindspore.float32)
365        >>> matrix_diag_part = ops.MatrixDiagPart()
366        >>> result = matrix_diag_part(x, assist)
367        >>> print(result)
368        [[12., -9.], [8., -5.], [4., -1.]]
369    """
370
371    @prim_attr_register
372    def __init__(self):
373        """Initialize MatrixDiagPart"""
374
375    def infer_dtype(self, x_dtype, assist_dtype):
376        valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8]
377        args = {"x": x_dtype, "assist": assist_dtype}
378        validator.check_tensors_dtypes_same_and_valid(args, valid_type, self.name)
379        return x_dtype
380
381    def infer_shape(self, x_shape, assist_shape):
382        validator.check_int(len(x_shape), 2, validator.GE, "x rank", self.name)
383        validator.check("x shape", x_shape, "assist shape", assist_shape, validator.EQ, self.name)
384
385        if assist_shape[-2] < assist_shape[-1]:
386            out_shape = assist_shape[:-1]
387        else:
388            out_shape = assist_shape[:-2] + assist_shape[-1:]
389        return out_shape
390
391
392class MatrixSetDiag(PrimitiveWithInfer):
393    r"""
394    Modifies the batched diagonal part of a batched tensor.
395
396    Inputs:
397        - **x** (Tensor) - The batched tensor. Rank k+1, where k >= 1. It can be one of the following data types:
398          float32, float16, int32, int8, uint8.
399        - **diagonal** (Tensor) - The diagonal values. Must have the same type as input `x`. Rank k, where k >= 1.
400        - **assist** (Tensor) - A eye tensor of the same type as `x`. With shape same as `x`.
401
402    Outputs:
403        Tensor, data type same as input `x`. The shape same as `x`.
404
405    Examples:
406        >>> x = Tensor([[[-1, 0], [0, 1]], [[-1, 0], [0, 1]], [[-1, 0], [0, 1]]], mindspore.float32)
407        >>> diagonal = Tensor([[-1., 2.], [-1., 1.], [-1., 1.]], mindspore.float32)
408        >>> matrix_set_diag = ops.MatrixSetDiag()
409        >>> result = matrix_set_diag(x, diagonal)
410        >>> print(result)
411        [[[-1, 0], [0, 2]], [[-1, 0], [0, 1]], [[-1, 0], [0, 1]]]
412
413    """
414
415    @prim_attr_register
416    def __init__(self):
417        """Initialize MatrixSetDiag"""
418
419    def infer_dtype(self, x_dtype, diagonal_dtype, assist_dtype):
420        valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8]
421        args = {"x": x_dtype, "diagonal": diagonal_dtype, "assist": assist_dtype}
422        validator.check_tensors_dtypes_same_and_valid(args, valid_type, self.name)
423        return x_dtype
424
425    def infer_shape(self, x_shape, diagonal_shape, assist_shape):
426        validator.check_int(len(x_shape), 2, validator.GE, "x rank", self.name)
427        validator.check("x shape", x_shape, "assist shape", assist_shape, validator.EQ, self.name)
428
429        if x_shape[-2] < x_shape[-1]:
430            validator.check("diagonal shape", diagonal_shape, "x shape excluding the last dimension",
431                            x_shape[:-1], validator.EQ, self.name)
432        else:
433            validator.check("diagonal shape", diagonal_shape, "x shape excluding the second last dimension",
434                            x_shape[:-2] + x_shape[-1:], validator.EQ, self.name)
435
436        return assist_shape
437
438
439class ConfusionMulGrad(PrimitiveWithInfer):
440    """
441    `output0` is the dot product result of input0 and input1.
442
443    `output1` is the dot product result of input0 and input1, then apply the reducesum operation on it.
444
445    Args:
446        axis (Union[int, tuple[int], list[int]]): The dimensions to reduce.
447            Default:(), reduce all dimensions. Only constant value is allowed.
448        keep_dims (bool):
449
450            - If true, keep these reduced dimensions and the length as 1.
451            - If false, don't keep these dimensions. Default:False.
452
453    Inputs:
454        - **input_0** (Tensor) - The input Tensor.
455        - **input_1** (Tensor) - The input Tensor.
456        - **input_2** (Tensor) - The input Tensor.
457
458    Outputs:
459        - **output_0** (Tensor) - The same shape as `input0`.
460        - **output_1** (Tensor)
461
462            - If axis is (), and keep_dims is false, the output is a 0-D array representing
463              the sum of all elements in the input array.
464            - If axis is int, set as 2, and keep_dims is false,
465              the shape of output is :math:`(x_1,x_3,...,x_R)`.
466            - If axis is tuple(int), set as (2,3), and keep_dims is false,
467              the shape of output is :math:`(x_1,x_4,...x_R)`.
468
469    Examples:
470        >>> confusion_mul_grad = ops.ConfusionMulGrad()
471        >>> input_0 = Tensor(np.random.randint(-2, 2, (2, 3)), mindspore.float32)
472        >>> input_1 = Tensor(np.random.randint(0, 4, (2, 3)), mindspore.float32)
473        >>> input_2 = Tensor(np.random.randint(-4, 0, (2, 3)), mindspore.float32)
474        >>> output_0, output_1 = confusion_mul_grad(input_0, input_1, input_2)
475        output_0:
476            [[ 3.   1.   0.]
477             [-6.   2.  -2.]]
478        output_1:
479            -3.0
480    """
481
482    @prim_attr_register
483    def __init__(self, axis=(), keep_dims=False):
484        self.init_prim_io_names(inputs=["input0", "input1", "input2"], outputs=["output0", "output1"])
485        self.axis_ = validator.check_value_type("axis", axis, [int, tuple, list], self.name)
486        self.keep_dims_ = validator.check_value_type("keep_dims", keep_dims, [bool], self.name)
487
488    def infer_shape(self, input0_shape, input1_shape, input2_shape):
489        outshape0 = input0_shape
490        outshape1 = _infer_shape_reduce(input1_shape, self.axis_, self.keep_dims_, self.name)
491        return outshape0, outshape1
492
493    def infer_dtype(self, input0_dtype, input1_dtype, input2_dtype):
494        validator.check_subclass("input0_dtype", input0_dtype, mstype.tensor_type, self.name)
495        validator.check_subclass("input1_dtype", input1_dtype, mstype.tensor_type, self.name)
496        validator.check_subclass("input2_dtype", input2_dtype, mstype.tensor_type, self.name)
497        return input0_dtype, input1_dtype
498
499
500class ConvertToDynamic(PrimitiveWithCheck):
501    """
502    This op is used for dynamic rank testing. Its inferred shape will be unknown
503    during compile time, so that its output will appear to be dynamically ranked.
504    The input will not be altered in any way. Put this operator before the operator
505    being tested for dynamic rank support.
506
507    Args:
508        is_dynamic_rank (bool): If true, convert to dynamic rank.
509                                If false, convert to dynamic shape. Default: ``False``.
510
511    Inputs:
512        - **input** (Tensor) - The tensor used for testing.
513
514    Outputs:
515        - **output** (Tensor) - Same shape, type and value as `input`.
516
517    Supported Platforms:
518        ``CPU``
519
520    Examples:
521          >>> import mindspore as ms
522          >>> import mindspore.nn as nn
523          >>> from mindspore.ops.operations import _inner_ops as inner
524          >>> from mindspore.ops import operations as P
525          >>> class TestDynamicNet(nn.Cell):
526          >>>     def __init__(self):
527          >>>         super(TestDynamicNet, self).__init__()
528          >>>         self.convert_to_dynamic = inner.ConvertToDynamic()
529          >>>         # suppose we are testing Reshape op
530          >>>         self.reshape = P.Reshape()
531          >>>
532          >>>     def construct(self, input, new_shape):
533          >>>         dynamic_input = self.convert_to_dynamic(input)
534          >>>         reshaped_input = self.reshape(dynamic_input, new_shape)
535          >>>
536          >>> ms.set_context(mode=ms.GRAPH_MODE, device_target="CPU")
537          >>> input = Tensor(np.array([0, 1, 2, 3])
538          >>> new_shape = (2, 2)
539          >>> net = TestDynamicNet()
540          >>> output = net(input, new_shape)
541          >>> print(output)
542          [[0, 1], [2, 3]
543    """
544
545    @prim_attr_register
546    def __init__(self, is_dynamic_rank=False):
547        validator.check_value_type('is_dynamic_rank', is_dynamic_rank, [bool], self.name)
548        self.init_prim_io_names(inputs=["input"], outputs=["output"])
549
550    def check_shape(self, input_shape):
551        validator.check("input_shape rank", len(input_shape), "", 0, validator.GT, self.name)
552
553    def check_dtype(self, input_dtype):
554        validator.check_subclass("input_dtype", input_dtype, mstype.tensor_type, self.name)
555
556
557class GpuConvertToDynamicShape(PrimitiveWithCheck):
558    """
559    This op is used for dynamic shape testing. Its inferred shape will be unknown
560    during compile time, so that its output will appear to be dynamically shaped.
561    The input will not be altered in any way. Put this operator before the operator
562    being tested for dynamic shape support.
563
564    Inputs:
565        - **input** (Tensor) - The tensor used for testing.
566
567    Outputs:
568        - **output** (Tensor) - Same shape, type and value as `input`.
569
570    Examples:
571          >>> # make a model, since dynamic shape operators must be in GRAPH_MODE
572          >>> import mindspore as ms
573          >>> import mindspore.nn as nn
574          >>> from mindspore.ops.operations import _inner_ops as inner
575          >>> from mindspore.ops import operations as P
576          >>> class TestDynamicShapeReshapeNet(nn.Cell):
577          >>>     def __init__(self):
578          >>>         super(TestDynamicShapeReshapeNet, self).__init__()
579          >>>         self.convert_to_dynamic_shape = inner.GpuConvertToDynamicShape()
580          >>>         # suppose we are testing Reshape op
581          >>>         self.reshape = P.Reshape()
582          >>>
583          >>>     def construct(self, input, new_shape):
584          >>>         dynamic_shape_input = self.convert_to_dynamic_shape(input)
585          >>>         reshaped_input = self.reshape(input, new_shape)
586          >>>
587          >>> ms.set_context(mode=ms.GRAPH_MODE, device_target="GPU")
588          >>> input = Tensor(np.array([0, 1, 2, 3])
589          >>> new_shape = (2, 2)
590          >>> net = TestDynamicShapeReshapeNet()
591          >>> output = net(input, new_shape)
592          >>> print(output)
593          [[0, 1], [2, 3]
594    """
595
596    @prim_attr_register
597    def __init__(self):
598        self.init_prim_io_names(inputs=["input"], outputs=["output"])
599
600    def check_shape(self, input_shape):
601        validator.check("input_shape rank", len(input_shape), "", 0, validator.GT, self.name)
602
603    def check_dtype(self, input_dtype):
604        validator.check_subclass("input_dtype", input_dtype, mstype.tensor_type, self.name)
605
606
607class ErrorOnDynamicShapeInput(PrimitiveWithInfer):
608    """
609    This op is used for dynamic shape testing. The only purpose of this operator is
610    that it will throw a value error if the input is dynamically shaped.
611
612    Inputs:
613        - **input** (Tensor) - The tensor used for testing.
614
615    Outputs:
616        - **output** (Tensor) - Same shape, type and value as `input`.
617
618    Examples:
619          >>> # make a model, since dynamic shape operators must be in GRAPH_MODE
620          >>> import mindspore as ms
621          >>> import mindspore.nn as nn
622          >>> from mindspore.ops.operations import _inner_ops as inner
623          >>> from mindspore.ops import operations as P
624          >>> class AssertDynamicShapeNet(nn.Cell):
625          >>>     def __init__(self):
626          >>>         super(AssertDynamicShapeNet, self).__init__()
627          >>>         self.convert_to_dynamic_shape = inner.GpuConvertToDynamicShape()
628          >>>         self.error_on_dynamic_shape_input = inner.ErrorOnDynamicShapeInput()
629          >>>
630          >>>     def construct(self, input, new_shape):
631          >>>         dynamic_shape_input = self.convert_to_dynamic_shape(input)
632          >>>         self.error_on_dynamic_shape_input(dynamic_shape_input)
633          >>>
634          >>> ms.set_context(mode=ms.GRAPH_MODE, device_target="GPU")
635          >>> input = Tensor(np.array([0])
636          >>> net = TestDynamicShapeReshapeNet()
637          >>> output = net(input, new_shape)
638          ValueError: Input is dynamically shaped.
639    """
640
641    @prim_attr_register
642    def __init__(self):
643        self.init_prim_io_names(inputs=["input"], outputs=["output"])
644
645    def infer_shape(self, input_shape):
646        shape = list(input_shape)
647
648        for dim in shape:
649            if dim == -1:
650                raise ValueError("Input is dynamically shaped.")
651
652        return input_shape
653
654    def infer_type(self, input_dtype):
655        """Infer the dtype of input for ErrorOnDynamicShapeInput."""
656        validator.check_subclass("input_dtype", input_dtype, mstype.tensor_type, self.name)
657        return input_dtype
658
659    def infer_value(self, input_tensor):
660        return input_tensor
661
662
663class SequenceMask(PrimitiveWithCheck):
664    """
665    Returns a mask tensor representing the first N positions of each cell.
666
667    If lengths has shape [d_1, d_2, ..., d_n], then the resulting tensor mask has type and shape
668    [d_1, d_2, ..., d_n, maxlen], with mask[i_1, i_2, ..., i_n, j] = (j < lengths[i_1, i_2, ..., i_n])
669
670    Inputs:
671        - **lengths** (Tensor) - Tensor to calculate the mask for. All values in this tensor should be
672          less than or equal to `maxlen`. Values greater than `maxlen` will be treated as `maxlen`.
673          Must be type int32 or int64.
674
675        - **maxlen** (int) - size of the last dimension of returned tensor. Must be positive and same
676          type as elements in `lengths`.
677
678    Outputs:
679        One mask tensor of shape lengths.shape + (maxlen,).
680
681    Supported Platforms:
682        ``GPU`` ``CPU``
683
684    Examples:
685        >>> from mindspore import ops
686        >>> import numpy as np
687        >>> x = Tensor(np.array([[1, 3], [2, 0]]))
688        >>> sequence_mask = ops.SequenceMask()
689        >>> output = sequence_mask(x, 3)
690        >>> print(output)
691        [[[True False False]
692          [True True True]]
693         [[True True False]
694          [False False False]]]
695    """
696
697    @prim_attr_register
698    def __init__(self):
699        self.init_prim_io_names(inputs=["lengths", "maxlen"], outputs=["mask"])
700
701    def check_shape(self, lengths_shape, maxlen_shape):
702        validator.check("lengths_shape", len(lengths_shape), "", 0, validator.GT, self.name)
703        validator.check("maxlen_shape", len(maxlen_shape), "", 0, validator.EQ, self.name)
704
705    def check_dtype(self, lengths_dtype, maxlen_dtype):
706        validator.check_subclass("lengths_dtype", lengths_dtype, mstype.tensor_type, self.name)
707        validator.check_subclass("maxlen", maxlen_dtype, mstype.number, self.name)
708
709
710class SyncBatchNorm(Primitive):
711    r"""
712    Sync Batch Normalization for input data and updated parameters.
713
714    Sync Batch Normalization is cross device synchronized Batch Normalization. Batch Normalization is
715    widely used in convolutional neural networks. This operation applies Batch Normalization over input
716    to avoid internal covariate shift as described in the paper `Batch Normalization: Accelerating
717    Deep Network Training by Reducing Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`_.
718    It rescales and recenters the features using a mini-batch of data and the learned parameters which
719    can be described in the following formula,
720
721    .. math::
722        y = \frac{x - mean}{\sqrt{variance + \epsilon}} * \gamma + \beta
723
724    where :math:`\gamma` is scale, :math:`\beta` is bias, :math:`\epsilon` is epsilon.
725
726    Args:
727        epsilon (float): A small value added for numerical stability. Default: 1e-5.
728        momentum (float): The hyper parameter to compute moving average for running_mean and running_var
729            (e.g. :math:`new\_running\_mean = (1 - momentum) * running\_mean + momentum * current\_mean`).
730            Momentum value must be [0, 1]. Default: 0.1.
731        group (str): The communication group to work on. Default: "sync_bn_group0".
732        device_num (int): The number of devices in each group. Default: 2.
733
734    Inputs:
735        - **input_x** (Tensor) - Tensor of shape :math:`(N, C)`, with float16 or float32 data type.
736        - **scale** (Tensor) - Tensor of shape :math:`(C,)`, with float16 or float32 data type.
737        - **bias** (Tensor) - Tensor of shape :math:`(C,)`, has the same data type with `scale`.
738        - **mean** (Tensor) - Tensor of shape :math:`(C,)`, with float16 or float32 data type.
739        - **variance** (Tensor) - Tensor of shape :math:`(C,)`, has the same data type with `mean`.
740
741    Outputs:
742        Tuple of 5 Tensor, the normalized inputs and the updated parameters.
743
744        - **output_x** (Tensor) - The same type and shape as the input_x. The shape is :math:`(N, C)`.
745        - **updated_scale** (Tensor) - Tensor of shape :math:`(C,)`.
746        - **updated_bias** (Tensor) - Tensor of shape :math:`(C,)`.
747        - **updated_moving_mean** (Tensor) - Tensor of shape :math:`(C,)`.
748        - **updated_moving_variance** (Tensor) - Tensor of shape :math:`(C,)`.
749
750    Supported Platforms:
751        ``Ascend``
752
753    Examples:
754        >>> # This example should be run with multiple processes.
755        >>> # Please refer to nn.SyncBatchNorm for direct use.
756        >>> input_x = Tensor(np.ones([2, 2]), mindspore.float32)
757        >>> scale = Tensor(np.ones([2]), mindspore.float32)
758        >>> bias = Tensor(np.ones([2]), mindspore.float32)
759        >>> mean = Tensor(np.ones([2]), mindspore.float32)
760        >>> variance = Tensor(np.ones([2]), mindspore.float32)
761        >>> sync_batch_norm = ops._inner_ops.SyncBatchNorm()
762        >>> output = sync_batch_norm(input_x, scale, bias, mean, variance)
763        >>> print(output)
764        (Tensor(shape=[2, 2], dtype=Float32, value=
765        [[ 1.00000000e+00, 1.00000000e+00],
766         [ 1.00000000e+00, 1.00000000e+00]]), Tensor(shape=[2], dtype=Float32, value=
767         [ 1.00000000e+00, 1.00000000e+00]), Tensor(shape=[2], dtype=Float32, value=
768         [ 1.00000000e+00, 1.00000000e+00]), Tensor(shape=[2], dtype=Float32, value=
769         [ 1.00000000e+00, 1.00000000e+00]), Tensor(shape=[2], dtype=Float32, value=
770         [ 1.00000000e+00, 1.00000000e+00]))
771    """
772
773    @prim_attr_register
774    def __init__(self, epsilon=1e-5, momentum=0.1, group="sync_bn_group0", device_num=2):
775        validator.check_float_range(epsilon, 0, 1, validator.INC_RIGHT, 'epsilon', self.name)
776        validator.check_float_range(momentum, 0, 1, validator.INC_BOTH, 'momentum', self.name)
777        validator.check_isinstance("group", group, str)
778        validator.check_int(device_num, 2, validator.GE, "device_num", self.name)
779        self.init_prim_io_names(inputs=['x', 'scale', 'offset', 'mean', 'variance'],
780                                outputs=['y', 'batch_mean', 'batch_variance', 'reserve_space_1', 'reserve_space_2'])
781        self.add_prim_attr('side_effect_mem', True)
782        self.add_prim_attr('format', 'NCHW')
783
784
785class Centralization(PrimitiveWithInfer):
786    """
787    Computes centralization. y = x - mean(x, axis).
788
789    Note:
790        The dimension index starts at 0 and must be in the range `[-input.ndim, input.ndim)`.
791
792    Inputs:
793        - **input_x** (Tensor) - The input tensor. The data type mast be float16 or float32.
794        - **axis** (Union[int, Tuple(int), List(int)]) - The dimensions to reduce. Default: (), reduce all dimensions.
795          Only constant value is allowed. Must be in the range [-rank(input_x), rank(input_x)).
796
797    Outputs:
798        Tensor, has the same shape and dtype as the `input_x`.
799
800    Raises:
801        TypeError: If `axis` is not one of the following types: int, list, tuple, NoneType.
802        TypeError: If `axis` has non-Int elements.
803
804    Supported Platforms:
805        ``Ascend``
806
807    Examples:
808        >>> mindspore.set_seed(1)
809        >>> input_x = Tensor(np.random.randn(2, 2).astype(np.float32))
810        >>> centralization = ops.Centralization()
811        >>> output = centralization(input_x, -1)
812        >>> print(output)
813        [[ 1.1180509 -1.1180508]
814         [ 0.2723984 -0.2723984]]
815    """
816
817    __mindspore_signature__ = (
818        sig.make_sig('input_x'),
819        sig.make_sig('axis', default=())
820    )
821
822    @prim_attr_register
823    def __init__(self):
824        """Initialize Centralization"""
825        self.init_prim_io_names(inputs=['input_x', 'axis'], outputs=['output'])
826
827    def __infer__(self, input_x, axis):
828        x_shape = list(input_x['shape'])
829        x_dtype = input_x['dtype']
830        axis_v = axis['value']
831        rank = len(x_shape)
832
833        args = {'input_x': input_x['dtype']}
834        validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name)
835
836        if axis_v is None:
837            raise ValueError(f"For {self.name}, axis must be const.")
838        validator.check_value_type('axis', axis_v, [int, list, tuple], self.name)
839
840        if isinstance(axis_v, int):
841            validator.check_int_range(axis_v, -rank, rank, validator.INC_LEFT, 'axis', self.name)
842        elif axis:
843            for index, one_axis in enumerate(axis_v):
844                validator.check_value_type('axis[%d]' % index, one_axis, [int], self.name)
845
846        out = {'shape': x_shape,
847               'dtype': x_dtype,
848               'value': None}
849        return out
850
851
852class StackInit(PrimitiveWithInfer):
853    """
854    Create a stack that produces tensors in first-in last-out order.
855
856    After `StackInit`, a tensor can be pushed onto the stack using `StackPush`, and popped
857    at the top of the stack using `StackPop`. Finally, the stack should be destroyed with `StackDestroy`.
858
859    Args:
860        index (int): The index of the stack. Default: 1.
861
862    Supported Platforms:
863        ``Ascend``
864
865    Examples:
866        >>> x = Tensor(np.array([[1, 3], [2, 0]]))
867        >>> index = 0
868        >>> stack = ops.StackInit(index)
869        >>> push = ops.StackPush(index)
870        >>> pop = ops.StackPop(index, x.shape, x.dtype)
871        >>> destroy = ops.StackDestroy(index)
872        >>> stack()
873        >>> push(x)
874        >>> y = pop()
875        >>> destroy()
876        >>> print(y)
877        [[1 3]
878         [2 0]]
879    """
880
881    @prim_attr_register
882    def __init__(self, index=1):
883        """StackInit"""
884        validator.check_value_type("index", index, [int], self.name)
885
886
887class StackPush(PrimitiveWithInfer):
888    """
889    Push a tensor onto the stack.
890
891    Before `StackPush`, the stack should be created using `StackInit`.
892    Please refer to the usage in source code of `StackInit`.
893
894    Args:
895        index (int): The index of the stack. Default: 1.
896
897    Inputs:
898        - **input** (Tensor) - A tensor to be pushed onto the stack.
899
900    Supported Platforms:
901        ``Ascend``
902
903    Examples:
904        Please refer to the usage of `StackInit`.
905    """
906
907    @prim_attr_register
908    def __init__(self, index=1):
909        """StackPush"""
910        validator.check_value_type("index", index, [int], self.name)
911        self.init_prim_io_names(inputs=['input'], outputs=[])
912
913
914class StackPop(PrimitiveWithInfer):
915    """
916    Pop the tensor at the top of the stack.
917
918     Before `StackPop`, the stack should be created using `StackInit`.
919     Please refer to the usage in source code of `StackInit`.
920
921    Args:
922        index (int): The index of the stack. Default: 1.
923        shape (tuple): The shape of the tensor at the top of the stack. Default: (1,).
924        dtype (mindspore.dtype): The type of the tensor at the top of the stack. Default: mindspore.float32.
925
926    Outputs:
927        - **output** (Tensor) - The tensor at the top of the stack.
928
929    Supported Platforms:
930        ``Ascend``
931
932    Examples:
933        Please refer to the usage of `StackInit`.
934    """
935
936    @prim_attr_register
937    def __init__(self, index=1, shape=(1,), dtype=mstype.float32):
938        """StackPop"""
939        validator.check_value_type("index", index, [int], self.name)
940
941        validator.check_value_type('shape type', shape, [list, tuple], self.name)
942        validator.check_int(len(np.array(shape).shape), 1, validator.EQ, "dim of shape", self.name)
943        for elem in shape:
944            validator.check_int(elem, 1, validator.GE, 'shape element', self.name)
945            validator.check_value_type('type of shape element', elem, [int], self.name)
946
947        validator.check_type_name("dtype", dtype, (mstype.bool_,) + mstype.number_type, self.name)
948        self.shape = shape
949        self.dtype = dtype
950
951        self.init_prim_io_names(inputs=[], outputs=['output'])
952
953    def __infer__(self):
954        return {'shape': (list(self.shape)),
955                'dtype': (self.dtype),
956                'value': None}
957
958
959class StackDestroy(PrimitiveWithInfer):
960    """
961    Destroy the stack.
962
963     Before `StackDestroy`, the stack should be created using `StackInit`.
964     Please refer to the usage in source code of `StackInit`.
965
966    Args:
967        index (int): The index of the stack. Default: 1.
968
969    Supported Platforms:
970        ``Ascend``
971
972    Examples:
973        Please refer to the usage of `StackInit`.
974    """
975
976    @prim_attr_register
977    def __init__(self, index=1):
978        """StackDestroy"""
979        validator.check_value_type("index", index, [int], self.name)
980
981
982class DynamicStitch(PrimitiveWithCheck):
983    r"""
984    Interleave the values from the data tensors into a single tensor.
985
986    Inputs:
987        - **indices** (Union[tuple, list]) - A Tuple or list of Tensor objects with the same shape and type.
988        - **data** (Union[tuple, list]) - A Tuple or list of Tensor objects with the same shape and type.
989
990    Outputs:
991        Tensor. A stacked Tensor with the same type as `data`.
992
993    Raises:
994        TypeError: If the data types of elements in `data` or `indices` are not the same.
995        ValueError: If the length of `data` or `indices` is not greater than 1.
996
997    Supported Platforms:
998        ``Ascend``
999
1000    Examples:
1001        >>> x1 = Tensor([6], mstype.int32)
1002        >>> x2 = Tensor(np.array([4, 1]), mstype.int32)
1003        >>> x3 = Tensor(np.array([[5, 2], [0, 3]]), mstype.int32)
1004        >>> y1 = Tensor(np.array([[6, 1]]), mstype.int32)
1005        >>> y2 = Tensor(np.array([[41, 42], [11, 12]]), mstype.int32)
1006        >>> y3 = Tensor(np.array([[[51, 52], [21, 22]], [[1, 2], [31, 32]]]), mstype.int32)
1007        >>> stitch = ops.DynamicStitch()
1008        >>> output = stitch([x1, x2, x3], [y1, y2, y3])
1009        >>> print(output)
1010        [[ 1  2]
1011         [11 12]
1012         [21 22]
1013         [31 32]
1014         [41 42]
1015         [51 52]
1016         [61 62]]
1017    """
1018
1019    @prim_attr_register
1020    def __init__(self):
1021        """Initialize DynamicStitch"""
1022
1023    def check_shape(self, indices_shape, data_shape):
1024        validator.check_value_type("shape of indices", indices_shape, [tuple, list], self.name)
1025        validator.check_int(len(indices_shape), 1, validator.GE, "len of indices_shape", self.name)
1026        indices_dim0 = len(indices_shape[0])
1027        indices_num = len(indices_shape)
1028
1029        validator.check_value_type("shape of data", data_shape, [tuple, list], self.name)
1030        validator.check_int(len(data_shape), 1, validator.GE, "len of data_shape", self.name)
1031        data_dim0 = len(data_shape[0])
1032        data_num = len(indices_shape)
1033
1034        validator.check("size of indices", indices_num, 'size of data', data_num, validator.EQ, self.name)
1035
1036        # shape of `data` must start with shape of `indices`
1037        for i in range(0, indices_num):
1038            indices_dim = len(indices_shape[i])
1039            data_dim = len(data_shape[i])
1040            validator.check(f"dim of indices[{i}]", indices_dim, f"dim of data[{i}]", data_dim, validator.LE, self.name)
1041            if data_shape[i][:indices_dim] != data_shape[i][:indices_dim]:
1042                raise ValueError(f"data[{i}].shape: {data_shape} does not start with indices[{i}].shape: {data_shape}")
1043
1044        # the last-(data_dim0-indices_dim0)-dim of data shape must end with same shape.
1045        base_extra = data_dim0 - indices_dim0
1046        for i in range(0, data_num):
1047            indices_dim = len(indices_shape[i])
1048            data_dim = len(data_shape[i])
1049            extra = data_dim - indices_dim
1050            validator.check(f"extra dim of data[{i}]", extra,
1051                            f"extra dim of data[0]", base_extra, validator.EQ, self.name)
1052            validator.check(f"data[0].shape[{indices_dim0}:]", data_shape[0][indices_dim0:],
1053                            f"data[{i}].shape[{len(indices_shape[i])}:]",
1054                            data_shape[i][indices_dim:], validator.EQ, self.name)
1055
1056        out_shape = [-1] + data_shape[0][indices_dim0:]
1057        return out_shape
1058
1059    def check_dtype(self, indices_type, data_type):
1060        validator.check_subclass("indices[0]", indices_type[0], mstype.tensor_type, self.name)
1061        validator.check_subclass("data[0]", data_type[0], mstype.tensor_type, self.name)
1062        indices_num = len(indices_type)
1063        for i in range(0, indices_num):
1064            validator.check_tensor_dtype_valid(f'indices[{i}]', indices_type[i], mstype.int32, self.name)
1065            validator.check_tensor_dtype_valid(f'data[{i}]', data_type[i],
1066                                               mstype.number_type + (mstype.bool_,), self.name)
1067            validator.check(f"type of data[{i}]", data_type[i], f"type of data[0]",
1068                            data_type[0], validator.EQ, self.name)
1069        return data_type[0]
1070
1071
1072class DynamicBroadcastGradientArgs(Primitive):
1073    """
1074    Broadcast the two input shapes, return the dimensions that each need to be broadcast.
1075
1076    Input shape `s0` and shape `s1` can be broadcast to a common shape if for each dimension pair they are either equal
1077    or input is one or the target dimension is -1. In case of -1 in target shape, it will be replaced by the input
1078    shape's value in that dimension.
1079
1080    Inputs:
1081        - **s0** (Tensor) - A `1-D` tensor. The data type should be one of the following types: int32, int64,
1082          uint32, uint64.
1083        - **s1** (Tensor) - A `1-D` tensor with the same type as `s0`.
1084
1085    Outputs:
1086        Tuple(Tensor), tuple of 2 tensors, r0 and r1. The first one is the index tensor and the other one is the mask
1087        tensor.
1088
1089        - **r0** (Tensor) - The output shape is 1-D with the same type as s0.
1090        - **r1** (Tensor) - The output shape is 1-D with the same type as s0.
1091
1092    Raises:
1093        ValueError: if the `s0` and `s1` are incompatible, or if a - 1 in the target shape is in an invalid
1094                    location.
1095
1096    Supported Platforms:
1097        ``Ascend``
1098
1099    Examples:
1100        >>> shape0 = (4, 2, 1)
1101        >>> shape1 = (2, 7)
1102        >>> from mindspore.ops.operations import _inner_ops
1103        >>> args = _inner_ops.DynamicBroadcastGradientArgs()
1104        >>> r0, r1 = args(Tensor(shape0), Tensor(shape1))
1105        >>> print(r0, r1)
1106        [2], [0]
1107    """
1108
1109    @prim_attr_register
1110    def __init__(self):
1111        """Init BroadcastGradientArgs"""
1112
1113
1114class DSDMatmul(PrimitiveWithInfer):
1115    """
1116    The definition of the CusSquare primitive.
1117    """
1118
1119    @prim_attr_register
1120    def __init__(self):
1121        self.init_prim_io_names(inputs=['input_w1', 'input_w2', 'input_v'], outputs=['output_y'])
1122
1123    def infer_shape(self, input_w1_shape, input_w2_shape, input_v_shape):
1124        batch_size = input_w1_shape[0]
1125        head = input_w1_shape[1]
1126        v_embedding = input_v_shape[1] * 16 // head
1127        seq_len = input_v_shape[0] * 16 // batch_size
1128        return (batch_size, head, v_embedding // 16, seq_len // 16, 16, 16)
1129
1130    def infer_dtype(self, data_dtype1, data_dtype2, data_dtype3):
1131        return data_dtype1
1132
1133
1134class MatmulDDS(PrimitiveWithInfer):
1135    """MatmulDDS definition"""
1136
1137    @prim_attr_register
1138    def __init__(self, bs, heads):
1139        """init MatmulDDS"""
1140        self.init_prim_io_names(inputs=['q', 'k', 'local_mask', 'global_mask'],
1141                                outputs=['local_prob', 'global_prob'])
1142
1143        self.heads = heads
1144
1145    def infer_shape(self, q, k, local_mask, global_mask):
1146        seq_len = local_mask[0] * local_mask[-1]
1147        bs = q[1] * q[2] // seq_len
1148        global_size = seq_len // 4
1149        size_per_head = q[0] * q[-1] // self.heads
1150        heads = q[0] * q[-1] // size_per_head
1151        block_size = local_mask[1] * local_mask[2] // bs
1152        block_num = seq_len // block_size
1153        l_size = (bs, heads, block_num, block_size // 16, block_size // 16, 16, 16)
1154        g_size = (bs, heads, block_num, global_size // 16, block_size // 16, 16, 16)
1155
1156        return l_size, g_size
1157
1158    def infer_dtype(self, q, k, local_mask, global_mask):
1159        return q, q
1160
1161
1162class DSDGrad(PrimitiveWithInfer):
1163    """
1164    The definition of the CusSquare primitive.
1165    """
1166
1167    @prim_attr_register
1168    def __init__(self):
1169        self.init_prim_io_names(inputs=['w1_gm', 'w2_gm', 'v_gm', 'a_gm', 'd_a_gm'],
1170                                outputs=['d_w1_gm', 'd_w2_gm', 'd_v_gm'])
1171
1172    def infer_shape(self, input_w1_shape, input_w2_shape, input_v_shape, input_a_shape, input_da_shape):
1173        return input_w1_shape, input_w2_shape, input_v_shape
1174
1175    def infer_dtype(self, data_dtype1, data_dtype2, data_dtype3, data_dtype4, data_dtype5):
1176        return data_dtype1, data_dtype1, data_dtype1
1177
1178
1179class MatmulDDSGrad(PrimitiveWithInfer):
1180    """MatmulDDS definition"""
1181
1182    @prim_attr_register
1183    def __init__(self):
1184        """init MatmulDDS"""
1185        self.init_prim_io_names(inputs=['q', 'k', 'local_prob', 'global_prob', 'local_prob_grad', 'global_prob_grad'],
1186                                outputs=['dq', 'dk'])
1187
1188    def infer_shape(self, q, k, local_prob, global_prob, local_prob_grad, global_prob_grad):
1189        k_size = (q[1], q[0], q[3], q[2])
1190
1191        return q, k_size
1192
1193    def infer_dtype(self, q, k, local_prob, global_prob, local_prob_grad, global_prob_grad):
1194        return q, k
1195
1196
1197class NonZeroWithValue(Primitive):
1198    """
1199    Returns the value of elements that are non-zero (in row-major order - by dimension).
1200
1201    Inputs:
1202        - **x** (Tensor), input array of rank >= 2.
1203
1204    Outputs:
1205         elements that are non-zero.
1206
1207    Supported Platforms:
1208        ``Ascend``
1209
1210    Examples:
1211        >>> op = NonZeroWithValue()
1212        >>> data = Tensor(np.array([[1, 0, 0], [0, 0, 1]]), mindspore.float32)
1213        >>> value, index, count = op(data)
1214        >>> print(value)
1215        [1.0, 1.0]
1216    """
1217
1218    @prim_attr_register
1219    def __init__(self, transpose=False):
1220        """Initialize NonZeroWithValue"""
1221        validator.check_value_type("transpose", transpose, [bool], self.name)
1222        self.init_prim_io_names(inputs=['x'], outputs=['value', 'index', 'count'])
1223
1224
1225class NonZeroWithValueShape(Primitive):
1226    """
1227    Returns the value and index of elements that are non-zero (in row-major order - by dimension).
1228
1229    Inputs:
1230        - **x** (Tensor), input array of rank >= 2.
1231
1232    Outputs:
1233         elements that are non-zero.
1234
1235    Supported Platforms:
1236        ``Ascend``
1237
1238    Examples:
1239        >>> non_zero = NonZeroWithValue()
1240        >>> op = NonZeroWithValueShape()
1241        >>> data = Tensor(np.array([[1, 0, 0], [0, 0, 1]]), mindspore.float32)
1242        >>> value, index, count = non_zero(data)
1243        >>> out_value, out_index = op(value, index, count)
1244        >>> print(out_index)
1245        [[0, 1], [0, 2]]
1246    """
1247
1248    @prim_attr_register
1249    def __init__(self):
1250        """Initialize NonZeroWithValueShape"""
1251        self.init_prim_io_names(inputs=['value', 'index', 'count'], outputs=['out_value', 'out_index'])
1252
1253
1254class DecodeImage(PrimitiveWithInfer):
1255    """
1256    Returns image data that parse from string Tensor.
1257
1258    Inputs:
1259        - **x** (Tensor), a Tensor of type string. 0-D. The jPEG, GIF, PNG, BMP-encoded image.
1260
1261    Outputs:
1262         A Tensor of type uint8, uint16, float.
1263
1264    Supported Platforms:
1265        ``Ascend``
1266
1267    Examples:
1268    """
1269
1270    @prim_attr_register
1271    def __init__(self, channels=0, dtype=mstype.uint8, expand_animations=False, _op_max_shape="8192,8192,3",
1272                 _op_max_size=[8000000]):
1273        self.init_prim_io_names(inputs=["contents"], outputs=["image"])
1274        self.res_type = dtype
1275
1276    def infer_shape(self, x):
1277        return (-1, -1, 3)
1278
1279    def infer_dtype(self, x):
1280        return self.res_type
1281
1282
1283class SliceGetItem(Primitive):
1284    """
1285        using SliceGetItem to get slice's attribute of 'start' 'stop' 'step'
1286    """
1287
1288    @prim_attr_register
1289    def __init__(self):
1290        """Initialize ScatterElements"""
1291        self.init_prim_io_names(inputs=['slice', 'attr'], outputs=['slice_item'])
1292
1293    def __call__(self, slice_value, value):
1294        if not isinstance(slice_value, slice):
1295            raise TypeError(
1296                "Primitive[SliceGetItem] only support to get a slice type element but got {}".format(slice_value))
1297        if value == "start":
1298            if hasattr(slice_value.start, "ndim") and slice_value.start.ndim == 1:
1299                return slice_value.start.item()
1300            return slice_value.start
1301        if value == "stop":
1302            if hasattr(slice_value.stop, "ndim") and slice_value.stop.ndim == 1:
1303                return slice_value.stop.item()
1304            return slice_value.stop
1305        if value == "step":
1306            if hasattr(slice_value.step, "ndim") and slice_value.step.ndim == 1:
1307                return slice_value.step.item()
1308            return slice_value.step
1309        raise AttributeError("\'slice\' object has no attribute {}".format(value))
1310
1311
1312class DynamicBroadcastTo(Primitive):
1313    """
1314    Broadcasts input tensor to a given shape.
1315
1316    Inputs:
1317        - **input_x** (Tensor) - The input tensor. The data type should be one of the following types:
1318          float16, float32, int32, int8, uint8.
1319          The shape is :math:`(N,*)` where :math:`*` means any number of additional dimensions.
1320        - **shape** (Tensor): The target shape to broadcast.
1321
1322    Outputs:
1323        Tensor, with the given `shape` and the same data type as `input_x`.
1324
1325    Raises:
1326        ValueError: if the target and input shapes are incompatible.
1327
1328    Supported Platforms:
1329        ``Ascend`` ``GPU`` ``CPU``
1330    """
1331
1332    @prim_attr_register
1333    def __init__(self):
1334        """Initialize DynamicBroadcastTo"""
1335        self.init_prim_io_names(inputs=['x', 'shape'], outputs=['y'])
1336
1337
1338class DynamicResizeNearestNeighbor(Primitive):
1339    r"""
1340    Resizes the input tensor by using the nearest neighbor algorithm.
1341
1342    Resizes the input tensor to a given size by using the nearest neighbor algorithm. The nearest
1343    neighbor algorithm selects the value of the nearest point and does not consider the
1344    values of neighboring points at all, yielding a piecewise-constant interpolant.
1345
1346    Note:
1347        The operator supports dynamic shape.
1348
1349    Args:
1350        align_corners (bool): Whether the centers of the 4 corner pixels of the input
1351                              and output tensors are aligned. Default: ``False``.
1352
1353    Inputs:
1354        - **input_x** (Tensor) - The input tensor. The shape of the tensor is :math:`(N, C, H, W)`.
1355        - **size** (Union[tuple, list]): The target size. The dimension of size must be 2.
1356
1357    Outputs:
1358        Tensor, the shape of the output tensor is  :math:`(N, C, NEW\_H, NEW\_W)`.
1359        The data type is the same as the `input_x`.
1360    """
1361
1362    @prim_attr_register
1363    def __init__(self, align_corners=False):
1364        """Initialize ResizeNearestNeighbor"""
1365        validator.check_value_type("align_corners", align_corners, [bool], self.name)
1366        self.init_prim_io_names(inputs=['image_in'], outputs=['image_out'])
1367
1368
1369class PsROIPooling(PrimitiveWithInfer):
1370    r"""
1371    Position Sensitive ROI-Pooling
1372    Inputs:
1373        - feature(Tensor)
1374        - rois(Tensor)
1375
1376        - **features** (Tensor) - The input features, whose shape must be :math:`(N, C, H, W)`.
1377        - **rois** (Tensor) - The shape is :math:`(rois\_n, 5)`. With data type of float16 or float32.
1378          `rois_n` represents the number of RoI. The size of the second dimension must be `5` and the `5` colunms
1379          are :math:`(image\_index, top\_left\_x, top\_left\_y, bottom\_right\_x, bottom\_right\_y)`.
1380          `image_index` represents the index of image. `top_left_x` and `top_left_y` represent the `x, y`
1381          coordinates of the top left corner of corresponding RoI, respectively. `bottom_right_x` and `bottom_right_y`
1382          represent the `x, y` coordinates of the bottom right corner of corresponding RoI, respectively.
1383
1384    Outputs:
1385        - out shape(rois_num, out_channel, pool_height, pool_width), the result after pooling.
1386        - channel_map shape(rois_num, out_channel, pool_height, pool_width), use for back forward to compute grad
1387    Supported Platforms:
1388        ``GPU``
1389
1390    Examples:
1391        >>> import mindspore
1392        >>> import numpy as np
1393        >>> from mindspore import Tensor
1394        >>> from mindspore.ops.operations import _inner_ops as inner
1395        >>> features = np.random.randn(4, 21 * 7 * 7, 80, 48)
1396        >>> features = Tensor.from_numpy(features).astype(mindspore.float32)
1397        >>> rois = Tensor.from_numpy(
1398        >>>     np.array([
1399        >>>        [0.0000, 150.3563, 200.1320, 579.3563, 602.3452],
1400        >>>        [1.0000, 657.1263, 302.8564, 762.4214, 567.9854],
1401        >>>        [2.0000, 321.3122, 232.2410, 679.0281, 587.6346],
1402        >>>        [3.0000, 664.1630, 387.4919, 778.7322, 562.7321],
1403        >>>     ])).astype(mindspore.float32)
1404        >>> psRoIPooling = inner.PsROIPooling(pooled_height=7, pooled_width=7, num_rois=4,
1405        >>>                                  spatial_scale=1.0/16, out_dim=21,
1406        >>>                                  group_size=7)
1407        >>> out, channel_map = psRoIPooling(features, rois)
1408        >>> print(out.shape)
1409            [4, 21, 7, 7]
1410        >>> print(channel_map.shape)
1411            [4, 21, 7, 7]
1412    """
1413
1414    @prim_attr_register
1415    def __init__(self, pooled_height, pooled_width, num_rois, spatial_scale, out_dim, group_size):
1416        """Initialize PsROIPooling"""
1417        validator.check_value_type("pooled_height", pooled_height, [int], self.name)
1418        validator.check_value_type("pooled_width", pooled_width, [int], self.name)
1419        validator.check_value_type("num_rois", pooled_width, [int], self.name)
1420        validator.check_value_type("spatial_scale", spatial_scale, [float], self.name)
1421        validator.check_value_type("out_dim", out_dim, [int], self.name)
1422        validator.check_value_type("group_size", group_size, [int], self.name)
1423        self.pooled_height = pooled_height
1424        self.pooled_width = pooled_width
1425        self.num_rois = num_rois
1426        self.spatial_scale = spatial_scale
1427        self.out_dim = out_dim
1428        self.group_size = group_size
1429
1430    def infer_shape(self, inputs_shape, rois_shape):
1431        output_shape = [self.num_rois, self.out_dim, self.pooled_height, self.pooled_width]
1432        output_map_shape = [self.num_rois, self.out_dim, self.pooled_height, self.pooled_width]
1433        return output_shape, output_map_shape
1434
1435    def infer_dtype(self, inputs_type, rois_type):
1436        map_type = mstype.TensorType(mstype.int32)
1437        return inputs_type, map_type
1438
1439
1440class ParallelResizeBilinear(PrimitiveWithInfer):
1441    """ParallelResizeBilinear ops"""
1442
1443    @prim_attr_register
1444    def __init__(self, ori_image_size, split_size, src_start_w, dst_start_w, align_corners):
1445        """Initialize ParallelResizeBilinear."""
1446        validator.check_value_type("ori_image_size", ori_image_size, [list, tuple], self.name)
1447        validator.check_value_type("split_size", split_size, [list, tuple], self.name)
1448        validator.check_int(len(split_size), 2, validator.EQ, "len of split_size", self.name)
1449        validator.check_value_type("src_start_w", src_start_w, [int], self.name)
1450        validator.check_value_type("dst_start_w", dst_start_w, [int], self.name)
1451        validator.check_value_type("align_corners", align_corners, [bool], self.name)
1452        self.ori_image_size = list(ori_image_size)
1453        self.split_size = list(split_size)
1454        self.src_start_w = src_start_w
1455        self.dst_start_w = dst_start_w
1456        self.align_corners = align_corners
1457        self.half_pixel_centers = False
1458        self.add_prim_attr('ori_image_size', self.ori_image_size)
1459        self.add_prim_attr('split_size', self.split_size)
1460        self.add_prim_attr('src_start_w', self.src_start_w)
1461        self.add_prim_attr('dst_start_w', self.dst_start_w)
1462        self.add_prim_attr('align_corners', self.align_corners)
1463        self.add_prim_attr('half_pixel_centers', self.half_pixel_centers)
1464
1465    def __infer__(self, x, size):
1466        size_val = size['value']
1467        x_shape = x['shape']
1468        x_dtype = x['dtype']
1469        validator.check_tensor_dtype_valid("x_dtype", x_dtype, [mstype.float16, mstype.float32], self.name)
1470        if size_val is None:
1471            raise ValueError("size must be const input")
1472        output_shape = [x_shape[0], x_shape[1], self.split_size[0], self.split_size[1]]
1473
1474        return {'shape': output_shape,
1475                'dtype': x_dtype,
1476                'value': None}
1477
1478
1479class PartitionedCall(PrimitiveWithInfer):
1480    """
1481    Pass the input tensors to the subgraph and return the output tensors.
1482
1483    Inputs:
1484        - **inputs** (Tuple), the input tensors, which will be passed to subgraph.
1485
1486    Outputs:
1487        - outputs(Tuple), the output tensor returned by subgraph.
1488
1489    Supported Platforms:
1490        ``Ascend``
1491
1492    Examples:
1493    """
1494
1495    @prim_attr_register
1496    def __init__(self, graph, executor_type=""):
1497        super(PartitionedCall, self).__init__(self.__class__.__name__)
1498        self.add_prim_attr("executor_type", executor_type)
1499        self.graph = graph
1500
1501    def infer_shape(self, *inputs):
1502        return NotImplementedError
1503
1504    def infer_dtype(self, *inputs):
1505        return NotImplementedError
1506
1507
1508class CellBackwardHook(PrimitiveWithInfer):
1509    r"""
1510    This operator is used to hook input gradient and output gradient of Cell object.
1511
1512    Note:
1513        This operator is only used in backward hook function of Cell object in pynative mode.
1514
1515    Args:
1516        cell_id (str): Used to identify which cell obj the hook function registered on. For example, 'nn.Add()' is a
1517        cell object.
1518
1519    Inputs:
1520        - **input** - The variable to hook.
1521
1522    Outputs:
1523        - **output** - Returns `input` directly. `CellBackwardHook` does not affect the forward result.
1524
1525    Supported Platforms:
1526        ``Ascend`` ``GPU`` ``CPU``
1527
1528    Examples:
1529        >>> import mindspore as ms
1530        >>> from mindspore import Tensor
1531        >>> from mindspore.ops import GradOperation
1532        >>> from mindspore.ops.operations import _inner_ops as inner
1533        >>> ms.set_context(mode=ms.PYNATIVE_MODE)
1534        >>> def hook_fn(grad):
1535        ...     print(grad)
1536        ...
1537        >>> hook = inner.CellBackwardHook()
1538        >>> hook_fn_key = hook.register_backward_hook(hook_fn)
1539        >>> def hook_test(x, y):
1540        ...     z = x * y
1541        ...     z = hook(z)
1542        ...     z = z * y
1543        ...     return z
1544        ...
1545        >>> grad_all = GradOperation(get_all=True)
1546        >>> def backward(x, y):
1547        ...     return grad_all(hook_test)(x, y)
1548        ...
1549        >>> output = backward(Tensor(1, mindspore.float32), Tensor(2, mindspore.float32))
1550        (Tensor(shape=[], dtype=Float32, value= 2),)
1551        >>> print(output)
1552        (Tensor(shape=[], dtype=Float32, value= 4), Tensor(shape=[], dtype=Float32, value= 4))
1553        >>> hook.remove_backward_hook(hook_fn_key)
1554        >>> output = backward(Tensor(1, mindspore.float32), Tensor(2, mindspore.float32))
1555        >>> print(output)
1556        (Tensor(shape=[], dtype=Float32, value= 4), Tensor(shape=[], dtype=Float32, value= 4))
1557    """
1558
1559    def __init__(self, cell_id=""):
1560        """Initialize CellBackwardHook"""
1561        super(CellBackwardHook, self).__init__(self.__class__.__name__)
1562        self.cell_id = cell_id
1563        self.add_prim_attr("cell_id", cell_id)
1564        self.init_attrs["cell_id"] = cell_id
1565
1566    def __call__(self, args):
1567        if not isinstance(args, tuple):
1568            args = (args,)
1569        return _run_op(self, self.name, args)
1570
1571    def infer_shape(self, *inputs_shape):
1572        if len(inputs_shape) == 1:
1573            return inputs_shape[0]
1574        return inputs_shape
1575
1576    def infer_dtype(self, *inputs_type):
1577        if len(inputs_type) == 1:
1578            return inputs_type[0]
1579        return inputs_type
1580
1581    def register_backward_hook(self, hook_fn):
1582        r"""
1583        This function is used to register backward hook function. Note that this function is only supported in pynative
1584        mode.
1585
1586        Note:
1587            The 'hook_fn' must be defined as the following code.
1588            `cell_id` is the information of registered cell. `grad_input` is the gradient passed to the cell.
1589            `grad_output` is the gradient computed and passed to the next cell or primitive, which may be modified by
1590            returning a new output gradient.
1591            The 'hook_fn' should have the following signature:
1592            hook_fn(cell_id, grad_input, grad_output) -> New output gradient or none.
1593            The 'hook_fn' is executed in the python environment.
1594
1595        Args:
1596            hook_fn (Function): Python function. Backward hook function.
1597
1598        Returns:
1599            - **key** (int) - The key of 'hook_fn'.
1600
1601        Raises:
1602            TypeError: If the `hook_fn` is not a function of python.
1603        """
1604        if not isinstance(hook_fn, (FunctionType, MethodType)):
1605            raise TypeError(f"When using 'register_backward_hook(hook_fn)', the type of 'hook_fn' must be python "
1606                            f"function, but got {type(hook_fn)}.")
1607        key = self.add_backward_hook_fn(hook_fn)
1608        return key
1609
1610    def remove_backward_hook(self, key):
1611        r"""
1612        This function is used to remove backward hook function. Note that this operation is only supported in pynative
1613        mode.
1614
1615        Note:
1616            The 'key' is the object returned by 'register_backward_hook' function of the same CellBackwardHook
1617            operator.
1618
1619        Args:
1620            key (int): The key corresponding to the 'hook_fn'.
1621
1622        Returns:
1623            None.
1624        """
1625        self.remove_backward_hook_fn(key)
1626
1627
1628class Format(PrimitiveWithInfer):
1629    r"""
1630    This operator is used to format a string.
1631
1632    Note:
1633     Current not supported to using by customer.
1634     Only support convert str.format() in user code and it will be converted to be Format
1635     operation by ME-Compiler automatically.
1636
1637
1638    Inputs:
1639     - **input** -
1640     string : the string to be formatted.
1641     args : the format args.
1642
1643    Outputs:
1644     - **output** - Returns formatted string.
1645
1646    Supported Platforms:
1647     ``Ascend`` ``GPU`` ``CPU``
1648    """
1649
1650    @prim_attr_register
1651    def __init__(self):
1652        self.init_prim_io_names(inputs=['string', 'args'], outputs=['string'])
1653
1654    def __infer__(self, str_, *var):
1655        def check_variable(str_, var):
1656            if _check_contains_variable(str_['dtype'], str_['value']):
1657                return True
1658
1659            for item in var:
1660                if _check_contains_variable(item['dtype'], item['value']):
1661                    return True
1662            return False
1663
1664        if check_variable(str_, var):
1665            return {'dtype': mstype.string, 'shape': [], 'value': None}
1666
1667        str_value = str_['value']
1668        kwargs = dict()
1669        var_value = list()
1670
1671        for item in var:
1672            if isinstance(item["dtype"], typing.Keyword):
1673                kwargs.update(item["value"])
1674            var_value.append(item["value"])
1675
1676        value = str_value.format(*var_value, **kwargs)
1677        return {'dtype': mstype.string, 'shape': [], 'value': value}
1678
1679
1680class FlattenConcat(Primitive):
1681    """
1682    Flatten input tensors and concatenate them into several chunk tensors grouped by data types.
1683
1684    Args:
1685        fusion_size (int): Maximum memory chunk size in bytes, 0 for unlimited. Default: 0.
1686
1687    Inputs:
1688        - **tensors** (tuple[Tensor], list[Tensor]) - The input Tensors to be flattened and concatenated.
1689
1690    Outputs:
1691        tuple[Tensor], result chunk tensors.
1692
1693    Supported Platforms:
1694        ``Ascend`` ``GPU`` ``CPU``
1695
1696    Examples:
1697        >>> from mindspore.ops.operations import _inner_ops as inner
1698        >>> t1 = Tensor(np.array([1]).astype(np.float32))
1699        >>> t2 = Tensor(np.array([2]).astype(np.float32))
1700        >>> t3 = Tensor(np.array([3]).astype(np.float64))
1701        >>> t4 = Tensor(np.array([4]).astype(np.float32))
1702        >>> t5 = Tensor(np.array([5]).astype(np.float64))
1703        >>> chunks = inner.FlattenConcat()([t1, t2, t2, t3, t4, t5])
1704        >>> print(chunks[0].asnumpy())
1705        >>> print(chunks[1].asnumpy())
1706        [1. 2. 4.]
1707        [3. 5.]
1708    """
1709
1710    @prim_attr_register
1711    def __init__(self, fusion_size=0):
1712        """Initialize FlattenConcat"""
1713        validator.check_non_negative_int(fusion_size, 'fusion_size', self.name)
1714        self.fusion_size = fusion_size
1715        self.add_prim_attr('fusion_size', fusion_size)
1716
1717
1718class KMeansCentroids(PrimitiveWithInfer):
1719    """
1720    Calculate the segment_sum, segment_count, kmean_total_sum that are clustering results
1721
1722    Args:
1723        use_actual_distance (bool): A bool value to decide whether do complete calculation of distance.
1724
1725    Inputs:
1726        - **x** (Tensor(float32)) - Input data used for clustering
1727        - **y** (Tensor(float32)) - Initial centroids of clutering
1728        - **sum_square_y** (Tensor(float32)) - The result of preprocessing such as square, reduce and transpose of y
1729        - **sum_square_x** (Tensor(float32)) - The result of preprocessing such as square and reduce of x
1730
1731    Outputs:
1732        - **segment_sum** (Tensor(float32)) - Clustering result w.r.t. each centroid
1733        - **segment_count** (Tensor(float32)) - Clustering count w.r.t. each centroid
1734        - **kmean_total_sum** (Tensor(float32)) - The sum of the distances from all vectors to ther nearest centroid
1735
1736    Supported Platforms:
1737        ''Ascend''
1738
1739    Examples:
1740        >>> import numpy as np
1741        >>> import mindspore as ms
1742        >>> import mindspore.common.dtype as mstype
1743        >>> import mindspore.nn as nn
1744        >>> from mindspore import Tensor
1745        >>> from mindspore.ops import operations as P
1746        >>> ms.set_context(mode=ms.GRAPH_MODE, device_target="Ascend")
1747
1748        >>> class Net(nn.Cell):
1749        >>>    def __init__(self):
1750        >>>        super(Net, self).__init__()
1751        >>>        self.reduce_sum = P.ReduceSUm(keep_dims=True)
1752        >>>        self.square = P.Square()
1753        >>>        self.transpose = P.Transpose()
1754        >>>        self.k_means_centroids = P.KMeansCentroids(True)
1755
1756        >>>    def construct(self, x, y):
1757        >>>        p1 = self.reduce_sum(self.square(x), -1)
1758        >>>        p2 = self.transpose(self.reduce_sum(self.square(y), -1), (1, 0))
1759        >>>        return self.k_means_centroids(x, y, p2, p1)
1760
1761        >>> def test_net():
1762        >>>    data_type = np.float32
1763        >>>    x = Tensor(np.random.uniform(-10, 10, (65536, 128)).astype(data_type))
1764        >>>    y = P.Ones()((1048576, 128), mstype.float32)
1765        >>>    net = Net()
1766        >>>    local_sum, local_count, local_avg_distance = net(x, y)
1767    """
1768
1769    @prim_attr_register
1770    def __init__(self, use_actual_distance):
1771        validator.check_value_type('use_actual_distance', use_actual_distance, [bool], self.name)
1772        self.init_prim_io_names(inputs=['x', 'y', 'sum_square_y', 'sum_square_x'],
1773                                outputs=['segment_sum', 'segment_count', 'kmean_total_sum'])
1774
1775    def infer_shape(self, x_shape, y_shape, sum_square_y_shape, sum_square_x_shape):
1776        """infer shape of primitive"""
1777        expected_shape_size = 2
1778        validator.check_int(len(x_shape), expected_shape_size, validator.EQ, "dims of x", self.name)
1779        validator.check_int(len(y_shape), expected_shape_size, validator.EQ, "dims of y", self.name)
1780        validator.check_int(len(sum_square_y_shape), expected_shape_size, validator.EQ,
1781                            "dims of sum_square_y", self.name)
1782        validator.check_int(len(sum_square_x_shape), expected_shape_size, validator.EQ,
1783                            "dims of sum_square_x", self.name)
1784
1785        validator.check_int(x_shape[1], y_shape[1], validator.EQ,
1786                            "the second dim of x and the second dim of y", self.name)
1787        validator.check_int(y_shape[0], sum_square_y_shape[1], validator.EQ,
1788                            "the first dim of y and the second dim of sum_square_y", self.name)
1789        validator.check_int(x_shape[0], sum_square_x_shape[0], validator.EQ,
1790                            "the first dim of x and the first dim of sum_square_x", self.name)
1791        validator.check_int(sum_square_y_shape[0], sum_square_x_shape[1], validator.EQ,
1792                            "the first dim of sum_square_y and the first dim of sum_square_x",
1793                            self.name)
1794        validator.check_int(sum_square_y_shape[0], 1, validator.EQ,
1795                            "the first dim of sum_square_y", self.name)
1796
1797        k = y_shape[0]
1798        em_size = x_shape[1]
1799        return (k, em_size), (k, 1), (1)
1800
1801
1802class ClipByNorm(PrimitiveWithInfer):
1803    r"""
1804    Clips tensor values to a maximum :math:`L_2`-norm.
1805
1806    Note:
1807        The output tensor of this operator remains the same with input tensor if the :math:`L_2`-norm of the input
1808        tensor is not greater than the argument `clip_norm`. Otherwise the output tensor will be normalized as:
1809
1810        .. math::
1811            \text{output}(X) = \frac{\text{clip_norm} * X}{L_2(X)},
1812
1813        where :math:`L_2(X)` is the :math:`L_2`-norm of :math:`X`.
1814
1815    Args:
1816        axis (Union[None, int, tuple(int), list(int)]): Compute the `L_2`-norm along the specific dimension.
1817                                                       Default: ``None``, all dimensions to calculate.
1818
1819    Inputs:
1820        - **x** (Tensor) - Tensor of shape N-D. The type must be float16 or float32.
1821        - **clip_norm** (Tensor) - A scalar Tensor of shape :math:`()` or :math:`(1)`.
1822          Or a Tensor which shape can be broadcast to the shape of `x`. The type must be float16 or float32.
1823
1824    Outputs:
1825        Tensor, clipped Tensor with the same shape as the `x`, whose type is float32.
1826
1827    Raises:
1828        TypeError: If `axis` is not one of None, int, tuple(int) and list(int).
1829        TypeError: If dtype of `x` is neither float16 nor float32.
1830        TypeError: If dtype of `clip_norm` is neither float16 nor float32.
1831
1832    Supported Platforms:
1833        ``Ascend`` ``GPU`` ``CPU``
1834
1835    Examples:
1836        >>> import numpy as np
1837        >>> import mindspore
1838        >>> from mindspore import Tensor
1839        >>> from mindspore.ops.operations import _inner_ops as inner
1840        >>> clip_by_norm = inner.ClipByNorm()
1841        >>> x = Tensor(np.random.randint(0, 10, [4, 16]), mindspore.float32)
1842        >>> clip_norm = Tensor(np.array([100]).astype(np.float32))
1843        >>> output = clip_by_norm(x, clip_norm)
1844        >>> print(output.shape)
1845        (4, 16)
1846    """
1847
1848    @prim_attr_register
1849    def __init__(self, axis=None):
1850        """Initialize ClipByNorm"""
1851        self.axis = () if axis is None else axis
1852        validator.check_value_type('axis', self.axis, [int, tuple, list], self.name)
1853        axis_check = self.axis if isinstance(self.axis, Iterable) else (self.axis,)
1854        for i, value in enumerate(axis_check):
1855            validator.check_value_type('axis[%d]' % i, value, [int], self.name)
1856        self.init_attrs['axis'] = self.axis
1857        self.add_prim_attr('axis', self.axis)
1858        self.init_prim_io_names(inputs=['x', 'clip_norm'], outputs=['output'])
1859
1860    def infer_shape(self, x_shape, clip_norm_shape):
1861        """Infer shape for ClipByNorm"""
1862        x_dim = len(x_shape)
1863        axis = self.axis if isinstance(self.axis, Iterable) else (self.axis,)
1864        for _, value in enumerate(axis):
1865            validator.check_int_range(value, -x_dim, x_dim, validator.INC_LEFT, 'axis', self.name)
1866        return x_shape
1867
1868    def infer_dtype(self, x_type, clip_norm_type):
1869        """Infer data type for ClipByNorm"""
1870        validator.check_tensor_dtype_valid("x_type", x_type, [mstype.float16, mstype.float32], self.name)
1871        validator.check_tensor_dtype_valid("clip_norm_type", clip_norm_type,
1872                                           [mstype.float16, mstype.float32], self.name)
1873        return mstype.float32
1874
1875
1876class TopTypeof(Primitive):
1877    """
1878        Internal primitive method, to speed up mindspore.ops.typeof.
1879
1880        Returns the top type of the input data.
1881
1882        In Pynative mode, returns the top type in cache.
1883
1884        Supported Platforms:
1885            ``Ascend`` ``GPU`` ``CPU``
1886    """
1887
1888    @prim_attr_register
1889    def __init__(self):
1890        self.prim = Primitive('TopTypeof')
1891        self.typeof_cache = {
1892            'slice': mstype.Slice(),
1893            'list': mstype.List(),
1894            'tuple': mstype.Tuple(),
1895            'Tensor': mstype.tensor_type,
1896            'NoneType': mstype.NoneType(),
1897            'int': mstype.Int(),
1898            'bool': mstype.Bool(),
1899            'ellipsis': mstype.Ellipsis_(),
1900            'dict': mstype.Dict()
1901        }
1902
1903    def __call__(self, x):
1904        index_type = type(x).__name__
1905        if 'Tensor' in index_type:
1906            index_type = 'Tensor'
1907        if index_type in self.typeof_cache:
1908            return self.typeof_cache.get(index_type)
1909        return _pynative_executor.constant_folding(self.prim, x)
1910
1911
1912class MixedPrecisionCast(Primitive):
1913    r"""
1914    Internal primitive method, to achieve mindspore.functional.mixed_precision_cast.
1915
1916    Note:
1917        This internal primitive method used to do mixed precision conversion.
1918        Only the input object with float dtype will be cast.
1919
1920    Inputs:
1921        - **dtype** (Union[Float16, Float32]) - The data type of the output object.
1922        - **input** (Union[Tensor, Tuple, Dictionary, KeywordArg]) - The object to be cast.
1923
1924    Outputs:
1925        Object, its dtype is the same as `dtype` and shape is the same as 'input'.
1926
1927    Supported Platforms:
1928        ``Ascend`` ``GPU`` ``CPU``
1929
1930    Examples:
1931        >>> import numpy as np
1932        >>> from mindspore import Tensor
1933        >>> from mindspore import dtype as mstype
1934        >>> from mindspore.ops.operations import _inner_ops as inner
1935        >>> x = Tensor(np.ones([2, 3], dtype=np.float32))
1936        >>> out = inner.MixedPrecisionCast(mstype.float16, x)
1937        >>> print(out.dtype)
1938        Float16
1939    """
1940
1941    @prim_attr_register
1942    def __init__(self):
1943        """Initialize MixedPrecisionCast"""
1944        self.init_prim_io_names(inputs=['dst_dtype', 'input_x'], outputs=['output'])
1945        self.cast = Cast()
1946        self.hyper_map = C.HyperMap()
1947
1948    def __call__(self, dst_dtype, x):
1949        def cast_inner(data):
1950            if isinstance(data, Tensor) and data.dtype in (mstype.float16, mstype.float32,
1951                                                           mstype.float64, mstype.bfloat16):
1952                return self.cast(data, dst_dtype)
1953            return data
1954
1955        return self.hyper_map(cast_inner, x)
1956
1957
1958class CheckBprop(PrimitiveWithInfer):
1959    """
1960    Checks whether the data type and the shape of corresponding elements from tuples x and y are the same.
1961
1962    Args:
1963        prim_to_check (str): The name of the primitive being checked. Default: ''.
1964
1965    Inputs:
1966        - **input_x** (tuple[Tensor]) - The `input_x` contains the outputs of bprop to be checked.
1967        - **input_y** (tuple[Tensor]) - The `input_y` contains the inputs of bprop to check against.
1968
1969    Outputs:
1970        Tuple[Tensor], the `input_x`,
1971        if data type and shape of corresponding elements from `input_x` and `input_y` are the same.
1972
1973    Raises:
1974        TypeError: If `input_x` or `input_y` is not a Tensor.
1975
1976    Supported Platforms:
1977        ``Ascend`` ``GPU`` ``CPU``
1978
1979    Examples:
1980        >>> class Net(nn.Cell):
1981        ...     def __init__(self):
1982        ...         super(Net, self).__init__()
1983        ...         self.op = ops.CheckBprop()
1984        ...     def construct(self, x, y):
1985        ...         return self.op(x, y)
1986        ...
1987        >>> net = Net()
1988        >>> input_x = (Tensor(np.array([[2, 2], [2, 2]]), mindspore.float32),)
1989        >>> input_y = (Tensor(np.array([[2, 2], [2, 2]]), mindspore.float32),)
1990        >>> output = net(input_x, input_y)
1991        >>> print(output)
1992        (Tensor(shape=[2, 2], dtype=Float32, value=
1993        [[ 2.00000000e+00,  2.00000000e+00],
1994         [ 2.00000000e+00,  2.00000000e+00]]),)
1995    """
1996
1997    @prim_attr_register
1998    def __init__(self, prim_to_check=""):
1999        """Initialize CheckBprop"""
2000        self.prim_to_check = prim_to_check
2001
2002    def infer_shape(self, xshapes, yshapes):
2003        """infer shape"""
2004        tips = f"user defined method 'bprop'"
2005        validator.check_value_type('grads', xshapes, (tuple,), tips)
2006        validator.check_value_type('params', yshapes, (tuple,), tips)
2007        if not len(xshapes) == len(yshapes):
2008            raise ValueError(f"For {tips} the number of return values(gradients) must be equal to "
2009                             f"the number of input arguments except 'out' and 'dout', "
2010                             f"which is:{len(yshapes)} but got {len(xshapes)}.")
2011
2012        def shape_equal(shape1, shape2):
2013            if len(shape1) != len(shape2):
2014                return False
2015            for shape_axis1, shape_axis2 in zip(shape1, shape2):
2016                if shape_axis1 == -1 or shape_axis2 == -1:
2017                    continue
2018                if shape_axis1 != shape_axis2:
2019                    return False
2020            return True
2021
2022        for i, (xshape, yshape) in enumerate(zip(xshapes, yshapes)):
2023            if not xshape or not yshape:
2024                continue
2025
2026            if not shape_equal(xshape, yshape):
2027                raise ValueError(f"For {tips}, the {i}th return value(gradient of the {i}th argument) "
2028                                 f"should have the same shape as the {i}th argument, "
2029                                 f"which is:{yshape}, but got: {xshape}.")
2030        return xshapes
2031
2032    def infer_dtype(self, xdtypes, ydtypes):
2033        """infer dtype"""
2034        tips = f"user defined method 'bprop'"
2035        validator.check_value_type('grads', xdtypes, (tuple,), tips)
2036        validator.check_value_type('params', ydtypes, (tuple,), tips)
2037        if not len(xdtypes) == len(ydtypes):
2038            raise ValueError(f"For {tips}, the number of return values(gradients) must be equal to "
2039                             f"the number of input arguments except 'out' and 'dout', "
2040                             f"which is:{len(ydtypes)} but got {len(xdtypes)}.")
2041        checking_range = len(ydtypes)
2042        for i in range(checking_range):
2043            xdtype = xdtypes[i]
2044            ydtype = ydtypes[i]
2045            if isinstance(xdtype, mstype.AnythingType) or isinstance(ydtype, mstype.AnythingType):
2046                continue
2047            if isinstance(ydtype, mstype.FunctionType):
2048                if not isinstance(xdtype, mstype.EnvType):
2049                    raise TypeError(f"For {tips}, the {i}th return value(gradient of the {i}th argument) type "
2050                                    f"should be {mstype.EnvType}, but got {xdtype}.")
2051            if xdtype != ydtype:
2052                raise TypeError(f"For {tips}, the {i}th return value(gradient of the {i}th argument) "
2053                                f"should have the same dtype as the {i}th argument, "
2054                                f"which is:{ydtype}, but got: {xdtype}.")
2055        return xdtypes
2056
2057
2058check_bprop = CheckBprop()
2059
2060
2061class SameTypeShape(PrimitiveWithInfer):
2062    """
2063    Checks whether the data type and shape of two tensors are the same.
2064
2065    Refer to :func:`mindspore.ops.same_type_shape` for more detail.
2066
2067    Supported Platforms:
2068        ``Ascend`` ``GPU`` ``CPU``
2069
2070    Examples:
2071        >>> input_x = Tensor(np.array([[2, 2], [2, 2]]), mindspore.float32)
2072        >>> input_y = Tensor(np.array([[2, 2], [2, 2]]), mindspore.float32)
2073        >>> output = ops.SameTypeShape()(input_x, input_y)
2074        >>> print(output)
2075        [[2. 2.]
2076         [2. 2.]]
2077    """
2078
2079    @prim_attr_register
2080    def __init__(self):
2081        """Initialize Same"""
2082
2083    def __call__(self, x, y):
2084        """run in PyNative mode"""
2085        validator.check_value_type('x', x, Tensor, self.name)
2086        validator.check_value_type('y', y, Tensor, self.name)
2087        validator.check('x dtype', x.dtype, 'y dtype', y.dtype, validator.EQ, self.name, TypeError)
2088        validator.check('x shape', x.shape, 'y shape', y.shape, validator.EQ, self.name)
2089        return x
2090
2091    def __infer__(self, x, y):
2092        validator.check_subclass('x', x['dtype'], mstype.tensor_type, self.name)
2093        validator.check_subclass('y', y['dtype'], mstype.tensor_type, self.name)
2094        validator.check('x dtype', x['dtype'], 'y dtype', y['dtype'], validator.EQ, self.name, TypeError)
2095        validator.check('x shape', x['shape'], 'y shape', y['shape'], validator.EQ, self.name)
2096        return x
2097
2098
2099same_type_shape_ = SameTypeShape()
2100
2101
2102def _is_subclass_(type_, dtype):
2103    if not isinstance(type_, typing.Type):
2104        return False
2105    return typing.is_subclass(type_, dtype)
2106
2107
2108class IsSubClass(PrimitiveWithInfer):
2109    """
2110    Checks whether this type is a sub-class of another type.
2111
2112    Inputs:
2113        - **sub_type** (mindspore.dtype) - The type to be checked. Only constant value is allowed.
2114        - **type_** (mindspore.dtype) - The target type. Only constant value is allowed.
2115
2116    Outputs:
2117        bool, the check result.
2118
2119    Raises:
2120        TypeError: If `sub_type` or `type_` is not a Type.
2121
2122    Supported Platforms:
2123        ``Ascend`` ``GPU`` ``CPU``
2124
2125    Examples:
2126        >>> output = ops.IsSubClass()(mindspore.int32,  mindspore.intc)
2127        >>> print(output)
2128        True
2129    """
2130
2131    @prim_attr_register
2132    def __init__(self):
2133        pass
2134
2135    def __infer__(self, sub_type, type_):
2136        sub_type_t = sub_type['value']
2137        type_v = type_['value']
2138
2139        validator.check_value_type("sub_type", sub_type_t, [mstype.Type], self.name)
2140        validator.check_value_type("type_", type_v, [mstype.Type], self.name)
2141
2142        value = _is_subclass_(sub_type_t, type_v)
2143
2144        out = {'shape': (),
2145               'dtype': mstype.type_type,
2146               'value': value}
2147        return out
2148
2149
2150issubclass_ = IsSubClass()
2151
2152
2153class IsInstance(PrimitiveWithInfer):
2154    """
2155    Checks whether an object is an instance of a target type.
2156
2157    Inputs:
2158        - **inst** (Any Object) - The instance to be checked. Only constant value is allowed.
2159        - **type_** (mindspore.dtype) - The target type. Only constant value is allowed.
2160
2161    Outputs:
2162        bool, the check result.
2163
2164    Raises:
2165        TypeError: If `type_` is not a Type.
2166
2167    Supported Platforms:
2168        ``Ascend`` ``GPU`` ``CPU``
2169
2170    Examples:
2171        >>> inst = 1
2172        >>> output = ops.IsInstance()(inst, mindspore.int32)
2173        >>> print(output)
2174        False
2175    """
2176
2177    @prim_attr_register
2178    def __init__(self):
2179        pass
2180
2181    def __infer__(self, inst, type_):
2182        sub_type_t = inst['dtype']
2183        type_v = type_['value']
2184
2185        validator.check_value_type("type_", type_v, [mstype.Type], self.name)
2186
2187        if type_v == mstype.list_:
2188            value = isinstance(sub_type_t, list)
2189        elif type_v == mstype.tuple_:
2190            value = isinstance(sub_type_t, tuple)
2191        else:
2192            value = _is_subclass_(sub_type_t, type_v)
2193
2194        out = {'shape': (),
2195               'dtype': mstype.type_type,
2196               'value': value}
2197        return out
2198
2199
2200class ConvertToAdapterTensor(Primitive):
2201    """
2202    Convert a tensor from MindSpore's Tensor type to MSAdapter's Tensor type,
2203    where MSAdapter's Tensor is a subclass of MindSpore's Tensor.
2204
2205    Inputs:
2206        - **x** (Tensor) - The input tensor.
2207
2208    Outputs:
2209        A tensor, whose type is MSAdapter's Tensor.
2210
2211    Supported Platforms:
2212        ``Ascend`` ``GPU`` ``CPU``
2213
2214    Examples:
2215        >>> x = Tensor([1, 2 ,3])
2216        >>> x = ops.ConvertToAdapterTensor()(x)
2217        >>> print(x)
2218        [1 2 3]
2219    """
2220
2221    @prim_attr_register
2222    def __init__(self):
2223        """Initialize"""
2224
2225    def __call__(self, x):
2226        """Run in PyNative mode"""
2227        return ms_adapter_registry.tensor(x, cast_tensor=True)
2228
2229
2230convert_to_adapter_tensor = ConvertToAdapterTensor()
2231
2232
2233class ConvertToMsTensor(Primitive):
2234    """
2235    Convert a tensor from MSAdapter's Tensor type to MindSpore's Tensor type,
2236    where MSAdapter's Tensor is a subclass of MindSpore's Tensor.
2237
2238    Inputs:
2239        - **x** (Tensor) - The input tensor.
2240
2241    Outputs:
2242        A tensor, whose type is MindSpore's Tensor.
2243
2244    Supported Platforms:
2245        ``Ascend`` ``GPU`` ``CPU``
2246
2247    Examples:
2248        >>> x = Tensor([1, 2 ,3])
2249        >>> x = ops.ConvertToMsTensor()(x)
2250        >>> print(x)
2251        [1 2 3]
2252    """
2253
2254    @prim_attr_register
2255    def __init__(self):
2256        """Initialize"""
2257
2258    def __call__(self, x):
2259        """Run in PyNative mode"""
2260        if isinstance(x, StubTensor):
2261            return StubTensor(stub=x.stub, tensor=x.tensor)
2262        return ops.auto_generate.deepcopy(x)
2263
2264
2265convert_to_ms_tensor = ConvertToMsTensor()
2266
2267
2268class GetGrad(Primitive):
2269    """
2270        Use the position id or Parameter object to get the gradient from the output
2271        which returned by the :func:`mindspore.ops.grad`.
2272    """
2273
2274    @prim_attr_register
2275    def __init__(self):
2276        """Initialize ScatterElements"""
2277        self.init_prim_io_names(
2278            inputs=['gradients', 'x'], outputs=['gradient'])
2279
2280    def __call__(self, gradients, x):
2281        if not isinstance(x, int) and not isinstance(x, Parameter):
2282            raise TypeError(
2283                f"For `get_grad`, the `x` should be an integer or a Parameter, but got {x}")
2284        hash_id = x
2285        if isinstance(x, Parameter):
2286            hash_id = x.name
2287        output = None
2288
2289        def _get_grad(grads, identifier):
2290            if isinstance(grads, tuple):
2291                if len(grads) != 2 or identifier != grads[0]:
2292                    for gradient in grads:
2293                        _get_grad(gradient, identifier)
2294                else:
2295                    nonlocal output
2296                    output = grads[1]
2297                    return
2298
2299        _get_grad(gradients, hash_id)
2300        if output is None:
2301            raise RuntimeError(
2302                f"Can not find the gradient for position or Parameter {x}")
2303        return output
2304
2305
2306class IsParameter(PrimitiveWithInfer):
2307    """
2308        Check if input is `Parameter`
2309    """
2310
2311    @prim_attr_register
2312    def __init__(self):
2313        """Initialize IsParameter"""
2314
2315    def __call__(self, x):
2316        return isinstance(x, Parameter)
2317
2318    def __infer__(self, x):
2319        return {'shape': [],
2320                'dtype': mstype.bool_,
2321                'value': isinstance(x['dtype'], mstype.RefType)}
2322
2323
2324class TileSize(Primitive):
2325    r"""
2326        Tile size for matmul
2327    """
2328
2329    @prim_attr_register
2330    def __init__(self):
2331        """Initialize TileSize"""
2332        self.init_prim_io_names(inputs=['shape', 'out_shape', 'ndim'], outputs=['output'])
2333
2334    def __call__(self, shape, out_shape, ndim):
2335        size = [1] * ndim
2336        for idx, (i, j) in enumerate(zip(shape, out_shape)):
2337            if i != j:
2338                size[idx] = j
2339        return tuple(size)
2340
2341
2342class GetitemTensorIndexInfo(Primitive):
2343    r"""
2344        Get getitem tensor index info
2345    """
2346
2347    @prim_attr_register
2348    def __init__(self, is_ascend):
2349        """Initialize GetitemTensorIndexInfo"""
2350        self.init_prim_io_names(inputs=['data', 'index'],
2351                                outputs=["new_index", "tensor_update_types", "tensor_update_args"])
2352        validator.check_value_type('is_ascend', is_ascend, [bool], self.name)
2353        self.is_ascend = is_ascend
2354
2355    def __call__(self, data, index):
2356        return Tensor_.getitem_index_info(data, index, self.is_ascend)
2357
2358
2359class SetitemTensorIndexInfo(Primitive):
2360    r"""
2361        Get setitem tensor index info
2362    """
2363
2364    @prim_attr_register
2365    def __init__(self, is_ascend):
2366        """Initialize GetitemTensorIndexInfo"""
2367        self.init_prim_io_names(
2368            inputs=['data', 'index', 'value'], outputs=['new_index',
2369                                                        'v_transfer_types',
2370                                                        'v_transfer_args',
2371                                                        'tensor_update_types',
2372                                                        'tensor_update_args'])
2373        validator.check_value_type('is_ascend', is_ascend, [bool], self.name)
2374        self.is_ascend = is_ascend
2375
2376    def __call__(self, data, index, value):
2377        return Tensor_.setitem_index_info(data, index, value, self.is_ascend)
2378
2379
2380class IsConstant(Primitive):
2381    r"""
2382        Check if the input is constant
2383    """
2384
2385    @prim_attr_register
2386    def __init__(self):
2387        """Initialize IsConstant"""
2388
2389    def __call__(self, x):
2390        return True
2391
2392
2393class SelectView(Primitive):
2394    r"""
2395        Select tensor of view
2396    """
2397
2398    @prim_attr_register
2399    def __init__(self):
2400        self.init_prim_io_names(inputs=['input_tensor', 'input_indices', 'axis'], outputs=['output'])
2401
2402
2403class CopyWithSlice(Primitive):
2404    r"""
2405        Copy data to discontinuous tensor
2406    """
2407
2408    @prim_attr_register
2409    def __init__(self):
2410        self.add_prim_attr('side_effect_mem', True)
2411        self.init_prim_io_names(inputs=['x', 'y'], outputs=['x'])
2412
2413
2414class FFN(Primitive):
2415    r"""
2416    The FFN computation is similar to Feed-Forward Network, it contains matmul + gelu + matmul.
2417
2418    Args:
2419        activation (string): The activation type, set to 'fastgelu' or 'gelu'.
2420            Only support 'fastgelu' for now. Default: "fastgelu".
2421        inner_precise (int): The precise mode, set to 0 for high precision or 1 for high performance.
2422            Only support 1 for now. Default: 0.
2423
2424    Inputs:
2425        - **x** (Tensor) - The input tensor with data type of int8, float16.
2426          Input tensor of shape :math:`(batch\_size * seq\_length, hidden\_size)`.
2427        - **weight1** (Tensor) - The weight1 tensor with data type of float16.
2428          Weight1 tensor of shape :math:`(expert\_num, hidden\_size, ffn\_hidden\_size)`.
2429        - **weight2** (Tensor) - The weight2 tensor with data type of float16.
2430          Weight2 tensor of shape :math:`(expert\_num, ffn\_hidden\_size, hidden\_size)`.
2431        - **expert_tokens** (Tensor]) - The expert tokens tensor with data type of int64.
2432          Expert tokens tensor of shape :math:`(16,)`. For example, `(2, 1, 0, .., 9)`
2433          indicate that the 0th expert deals with 2 tokens, the 1th expert deals with 1 tokens,
2434          the 2th expert do noting and so on.
2435        - **bias1** (Tensor) - The bias1 tensor with data type of float16.
2436          Bias1 tensor of shape :math:`(expert\_num, ffn\_hidden\_size)`.
2437        - **bias2** (Tensor) - The bias2 tensor with data type of float16.
2438          Bias2 tensor of shape :math:`(expert\_num, hidden\_size)`.
2439        - **scale** (Tensor) - The scale tensor with data type of float16. Not enable now.
2440        - **offset** (Tensor) - The offset tensor with data type of float16. Not enable now.
2441        - **deq_scale1** (Tensor) - The deq_scale1 tensor with data type of float16. Not enable now.
2442        - **deq_scale2** (Tensor) - The deq_scale2 tensor with data type of float16. Not enable now.
2443
2444    Outputs:
2445        Tensor of shape :math:`(batch\_size * seq\_length, hidden\_size)`. With data type of float16.
2446
2447    Supported Platforms:
2448        ``Ascend``
2449
2450    Examples:
2451        >>> from mindspore.ops.operations import _inner_ops
2452        >>> b = 4
2453        >>> s = 128
2454        >>> h = 1024
2455        >>> h_f = 4 * h
2456        >>> e = 16
2457        >>> x = Tensor(np.random.randn(s, h).astype(np.float16))
2458        >>> w1 = Tensor(np.random.randn(e, h, h_f).astype(np.float16))
2459        >>> w2 = Tensor(np.random.randn(e, h_f, h).astype(np.float16))
2460        >>> expert_tokens = Tensor(np.full(e, 8))
2461        >>> bias1 = Tensor(np.random.randn(e, h_f).astype(np.float16))
2462        >>> bias2 = Tensor(np.random.randn(e, h).astype(np.float16))
2463        >>> ffn = _inner_ops.FFN("fastgelu", 1)
2464        >>> output = ffn(x, w1, w2, expert_tokens, bias1, bias2)
2465        >>> print(output)
2466    """
2467
2468    @prim_attr_register
2469    def __init__(self, activation, inner_precise):
2470        """Initialize FFN."""
2471        self.init_prim_io_names(inputs=["x", "weight1", "weight2", "expert_tokens", "bias1",
2472                                        "bias2", "scale", "offset", "deq_scale1", "deq_scale2",
2473                                        "antiquant_scale1", "antiquant_scale2",
2474                                        "antiquant_offset1", "antiquant_offset2"],
2475                                outputs=["y"])
2476        cls_name = self.name
2477        validator.check_value_type("activation", activation, [str], cls_name)
2478        validator.check_value_type("inner_precise", inner_precise, [int], cls_name)
2479
2480
2481class _MirrorSilentCheck(PrimitiveWithInfer):
2482    """
2483    The operator _MirrorSilentCheck implements accuracy-sensitive detection on the tensor input in backpropagator.
2484    Call _MirrorSilentCheck in method __call__ of derived class to implement accuracy-sensitive detection.
2485
2486    Inputs:
2487        - **input** (Tensor) : The tensor used for detection.
2488          Its data type must be mindspore.float16, mindspore.float32 or mindspore.bfloat16.
2489        - **pre_val** (Parameter(Tensor)) : Support parameter in accuracy-sensitive detection.
2490          Please only generated by method generate_params() of ASDBase.
2491        - **min_val** (Parameter(Tensor)) : Support parameter in accuracy-sensitive detection.
2492          Please only generated by method generate_params() of ASDBase.
2493        - **max_val** (Parameter(Tensor)) : Support parameter in accuracy-sensitive detection.
2494          Please only generated by method generate_params() of ASDBase.
2495        - **cnt** (Parameter(Tensor)) : Support parameter in accuracy-sensitive detection.
2496          Please only generated by method generate_params() of ASDBase.
2497          After each invocation of _MirrorSilentCheck, increment the value of cnt by one.
2498
2499    Outputs:
2500        - **output** (Tensor) - Same shape, type and value as `input`.
2501    """
2502    @prim_attr_register
2503    def __init__(self, min_steps=8):
2504        upper_thresh, sigma_thresh = self.get_thresh()
2505        self.min_steps = min_steps
2506        self.thresh_l1 = upper_thresh[0]
2507        self.coeff_l1 = sigma_thresh[0]
2508        self.thresh_l2 = upper_thresh[1]
2509        self.coeff_l2 = sigma_thresh[1]
2510        self.add_prim_attr('side_effect_mem', True)
2511
2512    def parse_thresh(self, env_var_name, default_value, min_value):
2513        env_var = os.environ.get(env_var_name, default=default_value)
2514        thresh = [value.strip() for value in env_var.split(",")]
2515        if len(thresh) != 2 or not all(value.isdigit() for value in thresh):
2516            thresh = default_value.split(",")
2517        thresh = [float(max(int(value), min_value)) for value in thresh]
2518        if thresh[0] <= thresh[1]:
2519            thresh = [float(value) for value in default_value.split(",")]
2520
2521        return thresh
2522
2523    def get_thresh(self):
2524        upper_thresh = self.parse_thresh("NPU_ASD_UPPER_THRESH", "1000000,10000", 3)
2525        sigma_thresh = self.parse_thresh("NPU_ASD_SIGMA_THRESH", "100000,5000", 3)
2526        return upper_thresh, sigma_thresh
2527
2528    def infer_shape(self, x_shape, pre_shape, min_shape, max_shape, n_step, loss_scale_shape):
2529        return x_shape
2530
2531    def infer_dtype(self, x_dtype, pre_dtype, min_dtype, max_dtype, n_dtype, loss_scale_dtype):
2532        return x_dtype
2533
2534
2535class _VirtualConverterEnd(PrimitiveWithInfer):
2536    """
2537    Auto parallel virtual operator.
2538    """
2539
2540    @prim_attr_register
2541    def __init__(self, input_nums):
2542        """Initialize _VirtualConverterEnd."""
2543        self.input_nums = input_nums
2544
2545    def infer_shape(self, *args):
2546        return (args[0][0] * self.input_nums,) + tuple(args[0][1:])
2547
2548    def infer_dtype(self, *args):
2549        return args[0]
2550
2551
2552class _VirtualConverterBegin(PrimitiveWithInfer):
2553    """
2554    Auto parallel virtual operator.
2555    """
2556
2557    @prim_attr_register
2558    def __init__(self, output_nums):
2559        """Initialize _VirtualConverterBegin."""
2560        self.output_nums = output_nums
2561
2562    def infer_shape(self, arg):
2563        new_arg = (arg[0] / self.output_nums,) + tuple(arg[1:])
2564        return (new_arg,) * self.output_nums
2565
2566    def infer_dtype(self, arg):
2567        return (arg,) * self.output_nums
2568