1# Copyright 2020-2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""Loss scale manager abstract class.""" 16from __future__ import absolute_import 17 18from mindspore import _checkparam as validator 19from mindspore import nn 20 21 22class LossScaleManager: 23 """ 24 Loss scale (Magnification factor of gradients when mix precision is used) manager abstract class when using 25 mixed precision. 26 27 Derived class needs to implement all of its methods. `get_loss_scale` is used to get current loss scale value. 28 `update_loss_scale` is used to update loss scale value, `update_loss_scale` will be called during the training. 29 `get_update_cell` is used to get the instance of :class:`mindspore.nn.Cell` that is used to update the loss scale, 30 the instance will be called during the training. Currently, the `get_update_cell` is mostly used. 31 32 For example, :class:`mindspore.amp.FixedLossScaleManager` and :class:`mindspore.amp.DynamicLossScaleManager`. 33 """ 34 def get_loss_scale(self): 35 """Get the value of loss scale, which is the amplification factor of the gradients.""" 36 37 def update_loss_scale(self, overflow): 38 """ 39 Update the loss scale value according to the status of `overflow`. 40 41 Args: 42 overflow (bool): Whether the overflow occurs during the training. 43 """ 44 def get_update_cell(self): 45 """Get the instance of :class:`mindspore.nn.Cell` that is used to update the loss scale.""" 46 47 48class FixedLossScaleManager(LossScaleManager): 49 """ 50 Loss scale (Magnification factor of gradients when mix precision is used) manager with a fixed loss scale value, 51 inherits from :class:`mindspore.amp.LossScaleManager`. 52 53 Args: 54 loss_scale (float): Magnification factor of gradients. Note that if `drop_overflow_update` is set to ``False`` , 55 the value of `loss_scale` in optimizer should be set to the same as here. Default: ``128.0`` . 56 drop_overflow_update (bool): Whether to execute optimizer if there is an overflow. 57 If ``True`` , the optimizer will 58 not executed when overflow occurs. Default: ``True`` . 59 60 Examples: 61 >>> import mindspore as ms 62 >>> from mindspore import amp, nn 63 >>> 64 >>> # Define the network structure of LeNet5. Refer to 65 >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py 66 >>> net = LeNet5() 67 >>> loss_scale = 1024.0 68 >>> loss_scale_manager = amp.FixedLossScaleManager(loss_scale, False) 69 >>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9, loss_scale=loss_scale) 70 >>> model = ms.Model(net, loss_scale_manager=loss_scale_manager, optimizer=optim) 71 """ 72 def __init__(self, loss_scale=128.0, drop_overflow_update=True): 73 if loss_scale < 1: 74 raise ValueError("The argument 'loss_scale' must be >= 1, " 75 "but got {}".format(loss_scale)) 76 self._loss_scale = loss_scale 77 self._drop_overflow_update = drop_overflow_update 78 79 def get_loss_scale(self): 80 """ 81 Get loss scale value. 82 83 Returns: 84 bool, `loss_scale` value. 85 """ 86 return self._loss_scale 87 88 def get_drop_overflow_update(self): 89 """ 90 Get `drop_overflow_update`, whether to drop optimizer update for current step when there is an overflow. 91 92 Returns: 93 bool, `drop_overflow_update` value. 94 """ 95 return self._drop_overflow_update 96 97 def update_loss_scale(self, overflow): 98 """ 99 Update loss scale value. The interface at :class:`mindspore.amp.FixedLossScaleManager` will do nothing. 100 101 Args: 102 overflow (bool): Whether it overflows. 103 """ 104 105 def get_update_cell(self): 106 """ 107 Returns the instance of :class:`mindspore.nn.Cell` that used to update the loss scale which will be called at 108 :class:`mindspore.nn.TrainOneStepWithLossScaleCell`. As the loss scale is fixed in this class, the instance 109 will do nothing. 110 111 Returns: 112 None or :class:`mindspore.nn.FixedLossScaleUpdateCell`. Instance of 113 :class:`mindspore.nn.FixedLossScaleUpdateCell` when `drop_overflow_update` is True. None when 114 `drop_overflow_update` is False. 115 """ 116 if not self._drop_overflow_update: 117 return None 118 return nn.FixedLossScaleUpdateCell(self._loss_scale) 119 120 121class DynamicLossScaleManager(LossScaleManager): 122 """ 123 Loss scale(Magnification factor of gradients when mix precision is used) manager with loss scale dynamically 124 adjusted, inherits from :class:`mindspore.amp.LossScaleManager`. 125 126 Args: 127 init_loss_scale (float): Initialize loss scale. Default: ``2 ** 24`` . 128 scale_factor (int): Coefficient of increase and decrease. Default: ``2`` . 129 scale_window (int): Maximum continuous normal steps when there is no overflow. Default: ``2000`` . 130 131 Supported Platforms: 132 ``Ascend`` ``GPU`` 133 134 Examples: 135 >>> import mindspore as ms 136 >>> from mindspore import amp, nn 137 >>> 138 >>> # Define the network structure of LeNet5. Refer to 139 >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py 140 >>> net = LeNet5() 141 >>> loss_scale_manager = amp.DynamicLossScaleManager() 142 >>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) 143 >>> model = ms.Model(net, loss_scale_manager=loss_scale_manager, optimizer=optim) 144 """ 145 def __init__(self, 146 init_loss_scale=2 ** 24, 147 scale_factor=2, 148 scale_window=2000): 149 if init_loss_scale < 1.0: 150 raise ValueError("The argument 'init_loss_scale' must be > 1, but got {}".format(init_loss_scale)) 151 self.loss_scale = init_loss_scale 152 validator.check_positive_int(scale_window, "scale_window", self.__class__.__name__) 153 self.scale_window = scale_window 154 if scale_factor <= 0: 155 raise ValueError("The argument 'scale_factor' must be > 0, but got {}".format(scale_factor)) 156 self.scale_factor = scale_factor 157 self.increase_ratio = scale_factor 158 self.decrease_ratio = 1 / scale_factor 159 self.cur_iter = 1 160 self.last_overflow_iter = 0 161 self.bad_step_max = 1000 162 self.bad_step = 0 163 164 def get_loss_scale(self): 165 """ 166 Get the current loss scale value. 167 168 Returns: 169 float, `loss_scale` value. 170 """ 171 return self.loss_scale 172 173 def update_loss_scale(self, overflow): 174 """ 175 Update the loss scale value according to the status of `overflow`. If overflow occurs, decrease loss scale per 176 `scale_window`, otherwise, increase the loss scale. 177 178 Args: 179 overflow (bool): Whether it overflows. 180 """ 181 if overflow: 182 self.loss_scale = max(self.loss_scale * self.decrease_ratio, 1) 183 self.last_overflow_iter = self.cur_iter 184 self.bad_step += 1 185 else: 186 if (self.cur_iter - self.last_overflow_iter) % self.scale_window == 0: 187 self.loss_scale *= self.increase_ratio 188 self.bad_step = 0 189 190 if self.bad_step > self.bad_step_max: 191 raise RuntimeError("Dynamic loss scale Continuous overflow ", self.bad_step, 192 " times, has exceeded maximum threshold.") 193 194 self.cur_iter += 1 195 196 def get_drop_overflow_update(self): 197 """ 198 Whether to drop optimizer update for current step when there is an overflow. 199 200 Returns: 201 bool, always True. 202 """ 203 return True 204 205 def get_update_cell(self): 206 """ 207 Returns the instance of :class:`mindspore.nn.Cell` that is used to update the loss scale which will be called at 208 :class:`mindspore.nn.TrainOneStepWithLossScaleCell`. 209 210 Returns: 211 :class:`mindspore.nn.DynamicLossScaleUpdateCell`. 212 """ 213 return nn.DynamicLossScaleUpdateCell(self.loss_scale, self.scale_factor, self.scale_window) 214