• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020-2023 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16"""Communication APIs.
17"""
18from __future__ import absolute_import
19from __future__ import division
20
21from mindspore.common import Tensor
22from mindspore import _checkparam as validator
23from mindspore.communication.management import get_rank, get_group_size, GlobalComm, _get_group, _host_distribute
24from mindspore.common import dtype as mstype
25from mindspore.ops.primitive import PrimitiveWithInfer, PrimitiveWithCheck, Primitive, prim_attr_register
26from mindspore.common.api import context
27
28
29class ReduceOp:
30    """
31    Operation options for reducing tensors. This is an enumerated type, not an operator.
32
33    The main calling methods are as follows:
34
35    - SUM: ReduceOp.SUM.
36    - MAX: ReduceOp.MAX.
37    - MIN: ReduceOp.MIN.
38    - PROD: ReduceOp.PROD.
39
40    There are four kinds of operation options, "SUM", "MAX", "MIN", and "PROD".
41
42    - SUM: Take the sum.
43    - MAX: Take the maximum.
44    - MIN: Take the minimum.
45    - PROD: Take the product.
46
47    Supported Platforms:
48        ``Ascend`` ``GPU``
49
50    Examples:
51        .. note::
52            Before running the following examples, you need to configure the communication environment variables.
53
54            For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
55            without any third-party or configuration file dependencies.
56            Please see the `msrun start up
57            <https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/msrun_launcher.html>`_
58            for more details.
59
60            This example should be run with multiple devices.
61
62        >>> import numpy as np
63        >>> import mindspore
64        >>> from mindspore.communication import init
65        >>> from mindspore import Tensor, ops, nn
66        >>> from mindspore.ops import ReduceOp
67        >>>
68        >>> init()
69        >>> class Net(nn.Cell):
70        ...     def __init__(self):
71        ...         super(Net, self).__init__()
72        ...         self.allreduce_sum = ops.AllReduce(ReduceOp.SUM)
73        ...
74        ...     def construct(self, x):
75        ...         return self.allreduce_sum(x)
76        ...
77        >>> input_ = Tensor(np.ones([2, 8]).astype(np.float32))
78        >>> net = Net()
79        >>> output = net(input_)
80        >>> print(output)
81        [[2. 2. 2. 2. 2. 2. 2. 2.]
82         [2. 2. 2. 2. 2. 2. 2. 2.]]
83    """
84    SUM = "sum"
85    MAX = "max"
86    MIN = "min"
87    PROD = "prod"
88
89
90def check_collective_target_dtype(data_name, data_dtype, prim_name):
91    """Check if data type is valid."""
92    default_target_dtypes = (mstype.int8, mstype.int32, mstype.float16, mstype.float32, mstype.bfloat16)
93    gpu_target_dtypes = (mstype.bool_, mstype.int8, mstype.int32, mstype.int64, mstype.uint32, mstype.uint64,
94                         mstype.float16, mstype.float32, mstype.float64)
95
96    valid_dtype = gpu_target_dtypes if context.get_context("device_target") == "GPU" else default_target_dtypes
97    validator.check_tensor_dtype_valid(data_name, data_dtype, valid_dtype, prim_name)
98
99
100def check_hcom_group_valid(group, prim_name=None):
101    """Check if hcom group is valid."""
102    msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
103    if not _host_distribute() and context.get_context("mode") == context.PYNATIVE_MODE and \
104            group != GlobalComm.WORLD_COMM_GROUP:
105        raise RuntimeError(f"{msg_prefix} 'group' only support 'hccl_world_group' in pynative mode, but got "
106                           f"'group': {group}. Please start by using mpi-run.")
107
108
109class AllReduce(Primitive):
110    """
111    Reduces tensors across all devices in such a way that all devices will get the same final result,
112    returns the tensor which is all reduced.
113
114    Note:
115        The tensors must have the same shape and format in all processes of the collection.
116
117    Args:
118        op (str, optional): Specifies an operation used for element-wise reductions, like sum, prod, max, and min.
119                  On the CPU, only 'sum' is supported. Default: ``ReduceOp.SUM`` .
120        group (str, optional): The communication group to work on. Default: ``GlobalComm.WORLD_COMM_GROUP`` , which
121                  means ``"hccl_world_group"`` in Ascend, and ``"nccl_world_group"`` in GPU.
122
123    Inputs:
124        - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
125
126    Outputs:
127        Tensor, has the same shape of the input, i.e., :math:`(x_1, x_2, ..., x_R)`.
128        The contents depend on the specified operation.
129
130    Raises:
131        TypeError: If any of `op` and `group` is not a str or the input's dtype is bool.
132        RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails.
133
134    Supported Platforms:
135        ``Ascend`` ``GPU`` ``CPU``
136
137    Examples:
138        .. note::
139            Before running the following examples, you need to configure the communication environment variables.
140
141            For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
142            without any third-party or configuration file dependencies.
143            Please see the `msrun start up
144            <https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/msrun_launcher.html>`_
145            for more details.
146
147            This example should be run with 2 devices.
148
149        >>> import numpy as np
150        >>> from mindspore.communication import init
151        >>> from mindspore import Tensor
152        >>> from mindspore.ops import ReduceOp
153        >>> import mindspore.nn as nn
154        >>> from mindspore import ops
155        >>>
156        >>> init()
157        >>> class Net(nn.Cell):
158        ...     def __init__(self):
159        ...         super(Net, self).__init__()
160        ...         self.allreduce_sum = ops.AllReduce(ReduceOp.SUM)
161        ...
162        ...     def construct(self, x):
163        ...         return self.allreduce_sum(x)
164        ...
165        >>> input_ = Tensor(np.ones([2, 8]).astype(np.float32))
166        >>> net = Net()
167        >>> output = net(input_)
168        >>> print(output)
169        [[2. 2. 2. 2. 2. 2. 2. 2.]
170         [2. 2. 2. 2. 2. 2. 2. 2.]]
171
172    Tutorial Examples:
173        - `Distributed Set Communication Primitives - AllReduce
174          <https://www.mindspore.cn/docs/en/master/api_python/samples/ops/communicate_ops.html#allreduce>`_
175
176    """
177
178    @prim_attr_register
179    def __init__(self, op=ReduceOp.SUM, group=GlobalComm.WORLD_COMM_GROUP):
180        """Initialize AllReduce."""
181        if not isinstance(op, type(ReduceOp.SUM)):
182            raise TypeError(f"For '{self.name}', the 'op' must be str, but got {type(op).__name__}.")
183        if not isinstance(_get_group(group), str):
184            raise TypeError(f"For '{self.name}', the 'group' must be str, "
185                            f"but got {type(_get_group(group)).__name__}.")
186        check_hcom_group_valid(group, prim_name=self.name)
187        self.op = op
188        self.add_prim_attr('group', _get_group(group))
189        self.add_prim_attr('fusion', 0)
190        self.add_prim_attr('index', 0)
191        self.add_prim_attr('no_eliminate', True)
192
193
194class Reduce(PrimitiveWithInfer):
195    """
196    Reduces tensors across the processes in the specified communication group, sends the result
197    to the target dest_rank(local rank), and returns the tensor which is sent to the target process.
198
199    Note:
200        Only process with destination rank receives the reduced output.
201        Support PyNative mode and Graph mode, but Graph mode only supports scenes with a graph compilation level of O0.
202        Other processes only get a tensor with shape [1], which has no mathematical meaning.
203
204    Args:
205        dest_rank (int): The target process(local rank) in the specific group that receives the reduced output.
206        op (str, optional): Specifies an operation used for element-wise reductions, like sum, prod, max, and min.
207                  On the CPU, only 'sum' is supported. Default: ``ReduceOp.SUM`` .
208        group (str, optional): The communication group to work on. Default: ``GlobalComm.WORLD_COMM_GROUP`` , which
209                  means ``"hccl_world_group"`` in Ascend, and ``"nccl_world_group"`` in GPU.
210
211    Inputs:
212        - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
213
214    Outputs:
215        Tensor. Return the tensor in the specific rank of the process after reduction.
216        The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
217
218    Raises:
219        TypeError: If the type of the first input parameter is not Tensor,
220                or any of `op` and `group` is not a str.
221        RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails.
222
223    Supported Platforms:
224        ``Ascend``
225
226    Examples:
227        .. note::
228            Before running the following examples, you need to configure the communication environment variables.
229
230            For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method without any third-party
231            or configuration file dependencies.
232            Please see the `msrun start up
233            <https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/msrun_launcher.html>`_
234            for more details.
235
236            This example should be run with 4 devices.
237
238        >>> from mindspore import ops
239        >>> import mindspore.nn as nn
240        >>> from mindspore.communication import init
241        >>> from mindspore import Tensor
242        >>> import numpy as np
243        >>> # Launch 4 processes.
244        >>> init()
245        >>> class ReduceNet(nn.Cell):
246        >>>     def __init__(self):
247        >>>         super(Net, self).__init__()
248        >>>         self.reduce = ops.Reduce(dest_rank=1)
249        >>>
250        >>>     def construct(self, x):
251        >>>         out = self.reduce(x)
252        >>>         return out
253        >>> input = Tensor(np.ones([2, 8]).astype(np.float32))
254        >>> net = ReduceNet()
255        >>> output = net(input)
256        >>> print(output)
257        Process with rank 1: [[4. 4. 4. 4. 4. 4. 4. 4.]
258                             [4. 4. 4. 4. 4. 4. 4. 4.]],
259        Other proesses: [0.].
260    """
261
262    @prim_attr_register
263    def __init__(self, dest_rank, op=ReduceOp.SUM, group=GlobalComm.WORLD_COMM_GROUP):
264        validator.check_value_type('group', _get_group(group), (str,), self.name)
265        validator.check_value_type('op', op, (type(ReduceOp.SUM),), self.name)
266        self.dest_rank = dest_rank
267        self.op = op
268        self.group = _get_group(group)
269        self.add_prim_attr('group', _get_group(group))
270        self.add_prim_attr('dest_rank', dest_rank)
271
272    def infer_shape(self, x_shape):
273        # The process with dest_rank returns the reduced output.
274        # Other processes only gets a tensor with shape [1], which has no mathematical meaning.
275        if self.dest_rank == get_rank():
276            return x_shape
277        return [1]
278
279    def infer_dtype(self, x_dtype):
280        return x_dtype
281
282
283class AllGather(PrimitiveWithInfer):
284    """
285    Gathers tensors from the specified communication group and returns the tensor which is all gathered.
286
287    Note:
288        - The tensors must have the same shape and format in all processes of the collection.
289
290    Args:
291        group (str, optional): The communication group to work on. Default: ``GlobalComm.WORLD_COMM_GROUP`` , which
292            means ``"hccl_world_group"`` in Ascend, and ``"nccl_world_group"`` in GPU.
293
294    Inputs:
295        - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
296
297    Outputs:
298        Tensor. If the number of devices in the group is N,
299        then the shape of output is :math:`(N, x_1, x_2, ..., x_R)`.
300
301    Raises:
302        TypeError: If `group` is not a str.
303        ValueError: If the local rank id of the calling process in the group
304                    is larger than the group's rank size.
305        RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails.
306
307    Supported Platforms:
308        ``Ascend`` ``GPU``
309
310    Examples:
311        .. note::
312            Before running the following examples, you need to configure the communication environment variables.
313
314            For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
315            without any third-party or configuration file dependencies.
316            Please see the `msrun start up
317            <https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/msrun_launcher.html>`_
318            for more details.
319
320            This example should be run with 2 devices.
321
322        >>> import numpy as np
323        >>> import mindspore as ms
324        >>> from mindspore import ops
325        >>> import mindspore.nn as nn
326        >>> from mindspore.communication import init
327        >>> from mindspore import Tensor
328        >>>
329        >>> ms.set_context(mode=ms.GRAPH_MODE)
330        >>> init()
331        >>> class Net(nn.Cell):
332        ...     def __init__(self):
333        ...         super(Net, self).__init__()
334        ...         self.allgather = ops.AllGather()
335        ...
336        ...     def construct(self, x):
337        ...         return self.allgather(x)
338        ...
339        >>> input_x = Tensor(np.ones([2, 8]).astype(np.float32))
340        >>> net = Net()
341        >>> output = net(input_x)
342        >>> print(output)
343        [[1. 1. 1. 1. 1. 1. 1. 1.]
344         [1. 1. 1. 1. 1. 1. 1. 1.]
345         [1. 1. 1. 1. 1. 1. 1. 1.]
346         [1. 1. 1. 1. 1. 1. 1. 1.]]
347
348    Tutorial Examples:
349        - `Distributed Set Communication Primitives - AllGather
350          <https://www.mindspore.cn/docs/en/master/api_python/samples/ops/communicate_ops.html#allgather>`_
351
352    """
353
354    @prim_attr_register
355    def __init__(self, group=GlobalComm.WORLD_COMM_GROUP):
356        """Initialize AllGather."""
357        validator.check_value_type('group', _get_group(group), (str,), self.name)
358        self.rank = get_rank(_get_group(group))
359        self.rank_size = get_group_size(_get_group(group))
360        validator.check('rank', self.rank, 'rank_size', self.rank_size, validator.LT, self.name)
361        self.add_prim_attr('rank_size', self.rank_size)
362        self.add_prim_attr('group', _get_group(group))
363        self.add_prim_attr('fusion', 0)
364        self.add_prim_attr('mean_flag', False)
365        self.add_prim_attr('no_eliminate', True)
366
367    def infer_shape(self, x_shape):
368        validator.check_positive_int(len(x_shape), "x shape", self.name)
369        if x_shape[0] > 0:
370            x_shape[0] = x_shape[0] * self.rank_size
371        return x_shape
372
373    def infer_dtype(self, x_dtype):
374        check_collective_target_dtype('x', x_dtype, self.name)
375        return x_dtype
376
377
378class AShardIdentity(PrimitiveWithInfer):
379    """
380    Auto parallel virtual operator. Identity operator only for shard function.
381    Do nothing in terms of infer_shape, infer_dtype, and the tensor.
382
383    It is only for internal use of parallel modules and cannot be called by users.
384    """
385
386    @prim_attr_register
387    def __init__(self):
388        pass
389
390    def infer_shape(self, x_shape):
391        return x_shape
392
393    def infer_dtype(self, x_dtype):
394        return x_dtype
395
396
397class _MiniStepAllGather(PrimitiveWithInfer):
398    """
399    Auto parallel virtual operator. Do nothing in forward, do reducescatter in backward in mini-step. It is only for
400    internal use of parallel modules and cannot be called by users.
401
402    Args:
403        group (str): The communication group to work on. Default: ``None`` .
404        grad_accumulation_step (int): The grad accumulation step. Default: ``None`` .
405    """
406
407    @prim_attr_register
408    def __init__(self, group=GlobalComm.WORLD_COMM_GROUP, grad_accumulation_step=None, mean_flag=None):
409        """Initialize _MiniStepAllGather."""
410        validator.check_value_type('group', _get_group(group), (str,), self.name)
411        self.rank = get_rank(_get_group(group))
412        self.rank_size = get_group_size(_get_group(group))
413        validator.check('rank', self.rank, 'rank_size', self.rank_size, validator.LT, self.name)
414        self.add_prim_attr('rank_size', self.rank_size)
415        self.add_prim_attr('group', _get_group(group))
416        self.add_prim_attr('fusion', 1)
417        self.grad_accumulation_step = grad_accumulation_step
418        self.mean_flag = mean_flag
419        self.add_prim_attr('order_enforce_skip', True)
420        self.add_prim_attr('side_effect_backprop_mem', True)
421
422    def infer_shape(self, x_shape, z_shape):
423        validator.check_positive_int(len(x_shape), "x shape", self.name)
424        if x_shape[0] > 0:
425            x_shape[0] = x_shape[0] * self.rank_size
426        return x_shape
427
428    def infer_dtype(self, x_dtype, z_shape):
429        check_collective_target_dtype('x', x_dtype, self.name)
430        return x_dtype
431
432
433class _MicroStepAllGather(PrimitiveWithInfer):
434    """
435    Auto parallel virtual operator. Do nothing in forward, do reducescatter in backward in mini-step. It is only for
436    internal use of parallel modules and cannot be called by users.
437
438    Args:
439        group (str): The communication group to work on. Default: ``None`` .
440    """
441
442    @prim_attr_register
443    def __init__(self, group=GlobalComm.WORLD_COMM_GROUP, mean_flag=None):
444        validator.check_value_type('group', _get_group(group), (str,), self.name)
445        self.rank_size = 1
446        if group != "":
447            self.rank = get_rank(_get_group(group))
448            self.rank_size = get_group_size(_get_group(group))
449            validator.check('rank', self.rank, 'rank_size', self.rank_size, validator.LT, self.name)
450            self.add_prim_attr('rank_size', self.rank_size)
451            self.add_prim_attr('group', _get_group(group))
452            self.add_prim_attr('fusion', 1)
453            self.add_prim_attr('do_mirror', False)
454            self.mean_flag = mean_flag
455            self.add_prim_attr('order_enforce_skip', True)
456
457    def infer_shape(self, x_shape, z_shape):
458        validator.check_positive_int(len(x_shape), "x shape", self.name)
459        if x_shape[0] > 0:
460            x_shape[0] = x_shape[0] * self.rank_size
461        return x_shape
462
463    def infer_dtype(self, x_dtype, z_dtype):
464        check_collective_target_dtype('x', x_dtype, self.name)
465        return x_dtype
466
467
468class _HostAllGather(PrimitiveWithInfer):
469    """
470    Gathers tensors from the specified communication group on host.
471
472    Note:
473        The tensors must have the same shape and format in all processes of the collection.
474        _HostAllGather is a host-side operator, it depends on OpenMPI and must use build option -M on
475        to enable it. Using mpirun command to run it:
476        mpirun -output-filename log -merge-stderr-to-stdout -np 3 python test_host_all_gather.py
477
478    Args:
479        group (Union[tuple[int],list[int]]): The rand_ids of communication group to work on. Default: ``None`` .
480
481    Raises:
482        TypeError: If group is not a list nor tuple, or elements of group are not int.
483        ValueError: If group is not set, or rank_id from group not in [0, 7].
484
485    Inputs:
486        - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
487
488    Outputs:
489        Tensor. If the number of devices in the group is N,
490        then the shape of output is :math:`(N, x_1, x_2, ..., x_R)`.
491    """
492
493    @prim_attr_register
494    def __init__(self, group=None):
495        """Initialize _HostAllGather."""
496        if group is None:
497            raise ValueError(f"For '{self.name}', the 'group' cannot be None, but got {group}.")
498        validator.check_value_type('group', group, (tuple, list), self.name)
499        validator.check_int(len(group), 2, validator.GE, "group size", self.name)
500        for r in group:
501            validator.check_int_range(r, 0, 7, validator.INC_BOTH, "rank_id", self.name)
502            validator.check_value_type("rank_id", r, (int,), self.name)
503        self.group_size = len(group)
504        self.add_prim_attr('group', group)
505        self.add_prim_attr('no_eliminate', True)
506        self.add_prim_attr('order_enforce_skip', True)
507
508    def infer_shape(self, x_shape):
509        validator.check_positive_int(len(x_shape), "x shape", self.name)
510        if x_shape[0] > 0:
511            x_shape[0] = x_shape[0] * self.group_size
512        return x_shape
513
514    def infer_dtype(self, x_dtype):
515        check_collective_target_dtype('x', x_dtype, self.name)
516        return x_dtype
517
518    def __call__(self, tensor):
519        raise NotImplementedError
520
521
522class ReduceScatter(Primitive):
523    r"""
524    Reduces and scatters tensors from the specified communication group
525    and returns the tensor which is reduced and scattered.
526
527    Note:
528        The tensors must have the same shape and format in all processes of the collection.
529
530    Args:
531        op (str, optional): Specifies an operation used for element-wise reductions,
532                  like SUM and MAX. Default: ``ReduceOp.SUM`` .
533        group (str, optional): The communication group to work on. Default: ``GlobalComm.WORLD_COMM_GROUP`` .
534
535    Inputs:
536        - **input_x** (Tensor) - Input Tensor, suppose it has a shape :math:`(N, *)`, where `*`
537          means any number of additional dimensions. N must be divisible by rank_size.
538          rank_size refers to the number of cards in the communication group.
539
540    Outputs:
541        Tensor, it has the same dtype as `input_x` with a shape of :math:`(N/rank\_size, *)`.
542
543    Raises:
544        TypeError: If any of operation and group is not a string.
545        ValueError: If the first dimension of the input cannot be divided by the rank_size.
546        RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails.
547
548    Supported Platforms:
549        ``Ascend`` ``GPU``
550
551    Examples:
552        .. note::
553            Before running the following examples, you need to configure the communication environment variables.
554
555            For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
556            without any third-party or configuration file dependencies.
557            Please see the `msrun start up
558            <https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/msrun_launcher.html>`_
559            for more details.
560
561            This example should be run with 2 devices.
562
563        >>> import mindspore as ms
564        >>> from mindspore import Tensor
565        >>> from mindspore.communication import init
566        >>> from mindspore.ops import ReduceOp
567        >>> import mindspore.nn as nn
568        >>> from mindspore import ops
569        >>> import numpy as np
570        >>>
571        >>> ms.set_context(mode=ms.GRAPH_MODE)
572        >>> init()
573        >>> class Net(nn.Cell):
574        ...     def __init__(self):
575        ...         super(Net, self).__init__()
576        ...         self.reducescatter = ops.ReduceScatter(ReduceOp.SUM)
577        ...
578        ...     def construct(self, x):
579        ...         return self.reducescatter(x)
580        ...
581        >>> input_ = Tensor(np.ones([8, 8]).astype(np.float32))
582        >>> net = Net()
583        >>> output = net(input_)
584        >>> print(output)
585        [[2. 2. 2. 2. 2. 2. 2. 2.]
586         [2. 2. 2. 2. 2. 2. 2. 2.]
587         [2. 2. 2. 2. 2. 2. 2. 2.]
588         [2. 2. 2. 2. 2. 2. 2. 2.]]
589
590    Tutorial Examples:
591        - `Distributed Set Communication Primitives - ReduceScatter
592          <https://www.mindspore.cn/docs/en/master/api_python/samples/ops/communicate_ops.html#reducescatter>`_
593
594    """
595
596    @prim_attr_register
597    def __init__(self, op=ReduceOp.SUM, group=GlobalComm.WORLD_COMM_GROUP):
598        """Initialize ReduceScatter."""
599        validator.check_value_type('op', op, (type(ReduceOp.SUM),), self.name)
600        validator.check_value_type('group', _get_group(group), (str,), self.name)
601        self.op = op
602        self.rank_size = get_group_size(_get_group(group))
603        self.add_prim_attr('rank_size', self.rank_size)
604        self.add_prim_attr('group', _get_group(group))
605        self.add_prim_attr('fusion', 0)
606        self.add_prim_attr('no_eliminate', True)
607
608
609class _HostReduceScatter(PrimitiveWithInfer):
610    """
611    Reduces and scatters tensors from the specified communication group on host.
612
613    Note:
614        The tensors must have the same shape and format in all processes of the collection.
615        _HostReduceScatter is a host-side operator, it depends on OpenMPI and must use build option
616        -M on to enable it. Using mpirun command to run it:
617        mpirun -output-filename log -merge-stderr-to-stdout -np 3 python test_host_reduce_scatter.py
618
619    Args:
620        op (str): Specifies an operation used for element-wise reductions,
621                  like sum, max, avg. Default: ``ReduceOp.SUM`` .
622        group (Union[tuple[int],list[int]]): The rand_ids of communication group to work on. Default: ``None`` .
623
624    Raises:
625        TypeError: If op is not a string and group is not a list nor tuple,
626                   or elements of group are not int.
627        ValueError: If the first dimension of input can not be divided by group size,
628                    or group is not set, or rank_id not in [0, 7].
629    """
630
631    @prim_attr_register
632    def __init__(self, op=ReduceOp.SUM, group=None):
633        """Initialize _HostReduceScatter."""
634        if group is None:
635            raise ValueError(f"For '{self.name}', the 'group' cannot be None, but got {group}.")
636        validator.check_value_type('op', op, (type(ReduceOp.SUM),), self.name)
637        validator.check_value_type('group', group, (tuple, list), self.name)
638        validator.check_int(len(group), 2, validator.GE, "group size", self.name)
639        for r in group:
640            validator.check_int_range(r, 0, 7, validator.INC_BOTH, "rank_id", self.name)
641            validator.check_value_type("rank_id", r, (int,), self.name)
642        self.op = op
643        self.group_size = len(group)
644        self.add_prim_attr('group', group)
645        self.add_prim_attr('no_eliminate', True)
646        self.add_prim_attr('order_enforce_skip', True)
647
648    def infer_shape(self, x_shape):
649        if x_shape[0] % self.group_size != 0:
650            raise ValueError(f"For '{self.name}', the first dimension of 'x_shape' must be divided by 'group_size', "
651                             f"but got 'x_shape[0]': {x_shape[0]}, 'rank_size': {self.group_size}.")
652        x_shape[0] = int(x_shape[0] / self.group_size)
653        return x_shape
654
655    def infer_dtype(self, x_dtype):
656        check_collective_target_dtype('x', x_dtype, self.name)
657        return x_dtype
658
659    def __call__(self, tensor):
660        raise NotImplementedError
661
662
663class Broadcast(PrimitiveWithInfer):
664    """
665    Broadcasts the tensor to the whole group.
666
667    Note:
668        The tensors must have the same shape and format in all processes of the collection.
669
670    Args:
671        root_rank (int): Specifies the rank(global rank) of the process that broadcast the tensor.
672            And only process `root_rank` will broadcast the tensor.
673        group (str, optional): The communication group to work on. Default: ``GlobalComm.WORLD_COMM_GROUP`` .
674
675    Inputs:
676        - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
677
678    Outputs:
679        tuple[Tensor], Tensor has the same shape of the input, i.e., :math:`(x_1, x_2, ..., x_R)`.
680        The contents depend on the data of the `root_rank` device.
681
682    Raises:
683        TypeError: If root_rank is not an integer or group is not a string.
684
685    Supported Platforms:
686        ``Ascend`` ``GPU``
687
688    Examples:
689        .. note::
690            Before running the following examples, you need to configure the communication environment variables.
691
692            For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
693            without any third-party or configuration file dependencies.
694            Please see the `msrun start up
695            <https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/msrun_launcher.html>`_
696            for more details.
697
698            This example should be run with 2 devices.
699
700        >>> import mindspore as ms
701        >>> from mindspore import Tensor
702        >>> from mindspore.communication import init
703        >>> import mindspore.nn as nn
704        >>> from mindspore import ops
705        >>> import numpy as np
706        >>>
707        >>> ms.set_context(mode=ms.GRAPH_MODE)
708        >>> init()
709        >>> class Net(nn.Cell):
710        ...     def __init__(self):
711        ...         super(Net, self).__init__()
712        ...         self.broadcast = ops.Broadcast(1)
713        ...
714        ...     def construct(self, x):
715        ...         return self.broadcast((x,))
716        ...
717        >>> input_x = Tensor(np.ones([2, 4]).astype(np.int32))
718        >>> net = Net()
719        >>> output = net(input_x)
720        >>> print(output)
721        (Tensor(shape[2,4], dtype=Int32, value=
722        [[1, 1, 1, 1],
723         [1, 1, 1, 1]]),)
724
725    Tutorial Examples:
726        - `Distributed Set Communication Primitives - Broadcast
727          <https://www.mindspore.cn/docs/en/master/api_python/samples/ops/communicate_ops.html#broadcast>`_
728
729    """
730
731    @prim_attr_register
732    def __init__(self, root_rank, group=GlobalComm.WORLD_COMM_GROUP):
733        """Initialize Broadcast."""
734        validator.check_value_type('root_rank', root_rank, (int,), self.name)
735        validator.check_value_type('group', _get_group(group), (str,), self.name)
736        check_hcom_group_valid(group, prim_name=self.name)
737        self.add_prim_attr('group', _get_group(group))
738        self.add_prim_attr('no_eliminate', True)
739
740
741class _AllSwap(PrimitiveWithCheck):
742    """
743    _AllSwap is a collective operation.
744
745    _AllSwap sends data from the all processes to the all processes in the specified group. It has two phases:
746
747    - The scatter phase: On each process, the operand is split into the send size of blocks along the
748      0-th axis, and the blocks are scattered to all processes, e.g., the ith block is send to the ith process.
749    - The gather phase: Each process concatenates the received blocks along the 0-th axis.
750
751    Note:
752        The tensors must have the same format in all processes of the collection.
753
754    Args:
755        group (str): The communication group name.
756
757    Inputs:
758        tensor_in (tensor): A 2-D tensor. On each process, divide blocks into number of the send size.
759        send_size (tensor): A 1-D int64 tensor. The element is the send data size for each process.
760        recv_size (tensor): A 1-D int64 tensor. The element is the receive data size for each process.
761
762    Returns:
763        tensor_out (tensor): The result tensor.
764
765    Raises:
766        TypeError: If group is not a string.
767    """
768
769    @prim_attr_register
770    def __init__(self, group=GlobalComm.WORLD_COMM_GROUP):
771        """Initialize _AllSwap"""
772        validator.check_value_type('group', _get_group(group), (str,), self.name)
773        self.init_prim_io_names(inputs=['tensor_in', 'send_size', 'recv_size'], outputs=['tensor_out'])
774        self.add_prim_attr('group', _get_group(group))
775        self.add_prim_attr('no_eliminate', True)
776        self.add_prim_attr('order_enforce_skip', True)
777
778    def __check__(self, tensor_in, send_size, recv_size):
779        validator.check_subclass("tensor_in", tensor_in['dtype'], mstype.tensor_type, self.name)
780        validator.check_tensor_dtype_valid("send_size", send_size['dtype'], [mstype.int64],
781                                           self.name)
782        validator.check_tensor_dtype_valid("recv_size", recv_size['dtype'], [mstype.int64],
783                                           self.name)
784
785        validator.check_equal_int(len(tensor_in['shape']), 2, "tensor_in", self.name)
786        validator.check_equal_int(len(send_size['shape']), 1, "send_size", self.name)
787        validator.check_equal_int(len(recv_size['shape']), 1, "recv_size", self.name)
788
789        out_shape = [-1] + [tensor_in['shape'][1]]
790        out = {'shape': out_shape,
791               'dtype': tensor_in['dtype'],
792               'value': None}
793        return out
794
795
796class NeighborExchange(Primitive):
797    """
798    NeighborExchange is a collective operation.
799
800    NeighborExchange sends data from the local rank to ranks in the send_rank_ids,
801    as while receive data from recv_rank_ids.
802
803    Note:
804        The user needs to preset
805        communication environment variables before running the following example, please check the details on the
806        official website of `MindSpore \
807        <https://www.mindspore.cn/docs/en/master/api_python/mindspore.ops.primitive.html#communication-operator>`_.
808
809        This operator requires a full-mesh network topology, each device has the same vlan id, and the ip & mask are
810        in the same subnet, please check the `details \
811        <https://www.mindspore.cn/docs/en/master/api_python/samples/ops/communicate_ops.html#notes>`_.
812
813    Args:
814        send_rank_ids (list(int)): Ranks which the data is sent to.
815        recv_rank_ids (list(int)): Ranks which the data is received from.
816        recv_shapes (tuple(list(int))): Data shape which received from recv_rank_ids.
817        send_shapes (tuple(list(int))): Data shape which send to the send_rank_ids.
818        recv_type (type): Data type which received from recv_rank_ids
819        group (str): The communication group to work on. Default: ``GlobalComm.WORLD_COMM_GROUP`` .
820
821    Inputs:
822        - **input_x** (tuple[Tensor]) - Shapes are same as args of send_shapes.
823
824    Outputs:
825        Tuple tensor, shapes are same as args of recv_shapes.
826
827    Supported Platforms:
828        ``Ascend``
829
830    Examples:
831        >>> # This example should be run with 2 devices. Refer to the tutorial > Distributed Training on mindspore.cn
832        >>> import os
833        >>> import mindspore as ms
834        >>> from mindspore import Tensor
835        >>> from mindspore.communication import init
836        >>> import mindspore.nn as nn
837        >>> from mindspore import ops
838        >>> import numpy as np
839        >>> class Net(nn.Cell):
840        ...     def __init__(self):
841        ...         super(Net, self).__init__()
842        ...         self.neighborexchange = ops.NeighborExchange(send_rank_ids=[1], recv_rank_ids=[1],
843        ...                                                      recv_shapes=([2, 2],), send_shapes=([3, 3],),
844        ...                                                      recv_type=ms.float32)
845        ...
846        ...
847        ...     def construct(self, x):
848        ...         out = self.neighborexchange((x,))
849        ...
850        >>> ms.set_context(mode=ms.GRAPH_MODE)
851        >>> init()
852        >>> net = Net()
853        >>> input_x = Tensor(np.ones([3, 3]), dtype = ms.float32)
854        >>> output = net(input_x)
855        >>> print(output)
856        [[2. 2.], [2. 2.]]
857
858    Tutorial Examples:
859        - `Distributed Set Communication Primitives - NeighborExchange
860          <https://www.mindspore.cn/docs/en/master/api_python/samples/ops/communicate_ops.html#neighborexchange>`_
861
862    """
863
864    @prim_attr_register
865    def __init__(self, send_rank_ids, recv_rank_ids, recv_shapes, send_shapes, recv_type,
866                 group=GlobalComm.WORLD_COMM_GROUP):
867        self.init_prim_io_names(inputs=['x'], outputs=['output'])
868        self.send_rank_ids = send_rank_ids
869        self.recv_rank_ids = recv_rank_ids
870        self.recv_shapes = recv_shapes
871        self.send_shapes = send_shapes
872        self.recv_type = recv_type
873        self.add_prim_attr('group', _get_group(group))
874        self.add_prim_attr('no_eliminate', True)
875
876    def __call__(self, tensor):
877        raise NotImplementedError
878
879
880class AlltoAll(PrimitiveWithInfer):
881    r"""
882    AlltoAll is a collective operation.
883
884    AlltoAll sends data from the all processes to the all processes in the specified group. It has two phases:
885
886    - The scatter phase: On each process, the operand is split into split_count number of blocks along the
887      split_dimensions, and the blocks are scattered to all processes, e.g., the ith block is send to the ith process.
888    - The gather phase: Each process concatenates the received blocks along the concat_dimension.
889
890    Note:
891        This operator requires a full-mesh network topology, each device has the same vlan id, and the ip & mask are
892        in the same subnet, please check the `details \
893        <https://www.mindspore.cn/docs/en/master/api_python/samples/ops/communicate_ops.html#notes>`_.
894
895    Args:
896        split_count (int): On each process, divide blocks into split_count number.
897        split_dim (int): On each process, split blocks along the split_dim.
898        concat_dim (int): On each process, gather the received blocks along the concat_dimension.
899        group (str): The communication group to work on. Default: ``GlobalComm.WORLD_COMM_GROUP`` .
900
901    Inputs:
902        - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
903
904    Outputs:
905        Tensor. If the shape of input tensor is :math:`(x_1, x_2, ..., x_R)`, then the shape of output tensor is
906        :math:`(y_1, y_2, ..., y_R)`, where:
907
908        - :math:`y_{split\_dim} = x_{split\_dim} / split\_count`
909        - :math:`y_{concat\_dim} = x_{concat\_dim} * split\_count`
910        - :math:`y_{other} = x_{other}`.
911
912    Raises:
913        TypeError: If group is not a string.
914
915    Supported Platforms:
916        ``Ascend``
917
918    Examples:
919        .. note::
920            Before running the following examples, you need to configure the communication environment variables.
921
922            For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
923            without any third-party or configuration file dependencies.
924            Please see the `msrun start up
925            <https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/msrun_launcher.html>`_
926            for more details.
927
928            This example should be run with 8 devices.
929
930        >>> import os
931        >>> import mindspore as ms
932        >>> from mindspore import Tensor
933        >>> from mindspore.communication import init
934        >>> import mindspore.nn as nn
935        >>> from mindspore import ops
936        >>> import numpy as np
937        >>> class Net(nn.Cell):
938        ...     def __init__(self):
939        ...         super(Net, self).__init__()
940        ...         self.alltoall = ops.AlltoAll(split_count = 8, split_dim = -2, concat_dim = -1)
941        ...
942        ...     def construct(self, x):
943        ...         out = self.alltoall(x)
944        ...         return out
945        ...
946        >>> ms.set_context(mode=ms.GRAPH_MODE)
947        >>> init()
948        >>> net = Net()
949        >>> rank_id = int(os.getenv("RANK_ID"))
950        >>> input_x = Tensor(np.ones([1, 1, 8, 1]) * rank_id, dtype = ms.float32)
951        >>> output = net(input_x)
952        >>> print(output)
953        [[[[0. 1. 2. 3. 4. 5. 6. 7.]]]]
954
955    Tutorial Examples:
956        - `Distributed Set Communication Primitives - AlltoAll
957          <https://www.mindspore.cn/docs/en/master/api_python/samples/ops/communicate_ops.html#alltoall>`_
958
959    """
960
961    @prim_attr_register
962    def __init__(self, split_count, split_dim, concat_dim, group=GlobalComm.WORLD_COMM_GROUP):
963        """Initialize AlltoAll"""
964        validator.check_value_type('group', _get_group(group), (str,), self.name)
965        validator.check_is_int(split_count, int)
966        validator.check_is_int(split_dim, int)
967        validator.check_is_int(concat_dim, int)
968        self.split_count = split_count
969        self.split_dim = split_dim
970        self.concat_dim = concat_dim
971        self.add_prim_attr('group', _get_group(group))
972        self.add_prim_attr('no_eliminate', True)
973
974    def infer_shape(self, x_shape):
975        rank_size = get_group_size(_get_group(self.group))
976        if self.split_count != rank_size:
977            raise ValueError(f"For '{self.name}', the 'split_count' must be equal to 'rank_size', "
978                             f"but got 'split_count': {self.split_count}, 'rank_size': {rank_size}.")
979        if x_shape[self.split_dim] >= 0 and x_shape[self.split_dim] % self.split_count != 0:
980            raise ValueError(f"For '{self.name}', the 'x_shape[self.split_dim]' must be divisible by 'split_count', "
981                             f"but got 'x_shape[self.split_dim]' {x_shape[self.split_dim]}, "
982                             f"'split_count' {self.split_count}.")
983        if x_shape[self.concat_dim] >= 0:
984            x_shape[self.concat_dim] = x_shape[self.concat_dim] * self.split_count
985        if x_shape[self.split_dim] >= 0:
986            x_shape[self.split_dim] = int(x_shape[self.split_dim] / self.split_count)
987        return x_shape
988
989    def infer_dtype(self, x_dtype):
990        check_collective_target_dtype('x', x_dtype, self.name)
991        return x_dtype
992
993
994class NeighborExchangeV2(Primitive):
995    r"""
996    NeighborExchangeV2 is a collective communication operation.
997
998    NeighborExchangeV2 sends data from the local rank to ranks in the `send_rank_ids`,
999    as while receive data from `recv_rank_ids`. Please refer to the tutorial examples
1000    below to learn about how the data is exchanged between neighborhood devices.
1001
1002    Note:
1003        This operator requires a full-mesh network topology, each device has the same vlan id, and the ip & mask are
1004        in the same subnet, please check the `details \
1005        <https://www.mindspore.cn/docs/en/master/api_python/samples/ops/communicate_ops.html#notes>`_.
1006
1007    Args:
1008        send_rank_ids (list(int)): Ranks which the data is sent to. 8 rank_ids represents 8 directions, if one
1009                                   direction is not send to , set it -1.
1010        recv_rank_ids (list(int)): Ranks which the data is received from. 8 rank_ids represents 8 directions,
1011                                   if one direction is not recv from , set it -1.
1012        send_lens (list(int)): Data lens which send to the send_rank_ids, 4 numbers represent the lens of
1013                               [send_top, send_bottom, send_left, send_right].
1014        recv_lens (list(int)): Data lens which received from recv_rank_ids, 4 numbers represent the lens of
1015                               [recv_top, recv_bottom, recv_left, recv_right].
1016        data_format (str): Data format, only support NCHW now.
1017        group (str, optional): The communication group to work on. Default: ``GlobalComm.WORLD_COMM_GROUP`` , which
1018                     means ``"hccl_world_group"`` in Ascend, and ``"nccl_world_group"`` in GPU.
1019
1020    Inputs:
1021        - **input_x** (Tensor) - The Tensor before being exchanged. It has a shape of :math:`(N, C, H, W)`.
1022
1023    Outputs:
1024        The Tensor after being exchanged. If input shape is :math:`(N, C, H, W)`, output shape is
1025        :math:`(N, C, H+recv\_top+recv\_bottom, W+recv\_left+recv\_right)`.
1026
1027    Raises:
1028        TypeError: If `group` is not a string or any one of `send_rank_ids`,
1029            `recv_rank_ids`, `send_lens`, `recv_lens` is not a list.
1030        ValueError: If `send_rank_ids` or `recv_rank_ids` has value less than -1 or has repeated values.
1031        ValueError: If `send_lens`, `recv_lens` has value less than 0.
1032        ValueError: If `data_format` is not "NCHW".
1033
1034    Supported Platforms:
1035        ``Ascend``
1036
1037    Examples:
1038        .. note::
1039            Before running the following examples, you need to configure the communication environment variables.
1040
1041            For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
1042            without any third-party or configuration file dependencies.
1043            Please see the `msrun start up
1044            <https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/msrun_launcher.html>`_
1045            for more details.
1046
1047            This example should be run with 2 devices.
1048
1049        >>> import os
1050        >>> import mindspore as ms
1051        >>> from mindspore.communication import init
1052        >>> import mindspore.nn as nn
1053        >>> from mindspore import ops
1054        >>> import numpy as np
1055        >>>
1056        >>> class Net0(nn.Cell):
1057        ...     def __init__(self):
1058        ...         super(Net0, self).__init__()
1059        ...         self.neighbor_exchangev2 = ops.NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
1060        ...                                                           send_lens=[0, 1, 0, 0],
1061        ...                                                           recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
1062        ...                                                           recv_lens=[0, 1, 0, 0], data_format="NCHW")
1063        ...
1064        ...     def construct(self, x):
1065        ...         out = self.neighbor_exchangev2(x)
1066        ...         return out
1067        ... class Net1(nn.Cell):
1068        ...     def __init__(self):
1069        ...         super(Net1, self).__init__()
1070        ...         self.neighbor_exchangev2 = ops.NeighborExchangeV2(send_rank_ids=[0, -1, -1, -1, -1, -1, -1, -1],
1071        ...                                                           send_lens=[1, 0, 0, 0],
1072        ...                                                           recv_rank_ids=[0, -1, -1, -1, -1, -1, -1, -1],
1073        ...                                                           recv_lens=[1, 0, 0, 0], data_format="NCHW")
1074        ...
1075        ...     def construct(self, x):
1076        ...         out = self.neighbor_exchangev2(x)
1077        ...         return out
1078        >>>
1079        >>> ms.set_context(mode=ms.GRAPH_MODE)
1080        >>> init()
1081        >>> rank_id = int(os.getenv("RANK_ID"))
1082        >>> if (rank_id % 2 == 0):
1083        >>>     input_x = ms.Tensor(np.ones([1, 1, 2, 2]), dtype = ms.float32)
1084        >>>     net = Net0()
1085        >>>     output = net(input_x)
1086        >>>     print(output)
1087        >>> else:
1088        >>>     input_x = ms.Tensor(np.ones([1, 1, 2, 2]) * 2, dtype = ms.float32)
1089        >>>     net = Net1()
1090        >>>     output = net(input_x)
1091        >>>     print(output)
1092        [[[[1. 1.], [1. 1.], [2. 2.]]]]
1093
1094    Tutorial Examples:
1095        - `Distributed Set Communication Primitives - NeighborExchangeV2
1096          <https://www.mindspore.cn/docs/en/master/api_python/samples/ops/communicate_ops.html#neighborexchangev2>`_
1097
1098    """
1099
1100    @prim_attr_register
1101    def __init__(self, send_rank_ids, send_lens, recv_rank_ids, recv_lens, data_format,
1102                 group=GlobalComm.WORLD_COMM_GROUP):
1103        self.init_prim_io_names(inputs=['x'], outputs=['output'])
1104        self.send_rank_ids = send_rank_ids
1105        self.recv_rank_ids = recv_rank_ids
1106        self.send_lens = send_lens
1107        self.recv_lens = recv_lens
1108        self.format = data_format
1109        self.add_prim_attr('group', _get_group(group))
1110        self.add_prim_attr('no_eliminate', True)
1111        self.rank_size = get_group_size(_get_group(group))
1112        for rank_id in send_rank_ids:
1113            if rank_id != -1:
1114                validator.check_number_range(rank_id, 0, self.rank_size, validator.INC_LEFT, int,
1115                                             "rank_id in send_rank_ids")
1116        for rank_id in recv_rank_ids:
1117            if rank_id != -1:
1118                validator.check_number_range(rank_id, 0, self.rank_size, validator.INC_LEFT, int,
1119                                             "rank_id in recv_rank_ids")
1120
1121    def __call__(self, tensor):
1122        raise NotImplementedError
1123
1124
1125class CollectiveScatter(Primitive):
1126    r"""
1127    Scatter tensor evently across the processes in the specified communication group.
1128
1129    Note:
1130        The interface behavior only support Tensor input and scatter evenly.
1131        Only the tensor in process `src_rank` (global rank) will do scatter.
1132
1133    Args:
1134        src_rank (int, optional): Specifies the rank of the process that send the tensor.
1135            And only process `src_rank` will send the tensor.
1136        group (str, optional): The communication group to work on. Default: ``GlobalComm.WORLD_COMM_GROUP``.
1137
1138    Inputs:
1139        - **input_x** (Tensor) - The input tensor to be scattered. The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
1140
1141    Outputs:
1142        Tensor, the shape of output is :math:`(x_1/src\_rank, x_2, ..., x_R)`. The dimension 0 of data is equal to
1143        the dimension of input tensor divided by `src`, and the other dimension keep the same.
1144
1145    Raises:
1146        TypeError: If `group` is not a str.
1147        RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails.
1148        ValueError: If the local rank id of the calling process in the group
1149            is larger than the group's rank size.
1150
1151    Supported Platforms:
1152        ``Ascend``
1153
1154    Examples:
1155        .. note::
1156            Before running the following examples, you need to configure the communication environment variables.
1157
1158            For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
1159            without any third-party or configuration file dependencies.
1160            Please see the `msrun start up
1161            <https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/msrun_launcher.html>`_
1162            for more details.
1163
1164            This example should be run with 2 devices.
1165
1166        >>> import numpy as np
1167        >>> import mindspore.nn as nn
1168        >>> from mindspore import Tensor
1169        >>> from mindspore.communication.management import init, get_rank
1170        >>> from mindspore import ops
1171        >>> # Launch 2 processes.
1172        >>> init()
1173        >>> class CollectiveScatterNet(nn.Cell):
1174        >>>     def __init__(self):
1175        >>>         super(CollectiveScatter, self).__init__()
1176        >>>         self.collective_scatter = ops.CollectiveScatter(src_rank=0)
1177        >>>
1178        >>>     def construct(self, x):
1179        >>>         return self.collective_scatter(x)
1180        >>>
1181        >>> input = Tensor(np.arange(8).reshape([4, 2]).astype(np.float32))
1182        >>> net = CollectiveScatterNet()
1183        >>> output = net(input)
1184        >>> print(output)
1185        Process with rank 0: [[0. 1.],
1186                              [2. 3.]]
1187        Process with rank 1: [[4. 5.],
1188                              [6. 7.]]
1189
1190    Tutorial Examples:
1191        - `Distributed Set Communication Primitives - CollectiveScatter
1192          <https://www.mindspore.cn/docs/en/master/api_python/samples/ops/communicate_ops.html#reducescatter>`_
1193
1194    """
1195
1196    @prim_attr_register
1197    def __init__(self, src_rank=0, group=GlobalComm.WORLD_COMM_GROUP):
1198        validator.check_value_type('group', _get_group(group), (str,), self.name)
1199        self.rank_id = get_rank(_get_group(group))
1200        self.src_rank = src_rank
1201        self.rank_size = get_group_size(_get_group(group))
1202        validator.check('rank', self.rank_id, 'rank_size', self.rank_size, validator.LT, self.name)
1203        self.add_prim_attr('rank_id', self.rank_id)
1204        self.add_prim_attr('src_rank', self.src_rank)
1205        self.add_prim_attr('rank_size', self.rank_size)
1206        self.add_prim_attr('group', _get_group(group))
1207
1208
1209class CollectiveGather(Primitive):
1210    r"""
1211    Gathers tensors from the specified communication group. The operation will gather the tensor
1212    from processes according to dimension 0.
1213
1214    Note:
1215        Only the tensor in process `dest_rank` (global rank) will keep the gathered tensor. The other process
1216        will keep a tensor with shape [1], which has no mathematical meaning.
1217
1218    Args:
1219        dest_rank(int): Specifies the rank of the process that receive the tensor.
1220            And only process `dest_rank` will receive the gathered tensor.
1221        group (str, optional): The communication group to work on. Default: ``GlobalComm.WORLD_COMM_GROUP``.
1222
1223    Inputs:
1224        - **input_x** (Tensor) - The tensor to be gathered. The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
1225
1226    Outputs:
1227        Tensor, the shape of output is :math:`(\sum x_1, x_2, ..., x_R)`. The dimension 0 of data is equal to
1228        sum of the dimension of input tensor, and the other dimension keep the same.
1229
1230    Raises:
1231        TypeError: If `group` is not a str.
1232        RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails.
1233        ValueError: If the local rank id of the calling process in the group
1234                    is larger than the group's rank size.
1235
1236    Supported Platforms:
1237        ``Ascend``
1238
1239    Examples:
1240        .. note::
1241            Before running the following examples, you need to configure the communication environment variables.
1242
1243            For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
1244            without any third-party or configuration file dependencies.
1245            Please see the `msrun start up
1246            <https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/msrun_launcher.html>`_
1247            for more details.
1248
1249            This example should be run with 4 devices.
1250
1251        >>> import numpy as np
1252        >>> import mindspore as ms
1253        >>> import mindspore.nn as nn
1254        >>> from mindspore.communication import init
1255        >>> from mindspore import Tensor
1256        >>> from mindspore import ops
1257        >>> # Launch 2 processes.
1258        >>>
1259        >>> ms.set_context(mode=ms.GRAPH_MODE)
1260        >>> init()
1261        >>> class CollectiveGatherNet(nn.Cell):
1262        ...     def __init__(self):
1263        ...         super(CollectiveGatherNet, self).__init__()
1264        ...         self.collective_gather = ops.CollectiveGather(dest_rank=0)
1265        ...
1266        ...     def construct(self, x):
1267        ...         return self.collective_gather(x)
1268        ...
1269        >>> input = Tensor(np.arange(4).reshape([2, 2]).astype(np.float32))
1270        >>> net = CollectiveGatherNet()
1271        >>> output = net(input)
1272        >>> print(output)
1273        Process with rank 0: [[0. 1.],
1274                              [2. 3.],
1275                              [0. 1.],
1276                              [2. 3.]]
1277        Process with rank 1: [0.]
1278
1279    Tutorial Examples:
1280        - `Distributed Set Communication Primitives - CollectiveGather
1281          <https://www.mindspore.cn/docs/en/master/api_python/samples/ops/communicate_ops.html#collectivegather>`_
1282
1283    """
1284
1285    @prim_attr_register
1286    def __init__(self, dest_rank, group=GlobalComm.WORLD_COMM_GROUP):
1287        """Initialize Gather."""
1288        validator.check_value_type('group', _get_group(group), (str,), self.name)
1289        self.rank_id = get_rank(_get_group(group))
1290        self.dest_rank = dest_rank
1291        self.rank_size = get_group_size(_get_group(group))
1292        validator.check('rank', self.rank_id, 'rank_size', self.rank_size, validator.LT, self.name)
1293        self.add_prim_attr('rank_size', self.rank_size)
1294        self.add_prim_attr('group', _get_group(group))
1295        self.add_prim_attr('dest_rank', self.dest_rank)
1296        self.add_prim_attr('rank_id', self.rank_id)
1297
1298
1299class Barrier(PrimitiveWithInfer):
1300    """
1301    Synchronizes all processes in the specified group. Once the process call this operation, it will be blocked until
1302    all processes call this operation. After all processes finish calling the operations, the blocked processes
1303    will be waken and continue their task.
1304
1305    Args:
1306        group (str, optional): The communication group to work on. Default: ``GlobalComm.WORLD_COMM_GROUP``.
1307
1308    Raises:
1309        TypeError: If `group` is not a str.
1310        RuntimeError: If backend is invalid, or distributed initialization fails.
1311        ValueError: If the local rank id of the calling process in the group
1312                    is larger than the group's rank size.
1313
1314    Supported Platforms:
1315        ``Ascend``
1316
1317    Examples:
1318        .. note::
1319            Before running the following examples, you need to configure the communication environment variables.
1320
1321            For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
1322            without any third-party or configuration file dependencies.
1323            Please see the `msrun start up
1324            <https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/msrun_launcher.html>`_
1325            for more details.
1326
1327            This example should be run with 2 devices.
1328
1329        >>> import numpy as np
1330        >>> import mindspore.nn as nn
1331        >>> from mindspore.communication import init
1332        >>> from mindspore import Tensor
1333        >>> from mindspore import ops
1334        >>> # Launch 4 processes.
1335        >>> init()
1336        >>> class BarrierNet(nn.Cell):
1337        >>>     def __init__(self):
1338        >>>         super(BarrierNet, self).__init__()
1339        >>>         self.barrier = ops.Barrier()
1340        >>>
1341        >>>     def construct(self):
1342        >>>         self.barrier()
1343        >>> net = BarrierNet()
1344        >>> net()
1345
1346    Tutorial Examples:
1347        - `Distributed Set Communication Primitives - Barrier
1348          <https://www.mindspore.cn/docs/en/master/api_python/samples/ops/communicate_ops.html#barrier>`_
1349
1350    """
1351
1352    @prim_attr_register
1353    def __init__(self, group=GlobalComm.WORLD_COMM_GROUP):
1354        self.group = group
1355        self.add_prim_attr("side_effect_mem", True)
1356
1357    def infer_shape(self):
1358        return [1]
1359
1360    def infer_dtype(self):
1361        return mstype.float32
1362
1363
1364class Send(PrimitiveWithInfer):
1365    """
1366    Send tensors to the specified dest_rank.
1367
1368    Note:
1369        Send and Receive must be used in combination and have same sr_tag.
1370
1371    Args:
1372        sr_tag (int): The tag to identify the send/recv message. The message will
1373                      be received by the Receive op with the same "sr_tag".
1374        dest_rank (int): A required integer identifying the destination rank.
1375        group (str, optional): The communication group to work on. Default: ``GlobalComm.WORLD_COMM_GROUP``.
1376        group_back (str, optional): The communication group for backpropagation.
1377                                    Default: ``GlobalComm.WORLD_COMM_GROUP``.
1378
1379    Inputs:
1380        - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
1381
1382    Raises:
1383        TypeError: If `group` is not a str.
1384        RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails.
1385        ValueError: If the local rank id of the calling process in the group
1386                    is larger than the group's rank size.
1387
1388    Supported Platforms:
1389        ``Ascend`` ``GPU``
1390
1391    Examples:
1392        .. note::
1393            Before running the following examples, you need to configure the communication environment variables.
1394
1395            For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
1396            without any third-party or configuration file dependencies.
1397            Please see the `msrun start up
1398            <https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/msrun_launcher.html>`_
1399            for more details.
1400
1401            This example should be run with 2 devices.
1402
1403        >>> import numpy as np
1404        >>> import mindspore.nn as nn
1405        >>> from mindspore.communication import init
1406        >>> from mindspore import Tensor
1407        >>> from mindspore import ops
1408        >>>
1409        >>> init()
1410        >>> class SendNet(nn.Cell):
1411        >>>     def __init__(self):
1412        >>>         super(SendNet, self).__init__()
1413        >>>         self.depend = ops.Depend()
1414        >>>         self.send = ops.Send(st_tag=0, dest_rank=8, group="hccl_world_group")
1415        >>>
1416        >>>     def construct(self, x):
1417        >>>         out = self.depend(x, self.send(x))
1418        >>>         return out
1419        >>>
1420        >>> input_ = Tensor(np.ones([2, 8]).astype(np.float32))
1421        >>> net = Net()
1422        >>> output = net(input_)
1423
1424    Tutorial Examples:
1425        - `Distributed Set Communication Primitives - Send
1426          <https://www.mindspore.cn/docs/en/master/api_python/samples/ops/communicate_ops.html#send>`_
1427
1428    """
1429
1430    @prim_attr_register
1431    def __init__(self, sr_tag, dest_rank, group=GlobalComm.WORLD_COMM_GROUP, group_back=GlobalComm.WORLD_COMM_GROUP):
1432        self.rank = dest_rank
1433        self.sr_tag = sr_tag
1434        self.group = group
1435        self.add_prim_attr("no_eliminate", True)
1436
1437    def infer_shape(self, x_shape):
1438        self.add_prim_attr("shape", x_shape)
1439        return x_shape
1440
1441    def infer_dtype(self, x_dtype):
1442        return x_dtype
1443
1444
1445class Receive(PrimitiveWithInfer):
1446    """
1447    Receive tensors from src_rank.
1448
1449    Note:
1450        Send and Receive must be used in combination and have same sr_tag.
1451
1452    Args:
1453        sr_tag (int): A required integer identifying the send/recv message tag. The message will
1454                      will be send by the Send op with the same "sr_tag".
1455        src_rank (int): A required integer identifying the source rank.
1456        shape (list[int]): A required list identifying the shape of the tensor to be received.
1457        dtype (Type): A required Type identifying the type of the tensor to be received. The supported types:
1458                       int8/int16/int32/float16/float32.
1459        group (str, optional): The communication group to work on. Default: ``GlobalComm.WORLD_COMM_GROUP``.
1460        group_back (str, optional): The communication group for backpropagation.
1461                                    Default: ``GlobalComm.WORLD_COMM_GROUP``.
1462
1463    Outputs:
1464        Tensor, output has the same shape as the Tensor sent by `Send` operation.
1465
1466    Raises:
1467        TypeError: If `group` is not a str.
1468        RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails.
1469        ValueError: If the local rank id of the calling process in the group
1470                    is larger than the group's rank size.
1471
1472    Supported Platforms:
1473        ``Ascend`` ``GPU``
1474
1475    Examples:
1476        .. note::
1477            Before running the following examples, you need to configure the communication environment variables.
1478
1479            For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
1480            without any third-party or configuration file dependencies.
1481            Please see the `msrun start up
1482            <https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/msrun_launcher.html>`_
1483            for more details.
1484
1485            This example should be run with 2 devices.
1486
1487        >>> import numpy as np
1488        >>> import mindspore.nn as nn
1489        >>> from mindspore.communication import init
1490        >>> from mindspore import Tensor
1491        >>> from mindspore import ops
1492        >>>
1493        >>> init()
1494        >>> class ReceiveNet(nn.Cell):
1495        >>>     def __init__(self):
1496        >>>         super(ReceiveNet, self).__init__()
1497        >>>         self.recv = ops.Receive(sr_tag=0, src_rank=0, shape=[2, 8], dtype=ms.float32,
1498        >>>                               group="hccl_world_group")
1499        >>>
1500        >>>     def construct(self):
1501        >>>         out = self.recv()
1502        >>>         return out
1503        >>>
1504        >>> net = Net()
1505        >>> output = net()
1506
1507    Tutorial Examples:
1508        - `Distributed Set Communication Primitives - Receive
1509          <https://www.mindspore.cn/docs/en/master/api_python/samples/ops/communicate_ops.html#receive>`_
1510
1511    """
1512
1513    @prim_attr_register
1514    def __init__(self, sr_tag, src_rank, shape, dtype, group=GlobalComm.WORLD_COMM_GROUP,
1515                 group_back=GlobalComm.WORLD_COMM_GROUP):
1516        self.rank = src_rank
1517        self.tag = sr_tag
1518        self.shape = shape
1519        self.dtype = dtype
1520        self.group = group
1521        self.add_prim_attr("no_eliminate", True)
1522        valid_type = [mstype.float16, mstype.float32, mstype.float64, mstype.bfloat16,
1523                      mstype.int8, mstype.int16, mstype.int32, mstype.int64,
1524                      mstype.uint8, mstype.uint16, mstype.uint32, mstype.uint64]
1525        args = {"dtype": dtype}
1526        validator.check_scalar_or_tensor_types_same(args, valid_type, self.name)
1527
1528    def infer_shape(self, x_shape=None):
1529        return self.get_attr_dict()['shape']
1530
1531    def infer_dtype(self, x_dtype=None):
1532        return self.get_attr_dict()['dtype']
1533
1534
1535class _MirrorOperator(PrimitiveWithInfer):
1536    """
1537    Auto parallel virtual operator. Do nothing in forward, do all reduce and mean in backward. It is only for
1538    internal use of parallel modules and cannot be called by users.
1539
1540    Args:
1541        group (str): The communication group to work on. Default: ``None`` .
1542        dev_num (int): The device number of the group. Default: ``None`` .
1543        mean_flag (bool): Whether use mean in backward. Default: ``None`` .
1544    """
1545
1546    @prim_attr_register
1547    def __init__(self, group=None, dev_num=None, mean_flag=None):
1548        """Initialize _MirrorOperator."""
1549        self.group = group
1550        self.dev_num = dev_num
1551        self.mean_flag = mean_flag
1552        self.add_prim_attr("fusion", 1)
1553        self.add_prim_attr('order_enforce_skip', True)
1554
1555    def infer_shape(self, x_shape):
1556        return x_shape
1557
1558    def infer_dtype(self, x_dtype):
1559        return x_dtype
1560
1561
1562mirror = _MirrorOperator()
1563
1564
1565class _MirrorMiniStepOperator(PrimitiveWithInfer):
1566    """
1567    Auto parallel virtual operator. Do nothing in forward, do all reduce and mean in backward. It is only for
1568    internal use of parallel modules and cannot be called by users.
1569
1570    Args:
1571        group (str): The communication group to work on. Default: ``None`` .
1572        dev_num (int): The device number of the group. Default: ``None`` .
1573        mean_flag (bool): Whether use mean in backward. Default: ``None`` .
1574        grad_accumulation_step (int): The grad accumulation step. Default: ``None`` .
1575    """
1576
1577    @prim_attr_register
1578    def __init__(self, group=None, dev_num=None, mean_flag=None, grad_accumulation_step=None):
1579        """Initialize _MirrorMiniStepOperator."""
1580        self.group = group
1581        self.dev_num = dev_num
1582        self.mean_flag = mean_flag
1583        self.grad_accumulation_step = grad_accumulation_step
1584        self.add_prim_attr('order_enforce_skip', True)
1585        self.add_prim_attr('side_effect_backprop_mem', True)
1586
1587    def infer_shape(self, x_shape, z_shape):
1588        return x_shape
1589
1590    def infer_dtype(self, x_dtype, z_shape):
1591        return x_dtype
1592
1593
1594mirror_mini_step = _MirrorMiniStepOperator()
1595
1596
1597class _VirtualDiv(PrimitiveWithInfer):
1598    """
1599    Auto parallel virtual operator. Do nothing in forward, do Div in backward.
1600
1601    Args:
1602        divisor: float32
1603    """
1604
1605    @prim_attr_register
1606    def __init__(self, divisor=None):
1607        """Initialize _VirtualDiv."""
1608        self.divisor = divisor
1609        self.add_prim_attr('order_enforce_skip', True)
1610
1611    def infer_shape(self, x_shape):
1612        return x_shape
1613
1614    def infer_dtype(self, x_dtype):
1615        return x_dtype
1616
1617
1618virtual_div = _VirtualDiv()
1619
1620
1621class _VirtualPipelineEnd(PrimitiveWithInfer):
1622    """
1623    Auto parallel virtual operator. Do nothing in forward and backward, mark end node in pipeline parallel.
1624
1625    Args:
1626        divisor: float32
1627    """
1628
1629    @prim_attr_register
1630    def __init__(self):
1631        """Initialize _VirtualPipelineEnd."""
1632
1633    def infer_shape(self, x_shape):
1634        return x_shape
1635
1636    def infer_dtype(self, x_dtype):
1637        return x_dtype
1638
1639
1640virtual_pipeline_end = _VirtualPipelineEnd()
1641
1642
1643class _VirtualAdd(PrimitiveWithInfer):
1644    """Auto parallel virtual operator. Do nothing in forward, do Add in backward."""
1645
1646    @prim_attr_register
1647    def __init__(self):
1648        """Initialize _VirtualAdd."""
1649        self.add_prim_attr('order_enforce_skip', True)
1650
1651    def infer_shape(self, x_shape, y_shape):
1652        return x_shape
1653
1654    def infer_dtype(self, x_dtype, y_dtype):
1655        return x_dtype
1656
1657
1658class _VirtualDataset(PrimitiveWithInfer):
1659    """
1660    Auto parallel virtual dataset operator.
1661
1662    It would insert VirtualDataset operator in forward computation and be deleted before backward computation.
1663    """
1664
1665    @prim_attr_register
1666    def __init__(self):
1667        """Initialize _VirtualDataset."""
1668        self.add_prim_attr('order_enforce_skip', True)
1669
1670    def infer_shape(self, *args):
1671        return args
1672
1673    def infer_dtype(self, *args):
1674        return args
1675
1676
1677virtual_dataset = _VirtualDataset()
1678
1679
1680class _VirtualAssignAdd(PrimitiveWithInfer):
1681    """
1682    Auto parallel virtual operator. Do nothing in forward, do AssignAdd in backward. It is only for
1683    internal use of parallel modules and cannot be called by users.
1684
1685    """
1686
1687    @prim_attr_register
1688    def __init__(self):
1689        """Initialize _VirtualAssignAdd."""
1690        self.add_prim_attr('order_enforce_skip', True)
1691        self.add_prim_attr('side_effect_backprop_mem', True)
1692
1693    def infer_shape(self, x_shape, y_shape):
1694        return x_shape
1695
1696    def infer_dtype(self, x_dtype, y_dtype):
1697        return x_dtype
1698virtual_assign_add = _VirtualAssignAdd()
1699
1700
1701class _VirtualAccuGrad(PrimitiveWithInfer):
1702    """
1703    Auto parallel virtual operator. Do nothing in forward, return y in backward. It is only for
1704    internal use of parallel modules and cannot be called by users.
1705    """
1706
1707    @prim_attr_register
1708    def __init__(self):
1709        """Initialize _VirtualAccuGrad."""
1710        self.add_prim_attr('order_enforce_skip', True)
1711
1712    def infer_shape(self, x_shape, y_shape):
1713        return x_shape
1714
1715    def infer_dtype(self, x_dtype, y_dtype):
1716        return x_dtype
1717
1718
1719virtual_accu_grad = _VirtualAccuGrad()
1720
1721
1722class _MirrorMicroStepOperator(PrimitiveWithInfer):
1723    """
1724    Auto parallel virtual operator. Do nothing in forward, do all reduce and mean in backward. It is only for
1725    internal use of parallel modules and cannot be called by users.
1726
1727    Args:
1728        group (str): The communication group to work on. Default: ``None`` .
1729        dev_num (int): The device number of the group. Default: ``None`` .
1730        mean_flag (bool): Whether use mean in backward. Default: ``None`` .
1731    """
1732
1733    @prim_attr_register
1734    def __init__(self, group=None, dev_num=None, mean_flag=None):
1735        """Initialize _MirrorMicroStepOperator."""
1736        self.group = group
1737        self.dev_num = dev_num
1738        self.mean_flag = mean_flag
1739        self.add_prim_attr('order_enforce_skip', True)
1740        self.add_prim_attr('side_effect_backprop_mem', True)
1741
1742    def infer_shape(self, x_shape, z_shape):
1743        return x_shape
1744
1745    def infer_dtype(self, x_dtype, z_shape):
1746        return x_dtype
1747
1748
1749class _VirtualOutput(PrimitiveWithInfer):
1750    """
1751    Auto parallel virtual out operator.
1752
1753    It would insert VirtualOutput operator in forward computation and be deleted before backward computation.
1754    """
1755
1756    @prim_attr_register
1757    def __init__(self):
1758        """Initialize _VirtualOutput."""
1759        self.add_prim_attr('order_enforce_skip', True)
1760
1761    def infer_shape(self, x_shape):
1762        return x_shape
1763
1764    def infer_dtype(self, x_dtype):
1765        return x_dtype
1766
1767
1768class _GetTensorSlice(PrimitiveWithInfer):
1769    """
1770    Gets tensor slice by device matrix and tensor map.
1771
1772    Args:
1773        dev_mat (tuple): The device matrix of the slice tensor.
1774        tensor_map (tuple): The tensor map of the slice tensor.
1775    """
1776
1777    @prim_attr_register
1778    def __init__(self):
1779        """Initialize _GetTensorSlice."""
1780        self.add_prim_attr('order_enforce_skip', True)
1781
1782    def infer_value(self, x, dev_mat, tensor_map, slice_shape, full_shape):
1783        from mindspore.parallel._tensor import _load_tensor
1784        validator.check_value_type("dev_mat", dev_mat, [tuple], self.name)
1785        validator.check_value_type("tensor_map", tensor_map, [tuple], self.name)
1786        tensor_slice = _load_tensor(x, dev_mat, tensor_map, full_shape)
1787        if tensor_slice.shape != slice_shape:
1788            tensor_slice = tensor_slice.reshape(slice_shape)
1789        return Tensor(tensor_slice, x.dtype)
1790
1791
1792class BatchISendIRecv(PrimitiveWithInfer):
1793    """
1794    Batch send and recv tensors asynchronously.
1795
1796    Note:
1797        - The ``isend`` and ``irecv`` in ``op_types`` between ranks need to match each other.
1798        - ``isend`` and ``irecv`` in a batch can only be used in the same communication group.
1799
1800    Args:
1801        op_types(Union[tuple[str], list[str]]): "isend" or "irecv" to indicate the order and number of communication.
1802        remote_ranks(Union[tuple[int], list[int]]): src or dst rank that matches the op_types.
1803        receive_shapes(Union[tuple[int], list[int]]): receive tensor shapes that matches "irecv" in op_types.
1804        receive_types(Union[tuple[mindspore.dtype], list[mindspore.dtype]]): receive tensor dtype
1805          that matches "irecv" in op_types.
1806        group (str): The communication group to work on. Default: ``GlobalComm.WORLD_COMM_GROUP``, which
1807          means ``"hccl_world_group"`` in Ascend, and ``"nccl_world_group"`` in GPU.
1808
1809    Inputs:
1810        - **input_x** (Union[tuple[Tensor], list[Tensor], tuple(None)]) -
1811          The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
1812
1813    Outputs:
1814        tuple(Tensor). Output tensors is corresponding to ``op_types``:
1815        At ``"isend"`` position, output tensor is a fake tensor with scalar, which has no meaning.
1816        At ``"irecv"`` position, output tensor is a tensor received from remote end.
1817
1818
1819    Raises:
1820        TypeError: If ``group`` is not a str.
1821        TypeError: If ``op_types``, ``receive_shapes``, ``receive_dtypes``, ``remote_ranks`` are not tuple or list.
1822        ValueError: If the length of ``receive_shapes`` and ``receive_dtypes`` are not the same.
1823        ValueError: If the length of ``op_types`` and ``remote_ranks`` are not the same.
1824        RuntimeError: If the length of input tensors and ``"isend"`` count are not the same.
1825
1826    Supported Platforms:
1827        ``Ascend``
1828
1829    Examples:
1830        .. note::
1831            Before running the following examples, you need to configure the communication environment variables.
1832
1833            For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
1834            without any third-party or configuration file dependencies.
1835
1836            Please see the `msrun start up
1837            <https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/msrun_launcher.html>`_
1838            for more details.
1839
1840            This example should be run with 2 devices.
1841
1842        >>> import numpy as np
1843        >>> import mindspore as ms
1844        >>> from mindspore import ops
1845        >>> import mindspore.nn as nn
1846        >>> from mindspore.communication import init, get_rank
1847        >>> from mindspore import Tensor
1848        >>>
1849        >>> init()
1850        >>> rank = get_rank()
1851        >>> class Net(nn.Cell):
1852        ...     def __init__(self):
1853        ...         super(Net, self).__init__()
1854        ...         if rank == 0:
1855        ...             remote_rank = [1, 1]
1856        ...         else:
1857        ...             remote_rank = [0, 0]
1858        ...         self.batchisendirecv = ops.BatchISendIRecv(("isend", "irecv"), remote_rank, [()], (ms.float32,))
1859        ...
1860        ...     def construct(self, x):
1861        ...         if isinstance(x, Tensor):
1862        ...             x = (x,)
1863        ...         return self.batchisendirecv(x)
1864        ...
1865        >>> send_x = Tensor(rank + 1).astype(ms.float32)
1866        >>> net = Net()
1867        >>> output = net(send_x)
1868        >>> print(output)
1869        rank 0:
1870        (Tensor(shape=[], dtype=Float32, value= 0), Tensor(shape=[], dtype=Float32, value= 2))
1871        rank 1:
1872        (Tensor(shape=[], dtype=Float32, value= 0), Tensor(shape=[], dtype=Float32, value= 1))
1873
1874    Tutorial Examples:
1875        - `Distributed Set Communication Primitives - BatchISendIRecv
1876          <https://www.mindspore.cn/docs/en/master/api_python/samples/ops/communicate_ops.html#allgather>`_
1877
1878    """
1879
1880    @prim_attr_register
1881    def __init__(self, op_types, remote_ranks, receive_shapes=None,
1882                 receive_dtypes=None, group=GlobalComm.WORLD_COMM_GROUP):
1883        if receive_shapes is None:
1884            receive_shapes = ()
1885        else:
1886            validator.check_value_type("receive_shapes", receive_shapes, [tuple, list], self.name)
1887
1888        if receive_dtypes is None:
1889            receive_dtypes = ()
1890        else:
1891            validator.check_value_type("receive_dtypes", receive_dtypes, [tuple, list], self.name)
1892
1893        validator.check_value_type("op_types", op_types, [tuple, list], self.name)
1894        validator.check_value_type("remote_ranks", remote_ranks, [tuple, list], self.name)
1895
1896        if len(receive_shapes) != len(receive_dtypes):
1897            raise ValueError("length of receive_shapes and receive_shapes must be the same, "
1898                             f"but got receive_shapes: {len(receive_shapes)} "
1899                             f" and receive_shapes: {receive_dtypes}")
1900
1901        if len(op_types) != len(remote_ranks):
1902            raise ValueError("length of op_types and remote_ranks must be the same.")
1903
1904        if group is None:
1905            group = GlobalComm.WORLD_COMM_GROUP
1906        self.add_prim_attr('group', group)
1907        self.add_prim_attr('op_types', op_types)
1908        self.add_prim_attr('remote_ranks', remote_ranks)
1909        self.add_prim_attr('receive_shapes', receive_shapes)
1910        self.add_prim_attr('receive_dtypes', receive_dtypes)
1911        self.add_prim_attr('no_eliminate', True)
1912
1913
1914class AlltoAllV(PrimitiveWithInfer):
1915    """
1916    AllToAll which support uneven split.
1917
1918    Note:
1919        - Only support flatten tensor as input. input tensor should be flattened and
1920          concatenated before call this primitive.
1921
1922    Args:
1923        send_numel_list(Union[tuple[int], list[int]]): split numel to scatter to different remote rank.
1924        recv_numel_list(Union[tuple[int], list[int]]): split numel to gather from different remote rank.
1925        group (str): The communication group to work on. Default: ``GlobalComm.WORLD_COMM_GROUP``, which
1926          means ``"hccl_world_group"`` in Ascend, and ``"nccl_world_group"`` in GPU.
1927
1928    Inputs:
1929        - **input_x** (Tensor) - flatten tensor to scatter. The shape of tensor is :math:`(x_1)`.
1930
1931    Outputs:
1932        Tensor. flattened and concatenated tensor gather from remote ranks.
1933        If gather result is empty, it will return a Tensor with value 0, which has no actual meaning.
1934
1935    Raises:
1936        TypeError: If 'send_numel_list'  or 'recv_numel_list' is not type of tuple and list.
1937
1938    Supported Platforms:
1939        ``Ascend``
1940
1941    Examples:
1942        .. note::
1943            Before running the following examples, you need to configure the communication environment variables.
1944
1945            For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method
1946            without any third-party or configuration file dependencies.
1947
1948            Please see the `msrun start up
1949            <https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/msrun_launcher.html>`_
1950            for more details.
1951
1952            This example should be run with 2 devices.
1953
1954        >>> import numpy as np
1955        >>> import mindspore as ms
1956        >>> from mindspore import ops
1957        >>> import mindspore.nn as nn
1958        >>> from mindspore.communication import init, get_rank
1959        >>> from mindspore import Tensor
1960        >>>
1961        >>> init()
1962        >>> rank = get_rank()
1963        >>> class Net(nn.Cell):
1964        ...     def __init__(self):
1965        ...         super(Net, self).__init__()
1966        ...         if rank == 0:
1967        ...             self.all_to_all = ops.AlltoAllV([1, 2], [1, 2])
1968        ...         else:
1969        ...             self.all_to_all = ops.AlltoAllV([2, 1], [2, 1])
1970        ...
1971        ...     def construct(self, x):
1972        ...         return self.all_to_all(x)
1973        ...
1974        >>> if rank == 0:
1975        >>>    send_tensor = Tensor([0, 1, 2.])
1976        >>> elif rank == 1:
1977        >>>    send_tensor = Tensor([3, 4, 5.])
1978        >>> net = Net()
1979        >>> output = net(send_tensor)
1980        >>> print(output)
1981        rank 0:
1982        [0. 3. 4]
1983        rank 1:
1984        [1. 2. 5]
1985
1986    """
1987
1988    @prim_attr_register
1989    def __init__(self, send_numel_list, recv_numel_list, group=None):
1990        validator.check_value_type("send_numel_list", send_numel_list, [tuple, list], self.name)
1991        validator.check_value_type("recv_numel_list", recv_numel_list, [tuple, list], self.name)
1992        if group is None:
1993            group = GlobalComm.WORLD_COMM_GROUP
1994        self.add_prim_attr('group', group)
1995        self.add_prim_attr('send_numel_list', send_numel_list)
1996        self.add_prim_attr('recv_numel_list', recv_numel_list)
1997