• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021-2022 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""adasum"""
16from __future__ import absolute_import
17
18import copy
19import hashlib
20import math
21from mindspore.nn.cell import Cell
22from mindspore.communication.management import create_group
23from mindspore.ops import composite as C
24from mindspore.ops import functional as F
25from mindspore.ops import operations as P
26from mindspore.ops import Send, Receive
27
28
29__all__ = ["AdaSum"]
30
31
32_update_parameters = C.MultitypeFuncGraph("update_parameters")
33
34
35@_update_parameters.register("Tensor", "Tensor", "Tensor", "Tensor")
36def _update_parameters_after_broadcast(delta_weight, update_delta_weight, parameter, old_parameter):
37    shape = F.shape(delta_weight)
38    update_delta_weight = P.Reshape()(update_delta_weight, shape)
39    new_parameter = old_parameter - update_delta_weight
40    P.Assign()(parameter, new_parameter)
41    return parameter
42
43
44def _send_before_receive(send_part, send, recv):
45    send_ok = send(send_part)
46    return recv(send_ok)
47
48
49def _receive_before_send(send_part, send, recv):
50    receive_ok = recv(send_part)
51    send_part = F.depend(send_part, receive_ok)
52    return F.depend(receive_ok, send(send_part))
53
54
55def _send_recv_res(left_send, recv_part, local_part, allreduce, parameter_divisibility, allreduce_node_num):
56    """send result and receive result."""
57    if parameter_divisibility:
58        recv_part = P.Squeeze()(recv_part)
59        local_part = F.depend(local_part, recv_part)
60        eps = 1e-12
61        value_0 = P.ReduceSum()(local_part * recv_part) + eps
62        if left_send:
63            value_1 = P.ReduceSum()(local_part * local_part) + eps
64            value_2 = P.ReduceSum()(recv_part * recv_part) + eps
65        else:
66            value_1 = P.ReduceSum()(recv_part * recv_part) + eps
67            value_2 = P.ReduceSum()(local_part * local_part) + eps
68        value_0 = allreduce(value_0)
69        value_1 = F.depend(allreduce(value_1), value_0)
70        value_2 = F.depend(allreduce(value_2), value_1)
71        if left_send:
72            res = (1 - (value_0 / (2 * value_1))) * local_part + (1 - (value_0 / (2 * value_2))) * recv_part
73        else:
74            res = (1 - (value_0 / (2 * value_1))) * recv_part + (1 - (value_0 / (2 * value_2))) * local_part
75    else:
76        res = allreduce(local_part)
77        res /= allreduce_node_num
78    return res
79
80
81_adasum_opt_forward = C.MultitypeFuncGraph("adasum_opt_forward")
82
83
84@_adasum_opt_forward.register("Bool", "Function", "Bool", "Int64", "Function", "Function", "Tensor")
85def _adasum_opt_forward_process(left_send, allreduce, parameter_divisibility, allreduce_node_num, send, recv, delta_w):
86    """adasum optimizer process."""
87    if parameter_divisibility:
88        delta_w = P.Squeeze()(delta_w)
89        ori_len = F.shape(delta_w)[0]
90        divide_len = ori_len / 2
91        left_part = delta_w[:divide_len]
92        right_part = delta_w[divide_len:]
93    else:
94        left_part = delta_w
95        right_part = delta_w
96
97    if left_send:
98        if parameter_divisibility:
99            recv_part = _send_before_receive(left_part, send, recv)
100        else:
101            recv_part = right_part
102        update_delta_w = _send_recv_res(left_send, recv_part, right_part, allreduce, parameter_divisibility,
103                                        allreduce_node_num)
104    else:
105        if parameter_divisibility:
106            recv_part = _receive_before_send(right_part, send, recv)
107        else:
108            recv_part = left_part
109        update_delta_w = _send_recv_res(left_send, recv_part, left_part, allreduce, parameter_divisibility,
110                                        allreduce_node_num)
111
112    return update_delta_w
113
114
115_adasum_opt_rollback = C.MultitypeFuncGraph("adasum_opt_rollback")
116
117
118@_adasum_opt_rollback.register("Bool", "Bool", "Tensor", "Function", "Function")
119def _adasum_opt_rollback_process(left_send, parameter_divisibility, delta_w, send, recv):
120    """adasum optimizer rollback process."""
121    if parameter_divisibility:
122        if left_send:
123            recv_part = _send_before_receive(delta_w, send, recv)
124        else:
125            recv_part = _receive_before_send(delta_w, send, recv)
126
127        recv_part = P.Squeeze()(recv_part)
128        recv_part = P.Reshape()(recv_part, (-1,))
129        delta_w = P.Reshape()(delta_w, (-1,))
130
131        if left_send:
132            res = P.Concat()((recv_part, delta_w))
133        else:
134            res = P.Concat()((delta_w, recv_part))
135    else:
136        res = delta_w
137    return res
138
139
140class AdaSum(Cell):
141    r"""
142    The Adaptive Summation, or AdaSum, is a novel algorithm for improving distributed data
143    parallel training of Deep Learning models.
144
145    Args:
146        rank (int): Rank number.
147        device_number (int): Device number.
148        group_number (int): Group number.
149        parameter_tuple (Tuple(Parameter)): Tuple of parameters.
150
151    Inputs:
152        - **delta_weights** (Tuple(Tensor)) - Tuple of gradients.
153        - **parameters** (Tuple(Parameter)) - Tuple of current parameters.
154        - **old_parameters** (Tuple(Parameter)) - Tuple of last parameters.
155
156    Outputs:
157        - **adasum_parameters** (Tuple(Tensor)) - Tuple of parameters after adasum process.
158    """
159    def __init__(self, rank, device_number, group_number, parameter_tuple):
160        super(AdaSum, self).__init__()
161        self.rank = rank
162        self.device_number = device_number
163        self.group_number = group_number
164        self.parameter_tuple = parameter_tuple
165        self._generate_communication_op()
166        self.hyper_map = C.HyperMap()
167
168    @staticmethod
169    def _hash(step, target, weights_index):
170        target = "tag" + str(step) + str(target) + str(weights_index)
171        target_hash = hashlib.sha1(target.encode()).hexdigest()
172        max_num_hash = 2 ** 31
173        hash_res = int(int(target_hash, 16) % max_num_hash)
174        return hash_res
175
176    def construct(self, delta_weights, parameters, old_parameters):
177        forward_weights = [delta_weights]
178        for i in range(self.calc_times):
179            process_weights = self.hyper_map(F.partial(_adasum_opt_forward, self.send_node[i], self.allreduce_list[i]),
180                                             self.parameter_divisibility_list[i], self.allreduce_node_num_list[i],
181                                             self.send_list_forward[i], self.recv_list_forward[i], forward_weights[-1])
182            forward_weights.append(process_weights)
183        for i in range(self.calc_times):
184            j = self.calc_times - i - 1
185            process_weights = self.hyper_map(F.partial(_adasum_opt_rollback, self.send_node[j]),
186                                             self.parameter_divisibility_list[j], forward_weights[j + 1],
187                                             self.send_list_rollback[j], self.recv_list_rollback[j])
188            forward_weights[j] = process_weights
189        adasum_parameters = self.hyper_map(F.partial(_update_parameters), delta_weights, forward_weights[0],
190                                           parameters, old_parameters)
191        return adasum_parameters
192
193    def _generate_communication_op(self):
194        """generate communication op."""
195        self.calc_times = int(math.log(self.group_number, 2))
196        self.send_node = []
197        self.send_list_forward = []
198        self.recv_list_forward = []
199        self.send_list_rollback = []
200        self.recv_list_rollback = []
201        self.allreduce_list = []
202        self.broadcast_list = []
203        self.parameter_divisibility_list = []
204        self.allreduce_node_num_list = []
205        last_delta_weights = []
206        group_start_rank = (self.rank // self.device_number) * self.device_number
207
208        for step in range(self.calc_times):
209            current_group = self.device_number * (2 ** step)
210            sr_target = self.rank
211            if (sr_target // current_group) % 2 == 0:
212                dest_target = sr_target + current_group
213                self.send_node.append(True)
214            else:
215                dest_target = sr_target - current_group
216                self.send_node.append(False)
217
218            neighbor_ids = []
219            group_name_last = 0
220            for index in range(2 ** (step + 1)):
221                node_rank = self.rank // self.device_number
222                double_d = 2 ** (step + 1)
223                neighbor_id = (node_rank // double_d * double_d + index) * self.device_number + \
224                              self.rank % self.device_number
225                neighbor_ids.append(neighbor_id)
226                group_name_last += neighbor_id
227            group_name = "adasum_{}_{}".format(str(step), str(group_name_last))
228            create_group(group_name, neighbor_ids)
229
230            send_left = []
231            send_right = []
232            recv_left = []
233            recv_right = []
234            allreduce_node_num = ()
235            left_delta_weights, right_delta_weights, delta_weights_divisibility = \
236                self._get_delta_weights_info(last_delta_weights)
237            self.parameter_divisibility_list.append(delta_weights_divisibility)
238            weights_index = 0
239            fusion_id = (step + 1) * 3
240            for shape, dtype in left_delta_weights:
241                send_tag = AdaSum._hash(step, sr_target, weights_index)
242                send = Send(sr_tag=send_tag, dest_rank=dest_target, group="hccl_world_group")
243                send.add_prim_attr("fusion", fusion_id)
244                recv_tag = AdaSum._hash(step, dest_target, weights_index)
245                recv = Receive(sr_tag=recv_tag, src_rank=dest_target, shape=shape, dtype=dtype,
246                               group="hccl_world_group")
247                recv.add_prim_attr("fusion", fusion_id)
248                send_left.append(send)
249                recv_left.append(recv)
250                weights_index += 1
251            for shape, dtype in right_delta_weights:
252                send_tag = AdaSum._hash(step, sr_target, weights_index)
253                send = Send(sr_tag=send_tag, dest_rank=dest_target, group="hccl_world_group")
254                send.add_prim_attr("fusion", fusion_id + 1)
255                recv_tag = AdaSum._hash(step, dest_target, weights_index)
256                recv = Receive(sr_tag=recv_tag, src_rank=dest_target, shape=shape, dtype=dtype,
257                               group="hccl_world_group")
258                recv.add_prim_attr("fusion", fusion_id + 1)
259                send_right.append(send)
260                recv_right.append(recv)
261                weights_index += 1
262
263            if self.send_node and self.send_node[-1]:
264                self.send_list_forward.append(send_left)
265                self.send_list_rollback.append(send_right)
266                self.recv_list_forward.append(recv_right)
267                self.recv_list_rollback.append(recv_left)
268                last_delta_weights = right_delta_weights
269            else:
270                self.send_list_forward.append(send_right)
271                self.send_list_rollback.append(send_left)
272                self.recv_list_forward.append(recv_left)
273                self.recv_list_rollback.append(recv_right)
274                last_delta_weights = left_delta_weights
275
276            server_all_reduce = P.AllReduce("sum", group_name)
277            server_all_reduce.add_prim_attr("fusion", fusion_id + 2)
278            self.allreduce_list.append(server_all_reduce)
279
280            for param_divisibility in delta_weights_divisibility:
281                if param_divisibility:
282                    allreduce_node_num += (0,)
283                else:
284                    allreduce_node_num += (2 ** (step + 1),)
285            self.allreduce_node_num_list.append(allreduce_node_num)
286
287        broadcast_group = list(range(group_start_rank, group_start_rank + self.device_number))
288        broadcast_group_name = "broadcast_group_" + str(group_start_rank)
289        create_group(broadcast_group_name, broadcast_group)
290        for b_rank in range(len(broadcast_group)):
291            self.broadcast_list.append(P.Broadcast(b_rank, group=broadcast_group_name))
292        self.sync_barrier = P.AllReduce("sum", group=broadcast_group_name)
293
294    def _get_delta_weights_info(self, last_delta_weights):
295        """get delta weights info."""
296        half_delta_weights = []
297        if last_delta_weights:
298            half_delta_weights = last_delta_weights
299        else:
300            for parameter in self.parameter_tuple:
301                new_shape = [int(x) for x in parameter.shape]
302                half_delta_weights.append((new_shape, parameter.dtype))
303        left_delta_weights = []
304        right_delta_weights = []
305        delta_weights_divisibility = ()
306        for shape, dtype in half_delta_weights:
307            left_shape = copy.deepcopy(shape)
308            right_shape = copy.deepcopy(shape)
309            divisibility_flag = False
310            for i, _ in enumerate(shape):
311                if shape[i] > 1:
312                    left_shape[i] = int(shape[i] // 2)
313                    right_shape[i] = shape[i] - int(shape[i] // 2)
314                    divisibility_flag = True
315                    break
316            left_delta_weights.append((left_shape, dtype))
317            right_delta_weights.append((right_shape, dtype))
318            delta_weights_divisibility += (divisibility_flag,)
319        return left_delta_weights, right_delta_weights, delta_weights_divisibility
320