1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16"""comm_ops""" 17 18from mindspore.common import Tensor 19from ..._checkparam import Validator as validator 20from ..._checkparam import Rel 21from ...communication.management import get_rank, get_group_size, GlobalComm, _get_group 22from ...common import dtype as mstype 23from ..primitive import PrimitiveWithInfer, PrimitiveWithCheck, Primitive, prim_attr_register 24from ...common.api import context 25 26 27class ReduceOp: 28 """ 29 Operation options for reducing tensors. This is an enumerated type, not an operator. 30 Mainly used in data parallel mode. 31 32 The main calling methods are as follows: 33 34 - SUM: ReduceOp.SUM. 35 - MAX: ReduceOp.MAX. 36 - MIN: ReduceOp.MIN. 37 - PROD: ReduceOp.PROD. 38 39 There are four kinds of operation options, "SUM", "MAX", "MIN", and "PROD". 40 41 - SUM: Take the sum. 42 - MAX: Take the maximum. 43 - MIN: Take the minimum. 44 - PROD: Take the product. 45 46 Note: 47 For more, refer to example. This needs to run in an environment with multiple graphics cards. 48 49 Supported Platforms: 50 ``Ascend`` ``GPU`` 51 52 Examples: 53 >>> from mindspore.communication import init 54 >>> from mindspore import Tensor, ops 55 >>> from mindspore.ops import ReduceOp 56 >>> import mindspore.nn as nn 57 >>> 58 >>> init() 59 >>> class Net(nn.Cell): 60 ... def __init__(self): 61 ... super(Net, self).__init__() 62 ... self.allreduce_sum = ops.AllReduce(ReduceOp.SUM, group="nccl_world_group") 63 ... 64 ... def construct(self, x): 65 ... return self.allreduce_sum(x) 66 ... 67 >>> input_ = Tensor(np.ones([2, 8]).astype(np.float32)) 68 >>> net = Net() 69 >>> output = net(input_) 70 >>> print(output) 71 [[4. 5. 6. 0. 0. 0. 0. 0.] 72 [0. 0. 0. 0. 0. 0. 0. 0.]] 73 """ 74 SUM = "sum" 75 MAX = "max" 76 MIN = "min" 77 PROD = "prod" 78 79 80target_dtypes = (mstype.int8, mstype.int32, mstype.float16, mstype.float32) 81 82 83def check_hcom_group_valid(group, prim_name=None): 84 """Check if hcom group is valid.""" 85 msg_pfefix = f"For '{prim_name}', only" if prim_name else "Only" 86 if context.get_context("mode") == context.PYNATIVE_MODE and \ 87 context.get_context("device_target") == "Ascend" and \ 88 group != GlobalComm.WORLD_COMM_GROUP: 89 raise RuntimeError(f"{msg_pfefix} hccl_world_group is supported in Pynative mode, but got 'group': {group}.") 90 91 92class AllReduce(PrimitiveWithInfer): 93 """ 94 Reduces the tensor data across all devices in such a way that all devices will get the same final result. 95 96 Note: 97 The operation of AllReduce does not support "prod" currently. 98 The tensors must have the same shape and format in all processes of the collection. 99 100 Args: 101 op (str): Specifies an operation used for element-wise reductions, 102 like sum, max, and min. Default: ReduceOp.SUM. 103 group (str): The communication group to work on. Default: "hccl_world_group". 104 105 Inputs: 106 - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. 107 108 Outputs: 109 Tensor, has the same shape of the input, i.e., :math:`(x_1, x_2, ..., x_R)`. 110 The contents depend on the specified operation. 111 112 Raises: 113 TypeError: If any of `op` and `group` is not a str, 114 or fusion is not an integer, or the input's dtype is bool. 115 ValueError: If the `op` is "prod". 116 117 Supported Platforms: 118 ``Ascend`` ``GPU`` 119 120 Examples: 121 >>> # This example should be run with two devices. Refer to the tutorial > Distributed Training on mindspore.cn 122 >>> import numpy as np 123 >>> from mindspore.communication import init 124 >>> from mindspore import Tensor 125 >>> from mindspore.ops import ReduceOp 126 >>> import mindspore.nn as nn 127 >>> import mindspore.ops as ops 128 >>> 129 >>> init() 130 >>> class Net(nn.Cell): 131 ... def __init__(self): 132 ... super(Net, self).__init__() 133 ... self.allreduce_sum = ops.AllReduce(ReduceOp.SUM) 134 ... 135 ... def construct(self, x): 136 ... return self.allreduce_sum(x) 137 ... 138 >>> input_ = Tensor(np.ones([2, 8]).astype(np.float32)) 139 >>> net = Net() 140 >>> output = net(input_) 141 >>> print(output) 142 [[2. 2. 2. 2. 2. 2. 2. 2.] 143 [2. 2. 2. 2. 2. 2. 2. 2.]] 144 """ 145 146 @prim_attr_register 147 def __init__(self, op=ReduceOp.SUM, group=GlobalComm.WORLD_COMM_GROUP): 148 """Initialize AllReduce.""" 149 if not isinstance(op, type(ReduceOp.SUM)): 150 raise TypeError(f"For '{self.name}', the 'op' should be str, but got {type(op).__name__}.") 151 if not isinstance(_get_group(group), str): 152 raise TypeError(f"For '{self.name}', the 'group' should be str, " 153 f"but got {type(_get_group(group)).__name__}.") 154 check_hcom_group_valid(group, prim_name=self.name) 155 self.op = op 156 self.add_prim_attr('group', _get_group(group)) 157 self.add_prim_attr('fusion', 0) 158 self.add_prim_attr('index', 0) 159 self.add_prim_attr('no_elimilate', True) 160 161 def infer_shape(self, x_shape): 162 return x_shape 163 164 def infer_dtype(self, x_dtype): 165 validator.check_tensor_dtype_valid('x', x_dtype, target_dtypes, self.name) 166 return x_dtype 167 168 169class AllGather(PrimitiveWithInfer): 170 """ 171 Gathers tensors from the specified communication group. 172 173 Note: 174 The tensors must have the same shape and format in all processes of the collection. 175 176 Args: 177 group (str): The communication group to work on. Default: "hccl_world_group". 178 179 Inputs: 180 - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. 181 182 Outputs: 183 Tensor. If the number of devices in the group is N, 184 then the shape of output is :math:`(N, x_1, x_2, ..., x_R)`. 185 186 Raises: 187 TypeError: If `group` is not a str. 188 ValueError: If the local rank id of the calling process in the group 189 is larger than the group's rank size. 190 191 Supported Platforms: 192 ``Ascend`` ``GPU`` 193 194 Examples: 195 >>> # This example should be run with two devices. Refer to the tutorial > Distributed Training on mindspore.cn 196 >>> import numpy as np 197 >>> import mindspore.ops as ops 198 >>> import mindspore.nn as nn 199 >>> from mindspore.communication import init 200 >>> from mindspore import Tensor, context 201 >>> 202 >>> context.set_context(mode=context.GRAPH_MODE) 203 >>> init() 204 ... class Net(nn.Cell): 205 ... def __init__(self): 206 ... super(Net, self).__init__() 207 ... self.allgather = ops.AllGather() 208 ... 209 ... def construct(self, x): 210 ... return self.allgather(x) 211 ... 212 >>> input_x = Tensor(np.ones([2, 8]).astype(np.float32)) 213 >>> net = Net() 214 >>> output = net(input_x) 215 >>> print(output) 216 [[1. 1. 1. 1. 1. 1. 1. 1.] 217 [1. 1. 1. 1. 1. 1. 1. 1.] 218 [1. 1. 1. 1. 1. 1. 1. 1.] 219 [1. 1. 1. 1. 1. 1. 1. 1.]] 220 """ 221 222 @prim_attr_register 223 def __init__(self, group=GlobalComm.WORLD_COMM_GROUP): 224 """Initialize AllGather.""" 225 validator.check_value_type('group', _get_group(group), (str,), self.name) 226 self.rank = get_rank(_get_group(group)) 227 self.rank_size = get_group_size(_get_group(group)) 228 validator.check('rank', self.rank, 'rank_size', self.rank_size, Rel.LT, self.name) 229 self.add_prim_attr('rank_size', self.rank_size) 230 self.add_prim_attr('group', _get_group(group)) 231 self.add_prim_attr('fusion', 0) 232 self.add_prim_attr('mean_flag', False) 233 self.add_prim_attr('no_elimilate', True) 234 235 def infer_shape(self, x_shape): 236 validator.check_positive_int(len(x_shape), "x shape", self.name) 237 if x_shape[0] > 0: 238 x_shape[0] = x_shape[0] * self.rank_size 239 return x_shape 240 241 def infer_dtype(self, x_dtype): 242 validator.check_tensor_dtype_valid('x', x_dtype, target_dtypes, self.name) 243 return x_dtype 244 245 def __call__(self, tensor): 246 raise NotImplementedError 247 248 249class _MiniStepAllGather(PrimitiveWithInfer): 250 """ 251 Auto parallel virtual operator. Do nothing in forward, do reducescatter in backward in mini-step. It is only for 252 internal use of parallel modules and cannot be called by users. 253 254 Args: 255 group (str): The communication group to work on. Default: None. 256 grad_accumulation_step (int): The grad accumulation step. Default: None. 257 """ 258 259 @prim_attr_register 260 def __init__(self, group=GlobalComm.WORLD_COMM_GROUP, grad_accumulation_step=None, mean_flag=None): 261 """Initialize _MiniStepAllGather.""" 262 validator.check_value_type('group', _get_group(group), (str,), self.name) 263 self.rank = get_rank(_get_group(group)) 264 self.rank_size = get_group_size(_get_group(group)) 265 validator.check('rank', self.rank, 'rank_size', self.rank_size, Rel.LT, self.name) 266 self.add_prim_attr('rank_size', self.rank_size) 267 self.add_prim_attr('group', _get_group(group)) 268 self.add_prim_attr('fusion', 1) 269 self.grad_accumulation_step = grad_accumulation_step 270 self.mean_flag = mean_flag 271 272 def infer_shape(self, x_shape, z_shape): 273 validator.check_positive_int(len(x_shape), "x shape", self.name) 274 if x_shape[0] > 0: 275 x_shape[0] = x_shape[0] * self.rank_size 276 return x_shape 277 278 def infer_dtype(self, x_dtype, z_shape): 279 validator.check_tensor_dtype_valid('x', x_dtype, target_dtypes, self.name) 280 return x_dtype 281 282 283class _MicroStepAllGather(PrimitiveWithInfer): 284 """ 285 Auto parallel virtual operator. Do nothing in forward, do reducescatter in backward in mini-step. It is only for 286 internal use of parallel modules and cannot be called by users. 287 288 Args: 289 group (str): The communication group to work on. Default: None. 290 """ 291 292 @prim_attr_register 293 def __init__(self, group=GlobalComm.WORLD_COMM_GROUP, mean_flag=None): 294 validator.check_value_type('group', _get_group(group), (str,), self.name) 295 self.rank = get_rank(_get_group(group)) 296 self.rank_size = get_group_size(_get_group(group)) 297 validator.check('rank', self.rank, 'rank_size', self.rank_size, Rel.LT, self.name) 298 self.add_prim_attr('rank_size', self.rank_size) 299 self.add_prim_attr('group', _get_group(group)) 300 self.add_prim_attr('fusion', 1) 301 self.mean_flag = mean_flag 302 303 def infer_shape(self, x_shape, z_shape): 304 validator.check_positive_int(len(x_shape), "x shape", self.name) 305 if x_shape[0] > 0: 306 x_shape[0] = x_shape[0] * self.rank_size 307 return x_shape 308 309 def infer_dtype(self, x_dtype, z_dtype): 310 validator.check_tensor_dtype_valid('x', x_dtype, target_dtypes, self.name) 311 return x_dtype 312 313 314class _HostAllGather(PrimitiveWithInfer): 315 """ 316 Gathers tensors from the specified communication group on host. 317 318 Note: 319 The tensors must have the same shape and format in all processes of the collection. 320 _HostAllGather is a host-side operator, it depends on OpenMPI and must use build option -M on 321 to enable it. Using mpirun command to run it: 322 mpirun -output-filename log -merge-stderr-to-stdout -np 3 python test_host_all_gather.py 323 324 Args: 325 group (Union[tuple[int],list[int]]): The rand_ids of communication group to work on. Default: None. 326 327 Raises: 328 TypeError: If group is not a list nor tuple, or elements of group are not int. 329 ValueError: If group is not set, or rank_id from group not in [0, 7]. 330 331 Inputs: 332 - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. 333 334 Outputs: 335 Tensor. If the number of devices in the group is N, 336 then the shape of output is :math:`(N, x_1, x_2, ..., x_R)`. 337 """ 338 339 @prim_attr_register 340 def __init__(self, group=None): 341 """Initialize _HostAllGather.""" 342 if group is None: 343 raise ValueError(f"For '{self.name}', the 'group' cannot be None, but got {group}.") 344 validator.check_value_type('group', group, (tuple, list), self.name) 345 validator.check_int(len(group), 2, Rel.GE, "group size", self.name) 346 for r in group: 347 validator.check_int_range(r, 0, 7, Rel.INC_BOTH, "rank_id", self.name) 348 validator.check_value_type("rank_id", r, (int,), self.name) 349 self.group_size = len(group) 350 self.add_prim_attr('group', group) 351 self.add_prim_attr('no_elimilate', True) 352 353 def infer_shape(self, x_shape): 354 validator.check_positive_int(len(x_shape), "x shape", self.name) 355 if x_shape[0] > 0: 356 x_shape[0] = x_shape[0] * self.group_size 357 return x_shape 358 359 def infer_dtype(self, x_dtype): 360 validator.check_tensor_dtype_valid('x', x_dtype, target_dtypes, self.name) 361 return x_dtype 362 363 def __call__(self, tensor): 364 raise NotImplementedError 365 366 367class ReduceScatter(PrimitiveWithInfer): 368 """ 369 Reduces and scatters tensors from the specified communication group. 370 371 Note: 372 The back propagation of the op is not supported yet. Stay tuned for more. 373 The tensors must have the same shape and format in all processes of the collection. 374 375 Args: 376 op (str): Specifies an operation used for element-wise reductions, 377 like SUM, MAX, AVG. Default: ReduceOp.SUM. 378 group (str): The communication group to work on. Default: "hccl_world_group". 379 380 Raises: 381 TypeError: If any of operation and group is not a string. 382 ValueError: If the first dimension of the input cannot be divided by the rank size. 383 384 Supported Platforms: 385 ``Ascend`` ``GPU`` 386 387 Examples: 388 >>> # This example should be run with two devices. Refer to the tutorial > Distributed Training on mindspore.cn 389 >>> from mindspore import Tensor, context 390 >>> from mindspore.communication import init 391 >>> from mindspore.ops import ReduceOp 392 >>> import mindspore.nn as nn 393 >>> import mindspore.ops as ops 394 >>> import numpy as np 395 >>> 396 >>> context.set_context(mode=context.GRAPH_MODE) 397 >>> init() 398 >>> class Net(nn.Cell): 399 ... def __init__(self): 400 ... super(Net, self).__init__() 401 ... self.reducescatter = ops.ReduceScatter(ReduceOp.SUM) 402 ... 403 ... def construct(self, x): 404 ... return self.reducescatter(x) 405 ... 406 >>> input_ = Tensor(np.ones([8, 8]).astype(np.float32)) 407 >>> net = Net() 408 >>> output = net(input_) 409 >>> print(output) 410 [[2. 2. 2. 2. 2. 2. 2. 2.] 411 [2. 2. 2. 2. 2. 2. 2. 2.] 412 [2. 2. 2. 2. 2. 2. 2. 2.] 413 [2. 2. 2. 2. 2. 2. 2. 2.]] 414 """ 415 416 @prim_attr_register 417 def __init__(self, op=ReduceOp.SUM, group=GlobalComm.WORLD_COMM_GROUP): 418 """Initialize ReduceScatter.""" 419 validator.check_value_type('op', op, (type(ReduceOp.SUM),), self.name) 420 validator.check_value_type('group', _get_group(group), (str,), self.name) 421 self.op = op 422 self.rank_size = get_group_size(_get_group(group)) 423 self.add_prim_attr('rank_size', self.rank_size) 424 self.add_prim_attr('group', _get_group(group)) 425 self.add_prim_attr('fusion', 0) 426 self.add_prim_attr('no_elimilate', True) 427 428 def infer_shape(self, x_shape): 429 if self.rank_size == 0: 430 raise ValueError(f"For '{self.name}', the 'rank_size' cannot be zero, but got {self.rank_size}.") 431 if x_shape[0] % self.rank_size != 0: 432 raise ValueError(f"For '{self.name}', the first dimension of 'x_shape' should be divided by 'rank_size', " 433 f"but got 'x_shape[0]': {x_shape[0]}, 'rank_size': {self.rank_size}.") 434 x_shape[0] = int(x_shape[0] / self.rank_size) 435 return x_shape 436 437 def infer_dtype(self, x_dtype): 438 validator.check_tensor_dtype_valid('x', x_dtype, target_dtypes, self.name) 439 return x_dtype 440 441 def __call__(self, tensor): 442 raise NotImplementedError 443 444 445class _HostReduceScatter(PrimitiveWithInfer): 446 """ 447 Reduces and scatters tensors from the specified communication group on host. 448 449 Note: 450 The tensors must have the same shape and format in all processes of the collection. 451 _HostReduceScatter is a host-side operator, it depends on OpenMPI and must use build option 452 -M on to enable it. Using mpirun command to run it: 453 mpirun -output-filename log -merge-stderr-to-stdout -np 3 python test_host_reduce_scatter.py 454 455 Args: 456 op (str): Specifies an operation used for element-wise reductions, 457 like sum, max, avg. Default: ReduceOp.SUM. 458 group (Union[tuple[int],list[int]]): The rand_ids of communication group to work on. Default: None. 459 460 Raises: 461 TypeError: If op is not a string and group is not a list nor tuple, 462 or elements of group are not int. 463 ValueError: If the first dimension of input can not be divided by group size, 464 or group is not set, or rank_id not in [0, 7]. 465 """ 466 467 @prim_attr_register 468 def __init__(self, op=ReduceOp.SUM, group=None): 469 """Initialize _HostReduceScatter.""" 470 if group is None: 471 raise ValueError(f"For '{self.name}', the 'group' cannot be None, but got {group}.") 472 validator.check_value_type('op', op, (type(ReduceOp.SUM),), self.name) 473 validator.check_value_type('group', group, (tuple, list), self.name) 474 validator.check_int(len(group), 2, Rel.GE, "group size", self.name) 475 for r in group: 476 validator.check_int_range(r, 0, 7, Rel.INC_BOTH, "rank_id", self.name) 477 validator.check_value_type("rank_id", r, (int,), self.name) 478 self.op = op 479 self.group_size = len(group) 480 self.add_prim_attr('group', group) 481 self.add_prim_attr('no_elimilate', True) 482 483 def infer_shape(self, x_shape): 484 if x_shape[0] % self.group_size != 0: 485 raise ValueError(f"For '{self.name}', the first dimension of 'x_shape' should be divided by 'group_size', " 486 f"but got 'x_shape[0]': {x_shape[0]}, 'rank_size': {self.group_size}.") 487 x_shape[0] = int(x_shape[0] / self.group_size) 488 return x_shape 489 490 def infer_dtype(self, x_dtype): 491 validator.check_tensor_dtype_valid('x', x_dtype, target_dtypes, self.name) 492 return x_dtype 493 494 def __call__(self, tensor): 495 raise NotImplementedError 496 497 498class Broadcast(PrimitiveWithInfer): 499 """ 500 Broadcasts the tensor to the whole group. 501 502 Note: 503 The tensors must have the same shape and format in all processes of the collection. 504 505 Args: 506 root_rank (int): Source rank. Required in all processes except the one 507 that is sending the data. 508 group (str): The communication group to work on. Default: "hccl_world_group". 509 510 Inputs: 511 - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. 512 513 Outputs: 514 Tensor, has the same shape of the input, i.e., :math:`(x_1, x_2, ..., x_R)`. 515 The contents depend on the data of the `root_rank` device. 516 517 Raises: 518 TypeError: If root_rank is not a integer or group is not a string. 519 520 Supported Platforms: 521 ``Ascend`` ``GPU`` 522 523 Examples: 524 >>> # This example should be run with multiple processes. 525 >>> # Please refer to the tutorial > Distributed Training on mindspore.cn. 526 >>> from mindspore import Tensor 527 >>> from mindspore import context 528 >>> from mindspore.communication import init 529 >>> import mindspore.nn as nn 530 >>> import mindspore.ops as ops 531 >>> import numpy as np 532 >>> 533 >>> context.set_context(mode=context.GRAPH_MODE) 534 >>> init() 535 >>> class Net(nn.Cell): 536 ... def __init__(self): 537 ... super(Net, self).__init__() 538 ... self.broadcast = ops.Broadcast(1) 539 ... 540 ... def construct(self, x): 541 ... return self.broadcast((x,)) 542 ... 543 >>> input_x = Tensor(np.ones([2, 4]).astype(np.int32)) 544 >>> net = Net() 545 >>> output = net(input_x) 546 >>> print(output) 547 (Tensor(shape[2,4], dtype=Int32, value= 548 [[1, 1, 1, 1], 549 [1, 1, 1, 1]]),) 550 """ 551 552 @prim_attr_register 553 def __init__(self, root_rank, group=GlobalComm.WORLD_COMM_GROUP): 554 """Initialize Broadcast.""" 555 validator.check_value_type('root_rank', root_rank, (int,), self.name) 556 validator.check_value_type('group', _get_group(group), (str,), self.name) 557 check_hcom_group_valid(group, prim_name=self.name) 558 self.add_prim_attr('group', _get_group(group)) 559 self.add_prim_attr('no_elimilate', True) 560 561 def infer_shape(self, x_shape): 562 return x_shape 563 564 def infer_dtype(self, x_dtype): 565 if not isinstance(x_dtype, tuple): 566 raise TypeError(f"For '{self.name}', the 'input_x' should be a tuple, but got {type(x_dtype).__name__}!") 567 for _ele in x_dtype: 568 validator.check_tensor_dtype_valid('x', _ele, target_dtypes, self.name) 569 return x_dtype 570 571 572class AllSwap(PrimitiveWithCheck): 573 """ 574 AllSwap is a collective operation. 575 576 AllSwap sends data from the all processes to the all processes in the specified group. It has two phases: 577 578 - The scatter phase: On each process, the operand is split into the send size of blocks along the 579 0-th axis, and the blocks are scattered to all processes, e.g., the ith block is send to the ith process. 580 - The gather phase: Each process concatenates the received blocks along the 0-th axis. 581 582 Note: 583 The tensors must have the same format in all processes of the collection. 584 585 Args: 586 group (str): The communication group name. 587 588 Inputs: 589 tensor_in (tensor): A 2-D tensor. On each process, divide blocks into number of the send size. 590 send_size (tensor): A 1-D int64 tensor. The element is the send data size for each process. 591 recv_size (tensor): A 1-D int64 tensor. The element is the receive data size for each process. 592 593 Returns: 594 tensor_out (tensor): The result tensor. 595 596 Raises: 597 TypeError: If group is not a string. 598 """ 599 600 @prim_attr_register 601 def __init__(self, group=GlobalComm.WORLD_COMM_GROUP): 602 """Initialize AllSwap""" 603 validator.check_value_type('group', _get_group(group), (str,), self.name) 604 self.init_prim_io_names(inputs=['tensor_in', 'send_size', 'recv_size'], outputs=['tensor_out']) 605 self.add_prim_attr('group', _get_group(group)) 606 self.add_prim_attr('no_elimilate', True) 607 608 def __check__(self, tensor_in, send_size, recv_size): 609 validator.check_subclass("tensor_in", tensor_in['dtype'], mstype.tensor, self.name) 610 validator.check_tensor_dtype_valid("send_size", send_size['dtype'], [mstype.int64], 611 self.name) 612 validator.check_tensor_dtype_valid("recv_size", recv_size['dtype'], [mstype.int64], 613 self.name) 614 615 validator.check_equal_int(len(tensor_in['shape']), 2, "tensor_in", self.name) 616 validator.check_equal_int(len(send_size['shape']), 1, "send_size", self.name) 617 validator.check_equal_int(len(recv_size['shape']), 1, "recv_size", self.name) 618 619 out_shape = [-1] + [tensor_in['shape'][1]] 620 out = {'shape': out_shape, 621 'dtype': tensor_in['dtype'], 622 'value': None} 623 return out 624 625 626class NeighborExchange(Primitive): 627 """ 628 NeighborExchange is a collective operation. 629 630 NeighborExchange sends data from the local rank to ranks in the send_rank_ids, 631 as while receive data from recv_rank_ids. 632 633 Args: 634 send_rank_ids (list(int)): Ranks which the data is sent to. 635 recv_rank_ids (list(int)): Ranks which the data is received from. 636 recv_shapes (tuple(list(int))): Data shape which received from recv_rank_ids. 637 send_shapes (tuple(list(int))): Data shape which send to the send_rank_ids. 638 recv_type (type): Data type which received from recv_rank_ids 639 group (str): 640 """ 641 642 @prim_attr_register 643 def __init__(self, send_rank_ids, recv_rank_ids, recv_shapes, send_shapes, recv_type, 644 group=GlobalComm.WORLD_COMM_GROUP): 645 self.init_prim_io_names(inputs=['x'], outputs=['output']) 646 self.send_rank_ids = send_rank_ids 647 self.recv_rank_ids = recv_rank_ids 648 self.recv_shapes = recv_shapes 649 self.send_shapes = send_shapes 650 self.recv_type = recv_type 651 self.add_prim_attr('no_elimilate', True) 652 653 def __call__(self, tensor): 654 raise NotImplementedError 655 656 657class AlltoAll(PrimitiveWithInfer): 658 """ 659 AlltoAll is a collective operation. 660 661 AlltoAll sends data from the all processes to the all processes in the specified group. It has two phases: 662 663 - The scatter phase: On each process, the operand is split into split_count number of blocks along the 664 split_dimensions, and the blocks are scattered to all processes, e.g., the ith block is send to the ith process. 665 - The gather phase: Each process concatenates the received blocks along the concat_dimension. 666 667 Note: 668 The tensors must have the same shape and format in all processes of the collection. 669 670 Args: 671 split_count (int): On each process, divide blocks into split_count number. 672 split_dim (int): On each process, split blocks along the split_dim. 673 concat_dim (int): On each process, gather the received blocks along the concat_dimension. 674 group (str): The communication group to work on. Default: "hccl_world_group". 675 676 Raises: 677 TypeError: If group is not a string. 678 """ 679 680 @prim_attr_register 681 def __init__(self, split_count, split_dim, concat_dim, group=GlobalComm.WORLD_COMM_GROUP): 682 """Initialize AlltoAll""" 683 validator.check_value_type('group', _get_group(group), (str,), self.name) 684 validator.check_is_int(split_count, int) 685 validator.check_is_int(split_dim, int) 686 validator.check_is_int(concat_dim, int) 687 self.split_count = split_count 688 self.split_dim = split_dim 689 self.concat_dim = concat_dim 690 self.add_prim_attr('group', _get_group(group)) 691 self.add_prim_attr('no_elimilate', True) 692 693 def infer_shape(self, x_shape): 694 rank_size = get_group_size(_get_group(self.group)) 695 if self.split_count != rank_size: 696 raise ValueError(f"For '{self.name}', the 'split_count' must be equal to 'rank_size', " 697 f"but got 'split_count': {self.split_count}, 'rank_size': {rank_size}.") 698 if x_shape[self.split_dim] % self.split_count != 0: 699 raise ValueError(f"For '{self.name}', the 'split_count' must be divisible by 'rank_size', " 700 f"but got 'split_count' {self.split_count}, 'rank_size' {x_shape[self.split_dim]}.") 701 x_shape[self.concat_dim] = x_shape[self.concat_dim] * self.split_count 702 x_shape[self.split_dim] = int(x_shape[self.split_dim] / self.split_count) 703 return x_shape 704 705 def infer_dtype(self, x_dtype): 706 validator.check_tensor_dtype_valid('x', x_dtype, target_dtypes, self.name) 707 return x_dtype 708 709 def __call__(self, tensor): 710 raise NotImplementedError 711 712 713class _MirrorOperator(PrimitiveWithInfer): 714 """ 715 Auto parallel virtual operator. Do nothing in forward, do all reduce and mean in backward. It is only for 716 internal use of parallel modules and cannot be called by users. 717 718 Args: 719 group (str): The communication group to work on. Default: None. 720 dev_num (int): The device number of the group. Default: None. 721 mean_flag (bool): Whether use mean in backward. Default: None. 722 """ 723 724 @prim_attr_register 725 def __init__(self, group=None, dev_num=None, mean_flag=None): 726 """Initialize _MirrorOperator.""" 727 self.group = group 728 self.dev_num = dev_num 729 self.mean_flag = mean_flag 730 self.add_prim_attr("fusion", 1) 731 732 def infer_shape(self, x_shape): 733 return x_shape 734 735 def infer_dtype(self, x_dtype): 736 return x_dtype 737 738 739mirror = _MirrorOperator() 740 741 742class _MirrorMiniStepOperator(PrimitiveWithInfer): 743 """ 744 Auto parallel virtual operator. Do nothing in forward, do all reduce and mean in backward. It is only for 745 internal use of parallel modules and cannot be called by users. 746 747 Args: 748 group (str): The communication group to work on. Default: None. 749 dev_num (int): The device number of the group. Default: None. 750 mean_flag (bool): Whether use mean in backward. Default: None. 751 grad_accumulation_step (int): The grad accumulation step. Default: None. 752 """ 753 754 @prim_attr_register 755 def __init__(self, group=None, dev_num=None, mean_flag=None, grad_accumulation_step=None): 756 """Initialize _MirrorMiniStepOperator.""" 757 self.group = group 758 self.dev_num = dev_num 759 self.mean_flag = mean_flag 760 self.grad_accumulation_step = grad_accumulation_step 761 762 def infer_shape(self, x_shape, z_shape): 763 return x_shape 764 765 def infer_dtype(self, x_dtype, z_shape): 766 return x_dtype 767 768 769mirror_mini_step = _MirrorMiniStepOperator() 770 771 772class _VirtualDiv(PrimitiveWithInfer): 773 """ 774 Auto parallel virtual operator. Do nothing in forward, do Div in backward. 775 776 Args: 777 divisor: float32 778 """ 779 780 @prim_attr_register 781 def __init__(self, divisor=None): 782 """Initialize _VirtualDiv.""" 783 self.divisor = divisor 784 785 def infer_shape(self, x_shape): 786 return x_shape 787 788 def infer_dtype(self, x_dtype): 789 return x_dtype 790 791 792virtual_div = _VirtualDiv() 793 794 795class _VirtualAdd(PrimitiveWithInfer): 796 """Auto parallel virtual operator. Do nothing in forward, do Add in backward.""" 797 798 @prim_attr_register 799 def __init__(self): 800 """Initialize _VirtualAdd.""" 801 802 def infer_shape(self, x_shape, y_shape): 803 return x_shape 804 805 def infer_dtype(self, x_dtype, y_dtype): 806 return x_dtype 807 808 809class _VirtualDataset(PrimitiveWithInfer): 810 """ 811 Auto parallel virtual dataset operator. 812 813 It would insert VirtualDataset operator in forward computation and be deleted before backward computation. 814 """ 815 816 @prim_attr_register 817 def __init__(self): 818 """Initialize _VirtualDataset.""" 819 820 def infer_shape(self, *args): 821 return args 822 823 def infer_dtype(self, *args): 824 return args 825 826 827virtual_dataset = _VirtualDataset() 828 829 830class _VirtualAssignAdd(PrimitiveWithInfer): 831 """ 832 Auto parallel virtual operator. Do nothing in forward, do AssignAdd in backward. It is only for 833 internal use of parallel modules and cannot be called by users. 834 835 """ 836 837 @prim_attr_register 838 def __init__(self): 839 """Initialize _VirtualAssignAdd.""" 840 841 def infer_shape(self, x_shape, y_shape): 842 return x_shape 843 844 def infer_dtype(self, x_dtype, y_dtype): 845 return x_dtype 846 847 848virtual_assign_add = _VirtualAssignAdd() 849 850 851class _VirtualAccuGrad(PrimitiveWithInfer): 852 """ 853 Auto parallel virtual operator. Do nothing in forward, return y in backward. It is only for 854 internal use of parallel modules and cannot be called by users. 855 """ 856 857 @prim_attr_register 858 def __init__(self): 859 """Initialize _VirtualAccuGrad.""" 860 861 def infer_shape(self, x_shape, y_shape): 862 return x_shape 863 864 def infer_dtype(self, x_dtype, y_dtype): 865 return x_dtype 866 867 868virtual_accu_grad = _VirtualAccuGrad() 869 870 871class _MirrorMicroStepOperator(PrimitiveWithInfer): 872 """ 873 Auto parallel virtual operator. Do nothing in forward, do all reduce and mean in backward. It is only for 874 internal use of parallel modules and cannot be called by users. 875 876 Args: 877 group (str): The communication group to work on. Default: None. 878 dev_num (int): The device number of the group. Default: None. 879 mean_flag (bool): Whether use mean in backward. Default: None. 880 """ 881 882 @prim_attr_register 883 def __init__(self, group=None, dev_num=None, mean_flag=None): 884 """Initialize _MirrorMicroStepOperator.""" 885 self.group = group 886 self.dev_num = dev_num 887 self.mean_flag = mean_flag 888 889 def infer_shape(self, x_shape, z_shape): 890 return x_shape 891 892 def infer_dtype(self, x_dtype, z_shape): 893 return x_dtype 894 895 896class _VirtualOutput(PrimitiveWithInfer): 897 """ 898 Auto parallel virtual out operator. 899 900 It would insert VirtualOutput operator in forward computation and be deleted before backward computation. 901 """ 902 903 @prim_attr_register 904 def __init__(self): 905 """Initialize _VirtualOutput.""" 906 907 def infer_shape(self, x_shape): 908 return x_shape 909 910 def infer_dtype(self, x_dtype): 911 return x_dtype 912 913 914class _GetTensorSlice(PrimitiveWithInfer): 915 """ 916 Gets tensor slice by device matrix and tensor map. 917 918 Args: 919 dev_mat (tuple): The device matrix of the slice tensor. 920 tensor_map (tuple): The tensor map of the slice tensor. 921 """ 922 923 @prim_attr_register 924 def __init__(self): 925 """Initialize _GetTensorSlice.""" 926 927 def infer_value(self, x, dev_mat, tensor_map): 928 from mindspore.parallel._tensor import _load_tensor 929 validator.check_value_type("dev_mat", dev_mat, [tuple], self.name) 930 validator.check_value_type("tensor_map", tensor_map, [tuple], self.name) 931 return Tensor(_load_tensor(x, dev_mat, tensor_map)) 932