1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""AdamWeightDecayForBert, a customized Adam for bert. Input: gradient, overflow flag.""" 16import numpy as np 17 18from mindspore.common import dtype as mstype 19from mindspore.ops import operations as P 20from mindspore.ops import composite as C 21from mindspore.ops import functional as F 22from mindspore.common.tensor import Tensor 23from mindspore._checkparam import Validator as validator 24from mindspore._checkparam import Rel 25from mindspore.nn.optim.optimizer import Optimizer 26 27_adam_opt = C.MultitypeFuncGraph("adam_opt") 28_scaler_one = Tensor(1, mstype.int32) 29_scaler_ten = Tensor(10, mstype.float32) 30 31@_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor", 32 "Tensor", "Bool", "Bool") 33def _update_run_kernel(beta1, beta2, eps, lr, weight_decay, param, m, v, gradient, decay_flags, optim_filter): 34 """ 35 Update parameters by AdamWeightDecay op. 36 """ 37 if optim_filter: 38 adam = P.AdamWeightDecay() 39 if decay_flags: 40 next_param = adam(param, m, v, lr, beta1, beta2, eps, weight_decay, gradient) 41 else: 42 next_param = adam(param, m, v, lr, beta1, beta2, eps, 0.0, gradient) 43 return next_param 44 return gradient 45 46@_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor", 47 "Tensor", "Bool", "Bool") 48def _update_run_op(beta1, beta2, eps, lr, overflow, weight_decay, param, m, v, gradient, decay_flag, optim_filter): 49 """ 50 Update parameters. 51 52 Args: 53 beta1 (Tensor): The exponential decay rate for the 1st moment estimations. Should be in range (0.0, 1.0). 54 beta2 (Tensor): The exponential decay rate for the 2nd moment estimations. Should be in range (0.0, 1.0). 55 eps (Tensor): Term added to the denominator to improve numerical stability. Should be greater than 0. 56 lr (Tensor): Learning rate. 57 overflow (Tensor): Whether overflow occurs. 58 weight_decay (Number): Weight decay. Should be equal to or greater than 0. 59 param (Tensor): Parameters. 60 m (Tensor): m value of parameters. 61 v (Tensor): v value of parameters. 62 gradient (Tensor): Gradient of parameters. 63 decay_flag (bool): Applies weight decay or not. 64 optim_filter (bool): Applies parameter update or not. 65 66 Returns: 67 Tensor, the new value of v after updating. 68 """ 69 if optim_filter: 70 op_mul = P.Mul() 71 op_square = P.Square() 72 op_sqrt = P.Sqrt() 73 op_cast = P.Cast() 74 op_reshape = P.Reshape() 75 op_shape = P.Shape() 76 op_select = P.Select() 77 78 param_fp32 = op_cast(param, mstype.float32) 79 m_fp32 = op_cast(m, mstype.float32) 80 v_fp32 = op_cast(v, mstype.float32) 81 gradient_fp32 = op_cast(gradient, mstype.float32) 82 83 cond = op_cast(F.fill(mstype.int32, op_shape(m_fp32), 1) * op_reshape(overflow, (())), mstype.bool_) 84 next_m = op_mul(beta1, m_fp32) + op_select(cond, m_fp32,\ 85 op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta1, gradient_fp32)) 86 87 next_v = op_mul(beta2, v_fp32) + op_select(cond, v_fp32,\ 88 op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta2, op_square(gradient_fp32))) 89 90 update = next_m / (eps + op_sqrt(next_v)) 91 if decay_flag: 92 update = op_mul(weight_decay, param_fp32) + update 93 94 update_with_lr = op_mul(lr, update) 95 zeros = F.fill(mstype.float32, op_shape(param_fp32), 0) 96 next_param = param_fp32 - op_select(cond, zeros, op_reshape(update_with_lr, op_shape(param_fp32))) 97 98 next_param = F.depend(next_param, F.assign(param, op_cast(next_param, F.dtype(param)))) 99 next_param = F.depend(next_param, F.assign(m, op_cast(next_m, F.dtype(m)))) 100 next_param = F.depend(next_param, F.assign(v, op_cast(next_v, F.dtype(v)))) 101 102 return op_cast(next_param, F.dtype(param)) 103 return gradient 104 105 106@_adam_opt.register("Function", "Function", "Function", "Function", "Bool", "Bool", "Bool", "Tensor", "Tensor", 107 "Tensor", "Tensor", "Tensor", "Tensor", "RowTensor", "Tensor", "Tensor", "Tensor", "Bool", "Bool") 108def _run_opt_with_sparse(opt, sparse_opt, push, pull, use_locking, use_nesterov, target, beta1_power, 109 beta2_power, beta1, beta2, eps, lr, gradient, param, m, v, ps_parameter, cache_enable): 110 """Apply sparse adam optimizer to the weight parameter when the gradient is sparse.""" 111 success = True 112 indices = gradient.indices 113 values = gradient.values 114 if ps_parameter and not cache_enable: 115 op_shape = P.Shape() 116 shapes = (op_shape(param), op_shape(m), op_shape(v), 117 op_shape(beta1_power), op_shape(beta2_power), op_shape(lr), op_shape(beta1), 118 op_shape(beta2), op_shape(eps), op_shape(values), op_shape(indices)) 119 success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2, 120 eps, values, indices), shapes), param)) 121 return success 122 123 if not target: 124 success = F.depend(success, sparse_opt(param, m, v, beta1_power, beta2_power, lr, beta1, beta2, 125 eps, values, indices)) 126 else: 127 op_mul = P.Mul() 128 op_square = P.Square() 129 op_sqrt = P.Sqrt() 130 scatter_add = P.ScatterAdd(use_locking) 131 132 F.assign(m, op_mul(beta1, m)) 133 F.assign(v, op_mul(beta2, v)) 134 135 grad_indices = gradient.indices 136 grad_value = gradient.values 137 138 next_m = scatter_add(m, 139 grad_indices, 140 op_mul(F.tuple_to_array((1.0,)) - beta1, grad_value)) 141 142 next_v = scatter_add(v, 143 grad_indices, 144 op_mul(F.tuple_to_array((1.0,)) - beta2, op_square(grad_value))) 145 146 if use_nesterov: 147 m_temp = next_m * _scaler_ten 148 F.assign(m, op_mul(beta1, next_m)) 149 div_value = scatter_add(m, 150 op_mul(grad_indices, _scaler_one), 151 op_mul(F.tuple_to_array((1.0,)) - beta1, grad_value)) 152 param_update = div_value / (op_sqrt(next_v) + eps) 153 154 F.assign(m, m_temp / _scaler_ten) 155 156 157 else: 158 param_update = next_m / (op_sqrt(next_v) + eps) 159 160 lr_t = lr * op_sqrt(1 - beta2_power) / (1 - beta1_power) 161 162 next_param = param - lr_t * param_update 163 164 165 166 success = F.depend(success, F.assign(param, next_param)) 167 success = F.depend(success, F.assign(m, next_m)) 168 success = F.depend(success, F.assign(v, next_v)) 169 170 return success 171 172 173@_adam_opt.register("Function", "Function", "Function", "Function", "Bool", "Bool", "Bool", "Tensor", "Tensor", 174 "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool", "Bool") 175def _run_opt_with_one_number(opt, sparse_opt, push, pull, use_locking, use_nesterov, target, 176 beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, param, 177 moment1, moment2, ps_parameter, cache_enable): 178 """Apply adam optimizer to the weight parameter using Tensor.""" 179 success = True 180 if ps_parameter and not cache_enable: 181 op_shape = P.Shape() 182 success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2, eps, gradient), 183 (op_shape(param), op_shape(moment1), op_shape(moment2))), param)) 184 else: 185 success = F.depend(success, opt(param, moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2, 186 eps, gradient)) 187 return success 188 189 190@_adam_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", 191 "Tensor", "Tensor") 192def _run_off_load_opt(opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, param, moment1, moment2): 193 """Apply AdamOffload optimizer to the weight parameter using Tensor.""" 194 success = True 195 delat_param = opt(moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2, eps, gradient) 196 success = F.depend(success, F.assign_add(param, delat_param)) 197 return success 198 199 200def _check_param_value(beta1, beta2, eps, prim_name): 201 """Check the type of inputs.""" 202 validator.check_value_type("beta1", beta1, [float], prim_name) 203 validator.check_value_type("beta2", beta2, [float], prim_name) 204 validator.check_value_type("eps", eps, [float], prim_name) 205 validator.check_float_range(beta1, 0.0, 1.0, Rel.INC_NEITHER, "beta1", prim_name) 206 validator.check_float_range(beta2, 0.0, 1.0, Rel.INC_NEITHER, "beta2", prim_name) 207 validator.check_positive_float(eps, "eps", prim_name) 208 209class AdamWeightDecayForBert(Optimizer): 210 """ 211 Implements the Adam algorithm to fix the weight decay. 212 213 Note: 214 When separating parameter groups, the weight decay in each group will be applied on the parameters if the 215 weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied 216 on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive. 217 218 To improve parameter groups performance, the customized order of parameters can be supported. 219 220 Args: 221 params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated, 222 the element in `params` must be class `Parameter`. When the `params` is a list of `dict`, the "params", 223 "lr", "weight_decay" and "order_params" are the keys can be parsed. 224 225 - params: Required. The value must be a list of `Parameter`. 226 227 - lr: Optional. If "lr" is in the keys, the value of the corresponding learning rate will be used. 228 If not, the `learning_rate` in the API will be used. 229 230 - weight_decay: Optional. If "weight_decay" is in the keys, the value of the corresponding weight decay 231 will be used. If not, the `weight_decay` in the API will be used. 232 233 - order_params: Optional. If "order_params" is in the keys, the value must be the order of parameters and 234 the order will be followed in the optimizer. There are no other keys in the `dict` and the parameters 235 which in the 'order_params' must be in one of group parameters. 236 237 learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate. 238 When the learning_rate is an Iterable or a Tensor in a 1D dimension, use the dynamic learning rate, then 239 the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule, 240 use dynamic learning rate, the i-th learning rate will be calculated during the process of training 241 according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor in a zero 242 dimension, use fixed learning rate. Other cases are not supported. The float learning rate must be 243 equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float. 244 Default: 1e-3. 245 beta1 (float): The exponential decay rate for the 1st moment estimations. Default: 0.9. 246 Should be in range (0.0, 1.0). 247 beta2 (float): The exponential decay rate for the 2nd moment estimations. Default: 0.999. 248 Should be in range (0.0, 1.0). 249 eps (float): Term added to the denominator to improve numerical stability. Default: 1e-6. 250 Should be greater than 0. 251 weight_decay (float): Weight decay (L2 penalty). It must be equal to or greater than 0. Default: 0.0. 252 253 Inputs: 254 - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`. 255 - **overflow** (tuple[Tensor]) - The overflow flag in dynamiclossscale. 256 257 Outputs: 258 tuple[bool], all elements are True. 259 260 Supported Platforms: 261 ``Ascend`` ``GPU`` 262 263 Examples: 264 >>> net = Net() 265 >>> #1) All parameters use the same learning rate and weight decay 266 >>> optim = AdamWeightDecay(params=net.trainable_params()) 267 >>> 268 >>> #2) Use parameter groups and set different values 269 >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) 270 >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) 271 >>> group_params = [{'params': conv_params, 'weight_decay': 0.01}, 272 ... {'params': no_conv_params, 'lr': 0.01}, 273 ... {'order_params': net.trainable_params()}] 274 >>> optim = AdamWeightDecay(group_params, learning_rate=0.1, weight_decay=0.0) 275 >>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01. 276 >>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0. 277 >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'. 278 >>> 279 >>> loss = nn.SoftmaxCrossEntropyWithLogits() 280 >>> model = Model(net, loss_fn=loss, optimizer=optim) 281 """ 282 def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0): 283 super(AdamWeightDecayForBert, self).__init__(learning_rate, params, weight_decay) 284 _check_param_value(beta1, beta2, eps, self.cls_name) 285 self.beta1 = Tensor(np.array([beta1]).astype(np.float32)) 286 self.beta2 = Tensor(np.array([beta2]).astype(np.float32)) 287 self.eps = Tensor(np.array([eps]).astype(np.float32)) 288 self.moments1 = self.parameters.clone(prefix="adam_m", init='zeros') 289 self.moments2 = self.parameters.clone(prefix="adam_v", init='zeros') 290 self.hyper_map = C.HyperMap() 291 self.op_select = P.Select() 292 self.op_cast = P.Cast() 293 self.op_reshape = P.Reshape() 294 self.op_shape = P.Shape() 295 296 def construct(self, gradients, overflow): 297 """AdamWeightDecayForBert""" 298 lr = self.get_lr() 299 cond = self.op_cast(F.fill(mstype.int32, self.op_shape(self.beta1), 1) *\ 300 self.op_reshape(overflow, (())), mstype.bool_) 301 beta1 = self.op_select(cond, self.op_cast(F.tuple_to_array((1.0,)), mstype.float32), self.beta1) 302 beta2 = self.op_select(cond, self.op_cast(F.tuple_to_array((1.0,)), mstype.float32), self.beta2) 303 if self.is_group: 304 if self.is_group_lr: 305 optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps), 306 lr, self.weight_decay, self.parameters, self.moments1, self.moments2, 307 gradients, self.decay_flags, self.optim_filter) 308 else: 309 optim_result = self.hyper_map(F.partial(_adam_opt, beta1, beta2, self.eps, lr, overflow), 310 self.weight_decay, self.parameters, self.moments1, self.moments2, 311 gradients, self.decay_flags, self.optim_filter) 312 else: 313 optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr, self.weight_decay), 314 self.parameters, self.moments1, self.moments2, 315 gradients, self.decay_flags, self.optim_filter) 316 if self.use_parallel: 317 self.broadcast_params(optim_result) 318 return optim_result 319 320class AdamWeightDecayOp(Optimizer): 321 """ 322 Implements the Adam algorithm to fix the weight decay. It is a complete operator, not a combination of other ops. 323 324 Note: 325 When separating parameter groups, the weight decay in each group will be applied on the parameters if the 326 weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied 327 on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive. 328 329 To improve parameter groups performance, the customized order of parameters can be supported. 330 331 Args: 332 params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated, 333 the element in `params` must be class `Parameter`. When the `params` is a list of `dict`, the "params", 334 "lr", "weight_decay" and "order_params" are the keys can be parsed. 335 336 - params: Required. The value must be a list of `Parameter`. 337 338 - lr: Optional. If "lr" is in the keys, the value of the corresponding learning rate will be used. 339 If not, the `learning_rate` in the API will be used. 340 341 - weight_decay: Optional. If "weight_decay" is in the keys, the value of the corresponding weight decay 342 will be used. If not, the `weight_decay` in the API will be used. 343 344 - order_params: Optional. If "order_params" is in the keys, the value must be the order of parameters and 345 the order will be followed in the optimizer. There are no other keys in the `dict` and the parameters 346 which in the 'order_params' must be in one of group parameters. 347 348 learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate. 349 When the learning_rate is an Iterable or a Tensor in a 1D dimension, use the dynamic learning rate, then 350 the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule, 351 use dynamic learning rate, the i-th learning rate will be calculated during the process of training 352 according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor in a zero 353 dimension, use fixed learning rate. Other cases are not supported. The float learning rate must be 354 equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float. 355 Default: 1e-3. 356 beta1 (float): The exponential decay rate for the 1st moment estimations. Default: 0.9. 357 Should be in range (0.0, 1.0). 358 beta2 (float): The exponential decay rate for the 2nd moment estimations. Default: 0.999. 359 Should be in range (0.0, 1.0). 360 eps (float): Term added to the denominator to improve numerical stability. Default: 1e-6. 361 Should be greater than 0. 362 weight_decay (float): Weight decay (L2 penalty). It must be equal to or greater than 0. Default: 0.0. 363 364 Inputs: 365 - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`. 366 367 Outputs: 368 tuple[bool], all elements are True. 369 370 Supported Platforms: 371 ``GPU`` 372 373 Examples: 374 >>> net = Net() 375 >>> #1) All parameters use the same learning rate and weight decay 376 >>> optim = AdamWeightDecayOp(params=net.trainable_params()) 377 >>> 378 >>> #2) Use parameter groups and set different values 379 >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) 380 >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) 381 >>> group_params = [{'params': conv_params, 'weight_decay': 0.01}, 382 ... {'params': no_conv_params, 'lr': 0.01}, 383 ... {'order_params': net.trainable_params()}] 384 >>> optim = AdamWeightDecayOp(group_params, learning_rate=0.1, weight_decay=0.0) 385 >>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01. 386 >>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0. 387 >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'. 388 >>> 389 >>> loss = nn.SoftmaxCrossEntropyWithLogits() 390 >>> model = Model(net, loss_fn=loss, optimizer=optim) 391 """ 392 def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0): 393 super(AdamWeightDecayOp, self).__init__(learning_rate, params, weight_decay) 394 _check_param_value(beta1, beta2, eps, self.cls_name) 395 self.beta1 = Tensor(np.array([beta1]).astype(np.float32)) 396 self.beta2 = Tensor(np.array([beta2]).astype(np.float32)) 397 self.eps = Tensor(np.array([eps]).astype(np.float32)) 398 self.moments1 = self.parameters.clone(prefix="adam_m", init='zeros') 399 self.moments2 = self.parameters.clone(prefix="adam_v", init='zeros') 400 self.hyper_map = C.HyperMap() 401 402 def construct(self, gradients): 403 """AdamWeightDecayOp""" 404 lr = self.get_lr() 405 if self.is_group: 406 if self.is_group_lr: 407 optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps), 408 lr, self.weight_decay, self.parameters, self.moments1, self.moments2, 409 gradients, self.decay_flags, self.optim_filter) 410 else: 411 optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr), 412 self.weight_decay, self.parameters, self.moments1, self.moments2, 413 gradients, self.decay_flags, self.optim_filter) 414 else: 415 optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr, self.weight_decay), 416 self.parameters, self.moments1, self.moments2, 417 gradients, self.decay_flags, self.optim_filter) 418 if self.use_parallel: 419 self.broadcast_params(optim_result) 420 return optim_result 421