• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2022 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""adasum"""
16from __future__ import absolute_import
17
18import copy
19import hashlib
20import math
21
22import mindspore.nn as nn
23import mindspore.log as logger
24from mindspore import context
25from mindspore import _checkparam as validator
26from mindspore.nn.cell import Cell
27from mindspore.common.parameter import ParameterTuple, Parameter
28from mindspore.parallel._utils import _get_global_rank, _get_stage_device_num
29from mindspore.ops import composite as C
30from mindspore.ops import functional as F
31from mindspore.ops import operations as P
32from mindspore.ops import Send, Receive
33from mindspore.common.tensor import Tensor
34from mindspore.common import dtype as mstype
35from mindspore.communication.management import create_group
36
37__all__ = ["AdaSumByDeltaWeightWrapCell", "AdaSumByGradWrapCell"]
38
39MAX_NUM_HASH = 2 ** 31
40
41_update_parameters = C.MultitypeFuncGraph("update_parameters")
42_reshape_grads = C.MultitypeFuncGraph("reshape_grads")
43
44
45@_update_parameters.register("Tensor", "Tensor", "Tensor", "Tensor", "Function")
46def _update_parameters_adasum(delta_weight, update_delta_weight, parameter, old_parameter, reshape):
47    shape = F.shape(delta_weight)
48    update_delta_weight = reshape(update_delta_weight, shape)
49    new_parameter = old_parameter - update_delta_weight
50    P.Assign()(parameter, new_parameter)
51    return parameter
52
53
54@_reshape_grads.register("Tensor", "Tensor", "Function")
55def reshape_grads_adasum(grads, update_grads, reshape):
56    """
57    Reshape gradient.
58    """
59    shape = F.shape(grads)
60    update_grads = reshape(update_grads, shape)
61    return update_grads
62
63
64def _send_before_receive(send_part, send, recv):
65    send_ok = send(send_part)
66    return recv(send_ok)
67
68
69def _receive_before_send(send_part, send, recv):
70    receive_ok = recv(send_part)
71    send_part = F.depend(send_part, receive_ok)
72    return F.depend(receive_ok, send(send_part))
73
74
75def _send_recv_res(left_send, recv_part, local_part, allreduce, parameter_divisibility, allreduce_node_num):
76    """send result and receive result."""
77    if parameter_divisibility:
78        recv_part = P.Squeeze()(recv_part)
79        if F.shape(recv_part) is None:
80            recv_part = Tensor([recv_part])
81        local_part = F.depend(local_part, recv_part)
82        eps = 1e-12
83        scale_value = P.ReduceMax()(local_part) + eps
84        local_part_scale = local_part / scale_value
85        recv_part_scale = recv_part / scale_value
86        recv_part_scale = F.depend(recv_part_scale, local_part_scale)
87        value_0 = P.ReduceSum()(local_part_scale * recv_part_scale) + eps
88        if left_send:
89            value_1 = P.ReduceSum()(local_part_scale * local_part_scale) + eps
90            value_2 = P.ReduceSum()(recv_part_scale * recv_part_scale) + eps
91        else:
92            value_1 = P.ReduceSum()(recv_part_scale * recv_part_scale) + eps
93            value_2 = P.ReduceSum()(local_part_scale * local_part_scale) + eps
94        value_0 = allreduce(value_0)
95        value_1 = F.depend(allreduce(value_1), value_0)
96        value_2 = F.depend(allreduce(value_2), value_1)
97        if left_send:
98            res = (1 - (value_0 / (2 * value_1))) * local_part + (1 - (value_0 / (2 * value_2))) * recv_part
99        else:
100            res = (1 - (value_0 / (2 * value_1))) * recv_part + (1 - (value_0 / (2 * value_2))) * local_part
101    else:
102        res = allreduce(local_part)
103        res = res / allreduce_node_num
104    return res
105
106
107_adasum_opt_forward = C.MultitypeFuncGraph("adasum_opt_forward")
108_adasum_opt_rollback = C.MultitypeFuncGraph("adasum_opt_rollback")
109
110
111@_adasum_opt_forward.register("Bool", "Function", "Bool", "Int64", "Function", "Function", "Tensor")
112def _adasum_opt_forward_process(left_send, allreduce, parameter_divisibility, allreduce_node_num, send, recv, delta_w):
113    """adasum optimizer process."""
114    if parameter_divisibility:
115        delta_w = P.Squeeze()(delta_w)
116        ori_len = F.shape(delta_w)[0]
117        divide_len = ori_len // 2
118        left_part = delta_w[:divide_len]
119        right_part = delta_w[divide_len:]
120    else:
121        left_part = delta_w
122        right_part = delta_w
123
124    if left_send:
125        if parameter_divisibility:
126            recv_part = _send_before_receive(left_part, send, recv)
127        else:
128            recv_part = right_part
129        update_delta_w = _send_recv_res(left_send, recv_part, right_part, allreduce, parameter_divisibility,
130                                        allreduce_node_num)
131    else:
132        if parameter_divisibility:
133            recv_part = _receive_before_send(right_part, send, recv)
134        else:
135            recv_part = left_part
136        update_delta_w = _send_recv_res(left_send, recv_part, left_part, allreduce, parameter_divisibility,
137                                        allreduce_node_num)
138    return update_delta_w
139
140
141@_adasum_opt_rollback.register("Bool", "Bool", "Tensor", "Function", "Function")
142def _adasum_opt_rollback_process(left_send, parameter_divisibility, delta_w, send, recv):
143    """adasum optimizer rollback process."""
144    if parameter_divisibility:
145        if left_send:
146            recv_part = _send_before_receive(delta_w, send, recv)
147        else:
148            recv_part = _receive_before_send(delta_w, send, recv)
149
150        recv_part = P.Squeeze()(recv_part)
151        if F.shape(recv_part) is None:
152            recv_part = Tensor([recv_part])
153        if F.shape(delta_w) is None:
154            delta_w = Tensor([delta_w])
155        recv_part = P.Reshape()(recv_part, (-1,))
156        delta_w = P.Reshape()(delta_w, (-1,))
157
158        if left_send:
159            res = P.Concat()((recv_part, delta_w))
160        else:
161            res = P.Concat()((delta_w, recv_part))
162    else:
163        res = delta_w
164    return res
165
166
167class _AdaSum(Cell):
168    r"""
169    The Adaptive Summation, or AdaSum, is a novel algorithm for improving distributed data
170    parallel training of Deep Learning models.
171
172    Inputs:
173        - **delta_weights** (Tuple(Tensor)) - Tuple of gradients.
174        - **parameters** (Tuple(Parameter)) - Tuple of current parameters.
175        - **old_parameters** (Tuple(Parameter)) - Tuple of last parameters.
176
177    Outputs:
178        - **adasum_parameters** (Tuple(Tensor)) - Tuple of parameters after adasum process.
179    """
180    def __init__(self, rank, device_number, group_number, parameter_tuple):
181        super(_AdaSum, self).__init__()
182        self.rank = rank
183        self.device_number = device_number
184        self.group_number = group_number
185        self.parameter_tuple = parameter_tuple
186        self.calc_times = int(math.log(self.group_number, 2))
187        self.send_node = []
188        self.send_list_forward = []
189        self.recv_list_forward = []
190        self.send_list_rollback = []
191        self.recv_list_rollback = []
192        self.allreduce_list = []
193        self.parameter_divisibility_list = []
194        self.allreduce_node_num_list = []
195        self._generate_communication_op()
196        self.hyper_map = C.HyperMap()
197        self.update_reshape_list = []
198        for parameter in self.parameter_tuple:
199            reshape = P.Reshape().add_prim_attr("target_param", "adasum_delta_weight." + parameter.name)
200            self.update_reshape_list.append(reshape)
201
202    @staticmethod
203    def _hash(step, target, weights_index):
204        target = "tag" + str(step) + str(target) + str(weights_index)
205        target_hash = hashlib.sha1(target.encode()).hexdigest()
206        hash_res = int(int(target_hash, 16) % MAX_NUM_HASH)
207        return hash_res
208
209    def construct(self, delta_weights, parameters, old_parameters):
210        forward_weights = [delta_weights]
211        for i in range(self.calc_times):
212            process_weights = self.hyper_map(F.partial(_adasum_opt_forward, self.send_node[i]), self.allreduce_list[i],
213                                             self.parameter_divisibility_list[i], self.allreduce_node_num_list[i],
214                                             self.send_list_forward[i], self.recv_list_forward[i], forward_weights[-1])
215            forward_weights.append(process_weights)
216        for i in range(self.calc_times):
217            j = self.calc_times - i - 1
218            process_weights = self.hyper_map(F.partial(_adasum_opt_rollback, self.send_node[j]),
219                                             self.parameter_divisibility_list[j], forward_weights[j + 1],
220                                             self.send_list_rollback[j], self.recv_list_rollback[j])
221            forward_weights[j] = process_weights
222        adasum_parameters = self.hyper_map(F.partial(_update_parameters), delta_weights, forward_weights[0],
223                                           parameters, old_parameters, self.update_reshape_list)
224        return adasum_parameters
225
226    def _generate_communication_op(self):
227        """generate communication op."""
228        last_delta_weights = []
229        fusion_attr = "origin_fusion"
230        if context.get_auto_parallel_context("parallel_mode") in ["data_parallel", "hybrid_parallel"]:
231            fusion_attr = "fusion"
232        for step in range(self.calc_times):
233            current_group = self.device_number * (2 ** step)
234            if (self.rank // current_group) % 2 == 0:
235                dest_target = self.rank + current_group
236                self.send_node.append(True)
237            else:
238                dest_target = self.rank - current_group
239                self.send_node.append(False)
240            send_left = []
241            send_right = []
242            recv_left = []
243            recv_right = []
244            allreduce_node_num = ()
245            left_delta_weights, right_delta_weights, delta_weights_divisibility = \
246                self._get_delta_weights_info(last_delta_weights)
247            self.parameter_divisibility_list.append(delta_weights_divisibility)
248            weights_index = 0
249            fusion_id = (step + 1) * 3
250            for shape, dtype, name in left_delta_weights:
251                send_tag = self._hash(step, self.rank, weights_index)
252                send = Send(sr_tag=send_tag, dest_rank=dest_target, group="hccl_world_group")
253                send.add_prim_attr(fusion_attr, fusion_id)
254                send.add_prim_attr("opposite_rank", dest_target)
255                send.add_prim_attr("target_param", name)
256                recv_tag = self._hash(step, dest_target, weights_index)
257                recv = Receive(sr_tag=recv_tag, src_rank=dest_target, shape=shape, dtype=dtype,
258                               group="hccl_world_group")
259                recv.add_prim_attr(fusion_attr, fusion_id)
260                recv.add_prim_attr("opposite_rank", dest_target)
261                recv.add_prim_attr("target_param", name)
262                send_left.append(send)
263                recv_left.append(recv)
264                weights_index += 1
265            for shape, dtype, name in right_delta_weights:
266                send_tag = self._hash(step, self.rank, weights_index)
267                send = Send(sr_tag=send_tag, dest_rank=dest_target, group="hccl_world_group")
268                send.add_prim_attr(fusion_attr, fusion_id + 1)
269                send.add_prim_attr("opposite_rank", dest_target)
270                send.add_prim_attr("target_param", name)
271                recv_tag = self._hash(step, dest_target, weights_index)
272                recv = Receive(sr_tag=recv_tag, src_rank=dest_target, shape=shape, dtype=dtype,
273                               group="hccl_world_group")
274                recv.add_prim_attr(fusion_attr, fusion_id + 1)
275                recv.add_prim_attr("opposite_rank", dest_target)
276                recv.add_prim_attr("target_param", name)
277                send_right.append(send)
278                recv_right.append(recv)
279                weights_index += 1
280            if self.send_node and self.send_node[-1]:
281                self.send_list_forward.append(send_left)
282                self.send_list_rollback.append(send_right)
283                self.recv_list_forward.append(recv_right)
284                self.recv_list_rollback.append(recv_left)
285                last_delta_weights = right_delta_weights
286            else:
287                self.send_list_forward.append(send_right)
288                self.send_list_rollback.append(send_left)
289                self.recv_list_forward.append(recv_left)
290                self.recv_list_rollback.append(recv_right)
291                last_delta_weights = left_delta_weights
292            param_allreduce_list = []
293            neighbor_ids = []
294            rank_ids = []
295            for index in range(2 ** (step + 1)):
296                node_rank = self.rank // self.device_number
297                double_d = 2 ** (step + 1)
298                neighbor_id = (node_rank // double_d * double_d + index) * self.device_number + \
299                              self.rank % self.device_number
300                neighbor_ids.append(str(neighbor_id))
301                rank_ids.append(neighbor_id)
302            group_name = "-".join(neighbor_ids)
303            if context.get_auto_parallel_context("parallel_mode") in ["data_parallel", "hybrid_parallel"]:
304                create_group(group_name, rank_ids)
305            for parameter in self.parameter_tuple:
306                allreduce = P.AllReduce("sum", group_name)
307                allreduce.add_prim_attr("target_param", "adasum_delta_weight." + parameter.name)
308                allreduce.add_prim_attr(fusion_attr, fusion_id + 2)
309                allreduce.add_prim_attr("step", step)
310                param_allreduce_list.append(allreduce)
311            self.allreduce_list.append(param_allreduce_list)
312            for param_divisibility in delta_weights_divisibility:
313                if param_divisibility:
314                    allreduce_node_num += (0,)
315                else:
316                    allreduce_node_num += (2 ** (step + 1),)
317            self.allreduce_node_num_list.append(allreduce_node_num)
318
319    def _get_delta_weights_info(self, last_delta_weights):
320        """get delta weights info."""
321        half_delta_weights = []
322        if last_delta_weights:
323            half_delta_weights = last_delta_weights
324        else:
325            for parameter in self.parameter_tuple:
326                new_shape = [int(x) for x in parameter.shape]
327                half_delta_weights.append((new_shape, parameter.dtype, "adasum_delta_weight." + parameter.name))
328        left_delta_weights = []
329        right_delta_weights = []
330        delta_weights_divisibility = ()
331        for shape, dtype, name in half_delta_weights:
332            left_shape = copy.deepcopy(shape)
333            right_shape = copy.deepcopy(shape)
334            divisibility_flag = False
335            for i, value in enumerate(shape):
336                if value > 1:
337                    left_shape[i] = int(value // 2)
338                    right_shape[i] = value - int(value // 2)
339                    divisibility_flag = True
340                    break
341            left_delta_weights.append((left_shape, dtype, name))
342            right_delta_weights.append((right_shape, dtype, name))
343            delta_weights_divisibility += (divisibility_flag,)
344        return left_delta_weights, right_delta_weights, delta_weights_divisibility
345
346
347class _AdaSumByGrad(_AdaSum):
348    """Apply adasum by gradients"""
349    def construct(self, grads):
350        forward_grads = [grads]
351        for i in range(self.calc_times):
352            process_weights = self.hyper_map(F.partial(_adasum_opt_forward, self.send_node[i]), self.allreduce_list[i],
353                                             self.parameter_divisibility_list[i], self.allreduce_node_num_list[i],
354                                             self.send_list_forward[i], self.recv_list_forward[i], forward_grads[-1])
355            forward_grads.append(process_weights)
356        for i in range(self.calc_times):
357            j = self.calc_times - i - 1
358            process_weights = self.hyper_map(F.partial(_adasum_opt_rollback, self.send_node[j]),
359                                             self.parameter_divisibility_list[j], forward_grads[j + 1],
360                                             self.send_list_rollback[j], self.recv_list_rollback[j])
361            forward_grads[j] = process_weights
362        update_grads = self.hyper_map(F.partial(_reshape_grads), grads, forward_grads[0],
363                                      self.update_reshape_list)
364        return update_grads
365
366
367_get_delta_weight = C.MultitypeFuncGraph("_get_delta_weight")
368_save_weight = C.MultitypeFuncGraph("_save_weight")
369scale_mul = P.Mul().add_prim_attr("keep_alive", True)
370_clone_weight = C.MultitypeFuncGraph("_clone_weight")
371
372
373@_get_delta_weight.register("Tensor", "Tensor")
374def _get_delta_weight_process(new_parameter, old_parameter):
375    delta_w = old_parameter - new_parameter
376    return delta_w
377
378
379@_save_weight.register("Tensor", "Tensor")
380def _save_weight_process(new_parameter, old_parameter):
381    P.Assign()(new_parameter, old_parameter)
382    return new_parameter
383
384
385@_clone_weight.register("Tensor", "Tensor")
386def _clone_weight_process(scale, weight):
387    return scale_mul(weight, scale)
388
389
390def _parallel_check():
391    """Parallel infos checking"""
392    if context.get_auto_parallel_context("parallel_mode") == "stand_alone":
393        raise RuntimeError("Stand alone mode is not supported to apply adasum.")
394    if context.get_auto_parallel_context("parallel_mode") in ["data_parallel", "hybrid_parallel"]:
395        logger.warning("For data parallel mode or hybrid parallel mode, "
396                       "it is recommended to using mindspore.boost to enable adasum.")
397    if context.get_auto_parallel_context("enable_parallel_optimizer"):
398        raise RuntimeError("Currently, the optimizer shard is not supported with applying adasum.")
399    if context.get_auto_parallel_context("pipeline_stages") > 1:
400        raise RuntimeError("Currently, the pipeline parallel is not supported with applying adasum.")
401    stage_device_num = _get_stage_device_num()
402    if stage_device_num < 16 or (stage_device_num & (stage_device_num - 1) != 0):
403        raise RuntimeError("The device_num must be at least 16 and must be the power of 2 when applying adasum.")
404
405
406class AdaSumByGradWrapCell(Cell):
407    r"""
408    Enable the adasum in "auto_parallel/semi_auto_parallel" mode.
409    The implementation of the Adaptive Summation (AdaSum) algorithm is calculated by gradients.
410    See the paper `AdaSum: Scaling Distributed Training with Adaptive Summation <https://arxiv.org/abs/2006.02924>`_.
411
412    .. math::
413        \begin{array}{ll}
414          w_{t+1}=w_{t} - \alpha \cdot Adasum(g_{1}, g_{2})  \\
415          w_{t+1}=w_{t} - \alpha \cdot [(1 - \frac{g_2^{T}\cdot g_1}{2\cdot \left \| g_1 \right \|^2 })\cdot g_1 + (1 -
416          \frac{g_1^{T}\cdot g_2}{2\cdot \left \| g_2 \right \|^2 })\cdot g_2]  \\
417        \end{array}
418
419    In this implementation, :math:`g` represents the gradient of the weights,
420    and the subscripts represent different devices in the data-parallel dimension.
421
422    Note:
423        When using AdaSum, the number of traning cards needs to be a power of 2 and at least 16 cards are required.
424        Currently, the optimizer sharding and pipeline parallel is not supported when using AdaSum.
425        It is recommended to using AdaSumByGradWrapCell in semi auto parallel/auto parallel mode. In data parallel
426        mode, we recommend to using mindspore.boost to applying AdaSum.
427
428    Args:
429        optimizer (Union[Cell]): Optimizer for updating the weights. The construct function of the optimizer
430            requires only one input.
431
432    Inputs:
433        - **grads** (Tuple(Tensor)) - Tuple of gradients, same with the input of passed optimizer.
434
435    Raises:
436        RuntimeError: If `parallel_mode` uses `stand_alone` mode, AdaSum only supports use in distributed scenarios.
437        RuntimeError: If the optimizer parallel is used when using AdaSum.
438        RuntimeError: If the pipeline parallel is used when using AdaSum.
439        RuntimeError: If `device_num` is not a power of 2, or less than 16.
440
441    Supported Platforms:
442        ``Ascend`` ``GPU``
443
444    Examples:
445        >>> import mindspore as ms
446        >>> from mindspore import nn
447        >>> # Define the network structure of LeNet5. Refer to
448        >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
449        >>> net = LeNet5()
450        >>> optim = nn.AdaSumByGradWrapCell(nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9))
451        >>> loss = nn.SoftmaxCrossEntropyWithLogits()
452        >>> model = ms.train.Model(net, loss_fn=loss, optimizer=optim, metrics=None)
453    """
454    def __init__(self, optimizer):
455        super(AdaSumByGradWrapCell, self).__init__(auto_prefix=False)
456        _device_number = 8
457        _parallel_check()
458        self.optimizer = optimizer
459        validator.check_value_type('optimizer', optimizer, (nn.Optimizer,))
460        self.parameters = optimizer.parameters
461        self.hyper_map = C.HyperMap()
462        group_number = _get_stage_device_num() // _device_number
463        self.grad_clone = ParameterTuple(self.parameters)
464        self.adasum = _AdaSumByGrad(_get_global_rank(), _device_number, group_number, self.grad_clone)
465        self.sync_tensor = Parameter(Tensor(0, dtype=mstype.int32))
466
467    def construct(self, grads):
468        adasum_res = self.adasum(grads)
469        sync_tensor = F.depend(self.sync_tensor, adasum_res)
470        sync_flag = P.AllReduce()(sync_tensor)
471        return F.depend(self.optimizer(adasum_res), sync_flag)
472
473
474class AdaSumByDeltaWeightWrapCell(Cell):
475    r"""
476    Enable the adasum in "auto_parallel/semi_auto_parallel" mode.
477    The implementation of the Adaptive Summation (AdaSum) algorithm is calculated based on the difference of weights
478    before and after the updating of optimizer.
479    See the paper `AdaSum: Scaling Distributed Training with Adaptive Summation <https://arxiv.org/abs/2006.02924>`_.
480
481    .. math::
482        \begin{array}{ll}
483          w_{t+1}=w_{t} - \alpha \cdot Adasum(g_{1}, g_{2})  \\
484          w_{t+1}=w_{t} - \alpha \cdot [(1 - \frac{g_2^{T}\cdot g_1}{2\cdot \left \| g_1 \right \|^2 })\cdot g_1 + (1 -
485          \frac{g_1^{T}\cdot g_2}{2\cdot \left \| g_2 \right \|^2 })\cdot g_2]  \\
486        \end{array}
487
488    In this implementation, :math:`g` represents the weight difference before and after the updating of optimizer,
489    and the subscripts represent different devices in the data parallel dimension.
490
491    Note:
492        When using AdaSum, the number of traning cards needs to be a power of 2 and at least 16 cards are required.
493        Currently, the optimizer sharding and pipeline parallel is not supported when using AdaSum.
494        It is recommended to using AdaSumByDeltaWeightWrapCell in semi auto parallel/auto parallel mode.
495        In data parallel mode, we recommend to using mindspore.boost to applying AdaSum.
496
497    Args:
498        optimizer (Union[Cell]): Optimizer for updating the weights. The construct function of the optimizer
499            requires only one input.
500
501    Inputs:
502        - **grads** (Tuple(Tensor)) - Tuple of gradients, same with the input of passed optimizer.
503
504    Raises:
505        RuntimeError: If `parallel_mode` uses `stand_alone` mode, AdaSum only supports use in distributed scenarios.
506        RuntimeError: If the optimizer parallel is used when using AdaSum.
507        RuntimeError: If the pipeline parallel is used when using AdaSum.
508        RuntimeError: If `device_num` is not a power of 2, or less than 16.
509
510    Supported Platforms:
511        ``Ascend`` ``GPU``
512
513    Examples:
514        >>> import mindspore as ms
515        >>> from mindspore import nn
516        >>> # Define the network structure of LeNet5. Refer to
517        >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
518        >>> net = LeNet5()
519        >>> optim = nn.AdaSumByDeltaWeightWrapCell(nn.Momentum(params=net.trainable_params(),
520        ...                                                 learning_rate=0.1, momentum=0.9))
521        >>> loss = nn.SoftmaxCrossEntropyWithLogits()
522        >>> model = ms.train.Model(net, loss_fn=loss, optimizer=optim, metrics=None)
523    """
524    def __init__(self, optimizer):
525        super(AdaSumByDeltaWeightWrapCell, self).__init__(auto_prefix=False)
526        _parallel_check()
527        self.optimizer = optimizer
528        validator.check_value_type('optimizer', optimizer, (nn.Optimizer,))
529        self.parameters = optimizer.parameters
530        self.hyper_map = C.HyperMap()
531        _device_number = 8
532        group_number = _get_stage_device_num() // _device_number
533        self.grad_clone = ParameterTuple(self.parameters)
534        self.adasum = _AdaSum(_get_global_rank(), _device_number, group_number, self.grad_clone)
535        self.sync_tensor = Parameter(Tensor(0, dtype=mstype.int32))
536        self.scale = Tensor(1.0, dtype=mstype.float32)
537
538    def construct(self, grads):
539        grad_clone = self.hyper_map(F.partial(_clone_weight, self.scale), self.parameters)
540        grads = F.depend(grads, grad_clone)
541        opt_result = self.optimizer(grads)
542        parameters = F.depend(self.parameters, opt_result)
543        delta_w = self.hyper_map(F.partial(_get_delta_weight), parameters, grad_clone)
544        adasum_res = self.adasum(delta_w, parameters, grad_clone)
545        sync_tensor = F.depend(self.sync_tensor, adasum_res)
546        sync_flag = P.AllReduce()(sync_tensor)
547        updated_weights = F.depend(parameters, sync_flag)
548        return updated_weights
549