• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ============================================================================
16"""HCCL management API"""
17import ctypes
18import os
19from mindspore import context
20from .._c_expression import get_hccl_rank_id, get_hccl_rank_size
21
22MAX_GROUP_NAME_LEN = 127
23MAX_RANK_NUM = 4096
24HCCL_LIB = 'libhccl_plugin.so'
25HCCL_LIB_CTYPES = ""
26
27
28def check_group(group):
29    """
30    A function that check if a collection communication group is legal.
31
32    Returns:
33        None
34    """
35    if isinstance(group, (str)):
36        group_len = len(group)
37        if group_len > MAX_GROUP_NAME_LEN or group_len == 0:
38            raise ValueError('Group name is invalid.')
39    else:
40        raise TypeError('Group must be a python str.')
41
42
43def check_rank_num(rank_num):
44    """
45    A function that check if a collection communication rank number is legal.If not raise error.
46
47    Returns:
48        None
49    """
50    if isinstance(rank_num, (int)):
51        if rank_num > MAX_RANK_NUM or rank_num <= 0:
52            raise ValueError('Rank number is out of range.')
53    else:
54        raise TypeError('Rank number must be a python int.')
55
56
57def check_rank_id(rank_id):
58    """
59    A function that check if a collection communication rank id is legal.If not raise error.
60
61    Returns:
62        None
63    """
64    if isinstance(rank_id, (int)):
65        if rank_id >= MAX_RANK_NUM or rank_id < 0:
66            raise ValueError('Rank id is out of range.')
67    else:
68        raise TypeError('Rank id must be a python int.')
69
70
71def load_lib():
72    """load hccl lib"""
73    try:
74        base_dir = os.path.dirname(os.path.realpath(__file__))
75        lib_path = os.path.join(base_dir, "../lib", HCCL_LIB)
76        hccl_lib = ctypes.CDLL(lib_path)
77    except Exception:
78        raise RuntimeError('Get hccl lib error.')
79
80    global HCCL_LIB_CTYPES
81    HCCL_LIB_CTYPES = hccl_lib
82
83
84def c_str(string):
85    """Convert a python string to C string."""
86    if not isinstance(string, str):
87        string = string.decode('ascii')
88    return ctypes.c_char_p(string.encode('utf-8'))
89
90
91def c_array(ctype, values):
92    """Create ctypes array from a python array."""
93    return (ctype * len(values))(*values)
94
95
96def create_group(group, rank_num, rank_ids):
97    """
98    Create group.
99
100    A function that creates a collection communication group which includes 'rank_num'
101    device and 'rank_ids' is the list of these ranks of devices.
102
103    Note:
104        The world group can not be created.
105
106    Returns:
107        None
108    """
109    check_group(group)
110    check_rank_num(rank_num)
111    if isinstance(rank_ids, (list)):
112        if rank_num != len(rank_ids):
113            raise ValueError('Rank number is not equal to the length of rank_ids.')
114        for rank_id in rank_ids:
115            if not isinstance(rank_id, (int)) or rank_id < 0:
116                raise ValueError('Rank id must be unsigned integer!')
117        c_array_rank_ids = c_array(ctypes.c_uint, rank_ids)
118        c_rank_num = ctypes.c_uint(rank_num)
119        c_group = c_str(group)
120        ret = HCCL_LIB_CTYPES.HcomCreateGroup(c_group, c_rank_num, c_array_rank_ids)
121        if ret != 0:
122            raise RuntimeError('Create group error, the error code is ' + str(ret))
123    else:
124        raise TypeError('Rank ids must be a python list.')
125
126
127def destroy_group(group):
128    """
129    A function that destroy the group which created by user.
130
131    Note:
132        The world group can not be destroy.
133
134    Returns:
135        None
136    """
137    check_group(group)
138    c_group = c_str(group)
139    ret = HCCL_LIB_CTYPES.HcomDestroyGroup(c_group)
140    if ret != 0:
141        raise RuntimeError('Destroy group error.')
142
143
144def get_rank_size(group="hccl_world_group"):
145    """
146    A function that returns the number of ranks within the given collection communication group.
147
148    Note:
149        The default group is hccl_world_group.
150
151    Returns:
152        An integer scalar with the num of ranks.
153    """
154
155    if context.get_context("mode") == context.PYNATIVE_MODE:
156        return get_hccl_rank_size()
157
158    check_group(group)
159    c_group = c_str(group)
160    c_rank_size = ctypes.c_uint()
161    ret = HCCL_LIB_CTYPES.HcomGetRankSize(c_group, ctypes.byref(c_rank_size))
162    if ret != 0:
163        raise RuntimeError('Get rank size error.')
164
165    return c_rank_size.value
166
167
168def get_rank_id(group="hccl_world_group"):
169    """
170    A function that returns the rank id of the calling process, within the given collection communication group.
171
172    Returns:
173        An integer scalar with the rank id of the calling process.
174    """
175
176    if context.get_context("mode") == context.PYNATIVE_MODE:
177        return get_hccl_rank_id()
178
179    check_group(group)
180    c_group = c_str(group)
181    c_rank_id = ctypes.c_uint()
182    ret = HCCL_LIB_CTYPES.HcomGetRankId(c_group, ctypes.byref(c_rank_id))
183    if ret != 0:
184        raise RuntimeError('Get rank id error.')
185
186    return c_rank_id.value
187
188
189
190def get_local_rank_size(group="hccl_world_group"):
191    """
192    A function that returns the number of local ranks within the given collection communication group.
193
194    Note:
195        The default group is hccl_world_group.
196
197    Returns:
198        An integer scalar with the num of local ranks.
199    """
200    check_group(group)
201    c_group = c_str(group)
202    c_local_rank_size = ctypes.c_uint()
203    ret = HCCL_LIB_CTYPES.HcomGetLocalRankSize(c_group, ctypes.byref(c_local_rank_size))
204    if ret != 0:
205        raise RuntimeError('Get local rank size error.')
206
207    return c_local_rank_size.value
208
209
210def get_local_rank_id(group="hccl_world_group"):
211    """
212    Get local rank id.
213
214    A function that returns the local rank id of the calling process, within the given collection communication group.
215
216    Returns:
217        An integer scalar with the local rank id of the calling process.
218    """
219    check_group(group)
220    c_group = c_str(group)
221    c_local_rank_id = ctypes.c_uint()
222    ret = HCCL_LIB_CTYPES.HcomGetLocalRankId(c_group, ctypes.byref(c_local_rank_id))
223    if ret != 0:
224        raise RuntimeError('Get local rank id error.')
225
226    return c_local_rank_id.value
227
228
229def get_world_rank_from_group_rank(group, group_rank_id):
230    """
231    Get world rank from group rank.
232
233    A function that returns the rank id in the world group corresponding to the
234    rank which id is 'group_rank_id' in the user group.
235
236    Returns:
237        An integer scalar with the rank id in the world group.
238    """
239    check_group(group)
240    check_rank_id(group_rank_id)
241    c_group = c_str(group)
242    c_group_rank_id = ctypes.c_uint(group_rank_id)
243    c_world_rank_id = ctypes.c_uint()
244    ret = HCCL_LIB_CTYPES.HcomGetWorldRankFromGroupRank(c_group, c_group_rank_id, ctypes.byref(c_world_rank_id))
245    if ret != 0:
246        raise RuntimeError('Get world rank from group rank error.')
247
248    return c_world_rank_id.value
249
250
251def get_group_rank_from_world_rank(world_rank_id, group):
252    """
253    Get group rank from world rank.
254
255    A function that returns the rank id in the user group corresponding to the
256    rank which id is 'world_rank_id' in the world group.
257
258    Returns:
259        An integer scalar with the rank id in the user group.
260    """
261    check_group(group)
262    check_rank_id(world_rank_id)
263    c_group = c_str(group)
264    c_world_rank_id = ctypes.c_uint(world_rank_id)
265    c_group_rank_id = ctypes.c_uint()
266    ret = HCCL_LIB_CTYPES.HcomGetGroupRankFromWorldRank(c_world_rank_id, c_group, ctypes.byref(c_group_rank_id))
267    if ret != 0:
268        raise RuntimeError('Get group rank from world rank error.')
269
270    return c_group_rank_id.value
271