1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""adasum""" 16import copy 17import hashlib 18import math 19from mindspore.nn.cell import Cell 20from mindspore.communication.management import create_group 21from mindspore.ops import composite as C 22from mindspore.ops import functional as F 23from mindspore.ops import operations as P 24from mindspore.ops.operations._inner_ops import Send, Receive 25from mindspore.common.tensor import Tensor 26 27 28__all__ = ["AdaSum"] 29 30 31MAX_NUM_HASH = 2 ** 31 32 33 34_update_parameters = C.MultitypeFuncGraph("update_parameters") 35 36 37@_update_parameters.register("Tensor", "Tensor", "Tensor", "Tensor") 38def _update_parameters_after_broadcast(delta_weight, update_delta_weight, parameter, old_parameter): 39 shape = F.shape(delta_weight) 40 update_delta_weight = P.Reshape()(update_delta_weight, shape) 41 new_parameter = old_parameter - update_delta_weight 42 return P.Assign()(parameter, new_parameter) 43 44 45def _send_before_receive(send_part, send, recv): 46 send_ok = send(send_part) 47 return recv(send_ok) 48 49 50def _receive_before_send(send_part, send, recv): 51 receive_ok = recv(send_part) 52 send_part = F.depend(send_part, receive_ok) 53 return F.depend(receive_ok, send(send_part)) 54 55 56def _send_recv_res(left_send, recv_part, local_part, allreduce, parameter_divisibility, allreduce_node_num): 57 """send result and receive result.""" 58 if parameter_divisibility: 59 recv_part = P.Squeeze()(recv_part) 60 if F.shape(recv_part) is None: 61 recv_part = Tensor([recv_part]) 62 local_part = F.depend(local_part, recv_part) 63 eps = 1e-12 64 value_0 = P.ReduceSum()(local_part * recv_part) + eps 65 if left_send: 66 value_1 = P.ReduceSum()(local_part * local_part) + eps 67 value_2 = P.ReduceSum()(recv_part * recv_part) + eps 68 else: 69 value_1 = P.ReduceSum()(recv_part * recv_part) + eps 70 value_2 = P.ReduceSum()(local_part * local_part) + eps 71 value_0 = allreduce(value_0) 72 value_1 = F.depend(allreduce(value_1), value_0) 73 value_2 = F.depend(allreduce(value_2), value_1) 74 if left_send: 75 res = (1 - (value_0 / (2 * value_1))) * local_part + (1 - (value_0 / (2 * value_2))) * recv_part 76 else: 77 res = (1 - (value_0 / (2 * value_1))) * recv_part + (1 - (value_0 / (2 * value_2))) * local_part 78 else: 79 res = allreduce(local_part) 80 res /= allreduce_node_num 81 return res 82 83 84_adasum_opt_forward = C.MultitypeFuncGraph("adasum_opt_forward") 85 86 87@_adasum_opt_forward.register("Bool", "Function", "Bool", "Int64", "Function", "Function", "Tensor") 88def _adasum_opt_forward_process(left_send, allreduce, parameter_divisibility, allreduce_node_num, send, recv, delta_w): 89 """adasum optimizer process.""" 90 if parameter_divisibility: 91 delta_w = P.Squeeze()(delta_w) 92 ori_len = F.shape(delta_w)[0] 93 divide_len = ori_len / 2 94 left_part = delta_w[:divide_len] 95 right_part = delta_w[divide_len:] 96 else: 97 left_part = delta_w 98 right_part = delta_w 99 100 if left_send: 101 if parameter_divisibility: 102 recv_part = _send_before_receive(left_part, send, recv) 103 else: 104 recv_part = right_part 105 update_delta_w = _send_recv_res(left_send, recv_part, right_part, allreduce, parameter_divisibility, 106 allreduce_node_num) 107 else: 108 if parameter_divisibility: 109 recv_part = _receive_before_send(right_part, send, recv) 110 else: 111 recv_part = left_part 112 update_delta_w = _send_recv_res(left_send, recv_part, left_part, allreduce, parameter_divisibility, 113 allreduce_node_num) 114 115 return update_delta_w 116 117 118_adasum_opt_rollback = C.MultitypeFuncGraph("adasum_opt_rollback") 119 120 121@_adasum_opt_rollback.register("Bool", "Bool", "Tensor", "Function", "Function") 122def _adasum_opt_rollback_process(left_send, parameter_divisibility, delta_w, send, recv): 123 """adasum optimizer rollback process.""" 124 if parameter_divisibility: 125 if left_send: 126 recv_part = _send_before_receive(delta_w, send, recv) 127 else: 128 recv_part = _receive_before_send(delta_w, send, recv) 129 130 recv_part = P.Squeeze()(recv_part) 131 if F.shape(recv_part) is None: 132 recv_part = Tensor([recv_part]) 133 if F.shape(delta_w) is None: 134 delta_w = Tensor([delta_w]) 135 recv_part = P.Reshape()(recv_part, (-1,)) 136 delta_w = P.Reshape()(delta_w, (-1,)) 137 138 if left_send: 139 res = P.Concat()((recv_part, delta_w)) 140 else: 141 res = P.Concat()((delta_w, recv_part)) 142 else: 143 res = delta_w 144 return res 145 146 147class AdaSum(Cell): 148 r""" 149 The Adaptive Summation, or AdaSum, is a novel algorithm for improving distributed data 150 parallel training of Deep Learning models. 151 152 Args: 153 network (Cell): The training network. The network only supports single output. 154 optimizer (Union[Cell]): Optimizer for updating the weights. 155 sens (numbers.Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0. 156 157 Inputs: 158 - **delta_weights** (Tuple(Tensor)) - Tuple of gradients. 159 - **parameters** (Tuple(Parameter)) - Tuple of current parameters. 160 - **old_parameters** (Tuple(Parameter)) - Tuple of last parameters. 161 162 Outputs: 163 - **adasum_parameters** (Tuple(Tensor)) - Tuple of parameters after adasum process. 164 """ 165 def __init__(self, rank, device_number, group_number, parameter_tuple): 166 super(AdaSum, self).__init__() 167 self.rank = rank 168 self.device_number = device_number 169 self.group_number = group_number 170 self.parameter_tuple = parameter_tuple 171 self._generate_communication_op() 172 self.hyper_map = C.HyperMap() 173 174 def _generate_communication_op(self): 175 """generate communication op.""" 176 self.calc_times = int(math.log(self.group_number, 2)) 177 self.send_node = [] 178 self.send_list_forward = [] 179 self.recv_list_forward = [] 180 self.send_list_rollback = [] 181 self.recv_list_rollback = [] 182 self.allreduce_list = [] 183 self.broadcast_list = [] 184 self.parameter_divisibility_list = [] 185 self.allreduce_node_num_list = [] 186 last_delta_weights = [] 187 group_start_rank = (self.rank // self.device_number) * self.device_number 188 189 for step in range(self.calc_times): 190 current_group = self.device_number * (2 ** step) 191 sr_target = self.rank 192 if (sr_target // current_group) % 2 == 0: 193 dest_target = sr_target + current_group 194 self.send_node.append(True) 195 else: 196 dest_target = sr_target - current_group 197 self.send_node.append(False) 198 199 neighbor_ids = [] 200 group_name_last = 0 201 for index in range(2 ** (step + 1)): 202 node_rank = self.rank // self.device_number 203 double_d = 2 ** (step + 1) 204 neighbor_id = (node_rank // double_d * double_d + index) * self.device_number + \ 205 self.rank % self.device_number 206 neighbor_ids.append(neighbor_id) 207 group_name_last += neighbor_id 208 group_name = "adasum_" + str(step) + "_" + str(group_name_last) 209 create_group(group_name, neighbor_ids) 210 211 send_left = [] 212 send_right = [] 213 recv_left = [] 214 recv_right = [] 215 allreduce_node_num = () 216 left_delta_weights, right_delta_weights, delta_weights_divisibility = \ 217 self._get_delta_weights_info(last_delta_weights) 218 self.parameter_divisibility_list.append(delta_weights_divisibility) 219 weights_index = 0 220 fusion_id = (step + 1) * 3 221 for shape, dtype in left_delta_weights: 222 send_tag = self._hash(step, sr_target, weights_index) 223 send = Send(sr_tag=send_tag, dest_rank=dest_target, group="hccl_world_group") 224 send.add_prim_attr("fusion", fusion_id) 225 recv_tag = self._hash(step, dest_target, weights_index) 226 recv = Receive(sr_tag=recv_tag, src_rank=dest_target, shape=shape, dtype=dtype, 227 group="hccl_world_group") 228 recv.add_prim_attr("fusion", fusion_id) 229 send_left.append(send) 230 recv_left.append(recv) 231 weights_index += 1 232 for shape, dtype in right_delta_weights: 233 send_tag = self._hash(step, sr_target, weights_index) 234 send = Send(sr_tag=send_tag, dest_rank=dest_target, group="hccl_world_group") 235 send.add_prim_attr("fusion", fusion_id + 1) 236 recv_tag = self._hash(step, dest_target, weights_index) 237 recv = Receive(sr_tag=recv_tag, src_rank=dest_target, shape=shape, dtype=dtype, 238 group="hccl_world_group") 239 recv.add_prim_attr("fusion", fusion_id + 1) 240 send_right.append(send) 241 recv_right.append(recv) 242 weights_index += 1 243 244 if self.send_node and self.send_node[-1]: 245 self.send_list_forward.append(send_left) 246 self.send_list_rollback.append(send_right) 247 self.recv_list_forward.append(recv_right) 248 self.recv_list_rollback.append(recv_left) 249 last_delta_weights = right_delta_weights 250 else: 251 self.send_list_forward.append(send_right) 252 self.send_list_rollback.append(send_left) 253 self.recv_list_forward.append(recv_left) 254 self.recv_list_rollback.append(recv_right) 255 last_delta_weights = left_delta_weights 256 257 server_all_reduce = P.AllReduce("sum", group_name) 258 server_all_reduce.add_prim_attr("fusion", fusion_id + 2) 259 self.allreduce_list.append(server_all_reduce) 260 261 for param_divisibility in delta_weights_divisibility: 262 if param_divisibility: 263 allreduce_node_num += (0,) 264 else: 265 allreduce_node_num += (2 ** (step + 1),) 266 self.allreduce_node_num_list.append(allreduce_node_num) 267 268 broadcast_group = [x for x in range(group_start_rank, group_start_rank + self.device_number)] 269 broadcast_group_name = "broadcast_group_" + str(group_start_rank) 270 create_group(broadcast_group_name, broadcast_group) 271 for b_rank in range(len(broadcast_group)): 272 self.broadcast_list.append(P.Broadcast(b_rank, group=broadcast_group_name)) 273 self.sync_barrier = P.AllReduce("sum", group=broadcast_group_name) 274 275 def _get_delta_weights_info(self, last_delta_weights): 276 """get delta weights info.""" 277 half_delta_weights = [] 278 if last_delta_weights: 279 half_delta_weights = last_delta_weights 280 else: 281 for parameter in self.parameter_tuple: 282 new_shape = [int(x) for x in parameter.shape] 283 half_delta_weights.append((new_shape, parameter.dtype)) 284 left_delta_weights = [] 285 right_delta_weights = [] 286 delta_weights_divisibility = () 287 for shape, dtype in half_delta_weights: 288 left_shape = copy.deepcopy(shape) 289 right_shape = copy.deepcopy(shape) 290 divisibility_flag = False 291 for i in range(len(shape)): 292 if shape[i] > 1: 293 left_shape[i] = int(shape[i] // 2) 294 right_shape[i] = shape[i] - int(shape[i] // 2) 295 divisibility_flag = True 296 break 297 left_delta_weights.append((left_shape, dtype)) 298 right_delta_weights.append((right_shape, dtype)) 299 delta_weights_divisibility += (divisibility_flag,) 300 return left_delta_weights, right_delta_weights, delta_weights_divisibility 301 302 def _hash(self, step, target, weights_index): 303 target = "tag" + str(step) + str(target) + str(weights_index) 304 target_hash = hashlib.sha1(target.encode()).hexdigest() 305 hash_res = int(int(target_hash, 16) % MAX_NUM_HASH) 306 return hash_res 307 308 def construct(self, delta_weights, parameters, old_parameters): 309 forward_weights = [delta_weights] 310 for i in range(self.calc_times): 311 process_weights = self.hyper_map(F.partial(_adasum_opt_forward, self.send_node[i], self.allreduce_list[i]), 312 self.parameter_divisibility_list[i], self.allreduce_node_num_list[i], 313 self.send_list_forward[i], self.recv_list_forward[i], forward_weights[-1]) 314 forward_weights.append(process_weights) 315 for i in range(self.calc_times): 316 j = self.calc_times - i - 1 317 process_weights = self.hyper_map(F.partial(_adasum_opt_rollback, self.send_node[j]), 318 self.parameter_divisibility_list[j], forward_weights[j + 1], 319 self.send_list_rollback[j], self.recv_list_rollback[j]) 320 forward_weights[j] = process_weights 321 adasum_parameters = self.hyper_map(F.partial(_update_parameters), delta_weights, forward_weights[0], 322 parameters, old_parameters) 323 return adasum_parameters 324