1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16"""AdamWeightDecayForBert, a customized Adam for bert. Input: gradient, overflow flag.""" 17import numpy as np 18 19from mindspore.common import dtype as mstype 20from mindspore.ops import operations as P 21from mindspore.ops import composite as C 22from mindspore.ops import functional as F 23from mindspore.common.tensor import Tensor 24from mindspore._checkparam import Validator as validator 25from mindspore._checkparam import Rel 26from mindspore.nn.optim.optimizer import Optimizer 27 28_adam_opt = C.MultitypeFuncGraph("adam_opt") 29_scaler_one = Tensor(1, mstype.int32) 30_scaler_ten = Tensor(10, mstype.float32) 31 32@_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor", 33 "Tensor", "Bool", "Bool") 34def _update_run_kernel(beta1, beta2, eps, lr, weight_decay, param, m, v, gradient, decay_flags, optim_filter): 35 """ 36 Update parameters by AdamWeightDecay op. 37 """ 38 if optim_filter: 39 adam = P.AdamWeightDecay() 40 if decay_flags: 41 next_param = adam(param, m, v, lr, beta1, beta2, eps, weight_decay, gradient) 42 else: 43 next_param = adam(param, m, v, lr, beta1, beta2, eps, 0.0, gradient) 44 return next_param 45 return gradient 46 47@_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor", 48 "Tensor", "Bool", "Bool") 49def _update_run_op(beta1, beta2, eps, lr, overflow, weight_decay, param, m, v, gradient, decay_flag, optim_filter): 50 """ 51 Update parameters. 52 53 Args: 54 beta1 (Tensor): The exponential decay rate for the 1st moment estimations. Should be in range (0.0, 1.0). 55 beta2 (Tensor): The exponential decay rate for the 2nd moment estimations. Should be in range (0.0, 1.0). 56 eps (Tensor): Term added to the denominator to improve numerical stability. Should be greater than 0. 57 lr (Tensor): Learning rate. 58 overflow (Tensor): Whether overflow occurs. 59 weight_decay (Number): Weight decay. Should be equal to or greater than 0. 60 param (Tensor): Parameters. 61 m (Tensor): m value of parameters. 62 v (Tensor): v value of parameters. 63 gradient (Tensor): Gradient of parameters. 64 decay_flag (bool): Applies weight decay or not. 65 optim_filter (bool): Applies parameter update or not. 66 67 Returns: 68 Tensor, the new value of v after updating. 69 """ 70 if optim_filter: 71 op_mul = P.Mul() 72 op_square = P.Square() 73 op_sqrt = P.Sqrt() 74 op_cast = P.Cast() 75 op_reshape = P.Reshape() 76 op_shape = P.Shape() 77 op_select = P.Select() 78 79 param_fp32 = op_cast(param, mstype.float32) 80 m_fp32 = op_cast(m, mstype.float32) 81 v_fp32 = op_cast(v, mstype.float32) 82 gradient_fp32 = op_cast(gradient, mstype.float32) 83 84 cond = op_cast(F.fill(mstype.int32, op_shape(m_fp32), 1) * op_reshape(overflow, (())), mstype.bool_) 85 next_m = op_mul(beta1, m_fp32) + op_select(cond, m_fp32,\ 86 op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta1, gradient_fp32)) 87 88 next_v = op_mul(beta2, v_fp32) + op_select(cond, v_fp32,\ 89 op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta2, op_square(gradient_fp32))) 90 91 update = next_m / (eps + op_sqrt(next_v)) 92 if decay_flag: 93 update = op_mul(weight_decay, param_fp32) + update 94 95 update_with_lr = op_mul(lr, update) 96 zeros = F.fill(mstype.float32, op_shape(param_fp32), 0) 97 next_param = param_fp32 - op_select(cond, zeros, op_reshape(update_with_lr, op_shape(param_fp32))) 98 99 next_param = F.depend(next_param, F.assign(param, op_cast(next_param, F.dtype(param)))) 100 next_param = F.depend(next_param, F.assign(m, op_cast(next_m, F.dtype(m)))) 101 next_param = F.depend(next_param, F.assign(v, op_cast(next_v, F.dtype(v)))) 102 103 return op_cast(next_param, F.dtype(param)) 104 return gradient 105 106 107@_adam_opt.register("Function", "Function", "Function", "Function", "Bool", "Bool", "Bool", "Tensor", "Tensor", 108 "Tensor", "Tensor", "Tensor", "Tensor", "RowTensor", "Tensor", "Tensor", "Tensor", "Bool", "Bool") 109def _run_opt_with_sparse(opt, sparse_opt, push, pull, use_locking, use_nesterov, target, beta1_power, 110 beta2_power, beta1, beta2, eps, lr, gradient, param, m, v, ps_parameter, cache_enable): 111 """Apply sparse adam optimizer to the weight parameter when the gradient is sparse.""" 112 success = True 113 indices = gradient.indices 114 values = gradient.values 115 if ps_parameter and not cache_enable: 116 op_shape = P.Shape() 117 shapes = (op_shape(param), op_shape(m), op_shape(v), 118 op_shape(beta1_power), op_shape(beta2_power), op_shape(lr), op_shape(beta1), 119 op_shape(beta2), op_shape(eps), op_shape(values), op_shape(indices)) 120 success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2, 121 eps, values, indices), shapes), param)) 122 return success 123 124 if not target: 125 success = F.depend(success, sparse_opt(param, m, v, beta1_power, beta2_power, lr, beta1, beta2, 126 eps, values, indices)) 127 else: 128 op_mul = P.Mul() 129 op_square = P.Square() 130 op_sqrt = P.Sqrt() 131 scatter_add = P.ScatterAdd(use_locking) 132 133 F.assign(m, op_mul(beta1, m)) 134 F.assign(v, op_mul(beta2, v)) 135 136 grad_indices = gradient.indices 137 grad_value = gradient.values 138 139 next_m = scatter_add(m, 140 grad_indices, 141 op_mul(F.tuple_to_array((1.0,)) - beta1, grad_value)) 142 143 next_v = scatter_add(v, 144 grad_indices, 145 op_mul(F.tuple_to_array((1.0,)) - beta2, op_square(grad_value))) 146 147 if use_nesterov: 148 m_temp = next_m * _scaler_ten 149 F.assign(m, op_mul(beta1, next_m)) 150 div_value = scatter_add(m, 151 op_mul(grad_indices, _scaler_one), 152 op_mul(F.tuple_to_array((1.0,)) - beta1, grad_value)) 153 param_update = div_value / (op_sqrt(next_v) + eps) 154 155 F.assign(m, m_temp / _scaler_ten) 156 157 158 else: 159 param_update = next_m / (op_sqrt(next_v) + eps) 160 161 lr_t = lr * op_sqrt(1 - beta2_power) / (1 - beta1_power) 162 163 next_param = param - lr_t * param_update 164 165 166 success = F.depend(success, F.assign(param, next_param)) 167 success = F.depend(success, F.assign(m, next_m)) 168 success = F.depend(success, F.assign(v, next_v)) 169 170 return success 171 172 173@_adam_opt.register("Function", "Function", "Function", "Function", "Bool", "Bool", "Bool", "Tensor", "Tensor", 174 "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool", "Bool") 175def _run_opt_with_one_number(opt, sparse_opt, push, pull, use_locking, use_nesterov, target, 176 beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, param, 177 moment1, moment2, ps_parameter, cache_enable): 178 """Apply adam optimizer to the weight parameter using Tensor.""" 179 success = True 180 if ps_parameter and not cache_enable: 181 op_shape = P.Shape() 182 success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2, eps, gradient), 183 (op_shape(param), op_shape(moment1), op_shape(moment2))), param)) 184 else: 185 success = F.depend(success, opt(param, moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2, 186 eps, gradient)) 187 return success 188 189 190@_adam_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", 191 "Tensor", "Tensor") 192def _run_off_load_opt(opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, param, moment1, moment2): 193 """Apply AdamOffload optimizer to the weight parameter using Tensor.""" 194 success = True 195 delat_param = opt(moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2, eps, gradient) 196 success = F.depend(success, F.assign_add(param, delat_param)) 197 return success 198 199 200def _check_param_value(beta1, beta2, eps, prim_name): 201 """Check the type of inputs.""" 202 validator.check_value_type("beta1", beta1, [float], prim_name) 203 validator.check_value_type("beta2", beta2, [float], prim_name) 204 validator.check_value_type("eps", eps, [float], prim_name) 205 validator.check_float_range(beta1, 0.0, 1.0, Rel.INC_NEITHER, "beta1", prim_name) 206 validator.check_float_range(beta2, 0.0, 1.0, Rel.INC_NEITHER, "beta2", prim_name) 207 validator.check_positive_float(eps, "eps", prim_name) 208 209class AdamWeightDecayForBert(Optimizer): 210 """ 211 Implements the Adam algorithm to fix the weight decay. 212 213 Note: 214 When separating parameter groups, the weight decay in each group will be applied on the parameters if the 215 weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied 216 on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive. 217 218 To improve parameter groups performance, the customized order of parameters can be supported. 219 220 Args: 221 params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated, 222 the element in `params` must be class `Parameter`. When the `params` is a list of `dict`, the "params", 223 "lr", "weight_decay" and "order_params" are the keys can be parsed. 224 225 - params: Required. The value must be a list of `Parameter`. 226 227 - lr: Optional. If "lr" is in the keys, the value of the corresponding learning rate will be used. 228 If not, the `learning_rate` in the API will be used. 229 230 - weight_decay: Optional. If "weight_decay" is in the keys, the value of the corresponding weight decay 231 will be used. If not, the `weight_decay` in the API will be used. 232 233 - order_params: Optional. If "order_params" is in the keys, the value must be the order of parameters and 234 the order will be followed in the optimizer. There are no other keys in the `dict` and the parameters 235 which in the 'order_params' must be in one of group parameters. 236 237 learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate. 238 When the learning_rate is an Iterable or a Tensor in a 1D dimension, use the dynamic learning rate, then 239 the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule, 240 use dynamic learning rate, the i-th learning rate will be calculated during the process of training 241 according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor in a zero 242 dimension, use fixed learning rate. Other cases are not supported. The float learning rate must be 243 equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float. 244 Default: 1e-3. 245 beta1 (float): The exponential decay rate for the 1st moment estimations. Default: 0.9. 246 Should be in range (0.0, 1.0). 247 beta2 (float): The exponential decay rate for the 2nd moment estimations. Default: 0.999. 248 Should be in range (0.0, 1.0). 249 eps (float): Term added to the denominator to improve numerical stability. Default: 1e-6. 250 Should be greater than 0. 251 weight_decay (float): Weight decay (L2 penalty). It must be equal to or greater than 0. Default: 0.0. 252 253 Inputs: 254 - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`. 255 - **overflow** (tuple[Tensor]) - The overflow flag in dynamiclossscale. 256 257 Outputs: 258 tuple[bool], all elements are True. 259 260 Supported Platforms: 261 ``Ascend`` ``GPU`` 262 263 Examples: 264 >>> net = Net() 265 >>> #1) All parameters use the same learning rate and weight decay 266 >>> optim = AdamWeightDecay(params=net.trainable_params()) 267 >>> 268 >>> #2) Use parameter groups and set different values 269 >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) 270 >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) 271 >>> group_params = [{'params': conv_params, 'weight_decay': 0.01}, 272 ... {'params': no_conv_params, 'lr': 0.01}, 273 ... {'order_params': net.trainable_params()}] 274 >>> optim = AdamWeightDecay(group_params, learning_rate=0.1, weight_decay=0.0) 275 >>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01. 276 >>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0. 277 >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'. 278 >>> 279 >>> loss = nn.SoftmaxCrossEntropyWithLogits() 280 >>> model = Model(net, loss_fn=loss, optimizer=optim) 281 """ 282 def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0): 283 super(AdamWeightDecayForBert, self).__init__(learning_rate, params, weight_decay) 284 _check_param_value(beta1, beta2, eps, self.cls_name) 285 self.beta1 = Tensor(np.array([beta1]).astype(np.float32)) 286 self.beta2 = Tensor(np.array([beta2]).astype(np.float32)) 287 self.eps = Tensor(np.array([eps]).astype(np.float32)) 288 self.moments1 = self.parameters.clone(prefix="adam_m", init='zeros') 289 self.moments2 = self.parameters.clone(prefix="adam_v", init='zeros') 290 self.hyper_map = C.HyperMap() 291 self.op_select = P.Select() 292 self.op_cast = P.Cast() 293 self.op_reshape = P.Reshape() 294 self.op_shape = P.Shape() 295 296 def construct(self, gradients, overflow): 297 """AdamWeightDecayForBert""" 298 lr = self.get_lr() 299 cond = self.op_cast(F.fill(mstype.int32, self.op_shape(self.beta1), 1) *\ 300 self.op_reshape(overflow, (())), mstype.bool_) 301 beta1 = self.op_select(cond, self.op_cast(F.tuple_to_array((1.0,)), mstype.float32), self.beta1) 302 beta2 = self.op_select(cond, self.op_cast(F.tuple_to_array((1.0,)), mstype.float32), self.beta2) 303 if self.is_group: 304 if self.is_group_lr: 305 optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps), 306 lr, self.weight_decay, self.parameters, self.moments1, self.moments2, 307 gradients, self.decay_flags, self.optim_filter) 308 else: 309 optim_result = self.hyper_map(F.partial(_adam_opt, beta1, beta2, self.eps, lr, overflow), 310 self.weight_decay, self.parameters, self.moments1, self.moments2, 311 gradients, self.decay_flags, self.optim_filter) 312 else: 313 optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr, self.weight_decay), 314 self.parameters, self.moments1, self.moments2, 315 gradients, self.decay_flags, self.optim_filter) 316 if self.use_parallel: 317 self.broadcast_params(optim_result) 318 return optim_result 319 320class AdamWeightDecayOp(Optimizer): 321 """ 322 Implements the Adam algorithm to fix the weight decay. It is a complete operator, not a combination of other ops. 323 324 Note: 325 When separating parameter groups, the weight decay in each group will be applied on the parameters if the 326 weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied 327 on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive. 328 329 To improve parameter groups performance, the customized order of parameters can be supported. 330 331 Args: 332 params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated, 333 the element in `params` must be class `Parameter`. When the `params` is a list of `dict`, the "params", 334 "lr", "weight_decay" and "order_params" are the keys can be parsed. 335 336 - params: Required. The value must be a list of `Parameter`. 337 338 - lr: Optional. If "lr" is in the keys, the value of the corresponding learning rate will be used. 339 If not, the `learning_rate` in the API will be used. 340 341 - weight_decay: Optional. If "weight_decay" is in the keys, the value of the corresponding weight decay 342 will be used. If not, the `weight_decay` in the API will be used. 343 344 - order_params: Optional. If "order_params" is in the keys, the value must be the order of parameters and 345 the order will be followed in the optimizer. There are no other keys in the `dict` and the parameters 346 which in the 'order_params' must be in one of group parameters. 347 348 learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate. 349 When the learning_rate is an Iterable or a Tensor in a 1D dimension, use the dynamic learning rate, then 350 the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule, 351 use dynamic learning rate, the i-th learning rate will be calculated during the process of training 352 according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor in a zero 353 dimension, use fixed learning rate. Other cases are not supported. The float learning rate must be 354 equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float. 355 Default: 1e-3. 356 beta1 (float): The exponential decay rate for the 1st moment estimations. Default: 0.9. 357 Should be in range (0.0, 1.0). 358 beta2 (float): The exponential decay rate for the 2nd moment estimations. Default: 0.999. 359 Should be in range (0.0, 1.0). 360 eps (float): Term added to the denominator to improve numerical stability. Default: 1e-6. 361 Should be greater than 0. 362 weight_decay (float): Weight decay (L2 penalty). It must be equal to or greater than 0. Default: 0.0. 363 364 Inputs: 365 - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`. 366 367 Outputs: 368 tuple[bool], all elements are True. 369 370 Supported Platforms: 371 ``GPU`` 372 373 Examples: 374 >>> net = Net() 375 >>> #1) All parameters use the same learning rate and weight decay 376 >>> optim = AdamWeightDecayOp(params=net.trainable_params()) 377 >>> 378 >>> #2) Use parameter groups and set different values 379 >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) 380 >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) 381 >>> group_params = [{'params': conv_params, 'weight_decay': 0.01}, 382 ... {'params': no_conv_params, 'lr': 0.01}, 383 ... {'order_params': net.trainable_params()}] 384 >>> optim = AdamWeightDecayOp(group_params, learning_rate=0.1, weight_decay=0.0) 385 >>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01. 386 >>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0. 387 >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'. 388 >>> 389 >>> loss = nn.SoftmaxCrossEntropyWithLogits() 390 >>> model = Model(net, loss_fn=loss, optimizer=optim) 391 """ 392 def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0): 393 super(AdamWeightDecayOp, self).__init__(learning_rate, params, weight_decay) 394 _check_param_value(beta1, beta2, eps, self.cls_name) 395 self.beta1 = Tensor(np.array([beta1]).astype(np.float32)) 396 self.beta2 = Tensor(np.array([beta2]).astype(np.float32)) 397 self.eps = Tensor(np.array([eps]).astype(np.float32)) 398 self.moments1 = self.parameters.clone(prefix="adam_m", init='zeros') 399 self.moments2 = self.parameters.clone(prefix="adam_v", init='zeros') 400 self.hyper_map = C.HyperMap() 401 402 def construct(self, gradients): 403 """AdamWeightDecayOp""" 404 lr = self.get_lr() 405 if self.is_group: 406 if self.is_group_lr: 407 optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps), 408 lr, self.weight_decay, self.parameters, self.moments1, self.moments2, 409 gradients, self.decay_flags, self.optim_filter) 410 else: 411 optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr), 412 self.weight_decay, self.parameters, self.moments1, self.moments2, 413 gradients, self.decay_flags, self.optim_filter) 414 else: 415 optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr, self.weight_decay), 416 self.parameters, self.moments1, self.moments2, 417 gradients, self.decay_flags, self.optim_filter) 418 if self.use_parallel: 419 self.broadcast_params(optim_result) 420 return optim_result 421