1# Copyright 2023 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16import numpy as np 17import mindspore.nn as nn 18from mindspore import Tensor 19from mindspore.communication.management import init, get_rank 20from mindspore.ops.operations import comm_ops 21 22np.random.seed(1) 23init() 24this_rank = get_rank() 25dest_rank = 0 26 27 28class CollectiveGatherNet(nn.Cell): 29 def __init__(self): 30 super().__init__() 31 self.collective_gather = comm_ops.CollectiveGather(dest_rank=dest_rank) 32 33 def construct(self, x): 34 out = self.collective_gather(x) 35 return out 36 37 38def generate_input(dtype): 39 if this_rank == 0: 40 return Tensor(np.array([[1, 1, 1]]).astype(dtype)) 41 if this_rank == 1: 42 return Tensor(np.array([[2, 2, 2]]).astype(dtype)) 43 if this_rank == 2: 44 return Tensor(np.array([[3, 3, 3]]).astype(dtype)) 45 if this_rank == 3: 46 return Tensor(np.array([[4, 4, 4]]).astype(dtype)) 47 return None 48 49 50def test_hccl_gather_4p_float32(): 51 """ 52 Feature: test 'CollectiveGather' communication operator. 53 Description: test 'CollectiveGather' communication operator. 54 Expectation: expect correct result. 55 """ 56 ms_input = generate_input(np.float32) 57 net = CollectiveGatherNet() 58 output = net(ms_input) 59 60 if this_rank == dest_rank: 61 res = np.array([[1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4]]).astype(np.float32) 62 assert (output.numpy() == res).all() 63 64 65def test_hccl_gather_4p_float16(): 66 """ 67 Feature: test 'CollectiveGather' communication operator. 68 Description: test 'CollectiveGather' communication operator. 69 Expectation: expect correct result. 70 """ 71 ms_input = generate_input(np.float16) 72 net = CollectiveGatherNet() 73 output = net(ms_input) 74 75 if this_rank == dest_rank: 76 res = np.array([[1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4]]).astype(np.float16) 77 assert (output.numpy() == res).all() 78