• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""adasum"""
16import copy
17import hashlib
18import math
19from mindspore.nn.cell import Cell
20from mindspore.communication.management import create_group
21from mindspore.ops import composite as C
22from mindspore.ops import functional as F
23from mindspore.ops import operations as P
24from mindspore.ops.operations._inner_ops import Send, Receive
25from mindspore.common.tensor import Tensor
26
27
28__all__ = ["AdaSum"]
29
30
31MAX_NUM_HASH = 2 ** 31
32
33
34_update_parameters = C.MultitypeFuncGraph("update_parameters")
35
36
37@_update_parameters.register("Tensor", "Tensor", "Tensor", "Tensor")
38def _update_parameters_after_broadcast(delta_weight, update_delta_weight, parameter, old_parameter):
39    shape = F.shape(delta_weight)
40    update_delta_weight = P.Reshape()(update_delta_weight, shape)
41    new_parameter = old_parameter - update_delta_weight
42    return P.Assign()(parameter, new_parameter)
43
44
45def _send_before_receive(send_part, send, recv):
46    send_ok = send(send_part)
47    return recv(send_ok)
48
49
50def _receive_before_send(send_part, send, recv):
51    receive_ok = recv(send_part)
52    send_part = F.depend(send_part, receive_ok)
53    return F.depend(receive_ok, send(send_part))
54
55
56def _send_recv_res(left_send, recv_part, local_part, allreduce, parameter_divisibility, allreduce_node_num):
57    """send result and receive result."""
58    if parameter_divisibility:
59        recv_part = P.Squeeze()(recv_part)
60        if F.shape(recv_part) is None:
61            recv_part = Tensor([recv_part])
62        local_part = F.depend(local_part, recv_part)
63        eps = 1e-12
64        value_0 = P.ReduceSum()(local_part * recv_part) + eps
65        if left_send:
66            value_1 = P.ReduceSum()(local_part * local_part) + eps
67            value_2 = P.ReduceSum()(recv_part * recv_part) + eps
68        else:
69            value_1 = P.ReduceSum()(recv_part * recv_part) + eps
70            value_2 = P.ReduceSum()(local_part * local_part) + eps
71        value_0 = allreduce(value_0)
72        value_1 = F.depend(allreduce(value_1), value_0)
73        value_2 = F.depend(allreduce(value_2), value_1)
74        if left_send:
75            res = (1 - (value_0 / (2 * value_1))) * local_part + (1 - (value_0 / (2 * value_2))) * recv_part
76        else:
77            res = (1 - (value_0 / (2 * value_1))) * recv_part + (1 - (value_0 / (2 * value_2))) * local_part
78    else:
79        res = allreduce(local_part)
80        res /= allreduce_node_num
81    return res
82
83
84_adasum_opt_forward = C.MultitypeFuncGraph("adasum_opt_forward")
85
86
87@_adasum_opt_forward.register("Bool", "Function", "Bool", "Int64", "Function", "Function", "Tensor")
88def _adasum_opt_forward_process(left_send, allreduce, parameter_divisibility, allreduce_node_num, send, recv, delta_w):
89    """adasum optimizer process."""
90    if parameter_divisibility:
91        delta_w = P.Squeeze()(delta_w)
92        ori_len = F.shape(delta_w)[0]
93        divide_len = ori_len / 2
94        left_part = delta_w[:divide_len]
95        right_part = delta_w[divide_len:]
96    else:
97        left_part = delta_w
98        right_part = delta_w
99
100    if left_send:
101        if parameter_divisibility:
102            recv_part = _send_before_receive(left_part, send, recv)
103        else:
104            recv_part = right_part
105        update_delta_w = _send_recv_res(left_send, recv_part, right_part, allreduce, parameter_divisibility,
106                                        allreduce_node_num)
107    else:
108        if parameter_divisibility:
109            recv_part = _receive_before_send(right_part, send, recv)
110        else:
111            recv_part = left_part
112        update_delta_w = _send_recv_res(left_send, recv_part, left_part, allreduce, parameter_divisibility,
113                                        allreduce_node_num)
114
115    return update_delta_w
116
117
118_adasum_opt_rollback = C.MultitypeFuncGraph("adasum_opt_rollback")
119
120
121@_adasum_opt_rollback.register("Bool", "Bool", "Tensor", "Function", "Function")
122def _adasum_opt_rollback_process(left_send, parameter_divisibility, delta_w, send, recv):
123    """adasum optimizer rollback process."""
124    if parameter_divisibility:
125        if left_send:
126            recv_part = _send_before_receive(delta_w, send, recv)
127        else:
128            recv_part = _receive_before_send(delta_w, send, recv)
129
130        recv_part = P.Squeeze()(recv_part)
131        if F.shape(recv_part) is None:
132            recv_part = Tensor([recv_part])
133        if F.shape(delta_w) is None:
134            delta_w = Tensor([delta_w])
135        recv_part = P.Reshape()(recv_part, (-1,))
136        delta_w = P.Reshape()(delta_w, (-1,))
137
138        if left_send:
139            res = P.Concat()((recv_part, delta_w))
140        else:
141            res = P.Concat()((delta_w, recv_part))
142    else:
143        res = delta_w
144    return res
145
146
147class AdaSum(Cell):
148    r"""
149    The Adaptive Summation, or AdaSum, is a novel algorithm for improving distributed data
150    parallel training of Deep Learning models.
151
152    Args:
153        network (Cell): The training network. The network only supports single output.
154        optimizer (Union[Cell]): Optimizer for updating the weights.
155        sens (numbers.Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0.
156
157    Inputs:
158        - **delta_weights** (Tuple(Tensor)) - Tuple of gradients.
159        - **parameters** (Tuple(Parameter)) - Tuple of current parameters.
160        - **old_parameters** (Tuple(Parameter)) - Tuple of last parameters.
161
162    Outputs:
163        - **adasum_parameters** (Tuple(Tensor)) - Tuple of parameters after adasum process.
164    """
165    def __init__(self, rank, device_number, group_number, parameter_tuple):
166        super(AdaSum, self).__init__()
167        self.rank = rank
168        self.device_number = device_number
169        self.group_number = group_number
170        self.parameter_tuple = parameter_tuple
171        self._generate_communication_op()
172        self.hyper_map = C.HyperMap()
173
174    def _generate_communication_op(self):
175        """generate communication op."""
176        self.calc_times = int(math.log(self.group_number, 2))
177        self.send_node = []
178        self.send_list_forward = []
179        self.recv_list_forward = []
180        self.send_list_rollback = []
181        self.recv_list_rollback = []
182        self.allreduce_list = []
183        self.broadcast_list = []
184        self.parameter_divisibility_list = []
185        self.allreduce_node_num_list = []
186        last_delta_weights = []
187        group_start_rank = (self.rank // self.device_number) * self.device_number
188
189        for step in range(self.calc_times):
190            current_group = self.device_number * (2 ** step)
191            sr_target = self.rank
192            if (sr_target // current_group) % 2 == 0:
193                dest_target = sr_target + current_group
194                self.send_node.append(True)
195            else:
196                dest_target = sr_target - current_group
197                self.send_node.append(False)
198
199            neighbor_ids = []
200            group_name_last = 0
201            for index in range(2 ** (step + 1)):
202                node_rank = self.rank // self.device_number
203                double_d = 2 ** (step + 1)
204                neighbor_id = (node_rank // double_d * double_d + index) * self.device_number + \
205                              self.rank % self.device_number
206                neighbor_ids.append(neighbor_id)
207                group_name_last += neighbor_id
208            group_name = "adasum_" + str(step) + "_" + str(group_name_last)
209            create_group(group_name, neighbor_ids)
210
211            send_left = []
212            send_right = []
213            recv_left = []
214            recv_right = []
215            allreduce_node_num = ()
216            left_delta_weights, right_delta_weights, delta_weights_divisibility = \
217                self._get_delta_weights_info(last_delta_weights)
218            self.parameter_divisibility_list.append(delta_weights_divisibility)
219            weights_index = 0
220            fusion_id = (step + 1) * 3
221            for shape, dtype in left_delta_weights:
222                send_tag = self._hash(step, sr_target, weights_index)
223                send = Send(sr_tag=send_tag, dest_rank=dest_target, group="hccl_world_group")
224                send.add_prim_attr("fusion", fusion_id)
225                recv_tag = self._hash(step, dest_target, weights_index)
226                recv = Receive(sr_tag=recv_tag, src_rank=dest_target, shape=shape, dtype=dtype,
227                               group="hccl_world_group")
228                recv.add_prim_attr("fusion", fusion_id)
229                send_left.append(send)
230                recv_left.append(recv)
231                weights_index += 1
232            for shape, dtype in right_delta_weights:
233                send_tag = self._hash(step, sr_target, weights_index)
234                send = Send(sr_tag=send_tag, dest_rank=dest_target, group="hccl_world_group")
235                send.add_prim_attr("fusion", fusion_id + 1)
236                recv_tag = self._hash(step, dest_target, weights_index)
237                recv = Receive(sr_tag=recv_tag, src_rank=dest_target, shape=shape, dtype=dtype,
238                               group="hccl_world_group")
239                recv.add_prim_attr("fusion", fusion_id + 1)
240                send_right.append(send)
241                recv_right.append(recv)
242                weights_index += 1
243
244            if self.send_node and self.send_node[-1]:
245                self.send_list_forward.append(send_left)
246                self.send_list_rollback.append(send_right)
247                self.recv_list_forward.append(recv_right)
248                self.recv_list_rollback.append(recv_left)
249                last_delta_weights = right_delta_weights
250            else:
251                self.send_list_forward.append(send_right)
252                self.send_list_rollback.append(send_left)
253                self.recv_list_forward.append(recv_left)
254                self.recv_list_rollback.append(recv_right)
255                last_delta_weights = left_delta_weights
256
257            server_all_reduce = P.AllReduce("sum", group_name)
258            server_all_reduce.add_prim_attr("fusion", fusion_id + 2)
259            self.allreduce_list.append(server_all_reduce)
260
261            for param_divisibility in delta_weights_divisibility:
262                if param_divisibility:
263                    allreduce_node_num += (0,)
264                else:
265                    allreduce_node_num += (2 ** (step + 1),)
266            self.allreduce_node_num_list.append(allreduce_node_num)
267
268        broadcast_group = [x for x in range(group_start_rank, group_start_rank + self.device_number)]
269        broadcast_group_name = "broadcast_group_" + str(group_start_rank)
270        create_group(broadcast_group_name, broadcast_group)
271        for b_rank in range(len(broadcast_group)):
272            self.broadcast_list.append(P.Broadcast(b_rank, group=broadcast_group_name))
273        self.sync_barrier = P.AllReduce("sum", group=broadcast_group_name)
274
275    def _get_delta_weights_info(self, last_delta_weights):
276        """get delta weights info."""
277        half_delta_weights = []
278        if last_delta_weights:
279            half_delta_weights = last_delta_weights
280        else:
281            for parameter in self.parameter_tuple:
282                new_shape = [int(x) for x in parameter.shape]
283                half_delta_weights.append((new_shape, parameter.dtype))
284        left_delta_weights = []
285        right_delta_weights = []
286        delta_weights_divisibility = ()
287        for shape, dtype in half_delta_weights:
288            left_shape = copy.deepcopy(shape)
289            right_shape = copy.deepcopy(shape)
290            divisibility_flag = False
291            for i in range(len(shape)):
292                if shape[i] > 1:
293                    left_shape[i] = int(shape[i] // 2)
294                    right_shape[i] = shape[i] - int(shape[i] // 2)
295                    divisibility_flag = True
296                    break
297            left_delta_weights.append((left_shape, dtype))
298            right_delta_weights.append((right_shape, dtype))
299            delta_weights_divisibility += (divisibility_flag,)
300        return left_delta_weights, right_delta_weights, delta_weights_divisibility
301
302    def _hash(self, step, target, weights_index):
303        target = "tag" + str(step) + str(target) + str(weights_index)
304        target_hash = hashlib.sha1(target.encode()).hexdigest()
305        hash_res = int(int(target_hash, 16) % MAX_NUM_HASH)
306        return hash_res
307
308    def construct(self, delta_weights, parameters, old_parameters):
309        forward_weights = [delta_weights]
310        for i in range(self.calc_times):
311            process_weights = self.hyper_map(F.partial(_adasum_opt_forward, self.send_node[i], self.allreduce_list[i]),
312                                             self.parameter_divisibility_list[i], self.allreduce_node_num_list[i],
313                                             self.send_list_forward[i], self.recv_list_forward[i], forward_weights[-1])
314            forward_weights.append(process_weights)
315        for i in range(self.calc_times):
316            j = self.calc_times - i - 1
317            process_weights = self.hyper_map(F.partial(_adasum_opt_rollback, self.send_node[j]),
318                                             self.parameter_divisibility_list[j], forward_weights[j + 1],
319                                             self.send_list_rollback[j], self.recv_list_rollback[j])
320            forward_weights[j] = process_weights
321        adasum_parameters = self.hyper_map(F.partial(_update_parameters), delta_weights, forward_weights[0],
322                                           parameters, old_parameters)
323        return adasum_parameters
324