1# Copyright 2020-2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""Loss scale cell for loss scale training.""" 16from __future__ import absolute_import 17 18import os 19import mindspore.context as context 20from mindspore.context import ParallelMode 21from mindspore.parallel._utils import _get_enable_parallel_optimizer 22from mindspore import nn 23from mindspore.nn.wrap.cell_wrapper import TrainOneStepCell 24from mindspore.nn.cell import Cell 25from mindspore.common import Tensor 26from mindspore.common.sparse_tensor import RowTensorInner 27from mindspore.common.parameter import Parameter 28from mindspore.ops.operations.math_ops import NPUGetFloatStatusV2, NPUClearFloatStatusV2 29from mindspore.ops import functional as F 30from mindspore.ops import composite as C 31from mindspore.ops import operations as P 32from mindspore.ops.operations.nn_ops import AllFinite 33from mindspore.common import dtype as mstype 34from mindspore.common.api import jit 35from mindspore._c_expression import MSContext 36 37_grad_scale = C.MultitypeFuncGraph("grad_scale") 38reciprocal = P.Reciprocal() 39 40 41@_grad_scale.register("Tensor", "Tensor") 42def tensor_grad_scale(scale, grad): 43 return grad * F.cast(reciprocal(scale), F.dtype(grad)) 44 45 46@_grad_scale.register("Tensor", "RowTensor") 47def tensor_grad_scale_row_tensor(scale, grad): 48 return RowTensorInner(grad.indices, 49 grad.values * F.cast(reciprocal(scale), F.dtype(grad.values)), 50 grad.dense_shape) 51 52_grad_overflow = C.MultitypeFuncGraph("_grad_overflow") 53grad_overflow = P.FloatStatus() 54 55 56@_grad_overflow.register("Tensor") 57def _tensor_grad_overflow(grad): 58 return grad_overflow(grad) 59 60 61@_grad_overflow.register("RowTensor") 62def _tensor_grad_overflow_row_tensor(grad): 63 return grad_overflow(grad.values) 64 65 66_ascend_grad_overflow = C.MultitypeFuncGraph("_ascend_grad_overflow") 67ascend_grad_overflow = P.IsFinite() 68 69 70@_ascend_grad_overflow.register("Tensor") 71def _tensor_ascend_grad_overflow(grad): 72 status = ascend_grad_overflow(grad) 73 base = Tensor(1.0, dtype=mstype.float32) 74 output = base - status.all() 75 output = P.Reshape()(output, ((-1,))) 76 return output 77 78 79@_ascend_grad_overflow.register("RowTensor") 80def _tensor_ascend_grad_overflow_row_tensor(grad): 81 status = ascend_grad_overflow(grad.values) 82 base = Tensor(1.0, dtype=mstype.float32) 83 output = base - status.all() 84 output = P.Reshape()(output, ((1,))) 85 return output 86 87 88class DynamicLossScaleUpdateCell(Cell): 89 r""" 90 Dynamic Loss scale update cell. 91 92 For loss scaling training, the initial loss scaling value will be set to be `loss_scale_value`. 93 In each training step, the loss scaling value will be decreased by `loss_scale`/`scale_factor` 94 when there is an overflow. And it will be increased by `loss_scale` * `scale_factor` if there is no 95 overflow for a continuous `scale_window` steps. 96 97 `get_update_cell` method of :class:`mindspore.amp.DynamicLossScaleManager` will return this class. It will be called 98 by :class:`mindspore.nn.TrainOneStepWithLossScaleCell` during training to update loss scale. 99 100 Args: 101 loss_scale_value (float): Initializes loss scale. 102 scale_factor (int): Coefficient of increase and decrease. 103 scale_window (int): Maximum continuous training steps that do not have overflow to increase the loss scale. 104 105 Inputs: 106 - **loss_scale** (Tensor) - The loss scale value during training with shape :math:`()`. 107 - **overflow** (bool) - Whether the overflow occurs or not. 108 109 Outputs: 110 bool, the input `overflow`. 111 112 Supported Platforms: 113 ``Ascend`` ``GPU`` 114 115 Examples: 116 >>> import numpy as np 117 >>> import mindspore 118 >>> from mindspore import Tensor, Parameter, nn, ops 119 >>> 120 >>> class Net(nn.Cell): 121 ... def __init__(self, in_features, out_features): 122 ... super(Net, self).__init__() 123 ... self.weight = Parameter(Tensor(np.ones([in_features, out_features]).astype(np.float32)), 124 ... name='weight') 125 ... self.matmul = ops.MatMul() 126 ... 127 ... def construct(self, x): 128 ... output = self.matmul(x, self.weight) 129 ... return output 130 ... 131 >>> in_features, out_features = 16, 10 132 >>> net = Net(in_features, out_features) 133 >>> loss = nn.MSELoss() 134 >>> optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) 135 >>> net_with_loss = nn.WithLossCell(net, loss) 136 >>> manager = nn.DynamicLossScaleUpdateCell(loss_scale_value=2**12, scale_factor=2, scale_window=1000) 137 >>> train_network = nn.TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_sense=manager) 138 >>> input = Tensor(np.ones([out_features, in_features]), mindspore.float32) 139 >>> labels = Tensor(np.ones([out_features,]), mindspore.float32) 140 >>> output = train_network(input, labels) 141 """ 142 143 def __init__(self, 144 loss_scale_value, 145 scale_factor, 146 scale_window): 147 super(DynamicLossScaleUpdateCell, self).__init__() 148 149 self.scale_window = Tensor(scale_window, dtype=mstype.int32) 150 self.scale_factor = Tensor(scale_factor, dtype=mstype.float32) 151 self.loss_scale_value = loss_scale_value 152 153 self.cur_iter = Parameter(Tensor(1, dtype=mstype.int32), name="current_iterator_step") 154 self.last_overflow_iter = Parameter(Tensor(0, dtype=mstype.int32), name="last_overflow_iterator_step") 155 self.select = P.Select() 156 self.max = P.Maximum() 157 self.minimum_loss_scale = Tensor(1.0, dtype=mstype.float32) 158 self.reciprocal = P.Reciprocal() 159 self.less_equal = P.LessEqual() 160 self.logic_and = P.LogicalAnd() 161 self.logic_not = P.LogicalNot() 162 self.logic_or = P.LogicalOr() 163 self.const_true = Tensor(True, dtype=mstype.bool_) 164 165 def get_loss_scale(self): 166 """ 167 Get Loss Scale value. 168 169 Returns: 170 float, the loss scale value. 171 172 Examples: 173 >>> from mindspore import nn 174 >>> manager = nn.DynamicLossScaleUpdateCell(loss_scale_value=212, scale_factor=2, scale_window=1000) 175 >>> output = manager.get_loss_scale() 176 >>> print(output) 177 212 178 """ 179 return self.loss_scale_value 180 181 def construct(self, loss_scale, overflow): 182 overflow_cond = overflow 183 loss_scale_on_overflow = self.select(overflow_cond, self.max(loss_scale * self.reciprocal(self.scale_factor), 184 self.minimum_loss_scale), loss_scale) 185 should_inc = self.less_equal(self.scale_window, self.cur_iter - self.last_overflow_iter) 186 last_iter_cond = self.logic_or(overflow_cond, should_inc) 187 last_overflow_iter = self.select(last_iter_cond, self.cur_iter, self.last_overflow_iter) 188 last_iter = F.assign(self.last_overflow_iter, last_overflow_iter) 189 update_scale_cond = self.logic_and(should_inc, self.logic_not(overflow_cond)) 190 scale_mul_res = loss_scale_on_overflow * self.scale_factor 191 scaled_loss_scale = self.select(update_scale_cond, scale_mul_res, loss_scale_on_overflow) 192 F.assign(loss_scale, scaled_loss_scale) 193 inc_cur_iter = self.cur_iter + 1 194 inc_cur_iter = F.depend(inc_cur_iter, last_iter) 195 F.assign(self.cur_iter, inc_cur_iter) 196 return overflow 197 198 199class FixedLossScaleUpdateCell(Cell): 200 """ 201 Update cell with fixed loss scaling value. 202 203 `get_update_cell` method of :class:`mindspore.amp.FixedLossScaleManager` will return this class. It will be called 204 by :class:`mindspore.nn.TrainOneStepWithLossScaleCell` during trainning. 205 206 Args: 207 loss_scale_value (float): Initializes loss scale. 208 209 Inputs: 210 - **loss_scale** (Tensor) - The loss scale value during training with shape :math:`()`, it is ignored in this 211 class. 212 - **overflow** (bool) - Whether the overflow occurs or not. 213 214 Outputs: 215 bool, the input `overflow`. 216 217 Supported Platforms: 218 ``Ascend`` ``GPU`` 219 220 Examples: 221 >>> import numpy as np 222 >>> import mindspore 223 >>> from mindspore import Tensor, Parameter, nn, ops 224 >>> 225 >>> class Net(nn.Cell): 226 ... def __init__(self, in_features, out_features): 227 ... super(Net, self).__init__() 228 ... self.weight = Parameter(Tensor(np.ones([in_features, out_features]).astype(np.float32)), 229 ... name='weight') 230 ... self.matmul = ops.MatMul() 231 ... 232 ... def construct(self, x): 233 ... output = self.matmul(x, self.weight) 234 ... return output 235 ... 236 >>> in_features, out_features = 16, 10 237 >>> net = Net(in_features, out_features) 238 >>> loss = nn.MSELoss() 239 >>> optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) 240 >>> net_with_loss = nn.WithLossCell(net, loss) 241 >>> manager = nn.FixedLossScaleUpdateCell(loss_scale_value=2**12) 242 >>> train_network = nn.TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_sense=manager) 243 >>> input = Tensor(np.ones([out_features, in_features]), mindspore.float32) 244 >>> labels = Tensor(np.ones([out_features,]), mindspore.float32) 245 >>> output = train_network(input, labels) 246 """ 247 248 def __init__(self, loss_scale_value): 249 super(FixedLossScaleUpdateCell, self).__init__() 250 self.loss_scale_value = loss_scale_value 251 252 def get_loss_scale(self): 253 """ 254 Get Loss Scale value. 255 256 Returns: 257 float, the loss scale value. 258 259 Examples: 260 >>> from mindspore import nn 261 >>> manager = nn.FixedLossScaleUpdateCell(loss_scale_value=212) 262 >>> output = manager.get_loss_scale() 263 >>> print(output) 264 212 265 """ 266 return self.loss_scale_value 267 268 def construct(self, _, overflow): 269 return overflow 270 271 272class TrainOneStepWithLossScaleCell(TrainOneStepCell): 273 r""" 274 Network training with loss scaling. 275 276 This is a training step with loss scaling. It takes a network, an optimizer and a scale update Cell(or a Tensor) as 277 args. The loss scale value can be updated in both host side or device side. If you want to update it on 278 host side, using a value of Tensor type as `scale_sense`, otherwise, using a Cell instance for updating loss 279 scale as `scale_sense`. 280 281 Args: 282 network (Cell): The training network. The network only supports single output. 283 optimizer (Cell): Optimizer for updating the network parameters. 284 scale_sense (Union[Tensor, Cell]): If this value is a Cell, it will be called by `TrainOneStepWithLossScaleCell` 285 to update loss scale. If this value is a Tensor, the loss scale can be modified by `set_sense_scale`, 286 the shape should be :math:`()` or :math:`(1,)`. 287 288 Inputs: 289 - **\*inputs** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`. 290 291 Outputs: 292 Tuple of 3 Tensor, the loss, overflow flag and current loss scale value. 293 294 - **loss** (Tensor) - A scalar, the loss value. 295 - **overflow** (Tensor) - A scalar, whether overflow occur or not, the type is bool. 296 - **loss scale** (Tensor) - The loss scale value, the shape is :math:`()` or :math:`(1,)`. 297 298 Raises: 299 TypeError: If `scale_sense` is neither Cell nor Tensor. 300 ValueError: If shape of `scale_sense` is neither :math:`(1,)` nor :math:`()`. 301 302 Supported Platforms: 303 ``Ascend`` ``GPU`` 304 305 Examples: 306 >>> import numpy as np 307 >>> import mindspore 308 >>> from mindspore import Tensor, Parameter, nn, ops 309 >>> 310 >>> class Net(nn.Cell): 311 ... def __init__(self, in_features, out_features): 312 ... super(Net, self).__init__() 313 ... self.weight = Parameter(Tensor(np.ones([in_features, out_features]).astype(np.float32)), 314 ... name='weight') 315 ... self.matmul = ops.MatMul() 316 ... 317 ... def construct(self, x): 318 ... output = self.matmul(x, self.weight) 319 ... return output 320 ... 321 >>> size, in_features, out_features = 16, 16, 10 322 >>> #1) when the type of scale_sense is Cell: 323 >>> net = Net(in_features, out_features) 324 >>> loss_fn = nn.MSELoss() 325 >>> optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) 326 >>> net_with_loss = nn.WithLossCell(net, loss_fn) 327 >>> input = Tensor(np.ones([out_features, in_features]), mindspore.float32) 328 >>> labels = Tensor(np.ones([out_features,]), mindspore.float32) 329 >>> loss = net_with_loss(input, labels) 330 >>> manager = nn.DynamicLossScaleUpdateCell(loss_scale_value=2**12, scale_factor=2, scale_window=1000) 331 >>> train_network = nn.TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_sense=manager) 332 >>> status = Tensor([0] * 8, mindspore.int32) 333 >>> scaling_sens = train_network.scale_sense 334 >>> scaling_sens_filled = ops.ones_like(loss) * ops.cast(scaling_sens, ops.dtype(loss)) 335 >>> grads = train_network.grad(train_network.network, train_network.weights)(input, labels, scaling_sens_filled) 336 >>> grads = train_network.grad_reducer(grads) 337 >>> cond = train_network.get_overflow_status(status, grads) 338 >>> overflow = train_network.process_loss_scale(cond) 339 >>> 340 >>> #2) when the type of scale_sense is Tensor: 341 >>> net = Net(in_features, out_features) 342 >>> loss = nn.MSELoss() 343 >>> optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) 344 >>> net_with_loss = nn.WithLossCell(net, loss) 345 >>> inputs = Tensor(np.ones([size, in_features]).astype(np.float32)) 346 >>> label = Tensor(np.zeros([size, out_features]).astype(np.float32)) 347 >>> scaling_sens = Tensor([1024], dtype=mindspore.float32) 348 >>> train_network = nn.TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_sense=scaling_sens) 349 >>> scaling_sens = Tensor([1], dtype=mstype.float32) 350 >>> train_network.set_sense_scale(scaling_sens) 351 >>> output = train_network(inputs, label) 352 >>> 353 >>> # update scaling sens and train the network 354 >>> scaling_sens = Tensor([1], dtype=mindspore.float32) 355 >>> train_network.set_sense_scale(scaling_sens) 356 >>> output = train_network(inputs, label) 357 """ 358 def __init__(self, network, optimizer, scale_sense): 359 super(TrainOneStepWithLossScaleCell, self).__init__(network, optimizer, sens=None) 360 self.hyper_map = C.HyperMap() 361 self.base = Tensor(1, mstype.float32) 362 self.base0 = Tensor(0, mstype.int32) 363 self.reduce_sum = P.ReduceSum(keep_dims=False) 364 self.reduce_all = P.ReduceAll(keep_dims=False) 365 self.less_equal = P.LessEqual() 366 self.equal = P.Equal() 367 self.logic_not = P.LogicalNot() 368 self.allreduce = P.AllReduce() 369 self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) 370 self.gpu_target = (context.get_context("device_target") == "GPU") 371 self.ascend_910a_target = (MSContext.get_instance().get_ascend_soc_version() == 'ascend910') 372 self.ascend_910bc_target = (MSContext.get_instance().get_ascend_soc_version() in ['ascend910b', 'ascend910c']) 373 self.loss_scaling_manager = None 374 self._ascend_check_overflow_mode = os.environ.get('MS_ASCEND_CHECK_OVERFLOW_MODE') 375 376 self.enable_allfinite = False 377 runtime_conf = os.environ.get('MS_DEV_RUNTIME_CONF') 378 global_jit_config = context.get_jit_config() 379 if runtime_conf is not None and ("all_finite:True" in runtime_conf or "all_finite:true" in runtime_conf): 380 self.enable_allfinite = True 381 elif runtime_conf is not None and ("all_finite:False" in runtime_conf or "all_finite:false" in runtime_conf): 382 self.enable_allfinite = False 383 elif global_jit_config: 384 self.enable_allfinite = global_jit_config["jit_level"] == "O0" or global_jit_config["jit_level"] == "O1" 385 386 if isinstance(scale_sense, Cell): 387 self.loss_scaling_manager = scale_sense 388 self.scale_sense = Parameter(Tensor(scale_sense.get_loss_scale(), dtype=mstype.float32), 389 name="scale_sense") 390 elif isinstance(scale_sense, Tensor): 391 if scale_sense.shape == (1,) or scale_sense.shape == (): 392 self.scale_sense = Parameter(scale_sense, name='scale_sense') 393 else: 394 raise ValueError("For 'TrainOneStepWithLossScaleCell', " 395 "the shape of 'scale_sense' must be (1,) or (), but got {}." 396 .format(scale_sense.shape)) 397 else: 398 raise TypeError("For 'TrainOneStepWithLossScaleCell', " 399 "the 'scale_sense' must be Cell or Tensor, but got 'scale_sense' type: {}." 400 .format(type(scale_sense))) 401 self.enable_tuple_broaden = True 402 self._get_attr_from_cell(network) 403 404 def construct(self, *inputs): 405 weights = self.weights 406 loss = self.network(*inputs) 407 scaling_sens = self.scale_sense 408 status = Tensor([0] * 8, mstype.int32) 409 410 scaling_sens_filled = C.ones_like(loss) * F.cast(scaling_sens, F.dtype(loss)) 411 grads = self.grad(self.network, weights)(*inputs, scaling_sens_filled) 412 grads = self.hyper_map(F.partial(_grad_scale, scaling_sens), grads) 413 # apply grad reducer on grads 414 grads = self.grad_reducer(grads) 415 416 # get the overflow buffer 417 cond = self.get_overflow_status(status, grads) 418 overflow = self.process_loss_scale(cond) 419 # if there is no overflow, do optimize 420 if not overflow: 421 loss = F.depend(loss, self.optimizer(grads)) 422 return loss, cond, scaling_sens 423 424 def set_sense_scale(self, sens): 425 """ 426 If the user has set the `scale_sense` of Tensor type, he can call this function to reassign the value. 427 428 Args: 429 sens(Tensor): The new sense whose shape and type are the same with original `scale_sense`. 430 """ 431 if self.scale_sense and isinstance(sens, Tensor): 432 self.scale_sense.set_data(sens) 433 else: 434 raise TypeError("For 'TrainOneStepWithLossScaleCell', " 435 "the type of 'sens' must be Tensor, but got {}".format(type(sens))) 436 437 def start_overflow_check(self, pre_cond, compute_input): 438 """ 439 Start floating-point overflow detection. Create and clear the overflow detection state. 440 441 Specify the argument 'pre_cond' and 'compute_input' to make sure overflow status is cleared at the right time. 442 Taking this situation as an example, we need to execute state clearing after loss calculation and then detect 443 overflow in the process of gradient calculation. In this case, pre_cond should be the output of the loss 444 function, and compute_input should be the input of gradients-computing function. User-defined training network 445 based on this class can also call this interface to process the overflow. 446 447 Args: 448 pre_cond(Tensor): A precondition for starting overflow detection. It determines the executing order 449 of overflow state clearing and prior processions. It makes sure that the function 'start_overflow' 450 clears status after finishing the process of precondition. 451 compute_input(object): The input of subsequent process. Overflow detection should be performed on a 452 certain computation. Set `compute_input` as the input of the computation, to ensure overflow status is 453 cleared before executing the computation. 454 455 Returns: 456 Tuple[object, object], the first output is used to control the execution sequence. To ensure that the 457 `start_overflow_check` is executed before get_overflow_status after compilation optimization is performed. 458 This value should be used as the first input of get_overflow_status. The second output is the same as 459 the input of compute_input, used to control the execution sequence, and make ensure that the overflow flag 460 is cleaned up when the function returns. 461 """ 462 status = Tensor([0] * 8, mstype.int32) 463 if self.ascend_910a_target or (self.ascend_910bc_target and \ 464 self._ascend_check_overflow_mode == "SATURATION_MODE"): 465 status = F.depend(status, pre_cond) 466 # clear overflow buffer 467 clear_status = NPUClearFloatStatusV2()(status) 468 compute_input = F.depend(compute_input, clear_status) 469 return status, compute_input 470 471 def _check_overflow_status_on_infnan_mode(self, grad_overflow_check_func, compute_output): 472 """check overflow status on infnan mode.""" 473 flag_sum = self.hyper_map(F.partial(grad_overflow_check_func), compute_output) 474 flag_sum = P.AddN()(flag_sum) 475 # convert flag_sum to scalar 476 flag_sum = P.Reshape()(flag_sum, (())) 477 return flag_sum 478 479 def _get_distributed_overflow_status_on_infnan_mode(self, grad_overflow_check_func, compute_output): 480 """converge the distributed overflow status on infnan mode.""" 481 flag_sum = self._check_overflow_status_on_infnan_mode(grad_overflow_check_func, compute_output) 482 483 if self.is_distributed: 484 # sum overflow flag over devices 485 flag_reduce = self.allreduce(flag_sum) 486 overflow = self.less_equal(self.base, flag_reduce) 487 else: 488 overflow = self.less_equal(self.base, flag_sum) 489 return overflow 490 491 def _get_distributed_overflow_status_on_infnan_enable_allfinite(self, compute_output): 492 """check overflow status on infnan kernel mode.""" 493 overflow = AllFinite()(compute_output) 494 495 if self.is_distributed: 496 overflow = P.Cast()(overflow, mstype.int8) 497 overflow = P.Cast()(self.allreduce(overflow), mstype.bool_) 498 return overflow 499 500 def _get_gpu_overflow_status(self, compute_output): 501 """get overflow status of gpu.""" 502 overflow = self._get_distributed_overflow_status_on_infnan_mode(_grad_overflow, compute_output) 503 return overflow 504 505 def _get_ascend_overflow_status_on_infnan_mode(self, compute_output): 506 """get overflow status of ascend on infnan mode.""" 507 overflow = False 508 if self.enable_allfinite: 509 overflow = self._get_distributed_overflow_status_on_infnan_enable_allfinite(compute_output) 510 else: 511 overflow = self._get_distributed_overflow_status_on_infnan_mode(_ascend_grad_overflow, compute_output) 512 return overflow 513 514 def _get_ascend_overflow_status_on_saturation_mode(self, status, compute_output): 515 """get overflow status of ascend on saturation mode""" 516 status = F.depend(status, compute_output) 517 get_status = NPUGetFloatStatusV2()(status) 518 519 if self.is_distributed: 520 # sum overflow flag over devices 521 flag_reduce = self.allreduce(get_status) 522 # get_status not equal to [0]*8 means overflow 523 flag = self.equal(self.base0, flag_reduce) 524 status = F.depend(status, flag) 525 # distributed needs to skip allreduce to avoid its overflow affecting the next step 526 clear_status = NPUClearFloatStatusV2()(status) 527 flag = F.depend(flag, clear_status) 528 overall_finite = self.reduce_all(flag) 529 else: 530 status = F.depend(status, get_status) 531 clear_status = NPUClearFloatStatusV2()(status) 532 get_status = F.depend(get_status, clear_status) 533 flag = self.equal(self.base0, get_status) 534 overall_finite = self.reduce_all(flag) 535 overflow = self.logic_not(overall_finite) 536 return overflow 537 538 @jit 539 def get_overflow_status(self, status, compute_output): 540 """ 541 Get floating-point overflow status. 542 543 Get overflow results after executing the target process for overflow detection. User-defined training network 544 based on this class can also call this interface to process the overflow. 545 546 Args: 547 status (object): To control the execution sequence with start_overflow_check, it should be set as the first 548 output of start_overflow_check. 549 compute_output: Overflow detection should be performed in a certain computation process. Set 550 `compute_output` as the output of the computation process. 551 552 Returns: 553 bool, whether the overflow occurs or not. 554 """ 555 if self.gpu_target: 556 overflow = self._get_gpu_overflow_status(compute_output) 557 elif self.ascend_910bc_target: 558 if self._ascend_check_overflow_mode == "SATURATION_MODE": 559 overflow = self._get_ascend_overflow_status_on_saturation_mode(status, compute_output) 560 else: 561 overflow = self._get_ascend_overflow_status_on_infnan_mode(compute_output) 562 else: # ascend_910a_target 563 overflow = self._get_ascend_overflow_status_on_saturation_mode(status, compute_output) 564 return overflow 565 566 def process_loss_scale(self, overflow): 567 """ 568 Calculate loss scale according to the overflow. 569 570 User-defined training network based on this class can also call this interface to process the overflow. 571 572 Args: 573 overflow(bool): Whether the overflow occurs or not. 574 575 Returns: 576 bool, the input overflow value. 577 """ 578 if self.loss_scaling_manager is not None: 579 return self.loss_scaling_manager(self.scale_sense, overflow) 580 return overflow 581 582 583grad_scale = C.MultitypeFuncGraph("grad_scale") 584shard_grad_scale = C.MultitypeFuncGraph("shard_grad_scale") 585reciprocal = P.Reciprocal() 586 587 588@grad_scale.register("Tensor", "Tensor", "Tensor") 589def tensor_grad_scale_pipeline(scale, grad, accu_grad): 590 accu_grad = F.depend(accu_grad, grad) 591 new_grad = accu_grad * reciprocal(scale) 592 accu_grad = F.depend(accu_grad, new_grad) 593 zeros = F.tensor_mul(accu_grad, 0.0) 594 new_grad = F.depend(new_grad, F.assign(accu_grad, zeros)) 595 return new_grad 596 597 598@shard_grad_scale.register("Tensor", "Tensor", "Tensor") 599def tensor_shard_grad_scale_pipeline(scale, grad, accu_grad): 600 new_grad = grad * reciprocal(scale) 601 accu_grad = F.depend(accu_grad, new_grad) 602 new_grad = F.depend(new_grad, F.assign(accu_grad, F.zeros_like(accu_grad))) 603 return new_grad 604 605 606class _TrainGradAccuWithLossScaleCell(TrainOneStepCell): 607 """ 608 Append an optimizer to the training network after that the construct 609 function can be called to create the backward graph. 610 611 Args: 612 network (Cell): The training network. Note that loss function should have been added. 613 optimizer (Optimizer): Optimizer for updating the weights. 614 scale_sense (Cell): Cell to do the loss scale. 615 """ 616 def __init__(self, network, optimizer, scale_sense): 617 super(_TrainGradAccuWithLossScaleCell, self).__init__(network, optimizer, sens=None) 618 self.network = network 619 self.network.add_flags(defer_inline=True) 620 self.weights = optimizer.parameters 621 self.accu_grads = self.weights.clone(prefix="accu_grads", init="zeros") 622 self.optimizer = optimizer 623 self.grad = C.GradOperation(get_by_list=True, sens_param=True) 624 self.grad_reducer = nn.Identity() 625 self.degree = 1 626 self.cast = P.Cast() 627 self.alloc_status = P.NPUAllocFloatStatus() 628 self.get_status = P.NPUGetFloatStatus() 629 self.clear_before_grad = P.NPUClearFloatStatus() 630 self.reduce_sum = P.ReduceSum(keep_dims=False) 631 if self.parallel_mode not in [ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL]: 632 raise ValueError(f"ParallelMode must be one of " 633 f"[ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL], but found " 634 f"{self.parallel_mode}.") 635 self.allreduce = P.AllReduce() 636 self.base = Tensor(1, mstype.float32) 637 self.less_equal = P.LessEqual() 638 self.hyper_map = C.HyperMap() 639 self.reshape = P.Reshape() 640 self.loss_scaling_manager = None 641 if isinstance(scale_sense, Cell): 642 self.loss_scaling_manager = scale_sense 643 self.scale_sense = Parameter(Tensor(scale_sense.get_loss_scale(), dtype=mstype.float32), 644 name="scale_sense") 645 elif isinstance(scale_sense, Tensor): 646 if scale_sense.shape == (1,) or scale_sense.shape == (): 647 self.scale_sense = Parameter(scale_sense, name='scale_sense') 648 else: 649 raise ValueError("The shape of 'scale_sense' must be (1,) or (), but got {}" 650 .format(scale_sense.shape)) 651 else: 652 raise TypeError("The 'scale_sense' must be Cell or Tensor, but got {}".format(type(scale_sense))) 653 self.opt_shard = _get_enable_parallel_optimizer() 654 655 def construct(self, *inputs): 656 loss = self.network(*inputs) 657 scaling_sens = self.scale_sense 658 init = self.alloc_status() 659 scaling_sens_filled = C.ones_like(loss) * F.cast(scaling_sens, F.dtype(loss)) 660 scaling_sens_filled = F.depend(scaling_sens_filled, self.clear_before_grad(init)) 661 grads = self.grad(self.network, self.weights)(*inputs, scaling_sens_filled) 662 init = F.depend(init, grads) 663 get_status = self.get_status(init) 664 init = F.depend(init, get_status) 665 flag_sum = self.reduce_sum(init, (0,)) 666 if self.opt_shard: 667 grads = self.grad_reducer(grads) 668 grads = self.hyper_map(F.partial(shard_grad_scale, scaling_sens * self.degree), grads, self.accu_grads) 669 else: 670 accu_grads = self.grad_reducer(self.accu_grads) 671 grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads, accu_grads) 672 # sum overflow flag over devices 673 flag_reduce = self.allreduce(flag_sum) 674 cond = self.less_equal(self.base, flag_reduce) 675 overflow = cond 676 if self.loss_scaling_manager is not None: 677 overflow = self.loss_scaling_manager(self.scale_sense, cond) 678 if not overflow: 679 self.optimizer(grads) 680 return (loss, overflow, scaling_sens) 681