# Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """adasum""" import copy import hashlib import math from mindspore.nn.cell import Cell from mindspore.communication.management import create_group from mindspore.ops import composite as C from mindspore.ops import functional as F from mindspore.ops import operations as P from mindspore.ops.operations._inner_ops import Send, Receive from mindspore.common.tensor import Tensor __all__ = ["AdaSum"] MAX_NUM_HASH = 2 ** 31 _update_parameters = C.MultitypeFuncGraph("update_parameters") @_update_parameters.register("Tensor", "Tensor", "Tensor", "Tensor") def _update_parameters_after_broadcast(delta_weight, update_delta_weight, parameter, old_parameter): shape = F.shape(delta_weight) update_delta_weight = P.Reshape()(update_delta_weight, shape) new_parameter = old_parameter - update_delta_weight return P.Assign()(parameter, new_parameter) def _send_before_receive(send_part, send, recv): send_ok = send(send_part) return recv(send_ok) def _receive_before_send(send_part, send, recv): receive_ok = recv(send_part) send_part = F.depend(send_part, receive_ok) return F.depend(receive_ok, send(send_part)) def _send_recv_res(left_send, recv_part, local_part, allreduce, parameter_divisibility, allreduce_node_num): """send result and receive result.""" if parameter_divisibility: recv_part = P.Squeeze()(recv_part) if F.shape(recv_part) is None: recv_part = Tensor([recv_part]) local_part = F.depend(local_part, recv_part) eps = 1e-12 value_0 = P.ReduceSum()(local_part * recv_part) + eps if left_send: value_1 = P.ReduceSum()(local_part * local_part) + eps value_2 = P.ReduceSum()(recv_part * recv_part) + eps else: value_1 = P.ReduceSum()(recv_part * recv_part) + eps value_2 = P.ReduceSum()(local_part * local_part) + eps value_0 = allreduce(value_0) value_1 = F.depend(allreduce(value_1), value_0) value_2 = F.depend(allreduce(value_2), value_1) if left_send: res = (1 - (value_0 / (2 * value_1))) * local_part + (1 - (value_0 / (2 * value_2))) * recv_part else: res = (1 - (value_0 / (2 * value_1))) * recv_part + (1 - (value_0 / (2 * value_2))) * local_part else: res = allreduce(local_part) res /= allreduce_node_num return res _adasum_opt_forward = C.MultitypeFuncGraph("adasum_opt_forward") @_adasum_opt_forward.register("Bool", "Function", "Bool", "Int64", "Function", "Function", "Tensor") def _adasum_opt_forward_process(left_send, allreduce, parameter_divisibility, allreduce_node_num, send, recv, delta_w): """adasum optimizer process.""" if parameter_divisibility: delta_w = P.Squeeze()(delta_w) ori_len = F.shape(delta_w)[0] divide_len = ori_len / 2 left_part = delta_w[:divide_len] right_part = delta_w[divide_len:] else: left_part = delta_w right_part = delta_w if left_send: if parameter_divisibility: recv_part = _send_before_receive(left_part, send, recv) else: recv_part = right_part update_delta_w = _send_recv_res(left_send, recv_part, right_part, allreduce, parameter_divisibility, allreduce_node_num) else: if parameter_divisibility: recv_part = _receive_before_send(right_part, send, recv) else: recv_part = left_part update_delta_w = _send_recv_res(left_send, recv_part, left_part, allreduce, parameter_divisibility, allreduce_node_num) return update_delta_w _adasum_opt_rollback = C.MultitypeFuncGraph("adasum_opt_rollback") @_adasum_opt_rollback.register("Bool", "Bool", "Tensor", "Function", "Function") def _adasum_opt_rollback_process(left_send, parameter_divisibility, delta_w, send, recv): """adasum optimizer rollback process.""" if parameter_divisibility: if left_send: recv_part = _send_before_receive(delta_w, send, recv) else: recv_part = _receive_before_send(delta_w, send, recv) recv_part = P.Squeeze()(recv_part) if F.shape(recv_part) is None: recv_part = Tensor([recv_part]) if F.shape(delta_w) is None: delta_w = Tensor([delta_w]) recv_part = P.Reshape()(recv_part, (-1,)) delta_w = P.Reshape()(delta_w, (-1,)) if left_send: res = P.Concat()((recv_part, delta_w)) else: res = P.Concat()((delta_w, recv_part)) else: res = delta_w return res class AdaSum(Cell): r""" The Adaptive Summation, or AdaSum, is a novel algorithm for improving distributed data parallel training of Deep Learning models. Args: network (Cell): The training network. The network only supports single output. optimizer (Union[Cell]): Optimizer for updating the weights. sens (numbers.Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0. Inputs: - **delta_weights** (Tuple(Tensor)) - Tuple of gradients. - **parameters** (Tuple(Parameter)) - Tuple of current parameters. - **old_parameters** (Tuple(Parameter)) - Tuple of last parameters. Outputs: - **adasum_parameters** (Tuple(Tensor)) - Tuple of parameters after adasum process. """ def __init__(self, rank, device_number, group_number, parameter_tuple): super(AdaSum, self).__init__() self.rank = rank self.device_number = device_number self.group_number = group_number self.parameter_tuple = parameter_tuple self._generate_communication_op() self.hyper_map = C.HyperMap() def _generate_communication_op(self): """generate communication op.""" self.calc_times = int(math.log(self.group_number, 2)) self.send_node = [] self.send_list_forward = [] self.recv_list_forward = [] self.send_list_rollback = [] self.recv_list_rollback = [] self.allreduce_list = [] self.broadcast_list = [] self.parameter_divisibility_list = [] self.allreduce_node_num_list = [] last_delta_weights = [] group_start_rank = (self.rank // self.device_number) * self.device_number for step in range(self.calc_times): current_group = self.device_number * (2 ** step) sr_target = self.rank if (sr_target // current_group) % 2 == 0: dest_target = sr_target + current_group self.send_node.append(True) else: dest_target = sr_target - current_group self.send_node.append(False) neighbor_ids = [] group_name_last = 0 for index in range(2 ** (step + 1)): node_rank = self.rank // self.device_number double_d = 2 ** (step + 1) neighbor_id = (node_rank // double_d * double_d + index) * self.device_number + \ self.rank % self.device_number neighbor_ids.append(neighbor_id) group_name_last += neighbor_id group_name = "adasum_" + str(step) + "_" + str(group_name_last) create_group(group_name, neighbor_ids) send_left = [] send_right = [] recv_left = [] recv_right = [] allreduce_node_num = () left_delta_weights, right_delta_weights, delta_weights_divisibility = \ self._get_delta_weights_info(last_delta_weights) self.parameter_divisibility_list.append(delta_weights_divisibility) weights_index = 0 fusion_id = (step + 1) * 3 for shape, dtype in left_delta_weights: send_tag = self._hash(step, sr_target, weights_index) send = Send(sr_tag=send_tag, dest_rank=dest_target, group="hccl_world_group") send.add_prim_attr("fusion", fusion_id) recv_tag = self._hash(step, dest_target, weights_index) recv = Receive(sr_tag=recv_tag, src_rank=dest_target, shape=shape, dtype=dtype, group="hccl_world_group") recv.add_prim_attr("fusion", fusion_id) send_left.append(send) recv_left.append(recv) weights_index += 1 for shape, dtype in right_delta_weights: send_tag = self._hash(step, sr_target, weights_index) send = Send(sr_tag=send_tag, dest_rank=dest_target, group="hccl_world_group") send.add_prim_attr("fusion", fusion_id + 1) recv_tag = self._hash(step, dest_target, weights_index) recv = Receive(sr_tag=recv_tag, src_rank=dest_target, shape=shape, dtype=dtype, group="hccl_world_group") recv.add_prim_attr("fusion", fusion_id + 1) send_right.append(send) recv_right.append(recv) weights_index += 1 if self.send_node and self.send_node[-1]: self.send_list_forward.append(send_left) self.send_list_rollback.append(send_right) self.recv_list_forward.append(recv_right) self.recv_list_rollback.append(recv_left) last_delta_weights = right_delta_weights else: self.send_list_forward.append(send_right) self.send_list_rollback.append(send_left) self.recv_list_forward.append(recv_left) self.recv_list_rollback.append(recv_right) last_delta_weights = left_delta_weights server_all_reduce = P.AllReduce("sum", group_name) server_all_reduce.add_prim_attr("fusion", fusion_id + 2) self.allreduce_list.append(server_all_reduce) for param_divisibility in delta_weights_divisibility: if param_divisibility: allreduce_node_num += (0,) else: allreduce_node_num += (2 ** (step + 1),) self.allreduce_node_num_list.append(allreduce_node_num) broadcast_group = [x for x in range(group_start_rank, group_start_rank + self.device_number)] broadcast_group_name = "broadcast_group_" + str(group_start_rank) create_group(broadcast_group_name, broadcast_group) for b_rank in range(len(broadcast_group)): self.broadcast_list.append(P.Broadcast(b_rank, group=broadcast_group_name)) self.sync_barrier = P.AllReduce("sum", group=broadcast_group_name) def _get_delta_weights_info(self, last_delta_weights): """get delta weights info.""" half_delta_weights = [] if last_delta_weights: half_delta_weights = last_delta_weights else: for parameter in self.parameter_tuple: new_shape = [int(x) for x in parameter.shape] half_delta_weights.append((new_shape, parameter.dtype)) left_delta_weights = [] right_delta_weights = [] delta_weights_divisibility = () for shape, dtype in half_delta_weights: left_shape = copy.deepcopy(shape) right_shape = copy.deepcopy(shape) divisibility_flag = False for i in range(len(shape)): if shape[i] > 1: left_shape[i] = int(shape[i] // 2) right_shape[i] = shape[i] - int(shape[i] // 2) divisibility_flag = True break left_delta_weights.append((left_shape, dtype)) right_delta_weights.append((right_shape, dtype)) delta_weights_divisibility += (divisibility_flag,) return left_delta_weights, right_delta_weights, delta_weights_divisibility def _hash(self, step, target, weights_index): target = "tag" + str(step) + str(target) + str(weights_index) target_hash = hashlib.sha1(target.encode()).hexdigest() hash_res = int(int(target_hash, 16) % MAX_NUM_HASH) return hash_res def construct(self, delta_weights, parameters, old_parameters): forward_weights = [delta_weights] for i in range(self.calc_times): process_weights = self.hyper_map(F.partial(_adasum_opt_forward, self.send_node[i], self.allreduce_list[i]), self.parameter_divisibility_list[i], self.allreduce_node_num_list[i], self.send_list_forward[i], self.recv_list_forward[i], forward_weights[-1]) forward_weights.append(process_weights) for i in range(self.calc_times): j = self.calc_times - i - 1 process_weights = self.hyper_map(F.partial(_adasum_opt_rollback, self.send_node[j]), self.parameter_divisibility_list[j], forward_weights[j + 1], self.send_list_rollback[j], self.recv_list_rollback[j]) forward_weights[j] = process_weights adasum_parameters = self.hyper_map(F.partial(_update_parameters), delta_weights, forward_weights[0], parameters, old_parameters) return adasum_parameters