1# Copyright 2020 Huawei Technologies Co., Ltd 2 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================ 16"""HCCL management API""" 17import ctypes 18import os 19from mindspore import context 20from .._c_expression import get_hccl_rank_id, get_hccl_rank_size 21 22MAX_GROUP_NAME_LEN = 127 23MAX_RANK_NUM = 4096 24HCCL_LIB = 'libhccl_plugin.so' 25HCCL_LIB_CTYPES = "" 26 27 28def check_group(group): 29 """ 30 A function that check if a collection communication group is legal. 31 32 Returns: 33 None 34 """ 35 if isinstance(group, (str)): 36 group_len = len(group) 37 if group_len > MAX_GROUP_NAME_LEN or group_len == 0: 38 raise ValueError('Group name is invalid.') 39 else: 40 raise TypeError('Group must be a python str.') 41 42 43def check_rank_num(rank_num): 44 """ 45 A function that check if a collection communication rank number is legal.If not raise error. 46 47 Returns: 48 None 49 """ 50 if isinstance(rank_num, (int)): 51 if rank_num > MAX_RANK_NUM or rank_num <= 0: 52 raise ValueError('Rank number is out of range.') 53 else: 54 raise TypeError('Rank number must be a python int.') 55 56 57def check_rank_id(rank_id): 58 """ 59 A function that check if a collection communication rank id is legal.If not raise error. 60 61 Returns: 62 None 63 """ 64 if isinstance(rank_id, (int)): 65 if rank_id >= MAX_RANK_NUM or rank_id < 0: 66 raise ValueError('Rank id is out of range.') 67 else: 68 raise TypeError('Rank id must be a python int.') 69 70 71def load_lib(): 72 """load hccl lib""" 73 try: 74 base_dir = os.path.dirname(os.path.realpath(__file__)) 75 lib_path = os.path.join(base_dir, "../lib", HCCL_LIB) 76 hccl_lib = ctypes.CDLL(lib_path) 77 except Exception: 78 raise RuntimeError('Get hccl lib error.') 79 80 global HCCL_LIB_CTYPES 81 HCCL_LIB_CTYPES = hccl_lib 82 83 84def c_str(string): 85 """Convert a python string to C string.""" 86 if not isinstance(string, str): 87 string = string.decode('ascii') 88 return ctypes.c_char_p(string.encode('utf-8')) 89 90 91def c_array(ctype, values): 92 """Create ctypes array from a python array.""" 93 return (ctype * len(values))(*values) 94 95 96def create_group(group, rank_num, rank_ids): 97 """ 98 Create group. 99 100 A function that creates a collection communication group which includes 'rank_num' 101 device and 'rank_ids' is the list of these ranks of devices. 102 103 Note: 104 The world group can not be created. 105 106 Returns: 107 None 108 """ 109 check_group(group) 110 check_rank_num(rank_num) 111 if isinstance(rank_ids, (list)): 112 if rank_num != len(rank_ids): 113 raise ValueError('Rank number is not equal to the length of rank_ids.') 114 for rank_id in rank_ids: 115 if not isinstance(rank_id, (int)) or rank_id < 0: 116 raise ValueError('Rank id must be unsigned integer!') 117 c_array_rank_ids = c_array(ctypes.c_uint, rank_ids) 118 c_rank_num = ctypes.c_uint(rank_num) 119 c_group = c_str(group) 120 ret = HCCL_LIB_CTYPES.HcomCreateGroup(c_group, c_rank_num, c_array_rank_ids) 121 if ret != 0: 122 raise RuntimeError('Create group error, the error code is ' + str(ret)) 123 else: 124 raise TypeError('Rank ids must be a python list.') 125 126 127def destroy_group(group): 128 """ 129 A function that destroy the group which created by user. 130 131 Note: 132 The world group can not be destroy. 133 134 Returns: 135 None 136 """ 137 check_group(group) 138 c_group = c_str(group) 139 ret = HCCL_LIB_CTYPES.HcomDestroyGroup(c_group) 140 if ret != 0: 141 raise RuntimeError('Destroy group error.') 142 143 144def get_rank_size(group="hccl_world_group"): 145 """ 146 A function that returns the number of ranks within the given collection communication group. 147 148 Note: 149 The default group is hccl_world_group. 150 151 Returns: 152 An integer scalar with the num of ranks. 153 """ 154 155 if context.get_context("mode") == context.PYNATIVE_MODE: 156 return get_hccl_rank_size() 157 158 check_group(group) 159 c_group = c_str(group) 160 c_rank_size = ctypes.c_uint() 161 ret = HCCL_LIB_CTYPES.HcomGetRankSize(c_group, ctypes.byref(c_rank_size)) 162 if ret != 0: 163 raise RuntimeError('Get rank size error.') 164 165 return c_rank_size.value 166 167 168def get_rank_id(group="hccl_world_group"): 169 """ 170 A function that returns the rank id of the calling process, within the given collection communication group. 171 172 Returns: 173 An integer scalar with the rank id of the calling process. 174 """ 175 176 if context.get_context("mode") == context.PYNATIVE_MODE: 177 return get_hccl_rank_id() 178 179 check_group(group) 180 c_group = c_str(group) 181 c_rank_id = ctypes.c_uint() 182 ret = HCCL_LIB_CTYPES.HcomGetRankId(c_group, ctypes.byref(c_rank_id)) 183 if ret != 0: 184 raise RuntimeError('Get rank id error.') 185 186 return c_rank_id.value 187 188 189 190def get_local_rank_size(group="hccl_world_group"): 191 """ 192 A function that returns the number of local ranks within the given collection communication group. 193 194 Note: 195 The default group is hccl_world_group. 196 197 Returns: 198 An integer scalar with the num of local ranks. 199 """ 200 check_group(group) 201 c_group = c_str(group) 202 c_local_rank_size = ctypes.c_uint() 203 ret = HCCL_LIB_CTYPES.HcomGetLocalRankSize(c_group, ctypes.byref(c_local_rank_size)) 204 if ret != 0: 205 raise RuntimeError('Get local rank size error.') 206 207 return c_local_rank_size.value 208 209 210def get_local_rank_id(group="hccl_world_group"): 211 """ 212 Get local rank id. 213 214 A function that returns the local rank id of the calling process, within the given collection communication group. 215 216 Returns: 217 An integer scalar with the local rank id of the calling process. 218 """ 219 check_group(group) 220 c_group = c_str(group) 221 c_local_rank_id = ctypes.c_uint() 222 ret = HCCL_LIB_CTYPES.HcomGetLocalRankId(c_group, ctypes.byref(c_local_rank_id)) 223 if ret != 0: 224 raise RuntimeError('Get local rank id error.') 225 226 return c_local_rank_id.value 227 228 229def get_world_rank_from_group_rank(group, group_rank_id): 230 """ 231 Get world rank from group rank. 232 233 A function that returns the rank id in the world group corresponding to the 234 rank which id is 'group_rank_id' in the user group. 235 236 Returns: 237 An integer scalar with the rank id in the world group. 238 """ 239 check_group(group) 240 check_rank_id(group_rank_id) 241 c_group = c_str(group) 242 c_group_rank_id = ctypes.c_uint(group_rank_id) 243 c_world_rank_id = ctypes.c_uint() 244 ret = HCCL_LIB_CTYPES.HcomGetWorldRankFromGroupRank(c_group, c_group_rank_id, ctypes.byref(c_world_rank_id)) 245 if ret != 0: 246 raise RuntimeError('Get world rank from group rank error.') 247 248 return c_world_rank_id.value 249 250 251def get_group_rank_from_world_rank(world_rank_id, group): 252 """ 253 Get group rank from world rank. 254 255 A function that returns the rank id in the user group corresponding to the 256 rank which id is 'world_rank_id' in the world group. 257 258 Returns: 259 An integer scalar with the rank id in the user group. 260 """ 261 check_group(group) 262 check_rank_id(world_rank_id) 263 c_group = c_str(group) 264 c_world_rank_id = ctypes.c_uint(world_rank_id) 265 c_group_rank_id = ctypes.c_uint() 266 ret = HCCL_LIB_CTYPES.HcomGetGroupRankFromWorldRank(c_world_rank_id, c_group, ctypes.byref(c_group_rank_id)) 267 if ret != 0: 268 raise RuntimeError('Get group rank from world rank error.') 269 270 return c_group_rank_id.value 271