• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16"""comm_ops"""
17
18from mindspore.common import Tensor
19from ..._checkparam import Validator as validator
20from ..._checkparam import Rel
21from ...communication.management import get_rank, get_group_size, GlobalComm, _get_group
22from ...common import dtype as mstype
23from ..primitive import PrimitiveWithInfer, PrimitiveWithCheck, Primitive, prim_attr_register
24from ...common.api import context
25
26
27class ReduceOp:
28    """
29    Operation options for reducing tensors. This is an enumerated type, not an operator.
30    Mainly used in data parallel mode.
31
32    The main calling methods are as follows:
33
34    - SUM: ReduceOp.SUM.
35    - MAX: ReduceOp.MAX.
36    - MIN: ReduceOp.MIN.
37    - PROD: ReduceOp.PROD.
38
39    There are four kinds of operation options, "SUM", "MAX", "MIN", and "PROD".
40
41    - SUM: Take the sum.
42    - MAX: Take the maximum.
43    - MIN: Take the minimum.
44    - PROD: Take the product.
45
46    Note:
47        For more, refer to example. This needs to run in an environment with multiple graphics cards.
48
49    Supported Platforms:
50        ``Ascend`` ``GPU``
51
52    Examples:
53        >>> from mindspore.communication import init
54        >>> from mindspore import Tensor, ops
55        >>> from mindspore.ops import ReduceOp
56        >>> import mindspore.nn as nn
57        >>>
58        >>> init()
59        >>> class Net(nn.Cell):
60        ...     def __init__(self):
61        ...         super(Net, self).__init__()
62        ...         self.allreduce_sum = ops.AllReduce(ReduceOp.SUM, group="nccl_world_group")
63        ...
64        ...     def construct(self, x):
65        ...         return self.allreduce_sum(x)
66        ...
67        >>> input_ = Tensor(np.ones([2, 8]).astype(np.float32))
68        >>> net = Net()
69        >>> output = net(input_)
70        >>> print(output)
71        [[4. 5. 6. 0. 0. 0. 0. 0.]
72         [0. 0. 0. 0. 0. 0. 0. 0.]]
73    """
74    SUM = "sum"
75    MAX = "max"
76    MIN = "min"
77    PROD = "prod"
78
79
80target_dtypes = (mstype.int8, mstype.int32, mstype.float16, mstype.float32)
81
82
83def check_hcom_group_valid(group, prim_name=None):
84    """Check if hcom group is valid."""
85    msg_pfefix = f"For '{prim_name}', only" if prim_name else "Only"
86    if context.get_context("mode") == context.PYNATIVE_MODE and \
87            context.get_context("device_target") == "Ascend" and \
88            group != GlobalComm.WORLD_COMM_GROUP:
89        raise RuntimeError(f"{msg_pfefix} hccl_world_group is supported in Pynative mode, but got 'group': {group}.")
90
91
92class AllReduce(PrimitiveWithInfer):
93    """
94    Reduces the tensor data across all devices in such a way that all devices will get the same final result.
95
96    Note:
97        The operation of AllReduce does not support "prod" currently.
98        The tensors must have the same shape and format in all processes of the collection.
99
100    Args:
101        op (str): Specifies an operation used for element-wise reductions,
102                  like sum, max, and min. Default: ReduceOp.SUM.
103        group (str): The communication group to work on. Default: "hccl_world_group".
104
105    Inputs:
106        - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
107
108    Outputs:
109        Tensor, has the same shape of the input, i.e., :math:`(x_1, x_2, ..., x_R)`.
110        The contents depend on the specified operation.
111
112    Raises:
113        TypeError: If any of `op` and `group` is not a str,
114                   or fusion is not an integer, or the input's dtype is bool.
115        ValueError: If the `op` is "prod".
116
117    Supported Platforms:
118        ``Ascend`` ``GPU``
119
120    Examples:
121        >>> # This example should be run with two devices. Refer to the tutorial > Distributed Training on mindspore.cn
122        >>> import numpy as np
123        >>> from mindspore.communication import init
124        >>> from mindspore import Tensor
125        >>> from mindspore.ops import ReduceOp
126        >>> import mindspore.nn as nn
127        >>> import mindspore.ops as ops
128        >>>
129        >>> init()
130        >>> class Net(nn.Cell):
131        ...     def __init__(self):
132        ...         super(Net, self).__init__()
133        ...         self.allreduce_sum = ops.AllReduce(ReduceOp.SUM)
134        ...
135        ...     def construct(self, x):
136        ...         return self.allreduce_sum(x)
137        ...
138        >>> input_ = Tensor(np.ones([2, 8]).astype(np.float32))
139        >>> net = Net()
140        >>> output = net(input_)
141        >>> print(output)
142        [[2. 2. 2. 2. 2. 2. 2. 2.]
143         [2. 2. 2. 2. 2. 2. 2. 2.]]
144    """
145
146    @prim_attr_register
147    def __init__(self, op=ReduceOp.SUM, group=GlobalComm.WORLD_COMM_GROUP):
148        """Initialize AllReduce."""
149        if not isinstance(op, type(ReduceOp.SUM)):
150            raise TypeError(f"For '{self.name}', the 'op' should be str, but got {type(op).__name__}.")
151        if not isinstance(_get_group(group), str):
152            raise TypeError(f"For '{self.name}', the 'group' should be str, "
153                            f"but got {type(_get_group(group)).__name__}.")
154        check_hcom_group_valid(group, prim_name=self.name)
155        self.op = op
156        self.add_prim_attr('group', _get_group(group))
157        self.add_prim_attr('fusion', 0)
158        self.add_prim_attr('index', 0)
159        self.add_prim_attr('no_elimilate', True)
160
161    def infer_shape(self, x_shape):
162        return x_shape
163
164    def infer_dtype(self, x_dtype):
165        validator.check_tensor_dtype_valid('x', x_dtype, target_dtypes, self.name)
166        return x_dtype
167
168
169class AllGather(PrimitiveWithInfer):
170    """
171    Gathers tensors from the specified communication group.
172
173    Note:
174        The tensors must have the same shape and format in all processes of the collection.
175
176    Args:
177        group (str): The communication group to work on. Default: "hccl_world_group".
178
179    Inputs:
180        - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
181
182    Outputs:
183        Tensor. If the number of devices in the group is N,
184        then the shape of output is :math:`(N, x_1, x_2, ..., x_R)`.
185
186    Raises:
187        TypeError: If `group` is not a str.
188        ValueError: If the local rank id of the calling process in the group
189                    is larger than the group's rank size.
190
191    Supported Platforms:
192        ``Ascend`` ``GPU``
193
194    Examples:
195        >>> # This example should be run with two devices. Refer to the tutorial > Distributed Training on mindspore.cn
196        >>> import numpy as np
197        >>> import mindspore.ops as ops
198        >>> import mindspore.nn as nn
199        >>> from mindspore.communication import init
200        >>> from mindspore import Tensor, context
201        >>>
202        >>> context.set_context(mode=context.GRAPH_MODE)
203        >>> init()
204        ... class Net(nn.Cell):
205        ...     def __init__(self):
206        ...         super(Net, self).__init__()
207        ...         self.allgather = ops.AllGather()
208        ...
209        ...     def construct(self, x):
210        ...         return self.allgather(x)
211        ...
212        >>> input_x = Tensor(np.ones([2, 8]).astype(np.float32))
213        >>> net = Net()
214        >>> output = net(input_x)
215        >>> print(output)
216        [[1. 1. 1. 1. 1. 1. 1. 1.]
217         [1. 1. 1. 1. 1. 1. 1. 1.]
218         [1. 1. 1. 1. 1. 1. 1. 1.]
219         [1. 1. 1. 1. 1. 1. 1. 1.]]
220    """
221
222    @prim_attr_register
223    def __init__(self, group=GlobalComm.WORLD_COMM_GROUP):
224        """Initialize AllGather."""
225        validator.check_value_type('group', _get_group(group), (str,), self.name)
226        self.rank = get_rank(_get_group(group))
227        self.rank_size = get_group_size(_get_group(group))
228        validator.check('rank', self.rank, 'rank_size', self.rank_size, Rel.LT, self.name)
229        self.add_prim_attr('rank_size', self.rank_size)
230        self.add_prim_attr('group', _get_group(group))
231        self.add_prim_attr('fusion', 0)
232        self.add_prim_attr('mean_flag', False)
233        self.add_prim_attr('no_elimilate', True)
234
235    def infer_shape(self, x_shape):
236        validator.check_positive_int(len(x_shape), "x shape", self.name)
237        if x_shape[0] > 0:
238            x_shape[0] = x_shape[0] * self.rank_size
239        return x_shape
240
241    def infer_dtype(self, x_dtype):
242        validator.check_tensor_dtype_valid('x', x_dtype, target_dtypes, self.name)
243        return x_dtype
244
245    def __call__(self, tensor):
246        raise NotImplementedError
247
248
249class _MiniStepAllGather(PrimitiveWithInfer):
250    """
251    Auto parallel virtual operator. Do nothing in forward, do reducescatter in backward in mini-step. It is only for
252    internal use of parallel modules and cannot be called by users.
253
254    Args:
255        group (str): The communication group to work on. Default: None.
256        grad_accumulation_step (int): The grad accumulation step. Default: None.
257    """
258
259    @prim_attr_register
260    def __init__(self, group=GlobalComm.WORLD_COMM_GROUP, grad_accumulation_step=None, mean_flag=None):
261        """Initialize _MiniStepAllGather."""
262        validator.check_value_type('group', _get_group(group), (str,), self.name)
263        self.rank = get_rank(_get_group(group))
264        self.rank_size = get_group_size(_get_group(group))
265        validator.check('rank', self.rank, 'rank_size', self.rank_size, Rel.LT, self.name)
266        self.add_prim_attr('rank_size', self.rank_size)
267        self.add_prim_attr('group', _get_group(group))
268        self.add_prim_attr('fusion', 1)
269        self.grad_accumulation_step = grad_accumulation_step
270        self.mean_flag = mean_flag
271
272    def infer_shape(self, x_shape, z_shape):
273        validator.check_positive_int(len(x_shape), "x shape", self.name)
274        if x_shape[0] > 0:
275            x_shape[0] = x_shape[0] * self.rank_size
276        return x_shape
277
278    def infer_dtype(self, x_dtype, z_shape):
279        validator.check_tensor_dtype_valid('x', x_dtype, target_dtypes, self.name)
280        return x_dtype
281
282
283class _MicroStepAllGather(PrimitiveWithInfer):
284    """
285    Auto parallel virtual operator. Do nothing in forward, do reducescatter in backward in mini-step. It is only for
286    internal use of parallel modules and cannot be called by users.
287
288    Args:
289        group (str): The communication group to work on. Default: None.
290    """
291
292    @prim_attr_register
293    def __init__(self, group=GlobalComm.WORLD_COMM_GROUP, mean_flag=None):
294        validator.check_value_type('group', _get_group(group), (str,), self.name)
295        self.rank = get_rank(_get_group(group))
296        self.rank_size = get_group_size(_get_group(group))
297        validator.check('rank', self.rank, 'rank_size', self.rank_size, Rel.LT, self.name)
298        self.add_prim_attr('rank_size', self.rank_size)
299        self.add_prim_attr('group', _get_group(group))
300        self.add_prim_attr('fusion', 1)
301        self.mean_flag = mean_flag
302
303    def infer_shape(self, x_shape, z_shape):
304        validator.check_positive_int(len(x_shape), "x shape", self.name)
305        if x_shape[0] > 0:
306            x_shape[0] = x_shape[0] * self.rank_size
307        return x_shape
308
309    def infer_dtype(self, x_dtype, z_dtype):
310        validator.check_tensor_dtype_valid('x', x_dtype, target_dtypes, self.name)
311        return x_dtype
312
313
314class _HostAllGather(PrimitiveWithInfer):
315    """
316    Gathers tensors from the specified communication group on host.
317
318    Note:
319        The tensors must have the same shape and format in all processes of the collection.
320        _HostAllGather is a host-side operator, it depends on OpenMPI and must use build option -M on
321        to enable it. Using mpirun command to run it:
322        mpirun -output-filename log -merge-stderr-to-stdout -np 3 python test_host_all_gather.py
323
324    Args:
325        group (Union[tuple[int],list[int]]): The rand_ids of communication group to work on. Default: None.
326
327    Raises:
328        TypeError: If group is not a list nor tuple, or elements of group are not int.
329        ValueError: If group is not set, or rank_id from group not in [0, 7].
330
331    Inputs:
332        - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
333
334    Outputs:
335        Tensor. If the number of devices in the group is N,
336        then the shape of output is :math:`(N, x_1, x_2, ..., x_R)`.
337    """
338
339    @prim_attr_register
340    def __init__(self, group=None):
341        """Initialize _HostAllGather."""
342        if group is None:
343            raise ValueError(f"For '{self.name}', the 'group' cannot be None, but got {group}.")
344        validator.check_value_type('group', group, (tuple, list), self.name)
345        validator.check_int(len(group), 2, Rel.GE, "group size", self.name)
346        for r in group:
347            validator.check_int_range(r, 0, 7, Rel.INC_BOTH, "rank_id", self.name)
348            validator.check_value_type("rank_id", r, (int,), self.name)
349        self.group_size = len(group)
350        self.add_prim_attr('group', group)
351        self.add_prim_attr('no_elimilate', True)
352
353    def infer_shape(self, x_shape):
354        validator.check_positive_int(len(x_shape), "x shape", self.name)
355        if x_shape[0] > 0:
356            x_shape[0] = x_shape[0] * self.group_size
357        return x_shape
358
359    def infer_dtype(self, x_dtype):
360        validator.check_tensor_dtype_valid('x', x_dtype, target_dtypes, self.name)
361        return x_dtype
362
363    def __call__(self, tensor):
364        raise NotImplementedError
365
366
367class ReduceScatter(PrimitiveWithInfer):
368    """
369    Reduces and scatters tensors from the specified communication group.
370
371    Note:
372        The back propagation of the op is not supported yet. Stay tuned for more.
373        The tensors must have the same shape and format in all processes of the collection.
374
375    Args:
376        op (str): Specifies an operation used for element-wise reductions,
377                  like SUM, MAX, AVG. Default: ReduceOp.SUM.
378        group (str): The communication group to work on. Default: "hccl_world_group".
379
380    Raises:
381        TypeError: If any of operation and group is not a string.
382        ValueError: If the first dimension of the input cannot be divided by the rank size.
383
384    Supported Platforms:
385        ``Ascend`` ``GPU``
386
387    Examples:
388        >>> # This example should be run with two devices. Refer to the tutorial > Distributed Training on mindspore.cn
389        >>> from mindspore import Tensor, context
390        >>> from mindspore.communication import init
391        >>> from mindspore.ops import ReduceOp
392        >>> import mindspore.nn as nn
393        >>> import mindspore.ops as ops
394        >>> import numpy as np
395        >>>
396        >>> context.set_context(mode=context.GRAPH_MODE)
397        >>> init()
398        >>> class Net(nn.Cell):
399        ...     def __init__(self):
400        ...         super(Net, self).__init__()
401        ...         self.reducescatter = ops.ReduceScatter(ReduceOp.SUM)
402        ...
403        ...     def construct(self, x):
404        ...         return self.reducescatter(x)
405        ...
406        >>> input_ = Tensor(np.ones([8, 8]).astype(np.float32))
407        >>> net = Net()
408        >>> output = net(input_)
409        >>> print(output)
410        [[2. 2. 2. 2. 2. 2. 2. 2.]
411         [2. 2. 2. 2. 2. 2. 2. 2.]
412         [2. 2. 2. 2. 2. 2. 2. 2.]
413         [2. 2. 2. 2. 2. 2. 2. 2.]]
414    """
415
416    @prim_attr_register
417    def __init__(self, op=ReduceOp.SUM, group=GlobalComm.WORLD_COMM_GROUP):
418        """Initialize ReduceScatter."""
419        validator.check_value_type('op', op, (type(ReduceOp.SUM),), self.name)
420        validator.check_value_type('group', _get_group(group), (str,), self.name)
421        self.op = op
422        self.rank_size = get_group_size(_get_group(group))
423        self.add_prim_attr('rank_size', self.rank_size)
424        self.add_prim_attr('group', _get_group(group))
425        self.add_prim_attr('fusion', 0)
426        self.add_prim_attr('no_elimilate', True)
427
428    def infer_shape(self, x_shape):
429        if self.rank_size == 0:
430            raise ValueError(f"For '{self.name}', the 'rank_size' cannot be zero, but got {self.rank_size}.")
431        if x_shape[0] % self.rank_size != 0:
432            raise ValueError(f"For '{self.name}', the first dimension of 'x_shape' should be divided by 'rank_size', "
433                             f"but got 'x_shape[0]': {x_shape[0]}, 'rank_size': {self.rank_size}.")
434        x_shape[0] = int(x_shape[0] / self.rank_size)
435        return x_shape
436
437    def infer_dtype(self, x_dtype):
438        validator.check_tensor_dtype_valid('x', x_dtype, target_dtypes, self.name)
439        return x_dtype
440
441    def __call__(self, tensor):
442        raise NotImplementedError
443
444
445class _HostReduceScatter(PrimitiveWithInfer):
446    """
447    Reduces and scatters tensors from the specified communication group on host.
448
449    Note:
450        The tensors must have the same shape and format in all processes of the collection.
451        _HostReduceScatter is a host-side operator, it depends on OpenMPI and must use build option
452        -M on to enable it. Using mpirun command to run it:
453        mpirun -output-filename log -merge-stderr-to-stdout -np 3 python test_host_reduce_scatter.py
454
455    Args:
456        op (str): Specifies an operation used for element-wise reductions,
457                  like sum, max, avg. Default: ReduceOp.SUM.
458        group (Union[tuple[int],list[int]]): The rand_ids of communication group to work on. Default: None.
459
460    Raises:
461        TypeError: If op is not a string and group is not a list nor tuple,
462                   or elements of group are not int.
463        ValueError: If the first dimension of input can not be divided by group size,
464                    or group is not set, or rank_id not in [0, 7].
465    """
466
467    @prim_attr_register
468    def __init__(self, op=ReduceOp.SUM, group=None):
469        """Initialize _HostReduceScatter."""
470        if group is None:
471            raise ValueError(f"For '{self.name}', the 'group' cannot be None, but got {group}.")
472        validator.check_value_type('op', op, (type(ReduceOp.SUM),), self.name)
473        validator.check_value_type('group', group, (tuple, list), self.name)
474        validator.check_int(len(group), 2, Rel.GE, "group size", self.name)
475        for r in group:
476            validator.check_int_range(r, 0, 7, Rel.INC_BOTH, "rank_id", self.name)
477            validator.check_value_type("rank_id", r, (int,), self.name)
478        self.op = op
479        self.group_size = len(group)
480        self.add_prim_attr('group', group)
481        self.add_prim_attr('no_elimilate', True)
482
483    def infer_shape(self, x_shape):
484        if x_shape[0] % self.group_size != 0:
485            raise ValueError(f"For '{self.name}', the first dimension of 'x_shape' should be divided by 'group_size', "
486                             f"but got 'x_shape[0]': {x_shape[0]}, 'rank_size': {self.group_size}.")
487        x_shape[0] = int(x_shape[0] / self.group_size)
488        return x_shape
489
490    def infer_dtype(self, x_dtype):
491        validator.check_tensor_dtype_valid('x', x_dtype, target_dtypes, self.name)
492        return x_dtype
493
494    def __call__(self, tensor):
495        raise NotImplementedError
496
497
498class Broadcast(PrimitiveWithInfer):
499    """
500    Broadcasts the tensor to the whole group.
501
502    Note:
503        The tensors must have the same shape and format in all processes of the collection.
504
505    Args:
506        root_rank (int): Source rank. Required in all processes except the one
507                   that is sending the data.
508        group (str): The communication group to work on. Default: "hccl_world_group".
509
510    Inputs:
511        - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
512
513    Outputs:
514        Tensor, has the same shape of the input, i.e., :math:`(x_1, x_2, ..., x_R)`.
515        The contents depend on the data of the `root_rank` device.
516
517    Raises:
518        TypeError: If root_rank is not a integer or group is not a string.
519
520    Supported Platforms:
521        ``Ascend`` ``GPU``
522
523    Examples:
524        >>> # This example should be run with multiple processes.
525        >>> # Please refer to the tutorial > Distributed Training on mindspore.cn.
526        >>> from mindspore import Tensor
527        >>> from mindspore import context
528        >>> from mindspore.communication import init
529        >>> import mindspore.nn as nn
530        >>> import mindspore.ops as ops
531        >>> import numpy as np
532        >>>
533        >>> context.set_context(mode=context.GRAPH_MODE)
534        >>> init()
535        >>> class Net(nn.Cell):
536        ...     def __init__(self):
537        ...         super(Net, self).__init__()
538        ...         self.broadcast = ops.Broadcast(1)
539        ...
540        ...     def construct(self, x):
541        ...         return self.broadcast((x,))
542        ...
543        >>> input_x = Tensor(np.ones([2, 4]).astype(np.int32))
544        >>> net = Net()
545        >>> output = net(input_x)
546        >>> print(output)
547        (Tensor(shape[2,4], dtype=Int32, value=
548        [[1, 1, 1, 1],
549         [1, 1, 1, 1]]),)
550    """
551
552    @prim_attr_register
553    def __init__(self, root_rank, group=GlobalComm.WORLD_COMM_GROUP):
554        """Initialize Broadcast."""
555        validator.check_value_type('root_rank', root_rank, (int,), self.name)
556        validator.check_value_type('group', _get_group(group), (str,), self.name)
557        check_hcom_group_valid(group, prim_name=self.name)
558        self.add_prim_attr('group', _get_group(group))
559        self.add_prim_attr('no_elimilate', True)
560
561    def infer_shape(self, x_shape):
562        return x_shape
563
564    def infer_dtype(self, x_dtype):
565        if not isinstance(x_dtype, tuple):
566            raise TypeError(f"For '{self.name}', the 'input_x' should be a tuple, but got {type(x_dtype).__name__}!")
567        for _ele in x_dtype:
568            validator.check_tensor_dtype_valid('x', _ele, target_dtypes, self.name)
569        return x_dtype
570
571
572class AllSwap(PrimitiveWithCheck):
573    """
574    AllSwap is a collective operation.
575
576    AllSwap sends data from the all processes to the all processes in the specified group. It has two phases:
577
578    - The scatter phase: On each process, the operand is split into the send size of blocks along the
579      0-th axis, and the blocks are scattered to all processes, e.g., the ith block is send to the ith process.
580    - The gather phase: Each process concatenates the received blocks along the 0-th axis.
581
582    Note:
583        The tensors must have the same format in all processes of the collection.
584
585    Args:
586        group (str): The communication group name.
587
588    Inputs:
589        tensor_in (tensor): A 2-D tensor. On each process, divide blocks into number of the send size.
590        send_size (tensor): A 1-D int64 tensor. The element is the send data size for each process.
591        recv_size (tensor): A 1-D int64 tensor. The element is the receive data size for each process.
592
593    Returns:
594        tensor_out (tensor): The result tensor.
595
596    Raises:
597        TypeError: If group is not a string.
598    """
599
600    @prim_attr_register
601    def __init__(self, group=GlobalComm.WORLD_COMM_GROUP):
602        """Initialize AllSwap"""
603        validator.check_value_type('group', _get_group(group), (str,), self.name)
604        self.init_prim_io_names(inputs=['tensor_in', 'send_size', 'recv_size'], outputs=['tensor_out'])
605        self.add_prim_attr('group', _get_group(group))
606        self.add_prim_attr('no_elimilate', True)
607
608    def __check__(self, tensor_in, send_size, recv_size):
609        validator.check_subclass("tensor_in", tensor_in['dtype'], mstype.tensor, self.name)
610        validator.check_tensor_dtype_valid("send_size", send_size['dtype'], [mstype.int64],
611                                           self.name)
612        validator.check_tensor_dtype_valid("recv_size", recv_size['dtype'], [mstype.int64],
613                                           self.name)
614
615        validator.check_equal_int(len(tensor_in['shape']), 2, "tensor_in", self.name)
616        validator.check_equal_int(len(send_size['shape']), 1, "send_size", self.name)
617        validator.check_equal_int(len(recv_size['shape']), 1, "recv_size", self.name)
618
619        out_shape = [-1] + [tensor_in['shape'][1]]
620        out = {'shape': out_shape,
621               'dtype': tensor_in['dtype'],
622               'value': None}
623        return out
624
625
626class NeighborExchange(Primitive):
627    """
628    NeighborExchange is a collective operation.
629
630    NeighborExchange sends data from the local rank to ranks in the send_rank_ids,
631    as while receive data from recv_rank_ids.
632
633    Args:
634        send_rank_ids (list(int)): Ranks which the data is sent to.
635        recv_rank_ids (list(int)): Ranks which the data is received from.
636        recv_shapes (tuple(list(int))): Data shape which received from recv_rank_ids.
637        send_shapes (tuple(list(int))): Data shape which send to the send_rank_ids.
638        recv_type (type): Data type which received from recv_rank_ids
639        group (str):
640    """
641
642    @prim_attr_register
643    def __init__(self, send_rank_ids, recv_rank_ids, recv_shapes, send_shapes, recv_type,
644                 group=GlobalComm.WORLD_COMM_GROUP):
645        self.init_prim_io_names(inputs=['x'], outputs=['output'])
646        self.send_rank_ids = send_rank_ids
647        self.recv_rank_ids = recv_rank_ids
648        self.recv_shapes = recv_shapes
649        self.send_shapes = send_shapes
650        self.recv_type = recv_type
651        self.add_prim_attr('no_elimilate', True)
652
653    def __call__(self, tensor):
654        raise NotImplementedError
655
656
657class AlltoAll(PrimitiveWithInfer):
658    """
659    AlltoAll is a collective operation.
660
661    AlltoAll sends data from the all processes to the all processes in the specified group. It has two phases:
662
663    - The scatter phase: On each process, the operand is split into split_count number of blocks along the
664      split_dimensions, and the blocks are scattered to all processes, e.g., the ith block is send to the ith process.
665    - The gather phase: Each process concatenates the received blocks along the concat_dimension.
666
667    Note:
668        The tensors must have the same shape and format in all processes of the collection.
669
670    Args:
671        split_count (int): On each process, divide blocks into split_count number.
672        split_dim (int): On each process, split blocks along the split_dim.
673        concat_dim (int): On each process, gather the received blocks along the concat_dimension.
674        group (str): The communication group to work on. Default: "hccl_world_group".
675
676    Raises:
677        TypeError: If group is not a string.
678    """
679
680    @prim_attr_register
681    def __init__(self, split_count, split_dim, concat_dim, group=GlobalComm.WORLD_COMM_GROUP):
682        """Initialize AlltoAll"""
683        validator.check_value_type('group', _get_group(group), (str,), self.name)
684        validator.check_is_int(split_count, int)
685        validator.check_is_int(split_dim, int)
686        validator.check_is_int(concat_dim, int)
687        self.split_count = split_count
688        self.split_dim = split_dim
689        self.concat_dim = concat_dim
690        self.add_prim_attr('group', _get_group(group))
691        self.add_prim_attr('no_elimilate', True)
692
693    def infer_shape(self, x_shape):
694        rank_size = get_group_size(_get_group(self.group))
695        if self.split_count != rank_size:
696            raise ValueError(f"For '{self.name}', the 'split_count' must be equal to 'rank_size', "
697                             f"but got 'split_count': {self.split_count}, 'rank_size': {rank_size}.")
698        if x_shape[self.split_dim] % self.split_count != 0:
699            raise ValueError(f"For '{self.name}', the 'split_count' must be divisible by 'rank_size', "
700                             f"but got 'split_count' {self.split_count}, 'rank_size' {x_shape[self.split_dim]}.")
701        x_shape[self.concat_dim] = x_shape[self.concat_dim] * self.split_count
702        x_shape[self.split_dim] = int(x_shape[self.split_dim] / self.split_count)
703        return x_shape
704
705    def infer_dtype(self, x_dtype):
706        validator.check_tensor_dtype_valid('x', x_dtype, target_dtypes, self.name)
707        return x_dtype
708
709    def __call__(self, tensor):
710        raise NotImplementedError
711
712
713class _MirrorOperator(PrimitiveWithInfer):
714    """
715    Auto parallel virtual operator. Do nothing in forward, do all reduce and mean in backward. It is only for
716    internal use of parallel modules and cannot be called by users.
717
718    Args:
719        group (str): The communication group to work on. Default: None.
720        dev_num (int): The device number of the group. Default: None.
721        mean_flag (bool): Whether use mean in backward. Default: None.
722    """
723
724    @prim_attr_register
725    def __init__(self, group=None, dev_num=None, mean_flag=None):
726        """Initialize _MirrorOperator."""
727        self.group = group
728        self.dev_num = dev_num
729        self.mean_flag = mean_flag
730        self.add_prim_attr("fusion", 1)
731
732    def infer_shape(self, x_shape):
733        return x_shape
734
735    def infer_dtype(self, x_dtype):
736        return x_dtype
737
738
739mirror = _MirrorOperator()
740
741
742class _MirrorMiniStepOperator(PrimitiveWithInfer):
743    """
744    Auto parallel virtual operator. Do nothing in forward, do all reduce and mean in backward. It is only for
745    internal use of parallel modules and cannot be called by users.
746
747    Args:
748        group (str): The communication group to work on. Default: None.
749        dev_num (int): The device number of the group. Default: None.
750        mean_flag (bool): Whether use mean in backward. Default: None.
751        grad_accumulation_step (int): The grad accumulation step. Default: None.
752    """
753
754    @prim_attr_register
755    def __init__(self, group=None, dev_num=None, mean_flag=None, grad_accumulation_step=None):
756        """Initialize _MirrorMiniStepOperator."""
757        self.group = group
758        self.dev_num = dev_num
759        self.mean_flag = mean_flag
760        self.grad_accumulation_step = grad_accumulation_step
761
762    def infer_shape(self, x_shape, z_shape):
763        return x_shape
764
765    def infer_dtype(self, x_dtype, z_shape):
766        return x_dtype
767
768
769mirror_mini_step = _MirrorMiniStepOperator()
770
771
772class _VirtualDiv(PrimitiveWithInfer):
773    """
774    Auto parallel virtual operator. Do nothing in forward, do Div in backward.
775
776    Args:
777        divisor: float32
778    """
779
780    @prim_attr_register
781    def __init__(self, divisor=None):
782        """Initialize _VirtualDiv."""
783        self.divisor = divisor
784
785    def infer_shape(self, x_shape):
786        return x_shape
787
788    def infer_dtype(self, x_dtype):
789        return x_dtype
790
791
792virtual_div = _VirtualDiv()
793
794
795class _VirtualAdd(PrimitiveWithInfer):
796    """Auto parallel virtual operator. Do nothing in forward, do Add in backward."""
797
798    @prim_attr_register
799    def __init__(self):
800        """Initialize _VirtualAdd."""
801
802    def infer_shape(self, x_shape, y_shape):
803        return x_shape
804
805    def infer_dtype(self, x_dtype, y_dtype):
806        return x_dtype
807
808
809class _VirtualDataset(PrimitiveWithInfer):
810    """
811    Auto parallel virtual dataset operator.
812
813    It would insert VirtualDataset operator in forward computation and be deleted before backward computation.
814    """
815
816    @prim_attr_register
817    def __init__(self):
818        """Initialize _VirtualDataset."""
819
820    def infer_shape(self, *args):
821        return args
822
823    def infer_dtype(self, *args):
824        return args
825
826
827virtual_dataset = _VirtualDataset()
828
829
830class _VirtualAssignAdd(PrimitiveWithInfer):
831    """
832    Auto parallel virtual operator. Do nothing in forward, do AssignAdd in backward. It is only for
833    internal use of parallel modules and cannot be called by users.
834
835    """
836
837    @prim_attr_register
838    def __init__(self):
839        """Initialize _VirtualAssignAdd."""
840
841    def infer_shape(self, x_shape, y_shape):
842        return x_shape
843
844    def infer_dtype(self, x_dtype, y_dtype):
845        return x_dtype
846
847
848virtual_assign_add = _VirtualAssignAdd()
849
850
851class _VirtualAccuGrad(PrimitiveWithInfer):
852    """
853    Auto parallel virtual operator. Do nothing in forward, return y in backward. It is only for
854    internal use of parallel modules and cannot be called by users.
855    """
856
857    @prim_attr_register
858    def __init__(self):
859        """Initialize _VirtualAccuGrad."""
860
861    def infer_shape(self, x_shape, y_shape):
862        return x_shape
863
864    def infer_dtype(self, x_dtype, y_dtype):
865        return x_dtype
866
867
868virtual_accu_grad = _VirtualAccuGrad()
869
870
871class _MirrorMicroStepOperator(PrimitiveWithInfer):
872    """
873    Auto parallel virtual operator. Do nothing in forward, do all reduce and mean in backward. It is only for
874    internal use of parallel modules and cannot be called by users.
875
876    Args:
877        group (str): The communication group to work on. Default: None.
878        dev_num (int): The device number of the group. Default: None.
879        mean_flag (bool): Whether use mean in backward. Default: None.
880    """
881
882    @prim_attr_register
883    def __init__(self, group=None, dev_num=None, mean_flag=None):
884        """Initialize _MirrorMicroStepOperator."""
885        self.group = group
886        self.dev_num = dev_num
887        self.mean_flag = mean_flag
888
889    def infer_shape(self, x_shape, z_shape):
890        return x_shape
891
892    def infer_dtype(self, x_dtype, z_shape):
893        return x_dtype
894
895
896class _VirtualOutput(PrimitiveWithInfer):
897    """
898    Auto parallel virtual out operator.
899
900    It would insert VirtualOutput operator in forward computation and be deleted before backward computation.
901    """
902
903    @prim_attr_register
904    def __init__(self):
905        """Initialize _VirtualOutput."""
906
907    def infer_shape(self, x_shape):
908        return x_shape
909
910    def infer_dtype(self, x_dtype):
911        return x_dtype
912
913
914class _GetTensorSlice(PrimitiveWithInfer):
915    """
916    Gets tensor slice by device matrix and tensor map.
917
918    Args:
919        dev_mat (tuple): The device matrix of the slice tensor.
920        tensor_map (tuple): The tensor map of the slice tensor.
921    """
922
923    @prim_attr_register
924    def __init__(self):
925        """Initialize _GetTensorSlice."""
926
927    def infer_value(self, x, dev_mat, tensor_map):
928        from mindspore.parallel._tensor import _load_tensor
929        validator.check_value_type("dev_mat", dev_mat, [tuple], self.name)
930        validator.check_value_type("tensor_map", tensor_map, [tuple], self.name)
931        return Tensor(_load_tensor(x, dev_mat, tensor_map))
932