1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""comm_helper""" 16 17import os 18import glob 19import ctypes 20 21import sys 22from sys import excepthook 23 24from mindspore import context 25from mindspore.parallel._ps_context import _is_role_worker, _is_role_pserver, _is_role_sched, _is_ps_mode,\ 26 _get_ps_context 27from mindspore import log as logger 28from mindspore._c_expression import CollectiveManager, set_cluster_exit_with_exception, MSContext 29from mindspore.common._utils import load_lib 30 31HCCL_LIB = 'libhccl_plugin.so' 32 33 34def hccl_load_lib(): 35 """load hccl lib""" 36 try: 37 base_dir = os.path.dirname(os.path.realpath(__file__)) 38 lib_path = os.path.join(base_dir, "../lib/plugin/ascend", HCCL_LIB) 39 ctypes.CDLL(lib_path) 40 except Exception as exc: 41 raise RuntimeError('Get hccl lib error.') from exc 42 43_HCCL_TEST_AVAILABLE = False 44 45try: 46 if MSContext.get_instance().is_ascend_plugin_loaded(): 47 hccl_load_lib() 48except RuntimeError: 49 _HCCL_TEST_AVAILABLE = True 50 51if _HCCL_TEST_AVAILABLE: 52 try: 53 import hccl_test.manage.api as hccl 54 except ImportError: 55 _HCCL_TEST_AVAILABLE = False 56 57 58HCCL_WORLD_COMM_GROUP = "hccl_world_group" 59NCCL_WORLD_COMM_GROUP = "nccl_world_group" 60MCCL_WORLD_COMM_GROUP = "mccl_world_group" 61 62DEVICE_TO_BACKEND = { 63 "Ascend": "hccl", 64 "GPU": "nccl", 65 "CPU": "mccl" 66} 67 68class Backend: 69 """ 70 Class for available backends. 71 72 Note: 73 The backends' value should be string, e.g., "hccl". 74 If backend is set to Backend.UNDEFINED, it will be seen as invaliad. 75 76 Args: 77 name (str): The name of backend. 78 79 Raises: 80 TypeError: If name is not a string. 81 ValueError: If backend is invalid. 82 83 Examples: 84 >>> Backend("abc") 85 >>> hccl = Backend("hccl") 86 """ 87 UNDEFINED = "undefined" 88 HCCL = "hccl" 89 NCCL = "nccl" 90 MCCL = "mccl" 91 92 @staticmethod 93 def __new__(cls, name): 94 """Create instance object of Backend.""" 95 if not isinstance(name, str): 96 raise TypeError("For 'Backend', the class variable 'name' must be a string, " 97 "but got the type : {}".format(type(name))) 98 value = getattr(Backend, name.upper(), Backend.UNDEFINED) 99 if value == Backend.UNDEFINED: 100 raise ValueError("For 'Backend', the class variable 'name' {} is not supported, " 101 "please use hccl or nccl.".format(name)) 102 return value 103 104 105DEFAULT_BACKEND = Backend("hccl") 106 107 108class GlobalComm: 109 """ 110 World communication information. The GlobalComm is a global class. The members contain: 111 112 - ``BACKEND`` : The communication library used, using ``"hccl"`` / ``"nccl"`` / ``"mccl"`` . 113 ``"hccl"`` means Huawei Collective Communication Library(HCCL), 114 ``"nccl"`` means NVIDIA Collective Communication Library(NCCL), 115 ``"mccl"`` means MindSpore Collective Communication Library(MCCL). 116 - ``WORLD_COMM_GROUP`` : Global communication domain, 117 using ``"hccl_world_group"`` / ``"nccl_world_group"`` / ``"mccl_world_group"`` . 118 """ 119 BACKEND = DEFAULT_BACKEND 120 WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP 121 INITED = False 122 CHECK_ENVS = True 123 124 125class _ExistingGroup: 126 """ 127 The communication groups which exist in the progress. 128 """ 129 ITEMS = {} 130 131 132def _hccl_test(): 133 return _HCCL_TEST_AVAILABLE and GlobalComm.BACKEND == Backend.HCCL 134 135 136def _check_mpi_envs(): 137 """ 138 Check whether mpi environment variables have been exported or not. 139 140 return True if mpi environment variables have been exported, False otherwise. 141 """ 142 ompi_command_env = os.getenv("OMPI_COMMAND") 143 pmix_rank_env = os.getenv("PMIX_RANK") 144 if ompi_command_env and pmix_rank_env: 145 return True 146 return False 147 148 149def _check_bypass_rank_id_and_size(): 150 ''' 151 Whether bypass calling c++ API to get rank id and size, instead, use fake rank id 0 and rank size 1. 152 This returns True when this process is Scheduler node or is Server node in old Parameter Server training mode. 153 ''' 154 if _is_role_sched(): 155 return True 156 device_target = context.get_context("device_target") 157 if _is_ps_mode() and _get_ps_context("worker_num") == 1 and device_target == "Ascend": 158 return True 159 return False 160 161 162def _set_elegant_exit_handle(): 163 if _is_role_worker() or _is_role_pserver() or _is_role_sched(): 164 sys.excepthook = lambda *args: (set_cluster_exit_with_exception(), excepthook(*args)) 165 166 167def check_parameter_available(func): 168 """ 169 Check parameter is available. If not available, raise Error. 170 171 Args: 172 func (Function): The function to be run. 173 174 Raises: 175 RuntimeError. 176 177 Returns: 178 Wrapper. If not available, raise Error. 179 """ 180 def wrapper(*args, **kargs): 181 if not GlobalComm.INITED: 182 raise RuntimeError("Distributed Communication has not been inited") 183 group = None 184 if "group" in kargs.keys(): 185 group = kargs.get("group") 186 if group is not None and not isinstance(group, str): 187 raise TypeError("The parameter 'group' should be str or None, " 188 "but got the type : {}".format(type(group))) 189 if group is None: 190 group = GlobalComm.WORLD_COMM_GROUP 191 return func(*args, **kargs) 192 return wrapper 193 194 195def _is_available(): 196 """ 197 Returns `True` if distributed module is available. 198 199 Note: 200 Always returns `True` because MindSpore always has distributed ability on all platforms. 201 """ 202 return True 203 204 205def _is_initialized(): 206 """ 207 Checks if distributed module is successfully initialized. 208 """ 209 return CollectiveManager.get_instance().initialized() 210 211 212def _get_backend(): 213 """ 214 Returns the backend of communication process groups. 215 216 Note: 217 Only one communication backend is supported by MindSpore for each process. 218 It should be one of `hccl`/`nccl`/`mccl`. 219 """ 220 return GlobalComm.BACKEND 221 222 223def _is_hccl_available(): 224 """ 225 Checks if `hccl` backend is available. 226 """ 227 return _HCCL_TEST_AVAILABLE 228 229 230def _is_nccl_available(): 231 """ 232 Checks if `nccl` backend is available. 233 """ 234 base_dir = os.path.dirname(os.path.realpath(__file__)) 235 lib_path = os.path.join(base_dir, "../lib/plugin/gpu*/libnvidia_collective.so") 236 file_paths = glob.glob(lib_path) 237 return all(list(load_lib(f) for f in file_paths)) 238 239 240def _is_mpi_available(): 241 """ 242 Checks if OpenMPI's library is available. 243 """ 244 base_dir = os.path.dirname(os.path.realpath(__file__)) 245 lib_path = os.path.join(base_dir, "../lib/libmpi_collective.so") 246 return load_lib(lib_path) 247 248 249@check_parameter_available 250def _get_rank_helper(group): 251 """ 252 The Helper to do get_rank_id. 253 254 Args: 255 group (str): The communication group. 256 backend (str): The backend, like "hccl". 257 258 Raises: 259 ValueError: If backend is invalid. 260 261 Returns: 262 Integer. The local rank id of the calling process. 263 """ 264 if _check_bypass_rank_id_and_size(): 265 rank_id = 0 266 return rank_id 267 if _hccl_test(): 268 return hccl.get_rank_id(group) 269 rank_id = CollectiveManager.get_instance().get_rank_id(group) 270 return rank_id 271 272 273@check_parameter_available 274def _get_local_rank_helper(group): 275 """ 276 The Helper to do get_local_rank_id. 277 278 Args: 279 group (str): The communication group. 280 backend (str): The backend, like "hccl". 281 282 Raises: 283 ValueError: If backend is invalid. 284 285 Returns: 286 Integer. The local rank id of the calling process. 287 """ 288 if _check_bypass_rank_id_and_size(): 289 local_rank_id = 0 290 return local_rank_id 291 if _hccl_test(): 292 return hccl.get_local_rank_id(group) 293 rank_id = CollectiveManager.get_instance().get_local_rank_id(group) 294 return rank_id 295 296 297@check_parameter_available 298def _get_size_helper(group): 299 """ 300 The Helper to do get_rank_size. 301 302 Args: 303 group (str): The communication group. 304 backend (str): The backend, like "hccl". 305 306 Raises: 307 ValueError: If backend is invalid. 308 309 Returns: 310 Integer. The rank size of specified group. 311 """ 312 if _check_bypass_rank_id_and_size(): 313 size = 1 314 return size 315 if _hccl_test(): 316 return hccl.get_rank_size(group) 317 size = CollectiveManager.get_instance().get_group_size(group) 318 return size 319 320 321@check_parameter_available 322def _get_local_size_helper(group): 323 """ 324 The Helper to do get_local_rank_size. 325 326 Args: 327 group (str): The communication group. 328 backend (str): The backend, like "hccl". 329 330 Raises: 331 ValueError: If backend is invalid. 332 333 Returns: 334 Integer. The local rank size where the calling process is being within specified group. 335 """ 336 size = CollectiveManager.get_instance().get_local_group_size(group) 337 return size 338 339 340@check_parameter_available 341def _get_world_rank_from_group_rank_helper(group, group_rank_id): 342 """ 343 The Helper to do get_world_rank_from_group_rank. 344 345 Args: 346 group (str): The user communication group. 347 group_rank_id (int): A rank id in user communication group. 348 backend (str): The backend, like "hccl". 349 350 Raises: 351 TypeError: If group_rank_id is not int. 352 ValueError: If group is "hccl_world_group" or backend is invalid. 353 354 Returns: 355 Integer. A rank id in world communication group. 356 """ 357 if not isinstance(group_rank_id, int): 358 raise TypeError("For 'get_world_rank_from_group_rank', the argument 'group_rank_id' must be" 359 " type of int, but got 'group_rank_id' type : {}.".format(type(group_rank_id))) 360 if _hccl_test(): 361 return hccl.get_world_rank_from_group_rank(group, group_rank_id) 362 world_rank_id = CollectiveManager.get_instance().get_world_rank_from_group_rank(group, group_rank_id) 363 return world_rank_id 364 365 366@check_parameter_available 367def _get_group_rank_from_world_rank_helper(world_rank_id, group): 368 """ 369 The Helper to do get_group_rank_from_world_rank. 370 371 Args: 372 world_rank_id (int): A rank id in world communication group. 373 group (str): The user communication group. 374 backend (str): The backend, like "hccl". 375 376 Raises: 377 TypeError: If world_rank_id is not int. 378 ValueError: If group is 'hccl_world_group' or backend is invalid. 379 380 Returns: 381 Integer. A rank id in user communication group. 382 """ 383 group_rank_id = None 384 if not isinstance(world_rank_id, int): 385 raise TypeError("For 'get_group_rank_from_world_rank', the argument 'world_rank_id' must be type of int, " 386 "but got 'world_rank_id' type : {}.".format(type(world_rank_id))) 387 if _hccl_test(): 388 return hccl.get_group_rank_from_world_rank(world_rank_id, group) 389 group_rank_id = CollectiveManager.get_instance().get_group_rank_from_world_rank(world_rank_id, group) 390 return group_rank_id 391 392 393@check_parameter_available 394def _get_group_ranks(group): 395 """ 396 The Helper to do get_group_ranks. 397 398 Args: 399 group (str): The communication group. 400 401 Returns: 402 List. The ranks of specified group. 403 """ 404 return CollectiveManager.get_instance().get_group_ranks(group) 405 406 407@check_parameter_available 408def _create_group_helper(group, rank_ids): 409 """ 410 The Helper to do create_group. 411 412 Args: 413 group (str): The communication group. 414 rank_ids (list): Rank ids in the group. 415 backend (str): The backend, like "hccl". 416 417 Raises: 418 TypeError: If rank_ids is not a list. 419 ValueError: If rank_ids size is not larger than 1 or rank_ids has duplicate data or backend is invalid. 420 """ 421 if group in _ExistingGroup.ITEMS.keys(): 422 if rank_ids != _ExistingGroup.ITEMS.get(group): 423 raise ValueError("The group {} has been created, the rank_list is {}, " 424 "but current rank_list for the group is {}". 425 format(group, _ExistingGroup.ITEMS[group], rank_ids)) 426 logger.warning("%r group has existed.", group) 427 return 428 if not isinstance(rank_ids, list): 429 raise TypeError("For 'create_group', the argument 'rank_ids' must be type of list, " 430 "but got 'rank_ids' type : {}.".format(type(rank_ids))) 431 rank_size = len(rank_ids) 432 if rank_size < 1: 433 raise ValueError("For 'create_group', the argument 'rank_ids' size should be greater than 1, " 434 "but got 'rank_ids' size : {}.".format(len(rank_ids))) 435 if len(rank_ids) - len(list(set(rank_ids))) > 0: 436 raise ValueError("List rank_ids in Group {} has duplicate data!".format(group)) 437 if _hccl_test(): 438 hccl.create_group(group, rank_size, rank_ids) 439 else: 440 result = CollectiveManager.get_instance().create_group(group, rank_ids) 441 if not result: 442 raise RuntimeError("Failed to create communication group for {} with rank ids {}. " 443 "If NCCL is used, 'export NCCL_DEBUG=INFO' " 444 "is suggested before launching jobs.".format(group, rank_ids)) 445 446 _ExistingGroup.ITEMS[group] = rank_ids 447 448 449@check_parameter_available 450def _destroy_group_helper(group): 451 """ 452 The Helper to do destroy_group. 453 454 Args: 455 group (str): The user communication group. 456 backend (str): The backend, like "hccl". 457 458 Raises: 459 ValueError: If group is "hccl_world_group" or backend is invalid. 460 """ 461 if group == GlobalComm.WORLD_COMM_GROUP: 462 raise ValueError("The world_group does not support destruction.") 463 if _hccl_test(): 464 hccl.create_group(group) 465 else: 466 CollectiveManager.get_instance().destroy_group(group) 467