1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""Loss scale manager abstract class.""" 16 17from .._checkparam import Validator as validator 18from .. import nn 19 20 21class LossScaleManager: 22 """ 23 Loss scale manager abstract class. 24 25 Derive FixedLossScaleManager and DynamicLossScaleManager that override all LossScaleManager's method. 26 """ 27 def get_loss_scale(self): 28 """Get loss scale value.""" 29 30 def update_loss_scale(self, overflow): 31 """ 32 Update loss scale value. 33 34 Args: 35 overflow (bool): Whether it overflows. 36 """ 37 def get_update_cell(self): 38 """Get the loss scaling update logic cell.""" 39 40 41class FixedLossScaleManager(LossScaleManager): 42 """ 43 Loss scale with a fixed value, inherits from LossScaleManager. 44 45 Args: 46 loss_scale (float): Loss scale. Note that if `drop_overflow_update` is set to False, the value of `loss_scale` 47 in optimizer that you used need to be set to the same value as here. Default: 128.0. 48 drop_overflow_update (bool): Whether to execute optimizer if there is an overflow. If True, the optimizer will 49 not executed when overflow occurs. Default: True. 50 51 Examples: 52 >>> from mindspore import Model, nn, FixedLossScaleManager 53 >>> 54 >>> net = Net() 55 >>> #1) Drop the parameter update if there is an overflow 56 >>> loss_scale_manager = FixedLossScaleManager() 57 >>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) 58 >>> model = Model(net, loss_scale_manager=loss_scale_manager, optimizer=optim) 59 >>> 60 >>> #2) Execute parameter update even if overflow occurs 61 >>> loss_scale = 1024.0 62 >>> loss_scale_manager = FixedLossScaleManager(loss_scale, False) 63 >>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9, loss_scale=loss_scale) 64 >>> model = Model(net, loss_scale_manager=loss_scale_manager, optimizer=optim) 65 """ 66 def __init__(self, loss_scale=128.0, drop_overflow_update=True): 67 if loss_scale < 1: 68 raise ValueError("The argument 'loss_scale' must be >= 1, " 69 "but got {}".format(loss_scale)) 70 self._loss_scale = loss_scale 71 self._drop_overflow_update = drop_overflow_update 72 73 def get_loss_scale(self): 74 """ 75 Get loss scale value. 76 77 Returns: 78 bool, `loss_scale` value. 79 """ 80 return self._loss_scale 81 82 def get_drop_overflow_update(self): 83 """ 84 Get the flag whether to drop optimizer update when there is an overflow. 85 86 Returns: 87 bool, `drop_overflow_update` value. 88 """ 89 return self._drop_overflow_update 90 91 def update_loss_scale(self, overflow): 92 """ 93 Update loss scale value. The interface at `FixedLossScaleManager` will do nothing. 94 95 Args: 96 overflow (bool): Whether it overflows. 97 """ 98 99 def get_update_cell(self): 100 """ 101 Returns the update cell for `TrainOneStepWithLossScaleCell`. 102 103 Returns: 104 None or Cell. Cell object, used to update `loss_scale`, when `drop_overflow_update` is True. None when 105 `drop_overflow_update` is False. 106 """ 107 if not self._drop_overflow_update: 108 return None 109 return nn.FixedLossScaleUpdateCell(self._loss_scale) 110 111 112class DynamicLossScaleManager(LossScaleManager): 113 """ 114 Loss scale that dynamically adjusts itself, inherits from LossScaleManager. 115 116 Args: 117 init_loss_scale (float): Initialize loss scale. Default: 2**24. 118 scale_factor (int): Coefficient of increase and decrease. Default: 2. 119 scale_window (int): Maximum continuous normal steps when there is no overflow. Default: 2000. 120 121 Examples: 122 >>> from mindspore import Model, nn, DynamicLossScaleManager 123 >>> 124 >>> net = Net() 125 >>> loss_scale_manager = DynamicLossScaleManager() 126 >>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) 127 >>> model = Model(net, loss_scale_manager=loss_scale_manager, optimizer=optim) 128 """ 129 def __init__(self, 130 init_loss_scale=2 ** 24, 131 scale_factor=2, 132 scale_window=2000): 133 if init_loss_scale < 1.0: 134 raise ValueError("The argument 'init_loss_scale' must be > 1, but got {}".format(init_loss_scale)) 135 self.loss_scale = init_loss_scale 136 validator.check_positive_int(scale_window, "scale_window", self.__class__.__name__) 137 self.scale_window = scale_window 138 if scale_factor <= 0: 139 raise ValueError("The argument 'scale_factor' should be > 0, but got {}".format(scale_factor)) 140 self.scale_factor = scale_factor 141 self.increase_ratio = scale_factor 142 self.decrease_ratio = 1 / scale_factor 143 self.cur_iter = 1 144 self.last_overflow_iter = 0 145 self.bad_step_max = 1000 146 self.bad_step = 0 147 148 def get_loss_scale(self): 149 """ 150 Get loss scale value. 151 152 Returns: 153 bool, `loss_scale` value. 154 """ 155 return self.loss_scale 156 157 def update_loss_scale(self, overflow): 158 """ 159 Update loss scale value. 160 161 Args: 162 overflow (bool): Whether it overflows. 163 """ 164 if overflow: 165 self.loss_scale = max(self.loss_scale * self.decrease_ratio, 1) 166 self.last_overflow_iter = self.cur_iter 167 self.bad_step += 1 168 else: 169 if (self.cur_iter - self.last_overflow_iter) % self.scale_window == 0: 170 self.loss_scale *= self.increase_ratio 171 self.bad_step = 0 172 173 if self.bad_step > self.bad_step_max: 174 raise RuntimeError("Dynamic loss scale Continuous overflow ", self.bad_step, 175 " times, has exceeded maximum threshold.") 176 177 self.cur_iter += 1 178 179 def get_drop_overflow_update(self): 180 """ 181 Get the flag whether to drop optimizer update when there is an overflow. 182 183 Returns: 184 bool, always return True at `DynamicLossScaleManager`. 185 """ 186 return True 187 188 def get_update_cell(self): 189 """ 190 Returns the update cell for `TrainOneStepWithLossScaleCell`. 191 192 Returns: 193 Cell, cell object used to update `loss_scale`. 194 """ 195 return nn.DynamicLossScaleUpdateCell(self.loss_scale, self.scale_factor, self.scale_window) 196