• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""comm_helper"""
16
17from mindspore.parallel._ps_context import _is_role_pserver, _is_role_sched
18from mindspore import log as logger
19from ._hccl_management import load_lib as hccl_load_lib
20
21_HCCL_AVAILABLE = False
22_NCCL_AVAILABLE = False
23_MPI_AVAILABLE = False
24try:
25    import mindspore._ms_mpi as mpi
26    _NCCL_AVAILABLE = True
27except ImportError:
28    _NCCL_AVAILABLE = False
29
30
31try:
32    hccl_load_lib()
33    _HCCL_AVAILABLE = True
34except RuntimeError:
35    _HCCL_AVAILABLE = False
36
37if _HCCL_AVAILABLE:
38    from . import _hccl_management as hccl
39    try:
40        import mindspore._ascend_mpi as mpi
41        _MPI_AVAILABLE = True
42    except ImportError:
43        _MPI_AVAILABLE = False
44else:
45    try:
46        import hccl_test.manage.api as hccl
47        _HCCL_AVAILABLE = True
48    except ImportError:
49        _HCCL_AVAILABLE = False
50
51
52HCCL_WORLD_COMM_GROUP = "hccl_world_group"
53NCCL_WORLD_COMM_GROUP = "nccl_world_group"
54
55
56class Backend:
57    """
58    Class for available backends.
59
60    Note:
61        The backends' value should be string, e.g., "hccl".
62        If backend is set to Backend.UNDEFINED, it will be seen as invaliad.
63
64    Args:
65        name (str): The name of backend.
66
67    Raises:
68        TypeError: If name is not a string.
69        ValueError: If backend is invalid.
70
71    Examples:
72        >>> Backend("abc")
73        >>> hccl = Backend("hccl")
74    """
75    UNDEFINED = "undefined"
76    HCCL = "hccl"
77    NCCL = "nccl"
78    HCCL_MPI = "hccl_mpi"
79
80    def __new__(cls, name):
81        """Create instance object of Backend."""
82        if not isinstance(name, str):
83            raise TypeError("Backend name must be a string, but got {}".format(type(name)))
84        value = getattr(Backend, name.upper(), Backend.UNDEFINED)
85        if value == Backend.UNDEFINED:
86            raise ValueError("Invalid backend: '{}'".format(name))
87        return value
88
89DEFAULT_BACKEND = Backend("hccl")
90
91
92class GlobalComm:
93    """
94    World communication information. The GlobalComm is a global class. The members contain: BACKEND, WORLD_COMM_GROUP.
95    """
96    BACKEND = DEFAULT_BACKEND
97    WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP
98    INITED = False
99    CHECK_ENVS = True
100
101
102class _ExistingGroup:
103    """
104    The communication groups which exist in the progress.
105    """
106    ITEMS = {}
107
108
109def is_hccl_available():
110    """
111    Check HCCL api is available.
112
113    Returns:
114        Boolean. Return whether HCCL is available or not.
115    """
116    return _HCCL_AVAILABLE
117
118
119def is_mpi_available():
120    """
121    Check HCCL & MPI api is available.
122
123    Returns:
124        Boolean. Return whether HCCL & MPI is available or not.
125    """
126    return _MPI_AVAILABLE
127
128
129def is_nccl_available():
130    """
131    Check NCCL api is available.
132
133    Returns:
134        Boolean. Return whether NCCL is available or not.
135    """
136    return _NCCL_AVAILABLE
137
138
139def check_parameter_available(func):
140    """
141    Check parameter is available. If not available, raise Error.
142
143    Args:
144        func (Function): The function to be run.
145
146    Raises:
147        RuntimeError.
148
149    Returns:
150        Wrapper. If not available, raise Error.
151    """
152    def wrapper(*args, **kargs):
153        if _is_role_pserver() or _is_role_sched():
154            return func(*args, **kargs)
155        if not GlobalComm.INITED:
156            raise RuntimeError("Distributed Communication has not been inited")
157        group = None
158        if "group" in kargs.keys():
159            group = kargs.get("group")
160            if group is not None and not isinstance(group, str):
161                raise TypeError("Group should be str or None, "
162                                "but got group {}".format(type(group)))
163
164        if "backend" in kargs.keys():
165            backend = kargs.get("backend")
166            if backend is Backend.HCCL and not is_hccl_available():
167                raise RuntimeError("Distributed Communication doesn't have HCCL built in")
168            if backend is Backend.HCCL_MPI and not is_mpi_available():
169                raise RuntimeError("Distributed Communication doesn't have MPI built in")
170            if backend is Backend.NCCL and not is_nccl_available():
171                raise RuntimeError("Distributed Communication doesn't have NCCL built in")
172
173        if group is None:
174            if backend is Backend.HCCL or Backend.HCCL_MPI:
175                group = HCCL_WORLD_COMM_GROUP
176            elif backend is Backend.NCCL:
177                group = NCCL_WORLD_COMM_GROUP
178        return func(*args, **kargs)
179    return wrapper
180
181
182@check_parameter_available
183def _get_rank_helper(group, backend):
184    """
185    The Helper to do get_rank_id.
186
187    Args:
188        group (str): The communication group.
189        backend (str): The backend, like "hccl".
190
191    Raises:
192        ValueError: If backend is invalid.
193
194    Returns:
195        Integer. The local rank id of the calling process.
196    """
197    rank_id = None
198    if _is_role_pserver() or _is_role_sched():
199        rank_id = 0
200        return rank_id
201    if backend == Backend.HCCL_MPI:
202        rank_id = mpi.get_rank_id(group)
203    elif backend == Backend.HCCL:
204        if group == HCCL_WORLD_COMM_GROUP:
205            rank_id = hccl.get_rank_id()
206        else:
207            rank_id = hccl.get_rank_id(group)
208    elif backend == Backend.NCCL:
209        rank_id = mpi.get_rank_id(group)
210    else:
211        raise ValueError("Invalid backend: '{}'".format(backend))
212    return rank_id
213
214
215@check_parameter_available
216def _get_local_rank_helper(group, backend):
217    """
218    The Helper to do get_local_rank_id.
219
220    Args:
221        group (str): The communication group.
222        backend (str): The backend, like "hccl".
223
224    Raises:
225        ValueError: If backend is invalid.
226
227    Returns:
228        Integer. The local rank id of the calling process.
229    """
230    rank_id = None
231    if backend == Backend.HCCL_MPI:
232        rank_id = mpi.get_rank_id(group)
233    elif backend == Backend.HCCL:
234        if group == HCCL_WORLD_COMM_GROUP:
235            rank_id = hccl.get_local_rank_id()
236        else:
237            rank_id = hccl.get_local_rank_id(group)
238    elif backend == Backend.NCCL:
239        raise RuntimeError("Nccl doesn't support get_local_rank_id now.")
240    else:
241        raise ValueError("Invalid backend: '{}'".format(backend))
242    return rank_id
243
244
245@check_parameter_available
246def _get_size_helper(group, backend):
247    """
248    The Helper to do get_rank_size.
249
250    Args:
251        group (str): The communication group.
252        backend (str): The backend, like "hccl".
253
254    Raises:
255        ValueError: If backend is invalid.
256
257    Returns:
258        Integer. The rank size of specified group.
259    """
260    size = None
261    if _is_role_pserver() or _is_role_sched():
262        size = 1
263        return size
264    if backend == Backend.HCCL_MPI:
265        size = mpi.get_rank_size(group)
266    elif backend == Backend.HCCL:
267        if group == HCCL_WORLD_COMM_GROUP:
268            size = hccl.get_rank_size()
269        else:
270            size = hccl.get_rank_size(group)
271    elif backend == Backend.NCCL:
272        size = mpi.get_rank_size(group)
273    else:
274        raise ValueError("Invalid backend: '{}'".format(backend))
275    return size
276
277
278@check_parameter_available
279def _get_local_size_helper(group, backend):
280    """
281    The Helper to do get_local_rank_size.
282
283    Args:
284        group (str): The communication group.
285        backend (str): The backend, like "hccl".
286
287    Raises:
288        ValueError: If backend is invalid.
289
290    Returns:
291        Integer. The local rank size where the calling process is being within specified group.
292    """
293    size = None
294    if backend == Backend.HCCL:
295        if group == HCCL_WORLD_COMM_GROUP:
296            size = hccl.get_local_rank_size()
297        else:
298            size = hccl.get_local_rank_size(group)
299    elif backend == Backend.NCCL:
300        raise RuntimeError("Nccl doesn't support get_local_rank_size now.")
301    else:
302        raise ValueError("Invalid backend: '{}'".format(backend))
303    return size
304
305
306@check_parameter_available
307def _get_world_rank_from_group_rank_helper(group, group_rank_id, backend):
308    """
309    The Helper to do get_world_rank_from_group_rank.
310
311    Args:
312        group (str): The user communication group.
313        group_rank_id (int): A rank id in user communication group.
314        backend (str): The backend, like "hccl".
315
316    Raises:
317        TypeError: If group_rank_id is not int.
318        ValueError: If group is "hccl_world_group" or backend is invalid.
319
320    Returns:
321        Integer. A rank id in world communication group.
322    """
323    world_rank_id = None
324    if not isinstance(group_rank_id, int):
325        raise TypeError("group_rank_id should be int, but got type {}".format(type(group_rank_id)))
326    if backend == Backend.HCCL:
327        if group == HCCL_WORLD_COMM_GROUP:
328            raise ValueError("Group cannot be 'hccl_world_group'. ")
329        world_rank_id = hccl.get_world_rank_from_group_rank(group, group_rank_id)
330    elif backend == Backend.NCCL:
331        raise RuntimeError("Nccl doesn't support get_world_rank_from_group_rank now.")
332    else:
333        raise ValueError("Invalid backend: '{}'".format(backend))
334    return world_rank_id
335
336
337@check_parameter_available
338def _get_group_rank_from_world_rank_helper(world_rank_id, group, backend):
339    """
340    The Helper to do get_group_rank_from_world_rank.
341
342    Args:
343        world_rank_id (int): A rank id in world communication group.
344        group (str): The user communication group.
345        backend (str): The backend, like "hccl".
346
347    Raises:
348        TypeError: If world_rank_id is not int.
349        ValueError: If group is 'hccl_world_group' or backend is invalid.
350
351    Returns:
352        Integer. A rank id in user communication group.
353    """
354    group_rank_id = None
355    if not isinstance(world_rank_id, int):
356        raise TypeError("world_rank_id should be int, but got type {}".format(type(world_rank_id)))
357    if backend == Backend.HCCL:
358        if group == HCCL_WORLD_COMM_GROUP:
359            raise ValueError("Group cannot be 'hccl_world_group'. ")
360        group_rank_id = hccl.get_group_rank_from_world_rank(world_rank_id, group)
361    elif backend == Backend.NCCL:
362        raise RuntimeError("Nccl doesn't support get_group_rank_from_world_rank now.")
363    else:
364        raise ValueError("Invalid backend: '{}'".format(backend))
365    return group_rank_id
366
367
368@check_parameter_available
369def _create_group_helper(group, rank_ids, backend):
370    """
371    The Helper to do create_group.
372
373    Args:
374        group (str): The communication group.
375        rank_ids (list): Rank ids in the group.
376        backend (str): The backend, like "hccl".
377
378    Raises:
379        TypeError: If rank_ids is not a list.
380        ValueError: If rank_ids size is not larger than 1 or rank_ids has duplicate data or backend is invalid.
381    """
382    if group in _ExistingGroup.ITEMS.keys():
383        if rank_ids != _ExistingGroup.ITEMS[group]:
384            raise ValueError("The group {} has been created, the rank_list is {}, "
385                             "but current rank_list for the group is {}".
386                             format(group, _ExistingGroup.ITEMS[group], rank_ids))
387        logger.warning("%r group has existed.", group)
388        return
389    if backend == Backend.HCCL:
390        if not isinstance(rank_ids, list):
391            raise TypeError("Rank_ids {} should be list".format(rank_ids))
392        rank_size = len(rank_ids)
393        if rank_size < 1:
394            raise ValueError("Rank_ids size {} should be large than 0".format(rank_size))
395        if len(rank_ids) - len(list(set(rank_ids))) > 0:
396            raise ValueError("List rank_ids in Group {} has duplicate data!".format(group))
397        hccl.create_group(group, rank_size, rank_ids)
398    elif backend == Backend.HCCL_MPI:
399        mpi.create_group(group, rank_ids)
400    elif backend == Backend.NCCL:
401        raise RuntimeError("Nccl doesn't support create_group now.")
402    else:
403        raise ValueError("Invalid backend: '{}'".format(backend))
404    _ExistingGroup.ITEMS[group] = rank_ids
405
406
407@check_parameter_available
408def _destroy_group_helper(group, backend):
409    """
410    The Helper to do destroy_group.
411
412    Args:
413        group (str): The user communication group.
414        backend (str): The backend, like "hccl".
415
416    Raises:
417        ValueError: If group is "hccl_world_group" or backend is invalid.
418    """
419    if backend == Backend.HCCL:
420        if group == HCCL_WORLD_COMM_GROUP:
421            raise ValueError("The hccl_world_group does not support destruction.")
422        hccl.destroy_group(group)
423    elif backend == Backend.NCCL:
424        raise RuntimeError("Nccl doesn't support destroy_group now.")
425    else:
426        raise ValueError("Invalid backend: '{}'".format(backend))
427