1# Copyright 2020-2023 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16"""Communication APIs. 17""" 18from __future__ import absolute_import 19from __future__ import division 20 21from mindspore.common import Tensor 22from mindspore import _checkparam as validator 23from mindspore.communication.management import get_rank, get_group_size, GlobalComm, _get_group, _host_distribute 24from mindspore.common import dtype as mstype 25from mindspore.ops.primitive import PrimitiveWithInfer, PrimitiveWithCheck, Primitive, prim_attr_register 26from mindspore.common.api import context 27 28 29class ReduceOp: 30 """ 31 Operation options for reducing tensors. This is an enumerated type, not an operator. 32 33 The main calling methods are as follows: 34 35 - SUM: ReduceOp.SUM. 36 - MAX: ReduceOp.MAX. 37 - MIN: ReduceOp.MIN. 38 - PROD: ReduceOp.PROD. 39 40 There are four kinds of operation options, "SUM", "MAX", "MIN", and "PROD". 41 42 - SUM: Take the sum. 43 - MAX: Take the maximum. 44 - MIN: Take the minimum. 45 - PROD: Take the product. 46 47 Supported Platforms: 48 ``Ascend`` ``GPU`` 49 50 Examples: 51 .. note:: 52 Before running the following examples, you need to configure the communication environment variables. 53 54 For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method 55 without any third-party or configuration file dependencies. 56 Please see the `msrun start up 57 <https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/msrun_launcher.html>`_ 58 for more details. 59 60 This example should be run with multiple devices. 61 62 >>> import numpy as np 63 >>> import mindspore 64 >>> from mindspore.communication import init 65 >>> from mindspore import Tensor, ops, nn 66 >>> from mindspore.ops import ReduceOp 67 >>> 68 >>> init() 69 >>> class Net(nn.Cell): 70 ... def __init__(self): 71 ... super(Net, self).__init__() 72 ... self.allreduce_sum = ops.AllReduce(ReduceOp.SUM) 73 ... 74 ... def construct(self, x): 75 ... return self.allreduce_sum(x) 76 ... 77 >>> input_ = Tensor(np.ones([2, 8]).astype(np.float32)) 78 >>> net = Net() 79 >>> output = net(input_) 80 >>> print(output) 81 [[2. 2. 2. 2. 2. 2. 2. 2.] 82 [2. 2. 2. 2. 2. 2. 2. 2.]] 83 """ 84 SUM = "sum" 85 MAX = "max" 86 MIN = "min" 87 PROD = "prod" 88 89 90def check_collective_target_dtype(data_name, data_dtype, prim_name): 91 """Check if data type is valid.""" 92 default_target_dtypes = (mstype.int8, mstype.int32, mstype.float16, mstype.float32, mstype.bfloat16) 93 gpu_target_dtypes = (mstype.bool_, mstype.int8, mstype.int32, mstype.int64, mstype.uint32, mstype.uint64, 94 mstype.float16, mstype.float32, mstype.float64) 95 96 valid_dtype = gpu_target_dtypes if context.get_context("device_target") == "GPU" else default_target_dtypes 97 validator.check_tensor_dtype_valid(data_name, data_dtype, valid_dtype, prim_name) 98 99 100def check_hcom_group_valid(group, prim_name=None): 101 """Check if hcom group is valid.""" 102 msg_prefix = f"For '{prim_name}', the" if prim_name else "The" 103 if not _host_distribute() and context.get_context("mode") == context.PYNATIVE_MODE and \ 104 group != GlobalComm.WORLD_COMM_GROUP: 105 raise RuntimeError(f"{msg_prefix} 'group' only support 'hccl_world_group' in pynative mode, but got " 106 f"'group': {group}. Please start by using mpi-run.") 107 108 109class AllReduce(Primitive): 110 """ 111 Reduces tensors across all devices in such a way that all devices will get the same final result, 112 returns the tensor which is all reduced. 113 114 Note: 115 The tensors must have the same shape and format in all processes of the collection. 116 117 Args: 118 op (str, optional): Specifies an operation used for element-wise reductions, like sum, prod, max, and min. 119 On the CPU, only 'sum' is supported. Default: ``ReduceOp.SUM`` . 120 group (str, optional): The communication group to work on. Default: ``GlobalComm.WORLD_COMM_GROUP`` , which 121 means ``"hccl_world_group"`` in Ascend, and ``"nccl_world_group"`` in GPU. 122 123 Inputs: 124 - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. 125 126 Outputs: 127 Tensor, has the same shape of the input, i.e., :math:`(x_1, x_2, ..., x_R)`. 128 The contents depend on the specified operation. 129 130 Raises: 131 TypeError: If any of `op` and `group` is not a str or the input's dtype is bool. 132 RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails. 133 134 Supported Platforms: 135 ``Ascend`` ``GPU`` ``CPU`` 136 137 Examples: 138 .. note:: 139 Before running the following examples, you need to configure the communication environment variables. 140 141 For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method 142 without any third-party or configuration file dependencies. 143 Please see the `msrun start up 144 <https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/msrun_launcher.html>`_ 145 for more details. 146 147 This example should be run with 2 devices. 148 149 >>> import numpy as np 150 >>> from mindspore.communication import init 151 >>> from mindspore import Tensor 152 >>> from mindspore.ops import ReduceOp 153 >>> import mindspore.nn as nn 154 >>> from mindspore import ops 155 >>> 156 >>> init() 157 >>> class Net(nn.Cell): 158 ... def __init__(self): 159 ... super(Net, self).__init__() 160 ... self.allreduce_sum = ops.AllReduce(ReduceOp.SUM) 161 ... 162 ... def construct(self, x): 163 ... return self.allreduce_sum(x) 164 ... 165 >>> input_ = Tensor(np.ones([2, 8]).astype(np.float32)) 166 >>> net = Net() 167 >>> output = net(input_) 168 >>> print(output) 169 [[2. 2. 2. 2. 2. 2. 2. 2.] 170 [2. 2. 2. 2. 2. 2. 2. 2.]] 171 172 Tutorial Examples: 173 - `Distributed Set Communication Primitives - AllReduce 174 <https://www.mindspore.cn/docs/en/master/api_python/samples/ops/communicate_ops.html#allreduce>`_ 175 176 """ 177 178 @prim_attr_register 179 def __init__(self, op=ReduceOp.SUM, group=GlobalComm.WORLD_COMM_GROUP): 180 """Initialize AllReduce.""" 181 if not isinstance(op, type(ReduceOp.SUM)): 182 raise TypeError(f"For '{self.name}', the 'op' must be str, but got {type(op).__name__}.") 183 if not isinstance(_get_group(group), str): 184 raise TypeError(f"For '{self.name}', the 'group' must be str, " 185 f"but got {type(_get_group(group)).__name__}.") 186 check_hcom_group_valid(group, prim_name=self.name) 187 self.op = op 188 self.add_prim_attr('group', _get_group(group)) 189 self.add_prim_attr('fusion', 0) 190 self.add_prim_attr('index', 0) 191 self.add_prim_attr('no_eliminate', True) 192 193 194class Reduce(PrimitiveWithInfer): 195 """ 196 Reduces tensors across the processes in the specified communication group, sends the result 197 to the target dest_rank(local rank), and returns the tensor which is sent to the target process. 198 199 Note: 200 Only process with destination rank receives the reduced output. 201 Support PyNative mode and Graph mode, but Graph mode only supports scenes with a graph compilation level of O0. 202 Other processes only get a tensor with shape [1], which has no mathematical meaning. 203 204 Args: 205 dest_rank (int): The target process(local rank) in the specific group that receives the reduced output. 206 op (str, optional): Specifies an operation used for element-wise reductions, like sum, prod, max, and min. 207 On the CPU, only 'sum' is supported. Default: ``ReduceOp.SUM`` . 208 group (str, optional): The communication group to work on. Default: ``GlobalComm.WORLD_COMM_GROUP`` , which 209 means ``"hccl_world_group"`` in Ascend, and ``"nccl_world_group"`` in GPU. 210 211 Inputs: 212 - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. 213 214 Outputs: 215 Tensor. Return the tensor in the specific rank of the process after reduction. 216 The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. 217 218 Raises: 219 TypeError: If the type of the first input parameter is not Tensor, 220 or any of `op` and `group` is not a str. 221 RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails. 222 223 Supported Platforms: 224 ``Ascend`` 225 226 Examples: 227 .. note:: 228 Before running the following examples, you need to configure the communication environment variables. 229 230 For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method without any third-party 231 or configuration file dependencies. 232 Please see the `msrun start up 233 <https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/msrun_launcher.html>`_ 234 for more details. 235 236 This example should be run with 4 devices. 237 238 >>> from mindspore import ops 239 >>> import mindspore.nn as nn 240 >>> from mindspore.communication import init 241 >>> from mindspore import Tensor 242 >>> import numpy as np 243 >>> # Launch 4 processes. 244 >>> init() 245 >>> class ReduceNet(nn.Cell): 246 >>> def __init__(self): 247 >>> super(Net, self).__init__() 248 >>> self.reduce = ops.Reduce(dest_rank=1) 249 >>> 250 >>> def construct(self, x): 251 >>> out = self.reduce(x) 252 >>> return out 253 >>> input = Tensor(np.ones([2, 8]).astype(np.float32)) 254 >>> net = ReduceNet() 255 >>> output = net(input) 256 >>> print(output) 257 Process with rank 1: [[4. 4. 4. 4. 4. 4. 4. 4.] 258 [4. 4. 4. 4. 4. 4. 4. 4.]], 259 Other proesses: [0.]. 260 """ 261 262 @prim_attr_register 263 def __init__(self, dest_rank, op=ReduceOp.SUM, group=GlobalComm.WORLD_COMM_GROUP): 264 validator.check_value_type('group', _get_group(group), (str,), self.name) 265 validator.check_value_type('op', op, (type(ReduceOp.SUM),), self.name) 266 self.dest_rank = dest_rank 267 self.op = op 268 self.group = _get_group(group) 269 self.add_prim_attr('group', _get_group(group)) 270 self.add_prim_attr('dest_rank', dest_rank) 271 272 def infer_shape(self, x_shape): 273 # The process with dest_rank returns the reduced output. 274 # Other processes only gets a tensor with shape [1], which has no mathematical meaning. 275 if self.dest_rank == get_rank(): 276 return x_shape 277 return [1] 278 279 def infer_dtype(self, x_dtype): 280 return x_dtype 281 282 283class AllGather(PrimitiveWithInfer): 284 """ 285 Gathers tensors from the specified communication group and returns the tensor which is all gathered. 286 287 Note: 288 - The tensors must have the same shape and format in all processes of the collection. 289 290 Args: 291 group (str, optional): The communication group to work on. Default: ``GlobalComm.WORLD_COMM_GROUP`` , which 292 means ``"hccl_world_group"`` in Ascend, and ``"nccl_world_group"`` in GPU. 293 294 Inputs: 295 - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. 296 297 Outputs: 298 Tensor. If the number of devices in the group is N, 299 then the shape of output is :math:`(N, x_1, x_2, ..., x_R)`. 300 301 Raises: 302 TypeError: If `group` is not a str. 303 ValueError: If the local rank id of the calling process in the group 304 is larger than the group's rank size. 305 RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails. 306 307 Supported Platforms: 308 ``Ascend`` ``GPU`` 309 310 Examples: 311 .. note:: 312 Before running the following examples, you need to configure the communication environment variables. 313 314 For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method 315 without any third-party or configuration file dependencies. 316 Please see the `msrun start up 317 <https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/msrun_launcher.html>`_ 318 for more details. 319 320 This example should be run with 2 devices. 321 322 >>> import numpy as np 323 >>> import mindspore as ms 324 >>> from mindspore import ops 325 >>> import mindspore.nn as nn 326 >>> from mindspore.communication import init 327 >>> from mindspore import Tensor 328 >>> 329 >>> ms.set_context(mode=ms.GRAPH_MODE) 330 >>> init() 331 >>> class Net(nn.Cell): 332 ... def __init__(self): 333 ... super(Net, self).__init__() 334 ... self.allgather = ops.AllGather() 335 ... 336 ... def construct(self, x): 337 ... return self.allgather(x) 338 ... 339 >>> input_x = Tensor(np.ones([2, 8]).astype(np.float32)) 340 >>> net = Net() 341 >>> output = net(input_x) 342 >>> print(output) 343 [[1. 1. 1. 1. 1. 1. 1. 1.] 344 [1. 1. 1. 1. 1. 1. 1. 1.] 345 [1. 1. 1. 1. 1. 1. 1. 1.] 346 [1. 1. 1. 1. 1. 1. 1. 1.]] 347 348 Tutorial Examples: 349 - `Distributed Set Communication Primitives - AllGather 350 <https://www.mindspore.cn/docs/en/master/api_python/samples/ops/communicate_ops.html#allgather>`_ 351 352 """ 353 354 @prim_attr_register 355 def __init__(self, group=GlobalComm.WORLD_COMM_GROUP): 356 """Initialize AllGather.""" 357 validator.check_value_type('group', _get_group(group), (str,), self.name) 358 self.rank = get_rank(_get_group(group)) 359 self.rank_size = get_group_size(_get_group(group)) 360 validator.check('rank', self.rank, 'rank_size', self.rank_size, validator.LT, self.name) 361 self.add_prim_attr('rank_size', self.rank_size) 362 self.add_prim_attr('group', _get_group(group)) 363 self.add_prim_attr('fusion', 0) 364 self.add_prim_attr('mean_flag', False) 365 self.add_prim_attr('no_eliminate', True) 366 367 def infer_shape(self, x_shape): 368 validator.check_positive_int(len(x_shape), "x shape", self.name) 369 if x_shape[0] > 0: 370 x_shape[0] = x_shape[0] * self.rank_size 371 return x_shape 372 373 def infer_dtype(self, x_dtype): 374 check_collective_target_dtype('x', x_dtype, self.name) 375 return x_dtype 376 377 378class AShardIdentity(PrimitiveWithInfer): 379 """ 380 Auto parallel virtual operator. Identity operator only for shard function. 381 Do nothing in terms of infer_shape, infer_dtype, and the tensor. 382 383 It is only for internal use of parallel modules and cannot be called by users. 384 """ 385 386 @prim_attr_register 387 def __init__(self): 388 pass 389 390 def infer_shape(self, x_shape): 391 return x_shape 392 393 def infer_dtype(self, x_dtype): 394 return x_dtype 395 396 397class _MiniStepAllGather(PrimitiveWithInfer): 398 """ 399 Auto parallel virtual operator. Do nothing in forward, do reducescatter in backward in mini-step. It is only for 400 internal use of parallel modules and cannot be called by users. 401 402 Args: 403 group (str): The communication group to work on. Default: ``None`` . 404 grad_accumulation_step (int): The grad accumulation step. Default: ``None`` . 405 """ 406 407 @prim_attr_register 408 def __init__(self, group=GlobalComm.WORLD_COMM_GROUP, grad_accumulation_step=None, mean_flag=None): 409 """Initialize _MiniStepAllGather.""" 410 validator.check_value_type('group', _get_group(group), (str,), self.name) 411 self.rank = get_rank(_get_group(group)) 412 self.rank_size = get_group_size(_get_group(group)) 413 validator.check('rank', self.rank, 'rank_size', self.rank_size, validator.LT, self.name) 414 self.add_prim_attr('rank_size', self.rank_size) 415 self.add_prim_attr('group', _get_group(group)) 416 self.add_prim_attr('fusion', 1) 417 self.grad_accumulation_step = grad_accumulation_step 418 self.mean_flag = mean_flag 419 self.add_prim_attr('order_enforce_skip', True) 420 self.add_prim_attr('side_effect_backprop_mem', True) 421 422 def infer_shape(self, x_shape, z_shape): 423 validator.check_positive_int(len(x_shape), "x shape", self.name) 424 if x_shape[0] > 0: 425 x_shape[0] = x_shape[0] * self.rank_size 426 return x_shape 427 428 def infer_dtype(self, x_dtype, z_shape): 429 check_collective_target_dtype('x', x_dtype, self.name) 430 return x_dtype 431 432 433class _MicroStepAllGather(PrimitiveWithInfer): 434 """ 435 Auto parallel virtual operator. Do nothing in forward, do reducescatter in backward in mini-step. It is only for 436 internal use of parallel modules and cannot be called by users. 437 438 Args: 439 group (str): The communication group to work on. Default: ``None`` . 440 """ 441 442 @prim_attr_register 443 def __init__(self, group=GlobalComm.WORLD_COMM_GROUP, mean_flag=None): 444 validator.check_value_type('group', _get_group(group), (str,), self.name) 445 self.rank_size = 1 446 if group != "": 447 self.rank = get_rank(_get_group(group)) 448 self.rank_size = get_group_size(_get_group(group)) 449 validator.check('rank', self.rank, 'rank_size', self.rank_size, validator.LT, self.name) 450 self.add_prim_attr('rank_size', self.rank_size) 451 self.add_prim_attr('group', _get_group(group)) 452 self.add_prim_attr('fusion', 1) 453 self.add_prim_attr('do_mirror', False) 454 self.mean_flag = mean_flag 455 self.add_prim_attr('order_enforce_skip', True) 456 457 def infer_shape(self, x_shape, z_shape): 458 validator.check_positive_int(len(x_shape), "x shape", self.name) 459 if x_shape[0] > 0: 460 x_shape[0] = x_shape[0] * self.rank_size 461 return x_shape 462 463 def infer_dtype(self, x_dtype, z_dtype): 464 check_collective_target_dtype('x', x_dtype, self.name) 465 return x_dtype 466 467 468class _HostAllGather(PrimitiveWithInfer): 469 """ 470 Gathers tensors from the specified communication group on host. 471 472 Note: 473 The tensors must have the same shape and format in all processes of the collection. 474 _HostAllGather is a host-side operator, it depends on OpenMPI and must use build option -M on 475 to enable it. Using mpirun command to run it: 476 mpirun -output-filename log -merge-stderr-to-stdout -np 3 python test_host_all_gather.py 477 478 Args: 479 group (Union[tuple[int],list[int]]): The rand_ids of communication group to work on. Default: ``None`` . 480 481 Raises: 482 TypeError: If group is not a list nor tuple, or elements of group are not int. 483 ValueError: If group is not set, or rank_id from group not in [0, 7]. 484 485 Inputs: 486 - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. 487 488 Outputs: 489 Tensor. If the number of devices in the group is N, 490 then the shape of output is :math:`(N, x_1, x_2, ..., x_R)`. 491 """ 492 493 @prim_attr_register 494 def __init__(self, group=None): 495 """Initialize _HostAllGather.""" 496 if group is None: 497 raise ValueError(f"For '{self.name}', the 'group' cannot be None, but got {group}.") 498 validator.check_value_type('group', group, (tuple, list), self.name) 499 validator.check_int(len(group), 2, validator.GE, "group size", self.name) 500 for r in group: 501 validator.check_int_range(r, 0, 7, validator.INC_BOTH, "rank_id", self.name) 502 validator.check_value_type("rank_id", r, (int,), self.name) 503 self.group_size = len(group) 504 self.add_prim_attr('group', group) 505 self.add_prim_attr('no_eliminate', True) 506 self.add_prim_attr('order_enforce_skip', True) 507 508 def infer_shape(self, x_shape): 509 validator.check_positive_int(len(x_shape), "x shape", self.name) 510 if x_shape[0] > 0: 511 x_shape[0] = x_shape[0] * self.group_size 512 return x_shape 513 514 def infer_dtype(self, x_dtype): 515 check_collective_target_dtype('x', x_dtype, self.name) 516 return x_dtype 517 518 def __call__(self, tensor): 519 raise NotImplementedError 520 521 522class ReduceScatter(Primitive): 523 r""" 524 Reduces and scatters tensors from the specified communication group 525 and returns the tensor which is reduced and scattered. 526 527 Note: 528 The tensors must have the same shape and format in all processes of the collection. 529 530 Args: 531 op (str, optional): Specifies an operation used for element-wise reductions, 532 like SUM and MAX. Default: ``ReduceOp.SUM`` . 533 group (str, optional): The communication group to work on. Default: ``GlobalComm.WORLD_COMM_GROUP`` . 534 535 Inputs: 536 - **input_x** (Tensor) - Input Tensor, suppose it has a shape :math:`(N, *)`, where `*` 537 means any number of additional dimensions. N must be divisible by rank_size. 538 rank_size refers to the number of cards in the communication group. 539 540 Outputs: 541 Tensor, it has the same dtype as `input_x` with a shape of :math:`(N/rank\_size, *)`. 542 543 Raises: 544 TypeError: If any of operation and group is not a string. 545 ValueError: If the first dimension of the input cannot be divided by the rank_size. 546 RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails. 547 548 Supported Platforms: 549 ``Ascend`` ``GPU`` 550 551 Examples: 552 .. note:: 553 Before running the following examples, you need to configure the communication environment variables. 554 555 For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method 556 without any third-party or configuration file dependencies. 557 Please see the `msrun start up 558 <https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/msrun_launcher.html>`_ 559 for more details. 560 561 This example should be run with 2 devices. 562 563 >>> import mindspore as ms 564 >>> from mindspore import Tensor 565 >>> from mindspore.communication import init 566 >>> from mindspore.ops import ReduceOp 567 >>> import mindspore.nn as nn 568 >>> from mindspore import ops 569 >>> import numpy as np 570 >>> 571 >>> ms.set_context(mode=ms.GRAPH_MODE) 572 >>> init() 573 >>> class Net(nn.Cell): 574 ... def __init__(self): 575 ... super(Net, self).__init__() 576 ... self.reducescatter = ops.ReduceScatter(ReduceOp.SUM) 577 ... 578 ... def construct(self, x): 579 ... return self.reducescatter(x) 580 ... 581 >>> input_ = Tensor(np.ones([8, 8]).astype(np.float32)) 582 >>> net = Net() 583 >>> output = net(input_) 584 >>> print(output) 585 [[2. 2. 2. 2. 2. 2. 2. 2.] 586 [2. 2. 2. 2. 2. 2. 2. 2.] 587 [2. 2. 2. 2. 2. 2. 2. 2.] 588 [2. 2. 2. 2. 2. 2. 2. 2.]] 589 590 Tutorial Examples: 591 - `Distributed Set Communication Primitives - ReduceScatter 592 <https://www.mindspore.cn/docs/en/master/api_python/samples/ops/communicate_ops.html#reducescatter>`_ 593 594 """ 595 596 @prim_attr_register 597 def __init__(self, op=ReduceOp.SUM, group=GlobalComm.WORLD_COMM_GROUP): 598 """Initialize ReduceScatter.""" 599 validator.check_value_type('op', op, (type(ReduceOp.SUM),), self.name) 600 validator.check_value_type('group', _get_group(group), (str,), self.name) 601 self.op = op 602 self.rank_size = get_group_size(_get_group(group)) 603 self.add_prim_attr('rank_size', self.rank_size) 604 self.add_prim_attr('group', _get_group(group)) 605 self.add_prim_attr('fusion', 0) 606 self.add_prim_attr('no_eliminate', True) 607 608 609class _HostReduceScatter(PrimitiveWithInfer): 610 """ 611 Reduces and scatters tensors from the specified communication group on host. 612 613 Note: 614 The tensors must have the same shape and format in all processes of the collection. 615 _HostReduceScatter is a host-side operator, it depends on OpenMPI and must use build option 616 -M on to enable it. Using mpirun command to run it: 617 mpirun -output-filename log -merge-stderr-to-stdout -np 3 python test_host_reduce_scatter.py 618 619 Args: 620 op (str): Specifies an operation used for element-wise reductions, 621 like sum, max, avg. Default: ``ReduceOp.SUM`` . 622 group (Union[tuple[int],list[int]]): The rand_ids of communication group to work on. Default: ``None`` . 623 624 Raises: 625 TypeError: If op is not a string and group is not a list nor tuple, 626 or elements of group are not int. 627 ValueError: If the first dimension of input can not be divided by group size, 628 or group is not set, or rank_id not in [0, 7]. 629 """ 630 631 @prim_attr_register 632 def __init__(self, op=ReduceOp.SUM, group=None): 633 """Initialize _HostReduceScatter.""" 634 if group is None: 635 raise ValueError(f"For '{self.name}', the 'group' cannot be None, but got {group}.") 636 validator.check_value_type('op', op, (type(ReduceOp.SUM),), self.name) 637 validator.check_value_type('group', group, (tuple, list), self.name) 638 validator.check_int(len(group), 2, validator.GE, "group size", self.name) 639 for r in group: 640 validator.check_int_range(r, 0, 7, validator.INC_BOTH, "rank_id", self.name) 641 validator.check_value_type("rank_id", r, (int,), self.name) 642 self.op = op 643 self.group_size = len(group) 644 self.add_prim_attr('group', group) 645 self.add_prim_attr('no_eliminate', True) 646 self.add_prim_attr('order_enforce_skip', True) 647 648 def infer_shape(self, x_shape): 649 if x_shape[0] % self.group_size != 0: 650 raise ValueError(f"For '{self.name}', the first dimension of 'x_shape' must be divided by 'group_size', " 651 f"but got 'x_shape[0]': {x_shape[0]}, 'rank_size': {self.group_size}.") 652 x_shape[0] = int(x_shape[0] / self.group_size) 653 return x_shape 654 655 def infer_dtype(self, x_dtype): 656 check_collective_target_dtype('x', x_dtype, self.name) 657 return x_dtype 658 659 def __call__(self, tensor): 660 raise NotImplementedError 661 662 663class Broadcast(PrimitiveWithInfer): 664 """ 665 Broadcasts the tensor to the whole group. 666 667 Note: 668 The tensors must have the same shape and format in all processes of the collection. 669 670 Args: 671 root_rank (int): Specifies the rank(global rank) of the process that broadcast the tensor. 672 And only process `root_rank` will broadcast the tensor. 673 group (str, optional): The communication group to work on. Default: ``GlobalComm.WORLD_COMM_GROUP`` . 674 675 Inputs: 676 - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. 677 678 Outputs: 679 tuple[Tensor], Tensor has the same shape of the input, i.e., :math:`(x_1, x_2, ..., x_R)`. 680 The contents depend on the data of the `root_rank` device. 681 682 Raises: 683 TypeError: If root_rank is not an integer or group is not a string. 684 685 Supported Platforms: 686 ``Ascend`` ``GPU`` 687 688 Examples: 689 .. note:: 690 Before running the following examples, you need to configure the communication environment variables. 691 692 For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method 693 without any third-party or configuration file dependencies. 694 Please see the `msrun start up 695 <https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/msrun_launcher.html>`_ 696 for more details. 697 698 This example should be run with 2 devices. 699 700 >>> import mindspore as ms 701 >>> from mindspore import Tensor 702 >>> from mindspore.communication import init 703 >>> import mindspore.nn as nn 704 >>> from mindspore import ops 705 >>> import numpy as np 706 >>> 707 >>> ms.set_context(mode=ms.GRAPH_MODE) 708 >>> init() 709 >>> class Net(nn.Cell): 710 ... def __init__(self): 711 ... super(Net, self).__init__() 712 ... self.broadcast = ops.Broadcast(1) 713 ... 714 ... def construct(self, x): 715 ... return self.broadcast((x,)) 716 ... 717 >>> input_x = Tensor(np.ones([2, 4]).astype(np.int32)) 718 >>> net = Net() 719 >>> output = net(input_x) 720 >>> print(output) 721 (Tensor(shape[2,4], dtype=Int32, value= 722 [[1, 1, 1, 1], 723 [1, 1, 1, 1]]),) 724 725 Tutorial Examples: 726 - `Distributed Set Communication Primitives - Broadcast 727 <https://www.mindspore.cn/docs/en/master/api_python/samples/ops/communicate_ops.html#broadcast>`_ 728 729 """ 730 731 @prim_attr_register 732 def __init__(self, root_rank, group=GlobalComm.WORLD_COMM_GROUP): 733 """Initialize Broadcast.""" 734 validator.check_value_type('root_rank', root_rank, (int,), self.name) 735 validator.check_value_type('group', _get_group(group), (str,), self.name) 736 check_hcom_group_valid(group, prim_name=self.name) 737 self.add_prim_attr('group', _get_group(group)) 738 self.add_prim_attr('no_eliminate', True) 739 740 741class _AllSwap(PrimitiveWithCheck): 742 """ 743 _AllSwap is a collective operation. 744 745 _AllSwap sends data from the all processes to the all processes in the specified group. It has two phases: 746 747 - The scatter phase: On each process, the operand is split into the send size of blocks along the 748 0-th axis, and the blocks are scattered to all processes, e.g., the ith block is send to the ith process. 749 - The gather phase: Each process concatenates the received blocks along the 0-th axis. 750 751 Note: 752 The tensors must have the same format in all processes of the collection. 753 754 Args: 755 group (str): The communication group name. 756 757 Inputs: 758 tensor_in (tensor): A 2-D tensor. On each process, divide blocks into number of the send size. 759 send_size (tensor): A 1-D int64 tensor. The element is the send data size for each process. 760 recv_size (tensor): A 1-D int64 tensor. The element is the receive data size for each process. 761 762 Returns: 763 tensor_out (tensor): The result tensor. 764 765 Raises: 766 TypeError: If group is not a string. 767 """ 768 769 @prim_attr_register 770 def __init__(self, group=GlobalComm.WORLD_COMM_GROUP): 771 """Initialize _AllSwap""" 772 validator.check_value_type('group', _get_group(group), (str,), self.name) 773 self.init_prim_io_names(inputs=['tensor_in', 'send_size', 'recv_size'], outputs=['tensor_out']) 774 self.add_prim_attr('group', _get_group(group)) 775 self.add_prim_attr('no_eliminate', True) 776 self.add_prim_attr('order_enforce_skip', True) 777 778 def __check__(self, tensor_in, send_size, recv_size): 779 validator.check_subclass("tensor_in", tensor_in['dtype'], mstype.tensor_type, self.name) 780 validator.check_tensor_dtype_valid("send_size", send_size['dtype'], [mstype.int64], 781 self.name) 782 validator.check_tensor_dtype_valid("recv_size", recv_size['dtype'], [mstype.int64], 783 self.name) 784 785 validator.check_equal_int(len(tensor_in['shape']), 2, "tensor_in", self.name) 786 validator.check_equal_int(len(send_size['shape']), 1, "send_size", self.name) 787 validator.check_equal_int(len(recv_size['shape']), 1, "recv_size", self.name) 788 789 out_shape = [-1] + [tensor_in['shape'][1]] 790 out = {'shape': out_shape, 791 'dtype': tensor_in['dtype'], 792 'value': None} 793 return out 794 795 796class NeighborExchange(Primitive): 797 """ 798 NeighborExchange is a collective operation. 799 800 NeighborExchange sends data from the local rank to ranks in the send_rank_ids, 801 as while receive data from recv_rank_ids. 802 803 Note: 804 The user needs to preset 805 communication environment variables before running the following example, please check the details on the 806 official website of `MindSpore \ 807 <https://www.mindspore.cn/docs/en/master/api_python/mindspore.ops.primitive.html#communication-operator>`_. 808 809 This operator requires a full-mesh network topology, each device has the same vlan id, and the ip & mask are 810 in the same subnet, please check the `details \ 811 <https://www.mindspore.cn/docs/en/master/api_python/samples/ops/communicate_ops.html#notes>`_. 812 813 Args: 814 send_rank_ids (list(int)): Ranks which the data is sent to. 815 recv_rank_ids (list(int)): Ranks which the data is received from. 816 recv_shapes (tuple(list(int))): Data shape which received from recv_rank_ids. 817 send_shapes (tuple(list(int))): Data shape which send to the send_rank_ids. 818 recv_type (type): Data type which received from recv_rank_ids 819 group (str): The communication group to work on. Default: ``GlobalComm.WORLD_COMM_GROUP`` . 820 821 Inputs: 822 - **input_x** (tuple[Tensor]) - Shapes are same as args of send_shapes. 823 824 Outputs: 825 Tuple tensor, shapes are same as args of recv_shapes. 826 827 Supported Platforms: 828 ``Ascend`` 829 830 Examples: 831 >>> # This example should be run with 2 devices. Refer to the tutorial > Distributed Training on mindspore.cn 832 >>> import os 833 >>> import mindspore as ms 834 >>> from mindspore import Tensor 835 >>> from mindspore.communication import init 836 >>> import mindspore.nn as nn 837 >>> from mindspore import ops 838 >>> import numpy as np 839 >>> class Net(nn.Cell): 840 ... def __init__(self): 841 ... super(Net, self).__init__() 842 ... self.neighborexchange = ops.NeighborExchange(send_rank_ids=[1], recv_rank_ids=[1], 843 ... recv_shapes=([2, 2],), send_shapes=([3, 3],), 844 ... recv_type=ms.float32) 845 ... 846 ... 847 ... def construct(self, x): 848 ... out = self.neighborexchange((x,)) 849 ... 850 >>> ms.set_context(mode=ms.GRAPH_MODE) 851 >>> init() 852 >>> net = Net() 853 >>> input_x = Tensor(np.ones([3, 3]), dtype = ms.float32) 854 >>> output = net(input_x) 855 >>> print(output) 856 [[2. 2.], [2. 2.]] 857 858 Tutorial Examples: 859 - `Distributed Set Communication Primitives - NeighborExchange 860 <https://www.mindspore.cn/docs/en/master/api_python/samples/ops/communicate_ops.html#neighborexchange>`_ 861 862 """ 863 864 @prim_attr_register 865 def __init__(self, send_rank_ids, recv_rank_ids, recv_shapes, send_shapes, recv_type, 866 group=GlobalComm.WORLD_COMM_GROUP): 867 self.init_prim_io_names(inputs=['x'], outputs=['output']) 868 self.send_rank_ids = send_rank_ids 869 self.recv_rank_ids = recv_rank_ids 870 self.recv_shapes = recv_shapes 871 self.send_shapes = send_shapes 872 self.recv_type = recv_type 873 self.add_prim_attr('group', _get_group(group)) 874 self.add_prim_attr('no_eliminate', True) 875 876 def __call__(self, tensor): 877 raise NotImplementedError 878 879 880class AlltoAll(PrimitiveWithInfer): 881 r""" 882 AlltoAll is a collective operation. 883 884 AlltoAll sends data from the all processes to the all processes in the specified group. It has two phases: 885 886 - The scatter phase: On each process, the operand is split into split_count number of blocks along the 887 split_dimensions, and the blocks are scattered to all processes, e.g., the ith block is send to the ith process. 888 - The gather phase: Each process concatenates the received blocks along the concat_dimension. 889 890 Note: 891 This operator requires a full-mesh network topology, each device has the same vlan id, and the ip & mask are 892 in the same subnet, please check the `details \ 893 <https://www.mindspore.cn/docs/en/master/api_python/samples/ops/communicate_ops.html#notes>`_. 894 895 Args: 896 split_count (int): On each process, divide blocks into split_count number. 897 split_dim (int): On each process, split blocks along the split_dim. 898 concat_dim (int): On each process, gather the received blocks along the concat_dimension. 899 group (str): The communication group to work on. Default: ``GlobalComm.WORLD_COMM_GROUP`` . 900 901 Inputs: 902 - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. 903 904 Outputs: 905 Tensor. If the shape of input tensor is :math:`(x_1, x_2, ..., x_R)`, then the shape of output tensor is 906 :math:`(y_1, y_2, ..., y_R)`, where: 907 908 - :math:`y_{split\_dim} = x_{split\_dim} / split\_count` 909 - :math:`y_{concat\_dim} = x_{concat\_dim} * split\_count` 910 - :math:`y_{other} = x_{other}`. 911 912 Raises: 913 TypeError: If group is not a string. 914 915 Supported Platforms: 916 ``Ascend`` 917 918 Examples: 919 .. note:: 920 Before running the following examples, you need to configure the communication environment variables. 921 922 For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method 923 without any third-party or configuration file dependencies. 924 Please see the `msrun start up 925 <https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/msrun_launcher.html>`_ 926 for more details. 927 928 This example should be run with 8 devices. 929 930 >>> import os 931 >>> import mindspore as ms 932 >>> from mindspore import Tensor 933 >>> from mindspore.communication import init 934 >>> import mindspore.nn as nn 935 >>> from mindspore import ops 936 >>> import numpy as np 937 >>> class Net(nn.Cell): 938 ... def __init__(self): 939 ... super(Net, self).__init__() 940 ... self.alltoall = ops.AlltoAll(split_count = 8, split_dim = -2, concat_dim = -1) 941 ... 942 ... def construct(self, x): 943 ... out = self.alltoall(x) 944 ... return out 945 ... 946 >>> ms.set_context(mode=ms.GRAPH_MODE) 947 >>> init() 948 >>> net = Net() 949 >>> rank_id = int(os.getenv("RANK_ID")) 950 >>> input_x = Tensor(np.ones([1, 1, 8, 1]) * rank_id, dtype = ms.float32) 951 >>> output = net(input_x) 952 >>> print(output) 953 [[[[0. 1. 2. 3. 4. 5. 6. 7.]]]] 954 955 Tutorial Examples: 956 - `Distributed Set Communication Primitives - AlltoAll 957 <https://www.mindspore.cn/docs/en/master/api_python/samples/ops/communicate_ops.html#alltoall>`_ 958 959 """ 960 961 @prim_attr_register 962 def __init__(self, split_count, split_dim, concat_dim, group=GlobalComm.WORLD_COMM_GROUP): 963 """Initialize AlltoAll""" 964 validator.check_value_type('group', _get_group(group), (str,), self.name) 965 validator.check_is_int(split_count, int) 966 validator.check_is_int(split_dim, int) 967 validator.check_is_int(concat_dim, int) 968 self.split_count = split_count 969 self.split_dim = split_dim 970 self.concat_dim = concat_dim 971 self.add_prim_attr('group', _get_group(group)) 972 self.add_prim_attr('no_eliminate', True) 973 974 def infer_shape(self, x_shape): 975 rank_size = get_group_size(_get_group(self.group)) 976 if self.split_count != rank_size: 977 raise ValueError(f"For '{self.name}', the 'split_count' must be equal to 'rank_size', " 978 f"but got 'split_count': {self.split_count}, 'rank_size': {rank_size}.") 979 if x_shape[self.split_dim] >= 0 and x_shape[self.split_dim] % self.split_count != 0: 980 raise ValueError(f"For '{self.name}', the 'x_shape[self.split_dim]' must be divisible by 'split_count', " 981 f"but got 'x_shape[self.split_dim]' {x_shape[self.split_dim]}, " 982 f"'split_count' {self.split_count}.") 983 if x_shape[self.concat_dim] >= 0: 984 x_shape[self.concat_dim] = x_shape[self.concat_dim] * self.split_count 985 if x_shape[self.split_dim] >= 0: 986 x_shape[self.split_dim] = int(x_shape[self.split_dim] / self.split_count) 987 return x_shape 988 989 def infer_dtype(self, x_dtype): 990 check_collective_target_dtype('x', x_dtype, self.name) 991 return x_dtype 992 993 994class NeighborExchangeV2(Primitive): 995 r""" 996 NeighborExchangeV2 is a collective communication operation. 997 998 NeighborExchangeV2 sends data from the local rank to ranks in the `send_rank_ids`, 999 as while receive data from `recv_rank_ids`. Please refer to the tutorial examples 1000 below to learn about how the data is exchanged between neighborhood devices. 1001 1002 Note: 1003 This operator requires a full-mesh network topology, each device has the same vlan id, and the ip & mask are 1004 in the same subnet, please check the `details \ 1005 <https://www.mindspore.cn/docs/en/master/api_python/samples/ops/communicate_ops.html#notes>`_. 1006 1007 Args: 1008 send_rank_ids (list(int)): Ranks which the data is sent to. 8 rank_ids represents 8 directions, if one 1009 direction is not send to , set it -1. 1010 recv_rank_ids (list(int)): Ranks which the data is received from. 8 rank_ids represents 8 directions, 1011 if one direction is not recv from , set it -1. 1012 send_lens (list(int)): Data lens which send to the send_rank_ids, 4 numbers represent the lens of 1013 [send_top, send_bottom, send_left, send_right]. 1014 recv_lens (list(int)): Data lens which received from recv_rank_ids, 4 numbers represent the lens of 1015 [recv_top, recv_bottom, recv_left, recv_right]. 1016 data_format (str): Data format, only support NCHW now. 1017 group (str, optional): The communication group to work on. Default: ``GlobalComm.WORLD_COMM_GROUP`` , which 1018 means ``"hccl_world_group"`` in Ascend, and ``"nccl_world_group"`` in GPU. 1019 1020 Inputs: 1021 - **input_x** (Tensor) - The Tensor before being exchanged. It has a shape of :math:`(N, C, H, W)`. 1022 1023 Outputs: 1024 The Tensor after being exchanged. If input shape is :math:`(N, C, H, W)`, output shape is 1025 :math:`(N, C, H+recv\_top+recv\_bottom, W+recv\_left+recv\_right)`. 1026 1027 Raises: 1028 TypeError: If `group` is not a string or any one of `send_rank_ids`, 1029 `recv_rank_ids`, `send_lens`, `recv_lens` is not a list. 1030 ValueError: If `send_rank_ids` or `recv_rank_ids` has value less than -1 or has repeated values. 1031 ValueError: If `send_lens`, `recv_lens` has value less than 0. 1032 ValueError: If `data_format` is not "NCHW". 1033 1034 Supported Platforms: 1035 ``Ascend`` 1036 1037 Examples: 1038 .. note:: 1039 Before running the following examples, you need to configure the communication environment variables. 1040 1041 For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method 1042 without any third-party or configuration file dependencies. 1043 Please see the `msrun start up 1044 <https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/msrun_launcher.html>`_ 1045 for more details. 1046 1047 This example should be run with 2 devices. 1048 1049 >>> import os 1050 >>> import mindspore as ms 1051 >>> from mindspore.communication import init 1052 >>> import mindspore.nn as nn 1053 >>> from mindspore import ops 1054 >>> import numpy as np 1055 >>> 1056 >>> class Net0(nn.Cell): 1057 ... def __init__(self): 1058 ... super(Net0, self).__init__() 1059 ... self.neighbor_exchangev2 = ops.NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1], 1060 ... send_lens=[0, 1, 0, 0], 1061 ... recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1], 1062 ... recv_lens=[0, 1, 0, 0], data_format="NCHW") 1063 ... 1064 ... def construct(self, x): 1065 ... out = self.neighbor_exchangev2(x) 1066 ... return out 1067 ... class Net1(nn.Cell): 1068 ... def __init__(self): 1069 ... super(Net1, self).__init__() 1070 ... self.neighbor_exchangev2 = ops.NeighborExchangeV2(send_rank_ids=[0, -1, -1, -1, -1, -1, -1, -1], 1071 ... send_lens=[1, 0, 0, 0], 1072 ... recv_rank_ids=[0, -1, -1, -1, -1, -1, -1, -1], 1073 ... recv_lens=[1, 0, 0, 0], data_format="NCHW") 1074 ... 1075 ... def construct(self, x): 1076 ... out = self.neighbor_exchangev2(x) 1077 ... return out 1078 >>> 1079 >>> ms.set_context(mode=ms.GRAPH_MODE) 1080 >>> init() 1081 >>> rank_id = int(os.getenv("RANK_ID")) 1082 >>> if (rank_id % 2 == 0): 1083 >>> input_x = ms.Tensor(np.ones([1, 1, 2, 2]), dtype = ms.float32) 1084 >>> net = Net0() 1085 >>> output = net(input_x) 1086 >>> print(output) 1087 >>> else: 1088 >>> input_x = ms.Tensor(np.ones([1, 1, 2, 2]) * 2, dtype = ms.float32) 1089 >>> net = Net1() 1090 >>> output = net(input_x) 1091 >>> print(output) 1092 [[[[1. 1.], [1. 1.], [2. 2.]]]] 1093 1094 Tutorial Examples: 1095 - `Distributed Set Communication Primitives - NeighborExchangeV2 1096 <https://www.mindspore.cn/docs/en/master/api_python/samples/ops/communicate_ops.html#neighborexchangev2>`_ 1097 1098 """ 1099 1100 @prim_attr_register 1101 def __init__(self, send_rank_ids, send_lens, recv_rank_ids, recv_lens, data_format, 1102 group=GlobalComm.WORLD_COMM_GROUP): 1103 self.init_prim_io_names(inputs=['x'], outputs=['output']) 1104 self.send_rank_ids = send_rank_ids 1105 self.recv_rank_ids = recv_rank_ids 1106 self.send_lens = send_lens 1107 self.recv_lens = recv_lens 1108 self.format = data_format 1109 self.add_prim_attr('group', _get_group(group)) 1110 self.add_prim_attr('no_eliminate', True) 1111 self.rank_size = get_group_size(_get_group(group)) 1112 for rank_id in send_rank_ids: 1113 if rank_id != -1: 1114 validator.check_number_range(rank_id, 0, self.rank_size, validator.INC_LEFT, int, 1115 "rank_id in send_rank_ids") 1116 for rank_id in recv_rank_ids: 1117 if rank_id != -1: 1118 validator.check_number_range(rank_id, 0, self.rank_size, validator.INC_LEFT, int, 1119 "rank_id in recv_rank_ids") 1120 1121 def __call__(self, tensor): 1122 raise NotImplementedError 1123 1124 1125class CollectiveScatter(Primitive): 1126 r""" 1127 Scatter tensor evently across the processes in the specified communication group. 1128 1129 Note: 1130 The interface behavior only support Tensor input and scatter evenly. 1131 Only the tensor in process `src_rank` (global rank) will do scatter. 1132 1133 Args: 1134 src_rank (int, optional): Specifies the rank of the process that send the tensor. 1135 And only process `src_rank` will send the tensor. 1136 group (str, optional): The communication group to work on. Default: ``GlobalComm.WORLD_COMM_GROUP``. 1137 1138 Inputs: 1139 - **input_x** (Tensor) - The input tensor to be scattered. The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. 1140 1141 Outputs: 1142 Tensor, the shape of output is :math:`(x_1/src\_rank, x_2, ..., x_R)`. The dimension 0 of data is equal to 1143 the dimension of input tensor divided by `src`, and the other dimension keep the same. 1144 1145 Raises: 1146 TypeError: If `group` is not a str. 1147 RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails. 1148 ValueError: If the local rank id of the calling process in the group 1149 is larger than the group's rank size. 1150 1151 Supported Platforms: 1152 ``Ascend`` 1153 1154 Examples: 1155 .. note:: 1156 Before running the following examples, you need to configure the communication environment variables. 1157 1158 For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method 1159 without any third-party or configuration file dependencies. 1160 Please see the `msrun start up 1161 <https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/msrun_launcher.html>`_ 1162 for more details. 1163 1164 This example should be run with 2 devices. 1165 1166 >>> import numpy as np 1167 >>> import mindspore.nn as nn 1168 >>> from mindspore import Tensor 1169 >>> from mindspore.communication.management import init, get_rank 1170 >>> from mindspore import ops 1171 >>> # Launch 2 processes. 1172 >>> init() 1173 >>> class CollectiveScatterNet(nn.Cell): 1174 >>> def __init__(self): 1175 >>> super(CollectiveScatter, self).__init__() 1176 >>> self.collective_scatter = ops.CollectiveScatter(src_rank=0) 1177 >>> 1178 >>> def construct(self, x): 1179 >>> return self.collective_scatter(x) 1180 >>> 1181 >>> input = Tensor(np.arange(8).reshape([4, 2]).astype(np.float32)) 1182 >>> net = CollectiveScatterNet() 1183 >>> output = net(input) 1184 >>> print(output) 1185 Process with rank 0: [[0. 1.], 1186 [2. 3.]] 1187 Process with rank 1: [[4. 5.], 1188 [6. 7.]] 1189 1190 Tutorial Examples: 1191 - `Distributed Set Communication Primitives - CollectiveScatter 1192 <https://www.mindspore.cn/docs/en/master/api_python/samples/ops/communicate_ops.html#reducescatter>`_ 1193 1194 """ 1195 1196 @prim_attr_register 1197 def __init__(self, src_rank=0, group=GlobalComm.WORLD_COMM_GROUP): 1198 validator.check_value_type('group', _get_group(group), (str,), self.name) 1199 self.rank_id = get_rank(_get_group(group)) 1200 self.src_rank = src_rank 1201 self.rank_size = get_group_size(_get_group(group)) 1202 validator.check('rank', self.rank_id, 'rank_size', self.rank_size, validator.LT, self.name) 1203 self.add_prim_attr('rank_id', self.rank_id) 1204 self.add_prim_attr('src_rank', self.src_rank) 1205 self.add_prim_attr('rank_size', self.rank_size) 1206 self.add_prim_attr('group', _get_group(group)) 1207 1208 1209class CollectiveGather(Primitive): 1210 r""" 1211 Gathers tensors from the specified communication group. The operation will gather the tensor 1212 from processes according to dimension 0. 1213 1214 Note: 1215 Only the tensor in process `dest_rank` (global rank) will keep the gathered tensor. The other process 1216 will keep a tensor with shape [1], which has no mathematical meaning. 1217 1218 Args: 1219 dest_rank(int): Specifies the rank of the process that receive the tensor. 1220 And only process `dest_rank` will receive the gathered tensor. 1221 group (str, optional): The communication group to work on. Default: ``GlobalComm.WORLD_COMM_GROUP``. 1222 1223 Inputs: 1224 - **input_x** (Tensor) - The tensor to be gathered. The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. 1225 1226 Outputs: 1227 Tensor, the shape of output is :math:`(\sum x_1, x_2, ..., x_R)`. The dimension 0 of data is equal to 1228 sum of the dimension of input tensor, and the other dimension keep the same. 1229 1230 Raises: 1231 TypeError: If `group` is not a str. 1232 RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails. 1233 ValueError: If the local rank id of the calling process in the group 1234 is larger than the group's rank size. 1235 1236 Supported Platforms: 1237 ``Ascend`` 1238 1239 Examples: 1240 .. note:: 1241 Before running the following examples, you need to configure the communication environment variables. 1242 1243 For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method 1244 without any third-party or configuration file dependencies. 1245 Please see the `msrun start up 1246 <https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/msrun_launcher.html>`_ 1247 for more details. 1248 1249 This example should be run with 4 devices. 1250 1251 >>> import numpy as np 1252 >>> import mindspore as ms 1253 >>> import mindspore.nn as nn 1254 >>> from mindspore.communication import init 1255 >>> from mindspore import Tensor 1256 >>> from mindspore import ops 1257 >>> # Launch 2 processes. 1258 >>> 1259 >>> ms.set_context(mode=ms.GRAPH_MODE) 1260 >>> init() 1261 >>> class CollectiveGatherNet(nn.Cell): 1262 ... def __init__(self): 1263 ... super(CollectiveGatherNet, self).__init__() 1264 ... self.collective_gather = ops.CollectiveGather(dest_rank=0) 1265 ... 1266 ... def construct(self, x): 1267 ... return self.collective_gather(x) 1268 ... 1269 >>> input = Tensor(np.arange(4).reshape([2, 2]).astype(np.float32)) 1270 >>> net = CollectiveGatherNet() 1271 >>> output = net(input) 1272 >>> print(output) 1273 Process with rank 0: [[0. 1.], 1274 [2. 3.], 1275 [0. 1.], 1276 [2. 3.]] 1277 Process with rank 1: [0.] 1278 1279 Tutorial Examples: 1280 - `Distributed Set Communication Primitives - CollectiveGather 1281 <https://www.mindspore.cn/docs/en/master/api_python/samples/ops/communicate_ops.html#collectivegather>`_ 1282 1283 """ 1284 1285 @prim_attr_register 1286 def __init__(self, dest_rank, group=GlobalComm.WORLD_COMM_GROUP): 1287 """Initialize Gather.""" 1288 validator.check_value_type('group', _get_group(group), (str,), self.name) 1289 self.rank_id = get_rank(_get_group(group)) 1290 self.dest_rank = dest_rank 1291 self.rank_size = get_group_size(_get_group(group)) 1292 validator.check('rank', self.rank_id, 'rank_size', self.rank_size, validator.LT, self.name) 1293 self.add_prim_attr('rank_size', self.rank_size) 1294 self.add_prim_attr('group', _get_group(group)) 1295 self.add_prim_attr('dest_rank', self.dest_rank) 1296 self.add_prim_attr('rank_id', self.rank_id) 1297 1298 1299class Barrier(PrimitiveWithInfer): 1300 """ 1301 Synchronizes all processes in the specified group. Once the process call this operation, it will be blocked until 1302 all processes call this operation. After all processes finish calling the operations, the blocked processes 1303 will be waken and continue their task. 1304 1305 Args: 1306 group (str, optional): The communication group to work on. Default: ``GlobalComm.WORLD_COMM_GROUP``. 1307 1308 Raises: 1309 TypeError: If `group` is not a str. 1310 RuntimeError: If backend is invalid, or distributed initialization fails. 1311 ValueError: If the local rank id of the calling process in the group 1312 is larger than the group's rank size. 1313 1314 Supported Platforms: 1315 ``Ascend`` 1316 1317 Examples: 1318 .. note:: 1319 Before running the following examples, you need to configure the communication environment variables. 1320 1321 For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method 1322 without any third-party or configuration file dependencies. 1323 Please see the `msrun start up 1324 <https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/msrun_launcher.html>`_ 1325 for more details. 1326 1327 This example should be run with 2 devices. 1328 1329 >>> import numpy as np 1330 >>> import mindspore.nn as nn 1331 >>> from mindspore.communication import init 1332 >>> from mindspore import Tensor 1333 >>> from mindspore import ops 1334 >>> # Launch 4 processes. 1335 >>> init() 1336 >>> class BarrierNet(nn.Cell): 1337 >>> def __init__(self): 1338 >>> super(BarrierNet, self).__init__() 1339 >>> self.barrier = ops.Barrier() 1340 >>> 1341 >>> def construct(self): 1342 >>> self.barrier() 1343 >>> net = BarrierNet() 1344 >>> net() 1345 1346 Tutorial Examples: 1347 - `Distributed Set Communication Primitives - Barrier 1348 <https://www.mindspore.cn/docs/en/master/api_python/samples/ops/communicate_ops.html#barrier>`_ 1349 1350 """ 1351 1352 @prim_attr_register 1353 def __init__(self, group=GlobalComm.WORLD_COMM_GROUP): 1354 self.group = group 1355 self.add_prim_attr("side_effect_mem", True) 1356 1357 def infer_shape(self): 1358 return [1] 1359 1360 def infer_dtype(self): 1361 return mstype.float32 1362 1363 1364class Send(PrimitiveWithInfer): 1365 """ 1366 Send tensors to the specified dest_rank. 1367 1368 Note: 1369 Send and Receive must be used in combination and have same sr_tag. 1370 1371 Args: 1372 sr_tag (int): The tag to identify the send/recv message. The message will 1373 be received by the Receive op with the same "sr_tag". 1374 dest_rank (int): A required integer identifying the destination rank. 1375 group (str, optional): The communication group to work on. Default: ``GlobalComm.WORLD_COMM_GROUP``. 1376 group_back (str, optional): The communication group for backpropagation. 1377 Default: ``GlobalComm.WORLD_COMM_GROUP``. 1378 1379 Inputs: 1380 - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. 1381 1382 Raises: 1383 TypeError: If `group` is not a str. 1384 RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails. 1385 ValueError: If the local rank id of the calling process in the group 1386 is larger than the group's rank size. 1387 1388 Supported Platforms: 1389 ``Ascend`` ``GPU`` 1390 1391 Examples: 1392 .. note:: 1393 Before running the following examples, you need to configure the communication environment variables. 1394 1395 For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method 1396 without any third-party or configuration file dependencies. 1397 Please see the `msrun start up 1398 <https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/msrun_launcher.html>`_ 1399 for more details. 1400 1401 This example should be run with 2 devices. 1402 1403 >>> import numpy as np 1404 >>> import mindspore.nn as nn 1405 >>> from mindspore.communication import init 1406 >>> from mindspore import Tensor 1407 >>> from mindspore import ops 1408 >>> 1409 >>> init() 1410 >>> class SendNet(nn.Cell): 1411 >>> def __init__(self): 1412 >>> super(SendNet, self).__init__() 1413 >>> self.depend = ops.Depend() 1414 >>> self.send = ops.Send(st_tag=0, dest_rank=8, group="hccl_world_group") 1415 >>> 1416 >>> def construct(self, x): 1417 >>> out = self.depend(x, self.send(x)) 1418 >>> return out 1419 >>> 1420 >>> input_ = Tensor(np.ones([2, 8]).astype(np.float32)) 1421 >>> net = Net() 1422 >>> output = net(input_) 1423 1424 Tutorial Examples: 1425 - `Distributed Set Communication Primitives - Send 1426 <https://www.mindspore.cn/docs/en/master/api_python/samples/ops/communicate_ops.html#send>`_ 1427 1428 """ 1429 1430 @prim_attr_register 1431 def __init__(self, sr_tag, dest_rank, group=GlobalComm.WORLD_COMM_GROUP, group_back=GlobalComm.WORLD_COMM_GROUP): 1432 self.rank = dest_rank 1433 self.sr_tag = sr_tag 1434 self.group = group 1435 self.add_prim_attr("no_eliminate", True) 1436 1437 def infer_shape(self, x_shape): 1438 self.add_prim_attr("shape", x_shape) 1439 return x_shape 1440 1441 def infer_dtype(self, x_dtype): 1442 return x_dtype 1443 1444 1445class Receive(PrimitiveWithInfer): 1446 """ 1447 Receive tensors from src_rank. 1448 1449 Note: 1450 Send and Receive must be used in combination and have same sr_tag. 1451 1452 Args: 1453 sr_tag (int): A required integer identifying the send/recv message tag. The message will 1454 will be send by the Send op with the same "sr_tag". 1455 src_rank (int): A required integer identifying the source rank. 1456 shape (list[int]): A required list identifying the shape of the tensor to be received. 1457 dtype (Type): A required Type identifying the type of the tensor to be received. The supported types: 1458 int8/int16/int32/float16/float32. 1459 group (str, optional): The communication group to work on. Default: ``GlobalComm.WORLD_COMM_GROUP``. 1460 group_back (str, optional): The communication group for backpropagation. 1461 Default: ``GlobalComm.WORLD_COMM_GROUP``. 1462 1463 Outputs: 1464 Tensor, output has the same shape as the Tensor sent by `Send` operation. 1465 1466 Raises: 1467 TypeError: If `group` is not a str. 1468 RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails. 1469 ValueError: If the local rank id of the calling process in the group 1470 is larger than the group's rank size. 1471 1472 Supported Platforms: 1473 ``Ascend`` ``GPU`` 1474 1475 Examples: 1476 .. note:: 1477 Before running the following examples, you need to configure the communication environment variables. 1478 1479 For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method 1480 without any third-party or configuration file dependencies. 1481 Please see the `msrun start up 1482 <https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/msrun_launcher.html>`_ 1483 for more details. 1484 1485 This example should be run with 2 devices. 1486 1487 >>> import numpy as np 1488 >>> import mindspore.nn as nn 1489 >>> from mindspore.communication import init 1490 >>> from mindspore import Tensor 1491 >>> from mindspore import ops 1492 >>> 1493 >>> init() 1494 >>> class ReceiveNet(nn.Cell): 1495 >>> def __init__(self): 1496 >>> super(ReceiveNet, self).__init__() 1497 >>> self.recv = ops.Receive(sr_tag=0, src_rank=0, shape=[2, 8], dtype=ms.float32, 1498 >>> group="hccl_world_group") 1499 >>> 1500 >>> def construct(self): 1501 >>> out = self.recv() 1502 >>> return out 1503 >>> 1504 >>> net = Net() 1505 >>> output = net() 1506 1507 Tutorial Examples: 1508 - `Distributed Set Communication Primitives - Receive 1509 <https://www.mindspore.cn/docs/en/master/api_python/samples/ops/communicate_ops.html#receive>`_ 1510 1511 """ 1512 1513 @prim_attr_register 1514 def __init__(self, sr_tag, src_rank, shape, dtype, group=GlobalComm.WORLD_COMM_GROUP, 1515 group_back=GlobalComm.WORLD_COMM_GROUP): 1516 self.rank = src_rank 1517 self.tag = sr_tag 1518 self.shape = shape 1519 self.dtype = dtype 1520 self.group = group 1521 self.add_prim_attr("no_eliminate", True) 1522 valid_type = [mstype.float16, mstype.float32, mstype.float64, mstype.bfloat16, 1523 mstype.int8, mstype.int16, mstype.int32, mstype.int64, 1524 mstype.uint8, mstype.uint16, mstype.uint32, mstype.uint64] 1525 args = {"dtype": dtype} 1526 validator.check_scalar_or_tensor_types_same(args, valid_type, self.name) 1527 1528 def infer_shape(self, x_shape=None): 1529 return self.get_attr_dict()['shape'] 1530 1531 def infer_dtype(self, x_dtype=None): 1532 return self.get_attr_dict()['dtype'] 1533 1534 1535class _MirrorOperator(PrimitiveWithInfer): 1536 """ 1537 Auto parallel virtual operator. Do nothing in forward, do all reduce and mean in backward. It is only for 1538 internal use of parallel modules and cannot be called by users. 1539 1540 Args: 1541 group (str): The communication group to work on. Default: ``None`` . 1542 dev_num (int): The device number of the group. Default: ``None`` . 1543 mean_flag (bool): Whether use mean in backward. Default: ``None`` . 1544 """ 1545 1546 @prim_attr_register 1547 def __init__(self, group=None, dev_num=None, mean_flag=None): 1548 """Initialize _MirrorOperator.""" 1549 self.group = group 1550 self.dev_num = dev_num 1551 self.mean_flag = mean_flag 1552 self.add_prim_attr("fusion", 1) 1553 self.add_prim_attr('order_enforce_skip', True) 1554 1555 def infer_shape(self, x_shape): 1556 return x_shape 1557 1558 def infer_dtype(self, x_dtype): 1559 return x_dtype 1560 1561 1562mirror = _MirrorOperator() 1563 1564 1565class _MirrorMiniStepOperator(PrimitiveWithInfer): 1566 """ 1567 Auto parallel virtual operator. Do nothing in forward, do all reduce and mean in backward. It is only for 1568 internal use of parallel modules and cannot be called by users. 1569 1570 Args: 1571 group (str): The communication group to work on. Default: ``None`` . 1572 dev_num (int): The device number of the group. Default: ``None`` . 1573 mean_flag (bool): Whether use mean in backward. Default: ``None`` . 1574 grad_accumulation_step (int): The grad accumulation step. Default: ``None`` . 1575 """ 1576 1577 @prim_attr_register 1578 def __init__(self, group=None, dev_num=None, mean_flag=None, grad_accumulation_step=None): 1579 """Initialize _MirrorMiniStepOperator.""" 1580 self.group = group 1581 self.dev_num = dev_num 1582 self.mean_flag = mean_flag 1583 self.grad_accumulation_step = grad_accumulation_step 1584 self.add_prim_attr('order_enforce_skip', True) 1585 self.add_prim_attr('side_effect_backprop_mem', True) 1586 1587 def infer_shape(self, x_shape, z_shape): 1588 return x_shape 1589 1590 def infer_dtype(self, x_dtype, z_shape): 1591 return x_dtype 1592 1593 1594mirror_mini_step = _MirrorMiniStepOperator() 1595 1596 1597class _VirtualDiv(PrimitiveWithInfer): 1598 """ 1599 Auto parallel virtual operator. Do nothing in forward, do Div in backward. 1600 1601 Args: 1602 divisor: float32 1603 """ 1604 1605 @prim_attr_register 1606 def __init__(self, divisor=None): 1607 """Initialize _VirtualDiv.""" 1608 self.divisor = divisor 1609 self.add_prim_attr('order_enforce_skip', True) 1610 1611 def infer_shape(self, x_shape): 1612 return x_shape 1613 1614 def infer_dtype(self, x_dtype): 1615 return x_dtype 1616 1617 1618virtual_div = _VirtualDiv() 1619 1620 1621class _VirtualPipelineEnd(PrimitiveWithInfer): 1622 """ 1623 Auto parallel virtual operator. Do nothing in forward and backward, mark end node in pipeline parallel. 1624 1625 Args: 1626 divisor: float32 1627 """ 1628 1629 @prim_attr_register 1630 def __init__(self): 1631 """Initialize _VirtualPipelineEnd.""" 1632 1633 def infer_shape(self, x_shape): 1634 return x_shape 1635 1636 def infer_dtype(self, x_dtype): 1637 return x_dtype 1638 1639 1640virtual_pipeline_end = _VirtualPipelineEnd() 1641 1642 1643class _VirtualAdd(PrimitiveWithInfer): 1644 """Auto parallel virtual operator. Do nothing in forward, do Add in backward.""" 1645 1646 @prim_attr_register 1647 def __init__(self): 1648 """Initialize _VirtualAdd.""" 1649 self.add_prim_attr('order_enforce_skip', True) 1650 1651 def infer_shape(self, x_shape, y_shape): 1652 return x_shape 1653 1654 def infer_dtype(self, x_dtype, y_dtype): 1655 return x_dtype 1656 1657 1658class _VirtualDataset(PrimitiveWithInfer): 1659 """ 1660 Auto parallel virtual dataset operator. 1661 1662 It would insert VirtualDataset operator in forward computation and be deleted before backward computation. 1663 """ 1664 1665 @prim_attr_register 1666 def __init__(self): 1667 """Initialize _VirtualDataset.""" 1668 self.add_prim_attr('order_enforce_skip', True) 1669 1670 def infer_shape(self, *args): 1671 return args 1672 1673 def infer_dtype(self, *args): 1674 return args 1675 1676 1677virtual_dataset = _VirtualDataset() 1678 1679 1680class _VirtualAssignAdd(PrimitiveWithInfer): 1681 """ 1682 Auto parallel virtual operator. Do nothing in forward, do AssignAdd in backward. It is only for 1683 internal use of parallel modules and cannot be called by users. 1684 1685 """ 1686 1687 @prim_attr_register 1688 def __init__(self): 1689 """Initialize _VirtualAssignAdd.""" 1690 self.add_prim_attr('order_enforce_skip', True) 1691 self.add_prim_attr('side_effect_backprop_mem', True) 1692 1693 def infer_shape(self, x_shape, y_shape): 1694 return x_shape 1695 1696 def infer_dtype(self, x_dtype, y_dtype): 1697 return x_dtype 1698virtual_assign_add = _VirtualAssignAdd() 1699 1700 1701class _VirtualAccuGrad(PrimitiveWithInfer): 1702 """ 1703 Auto parallel virtual operator. Do nothing in forward, return y in backward. It is only for 1704 internal use of parallel modules and cannot be called by users. 1705 """ 1706 1707 @prim_attr_register 1708 def __init__(self): 1709 """Initialize _VirtualAccuGrad.""" 1710 self.add_prim_attr('order_enforce_skip', True) 1711 1712 def infer_shape(self, x_shape, y_shape): 1713 return x_shape 1714 1715 def infer_dtype(self, x_dtype, y_dtype): 1716 return x_dtype 1717 1718 1719virtual_accu_grad = _VirtualAccuGrad() 1720 1721 1722class _MirrorMicroStepOperator(PrimitiveWithInfer): 1723 """ 1724 Auto parallel virtual operator. Do nothing in forward, do all reduce and mean in backward. It is only for 1725 internal use of parallel modules and cannot be called by users. 1726 1727 Args: 1728 group (str): The communication group to work on. Default: ``None`` . 1729 dev_num (int): The device number of the group. Default: ``None`` . 1730 mean_flag (bool): Whether use mean in backward. Default: ``None`` . 1731 """ 1732 1733 @prim_attr_register 1734 def __init__(self, group=None, dev_num=None, mean_flag=None): 1735 """Initialize _MirrorMicroStepOperator.""" 1736 self.group = group 1737 self.dev_num = dev_num 1738 self.mean_flag = mean_flag 1739 self.add_prim_attr('order_enforce_skip', True) 1740 self.add_prim_attr('side_effect_backprop_mem', True) 1741 1742 def infer_shape(self, x_shape, z_shape): 1743 return x_shape 1744 1745 def infer_dtype(self, x_dtype, z_shape): 1746 return x_dtype 1747 1748 1749class _VirtualOutput(PrimitiveWithInfer): 1750 """ 1751 Auto parallel virtual out operator. 1752 1753 It would insert VirtualOutput operator in forward computation and be deleted before backward computation. 1754 """ 1755 1756 @prim_attr_register 1757 def __init__(self): 1758 """Initialize _VirtualOutput.""" 1759 self.add_prim_attr('order_enforce_skip', True) 1760 1761 def infer_shape(self, x_shape): 1762 return x_shape 1763 1764 def infer_dtype(self, x_dtype): 1765 return x_dtype 1766 1767 1768class _GetTensorSlice(PrimitiveWithInfer): 1769 """ 1770 Gets tensor slice by device matrix and tensor map. 1771 1772 Args: 1773 dev_mat (tuple): The device matrix of the slice tensor. 1774 tensor_map (tuple): The tensor map of the slice tensor. 1775 """ 1776 1777 @prim_attr_register 1778 def __init__(self): 1779 """Initialize _GetTensorSlice.""" 1780 self.add_prim_attr('order_enforce_skip', True) 1781 1782 def infer_value(self, x, dev_mat, tensor_map, slice_shape, full_shape): 1783 from mindspore.parallel._tensor import _load_tensor 1784 validator.check_value_type("dev_mat", dev_mat, [tuple], self.name) 1785 validator.check_value_type("tensor_map", tensor_map, [tuple], self.name) 1786 tensor_slice = _load_tensor(x, dev_mat, tensor_map, full_shape) 1787 if tensor_slice.shape != slice_shape: 1788 tensor_slice = tensor_slice.reshape(slice_shape) 1789 return Tensor(tensor_slice, x.dtype) 1790 1791 1792class BatchISendIRecv(PrimitiveWithInfer): 1793 """ 1794 Batch send and recv tensors asynchronously. 1795 1796 Note: 1797 - The ``isend`` and ``irecv`` in ``op_types`` between ranks need to match each other. 1798 - ``isend`` and ``irecv`` in a batch can only be used in the same communication group. 1799 1800 Args: 1801 op_types(Union[tuple[str], list[str]]): "isend" or "irecv" to indicate the order and number of communication. 1802 remote_ranks(Union[tuple[int], list[int]]): src or dst rank that matches the op_types. 1803 receive_shapes(Union[tuple[int], list[int]]): receive tensor shapes that matches "irecv" in op_types. 1804 receive_types(Union[tuple[mindspore.dtype], list[mindspore.dtype]]): receive tensor dtype 1805 that matches "irecv" in op_types. 1806 group (str): The communication group to work on. Default: ``GlobalComm.WORLD_COMM_GROUP``, which 1807 means ``"hccl_world_group"`` in Ascend, and ``"nccl_world_group"`` in GPU. 1808 1809 Inputs: 1810 - **input_x** (Union[tuple[Tensor], list[Tensor], tuple(None)]) - 1811 The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. 1812 1813 Outputs: 1814 tuple(Tensor). Output tensors is corresponding to ``op_types``: 1815 At ``"isend"`` position, output tensor is a fake tensor with scalar, which has no meaning. 1816 At ``"irecv"`` position, output tensor is a tensor received from remote end. 1817 1818 1819 Raises: 1820 TypeError: If ``group`` is not a str. 1821 TypeError: If ``op_types``, ``receive_shapes``, ``receive_dtypes``, ``remote_ranks`` are not tuple or list. 1822 ValueError: If the length of ``receive_shapes`` and ``receive_dtypes`` are not the same. 1823 ValueError: If the length of ``op_types`` and ``remote_ranks`` are not the same. 1824 RuntimeError: If the length of input tensors and ``"isend"`` count are not the same. 1825 1826 Supported Platforms: 1827 ``Ascend`` 1828 1829 Examples: 1830 .. note:: 1831 Before running the following examples, you need to configure the communication environment variables. 1832 1833 For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method 1834 without any third-party or configuration file dependencies. 1835 1836 Please see the `msrun start up 1837 <https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/msrun_launcher.html>`_ 1838 for more details. 1839 1840 This example should be run with 2 devices. 1841 1842 >>> import numpy as np 1843 >>> import mindspore as ms 1844 >>> from mindspore import ops 1845 >>> import mindspore.nn as nn 1846 >>> from mindspore.communication import init, get_rank 1847 >>> from mindspore import Tensor 1848 >>> 1849 >>> init() 1850 >>> rank = get_rank() 1851 >>> class Net(nn.Cell): 1852 ... def __init__(self): 1853 ... super(Net, self).__init__() 1854 ... if rank == 0: 1855 ... remote_rank = [1, 1] 1856 ... else: 1857 ... remote_rank = [0, 0] 1858 ... self.batchisendirecv = ops.BatchISendIRecv(("isend", "irecv"), remote_rank, [()], (ms.float32,)) 1859 ... 1860 ... def construct(self, x): 1861 ... if isinstance(x, Tensor): 1862 ... x = (x,) 1863 ... return self.batchisendirecv(x) 1864 ... 1865 >>> send_x = Tensor(rank + 1).astype(ms.float32) 1866 >>> net = Net() 1867 >>> output = net(send_x) 1868 >>> print(output) 1869 rank 0: 1870 (Tensor(shape=[], dtype=Float32, value= 0), Tensor(shape=[], dtype=Float32, value= 2)) 1871 rank 1: 1872 (Tensor(shape=[], dtype=Float32, value= 0), Tensor(shape=[], dtype=Float32, value= 1)) 1873 1874 Tutorial Examples: 1875 - `Distributed Set Communication Primitives - BatchISendIRecv 1876 <https://www.mindspore.cn/docs/en/master/api_python/samples/ops/communicate_ops.html#allgather>`_ 1877 1878 """ 1879 1880 @prim_attr_register 1881 def __init__(self, op_types, remote_ranks, receive_shapes=None, 1882 receive_dtypes=None, group=GlobalComm.WORLD_COMM_GROUP): 1883 if receive_shapes is None: 1884 receive_shapes = () 1885 else: 1886 validator.check_value_type("receive_shapes", receive_shapes, [tuple, list], self.name) 1887 1888 if receive_dtypes is None: 1889 receive_dtypes = () 1890 else: 1891 validator.check_value_type("receive_dtypes", receive_dtypes, [tuple, list], self.name) 1892 1893 validator.check_value_type("op_types", op_types, [tuple, list], self.name) 1894 validator.check_value_type("remote_ranks", remote_ranks, [tuple, list], self.name) 1895 1896 if len(receive_shapes) != len(receive_dtypes): 1897 raise ValueError("length of receive_shapes and receive_shapes must be the same, " 1898 f"but got receive_shapes: {len(receive_shapes)} " 1899 f" and receive_shapes: {receive_dtypes}") 1900 1901 if len(op_types) != len(remote_ranks): 1902 raise ValueError("length of op_types and remote_ranks must be the same.") 1903 1904 if group is None: 1905 group = GlobalComm.WORLD_COMM_GROUP 1906 self.add_prim_attr('group', group) 1907 self.add_prim_attr('op_types', op_types) 1908 self.add_prim_attr('remote_ranks', remote_ranks) 1909 self.add_prim_attr('receive_shapes', receive_shapes) 1910 self.add_prim_attr('receive_dtypes', receive_dtypes) 1911 self.add_prim_attr('no_eliminate', True) 1912 1913 1914class AlltoAllV(PrimitiveWithInfer): 1915 """ 1916 AllToAll which support uneven split. 1917 1918 Note: 1919 - Only support flatten tensor as input. input tensor should be flattened and 1920 concatenated before call this primitive. 1921 1922 Args: 1923 send_numel_list(Union[tuple[int], list[int]]): split numel to scatter to different remote rank. 1924 recv_numel_list(Union[tuple[int], list[int]]): split numel to gather from different remote rank. 1925 group (str): The communication group to work on. Default: ``GlobalComm.WORLD_COMM_GROUP``, which 1926 means ``"hccl_world_group"`` in Ascend, and ``"nccl_world_group"`` in GPU. 1927 1928 Inputs: 1929 - **input_x** (Tensor) - flatten tensor to scatter. The shape of tensor is :math:`(x_1)`. 1930 1931 Outputs: 1932 Tensor. flattened and concatenated tensor gather from remote ranks. 1933 If gather result is empty, it will return a Tensor with value 0, which has no actual meaning. 1934 1935 Raises: 1936 TypeError: If 'send_numel_list' or 'recv_numel_list' is not type of tuple and list. 1937 1938 Supported Platforms: 1939 ``Ascend`` 1940 1941 Examples: 1942 .. note:: 1943 Before running the following examples, you need to configure the communication environment variables. 1944 1945 For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method 1946 without any third-party or configuration file dependencies. 1947 1948 Please see the `msrun start up 1949 <https://www.mindspore.cn/tutorials/experts/zh-CN/master/parallel/msrun_launcher.html>`_ 1950 for more details. 1951 1952 This example should be run with 2 devices. 1953 1954 >>> import numpy as np 1955 >>> import mindspore as ms 1956 >>> from mindspore import ops 1957 >>> import mindspore.nn as nn 1958 >>> from mindspore.communication import init, get_rank 1959 >>> from mindspore import Tensor 1960 >>> 1961 >>> init() 1962 >>> rank = get_rank() 1963 >>> class Net(nn.Cell): 1964 ... def __init__(self): 1965 ... super(Net, self).__init__() 1966 ... if rank == 0: 1967 ... self.all_to_all = ops.AlltoAllV([1, 2], [1, 2]) 1968 ... else: 1969 ... self.all_to_all = ops.AlltoAllV([2, 1], [2, 1]) 1970 ... 1971 ... def construct(self, x): 1972 ... return self.all_to_all(x) 1973 ... 1974 >>> if rank == 0: 1975 >>> send_tensor = Tensor([0, 1, 2.]) 1976 >>> elif rank == 1: 1977 >>> send_tensor = Tensor([3, 4, 5.]) 1978 >>> net = Net() 1979 >>> output = net(send_tensor) 1980 >>> print(output) 1981 rank 0: 1982 [0. 3. 4] 1983 rank 1: 1984 [1. 2. 5] 1985 1986 """ 1987 1988 @prim_attr_register 1989 def __init__(self, send_numel_list, recv_numel_list, group=None): 1990 validator.check_value_type("send_numel_list", send_numel_list, [tuple, list], self.name) 1991 validator.check_value_type("recv_numel_list", recv_numel_list, [tuple, list], self.name) 1992 if group is None: 1993 group = GlobalComm.WORLD_COMM_GROUP 1994 self.add_prim_attr('group', group) 1995 self.add_prim_attr('send_numel_list', send_numel_list) 1996 self.add_prim_attr('recv_numel_list', recv_numel_list) 1997