• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# coding: utf-8
2
3# Copyright 2020-2021 Huawei Technologies Co., Ltd
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15
16# limitations under the License.
17# ============================================================================
18
19"""Operators for array."""
20from collections import Counter
21import copy
22import functools
23import itertools
24import numbers
25import numpy as np
26
27from mindspore import log as logger
28from mindspore.common.initializer import Zero
29from .._utils import get_broadcast_shape
30from .._utils import get_concat_offset
31from ..operations.math_ops import _infer_shape_reduce
32from ..primitive import Primitive, PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register, _run_op
33from .. import signature as sig
34from ..._checkparam import Rel
35from ..._checkparam import Validator as validator
36from ...common import dtype as mstype
37from ...common._decorator import deprecated
38from ...common.parameter import Parameter
39from ...common.tensor import Tensor
40
41
42class _ScatterOp(PrimitiveWithInfer):
43    """
44    Defines Scatter operators
45    """
46    __mindspore_signature__ = (
47        sig.make_sig('x', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
48        sig.make_sig('indices', dtype=sig.sig_dtype.T1),
49        sig.make_sig('updates', dtype=sig.sig_dtype.T)
50    )
51
52    def _check_scatter_shape(self, x_shape, indices_shape, updates_shape, prim_name):
53        if indices_shape != [-1] and updates_shape and updates_shape != indices_shape + x_shape[1:]:
54            raise ValueError(f"For '{prim_name}', "
55                             f"updates_shape = indices_shape + x_shape[1:], but got x_shape: {x_shape}, "
56                             f"indices_shape: {indices_shape}, updates_shape: {updates_shape}.")
57
58    @prim_attr_register
59    def __init__(self, use_locking=False):
60        """Initialize _ScatterOp"""
61        validator.check_value_type('use_locking', use_locking, [bool], self.name)
62        self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y'])
63        self.add_prim_attr('side_effect_mem', True)
64
65    def infer_shape(self, x_shape, indices_shape, updates_shape):
66        self._check_scatter_shape(x_shape, indices_shape, updates_shape, self.name)
67        return x_shape
68
69    def infer_dtype(self, x_dtype, indices_dtype, updates_dtype):
70        validator.check_tensor_dtype_valid('indices', indices_dtype, [mstype.int32], self.name)
71        args = {"x": x_dtype, "updates": updates_dtype}
72        validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
73        return x_dtype
74
75
76class _ScatterOpDynamic(PrimitiveWithCheck):
77    """
78    Defines Scatter operators with dynamic shape
79    """
80    __mindspore_signature__ = (
81        sig.make_sig('x', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
82        sig.make_sig('indices', dtype=sig.sig_dtype.T1),
83        sig.make_sig('updates', dtype=sig.sig_dtype.T)
84    )
85
86    def _check_scatter_shape(self, x_shape, indices_shape, updates_shape, prim_name):
87        # x_shape cannot be dynamic
88        if np.any(np.array(x_shape) == -1):
89            raise ValueError(f"For '{prim_name}', the 'input_x' does not support dynamic shape, "
90                             f"but got the shape of 'input_x' is {x_shape}.")
91        # support indices and updates dynamic
92        if np.any(np.array(indices_shape) == -1) or np.any(np.array(updates_shape) == -1):
93            pass
94        elif indices_shape != [-1] and updates_shape and updates_shape != indices_shape + x_shape[1:]:
95            raise ValueError(f"For '{prim_name}', "
96                             f"updates_shape = indices_shape + x_shape[1:], but got x_shape: {x_shape}, "
97                             f"indices_shape: {indices_shape}, updates_shape: {updates_shape}.")
98
99    @prim_attr_register
100    def __init__(self, use_locking=False):
101        """Initialize _ScatterOpDynamic"""
102        validator.check_value_type('use_locking', use_locking, [bool], self.name)
103        self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y'])
104        self.add_prim_attr('side_effect_mem', True)
105
106    def check_shape(self, x_shape, indices_shape, updates_shape):
107        self._check_scatter_shape(x_shape, indices_shape, updates_shape, self.name)
108
109    def check_dtype(self, x_dtype, indices_dtype, updates_dtype):
110        validator.check_tensor_dtype_valid('indices', indices_dtype, [mstype.int32], self.name)
111        args = {"x": x_dtype, "updates": updates_dtype}
112        validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
113
114
115class _ScatterNdOp(_ScatterOp):
116    """
117    Defines _ScatterNd operators
118    """
119
120    def _check_scatter_shape(self, x_shape, indices_shape, updates_shape, prim_name):
121        validator.check('the dimension of x', len(x_shape),
122                        'the dimension of indices', indices_shape[-1], Rel.GE)
123        if indices_shape[:-1] + x_shape[indices_shape[-1]:] != updates_shape:
124            raise ValueError(f"For '{prim_name}', updates_shape = "
125                             f"indices_shape[:-1] + x_shape[indices_shape[-1]:], but got x_shape: {x_shape}, "
126                             f"indices_shape: {indices_shape}, updates_shape: {updates_shape}.")
127
128
129def _check_infer_attr_reduce(axis, keep_dims, prim_name):
130    validator.check_value_type('keep_dims', keep_dims, [bool], prim_name)
131    validator.check_value_type('axis', axis, [int, tuple], prim_name)
132    if isinstance(axis, tuple):
133        for index, value in enumerate(axis):
134            validator.check_value_type('axis[%d]' % index, value, [int], prim_name)
135
136
137class ExpandDims(PrimitiveWithInfer):
138    """
139    Adds an additional dimension to 'input_x` at the given axis.
140
141    Note:
142        If the specified axis is a negative number, the index is counted
143        backward from the end and starts at 1.
144
145    Inputs:
146        - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
147        - **axis** (int) - Specifies the dimension index at which to expand
148          the shape of `input_x`. The value of axis must be in the range
149          `[-input_x.ndim-1, input_x.ndim]`. Only constant value is allowed.
150
151    Outputs:
152        Tensor, the shape of tensor is :math:`(1, x_1, x_2, ..., x_R)` if the
153        value of `axis` is 0. It has the same data type as `input_x`.
154
155    Raises:
156        ValueError: If `axis` is not an int or not in the valid range.
157
158    Supported Platforms:
159        ``Ascend`` ``GPU`` ``CPU``
160
161    Examples:
162        >>> input_tensor = Tensor(np.array([[2, 2], [2, 2]]), mindspore.float32)
163        >>> expand_dims = ops.ExpandDims()
164        >>> output = expand_dims(input_tensor, 0)
165        >>> print(output)
166        [[[2. 2.]
167          [2. 2.]]]
168    """
169
170    @prim_attr_register
171    def __init__(self):
172        """Initialize ExpandDims"""
173        self.init_prim_io_names(inputs=['x', 'axis'], outputs=['output'])
174
175    def __infer__(self, x, axis):
176        validator.check_subclass("x", x['dtype'], mstype.tensor, self.name)
177        x_shape = list(x['shape'])
178        axis_v = axis['value']
179        rank = len(x_shape)
180        validator.check_int_range(axis_v, -rank - 1, rank, Rel.INC_BOTH, 'axis', self.name)
181        value = None
182        if x['value'] is not None:
183            value = x['value'].asnumpy()
184            value = np.expand_dims(value, axis_v)
185            value = Tensor(value)
186        if axis_v < 0:
187            axis_v = rank + 1 + axis_v
188        x_shape.insert(axis_v, 1)
189        out = {'shape': x_shape,
190               'dtype': x['dtype'],
191               'value': value}
192        if 'min_shape' in x and 'max_shape' in x:
193            out['min_shape'] = x['min_shape']
194            out['min_shape'].insert(axis_v, 1)
195            out['max_shape'] = x['max_shape']
196            out['max_shape'].insert(axis_v, 1)
197        return out
198
199
200class DType(Primitive):
201    """
202    Returns the data type of the input tensor as mindspore.dtype.
203
204    Inputs:
205        - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
206
207    Outputs:
208        mindspore.dtype, the data type of a tensor.
209
210    Raises:
211        TypeError: If `input_x` is not a Tensor.
212
213    Supported Platforms:
214        ``Ascend`` ``GPU`` ``CPU``
215
216    Examples:
217        >>> input_tensor = Tensor(np.array([[2, 2], [2, 2]]), mindspore.float32)
218        >>> output = ops.DType()(input_tensor)
219        >>> print(output)
220        Float32
221    """
222
223    @prim_attr_register
224    def __init__(self):
225        """Initialize DType"""
226
227
228class SameTypeShape(PrimitiveWithInfer):
229    """
230    Checks whether the data type and shape of two tensors are the same.
231
232    Inputs:
233        - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
234        - **input_y** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_S)`.
235
236    Outputs:
237        Tensor, the shape of tensor is :math:`(x_1, x_2, ..., x_R)`,
238        if data type and shape of `input_x` and `input_y` are the same.
239
240    Raises:
241        TypeError: If the data types of `input_x` and `input_y` are not the same.
242        ValueError: If the shapes of `input_x` and `input_y` are not the same.
243
244    Supported Platforms:
245        ``Ascend`` ``GPU`` ``CPU``
246
247    Examples:
248        >>> input_x = Tensor(np.array([[2, 2], [2, 2]]), mindspore.float32)
249        >>> input_y = Tensor(np.array([[2, 2], [2, 2]]), mindspore.float32)
250        >>> output = ops.SameTypeShape()(input_x, input_y)
251        >>> print(output)
252        [[2. 2.]
253         [2. 2.]]
254    """
255
256    @prim_attr_register
257    def __init__(self):
258        """Initialize Same"""
259
260    def __call__(self, x, y):
261        """run in PyNative mode"""
262        validator.check_value_type('x', x, Tensor, self.name)
263        validator.check_value_type('y', y, Tensor, self.name)
264        validator.check('x dtype', x.dtype, 'y dtype', y.dtype, Rel.EQ, self.name, TypeError)
265        validator.check('x shape', x.shape, 'y shape', y.shape, Rel.EQ, self.name)
266        return x
267
268    def __infer__(self, x, y):
269        validator.check_subclass('x', x['dtype'], mstype.tensor, self.name)
270        validator.check_subclass('y', y['dtype'], mstype.tensor, self.name)
271        validator.check('x dtype', x['dtype'], 'y dtype', y['dtype'], Rel.EQ, self.name, TypeError)
272        validator.check('x shape', x['shape'], 'y shape', y['shape'], Rel.EQ, self.name)
273        return x
274
275
276class Cast(PrimitiveWithInfer):
277    """
278    Returns a tensor with the new specified data type.
279
280    Inputs:
281        - **input_x** (Union[Tensor, Number]) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
282          The tensor to be cast.
283        - **type** (dtype.Number) - The valid data type of the output tensor. Only constant value is allowed.
284
285    Outputs:
286        Tensor, the shape of tensor is the same as `input_x`, :math:`(x_1, x_2, ..., x_R)`.
287
288    Raises:
289        TypeError: If `input_x` is neither Tensor nor Number.
290        TypeError: If `type` is not a Number.
291
292    Supported Platforms:
293        ``Ascend`` ``GPU`` ``CPU``
294
295    Examples:
296        >>> input_np = np.random.randn(2, 3, 4, 5).astype(np.float32)
297        >>> input_x = Tensor(input_np)
298        >>> type_dst = mindspore.int32
299        >>> cast = ops.Cast()
300        >>> output = cast(input_x, type_dst)
301        >>> print(output.dtype)
302        Int32
303        >>> print(output.shape)
304        (2, 3, 4, 5)
305    """
306
307    @prim_attr_register
308    def __init__(self):
309        # if primitive need setattr in __infer__ need add this flag
310        """Initialize Cast"""
311        self.init_prim_io_names(inputs=['x', 'dst_type'], outputs=['output'])
312
313    def check_elim(self, x, dtype):
314        if isinstance(x, (Tensor, numbers.Number, Parameter)):
315            if isinstance(x, Parameter):
316                data = x.data
317                if data.dtype == dtype:
318                    return (True, x)
319            if isinstance(x, Tensor) and x.dtype == dtype:
320                x = Tensor(x)
321                x.set_cast_dtype()
322                return (True, x)
323            if isinstance(x, numbers.Number):
324                return (True, Tensor(x, dtype=dtype))
325        return (False, None)
326
327    def __infer__(self, x, t):
328        src_type = x['dtype']
329        dst_type = t['value']
330
331        validator.check_subclass("input_x", src_type, [mstype.tensor, mstype.number], self.name)
332        validator.check_subclass("type", dst_type, mstype.number, self.name)
333
334        if isinstance(src_type, type(mstype.tensor)):
335            src_type = x['dtype'].element_type()
336        if isinstance(dst_type, type(mstype.tensor)):
337            dst_type = dst_type.element_type()
338        self.add_prim_attr('DstT', dst_type)
339        self.add_prim_attr('SrcT', src_type)
340        self.add_prim_attr('dst_type', dst_type)
341
342        value = None
343        if x['value'] is not None:
344            np_dst_type = mstype.dtype_to_nptype(dst_type)
345            if isinstance(x['value'], (int, float)):
346                value = Tensor(np.array(x['value']).astype(np_dst_type))
347            else:
348                value = Tensor(x['value'].asnumpy().astype(np_dst_type))
349
350        out = {'shape': x['shape'],
351               'dtype': mstype.tensor_type(t['value']),
352               'value': value}
353        if 'min_shape' in x and 'max_shape' in x:
354            out['min_shape'] = x['min_shape']
355            out['max_shape'] = x['max_shape']
356        return out
357
358
359class IsSubClass(PrimitiveWithInfer):
360    """
361    Checks whether this type is a sub-class of another type.
362
363    Inputs:
364        - **sub_type** (mindspore.dtype) - The type to be checked. Only constant value is allowed.
365        - **type_** (mindspore.dtype) - The target type. Only constant value is allowed.
366
367    Outputs:
368        bool, the check result.
369
370    Raises:
371        TypeError: If `sub_type` or `type_` is not a Type.
372
373    Supported Platforms:
374        ``Ascend`` ``GPU`` ``CPU``
375
376    Examples:
377        >>> output = ops.IsSubClass()(mindspore.int32,  mindspore.intc)
378        >>> print(output)
379        True
380    """
381
382    @prim_attr_register
383    def __init__(self):
384        pass
385
386    def __infer__(self, sub_type, type_):
387        sub_type_t = sub_type['value']
388        type_v = type_['value']
389
390        validator.check_value_type("sub_type", sub_type_t, [mstype.Type], self.name)
391        validator.check_value_type("type_", type_v, [mstype.Type], self.name)
392
393        value = mstype.issubclass_(sub_type_t, type_v)
394
395        out = {'shape': (),
396               'dtype': mstype.type_type,
397               'value': value}
398        return out
399
400
401class IsInstance(PrimitiveWithInfer):
402    """
403    Checks whether an object is an instance of a target type.
404
405    Inputs:
406        - **inst** (Any Object) - The instance to be checked. Only constant value is allowed.
407        - **type_** (mindspore.dtype) - The target type. Only constant value is allowed.
408
409    Outputs:
410        bool, the check result.
411
412    Raises:
413        TypeError: If `type_` is not a Type.
414
415    Supported Platforms:
416        ``Ascend`` ``GPU`` ``CPU``
417
418    Examples:
419        >>> inst = 1
420        >>> output = ops.IsInstance()(inst, mindspore.int32)
421        >>> print(output)
422        False
423    """
424
425    @prim_attr_register
426    def __init__(self):
427        pass
428
429    def __infer__(self, inst, type_):
430        sub_type_t = inst['dtype']
431        type_v = type_['value']
432
433        validator.check_value_type("type_", type_v, [mstype.Type], self.name)
434
435        if type_v == mstype.list_:
436            value = isinstance(sub_type_t, list)
437        elif type_v == mstype.tuple_:
438            value = isinstance(sub_type_t, tuple)
439        else:
440            value = mstype.issubclass_(sub_type_t, type_v)
441
442        out = {'shape': (),
443               'dtype': mstype.type_type,
444               'value': value}
445        return out
446
447
448class Reshape(PrimitiveWithInfer):
449    """
450    Reshapes the input tensor with the same values based on a given shape tuple.
451
452    Inputs:
453        - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
454        - **input_shape** (tuple[int]) - The input tuple is constructed by multiple
455          integers, i.e., :math:`(y_1, y_2, ..., y_S)`. Only constant value is allowed.
456
457    Outputs:
458        Tensor, the shape of tensor is :math:`(y_1, y_2, ..., y_S)`.
459
460    Raises:
461        ValueError: Given a shape tuple, if it has several -1; or if the product
462            of its elements is less than or equal to 0 or cannot be divided by the product
463            of the input tensor shape; or if it does not match the input's array size.
464
465    Supported Platforms:
466        ``Ascend`` ``GPU`` ``CPU``
467
468    Examples:
469        >>> input_x = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32)
470        >>> reshape = ops.Reshape()
471        >>> output = reshape(input_x, (3, 2))
472        >>> print(output)
473        [[-0.1  0.3]
474         [ 3.6  0.4]
475         [ 0.5 -3.2]]
476    """
477
478    @prim_attr_register
479    def __init__(self):
480        """Initialize Reshape"""
481        self.init_prim_io_names(inputs=['tensor', 'shape'], outputs=['output'])
482
483    def __infer__(self, x, shape):
484        shape_v = shape['value']
485        x_shp = x['shape']
486        validator.check_subclass("x", x['dtype'], mstype.tensor, self.name)
487        validator.check_value_type("shape", shape_v, [tuple], self.name)
488        shape_v = list(shape_v)
489        neg_index = -1
490        dim_prod = 1
491        for i, shp_i in enumerate(shape_v):
492            validator.check_value_type("shape[%d]" % i, shp_i, [int], self.name)
493            if shp_i == -1:
494                if neg_index != -1:
495                    raise ValueError(f"For '{self.name}', the 'input_shape' can only has one -1 at most, "
496                                     f"but got {shape_v}.")
497                neg_index = i
498            else:
499                dim_prod *= shp_i
500        arr_prod = np.prod(x_shp)
501        if arr_prod <= 0:
502            if 'max_shape' in x:
503                x_max_shape = x['max_shape']
504            else:
505                x_max_shape = x['shape']
506            if 'min_shape' in x:
507                x_min_shape = x['min_shape']
508            else:
509                x_min_shape = x['shape']
510            max_arr_prod = np.prod(x_max_shape)
511            min_arr_prod = np.prod(x_min_shape)
512            max_shape = list(shape_v)
513            min_shape = list(shape_v)
514            if neg_index != -1:
515                max_shape[neg_index] = int(max_arr_prod / dim_prod)
516                min_shape[neg_index] = int(min_arr_prod / dim_prod)
517            else:
518                raise ValueError(f"For '{self.name}', the 'input_shape' must have -1 in the case of dynamic shape, "
519                                 f"but got {shape_v}.")
520            out = {'shape': shape['value'],
521                   'dtype': x['dtype'],
522                   'value': None,
523                   'max_shape': tuple(max_shape),
524                   'min_shape': tuple(min_shape)}
525        else:
526            if dim_prod <= 0:
527                raise ValueError(f"For '{self.name}', the shape of 'input_x' is {x_shp}, "
528                                 f"the value of 'input_shape' is {shape_v}. "
529                                 f"The product of 'input_shape' should > 0, but got {dim_prod}.")
530            if neg_index != -1:
531                shape_v[neg_index] = int(arr_prod / dim_prod)
532                dim_prod *= shape_v[neg_index]
533            if dim_prod != arr_prod:
534                raise ValueError(f"For '{self.name}', the shape of 'input_x' is {x_shp}, "
535                                 f"the value of 'input_shape' value is {shape_v}. "
536                                 f"The product of the shape of 'input_x' should be equal to product of 'input_shape', "
537                                 f"but product of the shape of 'input_x' is {arr_prod}, "
538                                 f"product of 'input_shape' is {dim_prod}.")
539            value = None
540            if x['value'] is not None:
541                value = Tensor(x['value'].asnumpy().reshape(shape_v))
542
543            out = {'shape': tuple(shape_v),
544                   'dtype': x['dtype'],
545                   'value': value}
546        return out
547
548
549class Shape(Primitive):
550    """
551    Returns the shape of the input tensor. And it used to be static shape.
552
553    static shape: A shape that can be obtained without running the graph. It is an inherent property of tensor and
554    may be unknown. The static shape information can be completed by artificial setting.
555    No matter what the input of the graph is, the static shape is not affected.
556
557    Inputs:
558        - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
559
560    Outputs:
561        tuple[int], the output tuple is constructed by multiple integers,
562        :math:`(x_1, x_2, ..., x_R)`.
563
564    Raises:
565        TypeError: If `input_x` is not a Tensor.
566
567    Supported Platforms:
568        ``Ascend`` ``GPU`` ``CPU``
569
570    Examples:
571        >>> input_x = Tensor(np.ones(shape=[3, 2, 1]), mindspore.float32)
572        >>> shape = ops.Shape()
573        >>> output = shape(input_x)
574        >>> print(output)
575        (3, 2, 1)
576    """
577
578    @prim_attr_register
579    def __init__(self):
580        """Initialize Shape"""
581
582
583class DynamicShape(Primitive):
584    """
585    Returns the shape of the input tensor. And it used to be dynamic shape.
586
587    Note:
588        Dynamic shape: After the graph is running, as the tensor flows in the graph, the specific shape of the tensor
589        on each node on the graph can be inferred according to the structure of the graph.
590        This shape is called a dynamic shape. As the input shape of the graph is different,
591        the dynamic shape of the tensor in the graph will change.
592
593    Inputs:
594        - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
595
596    Outputs:
597        Tensor[int], 1-dim Tensor of type int32
598
599    Raises:
600        TypeError: If `input_x` is not a Tensor.
601
602    Supported Platforms:
603        ``Ascend`` ``GPU`` ``CPU``
604
605    Examples:
606        >>> input_x = Tensor(np.ones(shape=[3, 2, 1]), mindspore.float32)
607        >>> shape = ops.DynamicShape()
608        >>> output = shape(input_x)
609        >>> print(output)
610        [3 2 1]
611    """
612
613    @prim_attr_register
614    def __init__(self):
615        """init Shape"""
616        self.init_prim_io_names(inputs=['tensor'], outputs=['output'])
617        self.add_prim_attr('is_dynamic_shape', True)
618
619
620class Squeeze(PrimitiveWithInfer):
621    """
622    Returns a tensor with the same data type but dimensions of 1 are removed based on `axis`.
623
624    If `axis` is specified, it will remove the dimensions of size 1 in the given `axis`.
625    It `axis` is None, it will remove all the dimensions of size 1.
626
627    Note:
628        The dimension index starts at 0 and must be in the range `[-input.ndim, input.ndim)`.
629
630    Args:
631        axis (Union[int, tuple(int)]): Specifies the dimension indexes of shape to be removed, which will remove
632            all the dimensions that are equal to 1. If specified, it must be int32 or int64.
633            Default: (), an empty tuple.
634
635    Inputs:
636        - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
637
638    Outputs:
639        Tensor, the shape of tensor is :math:`(x_1, x_2, ..., x_S)`.
640
641    Raises:
642        TypeError: If `axis` is neither an int nor tuple.
643        TypeError: If `axis` is a tuple whose elements are not all int.
644        ValueError: If the corresponding dimension of the specified axis does not equal to 1.
645
646    Supported Platforms:
647        ``Ascend`` ``GPU`` ``CPU``
648
649    Examples:
650        >>> input_x = Tensor(np.ones(shape=[3, 2, 1]), mindspore.float32)
651        >>> squeeze = ops.Squeeze(2)
652        >>> output = squeeze(input_x)
653        >>> print(output)
654        [[1. 1.]
655         [1. 1.]
656         [1. 1.]]
657    """
658
659    @prim_attr_register
660    def __init__(self, axis=()):
661        """Initialize Squeeze"""
662        self.init_prim_io_names(inputs=['x'], outputs=['output'])
663        validator.check_value_type('axis', axis, [int, tuple], self.name)
664        if isinstance(axis, tuple):
665            for idx, item in enumerate(axis):
666                validator.check_value_type("axis[%d]" % idx, item, [int], self.name)
667        else:
668            self.axis = (axis,)
669            self.add_prim_attr("axis", (axis,))
670
671    def infer_shape(self, x_shape):
672        axis = self.axis
673        x_shape = list(x_shape)
674        ndim = len(x_shape)
675        if not axis:
676            ret = [d for d in x_shape if d != 1]
677        else:
678            for a in axis:
679                validator.check_int_range(a, -ndim, ndim - 1, Rel.INC_BOTH, 'axis or its elements', self.name)
680                if x_shape[a] != 1:
681                    raise ValueError(f"For '{self.name}', the shape of 'input_x' at {a} dimension should be 1, "
682                                     f"but got {x_shape[a]}.")
683            ret = [x_shape[i] for i in range(ndim) if not (i in axis or (i - ndim) in axis)]
684        return ret
685
686    def infer_dtype(self, x_dtype):
687        validator.check_subclass("x", x_dtype, mstype.tensor, self.name)
688        return x_dtype
689
690
691class Transpose(Primitive):
692    """
693    Permutes the dimensions of the input tensor according to input permutation.
694
695    Inputs:
696        - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
697        - **input_perm** (tuple[int]) - The permutation to be converted. The elements in `input_perm` are composed of
698          the indexes of each dimension of `input_x`. The length of `input_perm` and the shape of `input_x` must be
699          the same. Only constant value is allowed. Must be in the range [0, rank(input_x)).
700
701    Outputs:
702        Tensor, the type of output tensor is the same as `input_x` and the shape of output tensor is decided by the
703        shape of `input_x` and the value of `input_perm`.
704
705    Raises:
706        TypeError: If `input_perm` is not a tuple.
707        ValueError: If length of shape of `input_x` is not equal to length of shape of `input_perm`.
708        ValueError: If the same element exists in `input_perm`.
709
710    Supported Platforms:
711        ``Ascend`` ``GPU`` ``CPU``
712
713    Examples:
714        >>> input_x = Tensor(np.array([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]), mindspore.float32)
715        >>> input_perm = (0, 2, 1)
716        >>> transpose = ops.Transpose()
717        >>> output = transpose(input_x, input_perm)
718        >>> print(output)
719        [[[ 1.  4.]
720          [ 2.  5.]
721          [ 3.  6.]]
722         [[ 7. 10.]
723          [ 8. 11.]
724          [ 9. 12.]]]
725    """
726
727    @prim_attr_register
728    def __init__(self):
729        """Initialize Transpose"""
730        self.init_prim_io_names(inputs=['x', 'perm'], outputs=['output'])
731
732
733class Unique(Primitive):
734    """
735    Returns the unique elements of input tensor and also return a tensor containing the index of each value of input
736    tensor corresponding to the output unique tensor.
737
738    The output contains Tensor `y` and Tensor `idx`, the format is probably similar to (`y`, `idx`).
739    The shape of Tensor `y` and Tensor `idx` is different in most cases, because Tensor `y` will be deduplicated,
740    and the shape of Tensor `idx` is consistent with the input.
741
742    To get the same shape between `idx` and `y`, please ref to 'UniqueWithPad' operator.
743
744    Inputs:
745        - **input_x** (Tensor) - The input tensor.
746          The shape is :math:`(N,*)` where :math:`*` means, any number of additional dimensions.
747
748    Outputs:
749        Tuple, containing Tensor objects `(y, idx), `y` is a tensor with the
750        same type as `input_x`, and contains the unique elements in `x`, sorted in
751        ascending order. `idx` is a tensor containing indices of elements in
752        the input corresponding to the output tensor.
753
754    Raises:
755        TypeError: If `input_x` is not a Tensor.
756
757    Supported Platforms:
758        ``Ascend`` ``GPU`` ``CPU``
759
760    Examples:
761        >>> input_x = Tensor(np.array([1, 2, 5, 2]), mindspore.int32)
762        >>> output = ops.Unique()(input_x)
763        >>> print(output)
764        (Tensor(shape=[3], dtype=Int32, value= [1, 2, 5]), Tensor(shape=[4], dtype=Int32, value= [0, 1, 2, 1]))
765        >>> y = output[0]
766        >>> print(y)
767        [1 2 5]
768        >>> idx = output[1]
769        >>> print(idx)
770        [0 1 2 1]
771        >>> # As can be seen from the above, y and idx shape
772        >>> # note that for GPU, this operator must be wrapped inside a model, and executed in graph mode.
773        >>> class UniqueNet(nn.Cell):
774        ...     def __init__(self):
775        ...         super(UniqueNet, self).__init__()
776        ...         self.unique_op = ops.Unique()
777        ...
778        ...     def construct(self, x):
779        ...         output, indices = self.unique_op(x)
780        ...         return output, indices
781        ...
782        >>> input_x = Tensor(np.array([1, 2, 5, 2]), mindspore.int32)
783        >>> net = UniqueNet()
784        >>> output = net(input_x)
785        >>> print(output)
786        (Tensor(shape=[3], dtype=Int32, value= [1, 2, 5]), Tensor(shape=[4], dtype=Int32, value= [0, 1, 2, 1]))
787    """
788
789    @prim_attr_register
790    def __init__(self):
791        self.init_prim_io_names(inputs=['x'], outputs=['output'])
792
793
794class Gather(Primitive):
795    r"""
796    Returns a slice of the input tensor based on the specified indices and axis.
797
798    Slices the input tensor base on the indices at specified axis. See the following example for more clear.
799
800    Inputs:
801        - **input_params** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
802          The original Tensor.
803        - **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`.
804          Specifies the indices of elements of the original Tensor. Must be in the range
805          `[0, input_param.shape[axis])` which are only validated on CPU. The data type can be int32 or int64.
806        - **axis** (int) - Specifies the dimension index to gather indices.
807
808    Outputs:
809        Tensor, the shape of tensor is
810        :math:`input\_params.shape[:axis] + input\_indices.shape + input\_params.shape[axis + 1:]`.
811
812    Raises:
813        TypeError: If `axis` is not an int.
814        TypeError: If `input_indices` is not an int.
815
816    Supported Platforms:
817        ``Ascend`` ``GPU`` ``CPU``
818
819    Examples:
820        >>> input_params = Tensor(np.array([[1, 2, 7, 42], [3, 4, 54, 22], [2, 2, 55, 3]]), mindspore.float32)
821        >>> input_indices = Tensor(np.array([1, 2]), mindspore.int32)
822        >>> axis = 1
823        >>> output = ops.Gather()(input_params, input_indices, axis)
824        >>> print(output)
825        [[ 2.  7.]
826         [ 4. 54.]
827         [ 2. 55.]]
828        >>> axis = 0
829        >>> output = ops.Gather()(input_params, input_indices, axis)
830        >>> print(output)
831        [[3. 4. 54. 22.]
832         [2. 2. 55.  3.]]
833    """
834
835    @prim_attr_register
836    def __init__(self):
837        """Initialize Gather"""
838        self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output'])
839
840
841class GatherV2(PrimitiveWithCheck):
842    """
843    Same as operator Gather. GatherV2 will be deprecated in the future.
844    Please use Gather instead.
845    """
846
847    @deprecated("1.1", "Gather", True)
848    @prim_attr_register
849    def __init__(self):
850        """Initialize GatherV2"""
851        self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output'])
852
853    def __check__(self, params, indices, axis):
854        validator.check_subclass("params", params['dtype'], mstype.tensor, self.name)
855        validator.check_tensor_dtype_valid("indices", indices['dtype'], mstype.int_type, self.name)
856        validator.check_subclass("axis", axis['dtype'], [mstype.number], self.name)
857        axis_v = axis['value']
858        validator.check_value_type('axis', axis_v, [int], self.name)
859        rank = len(params['shape'])
860        validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, "axis", self.name)
861
862
863class SparseGatherV2(PrimitiveWithCheck):
864    """
865    Returns a slice of input tensor based on the specified indices and axis.
866
867    Inputs:
868        - **input_params** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
869        - **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`.
870          Specifies the indices of elements of the original Tensor, must be in the range
871          `[0, input_param.shape[axis])`.
872        - **axis** (int) - Specifies the dimension index to gather indices.
873
874    Outputs:
875        Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`.
876
877    Raises:
878        TypeError: If `axis` is not an int.
879
880    Supported Platforms:
881        ``Ascend`` ``GPU``
882
883    Examples:
884        >>> input_params = Tensor(np.array([[1, 2, 7, 42], [3, 4, 54, 22], [2, 2, 55, 3]]), mindspore.float32)
885        >>> input_indices = Tensor(np.array([1, 2]), mindspore.int32)
886        >>> axis = 1
887        >>> out = ops.SparseGatherV2()(input_params, input_indices, axis)
888        >>> print(out)
889        [[2. 7.]
890         [4. 54.]
891         [2. 55.]]
892    """
893
894    @prim_attr_register
895    def __init__(self):
896        """Initialize SparseGatherV2"""
897        self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output'])
898
899
900    def __check__(self, params, indices, axis):
901        validator.check_subclass("params", params['dtype'], mstype.tensor, self.name)
902        validator.check_tensor_dtype_valid("indices", indices['dtype'], mstype.int_type, self.name)
903        validator.check_subclass("axis", axis['dtype'], [mstype.number], self.name)
904        axis_v = axis['value']
905        validator.check_value_type('axis', axis_v, [int], self.name)
906        rank = len(params['shape'])
907        validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, "axis", self.name)
908
909
910
911class Padding(PrimitiveWithInfer):
912    """
913    Extends the last dimension of the input tensor from 1 to pad_dim_size, by filling with 0.
914
915    Args:
916        pad_dim_size (int): The value of the last dimension of `x` to be extended, which must be positive. Default: 8.
917
918    Inputs:
919        - **x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. The rank of `x` must be at least 2.
920          The last dimension of `x` must be 1. The data type is Number.
921
922    Outputs:
923        Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`.
924
925    Raises:
926        TypeError: If `pad_dim_size` is not an int.
927        ValueError: If `pad_dim_size` is less than 1.
928        ValueError: If last dim of `x` is not equal 1.
929
930    Supported Platforms:
931        ``Ascend``
932
933    Examples:
934        >>> x = Tensor(np.array([[8], [10]]), mindspore.float32)
935        >>> pad_dim_size = 4
936        >>> output = ops.Padding(pad_dim_size)(x)
937        >>> print(output)
938        [[ 8.  0.  0.  0.]
939         [10.  0.  0.  0.]]
940    """
941
942    @prim_attr_register
943    def __init__(self, pad_dim_size=8):
944        """Initialize padding"""
945        validator.check_value_type("pad_dim_size", pad_dim_size, [int], self.name)
946        validator.check_positive_int(pad_dim_size, "pad_dim_size", self.name)
947        self.pad_dim_size = pad_dim_size
948
949    def __infer__(self, x):
950        validator.check_subclass("x", x['dtype'], mstype.tensor, self.name)
951        x_shape = list(x['shape'])
952        validator.check_int(len(x_shape), 1, Rel.GT, "rank of x", self.name)
953        validator.check_int(x_shape[-1], 1, Rel.EQ, "last dim of x", self.name)
954        out_shape = x_shape
955        out_shape[-1] = self.pad_dim_size
956        out = {'shape': out_shape,
957               'dtype': x['dtype'],
958               'value': None}
959        return out
960
961
962class UniqueWithPad(PrimitiveWithInfer):
963    """
964    Returns unique elements and relative indexes in 1-D tensor, filled with padding num.
965
966    The basic function is the same as the Unique operator, but the UniqueWithPad operator adds a Pad function.
967    The returned tuple(`y`,`idx`) after the input Tensor `x` is processed by the unique operator,
968    in which the shapes of `y` and `idx` are mostly not equal. Therefore, in order to solve the above situation,
969    the UniqueWithPad operator will fill the `y` Tensor with the `pad_num` specified by the user
970    to make it have the same shape as the Tensor `idx`.
971
972    Inputs:
973        - **x** (Tensor) - The tensor need to be unique. Must be 1-D vector with types: int32, int64.
974        - **pad_num** (int) - Pad num. The data type is an int.
975
976    Outputs:
977        tuple(Tensor), tuple of 2 tensors, `y` and `idx`.
978        - y (Tensor) - The unique elements filled with pad_num, the shape and data type same as `x`.
979        - idx (Tensor) - The index of each value of `x` in the unique output `y`, the shape and data type same as `x`.
980
981    Raises:
982        TypeError: If dtype of `x` is neither int32 nor int64.
983        ValueError: If length of shape of `x` is not equal to 1.
984
985    Supported Platforms:
986        ``Ascend`` ``CPU``
987
988    Examples:
989        >>> x = Tensor(np.array([1, 1, 5, 5, 4, 4, 3, 3, 2, 2,]), mindspore.int32)
990        >>> pad_num = 8
991        >>> output = ops.UniqueWithPad()(x, pad_num)
992        >>> print(output)
993        (Tensor(shape=[10], dtype=Int32, value= [1, 5, 4, 3, 2, 8, 8, 8, 8, 8]),
994         Tensor(shape=[10], dtype=Int32, value= [0, 0, 1, 1, 2, 2, 3, 3, 4, 4]))
995    """
996
997    @prim_attr_register
998    def __init__(self):
999        """init UniqueWithPad"""
1000
1001    def __infer__(self, x, pad_num):
1002        validator.check_tensor_dtype_valid("x", x['dtype'], [mstype.int32, mstype.int64], self.name)
1003        validator.check_subclass("pad_num", pad_num['dtype'], [mstype.int32, mstype.int64], self.name)
1004        x_shape = list(x['shape'])
1005        validator.check("rank of x", len(x_shape), "expected", 1, Rel.EQ, self.name)
1006        out_shape = x_shape
1007        out = {'shape': (out_shape, out_shape),
1008               'dtype': (x['dtype'], x['dtype']),
1009               'value': None}
1010        return out
1011
1012
1013class Split(PrimitiveWithCheck):
1014    """
1015    Splits the input tensor into output_num of tensors along the given axis and output numbers.
1016
1017    The `input_x` tensor will be split into equally sized sub-tensors.
1018    This requires that `input_x.shape(axis)` is divisible by `output_num`.
1019
1020    Args:
1021        axis (int): Index of the split position. Default: 0.
1022        output_num (int): The number of output tensors. Must be positive int. Default: 1.
1023
1024    Inputs:
1025        - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
1026
1027    Outputs:
1028        tuple[Tensor], the shape of each output tensor is the same, which is
1029        :math:`(y_1, y_2, ..., y_S)`. And the data type is the same with `input_x`.
1030
1031    Raises:
1032        TypeError: If `axis` or `output_num` is not an int.
1033        ValueError: If `axis` is out of the range [-len(`input_x.shape`), len(`input_x.shape`)),
1034            or if the `output_num` is less than or equal to 0.
1035
1036    Supported Platforms:
1037        ``Ascend`` ``GPU`` ``CPU``
1038
1039    Examples:
1040        >>> split = ops.Split(1, 2)
1041        >>> x = Tensor(np.array([[1, 1, 1, 1], [2, 2, 2, 2]]), mindspore.int32)
1042        >>> print(x)
1043        [[1 1 1 1]
1044         [2 2 2 2]]
1045        >>> output = split(x)
1046        >>> print(output)
1047        (Tensor(shape=[2, 2], dtype=Int32, value=
1048        [[1, 1],
1049         [2, 2]]), Tensor(shape=[2, 2], dtype=Int32, value=
1050        [[1, 1],
1051         [2, 2]]))
1052        >>> split = ops.Split(1, 4)
1053        >>> output = split(x)
1054        >>> print(output)
1055        (Tensor(shape=[2, 1], dtype=Int32, value=
1056        [[1],
1057         [2]]), Tensor(shape=[2, 1], dtype=Int32, value=
1058        [[1],
1059         [2]]), Tensor(shape=[2, 1], dtype=Int32, value=
1060        [[1],
1061         [2]]), Tensor(shape=[2, 1], dtype=Int32, value=
1062        [[1],
1063         [2]]))
1064    """
1065
1066    @prim_attr_register
1067    def __init__(self, axis=0, output_num=1):
1068        """Initialize Split"""
1069        validator.check_value_type("axis", axis, [int], self.name)
1070        validator.check_value_type("output_num", output_num, [int], self.name)
1071        validator.check_positive_int(output_num, "output_num", self.name)
1072        self.axis = axis
1073        self.output_num = output_num
1074
1075    def __check__(self, x):
1076        validator.check_subclass("x", x['dtype'], mstype.tensor, self.name)
1077        x_shape = list(x['shape'])
1078        dim = len(x_shape)
1079        validator.check_int_range(self.axis, -dim, dim, Rel.INC_LEFT, 'axis value', self.name)
1080        if -1 not in x_shape:
1081            # only validate when shape fully known
1082            output_valid_check = x_shape[self.axis] % self.output_num
1083            if output_valid_check != 0:
1084                raise ValueError(f"For '{self.name}', the specified axis of 'input_x' should be divided exactly by "
1085                                 f"'output_num', but got the shape of 'input_x' in 'axis' {self.axis} is "
1086                                 f"{x_shape[self.axis]}, 'output_num': {self.output_num}.")
1087        size_splits = [x_shape[self.axis] // self.output_num] * self.output_num
1088        self.add_prim_attr('size_splits', size_splits)
1089
1090
1091class Rank(PrimitiveWithInfer):
1092    """
1093    Returns the rank of a tensor.
1094
1095    Returns a 0-D int32 Tensor representing the rank of input; the rank of a tensor
1096    is the number of indices required to uniquely select each element of the tensor.
1097
1098    Inputs:
1099        - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. The data type is Number.
1100
1101    Outputs:
1102        Tensor. 0-D int32 Tensor representing the rank of input, i.e., :math:`R`. The data type is an int.
1103
1104    Raises:
1105        TypeError: If `input_x` is not a Tensor.
1106
1107    Supported Platforms:
1108        ``Ascend`` ``GPU`` ``CPU``
1109
1110    Examples:
1111        >>> input_tensor = Tensor(np.array([[2, 2], [2, 2]]), mindspore.float32)
1112        >>> rank = ops.Rank()
1113        >>> output = rank(input_tensor)
1114        >>> print(output)
1115        2
1116        >>> print(type(output))
1117        <class 'int'>
1118    """
1119
1120    @prim_attr_register
1121    def __init__(self):
1122        """Initialize Rank"""
1123
1124    def __infer__(self, x):
1125        validator.check_subclass("x", x['dtype'], mstype.tensor, self.name)
1126        out = {'shape': None,
1127               'dtype': None,
1128               'value': len(x['shape'])}
1129        return out
1130
1131
1132class TruncatedNormal(PrimitiveWithInfer):
1133    """
1134    Returns a tensor of the specified shape filled with truncated normal values.
1135
1136    The generated values follow a normal distribution.
1137
1138    Args:
1139        seed (int): A integer number used to create random seed. Default: 0.
1140        dtype (:class:`mindspore.dtype`): Data type. Default: mindspore.float32.
1141
1142    Inputs:
1143        - **shape** (tuple[int]) - The shape of the output tensor, is a tuple of positive integer.
1144
1145    Outputs:
1146        Tensor, the data type of output tensor is the same as attribute `dtype`.
1147
1148    Examples:
1149        >>> shape = (1, 2, 3)
1150        >>> truncated_normal = ops.TruncatedNormal()
1151        >>> output = truncated_normal(shape)
1152    """
1153
1154    @prim_attr_register
1155    def __init__(self, seed=0, dtype=mstype.float32):
1156        """Initialize TruncatedNormal"""
1157        validator.check_value_type('seed', seed, [int], self.name)
1158        validator.check_types_same_and_valid({'dtype': dtype}, mstype.number_type, self.name)
1159
1160    def __infer__(self, shape):
1161        shape_value = shape['value']
1162        validator.check_value_type("shape", shape_value, [tuple], self.name)
1163        for i, value in enumerate(shape_value):
1164            validator.check_positive_int(value, f'{i}th value of shape', self.name)
1165        out = {'shape': shape_value,
1166               'dtype': mstype.tensor_type(self.dtype),
1167               'value': None}
1168        return out
1169
1170
1171class Size(PrimitiveWithInfer):
1172    r"""
1173    Returns the size of a tensor.
1174
1175    Returns an int scalar representing the elements size of input, the total number of elements in the tensor.
1176
1177    Inputs:
1178        - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. The data type is Number.
1179
1180    Outputs:
1181        int. A scalar representing the elements size of `input_x`, tensor is the number of elements
1182        in a tensor, :math:`size=x_1*x_2*...x_R`. The data type is an int.
1183
1184    Raises:
1185        TypeError: If `input_x` is not a Tensor.
1186
1187    Supported Platforms:
1188        ``Ascend`` ``GPU`` ``CPU``
1189
1190    Examples:
1191        >>> input_x = Tensor(np.array([[2, 2], [2, 2]]), mindspore.float32)
1192        >>> size = ops.Size()
1193        >>> output = size(input_x)
1194        >>> print(output)
1195        4
1196    """
1197
1198    @prim_attr_register
1199    def __init__(self):
1200        """Initialize Size"""
1201
1202    def __infer__(self, x):
1203        size = 1
1204        validator.check_subclass("x", x['dtype'], mstype.tensor, self.name)
1205        shp = x['shape']
1206        if not shp:
1207            size = 0
1208        else:
1209            size = functools.reduce(lambda x, y: x * y, x['shape'])
1210        out = {'shape': None,
1211               'dtype': mstype.int64,
1212               'value': size}
1213        return out
1214
1215
1216class Fill(PrimitiveWithInfer):
1217    """
1218    Creates a tensor filled with a scalar value.
1219
1220    Creates a tensor with shape described by the first argument and fills it with values in the second argument.
1221
1222    Inputs:
1223        - **type** (mindspore.dtype) - The specified type of output tensor. Only constant value is allowed.
1224        - **shape** (tuple) - The specified shape of output tensor. Only constant value is allowed.
1225        - **value** (scalar) - Value to fill the returned tensor. Only constant value is allowed.
1226
1227    Outputs:
1228        Tensor, has the same type and shape as input value.
1229
1230    Raises:
1231        TypeError: If `shape` is not a tuple.
1232
1233    Supported Platforms:
1234        ``Ascend`` ``GPU`` ``CPU``
1235
1236    Examples:
1237        >>> fill = ops.Fill()
1238        >>> output = fill(mindspore.float32, (2, 2), 1)
1239        >>> print(output)
1240        [[1. 1.]
1241         [1. 1.]]
1242        >>> output = fill(mindspore.float32, (3, 3), 0)
1243        >>> print(output)
1244        [[0. 0. 0.]
1245         [0. 0. 0.]
1246         [0. 0. 0.]]
1247    """
1248
1249    @prim_attr_register
1250    def __init__(self):
1251        """Initialize Fill"""
1252
1253    def __infer__(self, dtype, dims, x):
1254        validator.check_value_type("shape", dims['value'], [tuple], self.name)
1255        validator.check_value_type("value", x['value'], [numbers.Number, bool], self.name)
1256        for i, item in enumerate(dims['value']):
1257            validator.check_positive_int(item, f'dims[{i}]', self.name)
1258        valid_dtypes = [mstype.bool_, mstype.int8, mstype.int16, mstype.int32, mstype.int64,
1259                        mstype.uint8, mstype.uint16, mstype.uint32, mstype.uint64,
1260                        mstype.float16, mstype.float32, mstype.float64]
1261        validator.check_types_same_and_valid({"value": dtype['value']}, valid_dtypes, self.name)
1262        x_nptype = mstype.dtype_to_nptype(dtype['value'])
1263        ret = np.full(dims['value'], x['value'], x_nptype)
1264        out = {
1265            'value': Tensor(ret),
1266            'shape': dims['value'],
1267            'dtype': x['dtype'],
1268        }
1269        return out
1270
1271
1272class Ones(Primitive):
1273    r"""
1274    Creates a tensor filled with value ones.
1275
1276    Creates a tensor with shape described by the first argument and
1277    fills it with value ones in type of the second argument.
1278
1279    Inputs:
1280        - **shape** (Union[tuple[int], int]) - The specified shape of output tensor.
1281          Only constant positive int is allowed.
1282        - **type** (mindspore.dtype) - The specified type of output tensor. Only constant value is allowed.
1283
1284    Outputs:
1285        Tensor, has the same type and shape as input shape value.
1286
1287    Raises:
1288        TypeError: If `shape` is neither tuple nor int.
1289
1290    Supported Platforms:
1291        ``Ascend`` ``GPU`` ``CPU``
1292
1293    Examples:
1294        >>> ones = ops.Ones()
1295        >>> output = ones((2, 2), mindspore.float32)
1296        >>> print(output)
1297        [[1. 1.]
1298         [1. 1.]]
1299        >>> output = ones((3, 3), mindspore.float32)
1300        >>> print(output)
1301        [[1. 1. 1.]
1302         [1. 1. 1.]
1303         [1. 1. 1.]]
1304    """
1305
1306    @prim_attr_register
1307    def __init__(self):
1308        """Initialize Ones"""
1309
1310
1311class Zeros(Primitive):
1312    r"""
1313    Creates a tensor filled with value zeros.
1314
1315    Creates a tensor with shape described by the first argument and
1316    fills it with value zeros in type of the second argument.
1317
1318    Inputs:
1319        - **shape** (Union[tuple[int], int]) - The specified shape of output tensor.
1320          Only constant positive int is allowed.
1321        - **type** (mindspore.dtype) - The specified type of output tensor. Only constant value is allowed.
1322
1323    Outputs:
1324        Tensor, has the same type and shape as input shape value.
1325
1326    Raises:
1327        TypeError: If `shape` is neither int nor tuple.
1328        TypeError: If `shape` is a tuple whose elements are not all int.
1329
1330    Supported Platforms:
1331        ``Ascend`` ``GPU`` ``CPU``
1332
1333    Examples:
1334        >>> zeros = ops.Zeros()
1335        >>> output = zeros((2, 2), mindspore.float32)
1336        >>> print(output)
1337        [[0. 0.]
1338         [0. 0.]]
1339
1340    """
1341
1342    @prim_attr_register
1343    def __init__(self):
1344        """Initialize Zeros"""
1345
1346
1347class OnesLike(Primitive):
1348    """
1349    Creates a new tensor. The values of all elements are 1.
1350
1351    Returns a tensor of ones with the same shape and type as the input.
1352
1353    Inputs:
1354        - **input_x** (Tensor) - Input tensor.
1355          The shape is :math:`(N,*)` where :math:`*` means, any number of additional dimensions.
1356
1357    Outputs:
1358        Tensor, has the same shape and type as `input_x` but filled with ones.
1359
1360    Raises:
1361        TypeError: If `input_x` is not a Tensor.
1362
1363    Supported Platforms:
1364        ``Ascend`` ``GPU`` ``CPU``
1365
1366    Examples:
1367        >>> oneslike = ops.OnesLike()
1368        >>> input_x = Tensor(np.array([[0, 1], [2, 1]]).astype(np.int32))
1369        >>> output = oneslike(input_x)
1370        >>> print(output)
1371        [[1 1]
1372         [1 1]]
1373    """
1374
1375    @prim_attr_register
1376    def __init__(self):
1377        """Initialize OnesLike"""
1378
1379
1380class ZerosLike(Primitive):
1381    """
1382    Creates a new tensor. All elements value are 0.
1383
1384    Returns a tensor of zeros with the same shape and data type as the input tensor.
1385
1386    Inputs:
1387        - **input_x** (Tensor) - Input tensor. The data type is int32, int64, float16 or float32.
1388          The shape is :math:`(N,*)` where :math:`*` means, any number of additional dimensions.
1389
1390    Outputs:
1391        Tensor, has the same shape and data type as `input_x` but filled with zeros.
1392
1393    Raises:
1394        TypeError: If `input_x` is not a Tensor.
1395
1396    Supported Platforms:
1397        ``Ascend`` ``GPU`` ``CPU``
1398
1399    Examples:
1400        >>> zeroslike = ops.ZerosLike()
1401        >>> input_x = Tensor(np.array([[0, 1], [2, 1]]).astype(np.float32))
1402        >>> output = zeroslike(input_x)
1403        >>> print(output)
1404        [[0. 0.]
1405         [0. 0.]]
1406    """
1407
1408    @prim_attr_register
1409    def __init__(self):
1410        """Initialize ZerosLike"""
1411        self.init_prim_io_names(inputs=['x'], outputs=['y'])
1412
1413
1414class TupleToArray(PrimitiveWithInfer):
1415    """
1416    Converts a tuple to a tensor.
1417
1418    If the type of the first number in the tuple is integer, the data type of the output tensor is int.
1419    Otherwise, the data type of the output tensor is float.
1420
1421    Inputs:
1422        - **input_x** (tuple) - A tuple of numbers. These numbers have the same type. Only constant value is allowed.
1423          The shape is :math:`(N,*)` where :math:`*` means,any number of additional dimensions.
1424
1425    Outputs:
1426        Tensor, if the input tuple contains `N` numbers, then the shape of the output tensor is (N,).
1427
1428    Raises:
1429        TypeError: If `input_x` is not a tuple.
1430        ValueError: If length of `input_x` is less than or equal to 0.
1431
1432    Supported Platforms:
1433        ``Ascend`` ``GPU`` ``CPU``
1434
1435    Examples:
1436        >>> input_x = (1,2,3)
1437        >>> print(type(input_x))
1438        <class 'tuple'>
1439        >>> output = ops.TupleToArray()(input_x)
1440        >>> print(type(output))
1441        <class 'mindspore.common.tensor.Tensor'>
1442        >>> print(output)
1443        [1 2 3]
1444    """
1445
1446    @prim_attr_register
1447    def __init__(self):
1448        """Initialize TupleToArray"""
1449
1450    def infer_value(self, x):
1451        validator.check_value_type("x", x, [tuple], self.name)
1452        validator.check("size of x", len(x), '', 0, Rel.GT, self.name)
1453        dtype = type(x[0])
1454        for i, item in enumerate(x):
1455            validator.check_value_type(f"x[{i}]", item, [numbers.Number], self.name)
1456        if not all(isinstance(item, dtype) for item in x):
1457            raise TypeError(f"For \'{self.name}\', all elements of 'input_x' must be have same type.")
1458        if isinstance(x[0], int):
1459            ret = np.array(x, np.int32)
1460        else:
1461            ret = np.array(x, np.float32)
1462        return Tensor(ret)
1463
1464    def __call__(self, x):
1465        args = list()
1466        if isinstance(x, range):
1467            args.append(tuple(x))
1468        else:
1469            args.append(x)
1470        return _run_op(self, self.name, args)
1471
1472
1473class ScalarToArray(PrimitiveWithInfer):
1474    """
1475    Converts a scalar to a `Tensor`.
1476
1477    Inputs:
1478        - **input_x** (Union[int, float]) - The input is a scalar. Only constant value is allowed.
1479
1480    Outputs:
1481        Tensor. 0-D Tensor and the content is the input.
1482
1483    Raises:
1484        TypeError: If `input_x` is neither int nor float.
1485
1486    Supported Platforms:
1487        ``Ascend`` ``GPU`` ``CPU``
1488
1489    Examples:
1490        >>> op = ops.ScalarToArray()
1491        >>> input_x = 1.0
1492        >>> print(type(input_x))
1493        <class 'float'>
1494        >>> output = op(input_x)
1495        >>> print(type(output))
1496        <class 'mindspore.common.tensor.Tensor'>
1497        >>> print(output)
1498        1.0
1499    """
1500
1501    @prim_attr_register
1502    def __init__(self):
1503        pass
1504
1505    def infer_value(self, x):
1506        validator.check_value_type("x", x, [int, float], self.name)
1507        if isinstance(x, int):
1508            ret = np.array(x, np.int32)
1509        else:
1510            ret = np.array(x, np.float32)
1511        return Tensor(ret)
1512
1513
1514class ScalarToTensor(PrimitiveWithInfer):
1515    """
1516    Converts a scalar to a `Tensor`, and converts the data type to the specified type.
1517
1518    Inputs:
1519        - **input_x** (Union[int, float]) - The input is a scalar. Only constant value is allowed.
1520        - **dtype** (mindspore.dtype) - The target data type. Default: mindspore.float32. Only
1521          constant value is allowed.
1522
1523    Outputs:
1524        Tensor. 0-D Tensor and the content is the input.
1525
1526    Raises:
1527        TypeError: If `input_x` is neither int nor float.
1528
1529    Supported Platforms:
1530        ``Ascend`` ``GPU`` ``CPU``
1531
1532    Examples:
1533        >>> op = ops.ScalarToTensor()
1534        >>> data = 1
1535        >>> output = op(data, mindspore.float32)
1536        >>> print(output)
1537        1.0
1538    """
1539
1540    @prim_attr_register
1541    def __init__(self):
1542        pass
1543
1544    def infer_value(self, x, dtype=mstype.float32):
1545        validator.check_value_type("x", x, [int, float], self.name)
1546        validator.check_subclass("dtype", dtype, mstype.number, self.name)
1547        data_type = mstype.dtype_to_nptype(dtype)
1548        return Tensor(np.array(x, data_type))
1549
1550
1551class InvertPermutation(PrimitiveWithInfer):
1552    r"""
1553    Computes the inverse of an index permutation.
1554
1555    This operator is mainly used to calculate the inverse of index permutation.
1556    It requires a 1-dimensional integer tensor x, which represents the index of a zero-based array,
1557    and exchanges each value with its index position. In other words, For output tensor y and input tensor x,
1558    this operation calculates the following values:
1559
1560    :math:`y[x[i]] = i, \quad i \in [0, 1, \ldots, \text{len}(x)-1]`.
1561
1562    Note:
1563        These values must include 0. There must be no duplicate values and the
1564        values can not be negative.
1565
1566    Inputs:
1567        - **input_x** (Union(tuple[int], list[int]) - The input is constructed by multiple
1568          integers, i.e., :math:`(y_1, y_2, ..., y_S)` representing the indices.
1569          The values must include 0. There can be no duplicate values or negative values.
1570          Only constant value is allowed. The maximum value must be equal to length of input_x.
1571
1572    Outputs:
1573        tuple[int]. It has the same length as the input.
1574
1575    Raises:
1576        TypeError: If `input_x` is neither tuple nor list.
1577        TypeError: If element of `input_x` is not an int.
1578
1579    Supported Platforms:
1580        ``Ascend`` ``GPU`` ``CPU``
1581
1582    Examples:
1583        >>> invert = ops.InvertPermutation()
1584        >>> input_data = (3, 4, 0, 2, 1)
1585        >>> output = invert(input_data)
1586        >>> print(output)
1587        (2, 4, 3, 0, 1)
1588    """
1589
1590    @prim_attr_register
1591    def __init__(self):
1592        """Initialize InvertPermutation"""
1593        self.set_const_prim(True)
1594
1595    def __infer__(self, x):
1596        x_shp = x['shape']
1597        x_value = x['value']
1598        if x_value is None:
1599            raise ValueError(f"For '{self.name}', the value of 'input_x' can not be None, but got {x_value}.")
1600        validator.check_value_type("shape", x_shp, [tuple, list], self.name)
1601        if mstype.issubclass_(x['dtype'], mstype.tensor):
1602            raise ValueError(f"For \'{self.name}\', the value of 'input_x' must be non-Tensor, but got {x['dtype']}")
1603        for shp in x_shp:
1604            if shp:
1605                x_rank = len(np.array(x_value, np.int64).shape)
1606                raise ValueError(f"For \'{self.name}\', the dimension of 'input_x' must be 1, but got {x_rank}.")
1607        for i, value in enumerate(x_value):
1608            validator.check_value_type("input[%d]" % i, value, [int], self.name)
1609        z = [x_value[i] for i in range(len(x_value))]
1610        z.sort()
1611
1612        for i in range(1, len(z)):
1613            if z[i - 1] == z[i]:
1614                raise ValueError(f"For '{self.name}', the 'input_x' can not contain duplicate values, "
1615                                 f"but got duplicated {z[i]} in the 'input_x'.")
1616        validator.check(f'value min', min(x_value), '', 0, Rel.EQ, self.name)
1617        validator.check(f'value max', max(x_value), '', len(x_value) - 1, Rel.EQ, self.name)
1618
1619        y = [None] * len(x_value)
1620        for i, value in enumerate(x_value):
1621            validator.check_value_type("input[%d]" % i, value, [int], self.name)
1622            validator.check(f'value', z[i], f'index', i, Rel.EQ, self.name)
1623            y[value] = i
1624            z.append(value)
1625        return {'shape': x_shp,
1626                'dtype': x['dtype'],
1627                'value': tuple(y)}
1628
1629
1630class Argmax(PrimitiveWithInfer):
1631    """
1632    Returns the indices of the maximum value of a tensor across the axis.
1633
1634    If the shape of input tensor is :math:`(x_1, ..., x_N)`, the shape of the output tensor will be
1635    :math:`(x_1, ..., x_{axis-1}, x_{axis+1}, ..., x_N)`.
1636
1637    Args:
1638        axis (int): Axis where the Argmax operation applies to. Default: -1.
1639        output_type (:class:`mindspore.dtype`): An optional data type of `mindspore.dtype.int32`.
1640            Default: `mindspore.dtype.int32`.
1641
1642    Inputs:
1643        - **input_x** (Tensor) - Input tensor. :math:`(N,*)` where :math:`*` means, any number of additional dimensions.
1644          Support data type list as follows:
1645
1646          - Ascend: Float16, Float32.
1647          - GPU: Float16, Float32.
1648          - CPU: Float16, Float32, Float64.
1649
1650    Outputs:
1651        Tensor, indices of the max value of input tensor across the axis.
1652
1653    Raises:
1654        TypeError: If `axis` is not an int.
1655        TypeError: If `output_type` is neither int32 nor int64.
1656
1657    Supported Platforms:
1658        ``Ascend`` ``GPU`` ``CPU``
1659
1660    Examples:
1661        >>> input_x = Tensor(np.array([[1, 20, 5], [67, 8, 9], [130, 24, 15]]).astype(np.float32))
1662        >>> output = ops.Argmax(output_type=mindspore.int32)(input_x)
1663        >>> print(output)
1664        [1 0 0]
1665    """
1666
1667    @prim_attr_register
1668    def __init__(self, axis=-1, output_type=mstype.int32):
1669        """Initialize Argmax"""
1670        self.init_prim_io_names(inputs=['x'], outputs=['output'])
1671        validator.check_value_type("axis", axis, [int], self.name)
1672        validator.check_types_same_and_valid({'output': output_type}, [mstype.int32], self.name)
1673        self.axis = axis
1674        self.add_prim_attr('output_type', output_type)
1675
1676    def infer_shape(self, x_shape):
1677        axis = self.axis
1678        if axis is None:
1679            axis = 0
1680        x_rank = len(x_shape)
1681        validator.check_int_range(axis, -x_rank, x_rank, Rel.INC_LEFT, "axis", self.name)
1682        axis = axis + x_rank if axis < 0 else axis
1683        ouput_shape = [x_shape[i] for i in range(x_rank) if i != axis]
1684        return ouput_shape
1685
1686    def infer_dtype(self, x_dtype):
1687        validator.check_subclass("input_x", x_dtype, mstype.tensor, self.name)
1688        return mstype.tensor_type(self.output_type)
1689
1690
1691class Argmin(PrimitiveWithInfer):
1692    """
1693    Returns the indices of the minimum value of a tensor across the axis.
1694
1695    If the shape of input tensor is :math:`(x_1, ..., x_N)`, the shape of the output tensor is
1696    :math:`(x_1, ..., x_{axis-1}, x_{axis+1}, ..., x_N)`.
1697
1698    Args:
1699        axis (int): Axis where the Argmin operation applies to. Default: -1.
1700        output_type (:class:`mindspore.dtype`): An optional data type of `mindspore.dtype.int32`.
1701            Default: `mindspore.dtype.int32`.
1702
1703    Inputs:
1704        - **input_x** (Tensor) - Input tensor.
1705          The shape is :math:`(N,*)` where :math:`*` means, any number of additional dimensions.
1706
1707    Outputs:
1708        Tensor, indices of the min value of input tensor across the axis.
1709
1710    Raises:
1711        TypeError: If `axis` is not an int.
1712        TypeError: If `output_type` is neither int32 nor int64.
1713
1714    Supported Platforms:
1715        ``Ascend``
1716
1717    Examples:
1718        >>> input_x = Tensor(np.array([2.0, 3.1, 1.2]), mindspore.float32)
1719        >>> index = ops.Argmin()(input_x)
1720        >>> print(index)
1721        2
1722    """
1723
1724    @prim_attr_register
1725    def __init__(self, axis=-1, output_type=mstype.int32):
1726        """Initialize Argmin"""
1727        self.init_prim_io_names(inputs=['x'], outputs=['output'])
1728        validator.check_value_type("axis", axis, [int], self.name)
1729        validator.check_type_name("output_type", output_type, [mstype.int32, mstype.int64], self.name)
1730        self.axis = axis
1731        self.add_prim_attr('output_type', output_type)
1732
1733    def infer_shape(self, x_shape):
1734        axis = self.axis
1735        if axis is None:
1736            axis = 0
1737        x_rank = len(x_shape)
1738        validator.check_int_range(axis, -x_rank, x_rank, Rel.INC_LEFT, "axis", self.name)
1739        axis = axis + x_rank if axis < 0 else axis
1740        ouput_shape = [x_shape[i] for i in range(x_rank) if i != axis]
1741        return ouput_shape
1742
1743    def infer_dtype(self, x_dtype):
1744        validator.check_subclass("input_x", x_dtype, mstype.tensor, self.name)
1745        return mstype.tensor_type(self.output_type)
1746
1747
1748class ArgMaxWithValue(PrimitiveWithInfer):
1749    """
1750    Calculates the maximum value with the corresponding index.
1751
1752    Calculates the maximum value along with the given axis for the input tensor. It returns the maximum values and
1753    indices.
1754
1755    Note:
1756        In auto_parallel and semi_auto_parallel mode, the first output index can not be used.
1757
1758    .. warning::
1759        - If there are multiple maximum values, the index of the first maximum value is used.
1760        - The value range of "axis" is [-dims, dims - 1]. "dims" is the dimension length of "input_x".
1761
1762    Args:
1763        axis (int): The dimension to reduce. Default: 0.
1764        keep_dims (bool): Whether to reduce dimension, if true, the output will keep same dimension with the input,
1765                          the output will reduce dimension if false. Default: False.
1766
1767    Inputs:
1768        - **input_x** (Tensor) - The input tensor, can be any dimension. Set the shape of input tensor as
1769          :math:`(x_1, x_2, ..., x_N)`. And the data type only support mindspore.float16 or float32.
1770
1771    Outputs:
1772        tuple (Tensor), tuple of 2 tensors, containing the corresponding index and the maximum value of the input
1773        tensor.
1774        - index (Tensor) - The index for the maximum value of the input tensor. If `keep_dims` is true, the shape of
1775        output tensors is :math:`(x_1, x_2, ..., x_{axis-1}, 1, x_{axis+1}, ..., x_N)`. Otherwise, the shape is
1776        :math:`(x_1, x_2, ..., x_{axis-1}, x_{axis+1}, ..., x_N)`.
1777        - output_x (Tensor) - The maximum value of input tensor, with the same shape as index.
1778
1779    Raises:
1780        TypeError: If `keep_dims` is not a bool.
1781        TypeError: If `axis` is not an int.
1782
1783    Supported Platforms:
1784        ``Ascend`` ``GPU`` ``CPU``
1785
1786    Examples:
1787        >>> input_x = Tensor(np.array([0.0, 0.4, 0.6, 0.7, 0.1]), mindspore.float32)
1788        >>> index, output = ops.ArgMaxWithValue()(input_x)
1789        >>> print(index, output)
1790        3 0.7
1791        >>> index, output = ops.ArgMaxWithValue(keep_dims=True)(input_x)
1792        >>> print(index, output)
1793        [3] [0.7]
1794    """
1795
1796    @prim_attr_register
1797    def __init__(self, axis=0, keep_dims=False):
1798        """Initialize ArgMaxWithValue"""
1799        self.axis = axis
1800        self.keep_dims = keep_dims
1801        validator.check_value_type('keep_dims', keep_dims, [bool], self.name)
1802        validator.check_value_type('axis', axis, [int], self.name)
1803
1804    def infer_shape(self, x_shape):
1805        axis = self.axis
1806        x_rank = len(x_shape)
1807        validator.check_int_range(axis, -x_rank, x_rank, Rel.INC_LEFT, "axis", self.name)
1808        ouput_shape = _infer_shape_reduce(x_shape, self.axis, self.keep_dims, self.name)
1809        return ouput_shape, ouput_shape
1810
1811    def infer_dtype(self, x_dtype):
1812        validator.check_subclass("input_x", x_dtype, mstype.tensor, self.name)
1813        return mstype.tensor_type(mstype.int32), x_dtype
1814
1815
1816class ArgMinWithValue(PrimitiveWithInfer):
1817    """
1818    Calculates the minimum value with corresponding index, and returns indices and values.
1819
1820    Calculates the minimum value along with the given axis for the input tensor. It returns the minimum values and
1821    indices.
1822
1823    Note:
1824        In auto_parallel and semi_auto_parallel mode, the first output index can not be used.
1825
1826    .. warning::
1827        - If there are multiple minimum values, the index of the first minimum value is used.
1828        - The value range of "axis" is [-dims, dims - 1]. "dims" is the dimension length of "input_x".
1829
1830    Args:
1831        axis (int): The dimension to reduce. Default: 0.
1832        keep_dims (bool): Whether to reduce dimension, if true the output will keep the same dimension as the input,
1833                          the output will reduce dimension if false. Default: False.
1834
1835    Inputs:
1836        - **input_x** (Tensor) - The input tensor, can be any dimension. Set the shape of input tensor as
1837          :math:`(x_1, x_2, ..., x_N)`.
1838
1839    Outputs:
1840        tuple (Tensor), tuple of 2 tensors, containing the corresponding index and the minimum value of the input
1841        tensor.
1842        - index (Tensor) - The index for the minimum value of the input tensor. If `keep_dims` is true, the shape of
1843        output tensors is :math:`(x_1, x_2, ..., x_{axis-1}, 1, x_{axis+1}, ..., x_N)`. Otherwise, the shape is
1844        :math:`(x_1, x_2, ..., x_{axis-1}, x_{axis+1}, ..., x_N)`.
1845        - output_x (Tensor) - The minimum value of input tensor, with the same shape as index.
1846
1847    Raises:
1848        TypeError: If `keep_dims` is not a bool.
1849        TypeError: If `axis` is not an int.
1850
1851    Supported Platforms:
1852        ``Ascend`` ``GPU`` ``CPU``
1853
1854    Examples:
1855        >>> input_x = Tensor(np.array([0.0, 0.4, 0.6, 0.7, 0.1]), mindspore.float32)
1856        >>> output = ops.ArgMinWithValue()(input_x)
1857        >>> print(output)
1858        (Tensor(shape=[], dtype=Int32, value= 0), Tensor(shape=[], dtype=Float32, value= 0))
1859        >>> output = ops.ArgMinWithValue(keep_dims=True)(input_x)
1860        >>> print(output)
1861        (Tensor(shape=[1], dtype=Int32, value= [0]), Tensor(shape=[1], dtype=Float32, value= [ 0.00000000e+00]))
1862    """
1863
1864    @prim_attr_register
1865    def __init__(self, axis=0, keep_dims=False):
1866        """Initialize ArgMinWithValue"""
1867        self.axis = axis
1868        self.keep_dims = keep_dims
1869        validator.check_value_type('keep_dims', keep_dims, [bool], self.name)
1870        validator.check_value_type('axis', axis, [int], self.name)
1871
1872    def infer_shape(self, x_shape):
1873        axis = self.axis
1874        x_rank = len(x_shape)
1875        validator.check_int_range(axis, -x_rank, x_rank, Rel.INC_LEFT, "axis", self.name)
1876        ouput_shape = _infer_shape_reduce(x_shape, self.axis, self.keep_dims, self.name)
1877        return ouput_shape, ouput_shape
1878
1879    def infer_dtype(self, x_dtype):
1880        validator.check_subclass("input_x", x_dtype, mstype.tensor, self.name)
1881        return mstype.tensor_type(mstype.int32), x_dtype
1882
1883
1884class Tile(PrimitiveWithInfer):
1885    r"""
1886    Replicates a tensor with given multiples times.
1887
1888    Creates a new tensor by replicating `input_x` `multiples` times. The i'th dimension of
1889    output tensor has `input_x.shape(i) * multiples[i]` elements, and the values of `input_x`
1890    are replicated `multiples[i]` times along the i'th dimension.
1891
1892    Note:
1893        The length of `multiples` must be greater or equal to the length of dimension in `input_x`.
1894
1895    Inputs:
1896        - **input_x** (Tensor) - 1-D or higher Tensor. Set the shape of input tensor as
1897          :math:`(x_1, x_2, ..., x_S)`.
1898
1899        - **multiples** (tuple[int]) - The input tuple is constructed by multiple
1900          integers, i.e., :math:`(y_1, y_2, ..., y_S)`. The length of `multiples`
1901          cannot be smaller than the length of the shape of `input_x`.
1902          Only constant value is allowed.
1903
1904    Outputs:
1905        Tensor, has the same data type as the `input_x`.
1906
1907        - If the length of `multiples` is the same as the length of shape of `input_x`,
1908          then the shape of their corresponding positions can be multiplied, and
1909          the shape of Outputs is :math:`(x_1*y_1, x_2*y_2, ..., x_S*y_R)`.
1910        - If the length of `multiples` is larger than the length of shape of `input_x`,
1911          fill in multiple 1 in the length of the shape of `input_x` until their lengths are consistent.
1912          Such as set the shape of `input_x` as :math:`(1, ..., x_1, x_2, ..., x_S)`,
1913          then the shape of their corresponding positions can be multiplied, and
1914          the shape of Outputs is :math:`(1*y_1, ..., x_S*y_R)`.
1915
1916    Raises:
1917        TypeError: If `multiples` is not a tuple or its elements are not all int.
1918        ValueError: If the elements of `multiples` are not all greater than 0.
1919        ValueError: If the length of `multiples` are smaller than the length of dimension in `input_x`.
1920
1921    Supported Platforms:
1922        ``Ascend`` ``GPU`` ``CPU``
1923
1924    Examples:
1925        >>> tile = ops.Tile()
1926        >>> input_x = Tensor(np.array([[1, 2], [3, 4]]), mindspore.float32)
1927        >>> multiples = (2, 3)
1928        >>> output = tile(input_x, multiples)
1929        >>> print(output)
1930        [[1.  2.  1.  2.  1.  2.]
1931         [3.  4.  3.  4.  3.  4.]
1932         [1.  2.  1.  2.  1.  2.]
1933         [3.  4.  3.  4.  3.  4.]]
1934        >>> multiples = (2, 3, 2)
1935        >>> output = tile(input_x, multiples)
1936        >>> print(output)
1937        [[[1. 2. 1. 2.]
1938          [3. 4. 3. 4.]
1939          [1. 2. 1. 2.]
1940          [3. 4. 3. 4.]
1941          [1. 2. 1. 2.]
1942          [3. 4. 3. 4.]]
1943         [[1. 2. 1. 2.]
1944          [3. 4. 3. 4.]
1945          [1. 2. 1. 2.]
1946          [3. 4. 3. 4.]
1947          [1. 2. 1. 2.]
1948          [3. 4. 3. 4.]]]
1949    """
1950
1951    @prim_attr_register
1952    def __init__(self):
1953        """Initialize Tile"""
1954        self.init_prim_io_names(inputs=['x', 'multiples'], outputs=['output'])
1955
1956    def check_elim(self, base_tensor, multiplier):
1957        if (not isinstance(base_tensor, Tensor)) or (not isinstance(multiplier, tuple)):
1958            raise TypeError(f"For '{self.name}', the type of ('input_x', 'multiples') should be (Tensor, tuple), "
1959                            f"but got ({type(base_tensor).__name__}, {type(multiplier).__name__}).")
1960        if all(v == 1 for v in multiplier):
1961            return (True, base_tensor)
1962        return (False, None)
1963
1964    def __infer__(self, x, multiples):
1965        multiples_v = multiples['value']
1966        x_shp = x['shape']
1967        validator.check_value_type("multiples", multiples_v, [tuple], self.name)
1968        for i, multiple in enumerate(multiples_v):
1969            validator.check_positive_int(multiple, "multiples[%d]" % i, self.name)
1970        validator.check_value_type("x[\'dtype\']", x["dtype"], mstype.tensor_type, self.name)
1971        len_sub = len(multiples_v) - len(x_shp)
1972        multiples_w = None
1973        if len_sub == 0:
1974            multiples_w = multiples_v
1975        if len_sub > 0:
1976            for i in range(0, len_sub):
1977                x_shp.insert(0, 1)
1978            multiples_w = multiples_v
1979        elif len_sub < 0:
1980            raise ValueError(f"For '{self.name}', the length of 'multiples' can not be smaller than "
1981                             f"the dimension of 'input_x', but got length of 'multiples': {len(multiples_v)} "
1982                             f"and dimension of 'input_x': {len(x_shp)}.")
1983        for i, a in enumerate(multiples_w):
1984            x_shp[i] *= a
1985        value = None
1986        if x['value'] is not None:
1987            value = Tensor(np.tile(x['value'].asnumpy(), multiples_w))
1988        return {'shape': x_shp,
1989                'dtype': x['dtype'],
1990                'value': value}
1991
1992
1993class UnsortedSegmentSum(PrimitiveWithInfer):
1994    r"""
1995    Computes the sum of a tensor along segments.
1996
1997    Calculates a tensor such that :math:`\text{output}[i] = \sum_{segment\_ids[j] == i} \text{data}[j, \ldots]`, where
1998    :math:`j` is a tuple describing the index of element in data.  `segment_ids` selects which elements in data to sum
1999    up. Segment_ids does not need to be sorted, and it does not need to cover all values in the entire valid value
2000    range.
2001
2002    The following figure shows the calculation process of UnsortedSegmentSum:
2003
2004    .. image:: api_img/UnsortedSegmentSum.png
2005
2006    Note:
2007        If the segment_id i is absent in the segment_ids, then output[i] will be filled with 0.
2008
2009    If the sum of the given segment_ids :math:`i` is empty, then :math:`\text{output}[i] = 0`. If the given segment_ids
2010    is negative, the value will be ignored. 'num_segments' must be equal to the number of different segment_ids.
2011
2012    Inputs:
2013        - **input_x** (Tensor) - The shape is :math:`(x_1, x_2, ..., x_R)`.
2014        - **segment_ids** (Tensor) - Set the shape as :math:`(x_1, x_2, ..., x_N)`, where 0 < N <= R.
2015        - **num_segments** (int) - Set :math:`z` as num_segments.
2016
2017    Outputs:
2018        Tensor, the shape is :math:`(z, x_{N+1}, ..., x_R)`.
2019
2020    Raises:
2021        TypeError: If `num_segments` is not an int.
2022        ValueError: If length of shape of `segment_ids` is less than 1.
2023
2024    Supported Platforms:
2025        ``Ascend`` ``GPU`` ``CPU``
2026
2027    Examples:
2028        >>> input_x = Tensor([1, 2, 3, 4], mindspore.float32)
2029        >>> segment_ids = Tensor([0, 0, 1, 2], mindspore.int32)
2030        >>> num_segments = 4
2031        >>> output = ops.UnsortedSegmentSum()(input_x, segment_ids, num_segments)
2032        >>> print(output)
2033        [3. 3. 4. 0.]
2034        >>> input_x = Tensor([1, 2, 3, 4, 2, 5], mindspore.float32)
2035        >>> segment_ids = Tensor([0, 0, 1, 2, 3, 4], mindspore.int32)
2036        >>> num_segments = 6
2037        >>> output = ops.UnsortedSegmentSum()(input_x, segment_ids, num_segments)
2038        >>> print(output)
2039        [3. 3. 4. 2. 5. 0.]
2040    """
2041
2042    @prim_attr_register
2043    def __init__(self):
2044        """Initialize UnsortedSegmentSum"""
2045        self.init_prim_io_names(inputs=['x', 'segment_ids', 'num_segments'], outputs=['y'])
2046
2047    def __infer__(self, x, segment_ids, num_segments):
2048        x_type = x['dtype']
2049        x_shp = x['shape']
2050        validator.check_subclass("input_x", x_type, mstype.tensor, self.name)
2051        validator.check_value_type("x_shape", x_shp, [list], self.name)
2052        x_shp_len = len(x_shp)
2053        validator.check_positive_int(x_shp_len, "rank of input_x", self.name)
2054        segment_ids_shp = segment_ids['shape']
2055        segment_ids_type = segment_ids['dtype']
2056        validator.check_subclass("segment_ids", segment_ids_type, mstype.tensor, self.name)
2057        validator.check_value_type("segment_ids", segment_ids_shp, [list], self.name)
2058        segment_ids_shp_len = len(segment_ids_shp)
2059        validator.check_positive_int(segment_ids_shp_len, "rank of segment_ids", self.name)
2060        validator.check(f'rank of input_x', len(x_shp),
2061                        'rank of segments_id', len(segment_ids_shp), Rel.GE, self.name)
2062        if -1 not in x_shp and -1 not in segment_ids_shp:
2063            # only validate when both shapes fully known
2064            for i, value in enumerate(segment_ids_shp):
2065                validator.check("ids[%d]" % i, value, 'input[%d]' % i, x_shp[i], Rel.EQ, self.name)
2066        num_segments_v = num_segments['value']
2067        num_segments_type = num_segments['dtype']
2068        validator.check_subclass("num_segments", num_segments_type, [mstype.tensor, mstype.number], self.name)
2069        if isinstance(num_segments_type, type(mstype.tensor)):
2070            validator.check_tensor_dtype_valid("num_segments", num_segments_type, [mstype.int32, mstype.int64],
2071                                               self.name)
2072            shp = [-1]
2073        else:
2074            validator.check_value_type('num_segments', num_segments_v, [int], self.name)
2075            validator.check_positive_int(num_segments_v, "num_segments", self.name)
2076            shp = [num_segments_v]
2077
2078        shp += x_shp[segment_ids_shp_len:]
2079        if "max_value" in num_segments and "min_value" in num_segments:
2080            output_max_shape = list(num_segments['max_value'])
2081            output_min_shape = list(num_segments['min_value'])
2082        else:
2083            if isinstance(num_segments_type, type(mstype.tensor)):
2084                raise ValueError(f"For '{self.name}', the dtype of 'num_segments' only support int type "
2085                                 f"when it is not a dynamic value, but got type of 'num_segments': "
2086                                 f"{num_segments_type}.")
2087            output_max_shape = [num_segments_v]
2088            output_min_shape = [num_segments_v]
2089        if 'max_shape' in x and 'min_shape' in x:
2090            max_output_incoming = x['max_shape']
2091            min_output_incoming = x['min_shape']
2092        else:
2093            max_output_incoming = x_shp
2094            min_output_incoming = x_shp
2095        output_max_shape += max_output_incoming[segment_ids_shp_len:]
2096        output_min_shape += min_output_incoming[segment_ids_shp_len:]
2097        return {'shape': shp,
2098                'max_shape': output_max_shape,
2099                'min_shape': output_min_shape,
2100                'dtype': mstype.tensor_type(x_type.element_type()),
2101                'value': None}
2102
2103
2104class UnsortedSegmentMin(PrimitiveWithCheck):
2105    r"""
2106    Computes the minimum of a tensor along segments.
2107
2108    The following figure shows the calculation process of UnsortedSegmentMin:
2109
2110    .. image:: api_img/UnsortedSegmentMin.png
2111
2112    .. math::
2113
2114        \text { output }_i=\text{min}_{j \ldots} \text { data }[j \ldots]
2115
2116    where :math:`min` over tuples :math:`j...` such that :math:`segment_ids[j...] == i`.
2117
2118    Note:
2119        If the segment_id i is absent in the segment_ids, then output[i] will be filled with
2120        the maximum value of the input_x's type.
2121        The `segment_ids` must be non-negative tensor.
2122
2123    Inputs:
2124        - **input_x** (Tensor) - The shape is :math:`(x_1, x_2, ..., x_R)`.
2125          The data type must be float16, float32 or int32.
2126        - **segment_ids** (Tensor) - A `1-D` tensor whose shape is :math:`(x_1)`, the value must be non-negative tensor.
2127          The data type must be int32.
2128        - **num_segments** (int) - The value specifies the number of distinct `segment_ids`.
2129
2130    Outputs:
2131        Tensor, set the number of `num_segments` as `N`, the shape is :math:`(N, x_2, ..., x_R)`.
2132
2133    Raises:
2134        TypeError: If `num_segments` is not an int.
2135        ValueError: If length of shape of `segment_ids` is not equal to 1.
2136
2137    Supported Platforms:
2138        ``Ascend`` ``GPU``
2139
2140    Examples:
2141        >>> input_x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [4, 2, 1]]).astype(np.float32))
2142        >>> segment_ids = Tensor(np.array([0, 1, 1]).astype(np.int32))
2143        >>> num_segments = 2
2144        >>> unsorted_segment_min = ops.UnsortedSegmentMin()
2145        >>> output = unsorted_segment_min(input_x, segment_ids, num_segments)
2146        >>> print(output)
2147        [[1. 2. 3.]
2148         [4. 2. 1.]]
2149    """
2150
2151    @prim_attr_register
2152    def __init__(self):
2153        """Initialize UnsortedSegmentMin"""
2154        self.init_prim_io_names(inputs=['x', 'segment_ids', 'num_segments'], outputs=['y'])
2155
2156    def __check__(self, x, segment_ids, num_segments):
2157        x_shape = x['shape']
2158        segment_ids_shape = segment_ids['shape']
2159        valid_type = [mstype.float16, mstype.float32, mstype.int32]
2160        validator.check_tensor_dtype_valid("x", x['dtype'], valid_type, self.name)
2161        validator.check_tensor_dtype_valid("segment_ids", segment_ids['dtype'], [mstype.int32], self.name)
2162        validator.check_equal_int(len(segment_ids_shape), 1, "rank of segment_ids_shape", self.name)
2163        num_segments_type = num_segments['dtype']
2164        validator.check_subclass("num_segments", num_segments_type, [mstype.number], self.name)
2165        if -1 not in x_shape and -1 not in segment_ids_shape:
2166            # only validate when both shapes fully known
2167            validator.check(f'first shape of input_x', x_shape[0],
2168                            'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name)
2169        num_segments_v = num_segments['value']
2170        validator.check_value_type('num_segments', num_segments_v, [int], self.name)
2171        validator.check_positive_int(num_segments_v, "num_segments", self.name)
2172
2173
2174class UnsortedSegmentMax(PrimitiveWithCheck):
2175    r"""
2176    Computes the maximum along segments of a tensor.
2177
2178    The following figure shows the calculation process of UnsortedSegmentMax:
2179
2180    .. image:: api_img/UnsortedSegmentMax.png
2181
2182    .. math::
2183
2184        \text { output }_i=\text{max}_{j \ldots} \text { data }[j \ldots]
2185
2186    where :math:`max` over tuples :math:`j...` such that :math:`segment\_ids[j...] == i`.
2187
2188    Note:
2189        If the segment_id i is absent in the segment_ids, then output[i] will be filled with
2190        the minimum value of the input_x's type.
2191
2192    Inputs:
2193        - **input_x** (Tensor) - The shape is :math:`(x_1, x_2, ..., x_R)`.
2194          The data type must be float16, float32 or int32.
2195        - **segment_ids** (Tensor) - A `1-D` tensor whose shape is :math:`(x_1)`, the value must be non-negative tensor.
2196          The data type must be int32.
2197        - **num_segments** (int) - The value specifies the number of distinct `segment_ids`.
2198
2199    Outputs:
2200        Tensor, set the number of `num_segments` as `N`, the shape is :math:`(N, x_2, ..., x_R)`.
2201
2202    Raises:
2203        TypeError: If `num_segments` is not an int.
2204        ValueError: If length of shape of `segment_ids` is not equal to 1.
2205
2206    Supported Platforms:
2207        ``Ascend`` ``GPU``
2208
2209    Examples:
2210        >>> # case 1: Only have two num_segments, where is 0 and 1, and segment_ids=[0, 1, 1]
2211        >>> # num_segments = 2 indicates that there are two types of segment_id,
2212        >>> # the first number '0' in [0, 1, 1] indicates input_x[0],
2213        >>> # the second number '1' in [0, 1, 1] indicates input_x[1],
2214        >>> # the third number '1' in [0, 1, 1] indicates input_x[2],
2215        >>> # input_x[0], which is [1, 2, 3] will not be compared to other segment_id.
2216        >>> # Only the same segment_id will be compared.
2217        >>> input_x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [4, 2, 1]]).astype(np.float32))
2218        >>> segment_ids = Tensor(np.array([0, 1, 1]).astype(np.int32))
2219        >>> num_segments = 2
2220        >>> unsorted_segment_max = ops.UnsortedSegmentMax()
2221        >>> output = unsorted_segment_max(input_x, segment_ids, num_segments)
2222        >>> print(output)
2223        [[1. 2. 3.]
2224         [4. 5. 6.]]
2225        >>>
2226        >>> # case 2: The segment_ids=[0, 0, 1, 1].
2227        >>> # [1, 2, 3] will compare with [4, 2, 0],
2228        >>> # and [4, 5, 6] will compare with [4, 2, 1].
2229        >>> input_x = Tensor(np.array([[1, 2, 3], [4, 2, 0], [4, 5, 6], [4, 2, 1]]).astype(np.float32))
2230        >>> segment_ids = Tensor(np.array([0, 0, 1, 1]).astype(np.int32))
2231        >>> num_segments = 2
2232        >>> unsorted_segment_max = ops.UnsortedSegmentMax()
2233        >>> output = unsorted_segment_max(input_x, segment_ids, num_segments)
2234        >>> print(input_x.shape)
2235            (4, 3)
2236        >>> print(output)
2237            [[4. 2. 3.]
2238             [4. 5. 6.]]
2239        >>> # case 3: If the input_x have three dimensions even more, what will happen?
2240        >>> # The shape of input_x is (2, 4, 3),
2241        >>> # and the length of segment_ids should be the same as the first dimension of input_x.
2242        >>> # Because the segment_ids are different, input_x[0] will not be compared to input_x[1].
2243        >>> input_x = Tensor(np.array([[[1, 2, 3], [4, 2, 0], [4, 5, 6], [4, 2, 1]],
2244        >>>                            [[1, 2, 3], [4, 2, 0], [4, 5, 6], [4, 2, 1]]]).astype(np.float32))
2245        >>> segment_ids = Tensor(np.array([0, 1]).astype(np.int32))
2246        >>> num_segments = 2
2247        >>> unsorted_segment_max = ops.UnsortedSegmentMax()
2248        >>> output = unsorted_segment_max(input_x, segment_ids, num_segments)
2249        >>> print(input_x.shape)
2250            (2, 4, 3)
2251        >>> print(output)
2252            [[[1. 2. 3.]
2253              [4. 2. 0.]
2254              [4. 5. 6.]
2255              [4. 2. 1.]]
2256             [[1. 2. 3.]
2257              [4. 2. 0.]
2258              [4. 5. 6.]
2259              [4. 2. 1.]]]
2260        >>> # case 4: It has the same input with the 3rd case.
2261        >>> # Because num_segments is equal to 2, there are two segment_ids, but currently only one 0 is used.
2262        >>> # the segment_id i is absent in the segment_ids, then output[i] will be filled with
2263        >>> # the smallest possible value of the input_x's type.
2264        >>> segment_ids = Tensor(np.array([0, 0]).astype(np.int32))
2265        >>> output = unsorted_segment_max(input_x, segment_ids, num_segments)
2266        >>> print(output)
2267            [[[ 1.0000000e+00  2.0000000e+00  3.0000000e+00]
2268              [ 4.0000000e+00  2.0000000e+00  0.0000000e+00]
2269              [ 4.0000000e+00  5.0000000e+00  6.0000000e+00]
2270              [ 4.0000000e+00  2.0000000e+00  1.0000000e+00]]
2271             [[-3.4028235e+38 -3.4028235e+38 -3.4028235e+38]
2272              [-3.4028235e+38 -3.4028235e+38 -3.4028235e+38]
2273              [-3.4028235e+38 -3.4028235e+38 -3.4028235e+38]
2274              [-3.4028235e+38 -3.4028235e+38 -3.4028235e+38]]]
2275    """
2276
2277    @prim_attr_register
2278    def __init__(self):
2279        """Initialize UnsortedSegmentMax"""
2280        self.init_prim_io_names(inputs=['x', 'segment_ids', 'num_segments'], outputs=['y'])
2281
2282    def __check__(self, x, segment_ids, num_segments):
2283        x_shape = x['shape']
2284        segment_ids_shape = segment_ids['shape']
2285        valid_type = [mstype.float16, mstype.float32, mstype.int32]
2286        validator.check_tensor_dtype_valid("x", x['dtype'], valid_type, self.name)
2287        validator.check_tensors_dtypes_same_and_valid({"segment_ids": segment_ids['dtype']},
2288                                                      [mstype.int32, mstype.int64], self.name)
2289        validator.check_equal_int(len(segment_ids_shape), 1, "rank of segment_ids_shape", self.name)
2290        num_segments_type = num_segments['dtype']
2291        validator.check_subclass("num_segments", num_segments_type, [mstype.number], self.name)
2292        if -1 not in x_shape and -1 not in segment_ids_shape:
2293            # only validate when both shapes fully known
2294            validator.check(f'first shape of input_x', x_shape[0],
2295                            'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name)
2296        num_segments_v = num_segments['value']
2297        validator.check_value_type('num_segments', num_segments_v, [int], self.name)
2298        validator.check_positive_int(num_segments_v, "num_segments", self.name)
2299
2300
2301class UnsortedSegmentProd(PrimitiveWithInfer):
2302    """
2303    Computes the product of a tensor along segments.
2304
2305    The following figure shows the calculation process of UnsortedSegmentProd:
2306
2307    .. image:: api_img/UnsortedSegmentProd.png
2308
2309    Inputs:
2310        - **input_x** (Tensor) - The shape is :math:`(x_1, x_2, ..., x_R)`.
2311          With float16, float32 or int32 data type.
2312        - **segment_ids** (Tensor) - A `1-D` tensor whose shape is :math:`(x_1)`, the value must be non-negative tensor.
2313          Data type must be int32.
2314        - **num_segments** (int) - The value specifies the number of distinct `segment_ids`,
2315          must be greater than 0.
2316
2317    Outputs:
2318        Tensor, set the number of `num_segments` as `N`, the shape is :math:`(N, x_2, ..., x_R)`.
2319
2320    Raises:
2321        TypeError: If `num_segments` is not an int.
2322        ValueError: If length of shape of `segment_ids` is not equal to 1.
2323
2324    Supported Platforms:
2325        ``Ascend``
2326
2327    Examples:
2328        >>> input_x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [4, 2, 1]]).astype(np.float32))
2329        >>> segment_ids = Tensor(np.array([0, 1, 0]).astype(np.int32))
2330        >>> num_segments = 2
2331        >>> unsorted_segment_prod = ops.UnsortedSegmentProd()
2332        >>> output = unsorted_segment_prod(input_x, segment_ids, num_segments)
2333        >>> print(output)
2334        [[4. 4. 3.]
2335         [4. 5. 6.]]
2336    """
2337
2338    @prim_attr_register
2339    def __init__(self):
2340        """Initialize UnsortedSegmentProd"""
2341        self.init_prim_io_names(inputs=['x', 'segment_ids', 'num_segments'], outputs=['y'])
2342
2343    def __infer__(self, x, segment_ids, num_segments):
2344        x_type = x['dtype']
2345        x_shape = x['shape']
2346        segment_ids_shape = segment_ids['shape']
2347        validator.check_subclass("input_x", x_type, mstype.tensor, self.name)
2348        validator.check_value_type("x_shape", x_shape, [list], self.name)
2349        valid_type = [mstype.float16, mstype.float32, mstype.int32]
2350        validator.check_tensor_dtype_valid("x", x['dtype'], valid_type, self.name)
2351        validator.check_tensor_dtype_valid("segment_ids", segment_ids['dtype'], [mstype.int32], self.name)
2352        validator.check_equal_int(len(segment_ids_shape), 1, "rank of segment_ids_shape", self.name)
2353        validator.check(f'first shape of input_x', x_shape[0],
2354                        'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name)
2355        num_segments_v = num_segments['value']
2356        validator.check_value_type('num_segments', num_segments_v, [int], self.name)
2357        validator.check_positive_int(num_segments_v, "num_segments", self.name)
2358        segment_ids_shape_len = len(segment_ids_shape)
2359        out_shape = [num_segments_v]
2360        out_shape += x_shape[segment_ids_shape_len:]
2361        out = {'shape': out_shape,
2362               'dtype': mstype.tensor_type(x_type.element_type()),
2363               'value': None}
2364        return out
2365
2366
2367class Concat(PrimitiveWithInfer):
2368    r"""
2369    Connect tensor in the specified axis.
2370
2371    Connect input tensors along with the given axis.
2372
2373    The input data is a tuple of tensors. These tensors have the same rank `R`. Set the given axis as `m`, and
2374    :math:`0 \le m < R`. Set the number of input tensors as `N`. For the :math:`i`-th tensor :math:`t_i`, it has
2375    the shape of :math:`(x_1, x_2, ..., x_{mi}, ..., x_R)`. :math:`x_{mi}` is the :math:`m`-th dimension of the
2376    :math:`i`-th tensor. Then, the shape of the output tensor is
2377
2378    .. math::
2379
2380        (x_1, x_2, ..., \sum_{i=1}^Nx_{mi}, ..., x_R)
2381
2382    .. warning::
2383        The value range of "axis" is [-dims, dims - 1]. "dims" is the dimension length of "input_x".
2384
2385    Args:
2386        axis (int): The specified axis. Default: 0.
2387
2388    Inputs:
2389        - **input_x** (tuple, list) - A tuple or a list of input tensors.
2390          Suppose there are two tensors in this tuple or list, namely x1 and x2.
2391          To perform `Concat` in the axis 0 direction, except for the 0th axis, all other axes should be equal,
2392          that is, :math:`x1.shape[1] == x2.shape[1], x1.shape[2] == x2.shape[2], ..., x1.shape[R] == x2.shape[R]`,
2393          where the :math:`R` indicates the last axis.
2394
2395    Outputs:
2396        - Tensor, the shape is :math:`(x_1, x_2, ..., \sum_{i=1}^Nx_{mi}, ..., x_R)`.
2397          The data type is the same with `input_x`.
2398
2399    Raises:
2400        TypeError: If `axis` is not an int.
2401
2402    Supported Platforms:
2403        ``Ascend`` ``GPU`` ``CPU``
2404
2405    Examples:
2406        >>> input_x1 = Tensor(np.array([[0, 1], [2, 1]]).astype(np.float32))
2407        >>> input_x2 = Tensor(np.array([[0, 1], [2, 1]]).astype(np.float32))
2408        >>> op = ops.Concat()
2409        >>> output = op((input_x1, input_x2))
2410        >>> print(output)
2411        [[0. 1.]
2412         [2. 1.]
2413         [0. 1.]
2414         [2. 1.]]
2415        >>> op = ops.Concat(1)
2416        >>> output = op((input_x1, input_x2))
2417        >>> print(output)
2418        [[0. 1. 0. 1.]
2419         [2. 1. 2. 1.]]
2420    """
2421
2422    @prim_attr_register
2423    def __init__(self, axis=0):
2424        """Initialize Concat"""
2425        validator.check_value_type("axis", axis, [int], self.name)
2426
2427    def __infer__(self, input_x):
2428        axis = self.axis
2429        x_shp = input_x['shape']
2430        x_type = input_x['dtype']
2431        _, all_shp, _ = get_concat_offset(x_shp, x_type, axis, self.name)
2432        self.add_prim_attr('inputNums', len(x_shp))
2433        ret_shp = x_shp[0].copy()
2434        value = None
2435        if input_x['value'] is not None:
2436            value = Tensor(np.concatenate([x.asnumpy() for x in input_x['value']], axis=axis))
2437        ret_shp[axis] = all_shp
2438        out = {'shape': ret_shp,
2439               'dtype': x_type[0],
2440               'value': value}
2441        if -1 in x_shp[0]:
2442            x_min_shp = input_x['min_shape']
2443            ret_min_shp = x_min_shp[0].copy()
2444            ret_min_shp[axis] = 0
2445            for all_min_shp in x_min_shp:
2446                ret_min_shp[axis] += all_min_shp[axis]
2447            out['min_shape'] = ret_min_shp
2448            x_max_shp = input_x['max_shape']
2449            ret_max_shp = x_max_shp[0].copy()
2450            ret_max_shp[axis] = 0
2451            for all_max_shp in x_max_shp:
2452                ret_max_shp[axis] += all_max_shp[axis]
2453            out['max_shape'] = ret_max_shp
2454        return out
2455
2456
2457class ParallelConcat(PrimitiveWithInfer):
2458    r"""
2459    Concats tensor in the first dimension.
2460
2461    Concats input tensors along with the first dimension.
2462
2463    The difference between Concat and ParallelConcat is that Concat requires all of the inputs be computed
2464    before the operation will begin but doesn't require that the input shapes be known during graph construction.
2465    Parallel concat will copy pieces of the input into the output as they become available, in some situations
2466    this can provide a performance benefit.
2467
2468    Note:
2469        The input tensors are all required to have size 1 in the first dimension.
2470
2471    Inputs:
2472        - **values** (tuple, list) - A tuple or a list of input tensors. The data type and shape of these
2473          tensors must be the same. The data type is Number except float64.
2474
2475    Outputs:
2476        Tensor, data type is the same as `values`.
2477
2478    Raises:
2479        ValueError: If length of shape of `values` is less than 1.
2480        ValueError: The data type and shape of these tensors are not the same.
2481
2482    Supported Platforms:
2483        ``Ascend``
2484
2485    Examples:
2486        >>> data1 = Tensor(np.array([[0, 1]]).astype(np.int32))
2487        >>> data2 = Tensor(np.array([[2, 1]]).astype(np.int32))
2488        >>> op = ops.ParallelConcat()
2489        >>> output = op((data1, data2))
2490        >>> print(output)
2491        [[0 1]
2492         [2 1]]
2493    """
2494
2495    @prim_attr_register
2496    def __init__(self):
2497        """Initialize ParallelConcat"""
2498
2499    def __infer__(self, values):
2500        x_shp = values['shape']
2501        x_type = values['dtype']
2502
2503        validator.check_int(len(x_shp), 1, Rel.GE, f'x_shp length', self.name)
2504
2505        args = {f"x_type[{i}]": elem for i, elem in enumerate(x_type)}
2506        validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type + (mstype.bool_,), self.name)
2507
2508        first_elem = x_shp[0]
2509        for i, elem in enumerate(x_shp[1:]):
2510            j = i + 1
2511            validator.check_equal_int(elem[0], 1, f'x_shp[{j}][0]', self.name)
2512            validator.check(f"x_shp[0] shape", first_elem, f"x_shp[{j}] shape", elem, Rel.EQ, self.name)
2513
2514        ret_shp = x_shp[0].copy()
2515        ret_shp[0] = len(x_shp)
2516        self.add_prim_attr('shape', ret_shp)
2517        self.add_prim_attr('N', len(x_shp))
2518
2519        out = {'shape': ret_shp,
2520               'dtype': x_type[0],
2521               'value': None}
2522        return out
2523
2524
2525def _get_stack_shape(x_shape, x_type, axis, prim_name):
2526    """for stack output shape"""
2527    validator.check_value_type("shape", x_shape, [tuple, list], prim_name)
2528    validator.check_int(len(x_shape), 1, Rel.GE, "len of input_x", prim_name)
2529    validator.check_subclass("input_x[0]", x_type[0], mstype.tensor, prim_name)
2530    rank_base = len(x_shape[0])
2531    n = len(x_shape)
2532    out_shape = x_shape[0]
2533    validator.check_int_range(axis, -rank_base - 1, rank_base, Rel.INC_BOTH, 'axis', prim_name)
2534    if axis < 0:
2535        axis = axis + rank_base + 1
2536    for i in range(1, n):
2537        validator.check('x_type[%d]' % i, x_type[i], 'base', x_type[0], Rel.EQ, prim_name, TypeError)
2538        if x_shape[i] != x_shape[0]:
2539            raise ValueError(f"For \'{prim_name}\' element {i} shape in input can not pack with first element")
2540    out_shape.insert(axis, n)
2541    return out_shape
2542
2543
2544class Pack(PrimitiveWithInfer):
2545    """
2546    Same as operator Stack. Pack will be deprecated in the future.
2547    Please use Stack instead.
2548    """
2549
2550    @deprecated("1.1", "Stack", True)
2551    @prim_attr_register
2552    def __init__(self, axis=0):
2553        """Initialize Pack"""
2554        validator.check_value_type("axis", axis, [int], self.name)
2555        self.axis = axis
2556
2557    def __infer__(self, value):
2558        x_shape = value['shape']
2559        x_type = value['dtype']
2560        self.add_prim_attr('num', len(x_shape))
2561        all_shape = _get_stack_shape(x_shape, x_type, self.axis, self.name)
2562        out = {'shape': all_shape,
2563               'dtype': x_type[0],
2564               'value': None}
2565        return out
2566
2567
2568class Stack(PrimitiveWithInfer):
2569    r"""
2570    Stacks a list of tensors in specified axis.
2571
2572    Stacks the list of input tensors with the same rank `R`, output is a tensor of rank `(R+1)`.
2573
2574    Given input tensors of shape :math:`(x_1, x_2, ..., x_R)`. Set the number of input tensors as `N`.
2575    If :math:`0 \le axis`, the shape of the output tensor is
2576    :math:`(x_1, x_2, ..., x_{axis}, N, x_{axis+1}, ..., x_R)`.
2577
2578    Args:
2579        axis (int): Dimension to stack. Default: 0.
2580                    Negative values wrap around. The range is [-(R+1), R+1).
2581
2582    Inputs:
2583        - **input_x** (Union[tuple, list]) - A Tuple or list of Tensor objects with the same shape and type.
2584
2585    Outputs:
2586        Tensor. A stacked Tensor with the same type as `input_x`.
2587
2588    Raises:
2589        TypeError: If the data types of elements in `input_x` are not the same.
2590        ValueError: If the length of `input_x` is not greater than 1;
2591                    or if axis is out of the range [-(R+1), R+1);
2592                    or if the shapes of elements in input_x are not the same.
2593
2594    Supported Platforms:
2595        ``Ascend`` ``GPU`` ``CPU``
2596
2597    Examples:
2598        >>> data1 = Tensor(np.array([0, 1]).astype(np.float32))
2599        >>> data2 = Tensor(np.array([2, 3]).astype(np.float32))
2600        >>> stack = ops.Stack()
2601        >>> output = stack([data1, data2])
2602        >>> print(output)
2603        [[0. 1.]
2604         [2. 3.]]
2605    """
2606
2607    @prim_attr_register
2608    def __init__(self, axis=0):
2609        """Initialize Stack"""
2610        validator.check_value_type("axis", axis, [int], self.name)
2611        self.axis = axis
2612
2613    def __infer__(self, value):
2614        x_shape = value['shape']
2615        x_type = value['dtype']
2616        self.add_prim_attr('num', len(x_shape))
2617        all_shape = _get_stack_shape(x_shape, x_type, self.axis, self.name)
2618        out = {'shape': all_shape,
2619               'dtype': x_type[0],
2620               'value': None}
2621        return out
2622
2623
2624class Unpack(PrimitiveWithInfer):
2625    """
2626    Same as operator Unstack. Unpack will be deprecated in the future.
2627    Please use Unstack instead.
2628    """
2629
2630    @deprecated("1.1", "Unstack", True)
2631    @prim_attr_register
2632    def __init__(self, axis=0):
2633        """Initialize Unpack"""
2634        validator.check_value_type("axis", axis, [int], self.name)
2635        self.axis = axis
2636
2637    def __infer__(self, x):
2638        validator.check_subclass("x", x['dtype'], mstype.tensor, self.name)
2639        x_shape = list(x['shape'])
2640        dim = len(x_shape)
2641        validator.check_int_range(self.axis, -dim, dim, Rel.INC_LEFT, 'axis value', self.name)
2642        if self.axis < 0:
2643            self.axis = self.axis + dim
2644        output_num = x_shape[self.axis]
2645        validator.check_value_type("num", output_num, [int], self.name)
2646        validator.check_positive_int(output_num, "output_num", self.name)
2647        self.add_prim_attr('num', output_num)
2648        output_valid_check = x_shape[self.axis] - output_num
2649        validator.check_int(output_valid_check, 0, Rel.EQ,
2650                            "The dimension which to unstack divides output_num", self.name)
2651        out_shapes = []
2652        out_dtypes = []
2653        out_shape = x_shape[:self.axis] + x_shape[self.axis + 1:]
2654        for _ in range(output_num):
2655            out_shapes.append(tuple(out_shape))
2656            out_dtypes.append(x['dtype'])
2657        out_shapes = tuple(out_shapes)
2658        out_dtypes = tuple(out_dtypes)
2659        out = {'shape': out_shapes,
2660               'dtype': out_dtypes,
2661               'value': None}
2662        return out
2663
2664
2665class Unstack(PrimitiveWithInfer):
2666    r"""
2667    Unstacks tensor in specified axis.
2668
2669    Unstacks a tensor of rank `R` along axis dimension, output tensors will have rank `(R-1)`.
2670
2671    Given a tensor of shape :math:`(x_1, x_2, ..., x_R)`. If :math:`0 \le axis`,
2672    the shape of tensor in output is :math:`(x_1, x_2, ..., x_{axis}, x_{axis+2}, ..., x_R)`.
2673
2674    This is the opposite of pack.
2675
2676    Args:
2677        axis (int): Dimension along which to pack. Default: 0.
2678                    Negative values wrap around. The range is [-R, R).
2679
2680    Inputs:
2681        - **input_x** (Tensor) - The shape is :math:`(x_1, x_2, ..., x_R)`.
2682          A tensor to be unstacked and the rank of the tensor must be greater than 0.
2683
2684    Outputs:
2685        A tuple of tensors, the shape of each objects is the same.
2686
2687    Raises:
2688        ValueError: If axis is out of the range [-len(input_x.shape), len(input_x.shape)).
2689
2690    Supported Platforms:
2691        ``Ascend`` ``GPU`` ``CPU``
2692
2693    Examples:
2694        >>> unstack = ops.Unstack()
2695        >>> input_x = Tensor(np.array([[1, 1, 1, 1], [2, 2, 2, 2]]))
2696        >>> output = unstack(input_x)
2697        >>> print(output)
2698        (Tensor(shape=[4], dtype=Int64, value= [1, 1, 1, 1]), Tensor(shape=[4], dtype=Int64, value= [2, 2, 2, 2]))
2699    """
2700
2701    @prim_attr_register
2702    def __init__(self, axis=0):
2703        """Initialize Unstack"""
2704        validator.check_value_type("axis", axis, [int], self.name)
2705        self.axis = axis
2706
2707    def __infer__(self, x):
2708        validator.check_subclass("x", x['dtype'], mstype.tensor, self.name)
2709        x_shape = list(x['shape'])
2710        dim = len(x_shape)
2711        validator.check_int_range(self.axis, -dim, dim, Rel.INC_LEFT, 'axis value', self.name)
2712        if self.axis < 0:
2713            self.axis = self.axis + dim
2714        output_num = x_shape[self.axis]
2715        validator.check_value_type("num", output_num, [int], self.name)
2716        validator.check_positive_int(output_num, "output_num", self.name)
2717        self.add_prim_attr('num', output_num)
2718        output_valid_check = x_shape[self.axis] - output_num
2719        validator.check_int(output_valid_check, 0, Rel.EQ,
2720                            "The dimension which to unstack divides output_num", self.name)
2721        out_shapes = []
2722        out_dtypes = []
2723        out_shape = x_shape[:self.axis] + x_shape[self.axis + 1:]
2724        for _ in range(output_num):
2725            out_shapes.append(tuple(out_shape))
2726            out_dtypes.append(x['dtype'])
2727        out_shapes = tuple(out_shapes)
2728        out_dtypes = tuple(out_dtypes)
2729        out = {'shape': out_shapes,
2730               'dtype': out_dtypes,
2731               'value': None}
2732        return out
2733
2734
2735class Slice(PrimitiveWithInfer):
2736    """
2737    Slices a tensor in the specified shape.
2738
2739    Slice the tensor `input_x` in shape of `size` and starting at the location specified by `begin`,
2740    The slice `begin` represents the offset in each dimension of `input_x`,
2741    The slice `size` represents the size of the output tensor.
2742
2743    Note that `begin` is zero-based and `size` is one-based.
2744
2745    If `size[i]` is -1, all remaining elements in dimension i are included in the slice.
2746    This is equivalent to setting :math:`size[i] = input_x.shape(i) - begin[i]`.
2747
2748    Inputs:
2749        - **input_x** (Tensor): The target tensor.
2750          The shape is :math:`(N,*)` where :math:`*` means, any number of additional dimensions.
2751        - **begin** (Union[tuple, list]): The beginning of the slice. Only constant value(>=0) is allowed.
2752        - **size** (Union[tuple, list]): The size of the slice. Only constant value is allowed.
2753
2754    Outputs:
2755        Tensor, the shape is : input `size`, the data type is the same as `input_x`.
2756
2757    Raises:
2758        TypeError: If `begin` or `size` is neither tuple nor list.
2759
2760    Supported Platforms:
2761        ``Ascend`` ``GPU`` ``CPU``
2762
2763    Examples:
2764        >>> data = Tensor(np.array([[[1, 1, 1], [2, 2, 2]],
2765        ...                         [[3, 3, 3], [4, 4, 4]],
2766        ...                         [[5, 5, 5], [6, 6, 6]]]).astype(np.int32))
2767        >>> slice_op = ops.Slice()
2768        >>> output = slice_op(data, (1, 0, 0), (1, 1, 3))
2769        >>> print(output)
2770        [[[3 3 3]]]
2771        >>> output = slice_op(data, (1, 0, 0), (1, 1, 2))
2772        >>> print(output)
2773        [[[3 3]]]
2774        >>> output = slice_op(data, (1, 0, 0), (1, 1, 1))
2775        >>> print(output)
2776        [[[3]]]
2777        >>> output = slice_op(data, (1, 1, 0), (1, 1, 3))
2778        >>> print(output)
2779        [[[4 4 4]]]
2780        >>> output = slice_op(data, (1, 0, 1), (1, 1, 2))
2781        >>> print(output)
2782        [[[3 3]]]
2783    """
2784
2785    @prim_attr_register
2786    def __init__(self):
2787        """Initialize slice"""
2788        self.init_prim_io_names(inputs=['x', 'begin', 'size'], outputs=['output'])
2789
2790    def __infer__(self, x, begin, size):
2791        x_shape = x['shape']
2792        x_shp_len = len(x_shape)
2793        validator.check_valid_input('begin', begin['value'], self.name)
2794        validator.check_valid_input('size', size['value'], self.name)
2795        begin_v, size_v = begin['value'], size['value']
2796        if begin_v is None or size_v is None:
2797            return {'shape': None,
2798                    'dtype': x['dtype'],
2799                    'value': None}
2800        validator.check_value_type("input begin", begin_v, [tuple, list], self.name)
2801        validator.check_value_type("input size", size_v, [tuple, list], self.name)
2802        for key, value in zip(('begin', 'size'), (begin_v, size_v)):
2803            validator.check(f'len of {key}', len(value),
2804                            'len x\'s dim', x_shp_len)
2805        for i in range(x_shp_len):
2806            validator.check_positive_int(size_v[i], f'input size[{i}]')
2807            validator.check_non_negative_int(begin_v[i], f'input begin[{i}]')
2808            if x_shape[i] < begin_v[i] + size_v[i]:
2809                y = begin_v[i] + size_v[i]
2810                raise ValueError(f"For '{self.name}', the sliced shape can not be greater than origin shape, but got "
2811                                 f"sliced shape is {y}, and origin shape is {x_shape}.")
2812        return {'shape': size_v,
2813                'dtype': x['dtype'],
2814                'value': None}
2815
2816
2817class ReverseV2(PrimitiveWithInfer):
2818    """
2819    Reverses specific dimensions of a tensor.
2820
2821    .. warning::
2822        The value range of "axis" is [-dims, dims - 1]. "dims" is the dimension length of "input_x".
2823
2824    Args:
2825        axis (Union[tuple(int), list(int)): The indices of the dimensions to reverse.
2826
2827    Inputs:
2828        - **input_x** (Tensor) - The target tensor. The data type is Number except float64.
2829          The shape is :math:`(N,*)` where :math:`*` means, any number of additional dimensions.
2830
2831    Outputs:
2832        Tensor, has the same shape and type as `input_x`.
2833
2834    Raises:
2835        TypeError: If `axis` is neither list nor tuple.
2836        TypeError: If element of `axis` is not an int.
2837
2838    Supported Platforms:
2839        ``Ascend`` ``GPU``
2840
2841    Examples:
2842        >>> input_x = Tensor(np.array([[1, 2, 3, 4], [5, 6, 7, 8]]), mindspore.int32)
2843        >>> op = ops.ReverseV2(axis=[1])
2844        >>> output = op(input_x)
2845        >>> print(output)
2846        [[4 3 2 1]
2847         [8 7 6 5]]
2848        >>> op = ops.ReverseV2(axis=[1, 0])
2849        >>> output = op(input_x)
2850        >>> print(output)
2851        [[8 7 6 5]
2852         [4 3 2 1]]
2853    """
2854
2855    @prim_attr_register
2856    def __init__(self, axis):
2857        """Initialize ReverseV2."""
2858        validator.check_value_type('axis', axis, [list, tuple], self.name)
2859        for i, each in enumerate(axis):
2860            validator.check_value_type(f'axis[{i}]', each, [int], self.name)
2861        self.axis = axis
2862        self.init_prim_io_names(inputs=['x'], outputs=['output'])
2863
2864
2865    def infer_shape(self, x_shape):
2866        dim = len(x_shape)
2867        for i, each in enumerate(self.axis):
2868            validator.check_int_range(each, -dim, dim, Rel.INC_LEFT, f'axis[{i}]', self.name)
2869        normalized_axis = []
2870        for i, v in enumerate(self.axis):
2871            if v < 0:
2872                normalized_axis.append(v + dim)
2873            else:
2874                normalized_axis.append(v)
2875
2876        if len(normalized_axis) != len(set(normalized_axis)):
2877            duplicated = [item for item, count in Counter(normalized_axis).items() if count > 1]
2878            raise ValueError(f"For '{self.name}', the 'axis' cannot contain duplicate dimensions,"
2879                             f" but got duplicated elements {duplicated}.")
2880
2881        return x_shape
2882
2883    def infer_dtype(self, x_dtype):
2884        validator.check_tensor_dtype_valid('x', x_dtype, (mstype.bool_,) + mstype.number_type, self.name)
2885        return x_dtype
2886
2887
2888class Rint(PrimitiveWithInfer):
2889    """
2890    Returns an integer that is closest to x element-wise.
2891
2892    Inputs:
2893        - **input_x** (Tensor) - The target tensor, which must be one of the following types:
2894          float16, float32. The shape is :math:`(N,*)` where :math:`*` means, any number of additional dimensions.
2895
2896    Outputs:
2897        Tensor, has the same shape and type as `input_x`.
2898
2899    Raises:
2900        TypeError: If dtype of `input_x` is neither float16 nor float32.
2901
2902    Supported Platforms:
2903        ``Ascend`` ``GPU`` ``CPU``
2904
2905    Examples:
2906        >>> input_x = Tensor(np.array([-1.6, -0.1, 1.5, 2.0]), mindspore.float32)
2907        >>> op = ops.Rint()
2908        >>> output = op(input_x)
2909        >>> print(output)
2910        [-2.  0.  2.  2.]
2911        >>> input_x = Tensor(np.array([[-2.0, -1.9, -1.8, -1.7, -1.6],
2912        ...                            [-2.0, -1.9, -1.8, -1.7, -1.6]]), mindspore.float32)
2913        >>> output = op(input_x)
2914        >>> print(output)
2915        [[-2. -2. -2. -2. -2.]
2916         [-2. -2. -2. -2. -2.]]
2917    """
2918
2919    @prim_attr_register
2920    def __init__(self):
2921        """Initialize Rint."""
2922        self.init_prim_io_names(inputs=['x'], outputs=['output'])
2923
2924    def infer_shape(self, x_shape):
2925        return x_shape
2926
2927    def infer_dtype(self, x_dtype):
2928        validator.check_tensor_dtype_valid('x', x_dtype, [mstype.float16, mstype.float32], self.name)
2929        return x_dtype
2930
2931
2932class Select(Primitive):
2933    r"""
2934
2935    Returns the selected elements, either from input :math:`x` or input :math:`y`, depending on the `condition`.
2936
2937    Given a tensor as input, this operation inserts a dimension of 1 at the dimension,
2938    it was invalid when both math: 'x' and math: 'y' are none.
2939    Keep in mind that the shape of the output tensor can vary depending
2940    on how many true values are in the input. Indexes are output in row-first
2941    order.
2942
2943    The conditional tensor acts as an optional compensation (mask), which
2944    determines whether the corresponding element / row in the output must be
2945    selected from :math:`x` (if true) or :math:`y` (if false) based on the value of each
2946    element.
2947
2948    It can be defined as:
2949
2950    .. math::
2951        out_i = \begin{cases}
2952        x_i, & \text{if } condition_i \\
2953        y_i, & \text{otherwise}
2954        \end{cases}
2955
2956    If condition is a vector, then :math:`x` and :math:`y` are higher-dimensional matrices, then it
2957    chooses to copy that row (external dimensions) from :math:`x` and :math:`y`. If condition has
2958    the same shape as :math:`x` and :math:`y`, you can choose to copy these elements from :math:`x`
2959    and :math:`y`.
2960
2961    Inputs:
2962        - **input_cond** (Tensor[bool]) - The shape is :math:`(x_1, x_2, ..., x_N, ..., x_R)`.
2963          The condition tensor, decides which element is chosen.
2964        - **input_x** (Tensor) - The shape is :math:`(x_1, x_2, ..., x_N, ..., x_R)`.
2965          The first input tensor.
2966        - **input_y** (Tensor) - The shape is :math:`(x_1, x_2, ..., x_N, ..., x_R)`.
2967          The second input tensor.
2968
2969    Outputs:
2970        Tensor, has the same shape as `input_x`. The shape is :math:`(x_1, x_2, ..., x_N, ..., x_R)`.
2971
2972    Raises:
2973        TypeError: If `input_x` or `input_y` is not a Tensor.
2974        ValueError: If shape of `input_x` is not equal to shape of `input_y` or shape of `input_cond`.
2975
2976    Supported Platforms:
2977        ``Ascend`` ``GPU`` ``CPU``
2978
2979    Examples:
2980        >>> select = ops.Select()
2981        >>> input_cond = Tensor([True, False])
2982        >>> input_x = Tensor([2,3], mindspore.float32)
2983        >>> input_y = Tensor([1,2], mindspore.float32)
2984        >>> output = select(input_cond, input_x, input_y)
2985        >>> print(output)
2986        [2. 2.]
2987    """
2988
2989    @prim_attr_register
2990    def __init__(self):
2991        """Initialize Select."""
2992        self.init_prim_io_names(inputs=['condition', 'x', 'y'], outputs=['output'])
2993
2994
2995def _compute_slicing_length(begin, end, stride, x_shape, i):
2996    """Computes the length of the slicing."""
2997    if i >= len(x_shape):
2998        raise ValueError(f"For 'StridedSlice', the index must be less than or equal to "
2999                         f"the dimension of 'input_x', but got the dimension of 'input_x': {len(x_shape)} "
3000                         f"and the index: {i}.")
3001    x_dim = x_shape[i]
3002    if stride > 0:
3003        # When slicing forward, convert begin and end to positive numbers.
3004        if begin >= x_dim or end < -x_dim:
3005            # When slicing forward, if begin >= x_dim or end < -x_dim, the length of the slicing is 0.
3006            slicing_length = 0
3007        else:
3008            if -x_dim <= begin < 0:
3009                begin += x_dim
3010            if begin < -x_dim:
3011                # When slicing forward, if begin < -x_dim, set begin = 0, which means start from the 0th element.
3012                begin = 0
3013            if -x_dim <= end < 0:
3014                end += x_dim
3015            if end > x_dim:
3016                # When slicing forward, if end > x_dim, set end = x_dims, which means slice to the last element.
3017                end = x_dim
3018            if begin >= end:
3019                # When slicing forward, if begin >= end, the length of the slicing is 0.
3020                slicing_length = 0
3021            else:
3022                slicing_length = 1 + (end - 1 - begin) // stride
3023    else:
3024        # When slicing backward, convert begin and end to negative numbers.
3025        if begin < -x_dim or end >= x_dim:
3026            # When slicing backward, if begin < -x_dim or end >= x_dim, the length of the slicing is 0.
3027            slicing_length = 0
3028        else:
3029            if 0 <= begin < x_dim:
3030                begin += -x_dim
3031            if begin >= x_dim:
3032                begin = -1
3033            if 0 <= end < x_dim:
3034                end += -x_dim
3035            if end < -x_dim - 1:
3036                # Slicing to the 0th element.
3037                end = -x_dim - 1
3038            if begin <= end:
3039                slicing_length = 0
3040            else:
3041                slicing_length = 1 + (end + 1 - begin) // stride
3042    return slicing_length
3043
3044
3045class StridedSlice(PrimitiveWithInfer):
3046    r"""
3047
3048    Extracts a strided slice of a tensor.
3049
3050    Given an input tensor, this operation inserts a dimension of length 1 at the dimension.
3051    This operation extracts a fragment of size (end-begin)/stride from the given 'input_tensor'.
3052    Starting from the beginning position, the fragment continues adding stride to the index until
3053    all dimensions are not less than the ending position.
3054
3055    Given a `input_x[m1, m2, ..., mn]`, `begin`, `end` and `strides` will be vectors of length n.
3056
3057    In each mask field (`begin_mask`, `end_mask`, `ellipsis_mask`, `new_axis_mask`, `shrink_axis_mask`)
3058    the ith bit will correspond to the ith m.
3059
3060    If the ith bit of `begin_mask` is set, `begin[i]` is ignored and the fullest possible range in that dimension
3061    is used instead. `end_mask` is analogous, except with the end range.
3062
3063    As for a 5*6*7 tensor, `x[2:,:3,:]` is equivalent to `x[2:5,0:3,0:7]`.
3064
3065    If the ith bit of `ellipsis_mask` is set, as many unspecified dimensions as needed will be inserted between
3066    other dimensions. Only one non-zero bit is allowed in `ellipsis_mask`.
3067
3068    As for a 5*6*7*8 tensor, `x[2:,...,:6]` is equivalent to `x[2:5,:,:,0:6]`.
3069    `x[2:,...]` is equivalent to `x[2:5,:,:,:]`.
3070
3071    If the ith bit of `new_axis_mask` is set, `begin`, `end` and `strides` are ignored and a new length 1
3072    dimension is added at the specified position in tthe output tensor.
3073
3074    As for a 5*6*7 tensor, `x[:2, newaxis, :6]` will produce a tensor with shape (2, 1, 7).
3075
3076    If the ith bit of `shrink_axis_mask` is set, ith size shrinks the dimension by 1, taking on the value
3077    at index `begin[i]`, `end[i]` and `strides[i]` are ignored.
3078
3079    As for a 5*6*7 tensor, `x[:, 5, :]` will result in `shrink_axis_mask` equal to 4.
3080
3081    Note:
3082        The stride may be negative value, which causes reverse slicing.
3083        The shape of `begin`, `end` and `strides` must be the same.
3084        `begin` and `end` are zero-indexed. The element of `strides` must be non-zero.
3085
3086    Args:
3087        begin_mask (int): Starting index of the slice. Default: 0.
3088        end_mask (int): Ending index of the slice. Default: 0.
3089        ellipsis_mask (int): An int mask. Default: 0.
3090        new_axis_mask (int): An int mask. Default: 0.
3091        shrink_axis_mask (int): An int mask. Default: 0.
3092
3093    Inputs:
3094        - **input_x** (Tensor) - The input Tensor.
3095        - **begin** (tuple[int]) - A tuple which represents the location where to start. Only
3096          constant value is allowed.
3097        - **end** (tuple[int]) - A tuple or which represents the maximum location where to end.
3098          Only constant value is allowed.
3099        - **strides** (tuple[int]) - A tuple which represents the stride is continuously added
3100          before reaching the maximum location. Only constant value is allowed.
3101
3102    Outputs:
3103        Tensor, The output is explained by following example.
3104
3105        In the 0th dimension, begin is 1, end is 2, and strides is 1,
3106        because :math:`1+1=2\geq2`, the interval is :math:`[1,2)`.
3107        Thus, return the element with :math:`index = 1` in 0th dimension, i.e., [[3, 3, 3], [4, 4, 4]].
3108
3109        In the 1st dimension, similarly, the interval is :math:`[0,1)`.
3110        Based on the return value of the 0th dimension, return the element with :math:`index = 0`,
3111        i.e., [3, 3, 3].
3112
3113        In the 2nd dimension, similarly, the interval is :math:`[0,3)`.
3114        Based on the return value of the 1st dimension, return the element with :math:`index = 0,1,2`,
3115        i.e., [3, 3, 3].
3116
3117        Finally, the output is [3, 3, 3].
3118
3119    Raises:
3120        TypeError: If `begin_mask`, `end_mask`, `ellipsis_mask`, `new_axis_mask` or `shrink_axis_mask` is not an int.
3121        TypeError: If `begin`, `end` or `strides` is not a tuple.
3122        ValueError: If `begin_mask`, `end_mask`, `ellipsis_mask`, `new_axis_mask` or `shrink_axis_mask` is less than 0.
3123
3124    Supported Platforms:
3125        ``Ascend`` ``GPU`` ``CPU``
3126
3127    Examples:
3128        >>> input_x = Tensor([[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]],
3129        ...                   [[5, 5, 5], [6, 6, 6]]], mindspore.float32)
3130        >>> #         [[[1. 1. 1.]
3131        >>> #           [2. 2. 2.]]
3132        >>> #
3133        >>> #          [[3. 3. 3.]
3134        >>> #           [4. 4. 4.]]
3135        >>> #
3136        >>> #          [[5. 5. 5.]
3137        >>> #           [6. 6. 6.]]]
3138        >>> # In order to visually view the multi-dimensional array, write the above as follows:
3139        >>> #         [
3140        >>> #             [
3141        >>> #                 [1,1,1]
3142        >>> #                 [2,2,2]
3143        >>> #             ]
3144        >>> #             [
3145        >>> #                 [3,3,3]
3146        >>> #                 [4,4,4]
3147        >>> #             ]
3148        >>> #             [
3149        >>> #                 [5,5,5]
3150        >>> #                 [6,6,6]
3151        >>> #             ]
3152        >>> #         ]
3153        >>> strided_slice = ops.StridedSlice()
3154        >>> output = strided_slice(input_x, (1, 0, 2), (3, 1, 3), (1, 1, 1))
3155        >>> # Take this " output = strided_slice(input_x, (1, 0, 2), (3, 1, 3), (1, 1, 1)) " as an example,
3156        >>> # start = [1, 0, 2] , end = [3, 1, 3], stride = [1, 1, 1], Find a segment of (start, end),
3157        >>> # note that end is an open interval
3158        >>> # To facilitate understanding, this operator can be divided into three steps:
3159        >>> # Step 1: Calculation of the first dimension:
3160        >>> # start = 1, end = 3, stride = 1, So can take 1st, 2nd rows, and then gets the final output at this time.
3161        >>> # output_1th =
3162        >>> # [
3163        >>> #     [
3164        >>> #         [3,3,3]
3165        >>> #         [4,4,4]
3166        >>> #     ]
3167        >>> #     [
3168        >>> #         [5,5,5]
3169        >>> #         [6,6,6]
3170        >>> #     ]
3171        >>> # ]
3172        >>> # Step 2: Calculation of the second dimension
3173        >>> # 2nd dimension, start = 0, end = 1, stride = 1. So only 0th rows can be taken, and the output at this time.
3174        >>> # output_2nd =
3175        >>> # [
3176        >>> #     [
3177        >>> #         [3,3,3]
3178        >>> #     ]
3179        >>> #     [
3180        >>> #         [5,5,5]
3181        >>> #     ]
3182        >>> # ]
3183        >>> # Step 3: Calculation of the third dimension
3184        >>> # 3nd dimension,start = 2, end = 3, stride = 1, So can take 2th cols,
3185        >>> # and you get the final output at this time.
3186        >>> # output_3ed =
3187        >>> # [
3188        >>> #     [
3189        >>> #         [3]
3190        >>> #     ]
3191        >>> #     [
3192        >>> #         [5]
3193        >>> #     ]
3194        >>> # ]
3195        >>> # The final output after finishing is:
3196        >>> print(output)
3197        [[[3.]]
3198         [[5.]]]
3199        >>> # another example like :
3200        >>> output = strided_slice(input_x, (1, 0, 0), (2, 1, 3), (1, 1, 1))
3201        >>> print(output)
3202        [[[3. 3. 3.]]]
3203    """
3204
3205    @prim_attr_register
3206    def __init__(self,
3207                 begin_mask=0,
3208                 end_mask=0,
3209                 ellipsis_mask=0,
3210                 new_axis_mask=0,
3211                 shrink_axis_mask=0):
3212        """Initialize StridedSlice"""
3213        self.init_prim_io_names(inputs=['x', 'begin', 'end', 'strides'], outputs=['output'])
3214        validator.check_non_negative_int(begin_mask, 'begin_mask', self.name)
3215        validator.check_non_negative_int(end_mask, 'end_mask', self.name)
3216        validator.check_non_negative_int(ellipsis_mask, 'ellipsis_mask', self.name)
3217        if len(tuple(filter(lambda x: x == '1', bin(ellipsis_mask)[-1:1:-1]))) > 1:
3218            raise ValueError(f"For '{self.name}', only support one ellipsis in the index, but got {end_mask}.")
3219        validator.check_non_negative_int(new_axis_mask, 'new_axis_mask', self.name)
3220        validator.check_non_negative_int(shrink_axis_mask, 'shrink_axis_mask', self.name)
3221
3222    def __infer__(self, x, begin, end, strides):
3223        begin_v, end_v, strides_v = begin['value'], end['value'], strides['value']
3224        validator.check_value_type("begin", begin_v, [tuple], self.name)
3225        validator.check_value_type("end", end_v, [tuple], self.name)
3226        validator.check_value_type("strides", strides_v, [tuple], self.name)
3227
3228        if tuple(filter(lambda x: not isinstance(x, int), begin_v + end_v + strides_v)):
3229            raise TypeError(f"For {self.name}, both the 'begin', 'end', and 'strides' must be a tuple of int, "
3230                            f"but got 'begin': {begin_v}, 'end': {end_v}, 'strides': {strides_v}.")
3231
3232        if tuple(filter(lambda x: x == 0, strides_v)):
3233            raise ValueError(f"For '{self.name}', the 'strides' cannot contain 0, but got 'strides': {strides_v}.")
3234
3235        if len(end_v) != len(begin_v) or len(strides_v) != len(begin_v):
3236            raise ValueError(f"For '{self.name}', the length of 'begin' index and the length of 'end' index "
3237                             f"must be the same as 'strides', but got the length of 'begin': {begin_v}, "
3238                             f"'end' index: {end_v} and 'strides': {strides_v}.")
3239
3240        ret_shape = self._compute_slicing_shape(x['shape'], begin_v, end_v, strides_v)
3241
3242        if all(ret_shape):
3243            value = None
3244        else:
3245            init_func = Zero()
3246            init_func.__enable_zero_dim__ = True
3247            value = Tensor(dtype=x['dtype'].element_type(), shape=ret_shape, init=init_func)
3248
3249        if "max_value" in x and "min_value" in x:
3250            validator.check_value_type("min_value", x["min_value"], [tuple, list], self.name)
3251            validator.check_value_type("max_value", x["max_value"], [tuple, list], self.name)
3252            max_value_np = np.array(x["max_value"])
3253            min_value_np = np.array(x["min_value"])
3254            slice_index = []
3255            for begin_i, end_i, strides_i in zip(begin_v, end_v, strides_v):
3256                s = slice(begin_i, end_i, strides_i)
3257                slice_index.append(s)
3258            slice_index = tuple(slice_index)
3259            max_value_slice = max_value_np[slice_index]
3260            min_value_slice = min_value_np[slice_index]
3261            max_value_slice = tuple(max_value_slice.tolist())
3262            min_value_slice = tuple(min_value_slice.tolist())
3263            return {'shape': ret_shape,
3264                    'dtype': x['dtype'],
3265                    'value': value,
3266                    'max_value': max_value_slice,
3267                    'min_value': min_value_slice}
3268
3269        return {'shape': ret_shape,
3270                'dtype': x['dtype'],
3271                'value': value}
3272
3273    def _compute_slicing_shape(self, x_shape, begin_v, end_v, strides_v):
3274        """Computes the shape of the slicing."""
3275        x_rank = len(x_shape)
3276        slice_len = len(begin_v)
3277
3278        # After the integer is converted to binary, it is a str and the first two chars are the flag char '0b'.
3279        begin_pos = bin(self.begin_mask)[-1:1:-1]
3280        end_pos = bin(self.end_mask)[-1:1:-1]
3281        ellipsis_pos = bin(self.ellipsis_mask)[-1:1:-1]
3282        new_axis_pos = bin(self.new_axis_mask)[-1:1:-1]
3283        shrink_axis_pos = bin(self.shrink_axis_mask)[-1:1:-1]
3284
3285        ret_shape = []
3286        i, j = 0, 0
3287        has_ellipsis = False
3288        while i < x_rank or j < slice_len:
3289            if j < slice_len:
3290                begin, end, stride = begin_v[j], end_v[j], strides_v[j]
3291
3292                if j < len(ellipsis_pos) and ellipsis_pos[j] == '1':
3293                    # When there is ellipsis, the latter part of the ellipsis will be processed separately.
3294                    has_ellipsis = True
3295                    break
3296                if j < len(begin_pos) and begin_pos[j] == '1':
3297                    begin = -1 if strides_v[j] < 0 else 0
3298                if j < len(end_pos) and end_pos[j] == '1':
3299                    end = -(x_shape[i] + 1) if strides_v[j] < 0 else x_shape[i]
3300                if j < len(new_axis_pos) and new_axis_pos[j] == '1':
3301                    ret_shape.append(1)
3302                    j += 1
3303                    continue
3304                if j < len(shrink_axis_pos) and shrink_axis_pos[j] == '1':
3305                    if (not -x_shape[i] <= begin < x_shape[i]) or stride < 0:
3306                        raise IndexError(f"For '{self.name}', the 'strides' cannot be negative number and "
3307                                         f"'begin' should be in [-{x_shape[i]}, {x_shape[i]}) "
3308                                         f"when 'shrink_axis_mask' is greater than 0, "
3309                                         f"but got 'shrink_axis_mask': {self.shrink_axis_mask}, 'strides': {stride}, "
3310                                         f"'begin': {begin}.")
3311                    j += 1
3312                    i += 1
3313                    continue
3314            else:
3315                begin, end, stride = 0, x_shape[i], 1
3316
3317            slicing_length = _compute_slicing_length(begin, end, stride, x_shape, i)
3318            ret_shape.append(slicing_length)
3319            i += 1
3320            j += 1
3321        if has_ellipsis:
3322            # When there is ellipsis, handle the second half of the ellipsis split.
3323            ellipsis_occupied_dims = x_rank - i - (slice_len - (j + 1)) + \
3324                                     len(tuple(filter(lambda x: x == '1', new_axis_pos[j + 1:slice_len])))
3325            ret_shape.extend(x_shape[i:i + ellipsis_occupied_dims])
3326            j += 1
3327            i += ellipsis_occupied_dims
3328
3329            while i < x_rank or j < slice_len:
3330                begin, end, stride = begin_v[j], end_v[j], strides_v[j]
3331
3332                if j < len(begin_pos) and begin_pos[j] == '1':
3333                    begin = -1 if strides_v[j] < 0 else 0
3334                if j < len(end_pos) and end_pos[j] == '1':
3335                    end = -(x_shape[i] + 1) if strides_v[j] < 0 else x_shape[i]
3336                if j < len(new_axis_pos) and new_axis_pos[j] == '1':
3337                    ret_shape.append(1)
3338                    j += 1
3339                    continue
3340                if j < len(shrink_axis_pos) and shrink_axis_pos[j] == '1':
3341                    if (not -x_shape[i] <= begin < x_shape[i]) or stride < 0:
3342                        raise IndexError(f"For '{self.name}', the 'strides' cannot be negative number and "
3343                                         f"'begin' should be in [-{x_shape[i]}, {x_shape[i]}) "
3344                                         f"when 'shrink_axis_mask' is greater than 0, "
3345                                         f"but got 'shrink_axis_mask': {self.shrink_axis_mask}, 'strides': {stride}, "
3346                                         f"'begin': {begin}.")
3347                    j += 1
3348                    i += 1
3349                    continue
3350
3351                slicing_length = _compute_slicing_length(begin, end, stride, x_shape, i)
3352                ret_shape.append(slicing_length)
3353                i += 1
3354                j += 1
3355        return ret_shape
3356
3357
3358class Diag(PrimitiveWithInfer):
3359    r"""
3360
3361    Constructs a diagonal tensor with a given diagonal values.
3362
3363    Assume `input_x` has dimensions :math:`[D_1,... D_k]`, the output is a tensor of
3364    rank 2k with dimensions :math:`[D_1,..., D_k, D_1,..., D_k]` where:
3365    :math:`output[i_1,..., i_k, i_1,..., i_k] = input_x[i_1,..., i_k]` and 0 everywhere else.
3366
3367    Inputs:
3368        - **input_x** (Tensor) - The input tensor. The input shape must be less than 5d.
3369
3370    Outputs:
3371        Tensor, has the same dtype as the `input_x`.
3372
3373    Raises:
3374        TypeError: If `input_x` is not a Tensor.
3375        ValueError: If rank of `input_x` is less than 1.
3376
3377    Supported Platforms:
3378        ``Ascend``
3379
3380    Examples:
3381        >>> input_x = Tensor([1, 2, 3, 4])
3382        >>> diag = ops.Diag()
3383        >>> output = diag(input_x)
3384        >>> print(output)
3385        [[1, 0, 0, 0],
3386         [0, 2, 0, 0],
3387         [0, 0, 3, 0],
3388         [0, 0, 0, 4]]
3389    """
3390
3391    @prim_attr_register
3392    def __init__(self):
3393        """Initialize Diag"""
3394
3395    def infer_dtype(self, x_type):
3396        validator.check_subclass('input_x', x_type, mstype.tensor, self.name)
3397        return x_type
3398
3399    def infer_shape(self, x_shape):
3400        validator.check("x rank", len(x_shape), "", 1, Rel.GE)
3401        ret_shape = copy.deepcopy(x_shape)
3402        ret_shape = ret_shape + ret_shape
3403        return ret_shape
3404
3405    def infer_value(self, x):
3406        if x is None:
3407            return None
3408        # do constant-folding only when x rank is 1
3409        if len(x.shape) != 1:
3410            return None
3411        ret = np.diag(x.asnumpy())
3412        return Tensor(ret)
3413
3414
3415class DiagPart(PrimitiveWithInfer):
3416    r"""
3417
3418    Extracts the diagonal part from given tensor.
3419
3420    Assume input has dimensions :math:`[D_1,..., D_k, D_1,..., D_k]`, the output is a tensor
3421    of rank k with dimensions :math:`[D_1,..., D_k]` where:
3422    :math:`output[i_1,..., i_k] = input[i_1,..., i_k, i_1,..., i_k]`.
3423
3424    Inputs:
3425        - **input_x** (Tensor) - The input tensor of rank 2k, k is not zero.
3426
3427    Outputs:
3428        Tensor, the extracted diagonal has the same dtype as the `input_x`.
3429
3430    Raises:
3431        TypeError: If `input_x` is not a Tensor.
3432        ValueError: If rank of `input_x` is not even or zero.
3433        ValueError: If input_shape[i] is not equal to input_shape[i + len(input_shape)/2].
3434
3435    Supported Platforms:
3436        ``Ascend``
3437
3438    Examples
3439        >>> input_x = Tensor([[1, 0, 0, 0],
3440        ...                   [0, 2, 0, 0],
3441        ...                   [0, 0, 3, 0],
3442        ...                   [0, 0, 0, 4]])
3443        >>> diag_part = ops.DiagPart()
3444        >>> output = diag_part(input_x)
3445        >>> print(output)
3446        [1 2 3 4]
3447    """
3448
3449    @prim_attr_register
3450    def __init__(self):
3451        """Initialize DiagPart"""
3452
3453    def infer_dtype(self, x_type):
3454        validator.check_subclass('input_x', x_type, mstype.tensor, self.name)
3455        return x_type
3456
3457    def infer_shape(self, x_shape):
3458        if len(x_shape) % 2 != 0 or \
3459                not x_shape:
3460            raise ValueError(f"For \'{self.name}\', the rank of 'input_x' must be non-zero and even, "
3461                             f"but got rank {len(x_shape)}, with shapes {x_shape}.")
3462        length = len(x_shape) // 2
3463        for i in range(length):
3464            validator.check('input_shape[i + len(input_shape)/2]', x_shape[i + length],
3465                            'input_shape[i]', x_shape[i], Rel.EQ, self.name)
3466        ret_shape = x_shape[0:length]
3467        return ret_shape
3468
3469    def infer_value(self, x):
3470        if x is None:
3471            return None
3472        # do constant-folding only when x rank is 2
3473        if len(x.shape) != 2:
3474            return None
3475        ret = np.diag(x.asnumpy())
3476        return Tensor(ret)
3477
3478
3479class Eye(PrimitiveWithInfer):
3480    """
3481
3482    Creates a tensor with ones on the diagonal and zeros the rest.
3483
3484    Inputs:
3485        - **n** (int) - The number of rows of returned tensor. only constant value.
3486        - **m** (int) - The number of columns of returned tensor. only constant value.
3487        - **t** (mindspore.dtype) - MindSpore's dtype, The data type of the returned tensor.
3488          The data type can be Number.
3489
3490    Outputs:
3491        Tensor, a tensor with ones on the diagonal and the rest of elements are zero. The shape of `output` depends on
3492        the user's Inputs `n` and `m`. And the data type depends on Inputs `t`.
3493
3494    Raises:
3495        TypeError: If `m` or `n` is not an int.
3496        ValueError: If `m` or `n` is less than 1.
3497
3498    Supported Platforms:
3499        ``Ascend`` ``GPU`` ``CPU``
3500
3501    Examples:
3502        >>> eye = ops.Eye()
3503        >>> output = eye(2, 2, mindspore.int32)
3504        >>> print(output)
3505        [[1 0]
3506         [0 1]]
3507        >>> print(output.dtype)
3508        Int32
3509        >>> output = eye(1, 2, mindspore.float64)
3510        >>> print(output)
3511        [[1. 0.]]
3512        >>> print(output.dtype)
3513        Float64
3514        >>> # if wants a anti-diagonal
3515        >>> anti_diagonal_input = eye(2, 2, mindspore.int32)
3516        >>> # Note that ReverseV2 only supports "Ascend" at this time
3517        >>> reverse = ops.ReverseV2([1])
3518        >>> anti_diagonal_output = reverse(anti_diagonal_input)
3519        >>> print(anti_diagonal_output)
3520        [[0 1]
3521         [1 0]]
3522    """
3523
3524    @prim_attr_register
3525    def __init__(self):
3526        """Initialize Eye"""
3527
3528    def infer_value(self, n, m, t):
3529        validator.check_positive_int(n, "n", self.name)
3530        validator.check_positive_int(m, "m", self.name)
3531        args = {"dtype": t}
3532        validator.check_types_same_and_valid(args, mstype.number_type + (mstype.bool_,), self.name)
3533        np_type = mstype.dtype_to_nptype(t)
3534        ret = np.eye(n, m, dtype=np_type)
3535        return Tensor(ret)
3536
3537
3538class ScatterNd(PrimitiveWithInfer):
3539    r"""
3540    Scatters a tensor into a new tensor depending on the specified indices.
3541
3542    Creates an empty tensor with the given `shape`, and set values by scattering the update tensor
3543    depending on indices.
3544
3545    The empty tensor has rank P and `indices` has rank Q where `Q >= 2`.
3546
3547    `indices` has shape :math:`(i_0, i_1, ..., i_{Q-2}, N)` where `N <= P`.
3548
3549    The last dimension of `indices` (with length `N` ) indicates slices along the `N` th dimension of the empty tensor.
3550
3551    `updates` is a tensor of rank `Q-1+P-N`. Its shape is: :math:`(i_0, i_1, ..., i_{Q-2}, shape_N, ..., shape_{P-1})`.
3552
3553    The following figure shows the calculation process of inserting two slices in the first dimension of a rank-3
3554    with two matrices of new values:
3555
3556    .. image:: api_img/ScatterNd.png
3557
3558    Inputs:
3559        - **indices** (Tensor) - The index of scattering in the new tensor with int32 or int64 data type.
3560          The rank of indices must be at least 2 and `indices_shape[-1] <= len(shape)`.
3561        - **updates** (Tensor) - The source Tensor to be scattered.
3562          It has shape `indices_shape[:-1] + shape[indices_shape[-1]:]`.
3563        - **shape** (tuple[int]) - Define the shape of the output tensor, has the same data type as indices.
3564          The shape of `shape` is :math:`(x_1, x_2, ..., x_R)`, and length of 'shape' is greater than or equal 2.
3565          In other words, the shape of `shape` is at least :math:`(x_1, x_2)`.
3566          And the value of any element in `shape` must be greater than or equal 1.
3567          In other words, :math:`x_1` >= 1, :math:`x_2` >= 1.
3568
3569    Outputs:
3570        Tensor, the new tensor, has the same type as `update` and the same shape as `shape`.
3571
3572    Raises:
3573        TypeError: If `shape` is not a tuple.
3574        ValueError: If any element of `shape` is less than 1.
3575
3576    Supported Platforms:
3577        ``Ascend`` ``GPU`` ``CPU``
3578
3579    Examples:
3580        >>> op = ops.ScatterNd()
3581        >>> indices = Tensor(np.array([[0], [2]]), mindspore.int32)
3582        >>> updates = Tensor(np.array([[[1, 1, 1, 1], [2, 2, 2, 2],
3583        ...                             [3, 3, 3, 3], [4, 4, 4, 4]],
3584        ...                            [[1, 1, 1, 1], [2, 2, 2, 2],
3585        ...                             [3, 3, 3, 3], [4, 4, 4, 4]]]), mindspore.float32)
3586        >>> shape = (4, 4, 4)
3587        >>> output = op(indices, updates, shape)
3588        >>> print(output)
3589        [[[1. 1. 1. 1.]
3590          [2. 2. 2. 2.]
3591          [3. 3. 3. 3.]
3592          [4. 4. 4. 4.]]
3593         [[0. 0. 0. 0.]
3594          [0. 0. 0. 0.]
3595          [0. 0. 0. 0.]
3596          [0. 0. 0. 0.]]
3597         [[1. 1. 1. 1.]
3598          [2. 2. 2. 2.]
3599          [3. 3. 3. 3.]
3600          [4. 4. 4. 4.]]
3601         [[0. 0. 0. 0.]
3602          [0. 0. 0. 0.]
3603          [0. 0. 0. 0.]
3604          [0. 0. 0. 0.]]]
3605        >>> indices = Tensor(np.array([[0, 1], [1, 1]]), mindspore.int32)
3606        >>> updates = Tensor(np.array([3.2, 1.1]), mindspore.float32)
3607        >>> shape = (3, 3)
3608        >>> output = op(indices, updates, shape)
3609        >>> # In order to facilitate understanding, explain the operator pseudo-operation process step by step:
3610        >>> # Step 1: Generate an empty Tensor of the specified shape according to the shape
3611        >>> # [
3612        >>> #     [0. 0. 0.]
3613        >>> #     [0. 0. 0.]
3614        >>> #     [0. 0. 0.]
3615        >>> # ]
3616        >>> # Step 2: Modify the data at the specified location according to the indicators
3617        >>> # 0th row of indices is [0, 1], 0th row of updates is 3.2.
3618        >>> # means that the empty tensor in the 0th row and 1st col set to 3.2
3619        >>> # [
3620        >>> #     [0. 3.2. 0.]
3621        >>> #     [0. 0.   0.]
3622        >>> #     [0. 0.   0.]
3623        >>> # ]
3624        >>> # 1th row of indices is [1, 1], 1th row of updates is 1.1.
3625        >>> # means that the empty tensor in the 1th row and 1st col set to 1.1
3626        >>> # [
3627        >>> #     [0. 3.2. 0.]
3628        >>> #     [0. 1.1  0.]
3629        >>> #     [0. 0.   0.]
3630        >>> # ]
3631        >>> # The final result is as follows:
3632        >>> print(output)
3633        [[0. 3.2 0.]
3634         [0. 1.1 0.]
3635         [0. 0.  0.]]
3636    """
3637
3638    @prim_attr_register
3639    def __init__(self):
3640        """Initialize ScatterNd"""
3641        self.init_prim_io_names(inputs=['indices', 'update', 'shape'], outputs=['output'])
3642
3643    def __infer__(self, indices, update, shape):
3644        shp = shape['value']
3645        validator.check_subclass("update_dtype", update['dtype'], mstype.tensor, self.name)
3646        validator.check_tensor_dtype_valid("indices", indices['dtype'], [mstype.int32, mstype.int64], self.name)
3647        validator.check_value_type("shape", shp, [tuple], self.name)
3648        for i, x in enumerate(shp):
3649            validator.check_positive_int(x, f'shape[{i}]', self.name)
3650
3651        indices_shape, update_shape = indices["shape"], update["shape"]
3652        if indices_shape[0] != update_shape[0]:
3653            raise ValueError(f"For '{self.name}', the first shape of 'indices' must be the same as the first shape of "
3654                             f"'updates', but got the first shape of 'indices': {indices_shape[0]}, "
3655                             f"the first shape of 'updates': {update_shape[0]}.")
3656
3657        return {'shape': shp,
3658                'dtype': update['dtype'],
3659                'value': None}
3660
3661
3662class ResizeNearestNeighbor(PrimitiveWithInfer):
3663    r"""
3664    Resizes the input tensor by using the nearest neighbor algorithm.
3665
3666    Resizes the input tensor to a given size by using the nearest neighbor algorithm. The nearest
3667    neighbor algorithm selects the value of the nearest point and does not consider the
3668    values of neighboring points at all, yielding a piecewise-constant interpolant.
3669
3670    Args:
3671        size (Union[tuple, list]): The target size. The dimension of size must be 2.
3672        align_corners (bool): Whether the centers of the 4 corner pixels of the input
3673                              and output tensors are aligned. Default: False.
3674
3675    Inputs:
3676        - **input_x** (Tensor) - The input tensor. The shape of the tensor is :math:`(N, C, H, W)`.
3677
3678    Outputs:
3679        Tensor, the shape of the output tensor is :math:`(N, C, NEW\_H, NEW\_W)`.
3680          The data type is same as the `input_x`.
3681
3682    Raises:
3683        TypeError: If `size` is neither tuple nor list.
3684        TypeError: If `align_corners` is not a bool.
3685        ValueError: If length of `size` is not equal to 2.
3686
3687    Supported Platforms:
3688        ``Ascend`` ``GPU`` ``CPU``
3689
3690    Examples:
3691        >>> input_tensor = Tensor(np.array([[[[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]]]), mindspore.float32)
3692        >>> resize = ops.ResizeNearestNeighbor((2, 2))
3693        >>> output = resize(input_tensor)
3694        >>> print(output)
3695        [[[[-0.1  0.3]
3696           [ 0.4  0.5]]]]
3697    """
3698
3699    @prim_attr_register
3700    def __init__(self, size, align_corners=False):
3701        """Initialize ResizeNearestNeighbor"""
3702        validator.check_value_type("size", size, [tuple, list], self.name)
3703        validator.check_value_type("align_corners", align_corners, [bool], self.name)
3704        validator.check_equal_int(len(size), 2, "length of size", self.name)
3705        for i, value in enumerate(size):
3706            validator.check_non_negative_int(value, f'{i}th value of size', self.name)
3707        self.init_prim_io_names(inputs=['image_in'], outputs=['image_out'])
3708
3709    def infer_shape(self, x_shape):
3710        validator.check('the dimension of input_x', len(x_shape), '', 4, Rel.EQ, self.name)
3711        return tuple(x_shape)[:-2] + tuple(self.size)
3712
3713    def infer_dtype(self, x_dtype):
3714        validator.check_tensor_dtype_valid("x", x_dtype, mstype.number_type, self.name)
3715        return x_dtype
3716
3717
3718class GatherNd(PrimitiveWithInfer):
3719    r"""
3720    Gathers slices from a tensor by indices.
3721
3722    Using given indices to gather slices from a tensor with a specified shape.
3723
3724    `indices` is an K-dimensional integer tensor. Supposes it as a (K-1)-dimensional tensor and each element of it
3725    defines a slice of `input_x`:
3726
3727    .. math::
3728        output[(i_0, ..., i_{K-2})] = input\_x[indices[(i_0, ..., i_{K-2})]]
3729
3730    The last dimension of `indices` can not more than the rank of `input_x`:
3731    :math:`indices.shape[-1] <= input\_x.rank`.
3732
3733    Inputs:
3734        - **input_x** (Tensor) - The target tensor to gather values.
3735          The shape is :math:`(N,*)` where :math:`*` means,any number of additional dimensions.
3736        - **indices** (Tensor) - The index tensor, with int32 or int64 data type.
3737          The dimension of `indices` should be <= the dimension of `input_x`.
3738
3739    Outputs:
3740        Tensor, has the same type as `input_x` and the shape is indices_shape[:-1] + x_shape[indices_shape[-1]:].
3741
3742    Raises:
3743        ValueError: If length of shape of `input_x` is less than the last dimension of `indices`.
3744
3745    Supported Platforms:
3746        ``Ascend`` ``GPU`` ``CPU``
3747
3748    Examples:
3749        >>> op = ops.GatherNd()
3750        >>> input_x = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32)
3751        >>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32)
3752        >>> output = op(input_x, indices)
3753        >>> print(output)
3754        [-0.1  0.5]
3755    """
3756
3757    @prim_attr_register
3758    def __init__(self):
3759        """Initialize GatherNd"""
3760        self.init_prim_io_names(inputs=['input_x', 'indices'], outputs=['y'])
3761
3762    def infer_shape(self, x_shape, indices_shape):
3763        validator.check('the dimension of x', len(x_shape),
3764                        'the dimension of indices', indices_shape[-1], Rel.GE, self.name)
3765        return indices_shape[:-1] + x_shape[indices_shape[-1]:]
3766
3767    def infer_dtype(self, x_dtype, indices_dtype):
3768        validator.check_tensor_dtype_valid("indices", indices_dtype, mstype.int_type, self.name)
3769        return x_dtype
3770
3771
3772class TensorScatterUpdate(PrimitiveWithInfer):
3773    """
3774    Creates a new tensor by updating the positions in `input_x` indicicated by
3775    `indices`, with values from `update`. This operation is almost equivalent to using
3776    ScatterNd, except that the updates are applied on `input_x` instead of a zero tensor.
3777
3778    `indices` must have rank at least 2, the last axis is the depth of each index
3779    vectors. For each index vector, there must be a corresponding value in `update`. If
3780    the depth of each index tensor matches the rank of `input_x`, then each index
3781    vector corresponds to a scalar in `input_x` and each update updates a scalar. If
3782    the depth of each index tensor is less than the rank of `input_x`, then each index
3783    vector corresponds to a slice in `input_x`, and each update updates a slice.
3784
3785    The order in which updates are applied is nondeterministic, meaning that if there
3786    are multiple index vectors in `indices` that correspond to the same position, the
3787    value of that position in the output will be nondeterministic.
3788
3789    Inputs:
3790        - **input_x** (Tensor) - The target tensor. The dimension of input_x must be no less than indices.shape[-1].
3791          The shape is :math:`(N,*)` where :math:`*` means,any number of additional dimensions.
3792          The data type is Number.
3793        - **indices** (Tensor) - The index of input tensor whose data type is int32 or int64.
3794          The rank must be at least 2.
3795        - **update** (Tensor) - The tensor to update the input tensor, has the same type as input,
3796          and update.shape = indices.shape[:-1] + input_x.shape[indices.shape[-1]:].
3797
3798    Outputs:
3799        Tensor, has the same shape and type as `input_x`.
3800
3801    Raises:
3802        TypeError: If dtype of `indices` is neither int32 nor int64.
3803        ValueError: If length of shape of `input_x` is less than the last dimension of shape of `indices`.
3804        ValueError: If the value of `input_x` are not match with input `indices`.
3805
3806    Supported Platforms:
3807        ``Ascend`` ``GPU`` ``CPU``
3808
3809    Examples:
3810        >>> input_x = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32)
3811        >>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32)
3812        >>> update = Tensor(np.array([1.0, 2.2]), mindspore.float32)
3813        >>> op = ops.TensorScatterUpdate()
3814        >>> output = op(input_x, indices, update)
3815        >>> print(output)
3816        [[ 1.   0.3  3.6]
3817         [ 0.4  2.2 -3.2]]
3818    """
3819
3820    @prim_attr_register
3821    def __init__(self):
3822        self.init_prim_io_names(inputs=['input_x', 'indices', 'updates'], outputs=['y'])
3823
3824    def infer_shape(self, input_x_shape, indices_shape, updates_shape):
3825        if len(indices_shape) < 2:
3826            raise ValueError(f"For '{self.name}', the dimension of 'indices' cannot be less than 2,"
3827                             f" but got {len(indices_shape)}.")
3828
3829        if indices_shape[-1] > len(input_x_shape):
3830            raise ValueError(f"For '{self.name}', the last dimension of 'indices' must be less than or equal to "
3831                             f"the dimension of 'input_x', but got the "
3832                             f"last dimension of 'indices': {indices_shape[-1]} and the dimension of 'input_x': "
3833                             f"{len(indices_shape)}.")
3834
3835        updates_shape_check = indices_shape[:-1] + input_x_shape[indices_shape[-1]:]
3836        if updates_shape_check != updates_shape:
3837            raise ValueError(f"For '{self.name}', the shape of 'update' must be equal to updates_shape_check, "
3838                             f"where updates_shape_check = indices_shape[:-1] + input_x_shape[indices_shape[-1]:] "
3839                             f"but got the shape of 'update': {updates_shape}, "
3840                             f"updates_shape_check: {updates_shape_check}, indices_shape: {indices_shape} and "
3841                             f"input_x_shape: {input_x_shape}. Please check input_x_shape and indices_shape.")
3842
3843        return input_x_shape
3844
3845    def infer_dtype(self, input_x_dtype, indices_dtype, updates_dtype):
3846        validator.check_tensor_dtype_valid('indices', indices_dtype, [mstype.int32, mstype.int64], self.name)
3847        args = {"input_x": input_x_dtype, "updates": updates_dtype}
3848        validator.check_tensors_dtypes_same_and_valid(args, (mstype.bool_,) + mstype.number_type, self.name)
3849        return input_x_dtype
3850
3851
3852class TensorScatterAdd(PrimitiveWithInfer):
3853    """
3854    Creates a new tensor by adding the values from the positions in `input_x` indicicated by
3855    `indices`, with values from `updates`. When multiple values are given for the same
3856    index, the updated result will be the sum of all values. This operation is almost
3857    equivalent to using ScatterNdAdd, except that the updates are applied on `Tensor`
3858    instead of `Parameter`.
3859
3860    The last axis of `indices` is the depth of each index vectors. For each index vector,
3861    there must be a corresponding value in `updates`. The shape of `updates` should be
3862    equal to the shape of `input_x[indices]`. For more details, see use cases.
3863
3864    Note:
3865        If some values of the `indices` are out of bound, instead of raising an index error,
3866        the corresponding `updates` will not be updated to `input_x`.
3867
3868    Inputs:
3869        - **input_x** (Tensor) - The target tensor. The dimension of input_x must be no less than indices.shape[-1].
3870        - **indices** (Tensor) - The index of input tensor whose data type is int32 or int64.
3871          The rank must be at least 2.
3872        - **updates** (Tensor) - The tensor to update the input tensor, has the same type as input,
3873          and updates.shape should be equal to indices.shape[:-1] + input_x.shape[indices.shape[-1]:].
3874
3875    Outputs:
3876        Tensor, has the same shape and type as `input_x`.
3877
3878    Raises:
3879        TypeError: If dtype of `indices` is neither int32 nor int64.
3880        ValueError: If length of shape of `input_x` is less than the last dimension of shape of `indices`.
3881
3882    Supported Platforms:
3883        ``GPU``
3884
3885    Examples:
3886        >>> input_x = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32)
3887        >>> indices = Tensor(np.array([[0, 0], [0, 0]]), mindspore.int32)
3888        >>> updates = Tensor(np.array([1.0, 2.2]), mindspore.float32)
3889        >>> # Next, demonstrate the approximate operation process of this operator:
3890        >>> # 1, indices[0] = [0, 0], indices[1] = [0, 0]
3891        >>> # 2, And input_x[0, 0] = -0.1
3892        >>> # 3, So input_x[indices] = [-0.1, -0.1]
3893        >>> # 4, Satisfy the above formula: input_x[indices].shape=(2) == updates.shape=(2)
3894        >>> op = ops.TensorScatterAdd()
3895        >>> # 5, Perform the addition operation for the first time:
3896        >>> #      first_input_x = input_x[0][0] + updates[0] = [[0.9, 0.3, 3.6], [0.4, 0.5, -3.2]]
3897        >>> # 6, Perform the addition operation for the second time:
3898        >>> #      second_input_x = input_x[0][0] + updates[1] = [[3.1, 0.3, 3.6], [0.4, 0.5, -3.2]]
3899        >>> output = op(input_x, indices, updates)
3900        >>> print(output)
3901        [[ 3.1  0.3  3.6]
3902         [ 0.4  0.5 -3.2]]
3903    """
3904
3905    @prim_attr_register
3906    def __init__(self):
3907        self.init_prim_io_names(inputs=['input_x', 'indices', 'updates'], outputs=['y'])
3908
3909    def infer_shape(self, input_x_shape, indices_shape, updates_shape):
3910        if len(indices_shape) < 2:
3911            raise ValueError(f"For '{self.name}', the dimension of 'indices' cannot be less than 2,"
3912                             f" but got {len(indices_shape)}.")
3913
3914        if indices_shape[-1] > len(input_x_shape):
3915            raise ValueError(f"For '{self.name}', the last dimension of 'indices' must be less than or equal to "
3916                             f"the dimension of 'input_x', but got the "
3917                             f"last dimension of 'indices': {indices_shape[-1]} and the dimension of 'input_x': "
3918                             f"{len(indices_shape)}.")
3919
3920        updates_shape_check = indices_shape[:-1] + input_x_shape[indices_shape[-1]:]
3921        if updates_shape_check != updates_shape:
3922            raise ValueError(f"For '{self.name}', the shape of 'update' must be equal to updates_shape_check, "
3923                             f"where updates_shape_check = indices_shape[:-1] + input_x_shape[indices_shape[-1]:] "
3924                             f"but got the shape of 'update': {updates_shape}, "
3925                             f"updates_shape_check: {updates_shape_check}, indices_shape: {indices_shape} and "
3926                             f"input_x_shape: {input_x_shape}. Please check input_x_shape and indices_shape.")
3927
3928        return input_x_shape
3929
3930    def infer_dtype(self, input_x_dtype, indices_dtype, updates_dtype):
3931        validator.check_tensor_dtype_valid('indices', indices_dtype, [mstype.int32, mstype.int64], self.name)
3932        args = {"input_x": input_x_dtype, "updates": updates_dtype}
3933        validator.check_tensors_dtypes_same_and_valid(args, (mstype.bool_,) + mstype.number_type, self.name)
3934        return input_x_dtype
3935
3936
3937class ScatterUpdate(_ScatterOpDynamic):
3938    r"""
3939    Updates tensor values by using input indices and value.
3940
3941    Using given values to update tensor value, along with the input indices.
3942
3943    for each `i, ..., j` in `indices.shape`:
3944
3945    .. math::
3946
3947        \text{input_x}[\text{indices}[i, ..., j], :] = \text{updates}[i, ..., j, :]
3948
3949    Inputs of `input_x` and `updates` comply with the implicit type conversion rules to make the data types consistent.
3950    If they have different data types, lower priority data type will be converted to
3951    relatively highest priority data type.
3952    RuntimeError exception will be thrown when the data type conversion of Parameter is required.
3953
3954    Args:
3955        use_locking (bool): Whether protect the assignment by a lock. Default: True.
3956
3957    Inputs:
3958        - **input_x** (Parameter) - The target tensor, with data type of Parameter.
3959          The shape is :math:`(N,*)` where :math:`*` means,any number of additional dimensions.
3960        - **indices** (Tensor) - The index of input tensor. With int32 data type.
3961          If there are duplicates in indices, the order for updating is undefined.
3962        - **updates** (Tensor) - The tensor to update the input tensor, has the same type as input,
3963          and updates.shape = indices.shape + input_x.shape[1:].
3964
3965    Outputs:
3966        Tensor, has the same shape and type as `input_x`.
3967
3968    Raises:
3969        TypeError: If `use_locking` is not a bool.
3970        TypeError: If `indices` is not an int32.
3971
3972    Supported Platforms:
3973        ``Ascend`` ``GPU`` ``CPU``
3974
3975    Examples:
3976        >>> np_x = np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]])
3977        >>> input_x = mindspore.Parameter(Tensor(np_x, mindspore.float32), name="x")
3978        >>> indices = Tensor(np.array([0, 1]), mindspore.int32)
3979        >>> np_updates = np.array([[2.0, 1.2, 1.0], [3.0, 1.2, 1.0]])
3980        >>> updates = Tensor(np_updates, mindspore.float32)
3981        >>> op = ops.ScatterUpdate()
3982        >>> output = op(input_x, indices, updates)
3983        >>> print(output)
3984        [[2. 1.2  1.]
3985         [3. 1.2  1.]]
3986    """
3987
3988    @prim_attr_register
3989    def __init__(self, use_locking=True):
3990        """Initialize ScatterUpdate"""
3991        validator.check_value_type('use_locking', use_locking, [bool], self.name)
3992        self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y'])
3993        self.add_prim_attr('side_effect_mem', True)
3994
3995
3996class ScatterNdUpdate(_ScatterNdOp):
3997    r"""
3998    Updates tensor values by using input indices and value.
3999
4000    Using given values to update tensor value, along with the input indices.
4001
4002    `input_x` has rank P and `indices` has rank Q where `Q >= 2`.
4003
4004    `indices` has shape :math:`(i_0, i_1, ..., i_{Q-2}, N)` where `N <= P`.
4005
4006    The last dimension of `indices` (with length `N` ) indicates slices along the `N` th dimension of `input_x`.
4007
4008    `updates` is a tensor of rank `Q-1+P-N`. Its shape is:
4009    :math:`(i_0, i_1, ..., i_{Q-2}, x\_shape_N, ..., x\_shape_{P-1})`.
4010
4011    Inputs of `input_x` and `updates` comply with the implicit type conversion rules to make the data types consistent.
4012    If they have different data types, lower priority data type will be converted to
4013    relatively highest priority data type.
4014    RuntimeError exception will be thrown when the data type conversion of Parameter is required.
4015
4016    Args:
4017        use_locking (bool): Whether protect the assignment by a lock. Default: True.
4018
4019    Inputs:
4020        - **input_x** (Parameter) - The target tensor, with data type of Parameter.
4021          The shape is :math:`(N,*)` where :math:`*` means,any number of additional dimensions.
4022        - **indices** (Tensor) - The index of input tensor, with int32 data type.
4023        - **updates** (Tensor) - The tensor to be updated to the input tensor, has the same type as input.
4024          The shape is `indices_shape[:-1] + x_shape[indices_shape[-1]:]`.
4025
4026    Outputs:
4027        Tensor, has the same shape and type as `input_x`.
4028
4029    Raises:
4030        TypeError: If `use_locking` is not a bool.
4031        TypeError: If `indices` is not an int32.
4032
4033    Supported Platforms:
4034        ``Ascend`` ``GPU`` ``CPU``
4035
4036    Examples:
4037        >>> np_x = np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]])
4038        >>> input_x = mindspore.Parameter(Tensor(np_x, mindspore.float32), name="x")
4039        >>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32)
4040        >>> updates = Tensor(np.array([1.0, 2.2]), mindspore.float32)
4041        >>> op = ops.ScatterNdUpdate()
4042        >>> output = op(input_x, indices, updates)
4043        >>> print(output)
4044        [[1.   0.3   3.6]
4045         [0.4  2.2  -3.2]]
4046    """
4047
4048    @prim_attr_register
4049    def __init__(self, use_locking=True):
4050        """Initialize ScatterNdUpdate"""
4051        validator.check_value_type('use_locking', use_locking, [bool], self.name)
4052        self.init_prim_io_names(inputs=['x', 'indices', 'value'], outputs=['y'])
4053        self.add_prim_attr('side_effect_mem', True)
4054
4055    def infer_dtype(self, x_dtype, indices_dtype, value_dtype):
4056        validator.check_tensor_dtype_valid('indices', indices_dtype, [mstype.int32], self.name)
4057        args = {"x": x_dtype, "value": value_dtype}
4058        validator.check_tensors_dtypes_same_and_valid(args, (mstype.bool_,) + mstype.number_type, self.name)
4059        return x_dtype
4060
4061
4062class ScatterMax(_ScatterOp):
4063    r"""
4064    Updates the value of the input tensor through the maximum operation.
4065
4066    Using given values to update tensor value through the max operation, along with the input indices.
4067    This operation outputs the `input_x` after the update is done, which makes it convenient to use the updated value.
4068
4069    for each `i, ..., j` in `indices.shape`:
4070
4071    .. math::
4072
4073        \text{input_x}[\text{indices}[i, ..., j], :]
4074        = max(\text{input_x}[\text{indices}[i, ..., j], :], \text{updates}[i, ..., j, :])
4075
4076    Inputs of `input_x` and `updates` comply with the implicit type conversion rules to make the data types consistent.
4077    If they have different data types, lower priority data type will be converted to
4078    relatively highest priority data type.
4079    RuntimeError exception will be thrown when the data type conversion of Parameter is required.
4080
4081    Args:
4082        use_locking (bool): Whether protect the assignment by a lock. Default: True.
4083
4084    Inputs:
4085        - **input_x** (Parameter) - The target tensor, with data type of Parameter.
4086          The shape is :math:`(N,*)` where :math:`*` means,any number of additional dimensions.
4087        - **indices** (Tensor) - The index to do max operation whose data type must be mindspore.int32.
4088        - **updates** (Tensor) - The tensor that performs the maximum operation with `input_x`,
4089          the data type is the same as `input_x`, the shape is `indices_shape + x_shape[1:]`.
4090
4091    Outputs:
4092        Tensor, the updated `input_x`, has the same shape and type as `input_x`.
4093
4094    Raises:
4095        TypeError: If `use_locking` is not a bool.
4096        TypeError: If `indices` is not an int32.
4097        ValueError: If the shape of `updates` is not equal to `indices_shape + x_shape[1:]`.
4098
4099    Supported Platforms:
4100        ``Ascend`` ``CPU``
4101
4102    Examples:
4103        >>> input_x = Parameter(Tensor(np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), mindspore.float32),
4104        ...                     name="input_x")
4105        >>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32)
4106        >>> updates = Tensor(np.ones([2, 2, 3]) * 88, mindspore.float32)
4107        >>> scatter_max = ops.ScatterMax()
4108        >>> output = scatter_max(input_x, indices, updates)
4109        >>> print(output)
4110        [[88. 88. 88.]
4111         [88. 88. 88.]]
4112    """
4113
4114
4115class ScatterMin(_ScatterOp):
4116    r"""
4117    Updates the value of the input tensor through the minimum operation.
4118
4119    Using given values to update tensor value through the min operation, along with the input indices.
4120    This operation outputs the `input_x` after the update is done, which makes it convenient to use the updated value.
4121
4122    for each `i, ..., j` in `indices.shape`:
4123
4124    .. math::
4125
4126        \text{input_x}[\text{indices}[i, ..., j], :]
4127        = min(\text{input_x}[\text{indices}[i, ..., j], :], \text{updates}[i, ..., j, :])
4128
4129    Inputs of `input_x` and `updates` comply with the implicit type conversion rules to make the data types consistent.
4130    If they have different data types, lower priority data type will be converted to
4131    relatively highest priority data type.
4132    RuntimeError exception will be thrown when the data type conversion of Parameter is required.
4133
4134    Args:
4135        use_locking (bool): Whether protect the assignment by a lock. Default: False.
4136
4137    Inputs:
4138        - **input_x** (Parameter) - The target tensor, with data type of Parameter.
4139          The shape is :math:`(N,*)` where :math:`*` means,any number of additional dimensions.
4140        - **indices** (Tensor) - The index to do min operation whose data type must be mindspore.int32.
4141        - **updates** (Tensor) - The tensor doing the min operation with `input_x`,
4142          the data type is same as `input_x`, the shape is `indices_shape + x_shape[1:]`.
4143
4144    Outputs:
4145        Tensor, the updated `input_x`, has the same shape and type as `input_x`.
4146
4147    Raises:
4148        TypeError: If `use_locking` is not a bool.
4149        TypeError: If `indices` is not an int32.
4150        ValueError: If the shape of `updates` is not equal to `indices_shape + x_shape[1:]`.
4151
4152    Supported Platforms:
4153        ``Ascend`` ``CPU``
4154
4155    Examples:
4156        >>> input_x = Parameter(Tensor(np.array([[0.0, 1.0, 2.0], [0.0, 0.0, 0.0]]), mindspore.float32),
4157        ...                     name="input_x")
4158        >>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32)
4159        >>> update = Tensor(np.ones([2, 2, 3]), mindspore.float32)
4160        >>> scatter_min = ops.ScatterMin()
4161        >>> output = scatter_min(input_x, indices, update)
4162        >>> print(output)
4163        [[0. 1. 1.]
4164         [0. 0. 0.]]
4165    """
4166
4167
4168class ScatterAdd(_ScatterOpDynamic):
4169    r"""
4170    Updates the value of the input tensor through the addition operation.
4171
4172    Using given values to update tensor value through the add operation, along with the input indices.
4173    This operation outputs the `input_x` after the update is done, which makes it convenient to use the updated value.
4174
4175    for each `i, ..., j` in `indices.shape`:
4176
4177    .. math::
4178
4179        \text{input_x}[\text{indices}[i, ..., j], :] \mathrel{+}= \text{updates}[i, ..., j, :]
4180
4181    Inputs of `input_x` and `updates` comply with the implicit type conversion rules to make the data types consistent.
4182    If they have different data types, lower priority data type will be converted to
4183    relatively highest priority data type.
4184    RuntimeError exception will be thrown when the data type conversion of Parameter is required.
4185
4186    Note:
4187        This is an in-place update operator. Therefore, the `input_x` will be updated after the operation is completed.
4188
4189    Args:
4190        use_locking (bool): Whether protect the assignment by a lock. Default: False.
4191
4192    Inputs:
4193        - **input_x** (Parameter) - The target tensor, with data type of Parameter.
4194          The shape is :math:`(N,*)` where :math:`*` means,any number of additional dimensions.
4195        - **indices** (Tensor) - The index to do min operation whose data type must be mindspore.int32.
4196        - **updates** (Tensor) - The tensor doing the min operation with `input_x`,
4197          the data type is same as `input_x`, the shape is `indices_shape + x_shape[1:]`.
4198
4199    Outputs:
4200        Tensor, the updated `input_x`, has the same shape and type as `input_x`.
4201
4202    Raises:
4203        TypeError: If `use_locking` is not a bool.
4204        TypeError: If `indices` is not an int32.
4205        ValueError: If the shape of `updates` is not equal to `indices_shape + x_shape[1:]`.
4206
4207    Supported Platforms:
4208        ``Ascend`` ``GPU`` ``CPU``
4209
4210    Examples:
4211        >>> input_x = Parameter(Tensor(np.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]), mindspore.float32), name="x")
4212        >>> indices = Tensor(np.array([[0, 1], [1, 1]]), mindspore.int32)
4213        >>> updates = Tensor(np.ones([2, 2, 3]), mindspore.float32)
4214        >>> scatter_add = ops.ScatterAdd()
4215        >>> output = scatter_add(input_x, indices, updates)
4216        >>> print(output)
4217        [[1. 1. 1.]
4218         [3. 3. 3.]]
4219        >>> # for input_x will be updated after the operation is completed. input_x need to be re-initialized.
4220        >>> input_x = Parameter(Tensor(np.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]), mindspore.float32), name="x")
4221        >>> # for indices = [[0, 1], [1, 1]]
4222        >>> # step 1: [0, 1]
4223        >>> # input_x[0] = [0.0, 0.0, 0.0] + [1.0, 1.0, 1.0] = [1.0, 1.0, 1.0]
4224        >>> # input_x[1] = [0.0, 0.0, 0.0] + [3.0, 3.0, 3.0] = [3.0, 3.0, 3.0]
4225        >>> # step 2: [1, 1]
4226        >>> # input_x[1] = [3.0, 3.0, 3.0] + [7.0, 7.0, 7.0] = [10.0, 10.0, 10.0]
4227        >>> # input_x[1] = [10.0, 10.0, 10.0] + [9.0, 9.0, 9.0] = [19.0, 19.0, 19.0]
4228        >>> indices = Tensor(np.array([[0, 1], [1, 1]]), mindspore.int32)
4229        >>> updates = Tensor(np.array([[[1.0, 1.0, 1.0], [3.0, 3.0, 3.0]],
4230        ...                            [[7.0, 7.0, 7.0], [9.0, 9.0, 9.0]]]), mindspore.float32)
4231        >>> scatter_add = ops.ScatterAdd()
4232        >>> output = scatter_add(input_x, indices, updates)
4233        >>> print(output)
4234        [[ 1.  1.  1.]
4235         [19. 19. 19.]]
4236        >>> # for input_x will be updated after the operation is completed. input_x need to be re-initialized.
4237        >>> input_x = Parameter(Tensor(np.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]), mindspore.float32), name="x")
4238        >>> # for indices = [[1, 0], [1, 1]]
4239        >>> # step 1: [1, 0]
4240        >>> # input_x[0] = [0.0, 0.0, 0.0] + [3.0, 3.0, 3.0] = [3.0, 3.0, 3.0]
4241        >>> # input_x[1] = [0.0, 0.0, 0.0] + [1.0, 1.0, 1.0] = [1.0, 1.0, 1.0]
4242        >>> # step 2: [1, 1]
4243        >>> # input_x[1] = [1.0, 1.0, 1.0] + [7.0, 7.0, 7.0] = [8.0, 8.0, 8.0]
4244        >>> # input_x[1] = [8.0, 8.0, 8.0] + [9.0, 9.0, 9.0] = [17.0, 17.0, 17.0]
4245        >>> indices = Tensor(np.array([[1, 0], [1, 1]]), mindspore.int32)
4246        >>> updates = Tensor(np.array([[[1.0, 1.0, 1.0], [3.0, 3.0, 3.0]],
4247        ...                            [[7.0, 7.0, 7.0], [9.0, 9.0, 9.0]]]), mindspore.float32)
4248        >>> scatter_add = ops.ScatterAdd()
4249        >>> output = scatter_add(input_x, indices, updates)
4250        >>> print(output)
4251        [[ 3.  3.  3.]
4252         [17. 17. 17.]]
4253        >>> # for input_x will be updated after the operation is completed. input_x need to be re-initialized.
4254        >>> input_x = Parameter(Tensor(np.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]), mindspore.float32), name="x")
4255        >>> # for indices = [[0, 1], [0, 1]]
4256        >>> # step 1: [0, 1]
4257        >>> # input_x[0] = [0.0, 0.0, 0.0] + [1.0, 1.0, 1.0] = [1.0, 1.0, 1.0]
4258        >>> # input_x[1] = [0.0, 0.0, 0.0] + [3.0, 3.0, 3.0] = [3.0, 3.0, 3.0]
4259        >>> # step 2: [0, 1]
4260        >>> # input_x[0] = [1.0, 1.0, 1.0] + [7.0, 7.0, 7.0] = [8.0, 8.0, 8.0]
4261        >>> # input_x[1] = [3.0, 3.0, 3.0] + [9.0, 9.0, 9.0] = [12.0, 12.0, 12.0]
4262        >>> indices = Tensor(np.array([[0, 1], [0, 1]]), mindspore.int32)
4263        >>> updates = Tensor(np.array([[[1.0, 1.0, 1.0], [3.0, 3.0, 3.0]],
4264        ...                            [[7.0, 7.0, 7.0], [9.0, 9.0, 9.0]]]), mindspore.float32)
4265        >>> scatter_add = ops.ScatterAdd()
4266        >>> output = scatter_add(input_x, indices, updates)
4267        >>> print(output)
4268        [[ 8.  8.  8.]
4269         [12. 12. 12.]]
4270    """
4271
4272    @prim_attr_register
4273    def __init__(self, use_locking=False):
4274        """Initialize ScatterAdd"""
4275        validator.check_value_type('use_locking', use_locking, [bool], self.name)
4276        self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y'])
4277        self.add_prim_attr('side_effect_mem', True)
4278
4279
4280class ScatterSub(_ScatterOpDynamic):
4281    r"""
4282    Updates the value of the input tensor through the subtraction operation.
4283
4284    Using given values to update tensor value through the subtraction operation, along with the input indices.
4285    This operation outputs the `input_x` after the update is done, which makes it convenient to use the updated value.
4286
4287    for each `i, ..., j` in `indices.shape`:
4288
4289    .. math::
4290
4291        \text{input_x}[\text{indices}[i, ..., j], :] \mathrel{-}= \text{updates}[i, ..., j, :]
4292
4293    Inputs of `input_x` and `updates` comply with the implicit type conversion rules to make the data types consistent.
4294    If they have different data types, lower priority data type will be converted to
4295    relatively highest priority data type.
4296    RuntimeError exception will be thrown when the data type conversion of Parameter is required.
4297
4298    Args:
4299        use_locking (bool): Whether protect the assignment by a lock. Default: False.
4300
4301    Inputs:
4302        - **input_x** (Parameter) - The target tensor, with data type of Parameter.
4303          The shape is :math:`(N,*)` where :math:`*` means,any number of additional dimensions.
4304        - **indices** (Tensor) - The index to do min operation whose data type must be mindspore.int32.
4305        - **updates** (Tensor) - The tensor doing the min operation with `input_x`,
4306          the data type is same as `input_x`, the shape is `indices_shape + x_shape[1:]`.
4307
4308    Outputs:
4309        Tensor, the updated `input_x`, has the same shape and type as `input_x`.
4310
4311    Raises:
4312        TypeError: If `use_locking` is not a bool.
4313        TypeError: If `indices` is not an int32.
4314        ValueError: If the shape of `updates` is not equal to `indices_shape + x_shape[1:]`.
4315
4316    Supported Platforms:
4317        ``Ascend`` ``CPU`` ``GPU``
4318
4319    Examples:
4320        >>> input_x = Parameter(Tensor(np.array([[0.0, 0.0, 0.0], [1.0, 1.0, 1.0]]), mindspore.float32), name="x")
4321        >>> indices = Tensor(np.array([[0, 1]]), mindspore.int32)
4322        >>> updates = Tensor(np.array([[[1.0, 1.0, 1.0], [2.0, 2.0, 2.0]]]), mindspore.float32)
4323        >>> scatter_sub = ops.ScatterSub()
4324        >>> output = scatter_sub(input_x, indices, updates)
4325        >>> print(output)
4326        [[-1. -1. -1.]
4327         [-1. -1. -1.]]
4328        >>> # for input_x will be updated after the operation is completed. input_x need to be re-initialized.
4329        >>> input_x = Parameter(Tensor(np.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]), mindspore.float32), name="x")
4330        >>> # for indices = [[0, 1], [1, 1]]
4331        >>> # step 1: [0, 1]
4332        >>> # input_x[0] = [0.0, 0.0, 0.0] - [1.0, 1.0, 1.0] = [-1.0, -1.0, -1.0]
4333        >>> # input_x[1] = [0.0, 0.0, 0.0] - [3.0, 3.0, 3.0] = [-3.0, -3.0, -3.0]
4334        >>> # step 2: [1, 1]
4335        >>> # input_x[1] = [-3.0, -3.0, -3.0] - [7.0, 7.0, 7.0] = [-10.0, -10.0, -10.0]
4336        >>> # input_x[1] = [-10.0, -10.0, -10.0] - [9.0, 9.0, 9.0] = [-19.0, -19.0, -19.0]
4337        >>> indices = Tensor(np.array([[0, 1], [1, 1]]), mindspore.int32)
4338        >>> updates = Tensor(np.array([[[1.0, 1.0, 1.0], [3.0, 3.0, 3.0]],
4339        ...                            [[7.0, 7.0, 7.0], [9.0, 9.0, 9.0]]]), mindspore.float32)
4340        >>> scatter_sub = ops.ScatterSub()
4341        >>> output = scatter_sub(input_x, indices, updates)
4342        >>> print(output)
4343        [[ -1.  -1.  -1.]
4344         [-19. -19. -19.]]
4345        >>> # for input_x will be updated after the operation is completed. input_x need to be re-initialized.
4346        >>> input_x = Parameter(Tensor(np.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]), mindspore.float32), name="x")
4347        >>> # for indices = [[1, 0], [1, 1]]
4348        >>> # step 1: [1, 0]
4349        >>> # input_x[0] = [0.0, 0.0, 0.0] - [3.0, 3.0, 3.0] = [-3.0, -3.0, -3.0]
4350        >>> # input_x[1] = [0.0, 0.0, 0.0] - [1.0, 1.0, 1.0] = [-1.0, -1.0, -1.0]
4351        >>> # step 2: [1, 1]
4352        >>> # input_x[1] = [-1.0, -1.0, -1.0] - [7.0, 7.0, 7.0] = [-8.0, -8.0, -8.0]
4353        >>> # input_x[1] = [-8.0, -8.0, -8.0] - [9.0, 9.0, 9.0] = [-17.0, -17.0, -17.0]
4354        >>> indices = Tensor(np.array([[1, 0], [1, 1]]), mindspore.int32)
4355        >>> updates = Tensor(np.array([[[1.0, 1.0, 1.0], [3.0, 3.0, 3.0]],
4356        ...                            [[7.0, 7.0, 7.0], [9.0, 9.0, 9.0]]]), mindspore.float32)
4357        >>> scatter_sub = ops.ScatterSub()
4358        >>> output = scatter_sub(input_x, indices, updates)
4359        >>> print(output)
4360        [[ -3.  -3.  -3.]
4361         [-17. -17. -17.]]
4362        >>> # for input_x will be updated after the operation is completed. input_x need to be re-initialized.
4363        >>> input_x = Parameter(Tensor(np.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]), mindspore.float32), name="x")
4364        >>> # for indices = [[0, 1], [0, 1]]
4365        >>> # step 1: [0, 1]
4366        >>> # input_x[0] = [0.0, 0.0, 0.0] - [1.0, 1.0, 1.0] = [-1.0, -1.0, -1.0]
4367        >>> # input_x[1] = [0.0, 0.0, 0.0] - [3.0, 3.0, 3.0] = [-3.0, -3.0, -3.0]
4368        >>> # step 2: [0, 1]
4369        >>> # input_x[0] = [-1.0, -1.0, -1.0] - [7.0, 7.0, 7.0] = [-8.0, -8.0, -8.0]
4370        >>> # input_x[1] = [-3.0, -3.0, -3.0] - [9.0, 9.0, 9.0] = [-12.0, -12.0, -12.0]
4371        >>> indices = Tensor(np.array([[0, 1], [0, 1]]), mindspore.int32)
4372        >>> updates = Tensor(np.array([[[1.0, 1.0, 1.0], [3.0, 3.0, 3.0]],
4373        ...                            [[7.0, 7.0, 7.0], [9.0, 9.0, 9.0]]]), mindspore.float32)
4374        >>> scatter_sub = ops.ScatterSub()
4375        >>> output = scatter_sub(input_x, indices, updates)
4376        >>> print(output)
4377        [[ -8.  -8.  -8.]
4378         [-12. -12. -12.]]
4379    """
4380
4381    @prim_attr_register
4382    def __init__(self, use_locking=False):
4383        """Initialize ScatterSub"""
4384        validator.check_value_type('use_locking', use_locking, [bool], self.name)
4385        self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y'])
4386        self.add_prim_attr('side_effect_mem', True)
4387
4388
4389class ScatterMul(_ScatterOp):
4390    r"""
4391    Updates the value of the input tensor through the multiply operation.
4392
4393    Using given values to update tensor value through the mul operation, along with the input indices.
4394    This operation outputs the `input_x` after the update is done, which makes it convenient to use the updated value.
4395
4396    for each `i, ..., j` in `indices.shape`:
4397
4398    .. math::
4399
4400        \text{input_x}[\text{indices}[i, ..., j], :] \mathrel{*}= \text{updates}[i, ..., j, :]
4401
4402    Inputs of `input_x` and `updates` comply with the implicit type conversion rules to make the data types consistent.
4403    If they have different data types, lower priority data type will be converted to
4404    relatively highest priority data type.
4405    RuntimeError exception will be thrown when the data type conversion of Parameter is required.
4406
4407    Args:
4408        use_locking (bool): Whether protect the assignment by a lock. Default: False.
4409
4410    Inputs:
4411        - **input_x** (Parameter) - The target tensor, with data type of Parameter.
4412          The shape is :math:`(N,*)` where :math:`*` means,any number of additional dimensions.
4413        - **indices** (Tensor) - The index to do min operation whose data type must be mindspore.int32.
4414        - **updates** (Tensor) - The tensor doing the min operation with `input_x`,
4415          the data type is same as `input_x`, the shape is `indices_shape + x_shape[1:]`.
4416
4417    Outputs:
4418        Tensor, the updated `input_x`, has the same shape and type as `input_x`.
4419
4420    Raises:
4421        TypeError: If `use_locking` is not a bool.
4422        TypeError: If `indices` is not an int32.
4423        ValueError: If the shape of `updates` is not equal to `indices_shape + x_shape[1:]`.
4424
4425    Supported Platforms:
4426        ``Ascend`` ``CPU``
4427
4428    Examples:
4429        >>> input_x = Parameter(Tensor(np.array([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0]]), mindspore.float32), name="x")
4430        >>> indices = Tensor(np.array([0, 1]), mindspore.int32)
4431        >>> updates = Tensor(np.array([[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]]), mindspore.float32)
4432        >>> scatter_mul = ops.ScatterMul()
4433        >>> output = scatter_mul(input_x, indices, updates)
4434        >>> print(output)
4435        [[2. 2. 2.]
4436         [4. 4. 4.]]
4437        >>> # for input_x will be updated after the operation is completed. input_x need to be re-initialized.
4438        >>> input_x = Parameter(Tensor(np.array([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0]]), mindspore.float32), name="x")
4439        >>> # for indices = [[0, 1], [1, 1]]
4440        >>> # step 1: [0, 1]
4441        >>> # input_x[0] = [1.0, 1.0, 1.0] * [1.0, 1.0, 1.0] = [1.0, 1.0, 1.0]
4442        >>> # input_x[1] = [2.0, 2.0, 2.0] * [3.0, 3.0, 3.0] = [6.0, 6.0, 6.0]
4443        >>> # step 2: [1, 1]
4444        >>> # input_x[1] = [6.0, 6.0, 6.0] * [7.0, 7.0, 7.0] = [42.0, 42.0, 42.0]
4445        >>> # input_x[1] = [42.0, 42.0, 42.0] * [9.0, 9.0, 9.0] = [378.0, 378.0, 378.0]
4446        >>> indices = Tensor(np.array([[0, 1], [1, 1]]), mindspore.int32)
4447        >>> updates = Tensor(np.array([[[1.0, 1.0, 1.0], [3.0, 3.0, 3.0]],
4448        ...                            [[7.0, 7.0, 7.0], [9.0, 9.0, 9.0]]]), mindspore.float32)
4449        >>> scatter_mul = ops.ScatterMul()
4450        >>> output = scatter_mul(input_x, indices, updates)
4451        >>> print(output)
4452        [[  1.   1.   1.]
4453         [378. 378. 378.]]
4454        >>> # for input_x will be updated after the operation is completed. input_x need to be re-initialized.
4455        >>> input_x = Parameter(Tensor(np.array([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0]]), mindspore.float32), name="x")
4456        >>> # for indices = [[1, 0], [1, 1]]
4457        >>> # step 1: [1, 0]
4458        >>> # input_x[0] = [1.0, 1.0, 1.0] * [3.0, 3.0, 3.0] = [3.0, 3.0, 3.0]
4459        >>> # input_x[1] = [2.0, 2.0, 2.0] * [1.0, 1.0, 1.0] = [2.0, 2.0, 2.0]
4460        >>> # step 2: [1, 1]
4461        >>> # input_x[1] = [2.0, 2.0, 2.0] * [7.0, 7.0, 7.0] = [14.0, 14.0, 14.0]
4462        >>> # input_x[1] = [14.0, 14.0, 14.0] * [9.0, 9.0, 9.0] = [126.0, 126.0, 126.0]
4463        >>> indices = Tensor(np.array([[1, 0], [1, 1]]), mindspore.int32)
4464        >>> updates = Tensor(np.array([[[1.0, 1.0, 1.0], [3.0, 3.0, 3.0]],
4465        ...                            [[7.0, 7.0, 7.0], [9.0, 9.0, 9.0]]]), mindspore.float32)
4466        >>> scatter_mul = ops.ScatterMul()
4467        >>> output = scatter_mul(input_x, indices, updates)
4468        >>> print(output)
4469        [[  3.   3.   3.]
4470         [126. 126. 126.]]
4471        >>> # for input_x will be updated after the operation is completed. input_x need to be re-initialized.
4472        >>> input_x = Parameter(Tensor(np.array([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0]]), mindspore.float32), name="x")
4473        >>> # for indices = [[0, 1], [0, 1]]
4474        >>> # step 1: [0, 1]
4475        >>> # input_x[0] = [1.0, 1.0, 1.0] * [1.0, 1.0, 1.0] = [1.0, 1.0, 1.0]
4476        >>> # input_x[1] = [2.0, 2.0, 2.0] * [3.0, 3.0, 3.0] = [6.0, 6.0, 6.0]
4477        >>> # step 2: [0, 1]
4478        >>> # input_x[0] = [1.0, 1.0, 1.0] * [7.0, 7.0, 7.0] = [7.0, 7.0, 7.0]
4479        >>> # input_x[1] = [6.0, 6.0, 6.0] * [9.0, 9.0, 9.0] = [54.0, 54.0, 54.0]
4480        >>> indices = Tensor(np.array([[0, 1], [0, 1]]), mindspore.int32)
4481        >>> updates = Tensor(np.array([[[1.0, 1.0, 1.0], [3.0, 3.0, 3.0]],
4482        ...                            [[7.0, 7.0, 7.0], [9.0, 9.0, 9.0]]]), mindspore.float32)
4483        >>> scatter_mul = ops.ScatterMul()
4484        >>> output = scatter_mul(input_x, indices, updates)
4485        >>> print(output)
4486        [[ 7.  7.  7.]
4487         [54. 54. 54.]]
4488    """
4489
4490
4491class ScatterDiv(_ScatterOp):
4492    r"""
4493    Updates the value of the input tensor through the divide operation.
4494
4495    Using given values to update tensor value through the div operation, along with the input indices.
4496    This operation outputs the `input_x` after the update is done, which makes it convenient to use the updated value.
4497
4498    for each `i, ..., j` in `indices.shape`:
4499
4500    .. math::
4501
4502        \text{input_x}[\text{indices}[i, ..., j], :] \mathrel{/}= \text{updates}[i, ..., j, :]
4503
4504    Inputs of `input_x` and `updates` comply with the implicit type conversion rules to make the data types consistent.
4505    If they have different data types, lower priority data type will be converted to
4506    relatively highest priority data type.
4507    RuntimeError exception will be thrown when the data type conversion of Parameter is required.
4508
4509    Args:
4510        use_locking (bool): Whether protect the assignment by a lock. Default: False.
4511
4512    Inputs:
4513        - **input_x** (Parameter) - The target tensor, with data type of Parameter.
4514          The shape is :math:`(N,*)` where :math:`*` means,any number of additional dimensions.
4515        - **indices** (Tensor) - The index to do min operation whose data type must be mindspore.int32.
4516        - **updates** (Tensor) - The tensor doing the min operation with `input_x`,
4517          the data type is same as `input_x`, the shape is `indices_shape + x_shape[1:]`.
4518
4519    Outputs:
4520        Tensor, the updated `input_x`, has the same shape and type as `input_x`.
4521
4522    Raises:
4523        TypeError: If `use_locking` is not a bool.
4524        TypeError: If `indices` is not an int32.
4525        ValueError: If the shape of `updates` is not equal to `indices_shape + x_shape[1:]`.
4526
4527    Supported Platforms:
4528        ``Ascend`` ``CPU``
4529
4530    Examples:
4531        >>> input_x = Parameter(Tensor(np.array([[6.0, 6.0, 6.0], [2.0, 2.0, 2.0]]), mindspore.float32), name="x")
4532        >>> indices = Tensor(np.array([0, 1]), mindspore.int32)
4533        >>> updates = Tensor(np.array([[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]]), mindspore.float32)
4534        >>> scatter_div = ops.ScatterDiv()
4535        >>> output = scatter_div(input_x, indices, updates)
4536        >>> print(output)
4537        [[3. 3. 3.]
4538         [1. 1. 1.]]
4539        >>> # for input_x will be updated after the operation is completed. input_x need to be re-initialized.
4540        >>> input_x = Parameter(Tensor(np.array([[105.0, 105.0, 105.0],
4541        ...                                      [315.0, 315.0, 315.0]]), mindspore.float32), name="x")
4542        >>> # for indices = [[0, 1], [1, 1]]
4543        >>> # step 1: [0, 1]
4544        >>> # input_x[0] = [105.0, 105.0, 105.0] / [1.0, 1.0, 1.0] = [105.0, 105.0, 105.0]
4545        >>> # input_x[1] = [315.0, 315.0, 315.0] / [3.0, 3.0, 3.0] = [105.0, 105.0, 105.0]
4546        >>> # step 2: [1, 1]
4547        >>> # input_x[1] = [105.0, 105.0, 105.0] / [5.0, 5.0, 5.0] = [21.0, 21.0, 21.0]
4548        >>> # input_x[1] = [21.0, 21.0, 21.0] / [7.0, 7.0, 7.0] = [3.0, 3.0, 3.0]
4549        >>> indices = Tensor(np.array([[0, 1], [1, 1]]), mindspore.int32)
4550        >>> updates = Tensor(np.array([[[1.0, 1.0, 1.0], [3.0, 3.0, 3.0]],
4551        ...                            [[5.0, 5.0, 5.0], [7.0, 7.0, 7.0]]]), mindspore.float32)
4552        >>> scatter_div = ops.ScatterDiv()
4553        >>> output = scatter_div(input_x, indices, updates)
4554        >>> print(output)
4555        [[105. 105. 105.]
4556         [  3.   3.   3.]]
4557        >>> # for input_x will be updated after the operation is completed. input_x need to be re-initialized.
4558        >>> input_x = Parameter(Tensor(np.array([[105.0, 105.0, 105.0],
4559        ...                                      [315.0, 315.0, 315.0]]), mindspore.float32), name="x")
4560        >>> # for indices = [[1, 0], [1, 1]]
4561        >>> # step 1: [1, 0]
4562        >>> # input_x[0] = [105.0, 105.0, 105.0] / [3.0, 3.0, 3.0] = [35.0, 35.0, 35.0]
4563        >>> # input_x[1] = [315.0, 315.0, 315.0] / [1.0, 1.0, 1.0] = [315.0, 315.0, 315.0]
4564        >>> # step 2: [1, 1]
4565        >>> # input_x[1] = [315.0, 315.0, 315.0] / [5.0, 5.0, 5.0] = [63.0 63.0 63.0]
4566        >>> # input_x[1] = [63.0 63.0 63.0] / [7.0, 7.0, 7.0] = [9.0, 9.0, 9.0]
4567        >>> indices = Tensor(np.array([[1, 0], [1, 1]]), mindspore.int32)
4568        >>> updates = Tensor(np.array([[[1.0, 1.0, 1.0], [3.0, 3.0, 3.0]],
4569        ...                            [[5.0, 5.0, 5.0], [7.0, 7.0, 7.0]]]), mindspore.float32)
4570        >>> scatter_div = ops.ScatterDiv()
4571        >>> output = scatter_div(input_x, indices, updates)
4572        >>> print(output)
4573        [[35. 35. 35.]
4574         [ 9.  9.  9.]]
4575        >>> # for input_x will be updated after the operation is completed. input_x need to be re-initialized.
4576        >>> input_x = Parameter(Tensor(np.array([[105.0, 105.0, 105.0],
4577        ...                                      [315.0, 315.0, 315.0]]), mindspore.float32), name="x")
4578        >>> # for indices = [[0, 1], [0, 1]]
4579        >>> # step 1: [0, 1]
4580        >>> # input_x[0] = [105.0, 105.0, 105.0] / [1.0, 1.0, 1.0] = [105.0, 105.0, 105.0]
4581        >>> # input_x[1] = [315.0, 315.0, 315.0] / [3.0, 3.0, 3.0] = [105.0, 105.0, 105.0]
4582        >>> # step 2: [0, 1]
4583        >>> # input_x[0] = [105.0, 105.0, 105.0] / [5.0, 5.0, 5.0] = [21.0, 21.0, 21.0]
4584        >>> # input_x[1] = [105.0, 105.0, 105.0] / [7.0, 7.0, 7.0] = [15.0, 15.0, 15.0]
4585        >>> indices = Tensor(np.array([[0, 1], [0, 1]]), mindspore.int32)
4586        >>> updates = Tensor(np.array([[[1.0, 1.0, 1.0], [3.0, 3.0, 3.0]],
4587        ...                            [[5.0, 5.0, 5.0], [7.0, 7.0, 7.0]]]), mindspore.float32)
4588        >>> scatter_div = ops.ScatterDiv()
4589        >>> output = scatter_div(input_x, indices, updates)
4590        >>> print(output)
4591        [[21. 21. 21.]
4592         [15. 15. 15.]]
4593    """
4594
4595
4596class ScatterNdAdd(_ScatterNdOp):
4597    r"""
4598    Applies sparse addition to individual values or slices in a tensor.
4599
4600    Using given values to update tensor value through the add operation, along with the input indices.
4601    This operation outputs the `input_x` after the update is done, which makes it convenient to use the updated value.
4602
4603    `input_x` has rank P and `indices` has rank Q where `Q >= 2`.
4604
4605    `indices` has shape :math:`(i_0, i_1, ..., i_{Q-2}, N)` where `N <= P`.
4606
4607    The last dimension of `indices` (with length `N` ) indicates slices along the `N` th dimension of `input_x`.
4608
4609    `updates` is a tensor of rank `Q-1+P-N`. Its shape is:
4610    :math:`(i_0, i_1, ..., i_{Q-2}, x\_shape_N, ..., x\_shape_{P-1})`.
4611
4612    Inputs of `input_x` and `updates` comply with the implicit type conversion rules to make the data types consistent.
4613    If they have different data types, lower priority data type will be converted to
4614    relatively highest priority data type.
4615    RuntimeError exception will be thrown when the data type conversion of Parameter is required.
4616
4617    Args:
4618        use_locking (bool): Whether protect the assignment by a lock. Default: False.
4619
4620    Inputs:
4621        - **input_x** (Parameter) - The target tensor, with data type of Parameter.
4622          The shape is :math:`(N,*)` where :math:`*` means,any number of additional dimensions.
4623        - **indices** (Tensor) - The index to do min operation whose data type must be mindspore.int32.
4624          The rank of indices must be at least 2 and `indices_shape[-1] <= len(shape)`.
4625        - **updates** (Tensor) - The tensor doing the min operation with `input_x`,
4626          the data type is same as `input_x`, the shape is `indices_shape[:-1] + x_shape[indices_shape[-1]:]`.
4627
4628    Outputs:
4629        Tensor, the updated `input_x`, has the same shape and type as `input_x`.
4630
4631    Raises:
4632        TypeError: If `use_locking` is not a bool.
4633        TypeError: If `indices` is not an int32.
4634        ValueError: If the shape of `updates` is not equal to `indices_shape[:-1] + x_shape[indices_shape[-1]:]`.
4635
4636    Supported Platforms:
4637        ``Ascend`` ``GPU``
4638
4639    Examples:
4640        >>> input_x = Parameter(Tensor(np.array([1, 2, 3, 4, 5, 6, 7, 8]), mindspore.float32), name="x")
4641        >>> indices = Tensor(np.array([[2], [4], [1], [7]]), mindspore.int32)
4642        >>> updates = Tensor(np.array([6, 7, 8, 9]), mindspore.float32)
4643        >>> scatter_nd_add = ops.ScatterNdAdd()
4644        >>> output = scatter_nd_add(input_x, indices, updates)
4645        >>> print(output)
4646        [ 1. 10.  9.  4. 12.  6.  7. 17.]
4647        >>> input_x = Parameter(Tensor(np.zeros((4, 4, 4)), mindspore.int32))
4648        >>> indices = Tensor(np.array([[0], [2]]), mindspore.int32)
4649        >>> updates = Tensor(np.array([[[1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3], [4, 4, 4, 4]],
4650        ...                            [[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]]]), mindspore.int32)
4651        >>> scatter_nd_add = ops.ScatterNdAdd()
4652        >>> output = scatter_nd_add(input_x, indices, updates)
4653        >>> print(output)
4654        [[[1 1 1 1]
4655          [2 2 2 2]
4656          [3 3 3 3]
4657          [4 4 4 4]]
4658         [[0 0 0 0]
4659          [0 0 0 0]
4660          [0 0 0 0]
4661          [0 0 0 0]]
4662         [[5 5 5 5]
4663          [6 6 6 6]
4664          [7 7 7 7]
4665          [8 8 8 8]]
4666         [[0 0 0 0]
4667          [0 0 0 0]
4668          [0 0 0 0]
4669          [0 0 0 0]]]
4670    """
4671
4672
4673class ScatterNdSub(_ScatterNdOp):
4674    r"""
4675    Applies sparse subtraction to individual values or slices in a tensor.
4676
4677    Using given values to update tensor value through the subtraction operation, along with the input indices.
4678    This operation outputs the `input_x` after the update is done, which makes it convenient to use the updated value.
4679
4680    `input_x` has rank P and `indices` has rank Q where `Q >= 2`.
4681
4682    `indices` has shape :math:`(i_0, i_1, ..., i_{Q-2}, N)` where `N <= P`.
4683
4684    The last dimension of `indices` (with length `N` ) indicates slices along the `N` th dimension of `input_x`.
4685
4686    `updates` is a tensor of rank `Q-1+P-N`. Its shape is:
4687    :math:`(i_0, i_1, ..., i_{Q-2}, x\_shape_N, ..., x\_shape_{P-1})`.
4688
4689    Inputs of `input_x` and `updates` comply with the implicit type conversion rules to make the data types consistent.
4690    If they have different data types, lower priority data type will be converted to
4691    relatively highest priority data type.
4692    RuntimeError exception will be thrown when the data type conversion of Parameter is required.
4693
4694    Args:
4695        use_locking (bool): Whether protect the assignment by a lock. Default: False.
4696
4697    Inputs:
4698        - **input_x** (Parameter) - The target tensor, with data type of Parameter.
4699          The shape is :math:`(N,*)` where :math:`*` means,any number of additional dimensions.
4700        - **indices** (Tensor) - The index of input tensor, with int32 data type.
4701          The rank of indices must be at least 2 and `indices_shape[-1] <= len(shape)`.
4702        - **updates** (Tensor) - The tensor to be updated to the input tensor, has the same type as input.
4703          The shape is `indices_shape[:-1] + x_shape[indices_shape[-1]:]`.
4704
4705    Outputs:
4706        Tensor, has the same shape and type as `input_x`.
4707
4708    Raises:
4709        TypeError: If `use_locking` is not a bool.
4710        TypeError: If `indices` is not an int32.
4711        ValueError: If the shape of `updates` is not equal to `indices_shape[:-1] + x_shape[indices_shape[-1]:]`.
4712
4713    Supported Platforms:
4714        ``Ascend`` ``GPU``
4715
4716    Examples:
4717        >>> input_x = Parameter(Tensor(np.array([1, 2, 3, 4, 5, 6, 7, 8]), mindspore.float32), name="x")
4718        >>> indices = Tensor(np.array([[2], [4], [1], [7]]), mindspore.int32)
4719        >>> updates = Tensor(np.array([6, 7, 8, 9]), mindspore.float32)
4720        >>> scatter_nd_sub = ops.ScatterNdSub()
4721        >>> output = scatter_nd_sub(input_x, indices, updates)
4722        >>> print(output)
4723        [ 1. -6. -3.  4. -2.  6.  7. -1.]
4724        >>> input_x = Parameter(Tensor(np.zeros((4, 4, 4)), mindspore.int32))
4725        >>> indices = Tensor(np.array([[0], [2]]), mindspore.int32)
4726        >>> updates = Tensor(np.array([[[1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3], [4, 4, 4, 4]],
4727        ...                            [[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]]]), mindspore.int32)
4728        >>> scatter_nd_sub = ops.ScatterNdSub()
4729        >>> output = scatter_nd_sub(input_x, indices, updates)
4730        >>> print(output)
4731        [[[-1 -1 -1 -1]
4732          [-2 -2 -2 -2]
4733          [-3 -3 -3 -3]
4734          [-4 -4 -4 -4]]
4735         [[ 0  0  0  0]
4736          [ 0  0  0  0]
4737          [ 0  0  0  0]
4738          [ 0  0  0  0]]
4739         [[-5 -5 -5 -5]
4740          [-6 -6 -6 -6]
4741          [-7 -7 -7 -7]
4742          [-8 -8 -8 -8]]
4743         [[ 0  0  0  0]
4744          [ 0  0  0  0]
4745          [ 0  0  0  0]
4746          [ 0  0  0  0]]]
4747    """
4748
4749
4750class ScatterNonAliasingAdd(_ScatterNdOp):
4751    """
4752    Applies sparse addition to the input using individual values or slices.
4753
4754    Using given values to update tensor value through the add operation, along with the input indices.
4755    This operation outputs the `input_x` after the update is done, which makes it convenient to use the updated value.
4756
4757    Inputs of `input_x` and `updates` comply with the implicit type conversion rules to make the data types consistent.
4758    If they have different data types, lower priority data type will be converted to
4759    relatively highest priority data type.
4760    RuntimeError exception will be thrown when the data type conversion of Parameter is required.
4761
4762    Inputs:
4763        - **input_x** (Parameter) - The target parameter. The data type must be float16, float32 or int32.
4764        - **indices** (Tensor) - The index to perform the addition operation whose data type must be mindspore.int32.
4765        - **updates** (Tensor) - The tensor that performs the addition operation with `input_x`,
4766          the data type is the same as `input_x`, the shape is `indices_shape[:-1] + x_shape[indices_shape[-1]:]`.
4767
4768    Outputs:
4769        Parameter, the updated `input_x`.
4770
4771    Raises:
4772        TypeError: If dtype of `indices` is not int32.
4773        TypeError: If dtype of `input_x` is not one of float16, float32, int32.
4774        ValueError: If the shape of `updates` is not equal to `indices_shape[:-1] + x_shape[indices_shape[-1]:]`.
4775
4776    Supported Platforms:
4777        ``Ascend``
4778
4779    Examples:
4780        >>> input_x = Parameter(Tensor(np.array([1, 2, 3, 4, 5, 6, 7, 8]), mindspore.float32), name="x")
4781        >>> indices = Tensor(np.array([[2], [4], [1], [7]]), mindspore.int32)
4782        >>> updates = Tensor(np.array([6, 7, 8, 9]), mindspore.float32)
4783        >>> scatter_non_aliasing_add = ops.ScatterNonAliasingAdd()
4784        >>> output = scatter_non_aliasing_add(input_x, indices, updates)
4785        >>> print(output)
4786        [ 1. 10.  9.  4. 12.  6.  7. 17.]
4787    """
4788
4789    @prim_attr_register
4790    def __init__(self):
4791        """Initialize ScatterNonAliasingAdd"""
4792        self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y'])
4793        self.add_prim_attr('side_effect_mem', True)
4794
4795    def infer_dtype(self, x_dtype, indices_dtype, updates_dtype):
4796        validator.check_tensor_dtype_valid('indices', indices_dtype, [mstype.int32], self.name)
4797        args = {"x": x_dtype, "updates": updates_dtype}
4798        validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32, mstype.int32], self.name)
4799        return x_dtype
4800
4801
4802class SpaceToDepth(PrimitiveWithInfer):
4803    r"""
4804    Rearranges blocks of spatial data into depth.
4805
4806    The output tensor's `height` dimension is :math:`height / block\_size`.
4807
4808    The output tensor's `weight` dimension is :math:`weight / block\_size`.
4809
4810    The depth of output tensor is :math:`block\_size * block\_size * input\_depth`.
4811
4812    The input tensor's height and width must be divisible by `block_size`.
4813    The data format is "NCHW".
4814
4815    Args:
4816        block_size (int): The block size used to divide spatial data. It must be >= 2.
4817
4818    Inputs:
4819        - **x** (Tensor) - The target tensor. The data tyoe is Number. It must be a 4-D tensor.
4820
4821    Outputs:
4822        Tensor, the same data type as `x`. It must be a 4-D tensor.Tensor of shape
4823          :math:`(N, ( C_{in} * \text{block_size} * 2), H_{in} / \text{block_size}, W_{in} / \text{block_size})`.
4824
4825    Raises:
4826        TypeError: If `block_size` is not an int.
4827        ValueError: If `block_size` is less than 2.
4828        ValueError: If length of shape of `x` is not equal to 4.
4829
4830    Supported Platforms:
4831        ``Ascend`` ``GPU`` ``CPU``
4832
4833    Examples:
4834        >>> x = Tensor(np.random.rand(1,3,2,2), mindspore.float32)
4835        >>> block_size = 2
4836        >>> space_to_depth = ops.SpaceToDepth(block_size)
4837        >>> output = space_to_depth(x)
4838        >>> print(output.shape)
4839        (1, 12, 1, 1)
4840    """
4841
4842    @prim_attr_register
4843    def __init__(self, block_size):
4844        """Initialize SpaceToDepth"""
4845        self.init_prim_io_names(inputs=['x'], outputs=['y'])
4846        validator.check_value_type('block_size', block_size, [int], self.name)
4847        validator.check('block_size', block_size, '', 2, Rel.GE)
4848        self.block_size = block_size
4849        self.add_prim_attr("data_format", "NCHW")
4850
4851    def infer_shape(self, x_shape):
4852        validator.check('x dimension', len(x_shape), '', 4, Rel.EQ)
4853        out_shape = copy.deepcopy(x_shape)
4854        for i in range(2):
4855            if out_shape[i + 2] % self.block_size != 0:
4856                msg_prefix = "2nd" if i + 2 == 2 else "3rd"
4857                raise ValueError(f"For '{self.name}', the shape of output with index {i + 2} must be divided "
4858                                 f"exactly by 'block_size', but got the {msg_prefix} dimension "
4859                                 f"of output: {out_shape[i + 2]} and "
4860                                 f"'block_size': {self.block_size}.")
4861            out_shape[i + 2] //= self.block_size
4862
4863        out_shape[1] *= self.block_size * self.block_size
4864        return out_shape
4865
4866    def infer_dtype(self, x_dtype):
4867        validator.check_subclass("x_dtype", x_dtype, mstype.tensor, self.name)
4868        return x_dtype
4869
4870
4871class DepthToSpace(PrimitiveWithInfer):
4872    r"""
4873    Rearranges blocks of depth data into spatial dimensions.
4874
4875    This is the reverse operation of SpaceToDepth.
4876
4877    The depth of output tensor is :math:`input\_depth / (block\_size * block\_size)`.
4878
4879    The output tensor's `height` dimension is :math:`height * block\_size`.
4880
4881    The output tensor's `weight` dimension is :math:`weight * block\_size`.
4882
4883    The input tensor's depth must be divisible by `block_size * block_size`.
4884    The data format is "NCHW".
4885
4886    Args:
4887        block_size (int): The block size used to divide depth data. It must be >= 2.
4888
4889    Inputs:
4890        - **x** (Tensor) - The target tensor. It must be a 4-D tensor with shape :math:`(N, C_{in}, H_{in}, W_{in})`.
4891          The data type is Number.
4892
4893    Outputs:
4894        Tensor of shape :math:`(N, C_{in} / \text{block_size} ^ 2, H_{in} * \text{block_size},
4895        W_{in} * \text{block_size})`.
4896
4897    Raises:
4898        TypeError: If `block_size` is not an int.
4899        ValueError: If `block_size` is less than 2.
4900        ValueError: If length of shape of `x` is not equal to 4.
4901
4902    Supported Platforms:
4903        ``Ascend`` ``GPU`` ``CPU``
4904
4905    Examples:
4906        >>> x = Tensor(np.random.rand(1, 12, 1, 1), mindspore.float32)
4907        >>> block_size = 2
4908        >>> depth_to_space = ops.DepthToSpace(block_size)
4909        >>> output = depth_to_space(x)
4910        >>> print(output.shape)
4911        (1, 3, 2, 2)
4912    """
4913
4914    @prim_attr_register
4915    def __init__(self, block_size):
4916        """Initialize DepthToSpace"""
4917        self.init_prim_io_names(inputs=['x'], outputs=['y'])
4918        validator.check_value_type('block_size', block_size, [int], self.name)
4919        validator.check('block_size', block_size, '', 2, Rel.GE, self.name)
4920        self.block_size = block_size
4921        self.add_prim_attr("data_format", "NCHW")
4922
4923    def infer_shape(self, x_shape):
4924        validator.check('x dimension', len(x_shape), '', 4, Rel.EQ)
4925        out_shape = copy.deepcopy(x_shape)
4926        for i in range(2):
4927            out_shape[i + 2] *= self.block_size
4928
4929        validator.check_int(x_shape[1] % (self.block_size * self.block_size),
4930                            0, Rel.EQ, 'x_shape[1] % (block_size*block_size)', self.name)
4931        out_shape[1] //= self.block_size * self.block_size
4932        return out_shape
4933
4934    def infer_dtype(self, x_dtype):
4935        validator.check_subclass("x_dtype", x_dtype, mstype.tensor, self.name)
4936        return x_dtype
4937
4938
4939class SpaceToBatch(PrimitiveWithInfer):
4940    r"""
4941    Divides spatial dimensions into blocks and combines the block size with the original batch.
4942
4943    This operation will divide spatial dimensions (H, W) into blocks with `block_size`, the output tensor's H and W
4944    dimension is the corresponding number of blocks after division. The output tensor's batch dimension is the
4945    product of the original batch and the square of block_size. Before division, the spatial dimensions
4946    of the input are zero padded according to paddings if necessary.
4947
4948    Args:
4949        block_size (int): The block size of dividing blocks with value greater than or euqual to 2.
4950        paddings (Union[tuple, list]): The padding values for H and W dimension, containing 2 subtraction lists.
4951            Each subtraction list contains 2 integer value. All values must be greater than 0.
4952            paddings[i] specifies the paddings for the spatial dimension i, which corresponds to the
4953            input dimension i+2. It is required that input_shape[i+2]+paddings[i][0]+paddings[i][1]
4954            is divisible by block_size.
4955
4956    Inputs:
4957        - **input_x** (Tensor) - The input tensor. It must be a 4-D tensor. The data type is Number.
4958
4959    Outputs:
4960        Tensor, the output tensor with the same data type as input. Assume input shape is :math:`(n, c, h, w)` with
4961        :math:`block\_size` and :math:`paddings`. The shape of the output tensor will be :math:`(n', c', h', w')`,
4962        where
4963
4964        :math:`n' = n*(block\_size*block\_size)`
4965
4966        :math:`c' = c`
4967
4968        :math:`h' = (h+paddings[0][0]+paddings[0][1])//block\_size`
4969
4970        :math:`w' = (w+paddings[1][0]+paddings[1][1])//block\_size`
4971
4972    Raises:
4973        TypeError: If `block_size` is not an int.
4974        ValueError: If `block_size` is less than 2.
4975
4976    Supported Platforms:
4977        ``Ascend`` ``GPU``
4978
4979    Examples:
4980        >>> block_size = 2
4981        >>> paddings = [[0, 0], [0, 0]]
4982        >>> space_to_batch = ops.SpaceToBatch(block_size, paddings)
4983        >>> input_x = Tensor(np.array([[[[1, 2], [3, 4]]]]), mindspore.float32)
4984        >>> output = space_to_batch(input_x)
4985        >>> print(output)
4986        [[[[1.]]]
4987         [[[2.]]]
4988         [[[3.]]]
4989         [[[4.]]]]
4990    """
4991
4992    @prim_attr_register
4993    def __init__(self, block_size, paddings):
4994        """Initialize SpaceToBatch"""
4995        logger.warning("WARN_DEPRECATED: The usage of SpaceToBatch is deprecated."
4996                       " Please use SpaceToBatchND.")
4997        validator.check_value_type('block_size', block_size, [int], self.name)
4998        validator.check('block_size', block_size, '', 2, Rel.GE, self.name)
4999        self.block_size = block_size
5000        validator.check('paddings shape', np.array(paddings).shape, '', (2, 2), Rel.EQ, self.name)
5001        for elem in itertools.chain(*paddings):
5002            validator.check_non_negative_int(elem, 'paddings element', self.name)
5003            validator.check_value_type('paddings element', elem, [int], self.name)
5004        self.paddings = paddings
5005
5006    def infer_dtype(self, x_dtype):
5007        validator.check_tensor_dtype_valid('input_x', x_dtype, mstype.number_type, self.name)
5008        return x_dtype
5009
5010    def infer_shape(self, x_shape):
5011        validator.check_equal_int(len(x_shape), 4, 'rank of input_x', self.name)
5012        out_shape = copy.deepcopy(x_shape)
5013        for i in range(2):
5014            padded = out_shape[i + 2] + self.paddings[i][0] + self.paddings[i][1]
5015            if padded % self.block_size != 0:
5016                msg_ndim = "2nd" if i + 2 == 2 else "3rd"
5017                raise ValueError(f"For '{self.name}', the shape of the output tensor should be "
5018                                 f"divisible by 'block_size', but got the {msg_ndim} dimension of output: {padded} and "
5019                                 f"'block_size': {self.block_size}. Please check the official homepage "
5020                                 f"for more information about the output tensor.")
5021            out_shape[i + 2] = padded // self.block_size
5022        out_shape[0] *= self.block_size * self.block_size
5023        return out_shape
5024
5025
5026class BatchToSpace(PrimitiveWithInfer):
5027    r"""
5028    Divides batch dimension with blocks and interleaves these blocks back into spatial dimensions.
5029
5030    This operation will divide batch dimension N into blocks with block_size, the output tensor's N dimension
5031    is the corresponding number of blocks after division. The output tensor's H, W dimension is product of
5032    original H, W dimension and block_size with given amount to crop from dimension, respectively.
5033
5034    Args:
5035        block_size (int): The block size of division, has the value not less than 2.
5036        crops (Union[list(int), tuple(int)]): The crop value for H and W dimension, containing 2 subtraction lists.
5037            Each list contains 2 integers.
5038            All values must be not less than 0. crops[i] specifies the crop values for the spatial dimension i, which
5039            corresponds to the input dimension i+2. It is required that
5040            input_shape[i+2]*block_size >= crops[i][0]+crops[i][1].
5041
5042    Inputs:
5043        - **input_x** (Tensor) - The input tensor. It must be a 4-D tensor, dimension 0 must be divisible by
5044          product of `block_shape`. The data type is float16 or float32.
5045
5046    Outputs:
5047        Tensor, the output tensor with the same type as input. Assume input shape is (n, c, h, w) with block_size
5048        and crops. The output shape will be (n', c', h', w'), where
5049
5050        :math:`n' = n//(block\_size*block\_size)`
5051
5052        :math:`c' = c`
5053
5054        :math:`h' = h*block\_size-crops[0][0]-crops[0][1]`
5055
5056        :math:`w' = w*block\_size-crops[1][0]-crops[1][1]`
5057
5058    Raises:
5059        TypeError: If `block_size` or element of `crops` is not an int.
5060        TypeError: If `crops` is neither list nor tuple.
5061        ValueError: If `block_size` is less than 2.
5062
5063    Supported Platforms:
5064        ``Ascend`` ``GPU``
5065
5066    Examples:
5067        >>> block_size = 2
5068        >>> crops = [[0, 0], [0, 0]]
5069        >>> batch_to_space = ops.BatchToSpace(block_size, crops)
5070        >>> input_x = Tensor(np.array([[[[1]]], [[[2]]], [[[3]]], [[[4]]]]), mindspore.float32)
5071        >>> output = batch_to_space(input_x)
5072        >>> print(output)
5073        [[[[1.  2.]
5074           [3.  4.]]]]
5075
5076    """
5077
5078    @prim_attr_register
5079    def __init__(self, block_size, crops):
5080        """Initialize BatchToSpace"""
5081        logger.warning("WARN_DEPRECATED: The usage of BatchToSpace is deprecated."
5082                       " Please use BatchToSpaceND.")
5083        validator.check_value_type('block_size', block_size, [int], self.name)
5084        validator.check('block_size', block_size, '', 2, Rel.GE, self.name)
5085        self.block_size = block_size
5086        validator.check_value_type('crops type', crops, [list, tuple], self.name)
5087        validator.check('crops shape', np.array(crops).shape, '', (2, 2))
5088        for elem in itertools.chain(*crops):
5089            validator.check_non_negative_int(elem, 'crops element', self.name)
5090            validator.check_value_type('crops element', elem, [int], self.name)
5091        self.crops = crops
5092
5093    def infer_dtype(self, x_dtype):
5094        validator.check_tensor_dtype_valid('input_x', x_dtype, mstype.number_type, self.name)
5095        return x_dtype
5096
5097    def infer_shape(self, x_shape):
5098        validator.check('rank of input_x', len(x_shape), '', 4)
5099        out_shape = copy.deepcopy(x_shape)
5100        for i in range(2):
5101            x_block_prod = out_shape[i + 2] * self.block_size
5102            crops_sum = self.crops[i][0] + self.crops[i][1]
5103            validator.check("x block shape prod", x_block_prod, 'crops sum', crops_sum, Rel.GT, self.name)
5104            out_shape[i + 2] = x_block_prod - crops_sum
5105        block_size_prod = self.block_size * self.block_size
5106        if out_shape[0] % block_size_prod != 0:
5107            raise ValueError(f"For '{self.name}', the shape of output with index 0 must be divided exactly "
5108                             f"by block_size_prod, but got the shape of output: {out_shape} and "
5109                             f"block_size_prod: {block_size_prod}.")
5110        out_shape[0] = out_shape[0] // block_size_prod
5111        return out_shape
5112
5113
5114class SpaceToBatchND(PrimitiveWithInfer):
5115    r"""
5116    Divides spatial dimensions into blocks and combines the block size with the original batch.
5117
5118    This operation will divide spatial dimensions (H, W) into blocks with block_shape, the output tensor's H and W
5119    dimension is the corresponding number of blocks after division. The output tensor's batch dimension is the
5120    product of the original batch and the product of `block_shape`. Before division,
5121    the spatial dimensions of the input are zero padded according to paddings if necessary.
5122
5123    Args:
5124        block_shape (Union[list(int), tuple(int), int]): The block shape of dividing block with all value greater
5125            than 1. If `block_shape` is a tuple or list, the length of `block_shape` is M corresponding to the
5126            number of spatial dimensions. If `block_shape` is a int, the block size of M dimendions are the same,
5127            equal to `block_shape`. M must be 2.
5128        paddings (Union[tuple, list]): The padding values for H and W dimension, containing 2 subtraction list.
5129            Each contains 2 integer value. All values must be greater than 0.
5130            `paddings[i]` specifies the paddings for the spatial dimension i,
5131            which corresponds to the input dimension i+2.
5132            It is required that input_shape[i+2]+paddings[i][0]+paddings[i][1] is divisible by block_shape[i].
5133
5134    Inputs:
5135        - **input_x** (Tensor) - The input tensor. It must be a 4-D tensor.
5136
5137    Outputs:
5138        Tensor, the output tensor with the same data type as input. Assume input shape is :math:`(n, c, h, w)` with
5139        :math:`block\_shape` and :math:`padddings`. The shape of the output tensor will be :math:`(n', c', h', w')`,
5140        where
5141
5142        :math:`n' = n*(block\_shape[0]*block\_shape[1])`
5143
5144        :math:`c' = c`
5145
5146        :math:`h' = (h+paddings[0][0]+paddings[0][1])//block\_shape[0]`
5147
5148        :math:`w' = (w+paddings[1][0]+paddings[1][1])//block\_shape[1]`
5149
5150    Raises:
5151        TypeError: If `block_shape` is not one of list, tuple, int.
5152        TypeError: If `paddings` is neither list nor tuple.
5153        ValueError: If length of shape of `block_shape` is not equal to 1.
5154        ValueError: If length of `block_shape` or `paddings` is not equal to 2.
5155
5156    Supported Platforms:
5157        ``Ascend``
5158
5159    Examples:
5160        >>> block_shape = [2, 2]
5161        >>> paddings = [[0, 0], [0, 0]]
5162        >>> space_to_batch_nd = ops.SpaceToBatchND(block_shape, paddings)
5163        >>> input_x = Tensor(np.array([[[[1, 2], [3, 4]]]]), mindspore.float32)
5164        >>> output = space_to_batch_nd(input_x)
5165        >>> print(output)
5166        [[[[1.]]]
5167         [[[2.]]]
5168         [[[3.]]]
5169         [[[4.]]]]
5170    """
5171
5172    @prim_attr_register
5173    def __init__(self, block_shape, paddings):
5174        """Initialize SpaceToBatchND"""
5175        if isinstance(block_shape, int):
5176            block_shape = (block_shape,) * 2
5177        self.add_prim_attr("block_shape", block_shape)
5178        validator.check_value_type('block_shape type', block_shape, [list, tuple], self.name)
5179        validator.check('block_shape shape', len(np.array(block_shape).shape), '', 1, Rel.EQ, self.name)
5180        block_rank = len(block_shape)
5181        validator.check('block_shape length', block_rank, '', 2, Rel.EQ, self.name)
5182        for elem in block_shape:
5183            validator.check('block_shape element', elem, '', 1, Rel.GE, self.name)
5184            validator.check_value_type('block_shape element', elem, [int], self.name)
5185        self.block_shape = block_shape
5186
5187        validator.check_value_type('paddings type', paddings, [list, tuple], self.name)
5188        validator.check('paddings length', len(paddings), '', 2, Rel.EQ, self.name)
5189        validator.check('paddings shape', np.array(paddings).shape, '', (block_rank, 2), Rel.EQ, self.name)
5190        for elem in itertools.chain(*paddings):
5191            validator.check_non_negative_int(elem, 'paddings element', self.name)
5192            validator.check_value_type('paddings element', elem, [int], self.name)
5193        self.paddings = paddings
5194
5195    def infer_dtype(self, x_dtype):
5196        validator.check_tensor_dtype_valid('input_x', x_dtype, mstype.number_type, self.name)
5197        return x_dtype
5198
5199    def infer_shape(self, x_shape):
5200        x_rank = len(x_shape)
5201        validator.check_equal_int(x_rank, 4, 'x_shape rank', self.name)
5202        out_shape = copy.deepcopy(x_shape)
5203
5204        block_shape_prod = 1
5205        offset = 2
5206        for i in range(len(self.block_shape)):
5207            padded = out_shape[i + offset] + self.paddings[i][0] + \
5208                     self.paddings[i][1]
5209            if padded % self.block_shape[i] != 0:
5210                raise ValueError(f"For '{self.name}', the padded should be divisible by 'block_shape', "
5211                                 f"where padded = input_x_shape[i + 2] + paddings[i][0] + paddings[i][1], "
5212                                 f"but got input_x_shape[{i + 2}]: {out_shape[i + offset]}, "
5213                                 f"paddings[{i}][0]: {self.paddings[i][0]} and paddings[{i}][1]: {self.paddings[i][1]}."
5214                                 f" Please check the official api documents for "
5215                                 f"more information about the output tensor.")
5216            out_shape[i + offset] = padded // self.block_shape[i]
5217            block_shape_prod = block_shape_prod * self.block_shape[i]
5218        out_shape[0] *= block_shape_prod
5219        return out_shape
5220
5221
5222class BatchToSpaceND(PrimitiveWithInfer):
5223    r"""
5224    Divides batch dimension with blocks and interleaves these blocks back into spatial dimensions.
5225
5226    This operation will divide batch dimension N into blocks with block_shape, the output tensor's N dimension
5227    is the corresponding number of blocks after division. The output tensor's H, W dimension is product of
5228    original H, W dimension and block_shape with given amount to crop from dimension, respectively.
5229
5230    Args:
5231        block_shape (Union[list(int), tuple(int), int]): The block shape of dividing block with all value greater
5232            than 1. If `block_shape` is a tuple or list, the length of `block_shape` is M corresponding to the
5233            number of spatial dimensions. If `block_shape` is a int, the block size of M dimendions are the same,
5234            equal to `block_shape`. M must be 2.
5235        crops (Union[list(int), tuple(int)]): The crop value for H and W dimension, containing 2 subtraction list,
5236            each containing 2 int value.
5237            All values must be >= 0. crops[i] specifies the crop values for spatial dimension i, which corresponds to
5238            input dimension i+2. It is required that input_shape[i+2]*block_shape[i] > crops[i][0]+crops[i][1].
5239
5240    Inputs:
5241        - **input_x** (Tensor) - The input tensor. It must be a 4-D tensor, dimension 0 must be divisible by
5242          product of `block_shape`. The data type is float16 or float32.
5243
5244    Outputs:
5245        Tensor, the output tensor with the same type as input. Assume input shape is (n, c, h, w) with block_shape
5246        and crops. The output shape will be (n', c', h', w'), where
5247
5248        :math:`n' = n//(block\_shape[0]*block\_shape[1])`
5249
5250        :math:`c' = c`
5251
5252        :math:`h' = h*block\_shape[0]-crops[0][0]-crops[0][1]`
5253
5254        :math:`w' = w*block\_shape[1]-crops[1][0]-crops[1][1]`
5255
5256    Raises:
5257        TypeError: If `block_shape` is not one of list, tuple, int.
5258        TypeError: If `crops` is neither list nor tuple.
5259        ValueError: If length of `block_shape` or `crops` is not equal to 2.
5260
5261    Supported Platforms:
5262        ``Ascend``
5263
5264    Examples:
5265        >>> block_shape = [2, 2]
5266        >>> crops = [[0, 0], [0, 0]]
5267        >>> batch_to_space_nd = ops.BatchToSpaceND(block_shape, crops)
5268        >>> input_x = Tensor(np.array([[[[1]]], [[[2]]], [[[3]]], [[[4]]]]), mindspore.float32)
5269        >>> output = batch_to_space_nd(input_x)
5270        >>> print(output)
5271        [[[[1.  2.]
5272           [3.  4.]]]]
5273
5274    """
5275
5276    @prim_attr_register
5277    def __init__(self, block_shape, crops):
5278        """Initialize BatchToSpaceND"""
5279        if isinstance(block_shape, int):
5280            block_shape = (block_shape,) * 2
5281        self.add_prim_attr("block_shape", block_shape)
5282        validator.check_value_type('block_shape type', block_shape, [list, tuple], self.name)
5283        validator.check('block_shape shape', len(np.array(block_shape).shape), '', 1, Rel.EQ, self.name)
5284        block_rank = len(block_shape)
5285        validator.check('block_shape length', block_rank, '', 2, Rel.EQ, self.name)
5286        for elem in block_shape:
5287            validator.check('block_shape element', elem, '', 1, Rel.GE, self.name)
5288            validator.check_value_type('block_shape element', elem, [int], self.name)
5289        self.block_shape = block_shape
5290
5291        validator.check_value_type('crops type', crops, [list, tuple], self.name)
5292        validator.check('crops length', len(crops), '', 2, Rel.EQ, self.name)
5293        validator.check('crops shape', np.array(crops).shape, '', (block_rank, 2), Rel.EQ, self.name)
5294        for elem in itertools.chain(*crops):
5295            validator.check_non_negative_int(elem, 'crops element', self.name)
5296            validator.check_value_type('crops element', elem, [int], self.name)
5297        self.crops = crops
5298
5299    def infer_dtype(self, x_dtype):
5300        validator.check_tensor_dtype_valid('input_x', x_dtype, mstype.number_type, self.name)
5301        return x_dtype
5302
5303    def infer_shape(self, x_shape):
5304        x_rank = len(x_shape)
5305        validator.check_int(x_rank, 4, Rel.EQ, 'x_shape rank', self.name)
5306        out_shape = copy.deepcopy(x_shape)
5307
5308        block_shape_prod = 1
5309        offset = 2
5310        for i in range(len(self.block_shape)):
5311            block_shape_prod = block_shape_prod * self.block_shape[i]
5312            x_block_prod = out_shape[i + offset] * self.block_shape[i]
5313            crops_sum = self.crops[i][0] + self.crops[i][1]
5314            validator.check("x block shape prod", x_block_prod, 'crops sum', crops_sum, Rel.GT, self.name)
5315            out_shape[i + offset] = x_block_prod - crops_sum
5316
5317        if out_shape[0] % block_shape_prod != 0:
5318            raise ValueError(f"For '{self.name}', the 0th dimension of the 'input_x' should be "
5319                             f"divisible by block_shape_prod, where block_shape_prod = "
5320                             f"'block_shape[0]' * 'block_shape[1]', "
5321                             f"but got 0th dimension of the 'input_x': "
5322                             f"{out_shape[0]} and the block_shape_prod: {block_shape_prod}.")
5323        out_shape[0] = out_shape[0] // block_shape_prod
5324        return out_shape
5325
5326
5327class BroadcastTo(Primitive):
5328    """
5329    Broadcasts input tensor to a given shape.
5330
5331    Input shape can be broadcast to target shape if for each dimension pair they are either equal or input is one or
5332    the target dimension is -1. In case of -1 in target shape, it will be replaced by the input shape's value
5333    in that dimension.
5334
5335    When input shape is broadcast to target shape, it starts with the trailing
5336    dimensions. If there is a -1 in the target shape, the -1 cannot be in a leading,
5337    non-existing dimension.
5338
5339    Args:
5340        shape (tuple): The target shape to broadcast. Can be fully specified, or have -1 in one position
5341            where it will be substituted by the input tensor's shape in that position, see example.
5342
5343    Inputs:
5344        - **input_x** (Tensor) - The input tensor. The data type should be one of the following types:
5345          float16, float32, int32, int8, uint8.
5346          The shape is :math:`(N,*)` where :math:`*` means,any number of additional dimensions.
5347
5348    Outputs:
5349        Tensor, with the given `shape` and the same data type as `input_x`.
5350
5351    Raises:
5352        TypeError: If `shape` is not a tuple.
5353        ValueError: if the target and input shapes are incompatible, or if a - 1 in the target shape is in an invalid
5354                    location.
5355
5356    Supported Platforms:
5357        ``Ascend`` ``GPU`` ``CPU``
5358
5359    Examples:
5360        >>> shape = (2, 3)
5361        >>> input_x = Tensor(np.array([1, 2, 3]).astype(np.float32))
5362        >>> broadcast_to = ops.BroadcastTo(shape)
5363        >>> output = broadcast_to(input_x)
5364        >>> print(output)
5365        [[1. 2. 3.]
5366         [1. 2. 3.]]
5367
5368        >>> shape = (-1, 2)
5369        >>> input_x = Tensor(np.array([[1], [2]]).astype(np.float32))
5370        >>> broadcast_to = ops.BroadcastTo(shape)
5371        >>> output = broadcast_to(input_x)
5372        >>> print(output)
5373        [[1. 1.]
5374         [2. 2.]]
5375    """
5376
5377    @prim_attr_register
5378    def __init__(self, shape):
5379        """Initialize BroadcastTo"""
5380        validator.check_value_type("shape", shape, (tuple), self.name)
5381        validator.check("shape length", len(shape), "", 0, Rel.GT, self.name)
5382        for ix, i in enumerate(shape):
5383            validator.check_value_type('target shape index -> ' + str(ix), i, [int], self.name)
5384            validator.check("shape element", i, "shape element min limit", -1, Rel.GE, self.name)
5385        self.shape = shape
5386
5387
5388class Meshgrid(PrimitiveWithInfer):
5389    """
5390    Generates coordinate matrices from given coordinate tensors.
5391
5392    Given N one-dimensional coordinate tensors, returns a tuple outputs of N N-D
5393    coordinate tensors for evaluating expressions on an N-D grid.
5394
5395
5396    Args:
5397        indexing (str): Either 'xy' or 'ij'. Default: 'xy'.
5398          When the indexing argument is set to 'xy' (the default), the broadcasting
5399          instructions for the first two dimensions are swapped.
5400
5401    Inputs:
5402        - **input** (Union[tuple]) - A Tuple of N 1-D Tensor objects.
5403          The length of input should be greater than 1. The data type is Number.
5404
5405    Outputs:
5406        Tensors, A Tuple of N N-D Tensor objects. The data type is the same with the Inputs.
5407
5408    Raises:
5409        TypeError: If `indexing` is not a str or `input` is not a tuple.
5410        ValueError: If `indexing` is neither 'xy' nor 'ij'.
5411
5412    Supported Platforms:
5413        ``Ascend`` ``GPU``
5414
5415    Examples:
5416        >>> x = Tensor(np.array([1, 2, 3, 4]).astype(np.int32))
5417        >>> y = Tensor(np.array([5, 6, 7]).astype(np.int32))
5418        >>> z = Tensor(np.array([8, 9, 0, 1, 2]).astype(np.int32))
5419        >>> inputs = (x, y, z)
5420        >>> meshgrid = ops.Meshgrid(indexing="xy")
5421        >>> output = meshgrid(inputs)
5422        >>> print(output)
5423        (Tensor(shape=[3, 4, 5], dtype=Int32, value=
5424         [[[1, 1, 1, 1, 1],
5425           [2, 2, 2, 2, 2],
5426           [3, 3, 3, 3, 3],
5427           [4, 4, 4, 4, 4]],
5428          [[1, 1, 1, 1, 1],
5429           [2, 2, 2, 2, 2],
5430           [3, 3, 3, 3, 3],
5431           [4, 4, 4, 4, 4]],
5432          [[1, 1, 1, 1, 1],
5433           [2, 2, 2, 2, 2],
5434           [3, 3, 3, 3, 3],
5435           [4, 4, 4, 4, 4]]]),
5436         Tensor(shape=[3, 4, 5], dtype=Int32, value=
5437         [[[5, 5, 5, 5, 5],
5438           [5, 5, 5, 5, 5],
5439           [5, 5, 5, 5, 5],
5440           [5, 5, 5, 5, 5]],
5441          [[6, 6, 6, 6, 6],
5442           [6, 6, 6, 6, 6],
5443           [6, 6, 6, 6, 6],
5444           [6, 6, 6, 6, 6]],
5445          [[7, 7, 7, 7, 7],
5446           [7, 7, 7, 7, 7],
5447           [7, 7, 7, 7, 7],
5448           [7, 7, 7, 7, 7]]]),
5449         Tensor(shape=[3, 4, 5], dtype=Int32, value=
5450         [[[8, 9, 0, 1, 2],
5451           [8, 9, 0, 1, 2],
5452           [8, 9, 0, 1, 2],
5453           [8, 9, 0, 1, 2]],
5454          [[8, 9, 0, 1, 2],
5455           [8, 9, 0, 1, 2],
5456           [8, 9, 0, 1, 2],
5457           [8, 9, 0, 1, 2]],
5458          [[8, 9, 0, 1, 2],
5459           [8, 9, 0, 1, 2],
5460           [8, 9, 0, 1, 2],
5461           [8, 9, 0, 1, 2]]]))
5462    """
5463
5464    @prim_attr_register
5465    def __init__(self, indexing="xy"):
5466        """Initialize Meshgrid."""
5467        validator.check_value_type("indexing", indexing, (str), self.name)
5468        validator.check_string(indexing.lower(), ["xy", "ij"], "indexing", self.name)
5469        self.indexing = indexing
5470
5471    def infer_shape(self, x_shape):
5472        validator.check_value_type("shape", x_shape, [tuple], self.name)
5473        validator.check_int(len(x_shape), 2, Rel.GE, "len of input", self.name)
5474        n = len(x_shape)
5475        shape_0 = []
5476        for s in x_shape:
5477            validator.check_int(len(s), 1, Rel.EQ, 'each input rank', self.name)
5478            shape_0.append(s[0])
5479        if self.indexing == "xy":
5480            shape_0[0], shape_0[1] = shape_0[1], shape_0[0]
5481        out_shape = tuple(tuple(shape_0) for _ in range(n))
5482        return out_shape
5483
5484    def infer_dtype(self, x_type):
5485        validator.check_subclass("input[0]", x_type[0], mstype.tensor, self.name)
5486        n = len(x_type)
5487        for i in range(1, n):
5488            validator.check('x_type[%d]' % i, x_type[i], 'base', x_type[0], Rel.EQ, self.name, TypeError)
5489        return x_type
5490
5491
5492class InplaceUpdate(PrimitiveWithInfer):
5493    r"""
5494    Updates specified rows with values in `v`.
5495
5496    Args:
5497        indices (Union[int, tuple]): Indices into the left-most dimension of `x`, and determines which rows of x
5498            to update with v. It is a int or tuple, whose value is in [0, the first dimension size of x).
5499
5500    Inputs:
5501        - **x** (Tensor) - A tensor which to be inplace updated. It can be one of the following data types:
5502          float32, float16 and int32.
5503        - **v** (Tensor) - A tensor with the same type as `x` and the same dimension size as `x` except
5504          the first dimension, which must be the same as the size of `indices`.
5505
5506    Outputs:
5507        Tensor, with the same type and shape as the input `x`.
5508
5509    Raises:
5510        TypeError: If `indices` is neither int nor tuple.
5511        TypeError: If `indices` is a tuple and its element is not an int.
5512
5513    Supported Platforms:
5514        ``Ascend``
5515
5516    Examples:
5517        >>> indices = (0, 1)
5518        >>> x = Tensor(np.array([[1, 2], [3, 4], [5, 6]]), mindspore.float32)
5519        >>> v = Tensor(np.array([[0.5, 1.0], [1.0, 1.5]]), mindspore.float32)
5520        >>> inplace_update = ops.InplaceUpdate(indices)
5521        >>> output = inplace_update(x, v)
5522        >>> print(output)
5523        [[0.5 1. ]
5524         [1.  1.5]
5525         [5.  6. ]]
5526    """
5527
5528    @prim_attr_register
5529    def __init__(self, indices):
5530        """Initialize InplaceUpdate"""
5531        self.init_prim_io_names(inputs=['x', 'v'], outputs=['y'])
5532        self.indices = indices
5533        validator.check_value_type("indices", indices, [int, tuple], self.name)
5534        if isinstance(indices, int):
5535            self.indices = (indices,)
5536        for item in self.indices:
5537            validator.check_value_type("item of indices", item, [int], self.name)
5538
5539    def infer_dtype(self, x_dtype, v_dtype):
5540        args = {'x': x_dtype, 'v': v_dtype}
5541        valid_type = [mstype.int32, mstype.float16, mstype.float32]
5542        validator.check_tensors_dtypes_same_and_valid(args, valid_type, self.name)
5543        return x_dtype
5544
5545    def infer_shape(self, x_shape, v_shape):
5546        validator.check("x", len(x_shape), "v", len(v_shape), Rel.EQ, self.name)
5547        validator.check("size of indices", len(self.indices), "v's first dimension", v_shape[0],
5548                        Rel.EQ, self.name)
5549        for i in self.indices:
5550            if i < 0 or i >= x_shape[0]:
5551                raise ValueError(f"For '{self.name}', the value of indices must be in [0, {x_shape[0]}), "
5552                                 f"but got {i}.")
5553        x_rank = len(x_shape)
5554        for idx in range(x_rank)[1:]:
5555            validator.check('v dim %d' % idx, v_shape[idx], "x dim %d" % idx, x_shape[idx], Rel.EQ, self.name)
5556        return x_shape
5557
5558
5559class ReverseSequence(PrimitiveWithInfer):
5560    """
5561    Reverses variable length slices.
5562
5563    Args:
5564        seq_dim (int): The dimension where reversal is performed. Required.
5565        batch_dim (int): The input is sliced in this dimension. Default: 0.
5566
5567    Inputs:
5568        - **x** (Tensor) - The input to reverse, supporting all number types including bool.
5569        - **seq_lengths** (Tensor) - Must be a 1-D vector with int32 or int64 types.
5570
5571    Outputs:
5572        Reversed tensor with the same shape and data type as input.
5573
5574    Raises:
5575        TypeError: If `seq_dim` or `batch_dim` is not an int.
5576
5577    Supported Platforms:
5578        ``Ascend`` ``GPU``
5579
5580    Examples:
5581        >>> x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), mindspore.float32)
5582        >>> seq_lengths = Tensor(np.array([1, 2, 3]))
5583        >>> reverse_sequence = ops.ReverseSequence(seq_dim=1)
5584        >>> output = reverse_sequence(x, seq_lengths)
5585        >>> print(output)
5586        [[1. 2. 3.]
5587         [5. 4. 6.]
5588         [9. 8. 7.]]
5589        >>> x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), mindspore.float32)
5590        >>> seq_lengths = Tensor(np.array([1, 2, 3]))
5591        >>> reverse_sequence = ops.ReverseSequence(seq_dim=0, batch_dim=1)
5592        >>> output = reverse_sequence(x, seq_lengths)
5593        >>> print(output)
5594        [[1. 5. 9.]
5595         [4. 2. 6.]
5596         [7. 8. 3.]]
5597        >>> x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), mindspore.float32)
5598        >>> seq_lengths = Tensor(np.array([2, 2, 3]))
5599        >>> reverse_sequence = ops.ReverseSequence(seq_dim=1)
5600        >>> output = reverse_sequence(x, seq_lengths)
5601        >>> print(output)
5602        [[2. 1. 3.]
5603         [5. 4. 6.]
5604         [9. 8. 7.]]
5605        >>> x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), mindspore.float32)
5606        >>> seq_lengths = Tensor(np.array([3, 2, 3]))
5607        >>> reverse_sequence = ops.ReverseSequence(seq_dim=1)
5608        >>> output = reverse_sequence(x, seq_lengths)
5609        >>> print(output)
5610        [[3. 2. 1.]
5611         [5. 4. 6.]
5612         [9. 8. 7.]]
5613        >>> x = Tensor(np.array([[1, 2, 3, 4], [5, 6, 7, 8]]), mindspore.float32)
5614        >>> seq_lengths = Tensor(np.array([4, 4]))
5615        >>> reverse_sequence = ops.ReverseSequence(seq_dim=1)
5616        >>> output = reverse_sequence(x, seq_lengths)
5617        >>> print(output)
5618        [[4. 3. 2. 1.]
5619         [8. 7. 6. 5.]]
5620    """
5621
5622    @prim_attr_register
5623    def __init__(self, seq_dim, batch_dim=0):
5624        """Initialize ReverseSequence"""
5625        self.init_prim_io_names(inputs=['x', 'seq_lengths'], outputs=['y'])
5626        validator.check_value_type("seq_dim", seq_dim, [int], self.name)
5627        self.seq_dim_ = seq_dim
5628        validator.check_value_type("batch_dim", batch_dim, [int], self.name)
5629        self.batch_dim_ = batch_dim
5630
5631    def infer_shape(self, x, seq_lengths):
5632        validator.check("seq_dim", self.seq_dim_, "x rank", len(x), Rel.LE, self.name)
5633        validator.check("batch_dim", self.batch_dim_, "x rank", len(x), Rel.LE, self.name)
5634        validator.check("batch_dim", self.batch_dim_, "seq_dim", self.seq_dim_, Rel.NE, self.name)
5635        validator.check("seq_lengths rank", len(seq_lengths), "expected", 1, Rel.EQ, self.name)
5636        validator.check("seq_lengths vector size", seq_lengths[0],
5637                        "input size along batch_dim", x[self.batch_dim_], Rel.EQ, self.name)
5638        return x
5639
5640    def infer_dtype(self, x, seq_lengths):
5641        validator.check_tensor_dtype_valid("x_dtype", x, mstype.number_type + (mstype.bool_,), self.name)
5642        validator.check_tensor_dtype_valid("seq_lengths_dtype", seq_lengths, [mstype.int32, mstype.int64], self.name)
5643        return x
5644
5645
5646class EditDistance(PrimitiveWithInfer):
5647    """
5648    Computes the Levenshtein Edit Distance. It is used to measure the similarity of two sequences. The inputs are
5649    variable-length sequences provided by SparseTensors (hypothesis_indices, hypothesis_values, hypothesis_shape)
5650    and (truth_indices, truth_values, truth_shape).
5651
5652    Args:
5653        normalize (bool): If true, edit distances are normalized by length of truth. Default: True.
5654
5655    Inputs:
5656        - **hypothesis_indices** (Tensor) - The indices of the hypothesis list SparseTensor. With int64 data type.
5657          The shape of tensor is :math:`(N, R)`.
5658        - **hypothesis_values** (Tensor) - The values of the hypothesis list SparseTensor. With float32 data type.
5659          Must be 1-D vector with length of N.
5660        - **hypothesis_shape** (Tensor) - The shape of the hypothesis list SparseTensor.
5661          Must be R-length vector with int64 data type. Only constant value is allowed.
5662        - **truth_indices** (Tensor) - The indices of the truth list SparseTensor. With int64 data type.
5663          The shape of tensor is :math:`(M, R)`.
5664        - **truth_values** (Tensor) - The values of the truth list SparseTensor. Must be 1-D vector with length of M.
5665          With float32 data type.
5666        - **truth_shape** (Tensor) - The shape of the truth list SparseTensor.
5667          Must be R-length vector with int64 data type. Only constant value is allowed.
5668
5669    Outputs:
5670        Tensor, a dense tensor with rank `R-1` and float32 data type.
5671
5672    Raises:
5673        TypeError: If `normalize` is not a bool.
5674
5675    Supported Platforms:
5676        ``Ascend``
5677
5678    Examples:
5679        >>> import numpy as np
5680        >>> from mindspore import context
5681        >>> from mindspore import Tensor
5682        >>> import mindspore.nn as nn
5683        >>> import mindspore.ops as ops
5684        >>> class EditDistance(nn.Cell):
5685        ...     def __init__(self, hypothesis_shape, truth_shape, normalize=True):
5686        ...         super(EditDistance, self).__init__()
5687        ...         self.edit_distance = ops.EditDistance(normalize)
5688        ...         self.hypothesis_shape = hypothesis_shape
5689        ...         self.truth_shape = truth_shape
5690        ...
5691        ...     def construct(self, hypothesis_indices, hypothesis_values, truth_indices, truth_values):
5692        ...         return self.edit_distance(hypothesis_indices, hypothesis_values, self.hypothesis_shape,
5693        ...                                   truth_indices, truth_values, self.truth_shape)
5694        ...
5695        >>> hypothesis_indices = Tensor(np.array([[0, 0, 0], [1, 0, 1], [1, 1, 1]]).astype(np.int64))
5696        >>> hypothesis_values = Tensor(np.array([1, 2, 3]).astype(np.float32))
5697        >>> hypothesis_shape = Tensor(np.array([1, 1, 2]).astype(np.int64))
5698        >>> truth_indices = Tensor(np.array([[0, 1, 0], [0, 0, 1], [1, 1, 0], [1, 0, 1]]).astype(np.int64))
5699        >>> truth_values = Tensor(np.array([1, 3, 2, 1]).astype(np.float32))
5700        >>> truth_shape = Tensor(np.array([2, 2, 2]).astype(np.int64))
5701        >>> edit_distance = EditDistance(hypothesis_shape, truth_shape)
5702        >>> output = edit_distance(hypothesis_indices, hypothesis_values, truth_indices, truth_values)
5703        >>> print(output)
5704        [[1. 1.]
5705         [1. 1.]]
5706    """
5707
5708    @prim_attr_register
5709    def __init__(self, normalize=True):
5710        """Initialize EditDistance"""
5711        self.normalize = validator.check_value_type("normalize", normalize, [bool], self.name)
5712        self.set_const_input_indexes([2, 5])
5713
5714    def __infer__(self, h_indices, h_values, h_shape, truth_indices, truth_values, truth_shape):
5715        validator.check_valid_input('hypothesis_shape', h_shape['value'], self.name)
5716        validator.check_valid_input('truth_shape', truth_shape['value'], self.name)
5717        args_int = {"hypothesis_indices": h_indices['dtype'], "hypothesis_shape": h_shape['dtype'],
5718                    "truth_indices": truth_indices['dtype'], "truth_shape": truth_shape['dtype']}
5719        validator.check_tensors_dtypes_same_and_valid(args_int, [mstype.int64], self.name)
5720        args = {"hypothesis_values": h_values['dtype'], "truth_values": truth_values['dtype']}
5721        validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
5722
5723        hypothesis_indices_shp, truth_indices_shp = h_indices['shape'], truth_indices['shape']
5724        validator.check("hypothesis_indices rank", len(hypothesis_indices_shp), "expected", 2, Rel.EQ, self.name)
5725        validator.check("truth_indices rank", len(truth_indices_shp), "expected", 2, Rel.EQ, self.name)
5726        validator.check("hypothesis_values rank", len(h_values['shape']), "expected", 1, Rel.EQ, self.name)
5727        validator.check("hypothesis_shape rank", len(h_shape['shape']), "expected", 1, Rel.EQ, self.name)
5728        validator.check("truth_values rank", len(truth_values['shape']), "expected", 1, Rel.EQ, self.name)
5729        validator.check("truth_shape rank", len(truth_shape['shape']), "expected", 1, Rel.EQ, self.name)
5730        validator.check("hypothesis_values shape", h_values['shape'][0],
5731                        "hypothesis_indices shape[0]", hypothesis_indices_shp[0], Rel.EQ, self.name)
5732        validator.check("hypothesis_shape", h_shape['shape'][0],
5733                        "hypothesis_indices shape[1]", hypothesis_indices_shp[1], Rel.EQ, self.name)
5734        validator.check("truth_values shape", truth_values['shape'][0],
5735                        "truth_indices shape[0]", truth_indices_shp[0], Rel.EQ, self.name)
5736        validator.check("hypothesis_shape", h_shape['shape'][0],
5737                        "truth_shape", truth_shape['shape'][0], Rel.EQ, self.name)
5738        hypothesis_shape_v = h_shape['value'].asnumpy()
5739        truth_shape_v = truth_shape['value'].asnumpy()
5740        out_shape_rank = len(hypothesis_shape_v) - 1
5741        out_shape = []
5742        for i in range(out_shape_rank):
5743            out_shape.append(max(hypothesis_shape_v[i], truth_shape_v[i]))
5744
5745        return {'shape': tuple(out_shape),
5746                'dtype': mstype.tensor_type(mstype.float32),
5747                'value': None}
5748
5749
5750class TransShape(PrimitiveWithInfer):
5751    """
5752    Transforms the shape of input tensor to target shape.
5753
5754    Inputs:
5755        - **input_x** (Tensor) - A input tensor.
5756        - **out_shape** (tuple[int]) - The shape of output data.
5757
5758    Outputs:
5759        Tensor, a tensor whose data type is same as 'input_x', and the shape is the same as the `out_shape`.
5760    """
5761
5762    @prim_attr_register
5763    def __init__(self):
5764        """Initialize TransShape."""
5765        self.__setattr_flag__ = True
5766
5767    def __infer__(self, x, shape):
5768        shp = shape['value']
5769        dtype = x['dtype']
5770        validator.check_tensor_dtype_valid('x', dtype, mstype.number_type + (mstype.bool_,), self.name)
5771        self.add_prim_attr('out_shape', tuple(shp))
5772        return {'shape': shp,
5773                'dtype': dtype,
5774                'value': None}
5775
5776
5777class Sort(PrimitiveWithInfer):
5778    """
5779    Sorts the elements of the input tensor along a given dimension in ascending order by value.
5780
5781    Args:
5782        axis (int): The dimension to sort along. Default: -1.
5783        descending (bool): Controls the sorting order. If descending is True then the elements
5784            are sorted in descending order by value. Default: False.
5785
5786    Inputs:
5787        - **x** (Tensor) - The input to sort, with float16 or float32 data type.
5788          The shape is :math:`(N,*)` where :math:`*` means,any number of additional dimensions.
5789
5790    Outputs:
5791        - **y1** (Tensor) - A tensor whose values are the sorted values, with the same shape and data type as input.
5792        - **y2** (Tensor) - The indices of the elements in the original input tensor. Data type is int32.
5793
5794    Raises:
5795        TypeError: If `axis` is not an int.
5796        TypeError: If `descending` is not a bool.
5797        TypeError: If dtype of `x` is neither float16 nor float32.
5798
5799    Supported Platforms:
5800        ``Ascend`` ``GPU`` ``CPU``
5801
5802    Examples:
5803        >>> x = Tensor(np.array([[8, 2, 1], [5, 9, 3], [4, 6, 7]]), mindspore.float16)
5804        >>> sort = ops.Sort()
5805        >>> output = sort(x)
5806        >>> print(output)
5807        (Tensor(shape=[3, 3], dtype=Float16, value=
5808        [[ 1.0000e+00,  2.0000e+00,  8.0000e+00],
5809         [ 3.0000e+00,  5.0000e+00,  9.0000e+00],
5810         [ 4.0000e+00,  6.0000e+00,  7.0000e+00]]), Tensor(shape=[3, 3], dtype=Int32, value=
5811        [[2, 1, 0],
5812         [2, 0, 1],
5813         [0, 1, 2]]))
5814    """
5815
5816    @prim_attr_register
5817    def __init__(self, axis=-1, descending=False):
5818        """Initialize Sort"""
5819        self.axis = validator.check_value_type("axis", axis, [int], self.name)
5820        self.descending = validator.check_value_type("descending", descending, [bool], self.name)
5821
5822    def infer_shape(self, x_shape):
5823        return x_shape, x_shape
5824
5825    def infer_dtype(self, x_dtype):
5826        validator.check_tensor_dtype_valid("x_dtype", x_dtype, [mstype.float32, mstype.float16], self.name)
5827        return x_dtype, mstype.tensor_type(mstype.int32)
5828
5829
5830class EmbeddingLookup(PrimitiveWithCheck):
5831    """
5832    Returns a slice of input tensor based on the specified indices.
5833
5834    This Primitive has the similar functionality as GatherV2 operating on `axis = 0`, but has one more inputs:
5835    `offset`.
5836
5837    Inputs:
5838        - **input_params** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
5839          This represents a Tensor slice, instead of the entire Tensor. Currently, the dimension is restricted to be 2.
5840        - **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`.
5841          Specifies the indices of elements of the original Tensor. Values can be out of range of `input_params`,
5842          and the exceeding part will be filled with 0 in the output. Values does not support negative and the result
5843          is undefined if values are negative. The data type should be int32 or int64.
5844        - **offset** (int) - Specifies the offset value of this `input_params` slice. Thus the real indices
5845          are equal to `input_indices` minus `offset`.
5846
5847    Outputs:
5848        Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`. The data type is the same with `input_params`.
5849
5850    Raises:
5851        TypeError: If dtype of `input_indices` is not int.
5852        ValueError: If length of shape of `input_params` is greater than 2.
5853
5854    Supported Platforms:
5855        ``Ascend`` ``CPU`` ``GPU``
5856
5857    Examples:
5858        >>> input_params = Tensor(np.array([[8, 9], [10, 11], [12, 13], [14, 15]]), mindspore.float32)
5859        >>> input_indices = Tensor(np.array([[5, 2], [8, 5]]), mindspore.int32)
5860        >>> offset = 4
5861        >>> output = ops.EmbeddingLookup()(input_params, input_indices, offset)
5862        >>> print(output)
5863        [[[10. 11.]
5864          [ 0.  0.]]
5865         [[ 0.  0.]
5866          [10. 11.]]]
5867    """
5868
5869    @prim_attr_register
5870    def __init__(self):
5871        """Initialize EmbeddingLookup."""
5872        self.__setattr_flag__ = True
5873        self.init_prim_io_names(inputs=['params', 'indices', 'offset'],
5874                                outputs=['output'])
5875
5876    def __check__(self, params, indices, offset):
5877        validator.check_subclass("params", params['dtype'], mstype.tensor, self.name)
5878        validator.check_tensor_dtype_valid("indices", indices['dtype'], mstype.int_type, self.name)
5879        validator.check_subclass("offset", offset['dtype'], mstype.int_, self.name)
5880        indices_shp = indices['shape']
5881        if not indices_shp:
5882            raise ValueError(f"For '{self.name}', the 'input_indices' should not be a scalar, but got {indices_shp}.")
5883        params_shp = params['shape']
5884        if len(params_shp) > 2:
5885            raise ValueError(f"For '{self.name}', the dimension of 'input_params' must <= 2, "
5886                             f"but got {len(params_shp)}.")
5887
5888
5889class GatherD(Primitive):
5890    """
5891    Gathers values along an axis specified by dim.
5892
5893    For a 3-D tensor, the output is:
5894
5895    .. code-block::
5896
5897        output[i][j][k] = x[index[i][j][k]][j][k]  # if dim == 0
5898
5899        output[i][j][k] = x[i][index[i][j][k]][k]  # if dim == 1
5900
5901        output[i][j][k] = x[i][j][index[i][j][k]]  # if dim == 2
5902
5903    If `x` is an n-D tensor with shape :math:`(z_0, z_1, ..., z_i, ..., z_{n-1})` and `dim` = i,
5904    the `index` must be an n-D tensor with shape :math:`(z_0, z_1, ..., y, ..., z_{n-1})`
5905    where `y`>=1 and the output will have the same shape as `index`.
5906
5907    Inputs:
5908        - **x** (Tensor) - The source tensor.
5909          The shape is :math:`(N,*)` where :math:`*` means,any number of additional dimensions.
5910        - **dim** (int) - The axis along which to index. It must be int32 or int64. Only constant value is allowed.
5911        - **index** (Tensor) - The indices of elements to gather. It can be one of the following data types:
5912          int32, int64. The value range of each index element is [-x_rank[dim], x_rank[dim]).
5913
5914    Outputs:
5915        Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`, has the same data type with `x`.
5916
5917    Raises:
5918        TypeError: If dtype of `dim` or `index` is neither int32 nor int64.
5919        ValueError: If length of shape of `x` is not equal to length of shape of `index`.
5920
5921    Supported Platforms:
5922        ``Ascend`` ``GPU`` ``CPU``
5923
5924    Examples:
5925        >>> x = Tensor(np.array([[1, 2], [3, 4]]), mindspore.int32)
5926        >>> index = Tensor(np.array([[0, 0], [1, 0]]), mindspore.int32)
5927        >>> dim = 1
5928        >>> output = ops.GatherD()(x, dim, index)
5929        >>> print(output)
5930        [[1 1]
5931         [4 3]]
5932    """
5933
5934    @prim_attr_register
5935    def __init__(self):
5936        """Initialize GatherD"""
5937        self.init_prim_io_names(inputs=['x', 'dim', 'index'], outputs=['output'])
5938
5939
5940class Identity(PrimitiveWithInfer):
5941    """
5942    Returns a Tensor with the same shape and contents as input.
5943
5944    Inputs:
5945        - **x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. The data type is Number.
5946
5947    Outputs:
5948        Tensor, the shape of tensor and the data type are the same as `input_x`, :math:`(x_1, x_2, ..., x_R)`.
5949
5950    Raises:
5951        TypeError: If `x` is not a Tensor.
5952
5953    Supported Platforms:
5954        ``Ascend`` ``CPU`` ``GPU``
5955
5956    Examples:
5957        >>> x = Tensor(np.array([1, 2, 3, 4]), mindspore.int64)
5958        >>> output = ops.Identity()(x)
5959        >>> print(output)
5960        [1 2 3 4]
5961    """
5962
5963    # Side effect is identity with input.
5964    side_effect_propagate = 1
5965
5966    @prim_attr_register
5967    def __init__(self):
5968        """Initialize identity"""
5969        self.add_prim_attr('side_effect_propagate', 1)
5970
5971    def __infer__(self, x):
5972        validator.check_subclass("x", x['dtype'], mstype.tensor, self.name)
5973        validator.check_tensor_dtype_valid('x', x['dtype'], mstype.number_type + (mstype.bool_,), self.name)
5974        out = {'shape': x['shape'],
5975               'dtype': x['dtype'],
5976               'value': None}
5977        return out
5978
5979
5980class Range(PrimitiveWithCheck):
5981    r"""
5982    Creates a sequence of numbers that begins at `start` and extends by increments of
5983    `delta` up to but not including `limit`.
5984
5985    The types of all 3 inputs must be the same. The type of the resulting tensor is
5986    the same as the type of the inputs.
5987
5988    Args:
5989        maxlen (int): Memory that can fit `maxlen` many elements
5990            will be allocated for the output. Optional, must be positive, defaults to 1000000.
5991            If the output has more than `maxlen` elements, a runtime error
5992            will occur.
5993
5994    Inputs:
5995        - **start** (Tensor) - A scalar Tensor. The first number in the sequence. Must have
5996          type: int32 or float32
5997        - **limit** (Tensor) - A scalar Tensor. Upper limit of the sequence, exclusive. Must
5998          have type: int32 or float32
5999        - **delta** (Tensor) - A scalar Tensor. Number that increments `start`. Must have
6000          type: int32 or float32
6001
6002    Outputs:
6003       A 1-D Tensor, with the same type as the inputs.
6004
6005    Supported Platforms:
6006        ``GPU``
6007
6008    Examples:
6009        >>> start = Tensor(0, mstype.int32)
6010        >>> limit = Tensor(10, mstype.int32)
6011        >>> delta = Tensor(4, mstype.int32)
6012        >>> output = ops.Range()(start, limit, delta)
6013        >>> print(output)
6014        [0, 4, 8]
6015    """
6016
6017    @prim_attr_register
6018    def __init__(self, maxlen=1000000):
6019        self.init_prim_io_names(inputs=['start', 'limit', 'delta'], outputs=['output'])
6020        validator.check_value_type("maxlen", maxlen, [int], self.name)
6021        validator.check_positive_int(maxlen, "maxlen", self.name)
6022        self.maxlen = maxlen
6023        self.add_prim_attr('maxlen', maxlen)
6024
6025    def check_shape(self, start_shape, limit_shape, delta_shape):
6026        validator.check("start_shape", len(start_shape), "", 0, Rel.EQ, self.name)
6027        validator.check("limit_shape", len(limit_shape), "", 0, Rel.EQ, self.name)
6028        validator.check("delta_shape", len(delta_shape), "", 0, Rel.EQ, self.name)
6029
6030    def check_dtype(self, start_dtype, limit_dtype, delta_dtype):
6031        valid_dtypes = [mstype.int32, mstype.float32]
6032        inputs = {"start": start_dtype, "limit": limit_dtype, "delta": delta_dtype}
6033        validator.check_tensors_dtypes_same_and_valid(inputs, valid_dtypes, self.name)
6034
6035    def infer_value(self, start_value, limit_value, delat_value):
6036        """Infer the value of input for Range."""
6037        if start_value is not None and limit_value is not None and delat_value is not None:
6038            start = np.asscalar(start_value.asnumpy())
6039            limit = np.asscalar(limit_value.asnumpy())
6040            delat = np.asscalar(delat_value.asnumpy())
6041            return Tensor(np.arange(start, limit, delat), dtype=start_value.dtype)
6042        return None
6043
6044
6045class MaskedFill(Primitive):
6046    """
6047    Fills elements of self tensor with value where mask is True.
6048
6049    The shapes of `input` and `mask` need to be the same or broadcast.
6050
6051    Inputs:
6052        - **input** (Tensor) - The source tensor whose data type is one of float16, float32, int8, int32.
6053        - **mask** (Tensor[bool]) - The boolean mask.
6054        - **value** (Union[float, Tensor]) – The value to fill in with, which only supports
6055          a 0-dimensional tensor or a float number.
6056
6057    Outputs:
6058        Tensor, has the same type and shape as `input`.
6059
6060    Raises:
6061        TypeError: If `input` or `mask` is not a tensor.
6062        TypeError: If `value` is neither float number nor tensor.
6063        TypeError: If dtype of `input` or `value` is not one of float16, float32, int8, int32.
6064        TypeError: If dtype of `value` is different from that of `input`.
6065        TypeError: If dtype of `mask` is not bool.
6066        ValueError: If the shapes of `input` and `mask` could not be broadcast.
6067
6068    Supported Platforms:
6069        ``Ascend``
6070
6071    Examples:
6072        >>> input = Tensor(np.array([1., 2., 3., 4.]), mindspore.float32)
6073        >>> mask = Tensor(np.array([True, True, False, True]), mindspore.bool_)
6074        >>> output = ops.MaskedFill()(input, mask, 0.5)
6075        >>> print(output)
6076        [0.5 0.5 3.  0.5]
6077    """
6078
6079    @prim_attr_register
6080    def __init__(self):
6081        self.init_prim_io_names(inputs=['input', 'mask', 'value'], outputs=['output'])
6082
6083
6084class MaskedSelect(PrimitiveWithCheck):
6085    """
6086    Returns a new 1-D Tensor which indexes the input tensor according to the boolean mask.
6087    The shapes of the mask tensor and the input tensor don’t need to match, but they must be broadcastable.
6088
6089    Inputs:
6090        - **x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
6091        - **mask** (Tensor[bool]) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
6092
6093    Outputs:
6094        A 1-D Tensor, with the same type as x.
6095
6096    Raises:
6097        TypeError: If `x` is not a Tensor.
6098
6099    Supported Platforms:
6100        ``Ascend`` ``CPU``
6101
6102    Examples:
6103        >>> x = Tensor(np.array([1, 2, 3, 4]), mindspore.int64)
6104        >>> mask = Tensor(np.array([1, 0, 1, 0]), mindspore.bool_)
6105        >>> output = ops.MaskedSelect()(x, mask)
6106        >>> print(output)
6107        [1 3]
6108    """
6109
6110    @prim_attr_register
6111    def __init__(self):
6112        self.init_prim_io_names(inputs=['x', 'mask'], outputs=['output'])
6113
6114    def check_shape(self, x_shape, mask_shape):
6115        get_broadcast_shape(x_shape, mask_shape, self.name)
6116        validator.check("rank of x", len(x_shape), "expected", 1, Rel.GE, self.name)
6117
6118    def check_dtype(self, x_dtype, mask_dtype):
6119        validator.check_tensor_dtype_valid('mask', mask_dtype, [mstype.bool_], self.name)
6120        validator.check_tensor_dtype_valid('x', x_dtype, [mstype.int32, mstype.float32], self.name)
6121
6122
6123class SearchSorted(PrimitiveWithInfer):
6124    """
6125    Find the indices from the innermost dimension of `sequence` such that the order of the innermost dimension
6126    within `sequence` would be preserved when the corresponding values in `values` were inserted before the indices.
6127
6128    Args:
6129        out_int32 (bool): Output datatype. Optional. If True, the output datatype will be int32;
6130                          if False, the output datatype will be int64. Default is False.
6131        right (bool): Search Strategy. Optional. If True, return the last suitable index found.
6132                      If False, return the first such index. Default is False.
6133
6134    Inputs:
6135        - **sequence** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R-1, x_R)` or `(x_1)`.
6136                                  It must contain monitonically increasing sequence on the innermost dimension.
6137        - **values** (Tensor) - The shape of tensor is : math:`(x_1, x_2, ..., x_R-1, x_S)`.
6138
6139    Outputs:
6140        Tensor containing the indices from the innermost dimension of the input sequence such that,
6141        if insert the corresponding value in the values tensor, the order of the tensor sequence would be preserved.
6142        The shape of tensor is :math:`(x_1, x_2, ..., x_R-1, x_S)`,
6143        whose datatype is int32 if out_int32 is True, otherwise int64, and shape is the same as the shape of values.
6144
6145    Raises:
6146        ValueError: If `sequence` and `values` do not have proper shapes.
6147
6148    Supported Platforms:
6149        ``CPU``
6150
6151    Examples:
6152        >>> sequence = Tensor(np.array([[0, 1, 3, 5, 7], [2, 4, 6, 8, 10]]), mindspore.float32)
6153        >>> values = Tensor(np.array([[3, 6, 9], [3, 6, 9]]), mindspore.float32)
6154        >>> output = ops.SearchSorted()(sequence, values)
6155        >>> print(output)
6156        [[2, 4, 5]
6157         [1, 2, 4]]
6158    """
6159    @prim_attr_register
6160    def __init__(self, out_int32=False, right=False):
6161        """Initialize SearchSorted"""
6162        self.out_int32 = validator.check_value_type("out_int32", out_int32, [bool], self.name)
6163        self.right = validator.check_value_type("right", right, [bool], self.name)
6164        self.init_prim_io_names(inputs=['sequence', 'values'], outputs=['positions'])
6165
6166    def infer_shape(self, sequence_shape, values_shape):
6167        if len(sequence_shape) != 1 and sequence_shape[:-1] != values_shape[:-1]:
6168            raise ValueError(f"For '{self.name}', the 'sequence' should be 1 dimensional or "
6169                             f"all dimensions except the last dimension of 'sequence' "
6170                             f"must be the same as all dimensions except the last dimension of 'values'. "
6171                             f"but got shape of 'sequence': {sequence_shape} "
6172                             f"and shape of 'values': {values_shape}.")
6173        return values_shape
6174
6175    def infer_dtype(self, sequence_dtype, values_dtype):
6176        args = {"sequence_dtype": sequence_dtype, "values_dtype": values_dtype}
6177        validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
6178        return mstype.tensor_type(mstype.int32) if self.out_int32 else mstype.tensor_type(mstype.int64)
6179
6180
6181class TensorScatterMax(PrimitiveWithInfer):
6182    """
6183    By comparing the value at the position indicated by the index in input_x with the value in the update,
6184    the value at the index will eventually be equal to the largest one to create a new tensor.
6185
6186    The last axis of the index is the depth of each index vector. For each index vector,
6187    there must be a corresponding value in `updates`. The shape of `updates` should be
6188    equal to the shape of input_x[indices].
6189    For more details, see use cases.
6190
6191    Note:
6192        If some values of the `indices` are out of bound, instead of raising an index error,
6193        the corresponding `updates` will not be updated to `input_x`.
6194
6195    Inputs:
6196        - **input_x** (Tensor) - The target tensor. The dimension of input_x must be no less than indices.shape[-1].
6197        - **indices** (Tensor) - The index of input tensor whose data type is int32 or int64.
6198          The rank must be at least 2.
6199        - **updates** (Tensor) - The tensor to update the input tensor, has the same type as input,
6200          and updates.shape should be equal to indices.shape[:-1] + input_x.shape[indices.shape[-1]:].
6201
6202    Outputs:
6203        Tensor, has the same shape and type as `input_x`.
6204
6205    Raises:
6206        TypeError: If dtype of `indices` is neither int32 nor int64.
6207        ValueError: If length of shape of `input_x` is less than the last dimension of shape of `indices`.
6208
6209    Supported Platforms:
6210        ``GPU``
6211
6212    Examples:
6213        >>> input_x = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32)
6214        >>> indices = Tensor(np.array([[0, 0], [0, 0]]), mindspore.int32)
6215        >>> updates = Tensor(np.array([1.0, 2.2]), mindspore.float32)
6216        >>> # Next, demonstrate the approximate operation process of this operator:
6217        >>> # 1, indices[0] = [0, 0], indices[1] = [0, 0]
6218        >>> # 2, And input_x[0, 0] = -0.1
6219        >>> # 3, So input_x[indices] = [-0.1, -0.1]
6220        >>> # 4, Satisfy the above formula: input_x[indices].shape=(2) == updates.shape=(2)
6221        >>> op = ops.TensorScatterMax()
6222        >>> # 5, Perform the max operation for the first time:
6223        >>> #      first_input_x = Max(input_x[0][0], updates[0]) = [[2.2, 0.3, 3.6], [0.4, 0.5, -3.2]]
6224        >>> # 6, Perform the max operation for the second time:
6225        >>> #      second_input_x = Max(input_x[0][0], updates[0]) = [[2.2, 0.3, 3.6], [0.4, 0.5, -3.2]]
6226        >>> output = op(input_x, indices, updates)
6227        >>> print(output)
6228        [[ 2.2  0.3  3.6]
6229         [ 0.4  0.5 -3.2]]
6230    """
6231
6232    @prim_attr_register
6233    def __init__(self):
6234        self.init_prim_io_names(inputs=['input_x', 'indices', 'updates'], outputs=['y'])
6235
6236    def infer_shape(self, input_x_shape, indices_shape, updates_shape):
6237        if len(indices_shape) < 2:
6238            raise ValueError(f"For '{self.name}', the dimension of 'indices' cannot be less than 2,"
6239                             f" but got {len(indices_shape)}.")
6240
6241        if indices_shape[-1] > len(input_x_shape):
6242            raise ValueError(f"For '{self.name}', the last dimension of 'indices' must be less than or equal to "
6243                             f"the dimension of 'input_x', but got the "
6244                             f"last dimension of 'indices': {indices_shape[-1]} and the dimension of 'input_x': "
6245                             f"{len(indices_shape)}.")
6246
6247        updates_shape_check = indices_shape[:-1] + input_x_shape[indices_shape[-1]:]
6248        if updates_shape_check != updates_shape:
6249            raise ValueError(f"For '{self.name}', the shape of 'update' must be equal to updates_shape_check, "
6250                             f"where updates_shape_check = indices_shape[:-1] + input_x_shape[indices_shape[-1]:] "
6251                             f"but got the shape of 'update': {updates_shape}, "
6252                             f"updates_shape_check: {updates_shape_check}, indices_shape: {indices_shape} and "
6253                             f"input_x_shape: {input_x_shape}. Please check input_x_shape and indices_shape.")
6254
6255        return input_x_shape
6256
6257    def infer_dtype(self, input_x_dtype, indices_dtype, updates_dtype):
6258        validator.check_tensor_dtype_valid('indices', indices_dtype, [mstype.int32, mstype.int64], self.name)
6259        args = {"input_x": input_x_dtype, "updates": updates_dtype}
6260        valid_input_types = (mstype.float16, mstype.float32, mstype.int8, mstype.uint8, mstype.int32)
6261        validator.check_tensors_dtypes_same_and_valid(args, valid_input_types, self.name)
6262        return input_x_dtype
6263
6264
6265class TensorScatterMin(PrimitiveWithInfer):
6266    """
6267    By comparing the value at the position indicated by the index in input_x with the value in the `updates`,
6268    the value at the index will eventually be equal to the smallest one to create a new tensor.
6269
6270    The last axis of the index is the depth of each index vector. For each index vector,
6271    there must be a corresponding value in `updates`. The shape of `updates` should be
6272    equal to the shape of input_x[indices].
6273    For more details, see use cases.
6274
6275    Note:
6276        If some values of the `indices` are out of bound, instead of raising an index error,
6277        the corresponding `updates` will not be updated to `input_x`.
6278
6279    Inputs:
6280        - **input_x** (Tensor) - The target tensor. The dimension of input_x must be no less than indices.shape[-1].
6281        - **indices** (Tensor) - The index of input tensor whose data type is int32 or int64.
6282          The rank must be at least 2.
6283        - **updates** (Tensor) - The tensor to update the input tensor, has the same type as input,
6284          and updates.shape should be equal to indices.shape[:-1] + input_x.shape[indices.shape[-1]:].
6285
6286    Outputs:
6287        Tensor, has the same shape and type as `input_x`.
6288
6289    Raises:
6290        TypeError: If dtype of `indices` is neither int32 nor int64.
6291        ValueError: If length of shape of `input_x` is less than the last dimension of shape of `indices`.
6292
6293    Supported Platforms:
6294        ``GPU``
6295
6296    Examples:
6297        >>> input_x = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32)
6298        >>> indices = Tensor(np.array([[0, 0], [0, 0]]), mindspore.int32)
6299        >>> updates = Tensor(np.array([1.0, 2.2]), mindspore.float32)
6300        >>> # Next, demonstrate the approximate operation process of this operator:
6301        >>> # 1, indices[0] = [0, 0], indices[1] = [0, 0]
6302        >>> # 2, And input_x[0, 0] = -0.1
6303        >>> # 3, So input_x[indices] = [-0.1, -0.1]
6304        >>> # 4, Satisfy the above formula: input_x[indices].shape=(2) == updates.shape=(2)
6305        >>> op = ops.TensorScatterMin()
6306        >>> # 5, Perform the min operation for the first time:
6307        >>> #      first_input_x = Min(input_x[0][0], updates[0]) = [[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]
6308        >>> # 6, Perform the min operation for the second time:
6309        >>> #      second_input_x = Min(input_x[0][0], updates[1]) = [[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]
6310        >>> output = op(input_x, indices, updates)
6311        >>> print(output)
6312        [[ -0.1  0.3  3.6]
6313         [ 0.4  0.5 -3.2]]
6314    """
6315
6316    @prim_attr_register
6317    def __init__(self):
6318        self.init_prim_io_names(inputs=['input_x', 'indices', 'updates'], outputs=['y'])
6319
6320    def infer_shape(self, input_x_shape, indices_shape, updates_shape):
6321        if len(indices_shape) < 2:
6322            raise ValueError(f"For '{self.name}', the dimension of 'indices' cannot be less than 2,"
6323                             f" but got {len(indices_shape)}.")
6324
6325        if indices_shape[-1] > len(input_x_shape):
6326            raise ValueError(f"For '{self.name}', the last dimension of 'indices' must be less than or equal to "
6327                             f"the dimension of 'input_x', but got the "
6328                             f"last dimension of 'indices': {indices_shape[-1]} and the dimension of 'input_x': "
6329                             f"{len(indices_shape)}.")
6330
6331        updates_shape_check = indices_shape[:-1] + input_x_shape[indices_shape[-1]:]
6332        if updates_shape_check != updates_shape:
6333            raise ValueError(f"For '{self.name}', the shape of 'update' must be equal to updates_shape_check, "
6334                             f"where updates_shape_check = indices_shape[:-1] + input_x_shape[indices_shape[-1]:] "
6335                             f"but got the shape of 'update': {updates_shape}, "
6336                             f"updates_shape_check: {updates_shape_check}, indices_shape: {indices_shape} and "
6337                             f"input_x_shape: {input_x_shape}. Please check input_x_shape and indices_shape.")
6338
6339        return input_x_shape
6340
6341    def infer_dtype(self, input_x_dtype, indices_dtype, updates_dtype):
6342        validator.check_tensor_dtype_valid('indices', indices_dtype, [mstype.int32, mstype.int64], self.name)
6343        args = {"input_x": input_x_dtype, "updates": updates_dtype}
6344        valid_input_types = (mstype.float16, mstype.float32, mstype.int8, mstype.uint8, mstype.int32)
6345        validator.check_tensors_dtypes_same_and_valid(args, valid_input_types, self.name)
6346        return input_x_dtype
6347
6348
6349class TensorScatterSub(PrimitiveWithInfer):
6350    """
6351    Creates a new tensor by subtracting the values from the positions in `input_x` indicicated by
6352    `indices`, with values from `updates`. When multiple values are provided for the same
6353    index, the result of the update will be to subtract these values respectively. This operation is almost
6354    equivalent to using ScatterNdSub, except that the updates are applied on `Tensor` instead of `Parameter`.
6355
6356    The last axis of `indices` is the depth of each index vectors. For each index vector,
6357    there must be a corresponding value in `updates`. The shape of `updates` should be
6358    equal to the shape of `input_x[indices]`. For more details, see use cases.
6359
6360    Note:
6361        If some values of the `indices` are out of bound, instead of raising an index error,
6362        the corresponding `updates` will not be updated to `input_x`.
6363
6364    Inputs:
6365        - **input_x** (Tensor) - The target tensor. The dimension of input_x must be no less than indices.shape[-1].
6366        - **indices** (Tensor) - The index of input tensor whose data type is int32 or int64.
6367          The rank must be at least 2.
6368        - **updates** (Tensor) - The tensor to update the input tensor, has the same type as input,
6369          and updates.shape should be equal to indices.shape[:-1] + input_x.shape[indices.shape[-1]:].
6370
6371    Outputs:
6372        Tensor, has the same shape and type as `input_x`.
6373
6374    Raises:
6375        TypeError: If dtype of `indices` is neither int32 nor int64.
6376        ValueError: If length of shape of `input_x` is less than the last dimension of shape of `indices`.
6377
6378    Supported Platforms:
6379        ``GPU``
6380
6381    Examples:
6382        >>> input_x = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32)
6383        >>> indices = Tensor(np.array([[0, 0], [0, 0]]), mindspore.int32)
6384        >>> updates = Tensor(np.array([1.0, 2.2]), mindspore.float32)
6385        >>> # Next, demonstrate the approximate operation process of this operator:
6386        >>> # 1, indices[0] = [0, 0], indices[1] = [0, 0]
6387        >>> # 2, And input_x[0, 0] = -0.1
6388        >>> # 3, So input_x[indices] = [-0.1, -0.1]
6389        >>> # 4, Satisfy the above formula: input_x[indices].shape=(2) == updates.shape=(2)
6390        >>> op = ops.TensorScatterSub()
6391        >>> # 5, Perform the subtract operation for the first time:
6392        >>> #      first_input_x = input_x[0][0] - updates[0] = [[-1.1, 0.3, 3.6], [0.4, 0.5, -3.2]]
6393        >>> # 6, Perform the subtract operation for the second time:
6394        >>> #      second_input_x = input_x[0][0] - updates[1] = [[-3.3, 0.3, 3.6], [0.4, 0.5, -3.2]]
6395        >>> output = op(input_x, indices, updates)
6396        >>> print(output)
6397        [[-3.3000002  0.3        3.6      ]
6398         [ 0.4        0.5       -3.2      ]]
6399    """
6400
6401    @prim_attr_register
6402    def __init__(self):
6403        self.init_prim_io_names(inputs=['input_x', 'indices', 'updates'], outputs=['y'])
6404
6405    def infer_shape(self, input_x_shape, indices_shape, updates_shape):
6406        if len(indices_shape) < 2:
6407            raise ValueError(f"For '{self.name}', the dimension of 'indices' cannot be less than 2,"
6408                             f" but got {len(indices_shape)}.")
6409
6410        if indices_shape[-1] > len(input_x_shape):
6411            raise ValueError(f"For '{self.name}', the last dimension of 'indices' must be less than or equal to "
6412                             f"the dimension of 'input_x', but got the "
6413                             f"last dimension of 'indices': {indices_shape[-1]} and the dimension of 'input_x': "
6414                             f"{len(indices_shape)}.")
6415
6416        updates_shape_check = indices_shape[:-1] + input_x_shape[indices_shape[-1]:]
6417        if updates_shape_check != updates_shape:
6418            raise ValueError(f"For '{self.name}', the shape of 'update' must be equal to updates_shape_check, "
6419                             f"where updates_shape_check = indices_shape[:-1] + input_x_shape[indices_shape[-1]:] "
6420                             f"but got the shape of 'update': {updates_shape}, "
6421                             f"updates_shape_check: {updates_shape_check}, indices_shape: {indices_shape} and "
6422                             f"input_x_shape: {input_x_shape}. Please check input_x_shape and indices_shape.")
6423
6424        return input_x_shape
6425
6426    def infer_dtype(self, input_x_dtype, indices_dtype, updates_dtype):
6427        validator.check_tensor_dtype_valid('indices', indices_dtype, [mstype.int32, mstype.int64], self.name)
6428        args = {"input_x": input_x_dtype, "updates": updates_dtype}
6429        valid_input_types = (mstype.float16, mstype.float32, mstype.int8, mstype.uint8, mstype.int32)
6430        validator.check_tensors_dtypes_same_and_valid(args, valid_input_types, self.name)
6431        return input_x_dtype
6432
6433
6434class SplitV(Primitive):
6435    r"""
6436    Splits the input tensor into num_split tensors along the given dimension.
6437
6438    The `input_x` tensor will be split into sub-tensors with individual shapes given by `size_splits` along the split
6439    dimension. This requires that `input_x.shape(split_dim)` is equal to the sum of `size_splits`.
6440
6441    The shape of `input_x` is :math:`(x_1, x_2, ..., x_M, ..., x_R)`. The rank of `input_x` is `R`. Set the given
6442    `split_dim` as M, and :math:`-R \le M < R`. Set the given `num_split` as `N`, the given `size_splits` as
6443    :math:`(x_{m_1}, x_{m_2}, ..., x_{m_N})`, :math:`x_M=\sum_{i=1}^Nx_{m_i}`. The output is a list of tensor objects,
6444    for the :math:`i`-th tensor, it has the shape of :math:`(x_1, x_2, ..., x_{m_i}, ..., x_R)`. :math:`x_{m_i}` is the
6445    :math:`M`-th dimension of the :math:`i`-th tensor. Then, the shape of the output tensor is
6446
6447    .. math::
6448
6449        ((x_1, x_2, ..., x_{m_1}, ..., x_R), (x_1, x_2, ..., x_{m_2}, ..., x_R), ...,
6450         (x_1, x_2, ..., x_{m_N}, ..., x_R))
6451
6452    Args:
6453        size_splits (Union[tuple, list]): The list containing the sizes of each output tensor along the split
6454                                          dimension. Must sum to the dimension of value along `split_dim`.
6455                                          Can contain one -1 indicating that dimension is to be inferred.
6456        split_dim (int): The dimension along which to split. Must be in the range [-len(input_x.shape),
6457                         len(input_x.shape)).
6458        num_split (int): The number of output tensors. Must be positive int.
6459
6460    Inputs:
6461        - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ...,x_M ..., x_R)`.
6462
6463    Outputs:
6464        Tensor, a list of `num_split` Tensor objects with the shape :math:`((x_1, x_2, ..., x_{m_1}, ..., x_R),
6465        (x_1, x_2, ..., x_{m_2}, ..., x_R), ..., (x_1, x_2, ..., x_{m_N}, ..., x_R))`, :math:`x_M=\sum_{i=1}^Nx_{m_i}`.
6466        The data type is the same with `input_x`.
6467
6468    Raises:
6469        TypeError: If `input_x` is not a Tensor.
6470        TypeError: If `size_splits` is not a tuple or a list.
6471        TypeError: If element of `size_splits` is not an int.
6472        TypeError: If `split_dim` or `num_split` is not an int.
6473        ValueError: If rank of the `size_splits` is not equal to `num_split`.
6474        ValueError: If sum of the `size_splits` is not equal to the dimension of value along `split_dim`.
6475        ValueError: If `split_dim` is out of the range [-len(input_x.shape), len(input_x.shape)).
6476        ValueError: If the `num_split` is less than or equal to 0.
6477
6478    Supported Platforms:
6479        ``Ascend``
6480
6481    Examples:
6482        >>> input_x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), mindspore.int32)
6483        >>> op = ops.SplitV(size_splits=[1, -1], split_dim=1, num_split=2)
6484        >>> output = op(input_x)
6485        >>> print(output)
6486        (Tensor(shape=[3, 1], dtype=Int32, value=
6487        [[1],
6488         [4],
6489         [7]]), Tensor(shape=[3, 2], dtype=Int32, value=
6490        [[2, 3],
6491         [5, 6],
6492         [8, 9]]))
6493        >>> input_x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), mindspore.int32)
6494        >>> op = ops.SplitV(size_splits=[2, 1], split_dim=0, num_split=2)
6495        >>> output = op(input_x)
6496        >>> print(output)
6497        (Tensor(shape=[2, 3], dtype=Int32, value=
6498        [[1, 2, 3],
6499         [4, 5, 6]]), Tensor(shape=[1, 3], dtype=Int32, value=
6500        [[7, 8, 9]]))
6501    """
6502
6503    @prim_attr_register
6504    def __init__(self, size_splits, split_dim, num_split):
6505        """Initialize SplitV"""
6506        validator.check_value_type("size_splits", size_splits, [tuple, list], self.name)
6507        for elements_of_size_splits in size_splits:
6508            validator.check_value_type("elements of size_splits", elements_of_size_splits, [int], self.name)
6509            if elements_of_size_splits != -1 and elements_of_size_splits < 1:
6510                raise ValueError(f"For \'{self.name}\', all elements of size_splits must be positive (except at most "
6511                                 f"one default value -1), but got: {elements_of_size_splits}.")
6512        validator.check_value_type("split_dim", split_dim, [int], self.name)
6513        validator.check_value_type("num_split", num_split, [int], self.name)
6514        validator.check_positive_int(num_split, "num_split", self.name)
6515        self.init_prim_io_names(inputs=['input_x'], outputs=['output'])
6516