1# Copyright 2024 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""came optimizer""" 16from __future__ import absolute_import 17 18from mindspore import context 19from mindspore.common import dtype as mstype 20from mindspore.log import logging 21from mindspore.common.initializer import initializer 22from mindspore.common.api import jit 23from mindspore.ops import operations as P 24from mindspore.ops import composite as C 25from mindspore.ops import functional as F 26from mindspore.common.parameter import Parameter, ParameterTuple 27from mindspore.common.tensor import Tensor 28try: 29 from mindspore._checkparam import Validator as validator 30 from mindspore._checkparam import Rel 31except ImportError: 32 import mindspore._checkparam as validator 33 import mindspore._checkparam as Rel 34from mindspore.nn.optim.optimizer import Optimizer 35from mindspore.nn.optim.optimizer import opt_init_args_register 36 37__all__ = ['Came'] 38 39 40def _rms(update_tensor): 41 """calculate rms""" 42 return F.sqrt(P.ReduceMean(False)(F.square(update_tensor))) 43 44 45def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col): 46 """Approximation of exponential moving average of square of gradient""" 47 reduce_mean = P.ReduceMean(keep_dims=True)(exp_avg_sq_row, -1) 48 div_val = 1.0 / P.Sqrt()(P.RealDiv()(exp_avg_sq_row, reduce_mean)) 49 r_factor = (P.ExpandDims()(div_val, -1)) 50 51 exp_avg_sq_col = P.ExpandDims()(exp_avg_sq_col, -2) 52 c_factor = 1.0 / P.Sqrt()(exp_avg_sq_col) 53 return P.Mul()(r_factor, c_factor) 54 55 56reduce_mean_keep_alive = P.ReduceMean().add_prim_attr("keep_alive", True) 57_came_opt = C.MultitypeFuncGraph("came_opt") 58 59 60@_came_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool", "Bool", "Bool", "Bool", 61 "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor") 62def _run_opt_with_one_number(eps, clip_threshold, beta1, beta2t, beta3, weight_decay, scale_parameter, 63 compression, use_first_moment, weight_decay_flag, learning_rate, 64 grad, param, exp_avg, exp_avg_sq_row, exp_avg_sq_col, exp_avg_sq, 65 exp_avg_insta_row, exp_avg_insta_col): 66 """Apply came optimizer to the weight parameter using Tensor.""" 67 cast = P.Cast() 68 grad_dtype = F.dtype(grad) 69 grad_shape = F.shape(grad) 70 71 if grad_dtype == mstype.float16: 72 grad = cast(grad, mstype.float32) 73 p_data_fp32 = param 74 if F.dtype(p_data_fp32) == mstype.float16: 75 p_data_fp32 = cast(p_data_fp32, mstype.float32) 76 77 factored = len(grad_shape) >= 2 78 79 if scale_parameter: 80 rms = _rms(p_data_fp32) 81 param_scale = P.Maximum()(eps[1], rms) 82 learning_rate_update = learning_rate * param_scale * F.ones_like(rms) 83 else: 84 learning_rate_update = learning_rate 85 86 update = (grad ** 2) + eps[0] 87 88 if factored: 89 exp_avg_sq_row_update = cast(exp_avg_sq_row, grad_dtype) 90 exp_avg_sq_row_update = P.Mul()(exp_avg_sq_row_update, beta2t) 91 update_mean = reduce_mean_keep_alive(update, -1) * (1.0 - beta2t) 92 exp_avg_sq_row_update = P.Add()(exp_avg_sq_row_update, update_mean) 93 F.assign(exp_avg_sq_row, cast(exp_avg_sq_row_update, F.dtype(exp_avg_sq_row))) 94 exp_avg_sq_row_update = exp_avg_sq_row 95 96 exp_avg_sq_col_update = cast(exp_avg_sq_col, grad_dtype) 97 exp_avg_sq_col_update = P.Mul()(exp_avg_sq_col_update, beta2t) 98 update_mean = reduce_mean_keep_alive(update, -2) * (1.0 - beta2t) 99 exp_avg_sq_col_update = P.Add()(exp_avg_sq_col_update, update_mean) 100 F.assign(exp_avg_sq_col, cast(exp_avg_sq_col_update, F.dtype(exp_avg_sq_col))) 101 exp_avg_sq_col_update = exp_avg_sq_col 102 update = _approx_sq_grad(exp_avg_sq_row_update, exp_avg_sq_col_update) 103 update = P.Mul()(update, grad) 104 105 else: 106 exp_avg_sq_update = cast(exp_avg_sq, grad_dtype) 107 update = update * (1.0 - beta2t) 108 exp_avg_sq_update = P.Add()(P.Mul()(exp_avg_sq_update, beta2t), update) 109 F.assign(exp_avg_sq, cast(exp_avg_sq_update, F.dtype(exp_avg_sq))) 110 exp_avg_sq_update = exp_avg_sq 111 exp_avg_sq_update = 1.0 / P.Sqrt()(exp_avg_sq_update) 112 update = P.Mul()(exp_avg_sq_update, grad) 113 114 update_rms_thres = _rms(update) / clip_threshold 115 update_coff = P.Maximum()(update_rms_thres, P.OnesLike()(update_rms_thres)) 116 update = P.RealDiv()(update, update_coff) 117 118 if use_first_moment: 119 exp_avg_update = exp_avg 120 if compression: 121 exp_avg_update = cast(exp_avg, grad_dtype) 122 exp_avg_update = P.Add()(P.Mul()(exp_avg_update, beta1), update * (1 - beta1)) 123 F.assign(exp_avg, cast(exp_avg_update, F.dtype(exp_avg))) 124 125 ### 126 # CAME optimizer modification is here 127 instability_matrix = (update - exp_avg) ** 2 + eps[2] 128 129 if factored: 130 exp_avg_insta_row_update = cast(exp_avg_insta_row, grad_dtype) 131 exp_avg_insta_row_update = P.Mul()(exp_avg_insta_row_update, beta3) 132 update_mean = reduce_mean_keep_alive(instability_matrix, -1) * (1.0 - beta3) 133 exp_avg_insta_row_update = P.Add()(exp_avg_insta_row_update, update_mean) 134 F.assign(exp_avg_insta_row, cast(exp_avg_insta_row_update, F.dtype(exp_avg_insta_row))) 135 exp_avg_insta_row_update = exp_avg_insta_row 136 137 exp_avg_insta_col_update = cast(exp_avg_insta_col, grad_dtype) 138 exp_avg_insta_col_update = P.Mul()(exp_avg_insta_col_update, beta3) 139 update_mean = reduce_mean_keep_alive(instability_matrix, -2) * (1.0 - beta3) 140 exp_avg_insta_col_update = P.Add()(exp_avg_insta_col_update, update_mean) 141 F.assign(exp_avg_insta_col, cast(exp_avg_insta_col_update, F.dtype(exp_avg_insta_col))) 142 exp_avg_insta_col_update = exp_avg_insta_col 143 144 s_t = _approx_sq_grad(exp_avg_insta_row_update, exp_avg_insta_col_update) 145 update = s_t * exp_avg * learning_rate_update 146 else: 147 update = exp_avg * learning_rate_update 148 # ### 149 150 if weight_decay_flag: 151 p_data_fp32_coff = p_data_fp32 * -weight_decay * learning_rate_update 152 p_data_fp32 = P.Add()(p_data_fp32, p_data_fp32_coff) 153 p_data_fp32 = P.Sub()(p_data_fp32, update) 154 P.Assign()(param, cast(p_data_fp32, F.dtype(param))) 155 return True 156 157 158@_came_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", 159 "Tensor", "Tensor", "Tensor", "Tensor") 160def _run_fused_ada_factor(fused_ada_factor, eps, clip_threshold, beta1, beta2t, weight_decay, learning_rate, 161 grad, param, exp_avg, exp_avg_sq_row, exp_avg_sq_col, exp_avg_sq): 162 fused_ada_factor(eps, clip_threshold, beta1, beta2t, weight_decay, learning_rate, 163 grad, param, exp_avg, exp_avg_sq_row, exp_avg_sq_col, exp_avg_sq) 164 return True 165 166 167def trans_to_tensor(param, is_tuple=False, fp32=True): 168 """ 169 Transform params to tensor. 170 """ 171 if param is None or isinstance(param, bool): 172 return param 173 data_type = mstype.float32 if fp32 else mstype.float16 174 if is_tuple: 175 new_param = [Tensor(ele, data_type) for ele in param] 176 return tuple(new_param) 177 return Tensor(param, data_type) 178 179 180class Came(Optimizer): 181 r""" 182 Updates gradients by the Confidence-guided Adaptive Memory Efficient Optimization (Came) algorithm. 183 184 The Came algorithm is proposed in `CAME: Confidence-guided Adaptive Memory Efficient Optimization 185 <https://arxiv.org/abs/2307.02047>`. 186 187 Args: 188 params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated, 189 the element in `params` must be class `Parameter`. 190 learning_rate (Union[float, Tensor]): A value or a graph for the learning rate. 191 When the learning_rate is a Tensor in a 1D dimension. 192 If the type of `learning_rate` is int, it will be converted to float. Default: None. 193 eps (tuple): The regularization constans for square gradient, parameter scale and instability_matrix 194 respectively. default: (1e-30, 1e-3, 1e-16) 195 clip_threshold (Union[float, Tensor]): The threshold of root mean square of final gradient update. default: 1.0 196 decay_rate (Union[float, Tensor]): The coefficient used to compute running averages of square gradient. 197 default: 0.8 198 beta1 (float): The coefficient to computing running averages of gradient. Should be in range (0.0, 1.0). 199 Default: 0.9. 200 beta3 (float): The coefficient to computing running averages of gradient. Should be in range (0.0, 1.0). 201 Default: 0.99. 202 weight_decay (float): Weight decay (L2 penalty). It must be equal to or greater than 0. Default: 0.0. 203 scale_parameter (bool): If True, learning rate is scaled by root mean square of parameter. default: True 204 relative_step (bool): If True, time-dependent learning rate is computed instead of external learning rate. 205 default: True 206 warmup_init (bool): The time-dependent learning rate computation depends on whether warm-up 207 initialization is being used. default: False 208 compression (bool): If True, the data type of the running averages exponent will be compression to float16. 209 default: False 210 loss_scale (float): A floating point value for the loss scale. Should be greater than 0. In general, use the 211 default value. Only when `FixedLossScaleManager` is used for training and the `drop_overflow_update` in 212 `FixedLossScaleManager` is set to False, then this value needs to be the same as the `loss_scale` in 213 `FixedLossScaleManager`. Refer to class :class:`mindspore.amp.FixedLossScaleManager` for more details. 214 Default: 1.0. 215 216 Inputs: 217 - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`. 218 219 Outputs: 220 Tensor[bool], the value is True. 221 222 Raises: 223 TypeError: If `learning_rate` is not one of int, float, Tensor, Iterable, LearningRateSchedule. 224 TypeError: If element of `parameters` is neither Parameter nor dict. 225 TypeError: If `beta1`, `beta3`, `eps` or `loss_scale` is not a float. 226 TypeError: If `weight_decay` is neither float nor int. 227 TypeError: If `use_locking` or `use_nesterov` is not a bool. 228 ValueError: If `loss_scale` or `eps` is less than or equal to 0. 229 ValueError: If `beta1`, `beta3` is not in range (0.0, 1.0). 230 ValueError: If `weight_decay` is less than 0. 231 232 Supported Platforms: 233 ``Ascend`` 234 """ 235 _support_parallel_optimizer = True 236 237 @opt_init_args_register 238 def __init__(self, 239 params, 240 learning_rate=None, 241 eps=(1e-30, 1e-3, 1e-16), 242 clip_threshold=1.0, 243 decay_rate=0.8, 244 beta1=0.9, 245 beta3=0.99, 246 weight_decay=0.0, 247 scale_parameter=False, 248 relative_step=False, 249 warmup_init=False, 250 compression=False, 251 loss_scale=1.0): 252 253 if learning_rate is not None and relative_step: 254 raise ValueError("Cannot combine manual lr and relative_step options", learning_rate) 255 if warmup_init and not relative_step: 256 raise ValueError("warmup_init requires relative_step=True") 257 if learning_rate is None and not relative_step: 258 raise ValueError("Cannot learning_rate is None and relative_step=False") 259 if learning_rate is None: 260 learning_rate = 0.0 261 if beta1 is None: 262 beta1 = 0.0 263 264 if not isinstance(learning_rate, (float, int)) and learning_rate is not None: 265 if relative_step or scale_parameter: 266 logging.warning("When learning_rate is learning scheduler, it not support update learning rate!") 267 268 super(Came, self).__init__(learning_rate, params, weight_decay, loss_scale) 269 validator.check_value_type("eps", eps, [list, tuple], self.cls_name) 270 if len(eps) != 3: 271 raise ValueError("eps must have 3 value: (eps1, eps2, eps3).") 272 for i, ele in enumerate(eps): 273 validator.check_value_type("eps{}".format(i), ele, [float], self.cls_name) 274 validator.check_non_negative_float(ele, "eps{}".format(i), self.cls_name) 275 validator.check_value_type("clip_threshold", clip_threshold, [float], self.cls_name) 276 validator.check_non_negative_float(clip_threshold, "clip_threshold", self.cls_name) 277 validator.check_value_type("decay_rate", decay_rate, [float], self.cls_name) 278 validator.check_float_range(decay_rate, 0, 1, Rel.INC_NEITHER, "decay_rate", self.cls_name) 279 validator.check_float_range(weight_decay, 0, 1, Rel.INC_LEFT, "weight_decay", self.cls_name) 280 validator.check_value_type("scale_parameter", scale_parameter, [bool], self.cls_name) 281 validator.check_value_type("relative_step", relative_step, [bool], self.cls_name) 282 validator.check_value_type("compression", compression, [bool], self.cls_name) 283 validator.check_value_type("beta1", beta1, [int, float], self.cls_name) 284 validator.check_non_negative_float(float(beta1), "beta1", self.cls_name) 285 validator.check_value_type("beta3", beta3, [int, float], self.cls_name) 286 validator.check_non_negative_float(float(beta3), "beta3", self.cls_name) 287 self.eps = trans_to_tensor(eps) 288 self.clip_threshold = trans_to_tensor(clip_threshold) 289 self.decay_rate = trans_to_tensor(-decay_rate) 290 self.beta1 = trans_to_tensor(beta1) 291 self.beta3 = trans_to_tensor(beta3) 292 self.weight_decay = trans_to_tensor(weight_decay) 293 self.weight_decay_flag = bool(weight_decay) 294 295 self.scale_parameter = scale_parameter 296 self.relative_step = relative_step 297 self.warmup_init = warmup_init 298 self.compression = compression 299 self.init_came_state(beta1) 300 self.step = Parameter(initializer(0, [1], mstype.float32), name='afactor_step') 301 self.fused_ada_factor = P.FusedAdaFactor(enable_scale_parameter=self.scale_parameter, 302 enable_first_moment=self.use_first_moment, 303 enable_weight_decay=self.weight_decay_flag) 304 if context.get_context("device_target") == "CPU": 305 self.use_fused_ada_factor = True 306 else: 307 self.use_fused_ada_factor = False 308 logging.info("Came init completed %s.", self.learning_rate) 309 310 def init_came_state(self, beta1): 311 """init came variables""" 312 if beta1 > 0: 313 self.use_first_moment = True 314 self.exp_avg = self._parameters.clone(prefix="exp_avg", init='zeros') 315 else: 316 self.use_first_moment = False 317 self.exp_avg = ParameterTuple([Parameter(Tensor(0.0))] * len(self._parameters)) 318 319 self.exp_avg_sq = [] 320 self.exp_avg_sq_col = [] 321 self.exp_avg_sq_row = [] 322 self.exp_avg_insta_col = [] 323 self.exp_avg_insta_row = [] 324 for param in self._parameters: 325 param_dtype = param.dtype 326 param_shape = param.shape 327 param_name = param.name 328 if len(param_shape) > 1: 329 self.exp_avg_sq_row.append(Parameter(initializer(0, shape=param_shape[:-1], dtype=param_dtype), 330 name="exp_avg_sq_row_{}".format(param_name))) 331 self.exp_avg_sq_col.append(Parameter(initializer(0, shape=param_shape[:-2] + param_shape[-1:], 332 dtype=param_dtype), 333 name="exp_avg_sq_col_{}".format(param_name))) 334 self.exp_avg_insta_row.append(Parameter(initializer(0, shape=param_shape[:-1], dtype=param_dtype), 335 name="exp_avg_insta_row_{}".format(param_name))) 336 self.exp_avg_insta_col.append(Parameter(initializer(0, shape=param_shape[:-2] + param_shape[-1:], 337 dtype=param_dtype), 338 name="exp_avg_insta_col_{}".format(param_name))) 339 self.exp_avg_sq.append(Parameter(initializer(0, shape=(1,), dtype=param_dtype), 340 name="exp_avg_sq_{}".format(param_name))) 341 342 else: 343 self.exp_avg_sq_row.append(Parameter(initializer(0, shape=(1,), dtype=param_dtype), 344 name="exp_avg_sq_row_{}".format(param_name))) 345 self.exp_avg_sq_col.append(Parameter(initializer(0, shape=(1,), dtype=param_dtype), 346 name="exp_avg_sq_col_{}".format(param_name))) 347 self.exp_avg_insta_row.append(Parameter(initializer(0, shape=(1,), dtype=param_dtype), 348 name="exp_avg_insta_row_{}".format(param_name))) 349 self.exp_avg_insta_col.append(Parameter(initializer(0, shape=(1,), dtype=param_dtype), 350 name="exp_avg_insta_col_{}".format(param_name))) 351 352 if self.compression: 353 self.exp_avg_sq.append(Parameter(initializer(0, shape=param_shape, dtype=mstype.float16), 354 name="exp_avg_sq_{}".format(param_name))) 355 else: 356 self.exp_avg_sq.append(Parameter(initializer(0, shape=param_shape, dtype=param_dtype), 357 name="exp_avg_sq_{}".format(param_name))) 358 359 self.exp_avg_sq_row = ParameterTuple(self.exp_avg_sq_row) 360 self.exp_avg_sq_col = ParameterTuple(self.exp_avg_sq_col) 361 self.exp_avg_insta_row = ParameterTuple(self.exp_avg_insta_row) 362 self.exp_avg_insta_col = ParameterTuple(self.exp_avg_insta_col) 363 self.exp_avg_sq = ParameterTuple(self.exp_avg_sq) 364 365 @property 366 def supports_memory_efficient_fp16(self): 367 """ 368 Support memory efficient for fp16 369 """ 370 return True 371 372 @property 373 def supports_flat_params(self): 374 """ 375 Support flatten params 376 """ 377 return False 378 379 @jit 380 def construct(self, gradients): 381 """construct of came optimizer.""" 382 gradients = self.flatten_gradients(gradients) 383 lr = self.get_lr() 384 self.assignadd(self.global_step, self.global_step_increase_tensor) 385 F.assign_add(self.step, 1) 386 step = self.step 387 beta2t = 1.0 - P.Pow()(step, self.decay_rate) 388 389 if self.use_fused_ada_factor: 390 success = self.hyper_map(F.partial(_came_opt, self.fused_ada_factor, self.eps, self.clip_threshold, 391 self.beta1, beta2t, self.weight_decay, lr), 392 gradients, self._parameters, self.exp_avg, self.exp_avg_sq_row, 393 self.exp_avg_sq_col, self.exp_avg_sq) 394 else: 395 success = self.hyper_map(F.partial(_came_opt, self.eps, self.clip_threshold, self.beta1, beta2t, self.beta3, 396 self.weight_decay, self.scale_parameter, self.compression, 397 self.use_first_moment, self.weight_decay_flag, lr), 398 gradients, self._parameters, self.exp_avg, self.exp_avg_sq_row, 399 self.exp_avg_sq_col, self.exp_avg_sq, self.exp_avg_insta_row, 400 self.exp_avg_insta_col) 401 402 return success 403 404 @Optimizer.target.setter 405 def target(self, value): 406 """ 407 If the input value is set to "CPU", the parameters will be updated on the host using the Fused 408 optimizer operation. 409 """ 410 self._set_base_target(value) 411 if value == 'CPU': 412 self.fused_ada_factor.set_device("CPU") 413 self.use_fused_ada_factor = True 414 else: 415 self.use_fused_ada_factor = False 416