1# Copyright 2020-2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""lamb""" 16import numpy as np 17from mindspore import context 18from mindspore.common import dtype as mstype 19from mindspore.common.initializer import initializer 20from mindspore.ops import operations as P 21from mindspore.ops import composite as C 22from mindspore.ops import functional as F 23from mindspore.common.parameter import Parameter 24from mindspore.common.tensor import Tensor 25from mindspore._checkparam import Validator as validator 26from mindspore._checkparam import Rel 27from .optimizer import Optimizer 28from .optimizer import opt_init_args_register 29from .. import layer 30 31 32num_one = Tensor(np.ones([1]), mstype.float32) 33 34_lamb_opt = C.MultitypeFuncGraph("lamb_opt") 35 36 37@_lamb_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor", 38 "Tensor", "Bool", "Bool") 39def _update_run_op(beta1, beta2, eps, global_step, lr, weight_decay, param, m, v, gradient, decay_flag, optim_filter): 40 """ 41 Update parameters. 42 43 Args: 44 beta1 (Tensor): The exponential decay rate for the 1st moment estimations. Should be in range (0.0, 1.0). 45 beta2 (Tensor): The exponential decay rate for the 2nd moment estimations. Should be in range (0.0, 1.0). 46 eps (Tensor): Term added to the denominator to improve numerical stability. Should be greater than 0. 47 lr (Tensor): Learning rate. 48 weight_decay (numbers.Number): Weight decay. Should be equal to or greater than 0. 49 global_step (Tensor): Global step. 50 param (Tensor): Parameters. 51 m (Tensor): m value of parameters. 52 v (Tensor): v value of parameters. 53 gradient (Tensor): Gradient of parameters. 54 decay_flag (bool): Specifies whether param update with weight decay. 55 optim_filter(bool): Applies parameter update or not. 56 57 Returns: 58 Tensor, the new value of v after updating. 59 """ 60 if optim_filter: 61 op_mul = P.Mul() 62 op_sqrt = P.Sqrt() 63 op_rsqrt = P.Rsqrt() 64 op_square = P.Square() 65 op_cast = P.Cast() 66 op_reshape = P.Reshape() 67 op_shape = P.Shape() 68 op_pow = P.Pow() 69 op_norm = layer.Norm() 70 op_select = P.Select() 71 op_greater = P.Greater() 72 op_fill = P.Fill() 73 op_dtype = P.DType() 74 75 param_fp32 = op_cast(param, mstype.float32) 76 m_fp32 = op_cast(m, mstype.float32) 77 v_fp32 = op_cast(v, mstype.float32) 78 gradient_fp32 = op_cast(gradient, mstype.float32) 79 80 next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(num_one, mstype.float32) - beta1, gradient_fp32) 81 82 next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(num_one, mstype.float32) - beta2, op_square(gradient_fp32)) 83 84 next_mm = next_m / (op_cast(num_one, mstype.float32) 85 - op_pow(beta1, op_cast(global_step + num_one, mstype.float32))) 86 next_vv = next_v / (op_cast(num_one, mstype.float32) - 87 op_pow(beta2, op_cast(global_step + num_one, mstype.float32))) 88 w_norm = op_norm(param_fp32) 89 g_norm = op_norm(gradient_fp32) 90 91 g_norm_hat = op_norm(op_mul(next_mm, op_rsqrt(next_vv + eps)) + weight_decay * param_fp32) 92 zeros = F.zeros_like(w_norm) 93 ones = op_fill(op_dtype(w_norm), op_shape(w_norm), 1.0) 94 trust_ratio = op_select( 95 op_greater(w_norm, zeros), 96 op_select(op_greater(g_norm, zeros), w_norm / g_norm_hat, ones), 97 ones) 98 tens = op_fill(op_dtype(trust_ratio), op_shape(trust_ratio), 10.0) 99 trust_ratio = C.clip_by_value(trust_ratio, zeros, tens) 100 update = next_mm / (op_sqrt(next_vv) + eps) 101 102 if decay_flag: 103 update = update + op_mul(weight_decay, param_fp32) 104 105 update_with_lr = op_mul(op_mul(trust_ratio, lr), update) 106 107 next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32)) 108 109 next_param = F.depend(next_param, F.assign(param, op_cast(next_param, F.dtype(param)))) 110 next_param = F.depend(next_param, F.assign(m, op_cast(next_m, F.dtype(m)))) 111 next_param = F.depend(next_param, F.assign(v, op_cast(next_v, F.dtype(v)))) 112 113 return op_cast(next_param, F.dtype(param)) 114 return gradient 115 116_lamb_opt_ascend = C.MultitypeFuncGraph("lamb_opt_ascend") 117 118 119@_lamb_opt_ascend.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor", 120 "Tensor", "Bool", "Bool") 121def _update_run_op_ascend(beta1, beta2, eps, global_step, lr, weight_decay, param, m, v, gradient, decay_flag, 122 optim_filter): 123 """ 124 Update parameters function when device target is ascend. 125 126 Args: 127 beta1 (Tensor): The exponential decay rate for the 1st moment estimations. Should be in range (0.0, 1.0). 128 beta2 (Tensor): The exponential decay rate for the 2nd moment estimations. Should be in range (0.0, 1.0). 129 eps (Tensor): Term added to the denominator to improve numerical stability. Should be greater than 0. 130 lr (Tensor): Learning rate. 131 weight_decay (numbers.Number): Weight decay. Should be equal to or greater than 0. 132 global_step (Tensor): Global step. 133 param (Tensor): Parameters. 134 m (Tensor): m value of parameters. 135 v (Tensor): v value of parameters. 136 gradient (Tensor): Gradient of parameters. 137 decay_flag (bool): Specifies whether param update with weight decay. 138 optim_filter(bool): Applies parameter update or not. 139 140 Returns: 141 Tensor, the new value of v after updating. 142 """ 143 if optim_filter: 144 op_cast = P.Cast() 145 op_norm = layer.Norm() 146 op_lamb_apply_optimizer_assign = P.LambApplyOptimizerAssign() 147 op_lamb_apply_weight_assign = P.LambApplyWeightAssign() 148 149 param_fp32 = op_cast(param, mstype.float32) 150 gradient_fp32 = op_cast(gradient, mstype.float32) 151 new_global_step = op_cast(global_step + num_one, mstype.float32) 152 weight_decay_flag = op_cast(decay_flag, mstype.float32) 153 154 update, _, _ = op_lamb_apply_optimizer_assign(gradient_fp32, v, m, param_fp32, 155 beta1, 1.0 - beta1, beta2, 1.0 - beta2, eps, 156 new_global_step, weight_decay_flag, weight_decay) 157 w_norm = op_norm(param_fp32) 158 g_norm = op_norm(update) 159 update = F.depend(update, op_lamb_apply_weight_assign(w_norm, g_norm, lr, update, param)) 160 return update 161 return gradient 162 163 164def _check_param_value(beta1, beta2, eps, prim_name): 165 validator.check_value_type("beta1", beta1, [float], prim_name) 166 validator.check_value_type("beta2", beta2, [float], prim_name) 167 validator.check_value_type("eps", eps, [float], prim_name) 168 validator.check_float_range(beta1, 0.0, 1.0, Rel.INC_NEITHER, "beta1", prim_name) 169 validator.check_float_range(beta2, 0.0, 1.0, Rel.INC_NEITHER, "beta2", prim_name) 170 validator.check_positive_float(eps, "eps", prim_name) 171 172 173class Lamb(Optimizer): 174 r""" 175 Lamb(Layer-wise Adaptive Moments optimizer for Batching training) Dynamic Learning Rate. 176 177 LAMB is an optimization algorithm employing a layerwise adaptive large batch optimization technique. 178 Refer to the paper `LARGE BATCH OPTIMIZATION FOR DEEP LEARNING: TRAINING BERT IN 76 179 MINUTES <https://arxiv.org/abs/1904.00962>`_. 180 181 The LAMB optimizer aims to increase the training batch size without reducing the accuracy, 182 and it supports adaptive element-by-element update and accurate layered correction. 183 184 The updating of parameters follows: 185 186 .. math:: 187 \begin{gather*} 188 m_t = \beta_1 m_{t - 1}+ (1 - \beta_1)g_t\\ 189 v_t = \beta_2 v_{t - 1} + (1 - \beta_2)g_t^2\\ 190 m_t = \frac{m_t}{\beta_1^t}\\ 191 v_t = \frac{v_t}{\beta_2^t}\\ 192 r_t = \frac{m_t}{\sqrt{v_t}+\epsilon}\\ 193 w_t = w_{t-1} -\eta_t \frac{\| w_{t-1} \|}{\| r_t + \lambda w_{t-1} \|} (r_t + \lambda w_{t-1}) 194 \end{gather*} 195 196 where :math:`m` is the 1st moment, and :math:`v` the 2nd moment, :math:`\eta` the 197 learning rate, :math:`\lambda` the LAMB weight decay rate. 198 199 Note: 200 When separating parameter groups, the weight decay in each group will be applied on the parameters if the 201 weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied 202 on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive. 203 204 When separating parameter groups, if you want to centralize the gradient, set grad_centralization to True, 205 but the gradient centralization can only be applied to the parameters of the convolution layer. 206 If the parameters of the non convolution layer are set to True, an error will be reported. 207 208 To improve parameter groups performance, the customized order of parameters can be supported. 209 210 There is usually no connection between a optimizer and mixed precision. But when `FixedLossScaleManager` is used 211 and `drop_overflow_update` in `FixedLossScaleManager` is set to False, optimizer needs to set the 'loss_scale'. 212 As this optimizer has no argument of `loss_scale`, so `loss_scale` needs to be processed by other means, refer 213 document `LossScale <https://www.mindspore.cn/docs/programming_guide/zh-CN/r1.5/lossscale.html>`_ to process 214 `loss_scale` correctly. 215 216 Args: 217 params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated, 218 the element in `params` must be class `Parameter`. When the `params` is a list of `dict`, the "params", 219 "lr", "weight_decay" and "order_params" are the keys can be parsed. 220 221 - params: Required. The value must be a list of `Parameter`. 222 223 - lr: Optional. If "lr" in the keys, the value of corresponding learning rate will be used. 224 If not, the `learning_rate` in the API will be used. 225 226 - weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay 227 will be used. If not, the `weight_decay` in the API will be used. 228 229 - order_params: Optional. If "order_params" in the keys, the value must be the order of parameters and 230 the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which 231 in the value of 'order_params' must be in one of group parameters. 232 233 - grad_centralization: Optional. The data type of "grad_centralization" is Bool. If "grad_centralization" 234 is in the keys, the set value will be used. If not, the `grad_centralization` is False by default. 235 This parameter only works on the convolution layer. 236 237 learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate. 238 When the learning_rate is an Iterable or a Tensor in a 1D dimension, use dynamic learning rate, then 239 the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule, 240 use dynamic learning rate, the i-th learning rate will be calculated during the process of training 241 according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor in a zero 242 dimension, use fixed learning rate. Other cases are not supported. The float learning rate must be 243 equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float. 244 beta1 (float): The exponential decay rate for the 1st moment estimations. Default: 0.9. 245 Should be in range (0.0, 1.0). 246 beta2 (float): The exponential decay rate for the 2nd moment estimations. Default: 0.999. 247 Should be in range (0.0, 1.0). 248 eps (float): Term added to the denominator to improve numerical stability. Default: 1e-6. 249 Should be greater than 0. 250 weight_decay (float): Weight decay (L2 penalty). Default: 0.0. Should be equal to or greater than 0. 251 252 Inputs: 253 - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`. 254 255 Outputs: 256 tuple[bool], all elements are True. 257 258 Raises: 259 TypeError: If `learning_rate` is not one of int, float, Tensor, Iterable, LearningRateSchedule. 260 TypeError: If element of `parameters` is neither Parameter nor dict. 261 TypeError: If `beta1`, `beta2` or `eps` is not a float. 262 TypeError: If `weight_decay` is neither float nor int. 263 ValueError: If `eps` is less than or equal to 0. 264 ValueError: If `beta1`, `beta2` is not in range (0.0, 1.0). 265 ValueError: If `weight_decay` is less than 0. 266 267 Supported Platforms: 268 ``Ascend`` ``GPU`` ``CPU`` 269 270 Examples: 271 >>> net = Net() 272 >>> #1) All parameters use the same learning rate and weight decay 273 >>> optim = nn.Lamb(params=net.trainable_params(), learning_rate=0.1) 274 >>> 275 >>> #2) Use parameter groups and set different values 276 >>> poly_decay_lr = learning_rate_schedule.PolynomialDecayLR(learning_rate=0.1, end_learning_rate=0.01, 277 ... decay_steps=4, power = 0.5) 278 >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) 279 >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) 280 >>> group_params = [{'params': conv_params, 'weight_decay': 0.01, 'grad_centralization':True}, 281 ... {'params': no_conv_params, 'lr': poly_decay_lr}, 282 ... {'order_params': net.trainable_params(0.01)}] 283 >>> optim = nn.Lamb(group_params, learning_rate=0.1, weight_decay=0.0) 284 >>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01 and grad 285 >>> # centralization of True. 286 >>> # The no_conv_params's parameters will use dynamic learning rate of poly decay learning rate and default 287 >>> # weight decay of 0.0 and grad centralization of False. 288 >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'. 289 >>> 290 >>> loss = nn.SoftmaxCrossEntropyWithLogits() 291 >>> model = Model(net, loss_fn=loss, optimizer=optim) 292 """ 293 294 @opt_init_args_register 295 def __init__(self, params, learning_rate, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0): 296 super(Lamb, self).__init__(learning_rate, params, weight_decay) 297 _check_param_value(beta1, beta2, eps, self.cls_name) 298 299 # turn them to scalar when me support scalar/tensor mix operations 300 self.beta1 = Tensor(np.array([beta1]).astype(np.float32)) 301 self.beta2 = Tensor(np.array([beta2]).astype(np.float32)) 302 self.eps = Tensor(np.array([eps]).astype(np.float32)) 303 self.params = self.parameters 304 self.moments1 = self.params.clone(prefix="lamb_m", init='zeros') 305 self.moments2 = self.params.clone(prefix="lamb_v", init='zeros') 306 307 if not self.dynamic_lr: 308 self.global_step = Parameter(initializer(0, [1]), name='global_step') 309 self.assignadd = P.AssignAdd() 310 self.device_ascend = context.get_context("device_target") == "Ascend" 311 312 def construct(self, gradients): 313 lr = self.get_lr() 314 lamb_opt = _lamb_opt_ascend if self.device_ascend else _lamb_opt 315 gradients = self.gradients_centralization(gradients) 316 if self.is_group: 317 if self.is_group_lr: 318 optim_result = self.hyper_map(F.partial(lamb_opt, self.beta1, self.beta2, self.eps, 319 self.global_step), 320 lr, self.weight_decay, self.params, self.moments1, self.moments2, 321 gradients, self.decay_flags, self.optim_filter) 322 else: 323 optim_result = self.hyper_map(F.partial(lamb_opt, self.beta1, self.beta2, self.eps, 324 self.global_step, lr), 325 self.weight_decay, self.params, self.moments1, self.moments2, 326 gradients, self.decay_flags, self.optim_filter) 327 else: 328 optim_result = self.hyper_map(F.partial(lamb_opt, self.beta1, self.beta2, self.eps, 329 self.global_step, lr, self.weight_decay), 330 self.params, self.moments1, self.moments2, gradients, 331 self.decay_flags, self.optim_filter) 332 333 if self.use_parallel: 334 optim_result = F.depend(optim_result, self.broadcast_params(optim_result)) 335 336 if not self.dynamic_lr: 337 optim_result = F.depend(optim_result, self.assignadd(self.global_step, 1)) 338 339 return optim_result 340