• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16"""Inner operators."""
17
18import numpy as np
19from mindspore.common import Tensor
20from ..._checkparam import Rel
21from ..._checkparam import Validator as validator
22from ... import context
23from ...common import dtype as mstype
24from ..primitive import PrimitiveWithCheck, PrimitiveWithInfer, prim_attr_register, Primitive
25from ..operations.math_ops import _infer_shape_reduce
26from ...communication.management import GlobalComm
27from .. import signature as sig
28
29
30class ExtractImagePatches(PrimitiveWithInfer):
31    """
32    Extracts patches from images.
33    The input tensor must be a 4-D tensor and the data format is NHWC.
34
35    Args:
36        ksizes (Union[tuple[int], list[int]]): The size of sliding window, must be a tuple or a list of integers,
37            and the format is [1, 1, ksize_row, ksize_col].
38        strides (Union[tuple[int], list[int]]): Distance between the centers of the two consecutive patches,
39            must be a tuple or list of int, and the format is [1, 1, stride_row, stride_col].
40        rates (Union[tuple[int], list[int]]): In each extracted patch, the gap between the corresponding dimension
41            pixel positions, must be a tuple or a list of integers, and the format is [1, 1, rate_row, rate_col].
42        padding (str): The type of padding algorithm, is a string whose value is "same" or "valid",
43            not case sensitive. Default: "valid".
44
45            - same: Means that the patch can take the part beyond the original image, and this part is filled with 0.
46
47            - valid: Means that the taken patch area must be completely covered in the original image.
48
49    Inputs:
50        - **input_x** (Tensor) - A 4-D tensor whose shape is [in_batch, in_row, in_col, in_depth] and
51          data type is number.
52
53    Outputs:
54        Tensor, a 4-D tensor whose data type is same as 'input_x',
55        and the shape is [out_batch, out_row, out_col, out_depth], the out_batch is the same as the in_batch.
56    """
57
58    @prim_attr_register
59    def __init__(self, ksizes, strides, rates, padding="valid"):
60        """init"""
61
62        def _check_tuple_or_list(arg_name, arg_val, prim_name):
63            validator.check_value_type(f"{arg_name}s", arg_val, [tuple, list], self.name)
64            if len(arg_val) != 4 or arg_val[0] != 1 or arg_val[1] != 1:
65                raise ValueError(f"For \'{prim_name}\' the format of {arg_name}s should be [1, {arg_name}_row, "
66                                 f"{arg_name}_col, 1], but got {arg_val}.")
67            if not isinstance(arg_val[2], int) or not isinstance(arg_val[3], int) or arg_val[2] < 1 or arg_val[3] < 1:
68                raise ValueError(f"For '{prim_name}' the {arg_name}_row and {arg_name}_col in {arg_name}s should be "
69                                 f"an positive integer number, but got {arg_name}_row is {arg_val[2]}, "
70                                 f"{arg_name}_col is {arg_val[3]}")
71
72        _check_tuple_or_list("ksize", ksizes, self.name)
73        _check_tuple_or_list("stride", strides, self.name)
74        _check_tuple_or_list("rate", rates, self.name)
75        self.padding = validator.check_string(padding.upper(), ['VALID', 'SAME'], 'padding', self.name)
76        self.add_prim_attr("padding", self.padding)
77        self.is_ge = context.get_context("enable_ge")
78
79    def infer_shape(self, input_x):
80        """infer shape"""
81        if len(input_x) != 4:
82            raise ValueError("The `input_x` should be a 4-D tensor, "
83                             f"but got a {len(input_x)}-D tensor whose shape is {input_x}")
84
85        in_batch, in_depth, in_row, in_col = input_x
86        _, _, ksize_row, ksize_col = self.ksizes
87        _, _, stride_row, stride_col = self.strides
88        _, _, rate_row, rate_col = self.rates
89
90        out_batch = in_batch
91        out_depth = ksize_row * ksize_col * in_depth
92
93        if self.padding == "VALID":
94            out_row = \
95                (in_row - (ksize_row + (ksize_row - 1) * (rate_row - 1))) // stride_row + 1
96            out_col = \
97                (in_col - (ksize_col + (ksize_col - 1) * (rate_col - 1))) // stride_col + 1
98        else:
99            out_row = (in_row - 1) // stride_row + 1
100            out_col = (in_col - 1) // stride_col + 1
101
102        out_shape = [out_batch, out_depth, out_row, out_col]
103        # avoiding empty outputs
104        validator.check("out_batch", out_batch, "", 0, Rel.GT, self.name)
105        validator.check("out_depth", out_depth, "", 0, Rel.GT, self.name)
106        validator.check("out_row", out_row, "", 0, Rel.GT, self.name)
107        validator.check("out_col", out_col, "", 0, Rel.GT, self.name)
108        return out_shape
109
110    def infer_dtype(self, input_x):
111        """infer dtype"""
112        validator.check_tensor_dtype_valid("input_x", input_x, mstype.number_type, self.name)
113        return input_x
114
115
116class Range(PrimitiveWithInfer):
117    r"""
118    Creates a sequence of numbers.
119    Set `input_x` as :math:`x_i` for each element, `output` as follows:
120
121    .. math::
122        \text{output}(x_i) = x_i * \text{delta} + \text{start}
123
124    Args:
125        start (float): If `limit` is `None`, the value acts as limit in the range and first entry
126            defaults to `0`. Otherwise, it acts as first entry in the range.
127        limit (float): Acts as upper limit of sequence. If `None`, defaults to the value of `start`
128            while set the first entry of the range to `0`. It can not be equal to `start`. Default: None.
129        delta (float): Increment of the range. It can not be equal to zero. Default: 1.0.
130
131    Inputs:
132        - **input_x** (Tensor) - The assistant data. A `1-D` tensor of type float32 or int32.
133
134    Outputs:
135        Tensor, has the same shape and dtype as `input_x`.
136
137    Examples:
138        >>> range_op = ops.Range(1.0, 8.0, 2.0)
139        >>> x = Tensor(np.array([1, 2, 3, 2]), mindspore.int32)
140        >>> output = range_op(x)
141        >>> print(output)
142        [3, 5, 7, 5]
143    """
144
145    @prim_attr_register
146    def __init__(self, start, limit=None, delta=1.0):
147        self.init_prim_io_names(inputs=['x'], outputs=['y'])
148        self.delta = validator.check_value_type("delta", delta, [float], self.name)
149        validator.check_value_type("start", start, [float], self.name)
150        if limit is None:
151            self.start = 0.0
152            self.limit = start
153            self.add_prim_attr("start", self.start)
154            self.add_prim_attr("limit", self.limit)
155        else:
156            validator.check_value_type("limit", limit, [float], self.name)
157        validator.check('start', self.start, 'limit', self.limit, Rel.NE, self.name)
158        if self.delta == 0.0:
159            raise ValueError("The input of `delta` can not be equal to zero.")
160        if self.delta > 0.0 and self.start > self.limit:
161            raise ValueError(f"Limit should be greater than start when delta:{self.delta} is more than zero, "
162                             f"but got start:{self.start}, limit:{self.limit}")
163        if self.delta < 0.0 and self.start < self.limit:
164            raise ValueError(f"Start should be greater than limit when delta:{self.delta} is less than zero, "
165                             f"but got start:{self.start}, limit:{self.limit}")
166
167    def infer_shape(self, x_shape):
168        return x_shape
169
170    def infer_dtype(self, x_dtype):
171        validator.check_tensor_dtype_valid('x', x_dtype, [mstype.float32, mstype.int32], self.name)
172        return x_dtype
173
174    def infer_value(self, x_value):
175        return Tensor(np.arange(self.start, self.limit, self.delta), dtype=x_value.dtype)
176
177
178class Quant(PrimitiveWithInfer):
179    r"""
180    Returns the quantized value of input_x.
181
182    If `sqrt_mode` is False:
183
184    .. math::
185        y = round(scale * x + offset)
186
187    If `sqrt_mode` is True:
188
189    .. math::
190        y = round(scale * x * scale + offset)
191
192    Note:
193        This operation only support Ascend 310 inference environment.
194
195    Args:
196        scale (float) : Specifies the scaling ratio.
197        offset (float): Specifies the offset.
198        sqrt_mode (bool) : Specifies whether to perform square root on `scale`. Default: False.
199        round_mode (str): Specifies the way to round. Must be one of ["Round", "Floor", "Ceil", "Trunc"].
200          Default: "Round".
201
202    Inputs:
203        - **input_x** (Tensor) : Input tensor. Its data type must be mindspore.float16 or mindspore.float32.
204
205    Outputs:
206        - Tensor: The quantized output tensor of type mindspore.int8.
207
208    Examples:
209        >>> input_x = Tensor([100.0, 150.0], mstype.float32)
210        >>> quant = ops.Quant(80.0, 0.0, False, "Round")
211        >>> y = quant(input_x)
212    """
213
214    @prim_attr_register
215    def __init__(self, scale, offset, sqrt_mode=False, round_mode="Round"):
216        self.scale = validator.check_value_type("scale", scale, [float], self.name)
217        self.offset = validator.check_value_type("offset", offset, [float], self.name)
218        self.sqrt_mode = validator.check_value_type("sqrt_mode", sqrt_mode, [bool], self.name)
219        self.round_mode = validator.check_string(round_mode, ["Round", "Floor", "Ceil", "Trunc"],
220                                                 "round_mode", self.name)
221
222    def infer_shape(self, x_shape):
223        return x_shape
224
225    def infer_dtype(self, x_type):
226        validator.check_subclass("input_x", x_type, mstype.tensor, self.name)
227        validator.check_type_name("input_x", x_type, [mstype.float16, mstype.float32], self.name)
228        return mstype.int8
229
230
231class Dequant(PrimitiveWithInfer):
232    r"""
233    Returns the dequantized value of input_x.
234    This operation will do ReLU to the dequantized value if `relu_flag` is True.
235
236    If `sqrt_mode` is False:
237
238    .. math::
239        y = x * deq\_scale
240
241    If `sqrt_mode` is True:
242
243    .. math::
244        y = x * deq\_scale * deq\_scale
245
246    Note:
247        This operation only support Ascend 310 inference environment.
248
249    Args:
250        sqrt_mode (bool) : Specifies whether to perform square root on `scale`. Default: False.
251        relu_flag (bool): Specifies whether to perform ReLU. Default: False.
252
253    Inputs:
254        - **input_x** (Tensor) : Input tensor. Must be mindspore.int32.
255        - **deq_scale** (Tensor) : Specifies the scaling ratio.
256          Data type must be mindspore.float16 or mindspore.uint64
257
258    Outputs:
259        - Tensor: The quantized output tensor of type mindspore.float16.
260
261    Examples:
262        >>> input_x = Tensor([100.0, 150.0], mstype.float32)
263        >>> dequant = ops.Dequant(False, False)
264        >>> y = dequant(input_x)
265    """
266
267    @prim_attr_register
268    def __init__(self, sqrt_mode=False, relu_flag=False):
269        self.sqrt_mode = validator.check_value_type("sqrt_mode", sqrt_mode, [bool], self.name)
270        self.relu_flag = validator.check_value_type("relu_flag", relu_flag, [bool], self.name)
271        self.add_prim_attr("dtype", mstype.float16)
272
273    def infer_shape(self, x_shape, deq_scale_shape):
274        return x_shape
275
276    def infer_dtype(self, x_type, deq_scale_type):
277        validator.check_subclass("x", x_type, mstype.tensor, self.name)
278        validator.check_type_name("x", x_type, [mstype.int32], self.name)
279        validator.check_type_name("deq_scale", deq_scale_type, [mstype.float16, mstype.uint64], self.name)
280        return mstype.float16
281
282
283class MatrixDiag(PrimitiveWithInfer):
284    """
285    Returns a batched diagonal tensor with a given batched diagonal values.
286
287    Inputs:
288        - **x** (Tensor) - A tensor which to be element-wise multi by `assist`. It can be one of the following data
289          types: float32, float16, int32, int8, and uint8.
290        - **assist** (Tensor) - A eye tensor of the same type as `x`. It's rank must greater than or equal to 2 and
291          it's last dimension must equal to the second to last dimension.
292
293    Outputs:
294        Tensor, has the same type and shape as input `assist`.
295
296    Examples:
297        >>> x = Tensor(np.array([1, -1]), mstype.float32)
298        >>> assist = Tensor(np.arange(-12, 0).reshape(3, 2, 2), mindspore.float32)
299        >>> matrix_diag = ops.MatrixDiag()
300        >>> result = matrix_diag(x, assist)
301        >>> print(result)
302        [[[-12.   11.]
303          [-10.    9.]]
304         [[ -8.    7.]
305          [ -6.    5.]]
306         [[ -4.    3.]
307          [ -2.    1.]]]
308    """
309
310    @prim_attr_register
311    def __init__(self):
312        """Initialize MatrixDiag"""
313
314    def infer_dtype(self, x_dtype, assist_dtype):
315        valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8]
316        args = {"x": x_dtype, "assist": assist_dtype}
317        validator.check_tensors_dtypes_same_and_valid(args, valid_type, self.name)
318        return x_dtype
319
320    def infer_shape(self, x_shape, assist_shape):
321        validator.check_int(len(assist_shape), 2, Rel.GE, "assist rank", self.name)
322        validator.check('rank of x', len(x_shape) + 1,
323                        'rank of assist', len(assist_shape), Rel.LE, self.name)
324        validator.check('assist\'s penultimate dimension', assist_shape[-2], 'assist\'s last dimension',
325                        assist_shape[-1], Rel.EQ, self.name)
326
327        r_end_dim = -len(x_shape)
328        r_idx = -1
329        while r_idx >= r_end_dim:
330            if x_shape[r_idx] != 1:
331                validator.check("reverse x dim %d" % r_idx, x_shape[r_idx], "reverse assist dim %d" %
332                                assist_shape[r_idx - 1], assist_shape[r_idx - 1], Rel.EQ, self.name)
333            r_idx = r_idx - 1
334
335        return assist_shape
336
337
338class MatrixDiagPart(PrimitiveWithInfer):
339    r"""
340    Returns the batched diagonal part of a batched tensor.
341
342    Inputs:
343        - **x** (Tensor) - The batched tensor. It can be one of the following data types:
344          float32, float16, int32, int8, uint8.
345        - **assist** (Tensor) - A eye tensor of the same type as `x`. With shape same as `x`.
346
347    Outputs:
348        Tensor, data type same as input `x`. The shape must be x.shape[:-2] + [min(x.shape[-2:])].
349
350    Examples:
351        >>> x = Tensor([[[-1, 0], [0, 1]], [[-1, 0], [0, 1]], [[-1, 0], [0, 1]]], mindspore.float32)
352        >>> assist = Tensor(np.arange(-12, 0).reshape(3, 2, 2), mindspore.float32)
353        >>> matrix_diag_part = ops.MatrixDiagPart()
354        >>> result = matrix_diag_part(x, assist)
355        >>> print(result)
356        [[12., -9.], [8., -5.], [4., -1.]]
357    """
358
359    @prim_attr_register
360    def __init__(self):
361        """Initialize MatrixDiagPart"""
362
363    def infer_dtype(self, x_dtype, assist_dtype):
364        valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8]
365        args = {"x": x_dtype, "assist": assist_dtype}
366        validator.check_tensors_dtypes_same_and_valid(args, valid_type, self.name)
367        return x_dtype
368
369    def infer_shape(self, x_shape, assist_shape):
370        validator.check_int(len(x_shape), 2, Rel.GE, "x rank", self.name)
371        validator.check("x shape", x_shape, "assist shape", assist_shape, Rel.EQ, self.name)
372
373        if assist_shape[-2] < assist_shape[-1]:
374            out_shape = assist_shape[:-1]
375        else:
376            out_shape = assist_shape[:-2] + assist_shape[-1:]
377        return out_shape
378
379
380class Send(PrimitiveWithInfer):
381    """
382    Send tensors from src_rank to the specified dest_rank.
383
384    Note:
385        Send and Recveive must be used in combination and have same sr_tag.
386        Send must be used between servers.
387
388    Args:
389        sr_tag (int): A required integer identifying the send/recv message tag. The message will
390                      will be received by the Receive op with the same "sr_tag".
391        dest_rank (int): A required integer identifying the destination rank.
392        group (str): The communication group to work on. Default: "hccl_world_group/nccl_world_group".
393
394    Inputs:
395        - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
396
397    Examples:
398        >>> import mindspore.ops as ops
399        >>> import mindspore.nn as nn
400        >>> from mindspore.communication import init
401        >>> from mindspore import Tensor
402        >>> import numpy as np
403        >>>
404        >>> init()
405        >>> class Net(nn.Cell):
406        >>>     def __init__(self):
407        >>>         super(Net, self).__init__()
408        >>>         self.depend = ops.Depend()
409        >>>         self.send = ops.Send(st_tag=0, dest_rank=8, group="hccl_world_group")
410        >>>
411        >>>     def construct(self, x):
412        >>>         out = self.depend(x, self.send(x))
413        >>>         return out
414        >>>
415        >>> input_ = Tensor(np.ones([2, 8]).astype(np.float32))
416        >>> net = Net()
417        >>> output = net(input_)
418    """
419
420    @prim_attr_register
421    def __init__(self, sr_tag, dest_rank, group=GlobalComm.WORLD_COMM_GROUP, group_back=GlobalComm.WORLD_COMM_GROUP):
422        self.rank = dest_rank
423        self.sr_tag = sr_tag
424        self.group = group
425
426    def infer_shape(self, x_shape):
427        self.add_prim_attr("shape", x_shape)
428        return x_shape
429
430    def infer_dtype(self, x_dtype):
431        return x_dtype
432
433
434class Receive(PrimitiveWithInfer):
435    """
436    receive tensors from src_rank.
437
438    Note:
439        Send and Receive must be used in combination and have same sr_tag.
440        Receive must be used between servers.
441
442    Args:
443        sr_tag (int): A required integer identifying the send/recv message tag. The message will
444                      will be send by the Send op with the same "sr_tag".
445        src_rank (int): A required integer identifying the source rank.
446        shape (list[int]): A required list identifying the shape of the tensor to be received.
447        dtype (Type): A required Type identifying the type of the tensor to be received. The supported types:
448                       int8, int16, int32, float16, float32.
449        group (str): The communication group to work on. Default: "hccl_world_group/nccl_world_group".
450
451    Inputs:
452        - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
453
454    Examples:
455        >>> import mindspore.ops as ops
456        >>> import mindspore.nn as nn
457        >>> from mindspore.communication import init
458        >>> from mindspore import Tensor
459        >>> import numpy as np
460        >>>
461        >>> init()
462        >>> class Net(nn.Cell):
463        >>>     def __init__(self):
464        >>>         super(Net, self).__init__()
465        >>>         self.recv = ops.Receive(st_tag=0, src_rank=0, shape=[2, 8], dtype=np.float32,
466        >>>                               group="hccl_world_group")
467        >>>
468        >>>     def construct(self):
469        >>>         out = self.recv()
470        >>>         return out
471        >>>
472        >>> net = Net()
473        >>> output = net()
474    """
475
476    @prim_attr_register
477    def __init__(self, sr_tag, src_rank, shape, dtype, group=GlobalComm.WORLD_COMM_GROUP,
478                 group_back=GlobalComm.WORLD_COMM_GROUP):
479        self.rank = src_rank
480        self.tag = sr_tag
481        self.shape = shape
482        self.dtype = dtype
483        self.group = group
484        valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8]
485        args = {"dtype": dtype}
486        validator.check_scalar_or_tensor_types_same(args, valid_type, self.name)
487
488    def infer_shape(self, x_shape=None):
489        return self.shape
490
491    def infer_dtype(self, x_dtype=None):
492        return self.dtype
493
494
495class MatrixSetDiag(PrimitiveWithInfer):
496    r"""
497    Modifies the batched diagonal part of a batched tensor.
498
499    Inputs:
500        - **x** (Tensor) - The batched tensor. Rank k+1, where k >= 1. It can be one of the following data types:
501          float32, float16, int32, int8, uint8.
502        - **diagonal** (Tensor) - The diagonal values. Must have the same type as input `x`. Rank k, where k >= 1.
503        - **assist** (Tensor) - A eye tensor of the same type as `x`. With shape same as `x`.
504
505    Outputs:
506        Tensor, data type same as input `x`. The shape same as `x`.
507
508    Examples:
509        >>> x = Tensor([[[-1, 0], [0, 1]], [[-1, 0], [0, 1]], [[-1, 0], [0, 1]]], mindspore.float32)
510        >>> diagonal = Tensor([[-1., 2.], [-1., 1.], [-1., 1.]], mindspore.float32)
511        >>> matrix_set_diag = ops.MatrixSetDiag()
512        >>> result = matrix_set_diag(x, diagonal)
513        >>> print(result)
514        [[[-1, 0], [0, 2]], [[-1, 0], [0, 1]], [[-1, 0], [0, 1]]]
515
516    """
517
518    @prim_attr_register
519    def __init__(self):
520        """Initialize MatrixSetDiag"""
521
522    def infer_dtype(self, x_dtype, diagonal_dtype, assist_dtype):
523        valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8]
524        args = {"x": x_dtype, "diagonal": diagonal_dtype, "assist": assist_dtype}
525        validator.check_tensors_dtypes_same_and_valid(args, valid_type, self.name)
526        return x_dtype
527
528    def infer_shape(self, x_shape, diagonal_shape, assist_shape):
529        validator.check_int(len(x_shape), 2, Rel.GE, "x rank", self.name)
530        validator.check("x shape", x_shape, "assist shape", assist_shape, Rel.EQ, self.name)
531
532        if x_shape[-2] < x_shape[-1]:
533            validator.check("diagnoal shape", diagonal_shape, "x shape excluding the last dimension",
534                            x_shape[:-1], Rel.EQ, self.name)
535        else:
536            validator.check("diagonal shape", diagonal_shape, "x shape excluding the second last dimension",
537                            x_shape[:-2] + x_shape[-1:], Rel.EQ, self.name)
538
539        return assist_shape
540
541
542class ConfusionMulGrad(PrimitiveWithInfer):
543    """
544    `output0` is the dot product result of input0 and input1.
545
546    `output1` is the dot product result of input0 and input1, then apply the reducesum operation on it.
547
548    Args:
549        axis (Union[int, tuple[int], list[int]]): The dimensions to reduce.
550            Default:(), reduce all dimensions. Only constant value is allowed.
551        keep_dims (bool):
552
553            - If true, keep these reduced dimensions and the length as 1.
554            - If false, don't keep these dimensions. Default:False.
555
556    Inputs:
557        - **input_0** (Tensor) - The input Tensor.
558        - **input_1** (Tensor) - The input Tensor.
559        - **input_2** (Tensor) - The input Tensor.
560
561    Outputs:
562        - **output_0** (Tensor) - The same shape as `input0`.
563        - **output_1** (Tensor)
564
565            - If axis is (), and keep_dims is false, the output is a 0-D array representing
566              the sum of all elements in the input array.
567            - If axis is int, set as 2, and keep_dims is false,
568              the shape of output is :math:`(x_1,x_3,...,x_R)`.
569            - If axis is tuple(int), set as (2,3), and keep_dims is false,
570              the shape of output is :math:`(x_1,x_4,...x_R)`.
571
572    Examples:
573        >>> confusion_mul_grad = ops.ConfusionMulGrad()
574        >>> input_0 = Tensor(np.random.randint(-2, 2, (2, 3)), mindspore.float32)
575        >>> input_1 = Tensor(np.random.randint(0, 4, (2, 3)), mindspore.float32)
576        >>> input_2 = Tensor(np.random.randint(-4, 0, (2, 3)), mindspore.float32)
577        >>> output_0, output_1 = confusion_mul_grad(input_0, input_1, input_2)
578        output_0:
579            [[ 3.   1.   0.]
580             [-6.   2.  -2.]]
581        output_1:
582            -3.0
583    """
584
585    @prim_attr_register
586    def __init__(self, axis=(), keep_dims=False):
587        self.init_prim_io_names(inputs=["input0", "input1", "input2"], outputs=["output0", "output1"])
588        self.axis_ = validator.check_value_type("axis", axis, [int, tuple, list], self.name)
589        self.keep_dims_ = validator.check_value_type("keep_dims", keep_dims, [bool], self.name)
590
591    def infer_shape(self, input0_shape, input1_shape, input2_shape):
592        outshape0 = input0_shape
593        outshape1 = _infer_shape_reduce(input1_shape, self.axis_, self.keep_dims_, self.name)
594        return outshape0, outshape1
595
596    def infer_dtype(self, input0_dtype, input1_dtype, input2_dtype):
597        validator.check_subclass("input0_dtype", input0_dtype, mstype.tensor, self.name)
598        validator.check_subclass("input1_dtype", input1_dtype, mstype.tensor, self.name)
599        validator.check_subclass("input2_dtype", input2_dtype, mstype.tensor, self.name)
600        return input0_dtype, input1_dtype
601
602
603class GpuConvertToDynamicShape(PrimitiveWithCheck):
604    """
605    This op is used for dynamic shape testing. Its inferred shape will be unknown
606    during compile time, so that its output will appear to be dynamically shaped.
607    The input will not be altered in any way. Put this operator before the operator
608    being tested for dynamic shape support.
609
610    Inputs:
611        - **input** (Tensor) - The tensor used for testing.
612
613    Outputs:
614        - **output** (Tensor) - Same shape, type and value as `input`.
615
616    Examples:
617          >>> # make a model, since dynamic shape operators must be in GRAPH_MODE
618          >>> class TestDynamicShapeReshapeNet(nn.Cell):
619          >>>     def __init__(self):
620          >>>         super(TestDynamicShapeReshapeNet, self).__init__()
621          >>>         self.convert_to_dynamic_shape = inner.GpuConvertToDynamicShape()
622          >>>         # suppose we are testing Reshape op
623          >>>         self.reshape = P.Reshape()
624          >>>
625          >>>     def construct(self, input, new_shape):
626          >>>         dynamic_shape_input = self.convert_to_dynamic_shape(input)
627          >>>         reshaped_input = self.reshape(input, new_shape)
628          >>>
629          >>> context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
630          >>> input = Tensor(np.array([0, 1, 2, 3])
631          >>> new_shape = (2, 2)
632          >>> net = TestDynamicShapeReshapeNet()
633          >>> output = net(input, new_shape)
634          >>> print(output)
635          [[0, 1], [2, 3]
636    """
637
638    @prim_attr_register
639    def __init__(self):
640        self.init_prim_io_names(inputs=["input"], outputs=["output"])
641
642    def check_shape(self, input_shape):
643        validator.check("input_shape rank", len(input_shape), "", 0, Rel.GT, self.name)
644
645    def check_dtype(self, input_dtype):
646        validator.check_subclass("input_dtype", input_dtype, mstype.tensor, self.name)
647
648
649class ErrorOnDynamicShapeInput(PrimitiveWithInfer):
650    """
651    This op is used for dynamic shape testing. The only purpose of this operator is
652    that it will throw a value error if the input is dynamically shaped.
653
654    Inputs:
655        - **input** (Tensor) - The tensor used for testing.
656
657    Outputs:
658        - **output** (Tensor) - Same shape, type and value as `input`.
659
660    Examples:
661          >>> # make a model, since dynamic shape operators must be in GRAPH_MODE
662          >>> class AssertDynamicShapeNet(nn.Cell):
663          >>>     def __init__(self):
664          >>>         super(AssertDynamicShapeNet, self).__init__()
665          >>>         self.convert_to_dynamic_shape = inner.GpuConvertToDynamicShape()
666          >>>         self.error_on_dynamic_shape_input = inner.ErrorOnDynamicShapeInput()
667          >>>
668          >>>     def construct(self, input, new_shape):
669          >>>         dynamic_shape_input = self.convert_to_dynamic_shape(input)
670          >>>         self.error_on_dynamic_shape_input(dynamic_shape_input)
671          >>>
672          >>> context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
673          >>> input = Tensor(np.array([0])
674          >>> net = TestDynamicShapeReshapeNet()
675          >>> output = net(input, new_shape)
676          ValueError: Input is dynamically shaped.
677    """
678
679    @prim_attr_register
680    def __init__(self):
681        self.init_prim_io_names(inputs=["input"], outputs=["output"])
682
683    def infer_shape(self, input_shape):
684        shape = list(input_shape)
685
686        for dim in shape:
687            if dim == -1:
688                raise ValueError("Input is dynamically shaped.")
689
690        return input_shape
691
692    def infer_type(self, input_dtype):
693        """Infer the dtype of input for ErrorOnDynamicShapeInput."""
694        validator.check_subclass("input_dtype", input_dtype, mstype.tensor, self.name)
695        return input_dtype
696
697    def infer_value(self, input_tensor):
698        return input_tensor
699
700
701class SequenceMask(PrimitiveWithCheck):
702    """
703    Returns a mask tensor representing the first N positions of each cell.
704
705    If lengths has shape [d_1, d_2, ..., d_n], then the resulting tensor mask has type dtype and shape
706    [d_1, d_2, ..., d_n, maxlen], with mask[i_1, i_2, ..., i_n, j] = (j < lengths[i_1, i_2, ..., i_n])
707
708    Inputs:
709        - **lengths** (Tensor) - Tensor to calculate the mask for. All values in this tensor should be
710          less than or equal to `maxlen`. Values greater than `maxlen` will be treated as `maxlen`.
711          Must be type int32 or int64.
712
713        - **maxlen** (int) - size of the last dimension of returned tensor. Must be positive and same
714          type as elements in `lengths`.
715
716    Outputs:
717        One mask tensor of shape lengths.shape + (maxlen,).
718
719    Supported Platforms:
720        ``GPU``
721
722    Examples:
723        >>> x = Tensor(np.array([[1, 3], [2, 0]]))
724        >>> sequence_mask = ops.SequenceMask()
725        >>> output = sequence_mask(x, 3)
726        >>> print(output)
727        [[[True False False]
728          [True True True]]
729         [[True True False]
730          [False False False]]]
731    """
732
733    @prim_attr_register
734    def __init__(self):
735        self.init_prim_io_names(inputs=["lengths", "maxlen"], outputs=["mask"])
736
737    def check_shape(self, lengths_shape, maxlen_shape):
738        validator.check("lengths_shape", len(lengths_shape), "", 0, Rel.GT, self.name)
739        validator.check("maxlen_shape", len(maxlen_shape), "", 0, Rel.EQ, self.name)
740
741    def check_dtype(self, lengths_dtype, maxlen_dtype):
742        validator.check_subclass("lengths_dtype", lengths_dtype, mstype.tensor, self.name)
743        validator.check_subclass("maxlen", maxlen_dtype, mstype.number, self.name)
744
745
746class SyncBatchNorm(PrimitiveWithInfer):
747    r"""
748    Sync Batch Normalization for input data and updated parameters.
749
750    Sync Batch Normalization is cross device synchronized Batch Normalization. Batch Normalization is
751    widely used in convolutional neural networks. This operation applies Batch Normalization over input
752    to avoid internal covariate shift as described in the paper `Batch Normalization: Accelerating
753    Deep Network Training by Reducing Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`_.
754    It rescales and recenters the features using a mini-batch of data and the learned parameters which
755    can be described in the following formula,
756
757    .. math::
758        y = \frac{x - mean}{\sqrt{variance + \epsilon}} * \gamma + \beta
759
760    where :math:`\gamma` is scale, :math:`\beta` is bias, :math:`\epsilon` is epsilon.
761
762    Args:
763        epsilon (float): A small value added for numerical stability. Default: 1e-5.
764        momentum (float): The hyper parameter to compute moving average for running_mean and running_var
765            (e.g. :math:`new\_running\_mean = (1 - momentum) * running\_mean + momentum * current\_mean`).
766            Momentum value must be [0, 1]. Default: 0.1.
767        group (str): The communication group to work on. Default: "sync_bn_group0".
768        device_num (int): The number of devices in each group. Default: 2.
769
770    Inputs:
771        - **input_x** (Tensor) - Tensor of shape :math:`(N, C)`, with float16 or float32 data type.
772        - **scale** (Tensor) - Tensor of shape :math:`(C,)`, with float16 or float32 data type.
773        - **bias** (Tensor) - Tensor of shape :math:`(C,)`, has the same data type with `scale`.
774        - **mean** (Tensor) - Tensor of shape :math:`(C,)`, with float16 or float32 data type.
775        - **variance** (Tensor) - Tensor of shape :math:`(C,)`, has the same data type with `mean`.
776
777    Outputs:
778        Tuple of 5 Tensor, the normalized inputs and the updated parameters.
779
780        - **output_x** (Tensor) - The same type and shape as the input_x. The shape is :math:`(N, C)`.
781        - **updated_scale** (Tensor) - Tensor of shape :math:`(C,)`.
782        - **updated_bias** (Tensor) - Tensor of shape :math:`(C,)`.
783        - **updated_moving_mean** (Tensor) - Tensor of shape :math:`(C,)`.
784        - **updated_moving_variance** (Tensor) - Tensor of shape :math:`(C,)`.
785
786    Supported Platforms:
787        ``Ascend``
788
789    Examples:
790        >>> # This example should be run with multiple processes.
791        >>> # Please refer to nn.SyncBatchNorm for direct use.
792        >>> input_x = Tensor(np.ones([2, 2]), mindspore.float32)
793        >>> scale = Tensor(np.ones([2]), mindspore.float32)
794        >>> bias = Tensor(np.ones([2]), mindspore.float32)
795        >>> mean = Tensor(np.ones([2]), mindspore.float32)
796        >>> variance = Tensor(np.ones([2]), mindspore.float32)
797        >>> sync_batch_norm = ops._inner_ops.SyncBatchNorm()
798        >>> output = sync_batch_norm(input_x, scale, bias, mean, variance)
799        >>> print(output)
800        (Tensor(shape=[2, 2], dtype=Float32, value=
801        [[ 1.00000000e+00, 1.00000000e+00],
802         [ 1.00000000e+00, 1.00000000e+00]]), Tensor(shape=[2], dtype=Float32, value=
803         [ 1.00000000e+00, 1.00000000e+00]), Tensor(shape=[2], dtype=Float32, value=
804         [ 1.00000000e+00, 1.00000000e+00]), Tensor(shape=[2], dtype=Float32, value=
805         [ 1.00000000e+00, 1.00000000e+00]), Tensor(shape=[2], dtype=Float32, value=
806         [ 1.00000000e+00, 1.00000000e+00]))
807    """
808
809    @prim_attr_register
810    def __init__(self, epsilon=1e-5, momentum=0.1, group="sync_bn_group0", device_num=2):
811        validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name)
812        validator.check_float_range(momentum, 0, 1, Rel.INC_BOTH, 'momentum', self.name)
813        validator.check_isinstance("group", group, str)
814        validator.check_int(device_num, 2, Rel.GE, "device_num", self.name)
815        self.init_prim_io_names(inputs=['x', 'scale', 'offset', 'mean', 'variance'],
816                                outputs=['y', 'batch_mean', 'batch_variance', 'reserve_space_1', 'reserve_space_2'])
817
818    def infer_shape(self, input_x, scale, bias, mean, variance):
819        validator.check_equal_int(len(scale), 1, "scale rank", self.name)
820        validator.check("scale shape", scale, "bias shape", bias, Rel.EQ, self.name)
821        validator.check("scale shape[0]", scale[0], "input_x channel", input_x[1], Rel.EQ, self.name)
822        validator.check_equal_int(len(mean), 1, "mean rank", self.name)
823        validator.check("mean shape", mean, "variance shape", variance, Rel.EQ, self.name)
824        validator.check("mean shape", mean, "scale shape", scale, Rel.EQ, self.name)
825        return (input_x, scale, scale, scale, scale)
826
827    def infer_dtype(self, input_x, scale, bias, mean, variance):
828        validator.check_tensor_dtype_valid("input_x", input_x, [mstype.float16, mstype.float32], self.name)
829        args = {"scale": scale, "bias": bias}
830        validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name)
831        args_moving = {"mean": mean, "variance": variance}
832        validator.check_tensors_dtypes_same_and_valid(args_moving, [mstype.float16, mstype.float32], self.name)
833        return (input_x, scale, bias, input_x, input_x)
834
835
836class Centralization(PrimitiveWithInfer):
837    """
838    Computes centralization. y = x - mean(x, axis).
839
840    Note:
841        The dimension index starts at 0 and must be in the range `[-input.ndim, input.ndim)`.
842
843    Inputs:
844        - **input_x** (Tensor) - The input tensor. The data type mast be float16 or float32.
845        - **axis** (Union[Int, Tuple(Int), List(Int)]) - The dimensions to reduce. Default: (), reduce all dimensions.
846          Only constant value is allowed. Must be in the range [-rank(input_x), rank(input_x)).
847
848    Outputs:
849        Tensor, has the same shape and dtype as the `input_x`.
850
851    Raises:
852        TypeError: If `axis` is not one of the following types: int, list, tuple, NoneType.
853        TypeError: If `axis` has non-Int elements.
854
855    Supported Platforms:
856        ``Ascend``
857
858    Examples:
859        >>> mindspore.set_seed(1)
860        >>> input_x = Tensor(np.random.randn(2, 2).astype(np.float32))
861        >>> centralization = ops.Centralization()
862        >>> output = centralization(input_x, -1)
863        >>> print(output)
864        [[ 1.1180509 -1.1180508]
865         [ 0.2723984 -0.2723984]]
866    """
867
868    __mindspore_signature__ = (
869        sig.make_sig('input_x'),
870        sig.make_sig('axis', default=())
871    )
872
873    @prim_attr_register
874    def __init__(self):
875        """Initialize Centralization"""
876        self.init_prim_io_names(inputs=['input_x', 'axis'], outputs=['output'])
877
878    def __infer__(self, input_x, axis):
879        x_shape = list(input_x['shape'])
880        x_dtype = input_x['dtype']
881        axis_v = axis['value']
882        rank = len(x_shape)
883
884        args = {'input_x': input_x['dtype']}
885        validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name)
886
887        if axis_v is None:
888            raise ValueError(f"For {self.name}, axis must be const.")
889        validator.check_value_type('axis', axis_v, [int, list, tuple], self.name)
890
891        if isinstance(axis_v, int):
892            validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, 'axis', self.name)
893        elif axis:
894            for index, one_axis in enumerate(axis_v):
895                validator.check_value_type('axis[%d]' % index, one_axis, [int], self.name)
896
897        out = {'shape': x_shape,
898               'dtype': x_dtype,
899               'value': None}
900        return out
901
902
903class StackInit(PrimitiveWithInfer):
904    """
905    Create a stack that produces tensors in first-in last-out order.
906
907    After `StackInit`, a tensor can be pushed onto the stack using `StackPush`, and popped
908    at the top of the stack using `StackPop`. Finally, the stack should be destroyed with `StackDestroy`.
909
910    Args:
911        index (int): The index of the stack. Default: 1.
912
913    Supported Platforms:
914        ``Ascend``
915
916    Examples:
917        >>> x = Tensor(np.array([[1, 3], [2, 0]]))
918        >>> index = 0
919        >>> stack = ops.StackInit(index)
920        >>> push = ops.StackPush(index)
921        >>> pop = ops.StackPop(index, x.shape, x.dtype)
922        >>> destroy = ops.StackDestroy(index)
923        >>> stack()
924        >>> push(x)
925        >>> y = pop()
926        >>> destroy()
927        >>> print(y)
928        [[1 3]
929         [2 0]]
930    """
931
932    @prim_attr_register
933    def __init__(self, index=1):
934        """StackInit"""
935        validator.check_value_type("index", index, [int], self.name)
936
937
938class StackPush(PrimitiveWithInfer):
939    """
940    Push a tensor onto the stack.
941
942    Before `StackPush`, the stack should be created using `StackInit`.
943    Please refer to the usage in source code of `StackInit`.
944
945    Args:
946        index (int): The index of the stack. Default: 1.
947
948    Inputs:
949        - **input** (Tensor) - A tensor to be pushed onto the stack.
950
951    Supported Platforms:
952        ``Ascend``
953
954    Examples:
955        Please refer to the usage of `StackInit`.
956    """
957
958    @prim_attr_register
959    def __init__(self, index=1):
960        """StackPush"""
961        validator.check_value_type("index", index, [int], self.name)
962        self.init_prim_io_names(inputs=['input'], outputs=[])
963
964
965class StackPop(PrimitiveWithInfer):
966    """
967    Pop the tensor at the top of the stack.
968
969     Before `StackPop`, the stack should be created using `StackInit`.
970     Please refer to the usage in source code of `StackInit`.
971
972    Args:
973        index (int): The index of the stack. Default: 1.
974        shape (tuple): The shape of the tensor at the top of the stack. Default: (1,).
975        dtype (mindspore.dtype): The type of the tensor at the top of the stack. Default: mindspore.float32.
976
977    Outputs:
978        - **output** (Tensor) - The tensor at the top of the stack.
979
980    Supported Platforms:
981        ``Ascend``
982
983    Examples:
984        Please refer to the usage of `StackInit`.
985    """
986
987    @prim_attr_register
988    def __init__(self, index=1, shape=(1,), dtype=mstype.float32):
989        """StackPop"""
990        validator.check_value_type("index", index, [int], self.name)
991
992        validator.check_value_type('shape type', shape, [list, tuple], self.name)
993        validator.check_int(len(np.array(shape).shape), 1, Rel.EQ, "dim of shape", self.name)
994        for elem in shape:
995            validator.check_int(elem, 1, Rel.GE, 'shape element', self.name)
996            validator.check_value_type('type of shape element', elem, [int], self.name)
997
998        validator.check_type_name("dtype", dtype, (mstype.bool_,) + mstype.number_type, self.name)
999        self.shape = shape
1000        self.dtype = dtype
1001
1002        self.init_prim_io_names(inputs=[], outputs=['output'])
1003
1004    def __infer__(self):
1005        return {'shape': (list(self.shape)),
1006                'dtype': (self.dtype),
1007                'value': None}
1008
1009
1010class StackDestroy(PrimitiveWithInfer):
1011    """
1012    Destroy the stack.
1013
1014     Before `StackDestroy`, the stack should be created using `StackInit`.
1015     Please refer to the usage in source code of `StackInit`.
1016
1017    Args:
1018        index (int): The index of the stack. Default: 1.
1019
1020    Supported Platforms:
1021        ``Ascend``
1022
1023    Examples:
1024        Please refer to the usage of `StackInit`.
1025    """
1026
1027    @prim_attr_register
1028    def __init__(self, index=1):
1029        """StackDestroy"""
1030        validator.check_value_type("index", index, [int], self.name)
1031
1032
1033class DynamicStitch(PrimitiveWithCheck):
1034    r"""
1035    Interleave the values from the data tensors into a single tensor.
1036
1037    Inputs:
1038        - **indices** (Union[tuple, list]) - A Tuple or list of Tensor objects with the same shape and type.
1039        - **data** (Union[tuple, list]) - A Tuple or list of Tensor objects with the same shape and type.
1040
1041    Outputs:
1042        Tensor. A stacked Tensor with the same type as `data`.
1043
1044    Raises:
1045        TypeError: If the data types of elements in `data` or `indices` are not the same.
1046        ValueError: If the length of `data` or `indices` is not greater than 1.
1047
1048    Supported Platforms:
1049        ``Ascend``
1050
1051    Examples:
1052        >>> x1 = Tensor([6], mstype.int32)
1053        >>> x2 = Tensor(np.array([4, 1]), mstype.int32)
1054        >>> x3 = Tensor(np.array([[5, 2], [0, 3]]), mstype.int32)
1055        >>> y1 = Tensor(np.array([[6, 1]]), mstype.int32)
1056        >>> y2 = Tensor(np.array([[41, 42], [11, 12]]), mstype.int32)
1057        >>> y3 = Tensor(np.array([[[51, 52], [21, 22]], [[1, 2], [31, 32]]]), mstype.int32)
1058        >>> stitch = ops.DynamicStitch()
1059        >>> output = stitch([x1, x2, x3], [y1, y2, y3])
1060        >>> print(output)
1061        [[ 1  2]
1062         [11 12]
1063         [21 22]
1064         [31 32]
1065         [41 42]
1066         [51 52]
1067         [61 62]]
1068    """
1069
1070    @prim_attr_register
1071    def __init__(self):
1072        """Initialize DynamicStitch"""
1073
1074    def check_shape(self, indices_shape, data_shape):
1075        validator.check_value_type("shape of indices", indices_shape, [tuple, list], self.name)
1076        validator.check_int(len(indices_shape), 1, Rel.GE, "len of indices_shape", self.name)
1077        indices_dim0 = len(indices_shape[0])
1078        indices_num = len(indices_shape)
1079
1080        validator.check_value_type("shape of data", data_shape, [tuple, list], self.name)
1081        validator.check_int(len(data_shape), 1, Rel.GE, "len of data_shape", self.name)
1082        data_dim0 = len(data_shape[0])
1083        data_num = len(indices_shape)
1084
1085        validator.check("size of indices", indices_num, 'size of data', data_num, Rel.EQ, self.name)
1086
1087        # shape of `data` must start with shape of `indices`
1088        for i in range(0, indices_num):
1089            indices_dim = len(indices_shape[i])
1090            data_dim = len(data_shape[i])
1091            validator.check(f"dim of indices[{i}]", indices_dim, f"dim of data[{i}]", data_dim, Rel.LE, self.name)
1092            if data_shape[i][:indices_dim] != data_shape[i][:indices_dim]:
1093                raise ValueError(f"data[{i}].shape: {data_shape} does not start with indices[{i}].shape: {data_shape}")
1094
1095        # the last-(data_dim0-indices_dim0)-dim of data shape must end with same shape.
1096        base_extra = data_dim0 - indices_dim0
1097        for i in range(0, data_num):
1098            indices_dim = len(indices_shape[i])
1099            data_dim = len(data_shape[i])
1100            extra = data_dim - indices_dim
1101            validator.check(f"extra dim of data[{i}]", extra,
1102                            f"extra dim of data[0]", base_extra, Rel.EQ, self.name)
1103            validator.check(f"data[0].shape[{indices_dim0}:]", data_shape[0][indices_dim0:],
1104                            f"data[{i}].shape[{len(indices_shape[i])}:]",
1105                            data_shape[i][indices_dim:], Rel.EQ, self.name)
1106
1107        out_shape = [-1] + data_shape[0][indices_dim0:]
1108        return out_shape
1109
1110    def check_dtype(self, indices_type, data_type):
1111        validator.check_subclass("indices[0]", indices_type[0], mstype.tensor, self.name)
1112        validator.check_subclass("data[0]", data_type[0], mstype.tensor, self.name)
1113        indices_num = len(indices_type)
1114        for i in range(0, indices_num):
1115            validator.check_tensor_dtype_valid(f'indices[{i}]', indices_type[i], mstype.int32, self.name)
1116            validator.check_tensor_dtype_valid(f'data[{i}]', data_type[i],
1117                                               mstype.number_type + (mstype.bool_,), self.name)
1118            validator.check(f"type of data[{i}]", data_type[i], f"type of data[0]", data_type[0], Rel.EQ, self.name)
1119        return data_type[0]
1120
1121
1122class DynamicBroadcastGradientArgs(Primitive):
1123    """
1124    Broadcast the two input shapes, return the dimensions that each need to be broadcast.
1125
1126    Input shape `s0` and shape `s1` can be broadcast to a common shape if for each dimension pair they are either equal
1127    or input is one or the target dimension is -1. In case of -1 in target shape, it will be replaced by the input
1128    shape's value in that dimension.
1129
1130    Inputs:
1131        - **s0** (Tensor) - A `1-D` tensor. The data type should be one of the following types: int32, int64,
1132          uint32, uint64.
1133        - **s1** (Tensor) - A `1-D` tensor with the same type as `s0`.
1134
1135    Outputs:
1136        Tuple(Tensor), tuple of 2 tensors, r0 and r1. The first one is the index tensor and the other one is the mask
1137        tensor.
1138
1139        - **r0** (Tensor) - The output shape is 1-D with the same type as s0.
1140        - **r1** (Tensor) - The output shape is 1-D with the same type as s0.
1141
1142    Raises:
1143        ValueError: if the `s0` and `s1` are incompatible, or if a - 1 in the target shape is in an invalid
1144                    location.
1145
1146    Supported Platforms:
1147        ``Ascend``
1148
1149    Examples:
1150        >>> shape0 = (4, 2, 1)
1151        >>> shape1 = (2, 7)
1152        >>> from mindspore.ops.operations import _inner_ops
1153        >>> args = _inner_ops.DynamicBroadcastGradientArgs()
1154        >>> r0, r1 = args(Tensor(shape0), Tensor(shape1))
1155        >>> print(r0, r1)
1156        [2], [0]
1157    """
1158
1159    @prim_attr_register
1160    def __init__(self):
1161        """Init BroadcastGradientArgs"""
1162
1163
1164class TensorCopySlices(Primitive):
1165    """
1166    Copy continues memory.
1167
1168    Inputs:
1169        - **x** (Tensor) - The target Tensor.
1170        - **value** (Tensor) - The tensor to update x.
1171        - **begin** (tuple[int]) - A tuple which represents the location where to start. Only
1172          constant value is allowed.
1173        - **end** (tuple[int]) - A tuple or which represents the maximum location where to end.
1174          Only constant value is allowed.
1175        - **strides** (tuple[int]) - A tuple which represents the stride is continuously added
1176          before reaching the maximum location. Only constant value is allowed.
1177
1178    Outputs:
1179        - **y** (Tensor), has the same shape and data type of x.
1180
1181    Examples:
1182        >>> import numpy as np
1183        >>> from mindspore.ops.operations import _inner_ops
1184        >>> copy_slices = _inner_ops.TensorCopySlices()
1185        >>> out = copy_slices(Tensor(np.zeros((5, 5))), Tensor(np.ones((2, 5))), (3, 0), (5, 5), (1, 1))
1186        >>> print(out)
1187            [[1., 1., 1., 1., 1.],
1188             [1., 1., 1., 1., 1.],
1189             [1., 1., 1., 1., 1.],
1190             [0., 0., 0., 0., 0.],
1191             [0., 0., 0., 0., 0.]]
1192
1193    Supported Platforms:
1194        ``Ascend`` ``GPU`` ``CPU``
1195    """
1196
1197    @prim_attr_register
1198    def __init__(self):
1199        """Initialize TensorScatterUpdate"""
1200        self.init_prim_io_names(inputs=['x', 'value', 'begin', 'end', 'strides'], outputs=['y'])
1201
1202
1203class Roll(Primitive):
1204    """
1205    Rolls the elements of a tensor along an axis.
1206
1207    The elements are shifted positively (towards larger indices) by the offset of `shift` along the dimension of `axis`.
1208    Negative `shift` values will shift elements in the opposite direction. Elements that roll passed the last position
1209    will wrap around to the first and vice versa. Multiple shifts along multiple axes may be specified.
1210
1211    Note:
1212        This inner operation is valid only if the axis is equal to 0. If the shift and the axis are tuples or lists,
1213        this inner operation is valid only for the first pair of elements.
1214
1215    Args:
1216        shift (Union[list(int), tuple(int), int]): Specifies the number of places by which elements are shifted
1217            positively (towards larger indices) along the specified dimension. Negative shifts will roll the elements
1218            in the opposite direction.
1219        axis (Union[list(int), tuple(int), int]): Specifies the dimension indexes of shape to be rolled. The value is
1220            forced to be zero in this operation.
1221
1222    Inputs:
1223        - **input_x** (Tensor) - Input tensor.
1224
1225    Outputs:
1226        Tensor, has the same shape and type as `input_x`.
1227
1228    Raises:
1229        TypeError: If `shift` is not an int, a tuple or a list.
1230        TypeError: If `axis` is not an int, a tuple or a list.
1231        TypeError: If element of `shift` is not an int.
1232        TypeError: If element of `axis` is not an int.
1233        ValueError: If axis is not equal to 0.
1234        ValueError: If shape of `shift` is not equal to 1.
1235        ValueError: If shape of `axis` is not equal to 1.
1236
1237    Supported Platforms:
1238        ``Ascend``
1239
1240    Examples:
1241        >>> from mindspore.ops.operations import _inner_ops as inner
1242        >>> input_x = Tensor(np.array([0, 1, 2, 3, 4]).astype(np.float32))
1243        >>> op = inner.Roll(shift=2, axis=0)
1244        >>> output = op(input_x)
1245        >>> print(output)
1246        [3. 4. 0. 1. 2.]
1247        >>> input_x = Tensor(np.array([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]).astype(np.float32))
1248        >>> op = inner.Roll(shift=-1, axis=0)
1249        >>> output = op(input_x)
1250        >>> print(output)
1251        [[5. 6. 7. 8. 9.]
1252         [0. 1. 2. 3. 4.]]
1253    """
1254
1255    @prim_attr_register
1256    def __init__(self, shift, axis):
1257        """Initialize Roll"""
1258        validator.check_value_type("shift", shift, [int, tuple, list], self.name)
1259        validator.check_value_type("axis", axis, [int, tuple, list], self.name)
1260        if isinstance(shift, (tuple, list)) and isinstance(axis, (tuple, list)):
1261            validator.check_equal_int(len(shift), 1, "shift size", self.name)
1262            validator.check_equal_int(len(axis), 1, "shift size", self.name)
1263            validator.check_equal_int(axis[0], 0, "axis", self.name)
1264        elif isinstance(shift, int) and isinstance(axis, int):
1265            validator.check_equal_int(axis, 0, "axis", self.name)
1266        self.init_prim_io_names(inputs=['input_x'], outputs=['output'])
1267
1268
1269class DSDMatmul(PrimitiveWithInfer):
1270    """
1271    The definition of the CusSquare primitive.
1272    """
1273
1274    @prim_attr_register
1275    def __init__(self):
1276        self.init_prim_io_names(inputs=['input_w1', 'input_w2', 'input_v'], outputs=['output_y'])
1277
1278    def infer_shape(self, input_w1_shape, input_w2_shape, input_v_shape):
1279        batch_size = input_w1_shape[0]
1280        head = input_w1_shape[1]
1281        v_embedding = input_v_shape[1] * 16 // head
1282        seq_len = input_v_shape[0] * 16 // batch_size
1283        return (batch_size, head, v_embedding // 16, seq_len // 16, 16, 16)
1284
1285    def infer_dtype(self, data_dtype1, data_dtype2, data_dtype3):
1286        return data_dtype1
1287
1288
1289class MatmulDDS(PrimitiveWithInfer):
1290    """MatmulDDS definition"""
1291
1292    @prim_attr_register
1293    def __init__(self, bs, heads):
1294        """init MatmulDDS"""
1295        self.init_prim_io_names(inputs=['q', 'k', 'local_mask', 'global_mask'],
1296                                outputs=['local_prob', 'global_prob'])
1297
1298        self.heads = heads
1299
1300    def infer_shape(self, q, k, local_mask, global_mask):
1301        seq_len = local_mask[0] * local_mask[-1]
1302        bs = q[1] * q[2] // seq_len
1303        global_size = seq_len // 4
1304        size_per_head = q[0] * q[-1] // self.heads
1305        heads = q[0] * q[-1] // size_per_head
1306        block_size = local_mask[1] * local_mask[2] // bs
1307        block_num = seq_len // block_size
1308        l_size = (bs, heads, block_num, block_size // 16, block_size // 16, 16, 16)
1309        g_size = (bs, heads, block_num, global_size // 16, block_size // 16, 16, 16)
1310
1311        return l_size, g_size
1312
1313    def infer_dtype(self, q, k, local_mask, global_mask):
1314        return q, q
1315
1316
1317class DSDGrad(PrimitiveWithInfer):
1318    """
1319    The definition of the CusSquare primitive.
1320    """
1321    @prim_attr_register
1322    def __init__(self):
1323        self.init_prim_io_names(inputs=['w1_gm', 'w2_gm', 'v_gm', 'a_gm', 'd_a_gm'],
1324                                outputs=['d_w1_gm', 'd_w2_gm', 'd_v_gm'])
1325
1326    def infer_shape(self, input_w1_shape, input_w2_shape, input_v_shape, input_a_shape, input_da_shape):
1327        return input_w1_shape, input_w2_shape, input_v_shape
1328
1329    def infer_dtype(self, data_dtype1, data_dtype2, data_dtype3, data_dtype4, data_dtype5):
1330        return data_dtype1, data_dtype1, data_dtype1
1331
1332
1333class MatmulDDSGrad(PrimitiveWithInfer):
1334    """MatmulDDS definition"""
1335
1336    @prim_attr_register
1337    def __init__(self):
1338        """init MatmulDDS"""
1339        self.init_prim_io_names(inputs=['q', 'k', 'local_prob', 'global_prob', 'local_prob_grad', 'global_prob_grad'],
1340                                outputs=['dq', 'dk'])
1341
1342    def infer_shape(self, q, k, local_prob, global_prob, local_prob_grad, global_prob_grad):
1343        k_size = (q[1], q[0], q[3], q[2])
1344
1345        return q, k_size
1346
1347    def infer_dtype(self, q, k, local_prob, global_prob, local_prob_grad, global_prob_grad):
1348        return q, k
1349