• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""comm_helper"""
16
17import os
18import glob
19import ctypes
20
21import sys
22from sys import excepthook
23
24from mindspore import context
25from mindspore.parallel._ps_context import _is_role_worker, _is_role_pserver, _is_role_sched, _is_ps_mode,\
26                                           _get_ps_context
27from mindspore import log as logger
28from mindspore._c_expression import CollectiveManager, set_cluster_exit_with_exception, MSContext
29from mindspore.common._utils import load_lib
30
31HCCL_LIB = 'libhccl_plugin.so'
32
33
34def hccl_load_lib():
35    """load hccl lib"""
36    try:
37        base_dir = os.path.dirname(os.path.realpath(__file__))
38        lib_path = os.path.join(base_dir, "../lib/plugin/ascend", HCCL_LIB)
39        ctypes.CDLL(lib_path)
40    except Exception as exc:
41        raise RuntimeError('Get hccl lib error.') from exc
42
43_HCCL_TEST_AVAILABLE = False
44
45try:
46    if MSContext.get_instance().is_ascend_plugin_loaded():
47        hccl_load_lib()
48except RuntimeError:
49    _HCCL_TEST_AVAILABLE = True
50
51if _HCCL_TEST_AVAILABLE:
52    try:
53        import hccl_test.manage.api as hccl
54    except ImportError:
55        _HCCL_TEST_AVAILABLE = False
56
57
58HCCL_WORLD_COMM_GROUP = "hccl_world_group"
59NCCL_WORLD_COMM_GROUP = "nccl_world_group"
60MCCL_WORLD_COMM_GROUP = "mccl_world_group"
61
62DEVICE_TO_BACKEND = {
63    "Ascend": "hccl",
64    "GPU": "nccl",
65    "CPU": "mccl"
66}
67
68class Backend:
69    """
70    Class for available backends.
71
72    Note:
73        The backends' value should be string, e.g., "hccl".
74        If backend is set to Backend.UNDEFINED, it will be seen as invaliad.
75
76    Args:
77        name (str): The name of backend.
78
79    Raises:
80        TypeError: If name is not a string.
81        ValueError: If backend is invalid.
82
83    Examples:
84        >>> Backend("abc")
85        >>> hccl = Backend("hccl")
86    """
87    UNDEFINED = "undefined"
88    HCCL = "hccl"
89    NCCL = "nccl"
90    MCCL = "mccl"
91
92    @staticmethod
93    def __new__(cls, name):
94        """Create instance object of Backend."""
95        if not isinstance(name, str):
96            raise TypeError("For 'Backend', the class variable 'name' must be a string, "
97                            "but got the type : {}".format(type(name)))
98        value = getattr(Backend, name.upper(), Backend.UNDEFINED)
99        if value == Backend.UNDEFINED:
100            raise ValueError("For 'Backend', the class variable 'name' {} is not supported, "
101                             "please use hccl or nccl.".format(name))
102        return value
103
104
105DEFAULT_BACKEND = Backend("hccl")
106
107
108class GlobalComm:
109    """
110    World communication information. The GlobalComm is a global class. The members contain:
111
112    - ``BACKEND`` : The communication library used, using ``"hccl"`` / ``"nccl"`` / ``"mccl"`` .
113      ``"hccl"`` means Huawei Collective Communication Library(HCCL),
114      ``"nccl"`` means NVIDIA Collective Communication Library(NCCL),
115      ``"mccl"`` means MindSpore Collective Communication Library(MCCL).
116    - ``WORLD_COMM_GROUP`` : Global communication domain,
117      using ``"hccl_world_group"`` / ``"nccl_world_group"`` / ``"mccl_world_group"`` .
118    """
119    BACKEND = DEFAULT_BACKEND
120    WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP
121    INITED = False
122    CHECK_ENVS = True
123
124
125class _ExistingGroup:
126    """
127    The communication groups which exist in the progress.
128    """
129    ITEMS = {}
130
131
132def _hccl_test():
133    return _HCCL_TEST_AVAILABLE and GlobalComm.BACKEND == Backend.HCCL
134
135
136def _check_mpi_envs():
137    """
138    Check whether mpi environment variables have been exported or not.
139
140    return True if mpi environment variables have been exported, False otherwise.
141    """
142    ompi_command_env = os.getenv("OMPI_COMMAND")
143    pmix_rank_env = os.getenv("PMIX_RANK")
144    if ompi_command_env and pmix_rank_env:
145        return True
146    return False
147
148
149def _check_bypass_rank_id_and_size():
150    '''
151    Whether bypass calling c++ API to get rank id and size, instead, use fake rank id 0 and rank size 1.
152    This returns True when this process is Scheduler node or is Server node in old Parameter Server training mode.
153    '''
154    if _is_role_sched():
155        return True
156    device_target = context.get_context("device_target")
157    if _is_ps_mode() and _get_ps_context("worker_num") == 1 and device_target == "Ascend":
158        return True
159    return False
160
161
162def _set_elegant_exit_handle():
163    if _is_role_worker() or _is_role_pserver() or _is_role_sched():
164        sys.excepthook = lambda *args: (set_cluster_exit_with_exception(), excepthook(*args))
165
166
167def check_parameter_available(func):
168    """
169    Check parameter is available. If not available, raise Error.
170
171    Args:
172        func (Function): The function to be run.
173
174    Raises:
175        RuntimeError.
176
177    Returns:
178        Wrapper. If not available, raise Error.
179    """
180    def wrapper(*args, **kargs):
181        if not GlobalComm.INITED:
182            raise RuntimeError("Distributed Communication has not been inited")
183        group = None
184        if "group" in kargs.keys():
185            group = kargs.get("group")
186            if group is not None and not isinstance(group, str):
187                raise TypeError("The parameter 'group' should be str or None, "
188                                "but got the type : {}".format(type(group)))
189        if group is None:
190            group = GlobalComm.WORLD_COMM_GROUP
191        return func(*args, **kargs)
192    return wrapper
193
194
195def _is_available():
196    """
197    Returns `True` if distributed module is available.
198
199    Note:
200        Always returns `True` because MindSpore always has distributed ability on all platforms.
201    """
202    return True
203
204
205def _is_initialized():
206    """
207    Checks if distributed module is successfully initialized.
208    """
209    return CollectiveManager.get_instance().initialized()
210
211
212def _get_backend():
213    """
214    Returns the backend of communication process groups.
215
216    Note:
217        Only one communication backend is supported by MindSpore for each process.
218        It should be one of `hccl`/`nccl`/`mccl`.
219    """
220    return GlobalComm.BACKEND
221
222
223def _is_hccl_available():
224    """
225    Checks if `hccl` backend is available.
226    """
227    return _HCCL_TEST_AVAILABLE
228
229
230def _is_nccl_available():
231    """
232    Checks if `nccl` backend is available.
233    """
234    base_dir = os.path.dirname(os.path.realpath(__file__))
235    lib_path = os.path.join(base_dir, "../lib/plugin/gpu*/libnvidia_collective.so")
236    file_paths = glob.glob(lib_path)
237    return all(list(load_lib(f) for f in file_paths))
238
239
240def _is_mpi_available():
241    """
242    Checks if OpenMPI's library is available.
243    """
244    base_dir = os.path.dirname(os.path.realpath(__file__))
245    lib_path = os.path.join(base_dir, "../lib/libmpi_collective.so")
246    return load_lib(lib_path)
247
248
249@check_parameter_available
250def _get_rank_helper(group):
251    """
252    The Helper to do get_rank_id.
253
254    Args:
255        group (str): The communication group.
256        backend (str): The backend, like "hccl".
257
258    Raises:
259        ValueError: If backend is invalid.
260
261    Returns:
262        Integer. The local rank id of the calling process.
263    """
264    if _check_bypass_rank_id_and_size():
265        rank_id = 0
266        return rank_id
267    if _hccl_test():
268        return hccl.get_rank_id(group)
269    rank_id = CollectiveManager.get_instance().get_rank_id(group)
270    return rank_id
271
272
273@check_parameter_available
274def _get_local_rank_helper(group):
275    """
276    The Helper to do get_local_rank_id.
277
278    Args:
279        group (str): The communication group.
280        backend (str): The backend, like "hccl".
281
282    Raises:
283        ValueError: If backend is invalid.
284
285    Returns:
286        Integer. The local rank id of the calling process.
287    """
288    if _check_bypass_rank_id_and_size():
289        local_rank_id = 0
290        return local_rank_id
291    if _hccl_test():
292        return hccl.get_local_rank_id(group)
293    rank_id = CollectiveManager.get_instance().get_local_rank_id(group)
294    return rank_id
295
296
297@check_parameter_available
298def _get_size_helper(group):
299    """
300    The Helper to do get_rank_size.
301
302    Args:
303        group (str): The communication group.
304        backend (str): The backend, like "hccl".
305
306    Raises:
307        ValueError: If backend is invalid.
308
309    Returns:
310        Integer. The rank size of specified group.
311    """
312    if _check_bypass_rank_id_and_size():
313        size = 1
314        return size
315    if _hccl_test():
316        return hccl.get_rank_size(group)
317    size = CollectiveManager.get_instance().get_group_size(group)
318    return size
319
320
321@check_parameter_available
322def _get_local_size_helper(group):
323    """
324    The Helper to do get_local_rank_size.
325
326    Args:
327        group (str): The communication group.
328        backend (str): The backend, like "hccl".
329
330    Raises:
331        ValueError: If backend is invalid.
332
333    Returns:
334        Integer. The local rank size where the calling process is being within specified group.
335    """
336    size = CollectiveManager.get_instance().get_local_group_size(group)
337    return size
338
339
340@check_parameter_available
341def _get_world_rank_from_group_rank_helper(group, group_rank_id):
342    """
343    The Helper to do get_world_rank_from_group_rank.
344
345    Args:
346        group (str): The user communication group.
347        group_rank_id (int): A rank id in user communication group.
348        backend (str): The backend, like "hccl".
349
350    Raises:
351        TypeError: If group_rank_id is not int.
352        ValueError: If group is "hccl_world_group" or backend is invalid.
353
354    Returns:
355        Integer. A rank id in world communication group.
356    """
357    if not isinstance(group_rank_id, int):
358        raise TypeError("For 'get_world_rank_from_group_rank', the argument 'group_rank_id' must be"
359                        " type of int, but got 'group_rank_id' type : {}.".format(type(group_rank_id)))
360    if _hccl_test():
361        return hccl.get_world_rank_from_group_rank(group, group_rank_id)
362    world_rank_id = CollectiveManager.get_instance().get_world_rank_from_group_rank(group, group_rank_id)
363    return world_rank_id
364
365
366@check_parameter_available
367def _get_group_rank_from_world_rank_helper(world_rank_id, group):
368    """
369    The Helper to do get_group_rank_from_world_rank.
370
371    Args:
372        world_rank_id (int): A rank id in world communication group.
373        group (str): The user communication group.
374        backend (str): The backend, like "hccl".
375
376    Raises:
377        TypeError: If world_rank_id is not int.
378        ValueError: If group is 'hccl_world_group' or backend is invalid.
379
380    Returns:
381        Integer. A rank id in user communication group.
382    """
383    group_rank_id = None
384    if not isinstance(world_rank_id, int):
385        raise TypeError("For 'get_group_rank_from_world_rank', the argument 'world_rank_id' must be type of int, "
386                        "but got 'world_rank_id' type : {}.".format(type(world_rank_id)))
387    if _hccl_test():
388        return hccl.get_group_rank_from_world_rank(world_rank_id, group)
389    group_rank_id = CollectiveManager.get_instance().get_group_rank_from_world_rank(world_rank_id, group)
390    return group_rank_id
391
392
393@check_parameter_available
394def _get_group_ranks(group):
395    """
396    The Helper to do get_group_ranks.
397
398    Args:
399        group (str): The communication group.
400
401    Returns:
402        List. The ranks of specified group.
403    """
404    return CollectiveManager.get_instance().get_group_ranks(group)
405
406
407@check_parameter_available
408def _create_group_helper(group, rank_ids):
409    """
410    The Helper to do create_group.
411
412    Args:
413        group (str): The communication group.
414        rank_ids (list): Rank ids in the group.
415        backend (str): The backend, like "hccl".
416
417    Raises:
418        TypeError: If rank_ids is not a list.
419        ValueError: If rank_ids size is not larger than 1 or rank_ids has duplicate data or backend is invalid.
420    """
421    if group in _ExistingGroup.ITEMS.keys():
422        if rank_ids != _ExistingGroup.ITEMS.get(group):
423            raise ValueError("The group {} has been created, the rank_list is {}, "
424                             "but current rank_list for the group is {}".
425                             format(group, _ExistingGroup.ITEMS[group], rank_ids))
426        logger.warning("%r group has existed.", group)
427        return
428    if not isinstance(rank_ids, list):
429        raise TypeError("For 'create_group', the argument 'rank_ids' must be type of list, "
430                        "but got 'rank_ids' type : {}.".format(type(rank_ids)))
431    rank_size = len(rank_ids)
432    if rank_size < 1:
433        raise ValueError("For 'create_group', the argument 'rank_ids' size should be greater than 1, "
434                         "but got 'rank_ids' size : {}.".format(len(rank_ids)))
435    if len(rank_ids) - len(list(set(rank_ids))) > 0:
436        raise ValueError("List rank_ids in Group {} has duplicate data!".format(group))
437    if _hccl_test():
438        hccl.create_group(group, rank_size, rank_ids)
439    else:
440        result = CollectiveManager.get_instance().create_group(group, rank_ids)
441        if not result:
442            raise RuntimeError("Failed to create communication group for {} with rank ids {}. "
443                               "If NCCL is used, 'export NCCL_DEBUG=INFO' "
444                               "is suggested before launching jobs.".format(group, rank_ids))
445
446    _ExistingGroup.ITEMS[group] = rank_ids
447
448
449@check_parameter_available
450def _destroy_group_helper(group):
451    """
452    The Helper to do destroy_group.
453
454    Args:
455        group (str): The user communication group.
456        backend (str): The backend, like "hccl".
457
458    Raises:
459        ValueError: If group is "hccl_world_group" or backend is invalid.
460    """
461    if group == GlobalComm.WORLD_COMM_GROUP:
462        raise ValueError("The world_group does not support destruction.")
463    if _hccl_test():
464        hccl.create_group(group)
465    else:
466        CollectiveManager.get_instance().destroy_group(group)
467