1# Copyright 2021-2022 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""Boost Mode Cell Wrapper.""" 16from __future__ import absolute_import 17 18import numpy as np 19from mindspore.nn.wrap import TrainOneStepCell 20import mindspore.context as context 21from mindspore.context import ParallelMode 22from mindspore.parallel._utils import _get_global_rank, _get_device_num, _get_gradients_mean 23from mindspore.communication.management import get_group_size, create_group 24from mindspore.nn.cell import Cell 25from mindspore.nn import SequentialCell 26from mindspore.common import Tensor 27from mindspore.common.sparse_tensor import RowTensorInner 28from mindspore.common.parameter import Parameter, ParameterTuple 29from mindspore.nn.wrap.grad_reducer import DistributedGradReducer 30from mindspore.ops.operations.math_ops import NPUGetFloatStatusV2, NPUClearFloatStatusV2 31from mindspore.ops import functional as F 32from mindspore.ops import composite as C 33from mindspore.ops import operations as P 34from mindspore.common import dtype as mstype 35from mindspore.boost.boost import AutoBoost 36from mindspore.boost.grad_freeze import FreezeOpt, freeze_cell 37from mindspore.boost.adasum import AdaSum 38from mindspore.boost.dim_reduce import DimReduce 39from mindspore.boost.grad_accumulation import gradient_accumulation_op, gradient_clear_op 40from mindspore.boost.base import _load_local_pca_mat 41 42__all__ = ["BoostTrainOneStepCell", "BoostTrainOneStepWithLossScaleCell"] 43 44_get_delta_weight = C.MultitypeFuncGraph("_get_delta_weight") 45 46 47@_get_delta_weight.register("Tensor", "Tensor") 48def _get_delta_weight_process(new_parameter, old_parameter): 49 delta_w = old_parameter - new_parameter 50 return delta_w 51 52 53_save_weight = C.MultitypeFuncGraph("_save_weight") 54 55 56@_save_weight.register("Tensor", "Tensor") 57def _save_weight_process(new_parameter, old_parameter): 58 P.Assign()(new_parameter, old_parameter) 59 return new_parameter 60 61 62_grad_scale = C.MultitypeFuncGraph("grad_scale") 63reciprocal = P.Reciprocal() 64 65 66@_grad_scale.register("Tensor", "Tensor") 67def tensor_grad_scale(scale, grad): 68 """grad scale function for tensor""" 69 return grad * F.cast(reciprocal(scale), F.dtype(grad)) 70 71 72@_grad_scale.register("Tensor", "RowTensor") 73def tensor_grad_scale_row_tensor(scale, grad): 74 """grad scale function for row tensor""" 75 return RowTensorInner(grad.indices, 76 grad.values * F.cast(reciprocal(scale), F.dtype(grad.values)), 77 grad.dense_shape) 78 79 80_grad_overflow = C.MultitypeFuncGraph("_grad_overflow") 81grad_overflow = P.FloatStatus() 82 83 84@_grad_overflow.register("Tensor") 85def _tensor_grad_overflow(grad): 86 return grad_overflow(grad) 87 88 89@_grad_overflow.register("RowTensor") 90def _tensor_grad_overflow_row_tensor(grad): 91 return grad_overflow(grad.values) 92 93 94class _OutputToFloat16(Cell): 95 "Wrap cell for amp. Cast network output back to float16" 96 97 def __init__(self, op): 98 super(_OutputToFloat16, self).__init__(auto_prefix=False) 99 self._op = op 100 101 def construct(self, *inputs): 102 return F.cast(self._op(*inputs), mstype.float16) 103 104 105class BoostTrainOneStepCell(TrainOneStepCell): 106 r""" 107 Boost Network training package class. 108 109 Wraps the network with an optimizer. The resulting Cell is trained with input '\*inputs'. 110 The backward graph will be created in the construct function to update the parameter, and different 111 parallel modes are available for training. 112 113 Args: 114 network (Cell): The training network. The network only supports single output. 115 optimizer (Union[Cell]): Optimizer for updating the weights. 116 sens (numbers.Number): The scaling number to be filled as the input of backpropagation. 117 Default: ``None`` , which is ``1.0`` . 118 119 Inputs: 120 - **\*inputs** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`. 121 122 Outputs: 123 Tensor, a tensor means the loss value, the shape of which is usually :math:`()`. 124 125 - loss(Tensor): A scalar Tensor. 126 - overflow(Tensor): A scalar Tensor which type is bool. 127 - loss scaling value(Tensor): A scalar Tensor. 128 129 Raises: 130 TypeError: If `sens` is not a number. 131 132 Supported Platforms: 133 ``Ascend`` ``GPU`` ``CPU`` 134 135 Examples: 136 >>> from mindspore import boost 137 >>> from mindspore import nn 138 >>> # Define the network structure of LeNet5. Refer to 139 >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py 140 >>> net = LeNet5() 141 >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits() 142 >>> optim = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) 143 >>> #1) Using the WithLossCell existing provide 144 >>> loss_net = nn.WithLossCell(net, loss_fn) 145 >>> train_net = boost.BoostTrainOneStepCell(loss_net, optim) 146 >>> 147 >>> #2) Using user-defined WithLossCell 148 >>> class MyWithLossCell(nn.Cell): 149 ... def __init__(self, backbone, loss_fn): 150 ... super(MyWithLossCell, self).__init__(auto_prefix=False) 151 ... self._backbone = backbone 152 ... self._loss_fn = loss_fn 153 ... 154 ... def construct(self, x, y, label): 155 ... out = self._backbone(x, y) 156 ... return self._loss_fn(out, label) 157 ... 158 ... @property 159 ... def backbone_network(self): 160 ... return self._backbone 161 ... 162 >>> loss_net = MyWithLossCell(net, loss_fn) 163 >>> train_net = boost.BoostTrainOneStepCell(loss_net, optim) 164 """ 165 166 def __init__(self, network, optimizer, sens=None): 167 super(BoostTrainOneStepCell, self).__init__(network, optimizer, sens) 168 self.hyper_map = C.HyperMap() 169 self.freeze = isinstance(optimizer, FreezeOpt) 170 if not self.freeze: 171 self.weights = self.optimizer.parameters 172 self.train_strategy = getattr(self.optimizer, 'train_strategy', None) 173 174 self.auto_boost = AutoBoost() 175 self.use_grad_accumulation = self.parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.STAND_ALONE) 176 self.use_grad_accumulation = self.use_grad_accumulation & \ 177 self.auto_boost.boost_config.get("grad_accumulation", False) 178 self.max_accumulation_step = 1 179 if self.use_grad_accumulation: 180 181 self.max_accumulation_step = self.auto_boost.grad_accumulation_step 182 if self.max_accumulation_step <= 1: 183 self.max_accumulation_step = 1 184 self.use_grad_accumulation = False 185 self.accumulation_step = Parameter(Tensor(0, dtype=mstype.int32), name="accumulation_step") 186 if self.use_grad_accumulation: 187 self.grad_accumulation = self.weights.clone(prefix="grad_accumulation", init='zeros') 188 189 self.enable_dim_reduce = self.check_dim_reduce_enable() 190 if self.enable_dim_reduce: 191 self.__init_dim_reduce() 192 193 self.freeze_nets = None 194 self.step = Parameter(Tensor(0, dtype=mstype.int32)) 195 if self.freeze: 196 if self.reducer_flag: 197 self.mean = _get_gradients_mean() 198 self.degree = _get_device_num() 199 else: 200 self.mean = None 201 self.degree = None 202 self.freeze_nets = freeze_cell(self.reducer_flag, self.network, self.optimizer, self.sens, 203 self.grad, self.use_grad_accumulation, self.mean, self.degree, 204 self.max_accumulation_step) 205 206 self.enable_adasum = self.check_adasum_enable() 207 self.sync_tensor = Parameter(Tensor(0, dtype=mstype.int32)) 208 if self.enable_adasum: 209 self.__init_adasum() 210 211 def construct(self, *inputs): 212 if self.freeze: 213 loss = self.gradient_freeze_process(*inputs) 214 else: 215 if not self.sense_flag: 216 return self._no_sens_impl(*inputs) 217 loss = self.network(*inputs) 218 sens = F.fill(loss.dtype, loss.shape, self.sens) 219 grads = self.grad(self.network, self.weights)(*inputs, sens) 220 grads = self.grad_reducer(grads) 221 if self.use_grad_accumulation: 222 loss = self.gradient_accumulation_process(loss, grads, sens, *inputs) 223 else: 224 if self.enable_dim_reduce: 225 loss = F.depend(loss, self.dim_reduce(loss, grads, sens, self.weights, self.weights_clone, *inputs)) 226 elif self.enable_adasum: 227 loss = F.depend(loss, self.adasum_process(loss, grads)) 228 else: 229 loss = F.depend(loss, self.optimizer(grads)) 230 return loss 231 232 def gradient_freeze_process(self, *inputs): 233 r""" 234 Gradient freeze algorithm process. 235 236 Args: 237 inputs (tuple(Tensor)): Tuple of input tensors with shape :math:`(N, \ldots)`. 238 239 Returns: 240 - **loss** (Tensor) - Network loss, tensor with shape :math:`()`. 241 """ 242 if self.train_strategy is None: 243 step = self.step 244 max_index = len(self.freeze_nets) 245 else: 246 step = self.train_strategy[self.step] 247 max_index = len(self.train_strategy) 248 loss = self.freeze_nets[step](*inputs) 249 if self.step + 1 >= max_index: 250 self.step = 0 251 else: 252 self.step += 1 253 return loss 254 255 def gradient_accumulation_process(self, loss, grads, sens, *inputs): 256 r""" 257 Gradient accumulation algorithm process. 258 259 Args: 260 loss (Tensor): Tensor with shape :math:`()`. 261 grads (tuple(Tensor)): Tuple of gradient tensors. 262 sens (Tensor): Tensor with shape :math:`()`. 263 inputs (tuple(Tensor)): Tuple of input tensors with shape :math:`(N, \ldots)`. 264 265 Returns: 266 - **loss** (Tensor) - Network loss, tensor with shape :math:`()`. 267 """ 268 loss = F.depend(loss, self.hyper_map(F.partial(gradient_accumulation_op, self.max_accumulation_step), 269 self.grad_accumulation, grads)) 270 self.accumulation_step += 1 271 272 if self.accumulation_step >= self.max_accumulation_step: 273 if self.enable_dim_reduce: 274 loss = F.depend(loss, self.dim_reduce(loss, self.grad_accumulation, sens, self.weights, 275 self.weights_clone, *inputs)) 276 elif self.enable_adasum: 277 loss = F.depend(loss, self.adasum_process(loss, self.grad_accumulation)) 278 else: 279 loss = F.depend(loss, self.optimizer(self.grad_accumulation)) 280 self.accumulation_step = 0 281 282 if self.accumulation_step == 0: 283 loss = F.depend(loss, self.hyper_map(F.partial(gradient_clear_op), self.grad_accumulation)) 284 285 return loss 286 287 def adasum_process(self, loss, grads): 288 r""" 289 Adasum algorithm process. 290 291 Args: 292 loss (Tensor): Tensor with shape :math:`()`. 293 grads (tuple(Tensor)): Tuple of gradient tensors. 294 295 Returns: 296 - **loss** (Tensor) - Network loss, tensor with shape :math:`()`. 297 """ 298 loss = F.depend(loss, self.optimizer(grads)) 299 rank_weights = self.weights[self.start[self.server_rank]: self.end[self.server_rank]] 300 grad_clone = F.depend(self.grad_clone, loss) 301 delta_w = self.hyper_map(F.partial(_get_delta_weight), rank_weights, grad_clone) 302 adasum_res = self.adasum(delta_w, rank_weights, grad_clone) 303 sync_tensor = F.depend(self.sync_tensor, adasum_res) 304 sync_flag = self.adasum.sync_barrier(sync_tensor) 305 for i in range(self.device_number): 306 weight_tuple = self.weights[self.start[i]: self.end[i]] 307 node_rank = F.depend(weight_tuple, sync_flag) 308 update_weights = self.adasum.broadcast_list[i](node_rank) 309 if i == self.server_rank: 310 self.hyper_map(F.partial(_save_weight), self.grad_clone, update_weights) 311 else: 312 self.hyper_map(F.partial(_save_weight), weight_tuple, update_weights) 313 return loss 314 315 def check_adasum_enable(self): 316 r""" 317 Check adasum enable. 318 319 Returns: 320 - **enable_adasum** (bool) - Check whether the Adasum algorithm is enabled. 321 """ 322 if not getattr(self.optimizer, "adasum", None) or not self.reducer_flag: 323 return False 324 _rank_size = get_group_size() 325 _device_number = 8 326 group_number = _rank_size // _device_number 327 is_enable = bool(group_number > 1 and group_number & (group_number - 1) == 0) 328 return is_enable 329 330 def check_dim_reduce_enable(self): 331 r""" 332 Check dim_reduce enable. 333 334 Returns: 335 - **enable_dim_reduce** (bool) - Check whether the dimensionality reduction second-order training 336 algorithm is enabled. 337 """ 338 if not getattr(self.optimizer, "dim_reduce", None): 339 return False 340 return True 341 342 def _no_sens_impl(self, *inputs): 343 """construct implementation when the 'sens' parameter is passed in.""" 344 loss = self.network(*inputs) 345 sens = F.fill(loss.dtype, loss.shape, self.sens) 346 grads = self.grad_no_sens(self.network, self.weights)(*inputs) 347 grads = self.grad_reducer(grads) 348 if self.use_grad_accumulation: 349 loss = self.gradient_accumulation_process(loss, grads, sens, *inputs) 350 else: 351 if self.enable_dim_reduce: 352 loss = F.depend(loss, self.dim_reduce(loss, grads, sens, self.weights, self.weights_clone, *inputs)) 353 elif self.enable_adasum: 354 loss = F.depend(loss, self.adasum_process(loss, grads)) 355 else: 356 loss = F.depend(loss, self.optimizer(grads)) 357 358 def __init_dim_reduce(self): 359 """dim reduce algorithm init method.""" 360 local_pca_mat_path = self.auto_boost.local_pca_mat_path 361 rho = self.auto_boost.rho 362 gamma = self.auto_boost.gamma 363 alpha = self.auto_boost.alpha 364 sigma = self.auto_boost.sigma 365 _rank = _get_global_rank() 366 _rank_size = 1 if self.parallel_mode == ParallelMode.STAND_ALONE else get_group_size() 367 n_components = self.auto_boost.n_components 368 timeout = self.auto_boost.timeout 369 pca_mat = _load_local_pca_mat(local_pca_mat_path, timeout) 370 self.weights_clone = ParameterTuple(self.weights).clone(prefix="weights_clone", init="same") 371 self.dim_reduce = DimReduce(self.network, self.optimizer, self.weights, pca_mat, n_components, rho, gamma, 372 alpha, sigma, _rank, _rank_size) 373 374 def __init_adasum(self): 375 """adasum algorithm init method.""" 376 _rank = _get_global_rank() 377 _rank_size = get_group_size() 378 _device_number = self.auto_boost.device_number 379 self.device_number = _device_number 380 group_number = _rank_size // _device_number 381 382 self.server_rank = _rank % _device_number 383 parameter_rank_number = len(self.weights) // _device_number 384 self.start = [x * parameter_rank_number for x in range(_device_number)] 385 self.end = [(x + 1) * parameter_rank_number for x in range(_device_number)] 386 self.end[-1] = len(self.weights) 387 388 current_weights = self.weights[self.start[self.server_rank]: self.end[self.server_rank]] 389 self.grad_clone = ParameterTuple(current_weights).clone(prefix="delta_weight") 390 self.adasum = AdaSum(_rank, _device_number, group_number, self.grad_clone) 391 392 self.degree = int(self.degree // group_number) 393 group_list = [list(range(x * self.degree, (x + 1) * self.degree)) for x in range(group_number)] 394 current_index = _rank // _device_number 395 server_group_name = "allreduce_" + str(current_index) 396 create_group(server_group_name, group_list[current_index]) 397 self.grad_reducer = DistributedGradReducer(self.weights, self.mean, self.degree, group=server_group_name) 398 399 400class BoostTrainOneStepWithLossScaleCell(BoostTrainOneStepCell): 401 r""" 402 Boost Network training with loss scaling. 403 404 This is a training step with loss scaling. It takes a network, an optimizer and possibly a scale update 405 Cell as args. The loss scale value can be updated in both host side or device side. The 406 BoostTrainOneStepWithLossScaleCell will be compiled to be graph which takes `*inputs` as input data. 407 The Tensor type of `scale_sense` is acting as loss scaling value. If you want to update it on host side, 408 the value must be provided. If the Tensor type of `scale_sense` is not given, the loss scale update logic 409 must be provide by Cell type of `scale_sense`. 410 411 Args: 412 network (Cell): The training network. The network only supports single output. 413 optimizer (Cell): Optimizer for updating the weights. 414 scale_sense (Union[Tensor, Cell]): If this value is Cell type, the loss scaling update logic cell.If this value 415 is Tensor type, :func:`mindspore.nn.TrainOneStepWithLossScaleCell.set_sense_scale` can be called to update 416 loss scale factor, Tensor with shape :math:`()` or :math:`(1,)`. 417 418 Inputs: 419 - **\*inputs** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`. 420 421 Outputs: 422 Tuple of 3 Tensor, the loss, overflow flag and current loss scaling value. 423 424 - **loss** (Tensor) - Tensor with shape :math:`()`. 425 - **overflow** (Tensor) - Tensor with shape :math:`()`, type is bool. 426 - **loss scaling value** (Tensor) - Tensor with shape :math:`()` 427 428 Raises: 429 TypeError: If `scale_sense` is neither Cell nor Tensor. 430 ValueError: If shape of `scale_sense` is neither :math:`(1,)` nor :math:`()`. 431 432 Supported Platforms: 433 ``Ascend`` ``GPU`` 434 435 Examples: 436 >>> import numpy as np 437 >>> from mindspore import Tensor, Parameter, nn 438 >>> from mindspore import ops 439 >>> from mindspore.nn import WithLossCell 440 >>> from mindspore import dtype as mstype 441 >>> from mindspore import boost 442 >>> 443 >>> class Net(nn.Cell): 444 ... def __init__(self, in_features, out_features): 445 ... super(Net, self).__init__() 446 ... self.weight = Parameter(Tensor(np.ones([in_features, out_features]).astype(np.float32)), 447 ... name='weight') 448 ... self.matmul = ops.MatMul() 449 ... 450 ... def construct(self, x): 451 ... output = self.matmul(x, self.weight) 452 ... return output 453 ... 454 >>> size, in_features, out_features = 16, 16, 10 455 >>> #1) when the type of scale_sense is Cell: 456 >>> net = Net(in_features, out_features) 457 >>> loss = nn.MSELoss() 458 >>> optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) 459 >>> net_with_loss = WithLossCell(net, loss) 460 >>> manager = nn.DynamicLossScaleUpdateCell(loss_scale_value=2**12, scale_factor=2, scale_window=1000) 461 >>> train_network = boost.BoostTrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_sense=manager) 462 >>> input = Tensor(np.ones([out_features, in_features]), mstype.float32) 463 >>> labels = Tensor(np.ones([out_features,]), mstype.float32) 464 >>> output = train_network(input, labels) 465 >>> 466 >>> #2) when the type of scale_sense is Tensor: 467 >>> net = Net(in_features, out_features) 468 >>> loss = nn.MSELoss() 469 >>> optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) 470 >>> net_with_loss = WithLossCell(net, loss) 471 >>> inputs = Tensor(np.ones([size, in_features]).astype(np.float32)) 472 >>> label = Tensor(np.zeros([size, out_features]).astype(np.float32)) 473 >>> scaling_sens = Tensor(np.full((1), np.finfo(np.float32).max), dtype=mstype.float32) 474 >>> train_network = boost.BoostTrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_sense=scaling_sens) 475 >>> output = train_network(inputs, label) 476 """ 477 478 def __init__(self, network, optimizer, scale_sense): 479 super(BoostTrainOneStepWithLossScaleCell, self).__init__(network, optimizer, sens=None) 480 self.base = Tensor(1, mstype.float32) 481 self.reduce_sum = P.ReduceSum(keep_dims=False) 482 self.less_equal = P.LessEqual() 483 self.allreduce = P.AllReduce() 484 self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) 485 self.gpu_target = (context.get_context("device_target") == "GPU") 486 self.loss_scaling_manager = None 487 self.base0 = Tensor(0, mstype.int32) 488 self.reduce_all = P.ReduceAll(keep_dims=False) 489 self.logic_not = P.LogicalNot() 490 self.equal = P.Equal() 491 492 if self.auto_boost.boost_config.get("loss_scale_group", False): 493 self.enable_enhanced_amp = True 494 if not isinstance(scale_sense, Cell) or not hasattr(scale_sense, "set_loss_scale_status"): 495 raise TypeError("The scale_sense must be enhanced amp Cell, bug got {}".format(type(scale_sense))) 496 self.loss_scaling_manager = scale_sense 497 self.loss_scale_groups = scale_sense.loss_scale_groups 498 self._init_enhanced_amp() 499 self._do_keep_mix_fp32(self.network) 500 else: 501 self.enable_enhanced_amp = False 502 if isinstance(scale_sense, Cell): 503 self.loss_scaling_manager = scale_sense 504 self.scale_sense = Parameter(Tensor(scale_sense.get_loss_scale(), dtype=mstype.float32), 505 name="scale_sense") 506 elif isinstance(scale_sense, Tensor): 507 if scale_sense.shape == (1,) or scale_sense.shape == (): 508 self.scale_sense = Parameter(scale_sense, name='scale_sense') 509 else: 510 raise ValueError("The shape of scale_sense must be (1,) or (), \ 511 but got {}".format(scale_sense.shape)) 512 else: 513 raise TypeError("The scale_sense must be Cell or Tensor, but got {}".format(type(scale_sense))) 514 515 def construct(self, *inputs): 516 weights = self.weights 517 loss = self.network(*inputs) 518 519 if self.enable_enhanced_amp: 520 scaling_sens = F.fill(loss.dtype, loss.shape, 1) 521 grads = self.grad(self.network, weights)(*inputs, scaling_sens) 522 grads = self.grad_reducer(grads) 523 cond, scaling_sens = self._enhanced_amp_process_overflow_status(grads) 524 else: 525 scaling_sens = self.scale_sense 526 status, scaling_sens = self._start_overflow_check(loss, scaling_sens) 527 scaling_sens_filled = C.ones_like(loss) * F.cast(scaling_sens, F.dtype(loss)) 528 529 grads = self.grad(self.network, weights)(*inputs, scaling_sens_filled) 530 grads = self.hyper_map(F.partial(_grad_scale, scaling_sens), grads) 531 grads = self.grad_reducer(grads) 532 533 # get the overflow buffer 534 cond = self._get_overflow_status(status, grads) 535 overflow = self._process_loss_scale(cond) 536 # if there is no overflow, do optimize 537 if not overflow: 538 loss = self.__multi_update(loss, grads, scaling_sens_filled, *inputs) 539 return loss, cond, scaling_sens 540 541 def __multi_update(self, loss, grads, scaling_sens_filled, *inputs): 542 """enable multi-algorithm's process""" 543 if self.use_grad_accumulation: 544 loss = self.gradient_accumulation_process(loss, grads, scaling_sens_filled, *inputs) 545 else: 546 if self.enable_dim_reduce: 547 loss = F.depend(loss, self.dim_reduce(loss, grads, scaling_sens_filled, self.weights, 548 self.weights_clone, *inputs)) 549 elif self.enable_adasum: 550 loss = F.depend(loss, self.adasum_process(loss, grads)) 551 else: 552 loss = F.depend(loss, self.optimizer(grads)) 553 return loss 554 555 def _get_dynamic_overflow_status(self, param): 556 """ 557 Judge whether the current network overflows. 558 559 Inputs: 560 - **param** (Tensor) - Whether the overflow occurs or not. 561 562 Outputs: 563 bool, overflow value. 564 float, update ratio. 565 """ 566 flag_sum = self.equal(self.base0, param) 567 if self.reducer_flag: 568 flag_reduce = self.allreduce(flag_sum) 569 overflow = self.logic_not(self.reduce_all(flag_reduce)) 570 else: 571 overflow = self.logic_not(self.reduce_all(flag_sum)) 572 573 if overflow: 574 update_ratio = self.reduce_ratio 575 else: 576 update_ratio = self.growth_ratio 577 return overflow, update_ratio 578 579 def _enhanced_amp_process_overflow_status(self, grads): 580 """ 581 Enhanced hybrid precision update loss scale and update weights. 582 583 Inputs: 584 - **grads** (Tuple(Tensor)) - Tuple of gradients. 585 586 Outputs: 587 bool, overflow value. 588 float, loss scale value. 589 """ 590 overflow_global_flag = Tensor(0, mstype.int32) 591 layer = 0 592 loss_scale_temp = () 593 for param in self.overflow_status_list: 594 overflow, update_ratio = self._get_dynamic_overflow_status(param) 595 if overflow: 596 overflow_global_flag += 1 597 new_loss_scale_value = self.loss_scaling_manager.update_loss_scale_status(layer, update_ratio) 598 loss_scale_temp += (new_loss_scale_value,) * self.optimizer_loss_scale[layer] 599 layer += 1 600 if P.Less()(overflow_global_flag, self.base): 601 grads = self.hyper_map(F.partial(_grad_scale), loss_scale_temp, grads) 602 overflow_global_flag = F.depend(overflow_global_flag, self.optimizer(grads)) 603 return overflow_global_flag, loss_scale_temp[0] 604 605 def _set_sense_scale(self, sens): 606 """ 607 If the user has set the sens in the training process and wants to reassign the value, he can call 608 this function again to make modification, and sens needs to be of type Tensor. 609 610 Inputs: 611 - **sens** (Tensor) - The new sense whose shape and type are the same with original `scale_sense`. 612 """ 613 if self.scale_sense and isinstance(sens, Tensor): 614 self.scale_sense.set_data(sens) 615 else: 616 raise TypeError("The input type must be Tensor, but got {}".format(type(sens))) 617 618 def _start_overflow_check(self, pre_cond, compute_input): 619 """ 620 Start floating-point overflow detection. Create and clear the overflow detection state. 621 622 Specify the argument 'pre_cond' and 'compute_input' to make sure overflow status is cleared at the right time. 623 Taking this situation as an example, we need to execute state clearing after loss calculation and then detect 624 overflow in the process of gradient calculation. In this case, pre_cond should be the output of the loss 625 function, and compute_input should be the input of gradients-computing function. 626 627 Inputs: 628 - **pre_cond** (Tensor) - A precondition for starting overflow detection. It determines the executing order 629 of overflow state clearing and prior processions. It makes sure that the function 'start_overflow' 630 clears status after finishing the process of precondition. 631 - **compute_input** (object) - The input of subsequent process. Overflow detection should be performed on a 632 certain computation. Set `compute_input` as the input of the computation, to ensure overflow status is 633 cleared before executing the computation. 634 635 Outputs: 636 Tuple[object, object], the first value is False for GPU backend, while it is an instance of 637 NPUAllocFloatStatus for other backend. The status is used to detect overflow during overflow detection. 638 The second value is the same as the input of `compute_input`, but contains some information about the 639 execution order. 640 """ 641 status = Tensor([0] * 8, mstype.int32) 642 if not self.gpu_target: 643 status = F.depend(status, pre_cond) 644 # clear overflow buffer 645 clear_status = NPUClearFloatStatusV2()(status) 646 compute_input = F.depend(compute_input, clear_status) 647 return status, compute_input 648 649 def _get_overflow_status(self, status, compute_output): 650 """ 651 Get floating-point overflow status. 652 653 Get overflow results after executing the target process for overflow detection. 654 655 Inputs: 656 - **status** (object) - A status instance used to detect the overflow. 657 - **compute_output** - Overflow detection should be performed on a certain computation. Set `compute_output` 658 as the output of the computation, to ensure overflow status is acquired before executing the 659 computation. 660 661 Outputs: 662 bool, whether the overflow occurs or not. 663 """ 664 if not self.gpu_target: 665 status = F.depend(status, compute_output) 666 get_status = NPUGetFloatStatusV2()(status) 667 668 if self.is_distributed: 669 # sum overflow flag over devices 670 flag_reduce = self.allreduce(get_status) 671 # get_status not equal to [0]*8 means overflow 672 flag = self.equal(self.base0, flag_reduce) 673 status = F.depend(status, flag) 674 # distributed needs to skip allreduce to avoid its overflow affecting the next step 675 clear_status = NPUClearFloatStatusV2()(status) 676 flag = F.depend(flag, clear_status) 677 overall_finite = self.reduce_all(flag) 678 else: 679 status = F.depend(status, get_status) 680 clear_status = NPUClearFloatStatusV2()(status) 681 get_status = F.depend(get_status, clear_status) 682 flag = self.equal(self.base0, get_status) 683 overall_finite = self.reduce_all(flag) 684 overflow = self.logic_not(overall_finite) 685 else: 686 flag_sum = self.hyper_map(F.partial(_grad_overflow), compute_output) 687 flag_sum = P.AddN()(flag_sum) 688 # convert flag_sum to scalar 689 flag_sum = P.Reshape()(flag_sum, (())) 690 691 if self.is_distributed: 692 # sum overflow flag over devices 693 flag_reduce = self.allreduce(flag_sum) 694 overflow = self.less_equal(self.base, flag_reduce) 695 else: 696 overflow = self.less_equal(self.base, flag_sum) 697 return overflow 698 699 def _process_loss_scale(self, overflow): 700 """ 701 Calculate loss scale according to the overflow. 702 703 Inputs: 704 - **overflow** (bool) - Whether the overflow occurs or not. 705 706 Outputs: 707 bool, overflow value. 708 """ 709 if self.loss_scaling_manager is not None: 710 return self.loss_scaling_manager(self.scale_sense, overflow) 711 return overflow 712 713 def _init_enhanced_amp(self): 714 """ 715 Init enhanced hybrid precision. 716 """ 717 self.params_len = len(self.optimizer.params) 718 self.parent = list(range(self.params_len)) 719 self.layer_rank = [0 for _ in range(self.params_len)] 720 index = 0 721 loss_scale_number = len(self.loss_scale_groups) 722 for loss_scale_group in self.loss_scale_groups: 723 for i, _ in enumerate(loss_scale_group): 724 if i == 0: 725 index += 1 726 continue 727 self._union(index - 1, index) 728 index += 1 729 parent_set = list(set(self.parent)) 730 self.optimizer_loss_scale = [self.parent.count(x) for x in parent_set] 731 self.reduce_ratio = Tensor(1.0 / (2 ** 0.5), mstype.float32) 732 self.growth_ratio = Tensor(2 ** (1.0 / 1000.0), mstype.float32) 733 self.overflow_status_list = ParameterTuple(Parameter(Tensor(np.zeros(shape=[8]), mstype.int32), 734 name='mix_layer_status_{}'.format(x), requires_grad=False) 735 for x in range(loss_scale_number)) 736 self.loss_scaling_manager.set_loss_scale_status(loss_scale_number, self.loss_scaling_manager.get_loss_scale()) 737 738 def _get_root(self, i): 739 """ 740 Get parent id. 741 742 Args: 743 i (int): the current parameters's id. 744 745 Returns: 746 Number, the parent id. 747 """ 748 if self.parent[i] != self.parent[self.parent[i]]: 749 self.parent[i] = self.get_root(self.parent[i]) 750 return self.parent[i] 751 752 def _union(self, i, j): 753 """ 754 Aggregate parameters of the same category. 755 756 Args: 757 i (int): the last parameters's id. 758 j (int): the current parameters's id. 759 """ 760 i_root = self._get_root(i) 761 j_root = self._get_root(j) 762 763 if self.layer_rank[i_root] == self.layer_rank[j_root]: 764 self.parent[j_root] = i_root 765 self.layer_rank[i_root] += 1 766 elif self.layer_rank[i_root] > self.layer_rank[j_root]: 767 self.parent[j_root] = i_root 768 else: 769 self.parent[i_root] = j_root 770 771 def _do_keep_mix_fp32(self, network): 772 """ 773 Keep enhanced amp cell of type float32. 774 775 Args: 776 network (Cell): The training network. 777 """ 778 cells = network.name_cells() 779 change = False 780 for name in cells: 781 subcell = cells[name] 782 if subcell == network: 783 continue 784 if "GroupLossScaleManager" in subcell.cls_name: 785 network._cells[name] = _OutputToFloat16(subcell.to_float(mstype.float32)) # pylint: disable=W0212 786 change = True 787 else: 788 self._do_keep_mix_fp32(subcell) 789 if isinstance(network, SequentialCell) and change: 790 network.cell_list = list(network.cells()) 791