1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""Data parallel allreduce fusion""" 16 17import ctypes 18 19_MAX_GROUP_NAME_LEN = 127 20_HCCL_LIB = 'libhccl.so' 21 22 23def _load_lib(): 24 try: 25 hccl_lib = ctypes.CDLL(_HCCL_LIB) 26 except Exception: 27 raise RuntimeError('Get hccl lib error') 28 29 return hccl_lib 30 31 32def _c_str(string): 33 """Convert a python string to C string.""" 34 if not isinstance(string, str): 35 string = string.decode('ascii') 36 return ctypes.c_char_p(string.encode('utf-8')) 37 38 39def _c_array(ctype, values): 40 """Create ctypes array from a python array.""" 41 return (ctype * len(values))(*values) 42 43 44def _set_fusion_strategy_by_idx(idx_list, group="hccl_world_group"): 45 """ 46 A function set gradient segment strategy according to the index list. 47 48 Note: 49 In the back propagation, 50 the fusion of the allreduce operators with a fusion attribute equals 1, 51 will be performed according to the idx_list, 52 to achieve the effect of parallel between calculation and communication. 53 54 Args: 55 idx_list (list): The index list of the gradient. 56 group (str): The hccl communication group. 57 58 Raises: 59 TypeError: If group is not a python str. 60 TypeError: If idx_list is not a python list. 61 TypeError: If type of idx_list item is not int. 62 ValueError: If group name length is out of range. 63 ValueError: If idx_list length is 0. 64 ValueError: If idx_list item is less than 0. 65 RuntimeError: If allreduce split failed. 66 """ 67 try: 68 lib_ctype = _load_lib() 69 except RuntimeError: 70 import hccl_test.manage.api as hccl 71 hccl.set_fusion_strategy_by_idx() 72 return 73 finally: 74 pass 75 if isinstance(group, str): 76 group_len = len(group) 77 if (group_len > _MAX_GROUP_NAME_LEN or group_len == 0): 78 raise ValueError('Group name len is out of range {_MAX_GROUP_NAME_LEN}') 79 else: 80 raise TypeError('Group must be a python str') 81 82 if isinstance(idx_list, list): 83 idx_len = len(idx_list) 84 if idx_len == 0: 85 raise ValueError('idx_list length is 0') 86 else: 87 raise TypeError('idx_list must be a python list') 88 89 for idx in idx_list: 90 if isinstance(idx, int): 91 if idx < 0: 92 raise ValueError('Idx < 0') 93 else: 94 raise TypeError('Idx in idx_list is invalid') 95 96 c_array_idx_list = _c_array(ctypes.c_uint, idx_list) 97 c_idx_num = ctypes.c_uint(len(idx_list)) 98 c_group = _c_str(group) 99 ret = lib_ctype.hcom_set_split_strategy_by_index(c_group, c_idx_num, c_array_idx_list) 100 if ret != 0: 101 raise RuntimeError('Allreduce split error') 102 103 104def _set_fusion_strategy_by_size(data_size_list, group="hccl_world_group"): 105 """ 106 A function set gradient segment strategy according to the data size percentage list. 107 108 Note: 109 In the back propagation, 110 the fusion of the allreduce operators with a fusion attribute equals 1, 111 will be performed according to data_size_list, 112 to achieve the effect of parallel between calculation and communication. 113 114 Args: 115 data_size_list (list): The data size percentage list of the gradient. 116 group (str): The hccl communication group. 117 118 Raises: 119 TypeError: If group is not a python str. 120 TypeError: If data_size_list is not a python list. 121 TypeError: If type of data_size_list item is not int or float. 122 ValueError: If group name length is out of range. 123 ValueError: If data_size_list length is 0. 124 ValueError: If data_size_list item is less than 0. 125 RuntimeError: If allreduce split failed. 126 """ 127 try: 128 lib_ctype = _load_lib() 129 except RuntimeError: 130 import hccl_test.manage.api as hccl 131 hccl.set_fusion_strategy_by_size() 132 return 133 finally: 134 pass 135 136 if isinstance(group, str): 137 group_len = len(group) 138 if group_len > _MAX_GROUP_NAME_LEN or group_len == 0: 139 raise ValueError('Group name is out of range {_MAX_GROUP_NAME_LEN}') 140 else: 141 raise TypeError('Group must be a python str') 142 if isinstance(data_size_list, list): 143 len_data_size = len(data_size_list) 144 if len_data_size == 0: 145 raise ValueError('data_size_list length is 0') 146 else: 147 raise TypeError('data_size_list must be a python list') 148 for data_size in data_size_list: 149 if not isinstance(data_size, (int, float)): 150 raise TypeError('data_size in data_size_list is invalid') 151 152 c_array_size_list = _c_array(ctypes.c_float, data_size_list) 153 c_size_num = ctypes.c_uint(len(data_size_list)) 154 c_group = _c_str(group) 155 ret = lib_ctype.hcom_set_split_strategy_by_size(c_group, c_size_num, c_array_size_list) 156 if ret != 0: 157 raise RuntimeError('Allreduce split error') 158