1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15import numpy as np 16 17import mindspore.context as context 18import mindspore.nn as nn 19from mindspore import Tensor 20from mindspore.common.initializer import initializer 21from mindspore.common.parameter import Parameter 22from mindspore.communication.management import init, NCCL_WORLD_COMM_GROUP, get_rank, get_group_size 23from mindspore.ops import operations as P 24from mindspore.ops.operations._inner_ops import Send, Receive 25from mindspore.common import dtype as mstype 26 27context.set_context(mode=context.GRAPH_MODE, device_target='GPU') 28 29init() 30rank = get_rank() 31size = get_group_size() 32if size % 2 != 0: 33 raise RuntimeError("Group size should be divided by 2 exactly.") 34x = np.ones([3, 3, 3, 3]).astype(np.float32) * 0.01 * (rank + 1) 35 36 37class SendNet(nn.Cell): 38 def __init__(self): 39 super(SendNet, self).__init__() 40 self.x = Parameter(initializer(Tensor(x), x.shape), name='x') 41 self.depend = P.Depend() 42 self.send = Send(sr_tag=0, dest_rank=rank+size//2, group=NCCL_WORLD_COMM_GROUP) 43 44 def construct(self): 45 out = self.depend(self.x, self.send(self.x)) 46 return out 47 48class RecvNet(nn.Cell): 49 def __init__(self): 50 super(RecvNet, self).__init__() 51 self.recv = Receive(sr_tag=0, src_rank=rank-size//2, shape=[3, 3, 3, 3], dtype=mstype.float32, 52 group=NCCL_WORLD_COMM_GROUP) 53 54 def construct(self): 55 out = self.recv() 56 return out 57 58def test_send_recv(): 59 if rank < size / 2: 60 send_net = SendNet() 61 output = send_net() 62 else: 63 expect_output = np.ones([3, 3, 3, 3]).astype(np.float32) * 0.01 * (rank-size//2 + 1) 64 recv_net = RecvNet() 65 output = recv_net() 66 67 diff = abs(output.asnumpy() - expect_output) 68 error = np.ones(shape=output.shape) * 1.0e-5 69 assert np.all(diff < error) 70 assert expect_output.shape == output.shape 71