1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""comm_helper""" 16 17from mindspore.parallel._ps_context import _is_role_pserver, _is_role_sched 18from mindspore import log as logger 19from ._hccl_management import load_lib as hccl_load_lib 20 21_HCCL_AVAILABLE = False 22_NCCL_AVAILABLE = False 23_MPI_AVAILABLE = False 24try: 25 import mindspore._ms_mpi as mpi 26 _NCCL_AVAILABLE = True 27except ImportError: 28 _NCCL_AVAILABLE = False 29 30 31try: 32 hccl_load_lib() 33 _HCCL_AVAILABLE = True 34except RuntimeError: 35 _HCCL_AVAILABLE = False 36 37if _HCCL_AVAILABLE: 38 from . import _hccl_management as hccl 39 try: 40 import mindspore._ascend_mpi as mpi 41 _MPI_AVAILABLE = True 42 except ImportError: 43 _MPI_AVAILABLE = False 44else: 45 try: 46 import hccl_test.manage.api as hccl 47 _HCCL_AVAILABLE = True 48 except ImportError: 49 _HCCL_AVAILABLE = False 50 51 52HCCL_WORLD_COMM_GROUP = "hccl_world_group" 53NCCL_WORLD_COMM_GROUP = "nccl_world_group" 54 55 56class Backend: 57 """ 58 Class for available backends. 59 60 Note: 61 The backends' value should be string, e.g., "hccl". 62 If backend is set to Backend.UNDEFINED, it will be seen as invaliad. 63 64 Args: 65 name (str): The name of backend. 66 67 Raises: 68 TypeError: If name is not a string. 69 ValueError: If backend is invalid. 70 71 Examples: 72 >>> Backend("abc") 73 >>> hccl = Backend("hccl") 74 """ 75 UNDEFINED = "undefined" 76 HCCL = "hccl" 77 NCCL = "nccl" 78 HCCL_MPI = "hccl_mpi" 79 80 def __new__(cls, name): 81 """Create instance object of Backend.""" 82 if not isinstance(name, str): 83 raise TypeError("Backend name must be a string, but got {}".format(type(name))) 84 value = getattr(Backend, name.upper(), Backend.UNDEFINED) 85 if value == Backend.UNDEFINED: 86 raise ValueError("Invalid backend: '{}'".format(name)) 87 return value 88 89DEFAULT_BACKEND = Backend("hccl") 90 91 92class GlobalComm: 93 """ 94 World communication information. The GlobalComm is a global class. The members contain: BACKEND, WORLD_COMM_GROUP. 95 """ 96 BACKEND = DEFAULT_BACKEND 97 WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP 98 INITED = False 99 CHECK_ENVS = True 100 101 102class _ExistingGroup: 103 """ 104 The communication groups which exist in the progress. 105 """ 106 ITEMS = {} 107 108 109def is_hccl_available(): 110 """ 111 Check HCCL api is available. 112 113 Returns: 114 Boolean. Return whether HCCL is available or not. 115 """ 116 return _HCCL_AVAILABLE 117 118 119def is_mpi_available(): 120 """ 121 Check HCCL & MPI api is available. 122 123 Returns: 124 Boolean. Return whether HCCL & MPI is available or not. 125 """ 126 return _MPI_AVAILABLE 127 128 129def is_nccl_available(): 130 """ 131 Check NCCL api is available. 132 133 Returns: 134 Boolean. Return whether NCCL is available or not. 135 """ 136 return _NCCL_AVAILABLE 137 138 139def check_parameter_available(func): 140 """ 141 Check parameter is available. If not available, raise Error. 142 143 Args: 144 func (Function): The function to be run. 145 146 Raises: 147 RuntimeError. 148 149 Returns: 150 Wrapper. If not available, raise Error. 151 """ 152 def wrapper(*args, **kargs): 153 if _is_role_pserver() or _is_role_sched(): 154 return func(*args, **kargs) 155 if not GlobalComm.INITED: 156 raise RuntimeError("Distributed Communication has not been inited") 157 group = None 158 if "group" in kargs.keys(): 159 group = kargs.get("group") 160 if group is not None and not isinstance(group, str): 161 raise TypeError("Group should be str or None, " 162 "but got group {}".format(type(group))) 163 164 if "backend" in kargs.keys(): 165 backend = kargs.get("backend") 166 if backend is Backend.HCCL and not is_hccl_available(): 167 raise RuntimeError("Distributed Communication doesn't have HCCL built in") 168 if backend is Backend.HCCL_MPI and not is_mpi_available(): 169 raise RuntimeError("Distributed Communication doesn't have MPI built in") 170 if backend is Backend.NCCL and not is_nccl_available(): 171 raise RuntimeError("Distributed Communication doesn't have NCCL built in") 172 173 if group is None: 174 if backend is Backend.HCCL or Backend.HCCL_MPI: 175 group = HCCL_WORLD_COMM_GROUP 176 elif backend is Backend.NCCL: 177 group = NCCL_WORLD_COMM_GROUP 178 return func(*args, **kargs) 179 return wrapper 180 181 182@check_parameter_available 183def _get_rank_helper(group, backend): 184 """ 185 The Helper to do get_rank_id. 186 187 Args: 188 group (str): The communication group. 189 backend (str): The backend, like "hccl". 190 191 Raises: 192 ValueError: If backend is invalid. 193 194 Returns: 195 Integer. The local rank id of the calling process. 196 """ 197 rank_id = None 198 if _is_role_pserver() or _is_role_sched(): 199 rank_id = 0 200 return rank_id 201 if backend == Backend.HCCL_MPI: 202 rank_id = mpi.get_rank_id(group) 203 elif backend == Backend.HCCL: 204 if group == HCCL_WORLD_COMM_GROUP: 205 rank_id = hccl.get_rank_id() 206 else: 207 rank_id = hccl.get_rank_id(group) 208 elif backend == Backend.NCCL: 209 rank_id = mpi.get_rank_id(group) 210 else: 211 raise ValueError("Invalid backend: '{}'".format(backend)) 212 return rank_id 213 214 215@check_parameter_available 216def _get_local_rank_helper(group, backend): 217 """ 218 The Helper to do get_local_rank_id. 219 220 Args: 221 group (str): The communication group. 222 backend (str): The backend, like "hccl". 223 224 Raises: 225 ValueError: If backend is invalid. 226 227 Returns: 228 Integer. The local rank id of the calling process. 229 """ 230 rank_id = None 231 if backend == Backend.HCCL_MPI: 232 rank_id = mpi.get_rank_id(group) 233 elif backend == Backend.HCCL: 234 if group == HCCL_WORLD_COMM_GROUP: 235 rank_id = hccl.get_local_rank_id() 236 else: 237 rank_id = hccl.get_local_rank_id(group) 238 elif backend == Backend.NCCL: 239 raise RuntimeError("Nccl doesn't support get_local_rank_id now.") 240 else: 241 raise ValueError("Invalid backend: '{}'".format(backend)) 242 return rank_id 243 244 245@check_parameter_available 246def _get_size_helper(group, backend): 247 """ 248 The Helper to do get_rank_size. 249 250 Args: 251 group (str): The communication group. 252 backend (str): The backend, like "hccl". 253 254 Raises: 255 ValueError: If backend is invalid. 256 257 Returns: 258 Integer. The rank size of specified group. 259 """ 260 size = None 261 if _is_role_pserver() or _is_role_sched(): 262 size = 1 263 return size 264 if backend == Backend.HCCL_MPI: 265 size = mpi.get_rank_size(group) 266 elif backend == Backend.HCCL: 267 if group == HCCL_WORLD_COMM_GROUP: 268 size = hccl.get_rank_size() 269 else: 270 size = hccl.get_rank_size(group) 271 elif backend == Backend.NCCL: 272 size = mpi.get_rank_size(group) 273 else: 274 raise ValueError("Invalid backend: '{}'".format(backend)) 275 return size 276 277 278@check_parameter_available 279def _get_local_size_helper(group, backend): 280 """ 281 The Helper to do get_local_rank_size. 282 283 Args: 284 group (str): The communication group. 285 backend (str): The backend, like "hccl". 286 287 Raises: 288 ValueError: If backend is invalid. 289 290 Returns: 291 Integer. The local rank size where the calling process is being within specified group. 292 """ 293 size = None 294 if backend == Backend.HCCL: 295 if group == HCCL_WORLD_COMM_GROUP: 296 size = hccl.get_local_rank_size() 297 else: 298 size = hccl.get_local_rank_size(group) 299 elif backend == Backend.NCCL: 300 raise RuntimeError("Nccl doesn't support get_local_rank_size now.") 301 else: 302 raise ValueError("Invalid backend: '{}'".format(backend)) 303 return size 304 305 306@check_parameter_available 307def _get_world_rank_from_group_rank_helper(group, group_rank_id, backend): 308 """ 309 The Helper to do get_world_rank_from_group_rank. 310 311 Args: 312 group (str): The user communication group. 313 group_rank_id (int): A rank id in user communication group. 314 backend (str): The backend, like "hccl". 315 316 Raises: 317 TypeError: If group_rank_id is not int. 318 ValueError: If group is "hccl_world_group" or backend is invalid. 319 320 Returns: 321 Integer. A rank id in world communication group. 322 """ 323 world_rank_id = None 324 if not isinstance(group_rank_id, int): 325 raise TypeError("group_rank_id should be int, but got type {}".format(type(group_rank_id))) 326 if backend == Backend.HCCL: 327 if group == HCCL_WORLD_COMM_GROUP: 328 raise ValueError("Group cannot be 'hccl_world_group'. ") 329 world_rank_id = hccl.get_world_rank_from_group_rank(group, group_rank_id) 330 elif backend == Backend.NCCL: 331 raise RuntimeError("Nccl doesn't support get_world_rank_from_group_rank now.") 332 else: 333 raise ValueError("Invalid backend: '{}'".format(backend)) 334 return world_rank_id 335 336 337@check_parameter_available 338def _get_group_rank_from_world_rank_helper(world_rank_id, group, backend): 339 """ 340 The Helper to do get_group_rank_from_world_rank. 341 342 Args: 343 world_rank_id (int): A rank id in world communication group. 344 group (str): The user communication group. 345 backend (str): The backend, like "hccl". 346 347 Raises: 348 TypeError: If world_rank_id is not int. 349 ValueError: If group is 'hccl_world_group' or backend is invalid. 350 351 Returns: 352 Integer. A rank id in user communication group. 353 """ 354 group_rank_id = None 355 if not isinstance(world_rank_id, int): 356 raise TypeError("world_rank_id should be int, but got type {}".format(type(world_rank_id))) 357 if backend == Backend.HCCL: 358 if group == HCCL_WORLD_COMM_GROUP: 359 raise ValueError("Group cannot be 'hccl_world_group'. ") 360 group_rank_id = hccl.get_group_rank_from_world_rank(world_rank_id, group) 361 elif backend == Backend.NCCL: 362 raise RuntimeError("Nccl doesn't support get_group_rank_from_world_rank now.") 363 else: 364 raise ValueError("Invalid backend: '{}'".format(backend)) 365 return group_rank_id 366 367 368@check_parameter_available 369def _create_group_helper(group, rank_ids, backend): 370 """ 371 The Helper to do create_group. 372 373 Args: 374 group (str): The communication group. 375 rank_ids (list): Rank ids in the group. 376 backend (str): The backend, like "hccl". 377 378 Raises: 379 TypeError: If rank_ids is not a list. 380 ValueError: If rank_ids size is not larger than 1 or rank_ids has duplicate data or backend is invalid. 381 """ 382 if group in _ExistingGroup.ITEMS.keys(): 383 if rank_ids != _ExistingGroup.ITEMS[group]: 384 raise ValueError("The group {} has been created, the rank_list is {}, " 385 "but current rank_list for the group is {}". 386 format(group, _ExistingGroup.ITEMS[group], rank_ids)) 387 logger.warning("%r group has existed.", group) 388 return 389 if backend == Backend.HCCL: 390 if not isinstance(rank_ids, list): 391 raise TypeError("Rank_ids {} should be list".format(rank_ids)) 392 rank_size = len(rank_ids) 393 if rank_size < 1: 394 raise ValueError("Rank_ids size {} should be large than 0".format(rank_size)) 395 if len(rank_ids) - len(list(set(rank_ids))) > 0: 396 raise ValueError("List rank_ids in Group {} has duplicate data!".format(group)) 397 hccl.create_group(group, rank_size, rank_ids) 398 elif backend == Backend.HCCL_MPI: 399 mpi.create_group(group, rank_ids) 400 elif backend == Backend.NCCL: 401 raise RuntimeError("Nccl doesn't support create_group now.") 402 else: 403 raise ValueError("Invalid backend: '{}'".format(backend)) 404 _ExistingGroup.ITEMS[group] = rank_ids 405 406 407@check_parameter_available 408def _destroy_group_helper(group, backend): 409 """ 410 The Helper to do destroy_group. 411 412 Args: 413 group (str): The user communication group. 414 backend (str): The backend, like "hccl". 415 416 Raises: 417 ValueError: If group is "hccl_world_group" or backend is invalid. 418 """ 419 if backend == Backend.HCCL: 420 if group == HCCL_WORLD_COMM_GROUP: 421 raise ValueError("The hccl_world_group does not support destruction.") 422 hccl.destroy_group(group) 423 elif backend == Backend.NCCL: 424 raise RuntimeError("Nccl doesn't support destroy_group now.") 425 else: 426 raise ValueError("Invalid backend: '{}'".format(backend)) 427