1# Copyright 2020-2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""adafactor""" 16from mindspore.common import dtype as mstype 17from mindspore.log import logging 18from mindspore.common.initializer import initializer 19from mindspore.ops import operations as P 20from mindspore.ops import composite as C 21from mindspore.ops import functional as F 22from mindspore.common.parameter import Parameter, ParameterTuple 23from mindspore.common.tensor import Tensor 24from mindspore._checkparam import Validator as validator 25from mindspore._checkparam import Rel 26from mindspore.nn.optim.optimizer import opt_init_args_register 27from .optimizer import Optimizer 28 29 30def _get_lr(step, rms, learning_rate, relative_step, warmup_init, scale_parameter, eps): 31 """update optimizer learning rete""" 32 rel_step_sz = learning_rate 33 if relative_step: 34 if warmup_init: 35 min_step = 1e-6 * step * 1.0 36 else: 37 min_step = 1e-2 * 1.0 38 39 rel_step_sz = P.Minimum()(min_step, 1.0 / P.Sqrt()(step * 1.0)) 40 param_scale = 1.0 41 if scale_parameter: 42 param_scale = P.Maximum()(eps[1], rms) 43 return rel_step_sz * param_scale * F.ones_like(rms) 44 45 46def _rms(update_tensor): 47 """calculate rms""" 48 return F.sqrt(P.ReduceMean(False)(F.square(update_tensor))) 49 50 51def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col): 52 """Approximation of exponential moving average of square of gradient""" 53 reduce_mean = P.ReduceMean(keep_dims=True)(exp_avg_sq_row, -1) 54 div_val = 1.0 / P.Sqrt()(P.Div()(exp_avg_sq_row, reduce_mean)) 55 r_factor = (P.ExpandDims()(div_val, -1)) 56 57 exp_avg_sq_col = P.ExpandDims()(exp_avg_sq_col, -2) 58 c_factor = 1.0 / P.Sqrt()(exp_avg_sq_col) 59 return P.Mul()(r_factor, c_factor) 60 61 62_adam_opt = C.MultitypeFuncGraph("adam_opt") 63 64 65@_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool", "Bool", 66 "Bool", "Bool", "Bool", "Bool", "Bool", "Tensor", "Tensor", "Tensor", 67 "Tensor", "Tensor", "Tensor", "Tensor", "Tensor") 68def _run_opt_with_one_number(eps, clip_threshold, decay_rate, beta1, 69 weight_decay, scale_lr, scale_parameter, relative_step, 70 warmup_init, compression, use_first_moment, weight_decay_flag, 71 learning_rate, step, grad, param, 72 exp_avg, exp_avg_sq_row, 73 exp_avg_sq_col, exp_avg_sq): 74 """Apply ada factor optimizer to the weight parameter using Tensor.""" 75 success = True 76 grad_dtype = F.dtype(grad) 77 grad_shape = F.shape(grad) 78 79 if grad_dtype == mstype.float16: 80 grad = F.cast(grad, mstype.float32) 81 p_data_fp32 = param 82 if F.dtype(p_data_fp32) == mstype.float16: 83 p_data_fp32 = F.cast(p_data_fp32, mstype.float32) 84 85 factored = len(grad_shape) >= 2 86 87 # State Initialization 88 exp_avg_update = exp_avg 89 exp_avg_sq_update = exp_avg_sq 90 exp_avg_sq_row_update = exp_avg_sq_row 91 exp_avg_sq_col_update = exp_avg_sq_col 92 93 if use_first_moment: 94 if compression: 95 exp_avg_update = F.cast(exp_avg, mstype.float16) 96 97 if factored: 98 exp_avg_sq_row_update = F.cast(exp_avg_sq_row, grad_dtype) 99 exp_avg_sq_col_update = F.cast(exp_avg_sq_col, grad_dtype) 100 else: 101 exp_avg_sq_update = F.cast(exp_avg_sq, grad_dtype) 102 103 if scale_lr: 104 rms = _rms(p_data_fp32) 105 learning_rate_update = _get_lr(step, rms, learning_rate, relative_step, warmup_init, scale_parameter, eps) 106 learning_rate_update = F.assign(learning_rate, F.cast(learning_rate_update, F.dtype(learning_rate))) 107 else: 108 learning_rate_update = learning_rate * 1.0 109 110 beta2t = 1.0 - P.Pow()(step, decay_rate) 111 update = (grad ** 2) + eps[0] 112 113 if factored: 114 exp_avg_sq_row_update = P.Mul()(exp_avg_sq_row_update, beta2t) 115 update_mean = P.ReduceMean()(update, -1) * (1.0 - beta2t) 116 exp_avg_sq_row_update = P.Add()(exp_avg_sq_row_update, update_mean) 117 exp_avg_sq_row_update = F.assign(exp_avg_sq_row, F.cast(exp_avg_sq_row_update, F.dtype(exp_avg_sq_row))) 118 119 exp_avg_sq_col_update = P.Mul()(exp_avg_sq_col_update, beta2t) 120 update_mean = P.ReduceMean()(update, -2) * (1.0 - beta2t) 121 exp_avg_sq_col_update = P.Add()(exp_avg_sq_col_update, update_mean) 122 exp_avg_sq_col_update = F.assign(exp_avg_sq_col, F.cast(exp_avg_sq_col_update, F.dtype(exp_avg_sq_col))) 123 124 update = _approx_sq_grad(exp_avg_sq_row_update, exp_avg_sq_col_update) 125 update = P.Mul()(update, grad) 126 else: 127 update = update * (1.0 - beta2t) 128 exp_avg_sq_update = P.Add()(P.Mul()(exp_avg_sq_update, beta2t), update) 129 exp_avg_sq_update = F.assign(exp_avg_sq, F.cast(exp_avg_sq_update, F.dtype(exp_avg_sq))) 130 exp_avg_sq_update = 1.0 / P.Sqrt()(exp_avg_sq_update) 131 update = P.Mul()(exp_avg_sq_update, grad) 132 133 update_rms_thres = _rms(update) / clip_threshold 134 update_coff = P.Maximum()(update_rms_thres, P.OnesLike()(update_rms_thres)) 135 update = P.Mul()(P.Div()(update, update_coff), learning_rate_update) 136 137 if use_first_moment: 138 if compression: 139 exp_avg_update = F.cast(exp_avg_update, grad_dtype) 140 exp_avg_update = P.Add()(P.Mul()(exp_avg_update, beta1), update * (1 - beta1)) 141 update = F.assign(exp_avg, F.cast(exp_avg_update, F.dtype(exp_avg))) 142 143 if weight_decay_flag: 144 p_data_fp32_coff = p_data_fp32 * -weight_decay * learning_rate_update 145 p_data_fp32 = P.Add()(p_data_fp32, p_data_fp32_coff) 146 p_data_fp32 = P.Sub()(p_data_fp32, update) 147 P.Assign()(param, F.cast(p_data_fp32, F.dtype(param))) 148 return success 149 150 151def trans_to_tensor(paras, is_tuple=False, fp32=True): 152 if paras is None or isinstance(paras, bool): 153 return paras 154 data_type = mstype.float32 if fp32 else mstype.float16 155 if is_tuple: 156 new_paras = [Tensor(ele, data_type) for ele in paras] 157 return tuple(new_paras) 158 return Tensor(paras, data_type) 159 160 161class AdaFactor(Optimizer): 162 r""" 163 Updates gradients by the Adaptive Learning Rates with Sublinear Memory Cost (Adafactor) algorithm. 164 165 The Adafactor algorithm is proposed in `Adafactor: Adafactor: Adaptive Learning Rates with Sublinear Memory 166 Cost <https://arxiv.org/abs/1804.04235>`_. 167 168 .. warning:: 169 This is an experimental prototype that is subject to change and/or deletion. 170 171 Adafactor for weight vector are as follows, 172 173 .. math:: 174 \begin{array}{l} \\ 175 \alpha_{t}=\max \left(\epsilon_{2}, \operatorname{RMS}\left(X_{t-1}\right)\right) \rho_{t} \\ 176 G_{t}=\nabla f_{t}\left(X_{t-1}\right) \\ 177 \hat{V}_{t}=\hat{\beta}_{2} \hat{V}_{t-1}+\left(1-\hat{\beta}_{2_{t}}\right)\left(G_{t}^{2}+ \\ 178 \epsilon_{1} 1_{n}\right) \\ 179 U_{t}=G_{t} / \sqrt{\hat{V}_{t}} \\ 180 \hat{U}_{t}=U_{t} / \max \left(1, \operatorname{RMS}\left(U_{t}\right) / d\right) \\ 181 X_{t}=X_{t-1}-\alpha_{t} \hat{U}_{t} 182 \end{array} 183 184 Adafactor for weight matrices are as follows, 185 186 .. math:: 187 \begin{array}{l} \\ 188 \alpha_{t}=\max \left(\epsilon_{2}, \operatorname{RMS}\left(X_{t-1}\right)\right) \rho_{t} \\ 189 G_{t}=\nabla f_{t}\left(X_{t-1}\right) \\ 190 R_{t}=\hat{\beta}_{2 t} R_{t-1}+\left(1-\hat{\beta}_{2 t}\right)\left(G_{t}^{2}+ \\ 191 \epsilon_{1} 1_{n} 1_{m}^{\top}\right) 1_{m} \\ 192 C_{t}=\hat{\beta}_{2 t} C_{t-1}+\left(1-\hat{\beta}_{2 t}\right) 1_{n}^{\top}\left(G_{t}^{2}+ \\ 193 \epsilon_{1} 1_{n} 1_{m}^{\top}\right) \\ 194 \hat{V}_{t}=R_{t} C_{t} / 1_{n}^{\top} R_{t} \\ 195 U_{t}=G_{t} / \sqrt{\hat{V}_{t}} \\ 196 \hat{U}_{t}=U_{t} / \max \left(1, \operatorname{RMS}\left(U_{t}\right) / d\right) \\ 197 X_{t}=X_{t-1}-\alpha_{t} U_{t} 198 \end{array} 199 200 Where RMS is: 201 202 .. math:: 203 \operatorname{RMS}\left(U_{t}\right)=\operatorname{RMS}_{x \in X}\left(u_{x t}\right)= \\ 204 \sqrt{\operatorname{Mean}_{x \in X}\left(\frac{\left(g_{x t}\right)^{2}}{\hat{v}_{x t}}\right)} 205 206 :math:`x` is each individual parameter, 207 :math:`t` is assumed to be the current number of steps, 208 :math:`a_{t}` is the learning rate, 209 :math:`f(X)` is the loss function, 210 :math:`\epsilon1` and :math:`\epsilon2` is a small positive number to prevent errors, 211 :math:`d` is the clipping threshold, 212 :math:`\beta_{2}` is the moment decay, 213 :math:`\rho` is the relative step size, 214 :math:`R` is the running averages of the row sums of the squared gradient, 215 :math:`C` is the running averages of the column sums of the squared gradient. 216 217 Note: 218 The learning rate depending of this optimizer will be control by the *scale_parameter*, *relative_step* and 219 *warmup_init* options. To use a manual (external) learning rate schedule, it should be 220 set `scale_parameter=False` and `relative_step=False`. 221 222 If parameters is not used in the network, please do not add it to the optimizer, 223 otherwise the calculation result will be abnormal. 224 225 To improve parameter groups performance, the customized order of parameters is supported. 226 227 Args: 228 params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated, 229 the element in `params` must be class `Parameter`. 230 231 learning_rate (Union[float, Tensor]): A value or a graph for the learning rate. 232 When the learning_rate is a Tensor in a 1D dimension. 233 If the type of `learning_rate` is int, it will be converted to float. Default: None. 234 eps (float): The regularization constans for square gradient and parameter scale respectively. 235 default: (1e-30, 1e-3) 236 clip_threshold (Union[float, Tensor]): The threshold of root mean square of final gradient update. default: 1.0 237 decay_rate (Union[float, Tensor]): The coefficient used to compute running averages of square gradient. 238 default: 0.8 239 beta1 (float): The coefficient to computing running averages of gradient. Should be in range (0.0, 1.0). 240 Default: None. 241 weight_decay (float): Weight decay (L2 penalty). It must be equal to or greater than 0. Default: 0.0. 242 scale_parameter (bool): If True, learning rate is scaled by root mean square of parameter. default: True 243 relative_step (bool): If True, time-dependent learning rate is computed instead of external learning rate. 244 default: True 245 warmup_init (bool): The time-dependent learning rate computation depends on whether warm-up 246 initialization is being used. default: False 247 compression (bool): If True, the data type of the running averages exponent will be compression to float16. 248 default: False 249 loss_scale (float): A floating point value for the loss scale. Should be greater than 0. In general, use the 250 default value. Only when `FixedLossScaleManager` is used for training and the `drop_overflow_update` in 251 `FixedLossScaleManager` is set to False, then this value needs to be the same as the `loss_scale` in 252 `FixedLossScaleManager`. Refer to class :class:`mindspore.FixedLossScaleManager` for more details. 253 Default: 1.0. 254 255 Inputs: 256 - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`. 257 258 Outputs: 259 Tensor[bool], the value is True. 260 261 Raises: 262 TypeError: If `learning_rate` is not one of int, float, Tensor, Iterable, LearningRateSchedule. 263 TypeError: If element of `parameters` is neither Parameter nor dict. 264 TypeError: If `beta1`, `beta2`, `eps` or `loss_scale` is not a float. 265 TypeError: If `weight_decay` is neither float nor int. 266 TypeError: If `use_locking` or `use_nesterov` is not a bool. 267 ValueError: If `loss_scale` or `eps` is less than or equal to 0. 268 ValueError: If `beta1`, `beta2` is not in range (0.0, 1.0). 269 ValueError: If `weight_decay` is less than 0. 270 271 Supported Platforms: 272 ``Ascend`` 273 274 Examples: 275 >>> net = Net() 276 >>> #1) Parameters use the default learning rate with None and weight decay with 0. 277 >>> optim = nn.AdaFactor(params=net.trainable_params()) 278 >>> 279 >>> #2) Use parameter groups 280 >>> all_params = net.trainable_params() 281 >>> group_params = [{'params': [all_params[0]]}, {'params': [all_params[1]]}] 282 >>> optim = nn.AdaFactor(group_params, learning_rate=0.1, weight_decay=0.0, relative_step=False) 283 >>> loss = nn.SoftmaxCrossEntropyWithLogits() 284 >>> model = Model(net, loss_fn=loss, optimizer=optim) 285 """ 286 287 @opt_init_args_register 288 def __init__(self, 289 params, 290 learning_rate=None, 291 eps=(1e-30, 1e-3), 292 clip_threshold=1.0, 293 decay_rate=0.8, 294 beta1=0.9, 295 weight_decay=0.0, 296 scale_parameter=True, 297 relative_step=True, 298 warmup_init=False, 299 compression=False, 300 loss_scale=1.0): 301 302 if learning_rate is not None and relative_step: 303 raise ValueError("Cannot combine manual lr and relative_step options", learning_rate) 304 if warmup_init and not relative_step: 305 raise ValueError("warmup_init requires relative_step=True") 306 if learning_rate is None and not relative_step: 307 raise ValueError("Cannot learning_rate is None and relative_step=False") 308 if learning_rate is None: 309 learning_rate = 0.0 310 if beta1 is None: 311 beta1 = 0.0 312 self.scale_lr = True 313 if not isinstance(learning_rate, (float, int)) and learning_rate is not None: 314 self.scale_lr = False 315 if relative_step or scale_parameter: 316 logging.warning("When learning_rate is learning scheduler, it not support update learning rate!") 317 318 super(AdaFactor, self).__init__(learning_rate, params, weight_decay, loss_scale) 319 validator.check_value_type("eps", eps, [list, tuple], self.cls_name) 320 if len(eps) != 2: 321 raise ValueError("eps must have 2 value: (eps1, eps2).") 322 for i, ele in enumerate(eps): 323 validator.check_value_type("eps{}".format(i), ele, [float], self.cls_name) 324 validator.check_non_negative_float(ele, "eps{}".format(i), self.cls_name) 325 validator.check_value_type("clip_threshold", clip_threshold, [float], self.cls_name) 326 validator.check_non_negative_float(clip_threshold, "clip_threshold", self.cls_name) 327 validator.check_value_type("decay_rate", decay_rate, [float], self.cls_name) 328 validator.check_float_range(decay_rate, 0, 1, Rel.INC_NEITHER, "decay_rate", self.cls_name) 329 validator.check_float_range(weight_decay, 0, 1, Rel.INC_LEFT, "weight_decay", self.cls_name) 330 validator.check_value_type("scale_parameter", scale_parameter, [bool], self.cls_name) 331 validator.check_value_type("relative_step", relative_step, [bool], self.cls_name) 332 validator.check_value_type("compression", compression, [bool], self.cls_name) 333 validator.check_value_type("beta1", beta1, [int, float], self.cls_name) 334 validator.check_non_negative_float(float(beta1), "beta1", self.cls_name) 335 self.eps = trans_to_tensor(eps) 336 self.clip_threshold = trans_to_tensor(clip_threshold) 337 self.decay_rate = trans_to_tensor(-decay_rate) 338 self.beta1 = trans_to_tensor(beta1) 339 self.weight_decay = trans_to_tensor(weight_decay) 340 self.weight_decay_flag = bool(weight_decay) 341 342 self.step = Parameter(Tensor(0, dtype=mstype.float32), name="train_step") 343 self.scale_parameter = scale_parameter 344 self.relative_step = relative_step 345 self.warmup_init = warmup_init 346 self.compression = compression 347 348 self.init_ada_factor_state(beta1) 349 self.step = Parameter(initializer(0, [1], mstype.float32), name='afactor_step') 350 print("AdaFactor init completed", self.learning_rate) 351 352 def init_ada_factor_state(self, beta1): 353 """init adafactor variables""" 354 if beta1 > 0: 355 self.use_first_moment = True 356 self.exp_avg = self.parameters.clone(prefix="exp_avg", init='zeros') 357 else: 358 self.use_first_moment = False 359 self.exp_avg = ParameterTuple([Parameter(Tensor(0.0))] * len(self.parameters)) 360 361 self.exp_avg_sq = [] 362 self.exp_avg_sq_col = [] 363 self.exp_avg_sq_row = [] 364 for paras in self.parameters: 365 paras_dtype = paras.dtype 366 paras_shape = paras.shape 367 paras_name = paras.name 368 if len(paras_shape) > 1: 369 self.exp_avg_sq_row.append(Parameter(initializer(0, shape=paras_shape[:-1], dtype=paras_dtype), 370 name="exp_avg_sq_row_{}".format(paras_name))) 371 self.exp_avg_sq_col.append(Parameter(initializer(0, shape=paras_shape[:-2] + paras_shape[-1:], 372 dtype=paras_dtype), 373 name="exp_avg_sq_col_{}".format(paras_name))) 374 if self.compression: 375 self.exp_avg_sq.append(Parameter(initializer(0, shape=(1,), dtype=mstype.float16), 376 name="exp_avg_sq_{}".format(paras_name))) 377 else: 378 self.exp_avg_sq.append(Parameter(initializer(0, shape=(1,), dtype=paras_dtype), 379 name="exp_avg_sq_{}".format(paras_name))) 380 381 else: 382 self.exp_avg_sq_row.append(Parameter(initializer(0, shape=(1,), dtype=paras_dtype), 383 name="exp_avg_sq_row_{}".format(paras_name))) 384 self.exp_avg_sq_col.append(Parameter(initializer(0, shape=(1,), dtype=paras_dtype), 385 name="exp_avg_sq_col_{}".format(paras_name))) 386 387 if self.compression: 388 self.exp_avg_sq.append(Parameter(initializer(0, shape=paras_shape, dtype=mstype.float16), 389 name="exp_avg_sq_{}".format(paras_name))) 390 else: 391 self.exp_avg_sq.append(Parameter(initializer(0, shape=paras_shape, dtype=paras_dtype), 392 name="exp_avg_sq_{}".format(paras_name))) 393 394 self.exp_avg_sq_row = ParameterTuple(self.exp_avg_sq_row) 395 self.exp_avg_sq_col = ParameterTuple(self.exp_avg_sq_col) 396 self.exp_avg_sq = ParameterTuple(self.exp_avg_sq) 397 398 @property 399 def supports_memory_efficient_fp16(self): 400 return True 401 402 @property 403 def supports_flat_params(self): 404 return False 405 406 def construct(self, gradients): 407 lr = self.get_lr() 408 step = F.assign_add(self.step, 1) 409 success = self.hyper_map(F.partial(_adam_opt, self.eps, self.clip_threshold, self.decay_rate, 410 self.beta1, self.weight_decay, self.scale_lr, 411 self.scale_parameter, self.relative_step, 412 self.warmup_init, self.compression, self.use_first_moment, 413 self.weight_decay_flag, lr, step), 414 gradients, self.parameters, self.exp_avg, self.exp_avg_sq_row, 415 self.exp_avg_sq_col, self.exp_avg_sq) 416 417 return success 418 419 @Optimizer.target.setter 420 def target(self, value): 421 """ 422 If the input value is set to "CPU", the parameters will be updated on the host using the Fused 423 optimizer operation. 424 """ 425 self._set_base_target(value) 426