• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Data parallel allreduce fusion"""
16
17import ctypes
18
19_MAX_GROUP_NAME_LEN = 127
20_HCCL_LIB = 'libhccl.so'
21
22
23def _load_lib():
24    try:
25        hccl_lib = ctypes.CDLL(_HCCL_LIB)
26    except Exception:
27        raise RuntimeError('Get hccl lib error')
28
29    return hccl_lib
30
31
32def _c_str(string):
33    """Convert a python string to C string."""
34    if not isinstance(string, str):
35        string = string.decode('ascii')
36    return ctypes.c_char_p(string.encode('utf-8'))
37
38
39def _c_array(ctype, values):
40    """Create ctypes array from a python array."""
41    return (ctype * len(values))(*values)
42
43
44def _set_fusion_strategy_by_idx(idx_list, group="hccl_world_group"):
45    """
46    A function set gradient segment strategy according to the index list.
47
48    Note:
49        In the back propagation,
50        the fusion of the allreduce operators with a fusion attribute equals 1,
51        will be performed according to the idx_list,
52        to achieve the effect of parallel between calculation and communication.
53
54    Args:
55        idx_list (list): The index list of the gradient.
56        group (str): The hccl communication group.
57
58    Raises:
59        TypeError: If group is not a python str.
60        TypeError: If idx_list is not a python list.
61        TypeError: If type of idx_list item is not int.
62        ValueError: If group name length is out of range.
63        ValueError: If idx_list length is 0.
64        ValueError: If idx_list item is less than 0.
65        RuntimeError: If allreduce split failed.
66    """
67    try:
68        lib_ctype = _load_lib()
69    except RuntimeError:
70        import hccl_test.manage.api as hccl
71        hccl.set_fusion_strategy_by_idx()
72        return
73    finally:
74        pass
75    if isinstance(group, str):
76        group_len = len(group)
77        if (group_len > _MAX_GROUP_NAME_LEN or group_len == 0):
78            raise ValueError('Group name len is out of range {_MAX_GROUP_NAME_LEN}')
79    else:
80        raise TypeError('Group must be a python str')
81
82    if isinstance(idx_list, list):
83        idx_len = len(idx_list)
84        if idx_len == 0:
85            raise ValueError('idx_list length is 0')
86    else:
87        raise TypeError('idx_list must be a python list')
88
89    for idx in idx_list:
90        if isinstance(idx, int):
91            if idx < 0:
92                raise ValueError('Idx < 0')
93        else:
94            raise TypeError('Idx in idx_list is invalid')
95
96    c_array_idx_list = _c_array(ctypes.c_uint, idx_list)
97    c_idx_num = ctypes.c_uint(len(idx_list))
98    c_group = _c_str(group)
99    ret = lib_ctype.hcom_set_split_strategy_by_index(c_group, c_idx_num, c_array_idx_list)
100    if ret != 0:
101        raise RuntimeError('Allreduce split error')
102
103
104def _set_fusion_strategy_by_size(data_size_list, group="hccl_world_group"):
105    """
106    A function set gradient segment strategy according to the data size percentage list.
107
108    Note:
109        In the back propagation,
110        the fusion of the allreduce operators with a fusion attribute equals 1,
111        will be performed according to data_size_list,
112        to achieve the effect of parallel between calculation and communication.
113
114    Args:
115        data_size_list (list): The data size percentage list of the gradient.
116        group (str): The hccl communication group.
117
118    Raises:
119        TypeError: If group is not a python str.
120        TypeError: If data_size_list is not a python list.
121        TypeError: If type of data_size_list item is not int or float.
122        ValueError: If group name length is out of range.
123        ValueError: If data_size_list length is 0.
124        ValueError: If data_size_list item is less than 0.
125        RuntimeError: If allreduce split failed.
126    """
127    try:
128        lib_ctype = _load_lib()
129    except RuntimeError:
130        import hccl_test.manage.api as hccl
131        hccl.set_fusion_strategy_by_size()
132        return
133    finally:
134        pass
135
136    if isinstance(group, str):
137        group_len = len(group)
138        if group_len > _MAX_GROUP_NAME_LEN or group_len == 0:
139            raise ValueError('Group name is out of range {_MAX_GROUP_NAME_LEN}')
140    else:
141        raise TypeError('Group must be a python str')
142    if isinstance(data_size_list, list):
143        len_data_size = len(data_size_list)
144        if len_data_size == 0:
145            raise ValueError('data_size_list length is 0')
146    else:
147        raise TypeError('data_size_list must be a python list')
148    for data_size in data_size_list:
149        if not isinstance(data_size, (int, float)):
150            raise TypeError('data_size in data_size_list is invalid')
151
152    c_array_size_list = _c_array(ctypes.c_float, data_size_list)
153    c_size_num = ctypes.c_uint(len(data_size_list))
154    c_group = _c_str(group)
155    ret = lib_ctype.hcom_set_split_strategy_by_size(c_group, c_size_num, c_array_size_list)
156    if ret != 0:
157        raise RuntimeError('Allreduce split error')
158