1# Copyright 2021-2022 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""adasum""" 16from __future__ import absolute_import 17 18import copy 19import hashlib 20import math 21from mindspore.nn.cell import Cell 22from mindspore.communication.management import create_group 23from mindspore.ops import composite as C 24from mindspore.ops import functional as F 25from mindspore.ops import operations as P 26from mindspore.ops import Send, Receive 27 28 29__all__ = ["AdaSum"] 30 31 32_update_parameters = C.MultitypeFuncGraph("update_parameters") 33 34 35@_update_parameters.register("Tensor", "Tensor", "Tensor", "Tensor") 36def _update_parameters_after_broadcast(delta_weight, update_delta_weight, parameter, old_parameter): 37 shape = F.shape(delta_weight) 38 update_delta_weight = P.Reshape()(update_delta_weight, shape) 39 new_parameter = old_parameter - update_delta_weight 40 P.Assign()(parameter, new_parameter) 41 return parameter 42 43 44def _send_before_receive(send_part, send, recv): 45 send_ok = send(send_part) 46 return recv(send_ok) 47 48 49def _receive_before_send(send_part, send, recv): 50 receive_ok = recv(send_part) 51 send_part = F.depend(send_part, receive_ok) 52 return F.depend(receive_ok, send(send_part)) 53 54 55def _send_recv_res(left_send, recv_part, local_part, allreduce, parameter_divisibility, allreduce_node_num): 56 """send result and receive result.""" 57 if parameter_divisibility: 58 recv_part = P.Squeeze()(recv_part) 59 local_part = F.depend(local_part, recv_part) 60 eps = 1e-12 61 value_0 = P.ReduceSum()(local_part * recv_part) + eps 62 if left_send: 63 value_1 = P.ReduceSum()(local_part * local_part) + eps 64 value_2 = P.ReduceSum()(recv_part * recv_part) + eps 65 else: 66 value_1 = P.ReduceSum()(recv_part * recv_part) + eps 67 value_2 = P.ReduceSum()(local_part * local_part) + eps 68 value_0 = allreduce(value_0) 69 value_1 = F.depend(allreduce(value_1), value_0) 70 value_2 = F.depend(allreduce(value_2), value_1) 71 if left_send: 72 res = (1 - (value_0 / (2 * value_1))) * local_part + (1 - (value_0 / (2 * value_2))) * recv_part 73 else: 74 res = (1 - (value_0 / (2 * value_1))) * recv_part + (1 - (value_0 / (2 * value_2))) * local_part 75 else: 76 res = allreduce(local_part) 77 res /= allreduce_node_num 78 return res 79 80 81_adasum_opt_forward = C.MultitypeFuncGraph("adasum_opt_forward") 82 83 84@_adasum_opt_forward.register("Bool", "Function", "Bool", "Int64", "Function", "Function", "Tensor") 85def _adasum_opt_forward_process(left_send, allreduce, parameter_divisibility, allreduce_node_num, send, recv, delta_w): 86 """adasum optimizer process.""" 87 if parameter_divisibility: 88 delta_w = P.Squeeze()(delta_w) 89 ori_len = F.shape(delta_w)[0] 90 divide_len = ori_len / 2 91 left_part = delta_w[:divide_len] 92 right_part = delta_w[divide_len:] 93 else: 94 left_part = delta_w 95 right_part = delta_w 96 97 if left_send: 98 if parameter_divisibility: 99 recv_part = _send_before_receive(left_part, send, recv) 100 else: 101 recv_part = right_part 102 update_delta_w = _send_recv_res(left_send, recv_part, right_part, allreduce, parameter_divisibility, 103 allreduce_node_num) 104 else: 105 if parameter_divisibility: 106 recv_part = _receive_before_send(right_part, send, recv) 107 else: 108 recv_part = left_part 109 update_delta_w = _send_recv_res(left_send, recv_part, left_part, allreduce, parameter_divisibility, 110 allreduce_node_num) 111 112 return update_delta_w 113 114 115_adasum_opt_rollback = C.MultitypeFuncGraph("adasum_opt_rollback") 116 117 118@_adasum_opt_rollback.register("Bool", "Bool", "Tensor", "Function", "Function") 119def _adasum_opt_rollback_process(left_send, parameter_divisibility, delta_w, send, recv): 120 """adasum optimizer rollback process.""" 121 if parameter_divisibility: 122 if left_send: 123 recv_part = _send_before_receive(delta_w, send, recv) 124 else: 125 recv_part = _receive_before_send(delta_w, send, recv) 126 127 recv_part = P.Squeeze()(recv_part) 128 recv_part = P.Reshape()(recv_part, (-1,)) 129 delta_w = P.Reshape()(delta_w, (-1,)) 130 131 if left_send: 132 res = P.Concat()((recv_part, delta_w)) 133 else: 134 res = P.Concat()((delta_w, recv_part)) 135 else: 136 res = delta_w 137 return res 138 139 140class AdaSum(Cell): 141 r""" 142 The Adaptive Summation, or AdaSum, is a novel algorithm for improving distributed data 143 parallel training of Deep Learning models. 144 145 Args: 146 rank (int): Rank number. 147 device_number (int): Device number. 148 group_number (int): Group number. 149 parameter_tuple (Tuple(Parameter)): Tuple of parameters. 150 151 Inputs: 152 - **delta_weights** (Tuple(Tensor)) - Tuple of gradients. 153 - **parameters** (Tuple(Parameter)) - Tuple of current parameters. 154 - **old_parameters** (Tuple(Parameter)) - Tuple of last parameters. 155 156 Outputs: 157 - **adasum_parameters** (Tuple(Tensor)) - Tuple of parameters after adasum process. 158 """ 159 def __init__(self, rank, device_number, group_number, parameter_tuple): 160 super(AdaSum, self).__init__() 161 self.rank = rank 162 self.device_number = device_number 163 self.group_number = group_number 164 self.parameter_tuple = parameter_tuple 165 self._generate_communication_op() 166 self.hyper_map = C.HyperMap() 167 168 @staticmethod 169 def _hash(step, target, weights_index): 170 target = "tag" + str(step) + str(target) + str(weights_index) 171 target_hash = hashlib.sha1(target.encode()).hexdigest() 172 max_num_hash = 2 ** 31 173 hash_res = int(int(target_hash, 16) % max_num_hash) 174 return hash_res 175 176 def construct(self, delta_weights, parameters, old_parameters): 177 forward_weights = [delta_weights] 178 for i in range(self.calc_times): 179 process_weights = self.hyper_map(F.partial(_adasum_opt_forward, self.send_node[i], self.allreduce_list[i]), 180 self.parameter_divisibility_list[i], self.allreduce_node_num_list[i], 181 self.send_list_forward[i], self.recv_list_forward[i], forward_weights[-1]) 182 forward_weights.append(process_weights) 183 for i in range(self.calc_times): 184 j = self.calc_times - i - 1 185 process_weights = self.hyper_map(F.partial(_adasum_opt_rollback, self.send_node[j]), 186 self.parameter_divisibility_list[j], forward_weights[j + 1], 187 self.send_list_rollback[j], self.recv_list_rollback[j]) 188 forward_weights[j] = process_weights 189 adasum_parameters = self.hyper_map(F.partial(_update_parameters), delta_weights, forward_weights[0], 190 parameters, old_parameters) 191 return adasum_parameters 192 193 def _generate_communication_op(self): 194 """generate communication op.""" 195 self.calc_times = int(math.log(self.group_number, 2)) 196 self.send_node = [] 197 self.send_list_forward = [] 198 self.recv_list_forward = [] 199 self.send_list_rollback = [] 200 self.recv_list_rollback = [] 201 self.allreduce_list = [] 202 self.broadcast_list = [] 203 self.parameter_divisibility_list = [] 204 self.allreduce_node_num_list = [] 205 last_delta_weights = [] 206 group_start_rank = (self.rank // self.device_number) * self.device_number 207 208 for step in range(self.calc_times): 209 current_group = self.device_number * (2 ** step) 210 sr_target = self.rank 211 if (sr_target // current_group) % 2 == 0: 212 dest_target = sr_target + current_group 213 self.send_node.append(True) 214 else: 215 dest_target = sr_target - current_group 216 self.send_node.append(False) 217 218 neighbor_ids = [] 219 group_name_last = 0 220 for index in range(2 ** (step + 1)): 221 node_rank = self.rank // self.device_number 222 double_d = 2 ** (step + 1) 223 neighbor_id = (node_rank // double_d * double_d + index) * self.device_number + \ 224 self.rank % self.device_number 225 neighbor_ids.append(neighbor_id) 226 group_name_last += neighbor_id 227 group_name = "adasum_{}_{}".format(str(step), str(group_name_last)) 228 create_group(group_name, neighbor_ids) 229 230 send_left = [] 231 send_right = [] 232 recv_left = [] 233 recv_right = [] 234 allreduce_node_num = () 235 left_delta_weights, right_delta_weights, delta_weights_divisibility = \ 236 self._get_delta_weights_info(last_delta_weights) 237 self.parameter_divisibility_list.append(delta_weights_divisibility) 238 weights_index = 0 239 fusion_id = (step + 1) * 3 240 for shape, dtype in left_delta_weights: 241 send_tag = AdaSum._hash(step, sr_target, weights_index) 242 send = Send(sr_tag=send_tag, dest_rank=dest_target, group="hccl_world_group") 243 send.add_prim_attr("fusion", fusion_id) 244 recv_tag = AdaSum._hash(step, dest_target, weights_index) 245 recv = Receive(sr_tag=recv_tag, src_rank=dest_target, shape=shape, dtype=dtype, 246 group="hccl_world_group") 247 recv.add_prim_attr("fusion", fusion_id) 248 send_left.append(send) 249 recv_left.append(recv) 250 weights_index += 1 251 for shape, dtype in right_delta_weights: 252 send_tag = AdaSum._hash(step, sr_target, weights_index) 253 send = Send(sr_tag=send_tag, dest_rank=dest_target, group="hccl_world_group") 254 send.add_prim_attr("fusion", fusion_id + 1) 255 recv_tag = AdaSum._hash(step, dest_target, weights_index) 256 recv = Receive(sr_tag=recv_tag, src_rank=dest_target, shape=shape, dtype=dtype, 257 group="hccl_world_group") 258 recv.add_prim_attr("fusion", fusion_id + 1) 259 send_right.append(send) 260 recv_right.append(recv) 261 weights_index += 1 262 263 if self.send_node and self.send_node[-1]: 264 self.send_list_forward.append(send_left) 265 self.send_list_rollback.append(send_right) 266 self.recv_list_forward.append(recv_right) 267 self.recv_list_rollback.append(recv_left) 268 last_delta_weights = right_delta_weights 269 else: 270 self.send_list_forward.append(send_right) 271 self.send_list_rollback.append(send_left) 272 self.recv_list_forward.append(recv_left) 273 self.recv_list_rollback.append(recv_right) 274 last_delta_weights = left_delta_weights 275 276 server_all_reduce = P.AllReduce("sum", group_name) 277 server_all_reduce.add_prim_attr("fusion", fusion_id + 2) 278 self.allreduce_list.append(server_all_reduce) 279 280 for param_divisibility in delta_weights_divisibility: 281 if param_divisibility: 282 allreduce_node_num += (0,) 283 else: 284 allreduce_node_num += (2 ** (step + 1),) 285 self.allreduce_node_num_list.append(allreduce_node_num) 286 287 broadcast_group = list(range(group_start_rank, group_start_rank + self.device_number)) 288 broadcast_group_name = "broadcast_group_" + str(group_start_rank) 289 create_group(broadcast_group_name, broadcast_group) 290 for b_rank in range(len(broadcast_group)): 291 self.broadcast_list.append(P.Broadcast(b_rank, group=broadcast_group_name)) 292 self.sync_barrier = P.AllReduce("sum", group=broadcast_group_name) 293 294 def _get_delta_weights_info(self, last_delta_weights): 295 """get delta weights info.""" 296 half_delta_weights = [] 297 if last_delta_weights: 298 half_delta_weights = last_delta_weights 299 else: 300 for parameter in self.parameter_tuple: 301 new_shape = [int(x) for x in parameter.shape] 302 half_delta_weights.append((new_shape, parameter.dtype)) 303 left_delta_weights = [] 304 right_delta_weights = [] 305 delta_weights_divisibility = () 306 for shape, dtype in half_delta_weights: 307 left_shape = copy.deepcopy(shape) 308 right_shape = copy.deepcopy(shape) 309 divisibility_flag = False 310 for i, _ in enumerate(shape): 311 if shape[i] > 1: 312 left_shape[i] = int(shape[i] // 2) 313 right_shape[i] = shape[i] - int(shape[i] // 2) 314 divisibility_flag = True 315 break 316 left_delta_weights.append((left_shape, dtype)) 317 right_delta_weights.append((right_shape, dtype)) 318 delta_weights_divisibility += (divisibility_flag,) 319 return left_delta_weights, right_delta_weights, delta_weights_divisibility 320