1# Copyright 2022 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""adasum""" 16from __future__ import absolute_import 17 18import copy 19import hashlib 20import math 21 22import mindspore.nn as nn 23import mindspore.log as logger 24from mindspore import context 25from mindspore import _checkparam as validator 26from mindspore.nn.cell import Cell 27from mindspore.common.parameter import ParameterTuple, Parameter 28from mindspore.parallel._utils import _get_global_rank, _get_stage_device_num 29from mindspore.ops import composite as C 30from mindspore.ops import functional as F 31from mindspore.ops import operations as P 32from mindspore.ops import Send, Receive 33from mindspore.common.tensor import Tensor 34from mindspore.common import dtype as mstype 35from mindspore.communication.management import create_group 36 37__all__ = ["AdaSumByDeltaWeightWrapCell", "AdaSumByGradWrapCell"] 38 39MAX_NUM_HASH = 2 ** 31 40 41_update_parameters = C.MultitypeFuncGraph("update_parameters") 42_reshape_grads = C.MultitypeFuncGraph("reshape_grads") 43 44 45@_update_parameters.register("Tensor", "Tensor", "Tensor", "Tensor", "Function") 46def _update_parameters_adasum(delta_weight, update_delta_weight, parameter, old_parameter, reshape): 47 shape = F.shape(delta_weight) 48 update_delta_weight = reshape(update_delta_weight, shape) 49 new_parameter = old_parameter - update_delta_weight 50 P.Assign()(parameter, new_parameter) 51 return parameter 52 53 54@_reshape_grads.register("Tensor", "Tensor", "Function") 55def reshape_grads_adasum(grads, update_grads, reshape): 56 """ 57 Reshape gradient. 58 """ 59 shape = F.shape(grads) 60 update_grads = reshape(update_grads, shape) 61 return update_grads 62 63 64def _send_before_receive(send_part, send, recv): 65 send_ok = send(send_part) 66 return recv(send_ok) 67 68 69def _receive_before_send(send_part, send, recv): 70 receive_ok = recv(send_part) 71 send_part = F.depend(send_part, receive_ok) 72 return F.depend(receive_ok, send(send_part)) 73 74 75def _send_recv_res(left_send, recv_part, local_part, allreduce, parameter_divisibility, allreduce_node_num): 76 """send result and receive result.""" 77 if parameter_divisibility: 78 recv_part = P.Squeeze()(recv_part) 79 if F.shape(recv_part) is None: 80 recv_part = Tensor([recv_part]) 81 local_part = F.depend(local_part, recv_part) 82 eps = 1e-12 83 scale_value = P.ReduceMax()(local_part) + eps 84 local_part_scale = local_part / scale_value 85 recv_part_scale = recv_part / scale_value 86 recv_part_scale = F.depend(recv_part_scale, local_part_scale) 87 value_0 = P.ReduceSum()(local_part_scale * recv_part_scale) + eps 88 if left_send: 89 value_1 = P.ReduceSum()(local_part_scale * local_part_scale) + eps 90 value_2 = P.ReduceSum()(recv_part_scale * recv_part_scale) + eps 91 else: 92 value_1 = P.ReduceSum()(recv_part_scale * recv_part_scale) + eps 93 value_2 = P.ReduceSum()(local_part_scale * local_part_scale) + eps 94 value_0 = allreduce(value_0) 95 value_1 = F.depend(allreduce(value_1), value_0) 96 value_2 = F.depend(allreduce(value_2), value_1) 97 if left_send: 98 res = (1 - (value_0 / (2 * value_1))) * local_part + (1 - (value_0 / (2 * value_2))) * recv_part 99 else: 100 res = (1 - (value_0 / (2 * value_1))) * recv_part + (1 - (value_0 / (2 * value_2))) * local_part 101 else: 102 res = allreduce(local_part) 103 res = res / allreduce_node_num 104 return res 105 106 107_adasum_opt_forward = C.MultitypeFuncGraph("adasum_opt_forward") 108_adasum_opt_rollback = C.MultitypeFuncGraph("adasum_opt_rollback") 109 110 111@_adasum_opt_forward.register("Bool", "Function", "Bool", "Int64", "Function", "Function", "Tensor") 112def _adasum_opt_forward_process(left_send, allreduce, parameter_divisibility, allreduce_node_num, send, recv, delta_w): 113 """adasum optimizer process.""" 114 if parameter_divisibility: 115 delta_w = P.Squeeze()(delta_w) 116 ori_len = F.shape(delta_w)[0] 117 divide_len = ori_len // 2 118 left_part = delta_w[:divide_len] 119 right_part = delta_w[divide_len:] 120 else: 121 left_part = delta_w 122 right_part = delta_w 123 124 if left_send: 125 if parameter_divisibility: 126 recv_part = _send_before_receive(left_part, send, recv) 127 else: 128 recv_part = right_part 129 update_delta_w = _send_recv_res(left_send, recv_part, right_part, allreduce, parameter_divisibility, 130 allreduce_node_num) 131 else: 132 if parameter_divisibility: 133 recv_part = _receive_before_send(right_part, send, recv) 134 else: 135 recv_part = left_part 136 update_delta_w = _send_recv_res(left_send, recv_part, left_part, allreduce, parameter_divisibility, 137 allreduce_node_num) 138 return update_delta_w 139 140 141@_adasum_opt_rollback.register("Bool", "Bool", "Tensor", "Function", "Function") 142def _adasum_opt_rollback_process(left_send, parameter_divisibility, delta_w, send, recv): 143 """adasum optimizer rollback process.""" 144 if parameter_divisibility: 145 if left_send: 146 recv_part = _send_before_receive(delta_w, send, recv) 147 else: 148 recv_part = _receive_before_send(delta_w, send, recv) 149 150 recv_part = P.Squeeze()(recv_part) 151 if F.shape(recv_part) is None: 152 recv_part = Tensor([recv_part]) 153 if F.shape(delta_w) is None: 154 delta_w = Tensor([delta_w]) 155 recv_part = P.Reshape()(recv_part, (-1,)) 156 delta_w = P.Reshape()(delta_w, (-1,)) 157 158 if left_send: 159 res = P.Concat()((recv_part, delta_w)) 160 else: 161 res = P.Concat()((delta_w, recv_part)) 162 else: 163 res = delta_w 164 return res 165 166 167class _AdaSum(Cell): 168 r""" 169 The Adaptive Summation, or AdaSum, is a novel algorithm for improving distributed data 170 parallel training of Deep Learning models. 171 172 Inputs: 173 - **delta_weights** (Tuple(Tensor)) - Tuple of gradients. 174 - **parameters** (Tuple(Parameter)) - Tuple of current parameters. 175 - **old_parameters** (Tuple(Parameter)) - Tuple of last parameters. 176 177 Outputs: 178 - **adasum_parameters** (Tuple(Tensor)) - Tuple of parameters after adasum process. 179 """ 180 def __init__(self, rank, device_number, group_number, parameter_tuple): 181 super(_AdaSum, self).__init__() 182 self.rank = rank 183 self.device_number = device_number 184 self.group_number = group_number 185 self.parameter_tuple = parameter_tuple 186 self.calc_times = int(math.log(self.group_number, 2)) 187 self.send_node = [] 188 self.send_list_forward = [] 189 self.recv_list_forward = [] 190 self.send_list_rollback = [] 191 self.recv_list_rollback = [] 192 self.allreduce_list = [] 193 self.parameter_divisibility_list = [] 194 self.allreduce_node_num_list = [] 195 self._generate_communication_op() 196 self.hyper_map = C.HyperMap() 197 self.update_reshape_list = [] 198 for parameter in self.parameter_tuple: 199 reshape = P.Reshape().add_prim_attr("target_param", "adasum_delta_weight." + parameter.name) 200 self.update_reshape_list.append(reshape) 201 202 @staticmethod 203 def _hash(step, target, weights_index): 204 target = "tag" + str(step) + str(target) + str(weights_index) 205 target_hash = hashlib.sha1(target.encode()).hexdigest() 206 hash_res = int(int(target_hash, 16) % MAX_NUM_HASH) 207 return hash_res 208 209 def construct(self, delta_weights, parameters, old_parameters): 210 forward_weights = [delta_weights] 211 for i in range(self.calc_times): 212 process_weights = self.hyper_map(F.partial(_adasum_opt_forward, self.send_node[i]), self.allreduce_list[i], 213 self.parameter_divisibility_list[i], self.allreduce_node_num_list[i], 214 self.send_list_forward[i], self.recv_list_forward[i], forward_weights[-1]) 215 forward_weights.append(process_weights) 216 for i in range(self.calc_times): 217 j = self.calc_times - i - 1 218 process_weights = self.hyper_map(F.partial(_adasum_opt_rollback, self.send_node[j]), 219 self.parameter_divisibility_list[j], forward_weights[j + 1], 220 self.send_list_rollback[j], self.recv_list_rollback[j]) 221 forward_weights[j] = process_weights 222 adasum_parameters = self.hyper_map(F.partial(_update_parameters), delta_weights, forward_weights[0], 223 parameters, old_parameters, self.update_reshape_list) 224 return adasum_parameters 225 226 def _generate_communication_op(self): 227 """generate communication op.""" 228 last_delta_weights = [] 229 fusion_attr = "origin_fusion" 230 if context.get_auto_parallel_context("parallel_mode") in ["data_parallel", "hybrid_parallel"]: 231 fusion_attr = "fusion" 232 for step in range(self.calc_times): 233 current_group = self.device_number * (2 ** step) 234 if (self.rank // current_group) % 2 == 0: 235 dest_target = self.rank + current_group 236 self.send_node.append(True) 237 else: 238 dest_target = self.rank - current_group 239 self.send_node.append(False) 240 send_left = [] 241 send_right = [] 242 recv_left = [] 243 recv_right = [] 244 allreduce_node_num = () 245 left_delta_weights, right_delta_weights, delta_weights_divisibility = \ 246 self._get_delta_weights_info(last_delta_weights) 247 self.parameter_divisibility_list.append(delta_weights_divisibility) 248 weights_index = 0 249 fusion_id = (step + 1) * 3 250 for shape, dtype, name in left_delta_weights: 251 send_tag = self._hash(step, self.rank, weights_index) 252 send = Send(sr_tag=send_tag, dest_rank=dest_target, group="hccl_world_group") 253 send.add_prim_attr(fusion_attr, fusion_id) 254 send.add_prim_attr("opposite_rank", dest_target) 255 send.add_prim_attr("target_param", name) 256 recv_tag = self._hash(step, dest_target, weights_index) 257 recv = Receive(sr_tag=recv_tag, src_rank=dest_target, shape=shape, dtype=dtype, 258 group="hccl_world_group") 259 recv.add_prim_attr(fusion_attr, fusion_id) 260 recv.add_prim_attr("opposite_rank", dest_target) 261 recv.add_prim_attr("target_param", name) 262 send_left.append(send) 263 recv_left.append(recv) 264 weights_index += 1 265 for shape, dtype, name in right_delta_weights: 266 send_tag = self._hash(step, self.rank, weights_index) 267 send = Send(sr_tag=send_tag, dest_rank=dest_target, group="hccl_world_group") 268 send.add_prim_attr(fusion_attr, fusion_id + 1) 269 send.add_prim_attr("opposite_rank", dest_target) 270 send.add_prim_attr("target_param", name) 271 recv_tag = self._hash(step, dest_target, weights_index) 272 recv = Receive(sr_tag=recv_tag, src_rank=dest_target, shape=shape, dtype=dtype, 273 group="hccl_world_group") 274 recv.add_prim_attr(fusion_attr, fusion_id + 1) 275 recv.add_prim_attr("opposite_rank", dest_target) 276 recv.add_prim_attr("target_param", name) 277 send_right.append(send) 278 recv_right.append(recv) 279 weights_index += 1 280 if self.send_node and self.send_node[-1]: 281 self.send_list_forward.append(send_left) 282 self.send_list_rollback.append(send_right) 283 self.recv_list_forward.append(recv_right) 284 self.recv_list_rollback.append(recv_left) 285 last_delta_weights = right_delta_weights 286 else: 287 self.send_list_forward.append(send_right) 288 self.send_list_rollback.append(send_left) 289 self.recv_list_forward.append(recv_left) 290 self.recv_list_rollback.append(recv_right) 291 last_delta_weights = left_delta_weights 292 param_allreduce_list = [] 293 neighbor_ids = [] 294 rank_ids = [] 295 for index in range(2 ** (step + 1)): 296 node_rank = self.rank // self.device_number 297 double_d = 2 ** (step + 1) 298 neighbor_id = (node_rank // double_d * double_d + index) * self.device_number + \ 299 self.rank % self.device_number 300 neighbor_ids.append(str(neighbor_id)) 301 rank_ids.append(neighbor_id) 302 group_name = "-".join(neighbor_ids) 303 if context.get_auto_parallel_context("parallel_mode") in ["data_parallel", "hybrid_parallel"]: 304 create_group(group_name, rank_ids) 305 for parameter in self.parameter_tuple: 306 allreduce = P.AllReduce("sum", group_name) 307 allreduce.add_prim_attr("target_param", "adasum_delta_weight." + parameter.name) 308 allreduce.add_prim_attr(fusion_attr, fusion_id + 2) 309 allreduce.add_prim_attr("step", step) 310 param_allreduce_list.append(allreduce) 311 self.allreduce_list.append(param_allreduce_list) 312 for param_divisibility in delta_weights_divisibility: 313 if param_divisibility: 314 allreduce_node_num += (0,) 315 else: 316 allreduce_node_num += (2 ** (step + 1),) 317 self.allreduce_node_num_list.append(allreduce_node_num) 318 319 def _get_delta_weights_info(self, last_delta_weights): 320 """get delta weights info.""" 321 half_delta_weights = [] 322 if last_delta_weights: 323 half_delta_weights = last_delta_weights 324 else: 325 for parameter in self.parameter_tuple: 326 new_shape = [int(x) for x in parameter.shape] 327 half_delta_weights.append((new_shape, parameter.dtype, "adasum_delta_weight." + parameter.name)) 328 left_delta_weights = [] 329 right_delta_weights = [] 330 delta_weights_divisibility = () 331 for shape, dtype, name in half_delta_weights: 332 left_shape = copy.deepcopy(shape) 333 right_shape = copy.deepcopy(shape) 334 divisibility_flag = False 335 for i, value in enumerate(shape): 336 if value > 1: 337 left_shape[i] = int(value // 2) 338 right_shape[i] = value - int(value // 2) 339 divisibility_flag = True 340 break 341 left_delta_weights.append((left_shape, dtype, name)) 342 right_delta_weights.append((right_shape, dtype, name)) 343 delta_weights_divisibility += (divisibility_flag,) 344 return left_delta_weights, right_delta_weights, delta_weights_divisibility 345 346 347class _AdaSumByGrad(_AdaSum): 348 """Apply adasum by gradients""" 349 def construct(self, grads): 350 forward_grads = [grads] 351 for i in range(self.calc_times): 352 process_weights = self.hyper_map(F.partial(_adasum_opt_forward, self.send_node[i]), self.allreduce_list[i], 353 self.parameter_divisibility_list[i], self.allreduce_node_num_list[i], 354 self.send_list_forward[i], self.recv_list_forward[i], forward_grads[-1]) 355 forward_grads.append(process_weights) 356 for i in range(self.calc_times): 357 j = self.calc_times - i - 1 358 process_weights = self.hyper_map(F.partial(_adasum_opt_rollback, self.send_node[j]), 359 self.parameter_divisibility_list[j], forward_grads[j + 1], 360 self.send_list_rollback[j], self.recv_list_rollback[j]) 361 forward_grads[j] = process_weights 362 update_grads = self.hyper_map(F.partial(_reshape_grads), grads, forward_grads[0], 363 self.update_reshape_list) 364 return update_grads 365 366 367_get_delta_weight = C.MultitypeFuncGraph("_get_delta_weight") 368_save_weight = C.MultitypeFuncGraph("_save_weight") 369scale_mul = P.Mul().add_prim_attr("keep_alive", True) 370_clone_weight = C.MultitypeFuncGraph("_clone_weight") 371 372 373@_get_delta_weight.register("Tensor", "Tensor") 374def _get_delta_weight_process(new_parameter, old_parameter): 375 delta_w = old_parameter - new_parameter 376 return delta_w 377 378 379@_save_weight.register("Tensor", "Tensor") 380def _save_weight_process(new_parameter, old_parameter): 381 P.Assign()(new_parameter, old_parameter) 382 return new_parameter 383 384 385@_clone_weight.register("Tensor", "Tensor") 386def _clone_weight_process(scale, weight): 387 return scale_mul(weight, scale) 388 389 390def _parallel_check(): 391 """Parallel infos checking""" 392 if context.get_auto_parallel_context("parallel_mode") == "stand_alone": 393 raise RuntimeError("Stand alone mode is not supported to apply adasum.") 394 if context.get_auto_parallel_context("parallel_mode") in ["data_parallel", "hybrid_parallel"]: 395 logger.warning("For data parallel mode or hybrid parallel mode, " 396 "it is recommended to using mindspore.boost to enable adasum.") 397 if context.get_auto_parallel_context("enable_parallel_optimizer"): 398 raise RuntimeError("Currently, the optimizer shard is not supported with applying adasum.") 399 if context.get_auto_parallel_context("pipeline_stages") > 1: 400 raise RuntimeError("Currently, the pipeline parallel is not supported with applying adasum.") 401 stage_device_num = _get_stage_device_num() 402 if stage_device_num < 16 or (stage_device_num & (stage_device_num - 1) != 0): 403 raise RuntimeError("The device_num must be at least 16 and must be the power of 2 when applying adasum.") 404 405 406class AdaSumByGradWrapCell(Cell): 407 r""" 408 Enable the adasum in "auto_parallel/semi_auto_parallel" mode. 409 The implementation of the Adaptive Summation (AdaSum) algorithm is calculated by gradients. 410 See the paper `AdaSum: Scaling Distributed Training with Adaptive Summation <https://arxiv.org/abs/2006.02924>`_. 411 412 .. math:: 413 \begin{array}{ll} 414 w_{t+1}=w_{t} - \alpha \cdot Adasum(g_{1}, g_{2}) \\ 415 w_{t+1}=w_{t} - \alpha \cdot [(1 - \frac{g_2^{T}\cdot g_1}{2\cdot \left \| g_1 \right \|^2 })\cdot g_1 + (1 - 416 \frac{g_1^{T}\cdot g_2}{2\cdot \left \| g_2 \right \|^2 })\cdot g_2] \\ 417 \end{array} 418 419 In this implementation, :math:`g` represents the gradient of the weights, 420 and the subscripts represent different devices in the data-parallel dimension. 421 422 Note: 423 When using AdaSum, the number of traning cards needs to be a power of 2 and at least 16 cards are required. 424 Currently, the optimizer sharding and pipeline parallel is not supported when using AdaSum. 425 It is recommended to using AdaSumByGradWrapCell in semi auto parallel/auto parallel mode. In data parallel 426 mode, we recommend to using mindspore.boost to applying AdaSum. 427 428 Args: 429 optimizer (Union[Cell]): Optimizer for updating the weights. The construct function of the optimizer 430 requires only one input. 431 432 Inputs: 433 - **grads** (Tuple(Tensor)) - Tuple of gradients, same with the input of passed optimizer. 434 435 Raises: 436 RuntimeError: If `parallel_mode` uses `stand_alone` mode, AdaSum only supports use in distributed scenarios. 437 RuntimeError: If the optimizer parallel is used when using AdaSum. 438 RuntimeError: If the pipeline parallel is used when using AdaSum. 439 RuntimeError: If `device_num` is not a power of 2, or less than 16. 440 441 Supported Platforms: 442 ``Ascend`` ``GPU`` 443 444 Examples: 445 >>> import mindspore as ms 446 >>> from mindspore import nn 447 >>> # Define the network structure of LeNet5. Refer to 448 >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py 449 >>> net = LeNet5() 450 >>> optim = nn.AdaSumByGradWrapCell(nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)) 451 >>> loss = nn.SoftmaxCrossEntropyWithLogits() 452 >>> model = ms.train.Model(net, loss_fn=loss, optimizer=optim, metrics=None) 453 """ 454 def __init__(self, optimizer): 455 super(AdaSumByGradWrapCell, self).__init__(auto_prefix=False) 456 _device_number = 8 457 _parallel_check() 458 self.optimizer = optimizer 459 validator.check_value_type('optimizer', optimizer, (nn.Optimizer,)) 460 self.parameters = optimizer.parameters 461 self.hyper_map = C.HyperMap() 462 group_number = _get_stage_device_num() // _device_number 463 self.grad_clone = ParameterTuple(self.parameters) 464 self.adasum = _AdaSumByGrad(_get_global_rank(), _device_number, group_number, self.grad_clone) 465 self.sync_tensor = Parameter(Tensor(0, dtype=mstype.int32)) 466 467 def construct(self, grads): 468 adasum_res = self.adasum(grads) 469 sync_tensor = F.depend(self.sync_tensor, adasum_res) 470 sync_flag = P.AllReduce()(sync_tensor) 471 return F.depend(self.optimizer(adasum_res), sync_flag) 472 473 474class AdaSumByDeltaWeightWrapCell(Cell): 475 r""" 476 Enable the adasum in "auto_parallel/semi_auto_parallel" mode. 477 The implementation of the Adaptive Summation (AdaSum) algorithm is calculated based on the difference of weights 478 before and after the updating of optimizer. 479 See the paper `AdaSum: Scaling Distributed Training with Adaptive Summation <https://arxiv.org/abs/2006.02924>`_. 480 481 .. math:: 482 \begin{array}{ll} 483 w_{t+1}=w_{t} - \alpha \cdot Adasum(g_{1}, g_{2}) \\ 484 w_{t+1}=w_{t} - \alpha \cdot [(1 - \frac{g_2^{T}\cdot g_1}{2\cdot \left \| g_1 \right \|^2 })\cdot g_1 + (1 - 485 \frac{g_1^{T}\cdot g_2}{2\cdot \left \| g_2 \right \|^2 })\cdot g_2] \\ 486 \end{array} 487 488 In this implementation, :math:`g` represents the weight difference before and after the updating of optimizer, 489 and the subscripts represent different devices in the data parallel dimension. 490 491 Note: 492 When using AdaSum, the number of traning cards needs to be a power of 2 and at least 16 cards are required. 493 Currently, the optimizer sharding and pipeline parallel is not supported when using AdaSum. 494 It is recommended to using AdaSumByDeltaWeightWrapCell in semi auto parallel/auto parallel mode. 495 In data parallel mode, we recommend to using mindspore.boost to applying AdaSum. 496 497 Args: 498 optimizer (Union[Cell]): Optimizer for updating the weights. The construct function of the optimizer 499 requires only one input. 500 501 Inputs: 502 - **grads** (Tuple(Tensor)) - Tuple of gradients, same with the input of passed optimizer. 503 504 Raises: 505 RuntimeError: If `parallel_mode` uses `stand_alone` mode, AdaSum only supports use in distributed scenarios. 506 RuntimeError: If the optimizer parallel is used when using AdaSum. 507 RuntimeError: If the pipeline parallel is used when using AdaSum. 508 RuntimeError: If `device_num` is not a power of 2, or less than 16. 509 510 Supported Platforms: 511 ``Ascend`` ``GPU`` 512 513 Examples: 514 >>> import mindspore as ms 515 >>> from mindspore import nn 516 >>> # Define the network structure of LeNet5. Refer to 517 >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py 518 >>> net = LeNet5() 519 >>> optim = nn.AdaSumByDeltaWeightWrapCell(nn.Momentum(params=net.trainable_params(), 520 ... learning_rate=0.1, momentum=0.9)) 521 >>> loss = nn.SoftmaxCrossEntropyWithLogits() 522 >>> model = ms.train.Model(net, loss_fn=loss, optimizer=optim, metrics=None) 523 """ 524 def __init__(self, optimizer): 525 super(AdaSumByDeltaWeightWrapCell, self).__init__(auto_prefix=False) 526 _parallel_check() 527 self.optimizer = optimizer 528 validator.check_value_type('optimizer', optimizer, (nn.Optimizer,)) 529 self.parameters = optimizer.parameters 530 self.hyper_map = C.HyperMap() 531 _device_number = 8 532 group_number = _get_stage_device_num() // _device_number 533 self.grad_clone = ParameterTuple(self.parameters) 534 self.adasum = _AdaSum(_get_global_rank(), _device_number, group_number, self.grad_clone) 535 self.sync_tensor = Parameter(Tensor(0, dtype=mstype.int32)) 536 self.scale = Tensor(1.0, dtype=mstype.float32) 537 538 def construct(self, grads): 539 grad_clone = self.hyper_map(F.partial(_clone_weight, self.scale), self.parameters) 540 grads = F.depend(grads, grad_clone) 541 opt_result = self.optimizer(grads) 542 parameters = F.depend(self.parameters, opt_result) 543 delta_w = self.hyper_map(F.partial(_get_delta_weight), parameters, grad_clone) 544 adasum_res = self.adasum(delta_w, parameters, grad_clone) 545 sync_tensor = F.depend(self.sync_tensor, adasum_res) 546 sync_flag = P.AllReduce()(sync_tensor) 547 updated_weights = F.depend(parameters, sync_flag) 548 return updated_weights 549