1# Copyright 2020-2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""Loss scale cell for loss scale training.""" 16import mindspore.context as context 17from mindspore.context import ParallelMode 18from mindspore.parallel._utils import _get_enable_parallel_optimizer 19from .cell_wrapper import TrainOneStepCell 20from ..cell import Cell 21from ...common import Tensor, RowTensor 22from ...common.parameter import Parameter 23from ...ops import functional as F 24from ...ops import composite as C 25from ...ops import operations as P 26from ...common import dtype as mstype 27 28_grad_scale = C.MultitypeFuncGraph("grad_scale") 29reciprocal = P.Reciprocal() 30 31 32@_grad_scale.register("Tensor", "Tensor") 33def tensor_grad_scale(scale, grad): 34 return grad * F.cast(reciprocal(scale), F.dtype(grad)) 35 36 37@_grad_scale.register("Tensor", "RowTensor") 38def tensor_grad_scale_row_tensor(scale, grad): 39 return RowTensor(grad.indices, 40 grad.values * F.cast(reciprocal(scale), F.dtype(grad.values)), 41 grad.dense_shape) 42 43_grad_overflow = C.MultitypeFuncGraph("_grad_overflow") 44grad_overflow = P.FloatStatus() 45 46 47@_grad_overflow.register("Tensor") 48def _tensor_grad_overflow(grad): 49 return grad_overflow(grad) 50 51 52@_grad_overflow.register("RowTensor") 53def _tensor_grad_overflow_row_tensor(grad): 54 return grad_overflow(grad.values) 55 56 57class DynamicLossScaleUpdateCell(Cell): 58 r""" 59 Dynamic Loss scale update cell. 60 61 For loss scaling training, the initial loss scaling value will be set to be `loss_scale_value`. 62 In each training step, the loss scaling value will be updated by loss scaling value/`scale_factor` 63 when there is an overflow. And it will be increased by loss scaling value * `scale_factor` if there is no 64 overflow for a continuous `scale_window` steps. This cell is used for Graph mode training in which all 65 logic will be executed on device side(Another training mode is normal(non-sink) mode in which some logic will be 66 executed on host). 67 68 Args: 69 loss_scale_value (float): Initializes loss scale. 70 scale_factor (int): Coefficient of increase and decrease. 71 scale_window (int): Maximum continuous training steps that do not have overflow. 72 73 Inputs: 74 - **loss_scale** (Tensor) - The loss scale value during training with shape :math:`()`. 75 - **overflow** (bool) - Whether the overflow occurs or not. 76 77 Outputs: 78 bool, the input `overflow`. 79 80 Raises: 81 TypeError: If dtype of `inputs` or `label` is neither float16 nor float32. 82 83 Supported Platforms: 84 ``Ascend`` ``GPU`` 85 86 Examples: 87 >>> import numpy as np 88 >>> from mindspore import Tensor, Parameter, nn 89 >>> import mindspore.ops as ops 90 >>> 91 >>> class Net(nn.Cell): 92 ... def __init__(self, in_features, out_features): 93 ... super(Net, self).__init__() 94 ... self.weight = Parameter(Tensor(np.ones([in_features, out_features]).astype(np.float32)), 95 ... name='weight') 96 ... self.matmul = ops.MatMul() 97 ... 98 ... def construct(self, x): 99 ... output = self.matmul(x, self.weight) 100 ... return output 101 ... 102 >>> in_features, out_features = 16, 10 103 >>> net = Net(in_features, out_features) 104 >>> loss = nn.MSELoss() 105 >>> optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) 106 >>> net_with_loss = nn.WithLossCell(net, loss) 107 >>> manager = nn.DynamicLossScaleUpdateCell(loss_scale_value=2**12, scale_factor=2, scale_window=1000) 108 >>> train_network = nn.TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_sense=manager) 109 >>> input = Tensor(np.ones([out_features, in_features]), mindspore.float32) 110 >>> labels = Tensor(np.ones([out_features,]), mindspore.float32) 111 >>> output = train_network(input, labels) 112 """ 113 114 def __init__(self, 115 loss_scale_value, 116 scale_factor, 117 scale_window): 118 super(DynamicLossScaleUpdateCell, self).__init__() 119 120 self.scale_window = Tensor(scale_window, dtype=mstype.int32) 121 self.scale_factor = Tensor(scale_factor, dtype=mstype.float32) 122 self.loss_scale_value = loss_scale_value 123 124 self.cur_iter = Parameter(Tensor(1, dtype=mstype.int32), name="current_iterator_step") 125 self.last_overflow_iter = Parameter(Tensor(0, dtype=mstype.int32), name="last_overflow_iterator_step") 126 self.select = P.Select() 127 self.max = P.Maximum() 128 self.minimum_loss_scale = Tensor(1.0, dtype=mstype.float32) 129 self.reciprocal = P.Reciprocal() 130 self.less_equal = P.LessEqual() 131 self.logic_and = P.LogicalAnd() 132 self.logic_not = P.LogicalNot() 133 self.logic_or = P.LogicalOr() 134 self.const_true = Tensor(True, dtype=mstype.bool_) 135 136 def get_loss_scale(self): 137 """ 138 Get Loss Scale value. 139 """ 140 return self.loss_scale_value 141 142 def construct(self, loss_scale, overflow): 143 overflow_cond = overflow 144 loss_scale_on_overflow = self.select(overflow_cond, self.max(loss_scale * self.reciprocal(self.scale_factor), 145 self.minimum_loss_scale), loss_scale) 146 should_inc = self.less_equal(self.scale_window, self.cur_iter - self.last_overflow_iter) 147 last_iter_cond = self.logic_or(overflow_cond, should_inc) 148 last_overflow_iter = self.select(last_iter_cond, self.cur_iter, self.last_overflow_iter) 149 last_iter = F.assign(self.last_overflow_iter, last_overflow_iter) 150 update_scale_cond = self.logic_and(should_inc, self.logic_not(overflow_cond)) 151 scale_mul_res = loss_scale_on_overflow * self.scale_factor 152 scaled_loss_scale = self.select(update_scale_cond, scale_mul_res, loss_scale_on_overflow) 153 F.assign(loss_scale, scaled_loss_scale) 154 inc_cur_iter = self.cur_iter + 1 155 inc_cur_iter = F.depend(inc_cur_iter, last_iter) 156 F.assign(self.cur_iter, inc_cur_iter) 157 return overflow 158 159 160class FixedLossScaleUpdateCell(Cell): 161 """ 162 Static scale update cell, the loss scaling value will not be updated. 163 164 For usage, refer to `DynamicLossScaleUpdateCell`. 165 166 Args: 167 loss_scale_value (float): Initializes loss scale. 168 169 Inputs: 170 - **loss_scale** (Tensor) - The loss scale value during training with shape :math:`()`, that will be ignored. 171 - **overflow** (bool) - Whether the overflow occurs or not. 172 173 Outputs: 174 bool, the input `overflow`. 175 176 Supported Platforms: 177 ``Ascend`` ``GPU`` 178 179 Examples: 180 >>> import numpy as np 181 >>> from mindspore import Tensor, Parameter, nn, ops 182 >>> 183 >>> class Net(nn.Cell): 184 ... def __init__(self, in_features, out_features): 185 ... super(Net, self).__init__() 186 ... self.weight = Parameter(Tensor(np.ones([in_features, out_features]).astype(np.float32)), 187 ... name='weight') 188 ... self.matmul = ops.MatMul() 189 ... 190 ... def construct(self, x): 191 ... output = self.matmul(x, self.weight) 192 ... return output 193 ... 194 >>> in_features, out_features = 16, 10 195 >>> net = Net(in_features, out_features) 196 >>> loss = nn.MSELoss() 197 >>> optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) 198 >>> net_with_loss = nn.WithLossCell(net, loss) 199 >>> manager = nn.FixedLossScaleUpdateCell(loss_scale_value=2**12) 200 >>> train_network = nn.TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_sense=manager) 201 >>> input = Tensor(np.ones([out_features, in_features]), mindspore.float32) 202 >>> labels = Tensor(np.ones([out_features,]), mindspore.float32) 203 >>> output = train_network(input, labels) 204 """ 205 206 def __init__(self, loss_scale_value): 207 super(FixedLossScaleUpdateCell, self).__init__() 208 self.loss_scale_value = loss_scale_value 209 210 def get_loss_scale(self): 211 """ 212 Get Loss Scale value. 213 """ 214 return self.loss_scale_value 215 216 def construct(self, _, overflow): 217 return overflow 218 219 220class TrainOneStepWithLossScaleCell(TrainOneStepCell): 221 r""" 222 Network training with loss scaling. 223 224 This is a training step with loss scaling. It takes a network, an optimizer and possibly a scale update 225 Cell as args. The loss scale value can be updated in both host side or device side. The 226 TrainOneStepWithLossScaleCell will be compiled to be graph which takes `*inputs` as input data. 227 The Tensor type of `scale_sense` is acting as loss scaling value. If you want to update it on host side, 228 the value must be provided. If the Tensor type of `scale_sense` is not given, the loss scale update logic 229 must be provied by Cell type of `scale_sense`. 230 231 Args: 232 network (Cell): The training network. The network only supports single output. 233 optimizer (Cell): Optimizer for updating the weights. 234 scale_sense (Union[Tensor, Cell]): If this value is Cell type, the loss scaling update logic cell.If this value 235 is Tensor type, Tensor with shape :math:`()` or :math:`(1,)`. 236 237 Inputs: 238 - **(*inputs)** (Tuple(Tensor)) - Tuple of input tensors with shape :math:`(N, \ldots)`. 239 240 Outputs: 241 Tuple of 3 Tensor, the loss, overflow flag and current loss scaling value. 242 243 - **loss** (Tensor) - Tensor with shape :math:`()`. 244 - **overflow** (Tensor) - Tensor with shape :math:`()`, type is bool. 245 - **loss scaling value** (Tensor) - Tensor with shape :math:`()` 246 247 Raises: 248 TypeError: If `scale_sense` is neither Cell nor Tensor. 249 ValueError: If shape of `scale_sense` is neither (1,) nor (). 250 251 Supported Platforms: 252 ``Ascend`` ``GPU`` 253 254 Examples: 255 >>> import numpy as np 256 >>> from mindspore import Tensor, Parameter, nn, ops 257 >>> from mindspore import dtype as mstype 258 >>> 259 >>> class Net(nn.Cell): 260 ... def __init__(self, in_features, out_features): 261 ... super(Net, self).__init__() 262 ... self.weight = Parameter(Tensor(np.ones([in_features, out_features]).astype(np.float32)), 263 ... name='weight') 264 ... self.matmul = ops.MatMul() 265 ... 266 ... def construct(self, x): 267 ... output = self.matmul(x, self.weight) 268 ... return output 269 ... 270 >>> size, in_features, out_features = 16, 16, 10 271 >>> #1) when the type of scale_sense is Cell: 272 >>> net = Net(in_features, out_features) 273 >>> loss = nn.MSELoss() 274 >>> optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) 275 >>> net_with_loss = nn.WithLossCell(net, loss) 276 >>> manager = nn.DynamicLossScaleUpdateCell(loss_scale_value=2**12, scale_factor=2, scale_window=1000) 277 >>> train_network = nn.TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_sense=manager) 278 >>> input = Tensor(np.ones([out_features, in_features]), mindspore.float32) 279 >>> labels = Tensor(np.ones([out_features,]), mindspore.float32) 280 >>> output = train_network(input, labels) 281 >>> 282 >>> #2) when the type of scale_sense is Tensor: 283 >>> net = Net(in_features, out_features) 284 >>> loss = nn.MSELoss() 285 >>> optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) 286 >>> net_with_loss = nn.WithLossCell(net, loss) 287 >>> inputs = Tensor(np.ones([size, in_features]).astype(np.float32)) 288 >>> label = Tensor(np.zeros([size, out_features]).astype(np.float32)) 289 >>> scaling_sens = Tensor(np.full((1), np.finfo(np.float32).max), dtype=mstype.float32) 290 >>> train_network = nn.TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_sense=scaling_sens) 291 >>> output = train_network(inputs, label) 292 """ 293 def __init__(self, network, optimizer, scale_sense): 294 super(TrainOneStepWithLossScaleCell, self).__init__(network, optimizer, sens=None) 295 self.hyper_map = C.HyperMap() 296 self.base = Tensor(1, mstype.float32) 297 self.reduce_sum = P.ReduceSum(keep_dims=False) 298 self.less_equal = P.LessEqual() 299 self.allreduce = P.AllReduce() 300 self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) 301 self.gpu_target = (context.get_context("device_target") == "GPU") 302 self.loss_scaling_manager = None 303 if isinstance(scale_sense, Cell): 304 self.loss_scaling_manager = scale_sense 305 self.scale_sense = Parameter(Tensor(scale_sense.get_loss_scale(), dtype=mstype.float32), 306 name="scale_sense") 307 elif isinstance(scale_sense, Tensor): 308 if scale_sense.shape == (1,) or scale_sense.shape == (): 309 self.scale_sense = Parameter(scale_sense, name='scale_sense') 310 else: 311 raise ValueError("The shape of scale_sense must be (1,) or (), but got {}".format(scale_sense.shape)) 312 else: 313 raise TypeError("The scale_sense must be Cell or Tensor, but got {}".format(type(scale_sense))) 314 315 def construct(self, *inputs): 316 weights = self.weights 317 loss = self.network(*inputs) 318 scaling_sens = self.scale_sense 319 320 status, scaling_sens = self.start_overflow_check(loss, scaling_sens) 321 322 scaling_sens_filled = C.ones_like(loss) * F.cast(scaling_sens, F.dtype(loss)) 323 grads = self.grad(self.network, weights)(*inputs, scaling_sens_filled) 324 grads = self.hyper_map(F.partial(_grad_scale, scaling_sens), grads) 325 # apply grad reducer on grads 326 grads = self.grad_reducer(grads) 327 328 # get the overflow buffer 329 cond = self.get_overflow_status(status, grads) 330 overflow = self.process_loss_scale(cond) 331 # if there is no overflow, do optimize 332 if not overflow: 333 loss = F.depend(loss, self.optimizer(grads)) 334 return loss, cond, scaling_sens 335 336 def set_sense_scale(self, sens): 337 """ 338 If the user has set the sens in the training process and wants to reassign the value, he can call 339 this function again to make modification, and sens needs to be of type Tensor. 340 341 Inputs: 342 - **sens** (Tensor) - The new sense whose shape and type are the same with original `scale_sense`. 343 """ 344 if self.scale_sense and isinstance(sens, Tensor): 345 self.scale_sense.set_data(sens) 346 else: 347 raise TypeError("The input type must be Tensor, but got {}".format(type(sens))) 348 349 def start_overflow_check(self, pre_cond, compute_input): 350 """ 351 Start floating-point overflow detection. Create and clear the overflow detection state. 352 353 Specify the argument 'pre_cond' and 'compute_input' to make sure overflow status is cleared at the right time. 354 Taking this situation as an example, we need to execute state clearing after loss calculation and then detect 355 overflow in the process of gradient calculation. In this case, pre_cond should be the output of the loss 356 function, and compute_input should be the input of gradients-computing function. 357 358 Inputs: 359 - **pre_cond** (Tensor) - A precondition for starting overflow detection. It determines the executing order 360 of overflow state clearing and prior processions. It makes sure that the function 'start_overflow' 361 clears status after finishing the process of precondition. 362 - **compute_input** (object) - The input of subsequent process. Overflow detection should be performed on a 363 certain computation. Set `compute_input` as the input of the computation, to ensure overflow status is 364 cleared before executing the computation. 365 366 Outputs: 367 Tuple[object, object], the first value is False for GPU backend, while it is a instance of 368 NPUAllocFloatStatus for other backend. The status is used to detect overflow during overflow detection. 369 The second value is the same as the input of `compute_input`, but contains some information about the 370 execution order. 371 """ 372 status = False 373 if not self.gpu_target: 374 # init overflow buffer 375 status = P.NPUAllocFloatStatus()() 376 status = F.depend(status, pre_cond) 377 # clear overflow buffer 378 clear_status = P.NPUClearFloatStatus()(status) 379 compute_input = F.depend(compute_input, clear_status) 380 return status, compute_input 381 382 def get_overflow_status(self, status, compute_output): 383 """ 384 Get floating-point overflow status. 385 386 Get overflow results after executing the target process for overflow detection. 387 388 Inputs: 389 - **status** (object) - A status instance used to detect the overflow. 390 - **compute_output** - Overflow detection should be performed on a certain computation. Set `compute_output` 391 as the output of the computation, to ensure overflow status is acquired before executing the 392 computation. 393 394 Outputs: 395 bool, whether the overflow occurs or not. 396 """ 397 if not self.gpu_target: 398 status = F.depend(status, compute_output) 399 get_status = P.NPUGetFloatStatus()(status) 400 status = F.depend(status, get_status) 401 # sum overflow buffer elements, 0:not overflow , >0:overflow 402 flag_sum = self.reduce_sum(status, (0,)) 403 else: 404 flag_sum = self.hyper_map(F.partial(_grad_overflow), compute_output) 405 flag_sum = P.AddN()(flag_sum) 406 # convert flag_sum to scalar 407 flag_sum = P.Reshape()(flag_sum, (())) 408 409 if self.is_distributed: 410 # sum overflow flag over devices 411 flag_reduce = self.allreduce(flag_sum) 412 overflow = self.less_equal(self.base, flag_reduce) 413 else: 414 overflow = self.less_equal(self.base, flag_sum) 415 return overflow 416 417 def process_loss_scale(self, overflow): 418 """ 419 Calculate loss scale according to the overflow. 420 421 Inputs: 422 - **overflow** (bool) - Whether the overflow occurs or not. 423 424 Outputs: 425 bool, overflow value. 426 """ 427 if self.loss_scaling_manager is not None: 428 return self.loss_scaling_manager(self.scale_sense, overflow) 429 return overflow 430 431 432grad_scale = C.MultitypeFuncGraph("grad_scale") 433shard_grad_scale = C.MultitypeFuncGraph("shard_grad_scale") 434reciprocal = P.Reciprocal() 435 436 437@grad_scale.register("Tensor", "Tensor", "Tensor") 438def tensor_grad_scale_pipeline(scale, grad, accu_grad): 439 accu_grad = F.depend(accu_grad, grad) 440 new_grad = accu_grad * reciprocal(scale) 441 accu_grad = F.depend(accu_grad, new_grad) 442 zeros = F.tensor_mul(accu_grad, 0.0) 443 new_grad = F.depend(new_grad, F.assign(accu_grad, zeros)) 444 return new_grad 445 446 447@shard_grad_scale.register("Tensor", "Tensor", "Tensor") 448def tensor_shard_grad_scale_pipeline(scale, grad, accu_grad): 449 new_grad = grad * reciprocal(scale) 450 accu_grad = F.depend(accu_grad, new_grad) 451 new_grad = F.depend(new_grad, F.assign(accu_grad, F.zeros_like(accu_grad))) 452 return new_grad 453 454 455class _TrainPipelineWithLossScaleCell(TrainOneStepCell): 456 """ 457 Append an optimizer to the training network after that the construct 458 function can be called to create the backward graph. 459 460 Args: 461 network (Cell): The training network. Note that loss function should have been added. 462 optimizer (Optimizer): Optimizer for updating the weights. 463 scale_sense (Cell): Cell to do the loss scale. 464 """ 465 def __init__(self, network, optimizer, scale_sense): 466 super(_TrainPipelineWithLossScaleCell, self).__init__(network, optimizer, sens=None) 467 self.network = network 468 self.network.add_flags(defer_inline=True) 469 self.weights = optimizer.parameters 470 self.accu_grads = self.weights.clone(prefix="accu_grads", init="zeros") 471 self.optimizer = optimizer 472 self.grad = C.GradOperation(get_by_list=True, sens_param=True) 473 self.grad_reducer = F.identity 474 self.degree = 1 475 self.cast = P.Cast() 476 self.alloc_status = P.NPUAllocFloatStatus() 477 self.get_status = P.NPUGetFloatStatus() 478 self.clear_before_grad = P.NPUClearFloatStatus() 479 self.reduce_sum = P.ReduceSum(keep_dims=False) 480 self.base = Tensor(1, mstype.float32) 481 self.less_equal = P.LessEqual() 482 self.hyper_map = C.HyperMap() 483 self.reshape = P.Reshape() 484 self.loss_scaling_manager = None 485 if isinstance(scale_sense, Cell): 486 self.loss_scaling_manager = scale_sense 487 self.scale_sense = Parameter(Tensor(scale_sense.get_loss_scale(), dtype=mstype.float32), 488 name="scale_sense") 489 elif isinstance(scale_sense, Tensor): 490 if scale_sense.shape == (1,) or scale_sense.shape == (): 491 self.scale_sense = Parameter(scale_sense, name='scale_sense') 492 else: 493 raise ValueError("The shape of scale_sense must be (1,) or (), but got {}".format(scale_sense.shape)) 494 else: 495 raise TypeError("The scale_sense must be Cell or Tensor, but got {}".format(type(scale_sense))) 496 self.opt_shard = _get_enable_parallel_optimizer() 497 498 def construct(self, *inputs): 499 weights = self.weights 500 loss = self.network(*inputs) 501 scaling_sens = self.scale_sense 502 init = self.alloc_status() 503 status_clear = self.clear_before_grad(init) 504 scaling_sens_filled = C.ones_like(loss) * F.cast(scaling_sens, F.dtype(loss)) 505 grads = self.grad(self.network, weights)(*inputs, scaling_sens_filled) 506 init = F.depend(init, grads) 507 get_status = self.get_status(init) 508 init = F.depend(init, get_status) 509 flag_sum = self.reduce_sum(init, (0,)) 510 loss = F.depend(loss, status_clear) 511 if self.opt_shard: 512 grads = self.grad_reducer(grads) 513 grads = self.hyper_map(F.partial(shard_grad_scale, scaling_sens * self.degree), grads, self.accu_grads) 514 else: 515 accu_grads = self.grad_reducer(self.accu_grads) 516 grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads, accu_grads) 517 cond = self.less_equal(self.base, flag_sum) 518 overflow = cond 519 if self.loss_scaling_manager is not None: 520 overflow = self.loss_scaling_manager(self.scale_sense, cond) 521 if overflow: 522 succ = False 523 else: 524 succ = self.optimizer(grads) 525 ret = (loss, overflow, scaling_sens) 526 return F.depend(ret, succ) 527