1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""AdamWeightDecayForBert, a customized Adam for bert. Input: gradient, overflow flag.""" 16import numpy as np 17 18from mindspore.common import dtype as mstype 19from mindspore.ops import operations as P 20from mindspore.ops import composite as C 21from mindspore.ops import functional as F 22from mindspore.common.tensor import Tensor 23from mindspore._checkparam import Validator as validator 24from mindspore._checkparam import Rel 25from mindspore.nn.optim.optimizer import Optimizer 26 27_adam_opt = C.MultitypeFuncGraph("adam_opt") 28_scaler_one = Tensor(1, mstype.int32) 29_scaler_ten = Tensor(10, mstype.float32) 30 31@_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor", 32 "Tensor", "Bool", "Bool") 33def _update_run_kernel(beta1, beta2, eps, lr, weight_decay, param, m, v, gradient, decay_flags, optim_filter): 34 """ 35 Update parameters by AdamWeightDecay op. 36 """ 37 if optim_filter: 38 adam = P.AdamWeightDecay() 39 if decay_flags: 40 next_param = adam(param, m, v, lr, beta1, beta2, eps, weight_decay, gradient) 41 else: 42 next_param = adam(param, m, v, lr, beta1, beta2, eps, 0.0, gradient) 43 return next_param 44 return gradient 45 46@_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor", 47 "Tensor", "Bool", "Bool") 48def _update_run_op(beta1, beta2, eps, lr, overflow, weight_decay, param, m, v, gradient, decay_flag, optim_filter): 49 """ 50 Update parameters. 51 52 Args: 53 beta1 (Tensor): The exponential decay rate for the 1st moment estimations. Should be in range (0.0, 1.0). 54 beta2 (Tensor): The exponential decay rate for the 2nd moment estimations. Should be in range (0.0, 1.0). 55 eps (Tensor): Term added to the denominator to improve numerical stability. Should be greater than 0. 56 lr (Tensor): Learning rate. 57 overflow (Tensor): Whether overflow occurs. 58 weight_decay (Number): Weight decay. Should be equal to or greater than 0. 59 param (Tensor): Parameters. 60 m (Tensor): m value of parameters. 61 v (Tensor): v value of parameters. 62 gradient (Tensor): Gradient of parameters. 63 decay_flag (bool): Applies weight decay or not. 64 optim_filter (bool): Applies parameter update or not. 65 66 Returns: 67 Tensor, the new value of v after updating. 68 """ 69 if optim_filter: 70 op_mul = P.Mul() 71 op_square = P.Square() 72 op_sqrt = P.Sqrt() 73 op_cast = P.Cast() 74 op_reshape = P.Reshape() 75 op_shape = P.Shape() 76 op_select = P.Select() 77 78 param_fp32 = op_cast(param, mstype.float32) 79 m_fp32 = op_cast(m, mstype.float32) 80 v_fp32 = op_cast(v, mstype.float32) 81 gradient_fp32 = op_cast(gradient, mstype.float32) 82 83 cond = op_cast(F.fill(mstype.int32, op_shape(m_fp32), 1) * op_reshape(overflow, (())), mstype.bool_) 84 next_m = op_mul(beta1, m_fp32) + op_select(cond, m_fp32,\ 85 op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta1, gradient_fp32)) 86 87 next_v = op_mul(beta2, v_fp32) + op_select(cond, v_fp32,\ 88 op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta2, op_square(gradient_fp32))) 89 90 update = next_m / (eps + op_sqrt(next_v)) 91 if decay_flag: 92 update = op_mul(weight_decay, param_fp32) + update 93 94 update_with_lr = op_mul(lr, update) 95 zeros = F.fill(mstype.float32, op_shape(param_fp32), 0) 96 next_param = param_fp32 - op_select(cond, zeros, op_reshape(update_with_lr, op_shape(param_fp32))) 97 98 next_param = F.depend(next_param, F.assign(param, op_cast(next_param, F.dtype(param)))) 99 next_param = F.depend(next_param, F.assign(m, op_cast(next_m, F.dtype(m)))) 100 next_param = F.depend(next_param, F.assign(v, op_cast(next_v, F.dtype(v)))) 101 102 return op_cast(next_param, F.dtype(param)) 103 return gradient 104 105 106@_adam_opt.register("Function", "Function", "Function", "Function", "Bool", "Bool", "Bool", "Tensor", "Tensor", 107 "Tensor", "Tensor", "Tensor", "Tensor", "RowTensor", "Tensor", "Tensor", "Tensor", "Bool", "Bool") 108def _run_opt_with_sparse(opt, sparse_opt, push, pull, use_locking, use_nesterov, target, beta1_power, 109 beta2_power, beta1, beta2, eps, lr, gradient, param, m, v, ps_parameter, cache_enable): 110 """Apply sparse adam optimizer to the weight parameter when the gradient is sparse.""" 111 success = True 112 indices = gradient.indices 113 values = gradient.values 114 if ps_parameter and not cache_enable: 115 op_shape = P.Shape() 116 shapes = (op_shape(param), op_shape(m), op_shape(v), 117 op_shape(beta1_power), op_shape(beta2_power), op_shape(lr), op_shape(beta1), 118 op_shape(beta2), op_shape(eps), op_shape(values), op_shape(indices)) 119 success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2, 120 eps, values, indices), shapes), param)) 121 return success 122 123 if not target: 124 success = F.depend(success, sparse_opt(param, m, v, beta1_power, beta2_power, lr, beta1, beta2, 125 eps, values, indices)) 126 else: 127 op_mul = P.Mul() 128 op_square = P.Square() 129 op_sqrt = P.Sqrt() 130 scatter_add = P.ScatterAdd(use_locking) 131 132 F.assign(m, op_mul(beta1, m)) 133 F.assign(v, op_mul(beta2, v)) 134 135 grad_indices = gradient.indices 136 grad_value = gradient.values 137 138 next_m = scatter_add(m, 139 grad_indices, 140 op_mul(F.tuple_to_array((1.0,)) - beta1, grad_value)) 141 142 next_v = scatter_add(v, 143 grad_indices, 144 op_mul(F.tuple_to_array((1.0,)) - beta2, op_square(grad_value))) 145 146 if use_nesterov: 147 m_temp = next_m * _scaler_ten 148 F.assign(m, op_mul(beta1, next_m)) 149 div_value = scatter_add(m, 150 op_mul(grad_indices, _scaler_one), 151 op_mul(F.tuple_to_array((1.0,)) - beta1, grad_value)) 152 param_update = div_value / (op_sqrt(next_v) + eps) 153 154 F.assign(m, m_temp / _scaler_ten) 155 156 else: 157 param_update = next_m / (op_sqrt(next_v) + eps) 158 159 lr_t = lr * op_sqrt(1 - beta2_power) / (1 - beta1_power) 160 161 next_param = param - lr_t * param_update 162 163 164 165 success = F.depend(success, F.assign(param, next_param)) 166 success = F.depend(success, F.assign(m, next_m)) 167 success = F.depend(success, F.assign(v, next_v)) 168 169 return success 170 171 172@_adam_opt.register("Function", "Function", "Function", "Function", "Bool", "Bool", "Bool", "Tensor", "Tensor", 173 "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool", "Bool") 174def _run_opt_with_one_number(opt, sparse_opt, push, pull, use_locking, use_nesterov, target, 175 beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, param, 176 moment1, moment2, ps_parameter, cache_enable): 177 """Apply adam optimizer to the weight parameter using Tensor.""" 178 success = True 179 if ps_parameter and not cache_enable: 180 op_shape = P.Shape() 181 success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2, eps, gradient), 182 (op_shape(param), op_shape(moment1), op_shape(moment2))), param)) 183 else: 184 success = F.depend(success, opt(param, moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2, 185 eps, gradient)) 186 return success 187 188 189@_adam_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", 190 "Tensor", "Tensor") 191def _run_off_load_opt(opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, param, moment1, moment2): 192 """Apply AdamOffload optimizer to the weight parameter using Tensor.""" 193 success = True 194 delat_param = opt(moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2, eps, gradient) 195 success = F.depend(success, F.assign_add(param, delat_param)) 196 return success 197 198 199def _check_param_value(beta1, beta2, eps, prim_name): 200 """Check the type of inputs.""" 201 validator.check_value_type("beta1", beta1, [float], prim_name) 202 validator.check_value_type("beta2", beta2, [float], prim_name) 203 validator.check_value_type("eps", eps, [float], prim_name) 204 validator.check_float_range(beta1, 0.0, 1.0, Rel.INC_NEITHER, "beta1", prim_name) 205 validator.check_float_range(beta2, 0.0, 1.0, Rel.INC_NEITHER, "beta2", prim_name) 206 validator.check_positive_float(eps, "eps", prim_name) 207 208class AdamWeightDecayForBert(Optimizer): 209 """ 210 Implements the Adam algorithm to fix the weight decay. 211 212 Note: 213 When separating parameter groups, the weight decay in each group will be applied on the parameters if the 214 weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied 215 on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive. 216 217 To improve parameter groups performance, the customized order of parameters can be supported. 218 219 Args: 220 params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated, 221 the element in `params` must be class `Parameter`. When the `params` is a list of `dict`, the "params", 222 "lr", "weight_decay" and "order_params" are the keys can be parsed. 223 224 - params: Required. The value must be a list of `Parameter`. 225 226 - lr: Optional. If "lr" is in the keys, the value of the corresponding learning rate will be used. 227 If not, the `learning_rate` in the API will be used. 228 229 - weight_decay: Optional. If "weight_decay" is in the keys, the value of the corresponding weight decay 230 will be used. If not, the `weight_decay` in the API will be used. 231 232 - order_params: Optional. If "order_params" is in the keys, the value must be the order of parameters and 233 the order will be followed in the optimizer. There are no other keys in the `dict` and the parameters 234 which in the 'order_params' must be in one of group parameters. 235 236 learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate. 237 When the learning_rate is an Iterable or a Tensor in a 1D dimension, use the dynamic learning rate, then 238 the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule, 239 use dynamic learning rate, the i-th learning rate will be calculated during the process of training 240 according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor in a zero 241 dimension, use fixed learning rate. Other cases are not supported. The float learning rate must be 242 equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float. 243 Default: 1e-3. 244 beta1 (float): The exponential decay rate for the 1st moment estimations. Default: 0.9. 245 Should be in range (0.0, 1.0). 246 beta2 (float): The exponential decay rate for the 2nd moment estimations. Default: 0.999. 247 Should be in range (0.0, 1.0). 248 eps (float): Term added to the denominator to improve numerical stability. Default: 1e-6. 249 Should be greater than 0. 250 weight_decay (float): Weight decay (L2 penalty). It must be equal to or greater than 0. Default: 0.0. 251 252 Inputs: 253 - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`. 254 - **overflow** (tuple[Tensor]) - The overflow flag in dynamiclossscale. 255 256 Outputs: 257 tuple[bool], all elements are True. 258 259 Supported Platforms: 260 ``Ascend`` ``GPU`` 261 262 Examples: 263 >>> net = Net() 264 >>> #1) All parameters use the same learning rate and weight decay 265 >>> optim = AdamWeightDecay(params=net.trainable_params()) 266 >>> 267 >>> #2) Use parameter groups and set different values 268 >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) 269 >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) 270 >>> group_params = [{'params': conv_params, 'weight_decay': 0.01}, 271 ... {'params': no_conv_params, 'lr': 0.01}, 272 ... {'order_params': net.trainable_params()}] 273 >>> optim = AdamWeightDecay(group_params, learning_rate=0.1, weight_decay=0.0) 274 >>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01. 275 >>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0. 276 >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'. 277 >>> 278 >>> loss = nn.SoftmaxCrossEntropyWithLogits() 279 >>> model = Model(net, loss_fn=loss, optimizer=optim) 280 """ 281 def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0): 282 super(AdamWeightDecayForBert, self).__init__(learning_rate, params, weight_decay) 283 _check_param_value(beta1, beta2, eps, self.cls_name) 284 self.beta1 = Tensor(np.array([beta1]).astype(np.float32)) 285 self.beta2 = Tensor(np.array([beta2]).astype(np.float32)) 286 self.eps = Tensor(np.array([eps]).astype(np.float32)) 287 self.moments1 = self.parameters.clone(prefix="adam_m", init='zeros') 288 self.moments2 = self.parameters.clone(prefix="adam_v", init='zeros') 289 self.hyper_map = C.HyperMap() 290 self.op_select = P.Select() 291 self.op_cast = P.Cast() 292 self.op_reshape = P.Reshape() 293 self.op_shape = P.Shape() 294 295 def construct(self, gradients, overflow): 296 """AdamWeightDecayForBert""" 297 lr = self.get_lr() 298 cond = self.op_cast(F.fill(mstype.int32, self.op_shape(self.beta1), 1) *\ 299 self.op_reshape(overflow, (())), mstype.bool_) 300 beta1 = self.op_select(cond, self.op_cast(F.tuple_to_array((1.0,)), mstype.float32), self.beta1) 301 beta2 = self.op_select(cond, self.op_cast(F.tuple_to_array((1.0,)), mstype.float32), self.beta2) 302 if self.is_group: 303 if self.is_group_lr: 304 optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps), 305 lr, self.weight_decay, self.parameters, self.moments1, self.moments2, 306 gradients, self.decay_flags, self.optim_filter) 307 else: 308 optim_result = self.hyper_map(F.partial(_adam_opt, beta1, beta2, self.eps, lr, overflow), 309 self.weight_decay, self.parameters, self.moments1, self.moments2, 310 gradients, self.decay_flags, self.optim_filter) 311 else: 312 optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr, self.weight_decay), 313 self.parameters, self.moments1, self.moments2, 314 gradients, self.decay_flags, self.optim_filter) 315 if self.use_parallel: 316 self.broadcast_params(optim_result) 317 return optim_result 318 319class AdamWeightDecayOp(Optimizer): 320 """ 321 Implements the Adam algorithm to fix the weight decay. It is a complete operator, not a combination of other ops. 322 323 Note: 324 When separating parameter groups, the weight decay in each group will be applied on the parameters if the 325 weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied 326 on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive. 327 328 To improve parameter groups performance, the customized order of parameters can be supported. 329 330 Args: 331 params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated, 332 the element in `params` must be class `Parameter`. When the `params` is a list of `dict`, the "params", 333 "lr", "weight_decay" and "order_params" are the keys can be parsed. 334 335 - params: Required. The value must be a list of `Parameter`. 336 337 - lr: Optional. If "lr" is in the keys, the value of the corresponding learning rate will be used. 338 If not, the `learning_rate` in the API will be used. 339 340 - weight_decay: Optional. If "weight_decay" is in the keys, the value of the corresponding weight decay 341 will be used. If not, the `weight_decay` in the API will be used. 342 343 - order_params: Optional. If "order_params" is in the keys, the value must be the order of parameters and 344 the order will be followed in the optimizer. There are no other keys in the `dict` and the parameters 345 which in the 'order_params' must be in one of group parameters. 346 347 learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate. 348 When the learning_rate is an Iterable or a Tensor in a 1D dimension, use the dynamic learning rate, then 349 the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule, 350 use dynamic learning rate, the i-th learning rate will be calculated during the process of training 351 according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor in a zero 352 dimension, use fixed learning rate. Other cases are not supported. The float learning rate must be 353 equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float. 354 Default: 1e-3. 355 beta1 (float): The exponential decay rate for the 1st moment estimations. Default: 0.9. 356 Should be in range (0.0, 1.0). 357 beta2 (float): The exponential decay rate for the 2nd moment estimations. Default: 0.999. 358 Should be in range (0.0, 1.0). 359 eps (float): Term added to the denominator to improve numerical stability. Default: 1e-6. 360 Should be greater than 0. 361 weight_decay (float): Weight decay (L2 penalty). It must be equal to or greater than 0. Default: 0.0. 362 363 Inputs: 364 - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`. 365 366 Outputs: 367 tuple[bool], all elements are True. 368 369 Supported Platforms: 370 ``GPU`` 371 372 Examples: 373 >>> net = Net() 374 >>> #1) All parameters use the same learning rate and weight decay 375 >>> optim = AdamWeightDecayOp(params=net.trainable_params()) 376 >>> 377 >>> #2) Use parameter groups and set different values 378 >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) 379 >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) 380 >>> group_params = [{'params': conv_params, 'weight_decay': 0.01}, 381 ... {'params': no_conv_params, 'lr': 0.01}, 382 ... {'order_params': net.trainable_params()}] 383 >>> optim = AdamWeightDecayOp(group_params, learning_rate=0.1, weight_decay=0.0) 384 >>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01. 385 >>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0. 386 >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'. 387 >>> 388 >>> loss = nn.SoftmaxCrossEntropyWithLogits() 389 >>> model = Model(net, loss_fn=loss, optimizer=optim) 390 """ 391 def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0): 392 super(AdamWeightDecayOp, self).__init__(learning_rate, params, weight_decay) 393 _check_param_value(beta1, beta2, eps, self.cls_name) 394 self.beta1 = Tensor(np.array([beta1]).astype(np.float32)) 395 self.beta2 = Tensor(np.array([beta2]).astype(np.float32)) 396 self.eps = Tensor(np.array([eps]).astype(np.float32)) 397 self.moments1 = self.parameters.clone(prefix="adam_m", init='zeros') 398 self.moments2 = self.parameters.clone(prefix="adam_v", init='zeros') 399 self.hyper_map = C.HyperMap() 400 401 def construct(self, gradients): 402 """AdamWeightDecayOp""" 403 lr = self.get_lr() 404 if self.is_group: 405 if self.is_group_lr: 406 optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps), 407 lr, self.weight_decay, self.parameters, self.moments1, self.moments2, 408 gradients, self.decay_flags, self.optim_filter) 409 else: 410 optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr), 411 self.weight_decay, self.parameters, self.moments1, self.moments2, 412 gradients, self.decay_flags, self.optim_filter) 413 else: 414 optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr, self.weight_decay), 415 self.parameters, self.moments1, self.moments2, 416 gradients, self.decay_flags, self.optim_filter) 417 if self.use_parallel: 418 self.broadcast_params(optim_result) 419 return optim_result 420