• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2024 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""came optimizer"""
16from __future__ import absolute_import
17
18from mindspore import context
19from mindspore.common import dtype as mstype
20from mindspore.log import logging
21from mindspore.common.initializer import initializer
22from mindspore.common.api import jit
23from mindspore.ops import operations as P
24from mindspore.ops import composite as C
25from mindspore.ops import functional as F
26from mindspore.common.parameter import Parameter, ParameterTuple
27from mindspore.common.tensor import Tensor
28try:
29    from mindspore._checkparam import Validator as validator
30    from mindspore._checkparam import Rel
31except ImportError:
32    import mindspore._checkparam as validator
33    import mindspore._checkparam as Rel
34from mindspore.nn.optim.optimizer import Optimizer
35from mindspore.nn.optim.optimizer import opt_init_args_register
36
37__all__ = ['Came']
38
39
40def _rms(update_tensor):
41    """calculate rms"""
42    return F.sqrt(P.ReduceMean(False)(F.square(update_tensor)))
43
44
45def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col):
46    """Approximation of exponential moving average of square of gradient"""
47    reduce_mean = P.ReduceMean(keep_dims=True)(exp_avg_sq_row, -1)
48    div_val = 1.0 / P.Sqrt()(P.RealDiv()(exp_avg_sq_row, reduce_mean))
49    r_factor = (P.ExpandDims()(div_val, -1))
50
51    exp_avg_sq_col = P.ExpandDims()(exp_avg_sq_col, -2)
52    c_factor = 1.0 / P.Sqrt()(exp_avg_sq_col)
53    return P.Mul()(r_factor, c_factor)
54
55
56reduce_mean_keep_alive = P.ReduceMean().add_prim_attr("keep_alive", True)
57_came_opt = C.MultitypeFuncGraph("came_opt")
58
59
60@_came_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool", "Bool", "Bool", "Bool",
61                    "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor")
62def _run_opt_with_one_number(eps, clip_threshold, beta1, beta2t, beta3, weight_decay, scale_parameter,
63                             compression, use_first_moment, weight_decay_flag, learning_rate,
64                             grad, param, exp_avg, exp_avg_sq_row, exp_avg_sq_col, exp_avg_sq,
65                             exp_avg_insta_row, exp_avg_insta_col):
66    """Apply came optimizer to the weight parameter using Tensor."""
67    cast = P.Cast()
68    grad_dtype = F.dtype(grad)
69    grad_shape = F.shape(grad)
70
71    if grad_dtype == mstype.float16:
72        grad = cast(grad, mstype.float32)
73    p_data_fp32 = param
74    if F.dtype(p_data_fp32) == mstype.float16:
75        p_data_fp32 = cast(p_data_fp32, mstype.float32)
76
77    factored = len(grad_shape) >= 2
78
79    if scale_parameter:
80        rms = _rms(p_data_fp32)
81        param_scale = P.Maximum()(eps[1], rms)
82        learning_rate_update = learning_rate * param_scale * F.ones_like(rms)
83    else:
84        learning_rate_update = learning_rate
85
86    update = (grad ** 2) + eps[0]
87
88    if factored:
89        exp_avg_sq_row_update = cast(exp_avg_sq_row, grad_dtype)
90        exp_avg_sq_row_update = P.Mul()(exp_avg_sq_row_update, beta2t)
91        update_mean = reduce_mean_keep_alive(update, -1) * (1.0 - beta2t)
92        exp_avg_sq_row_update = P.Add()(exp_avg_sq_row_update, update_mean)
93        F.assign(exp_avg_sq_row, cast(exp_avg_sq_row_update, F.dtype(exp_avg_sq_row)))
94        exp_avg_sq_row_update = exp_avg_sq_row
95
96        exp_avg_sq_col_update = cast(exp_avg_sq_col, grad_dtype)
97        exp_avg_sq_col_update = P.Mul()(exp_avg_sq_col_update, beta2t)
98        update_mean = reduce_mean_keep_alive(update, -2) * (1.0 - beta2t)
99        exp_avg_sq_col_update = P.Add()(exp_avg_sq_col_update, update_mean)
100        F.assign(exp_avg_sq_col, cast(exp_avg_sq_col_update, F.dtype(exp_avg_sq_col)))
101        exp_avg_sq_col_update = exp_avg_sq_col
102        update = _approx_sq_grad(exp_avg_sq_row_update, exp_avg_sq_col_update)
103        update = P.Mul()(update, grad)
104
105    else:
106        exp_avg_sq_update = cast(exp_avg_sq, grad_dtype)
107        update = update * (1.0 - beta2t)
108        exp_avg_sq_update = P.Add()(P.Mul()(exp_avg_sq_update, beta2t), update)
109        F.assign(exp_avg_sq, cast(exp_avg_sq_update, F.dtype(exp_avg_sq)))
110        exp_avg_sq_update = exp_avg_sq
111        exp_avg_sq_update = 1.0 / P.Sqrt()(exp_avg_sq_update)
112        update = P.Mul()(exp_avg_sq_update, grad)
113
114    update_rms_thres = _rms(update) / clip_threshold
115    update_coff = P.Maximum()(update_rms_thres, P.OnesLike()(update_rms_thres))
116    update = P.RealDiv()(update, update_coff)
117
118    if use_first_moment:
119        exp_avg_update = exp_avg
120        if compression:
121            exp_avg_update = cast(exp_avg, grad_dtype)
122        exp_avg_update = P.Add()(P.Mul()(exp_avg_update, beta1), update * (1 - beta1))
123        F.assign(exp_avg, cast(exp_avg_update, F.dtype(exp_avg)))
124
125    ###
126    # CAME  optimizer modification is here
127    instability_matrix = (update - exp_avg) ** 2 + eps[2]
128
129    if factored:
130        exp_avg_insta_row_update = cast(exp_avg_insta_row, grad_dtype)
131        exp_avg_insta_row_update = P.Mul()(exp_avg_insta_row_update, beta3)
132        update_mean = reduce_mean_keep_alive(instability_matrix, -1) * (1.0 - beta3)
133        exp_avg_insta_row_update = P.Add()(exp_avg_insta_row_update, update_mean)
134        F.assign(exp_avg_insta_row, cast(exp_avg_insta_row_update, F.dtype(exp_avg_insta_row)))
135        exp_avg_insta_row_update = exp_avg_insta_row
136
137        exp_avg_insta_col_update = cast(exp_avg_insta_col, grad_dtype)
138        exp_avg_insta_col_update = P.Mul()(exp_avg_insta_col_update, beta3)
139        update_mean = reduce_mean_keep_alive(instability_matrix, -2) * (1.0 - beta3)
140        exp_avg_insta_col_update = P.Add()(exp_avg_insta_col_update, update_mean)
141        F.assign(exp_avg_insta_col, cast(exp_avg_insta_col_update, F.dtype(exp_avg_insta_col)))
142        exp_avg_insta_col_update = exp_avg_insta_col
143
144        s_t = _approx_sq_grad(exp_avg_insta_row_update, exp_avg_insta_col_update)
145        update = s_t * exp_avg * learning_rate_update
146    else:
147        update = exp_avg * learning_rate_update
148    # ###
149
150    if weight_decay_flag:
151        p_data_fp32_coff = p_data_fp32 * -weight_decay * learning_rate_update
152        p_data_fp32 = P.Add()(p_data_fp32, p_data_fp32_coff)
153    p_data_fp32 = P.Sub()(p_data_fp32, update)
154    P.Assign()(param, cast(p_data_fp32, F.dtype(param)))
155    return True
156
157
158@_came_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
159                    "Tensor", "Tensor", "Tensor", "Tensor")
160def _run_fused_ada_factor(fused_ada_factor, eps, clip_threshold, beta1, beta2t, weight_decay, learning_rate,
161                          grad, param, exp_avg, exp_avg_sq_row, exp_avg_sq_col, exp_avg_sq):
162    fused_ada_factor(eps, clip_threshold, beta1, beta2t, weight_decay, learning_rate,
163                     grad, param, exp_avg, exp_avg_sq_row, exp_avg_sq_col, exp_avg_sq)
164    return True
165
166
167def trans_to_tensor(param, is_tuple=False, fp32=True):
168    """
169    Transform params to tensor.
170    """
171    if param is None or isinstance(param, bool):
172        return param
173    data_type = mstype.float32 if fp32 else mstype.float16
174    if is_tuple:
175        new_param = [Tensor(ele, data_type) for ele in param]
176        return tuple(new_param)
177    return Tensor(param, data_type)
178
179
180class Came(Optimizer):
181    r"""
182    Updates gradients by the Confidence-guided Adaptive Memory Efficient Optimization (Came) algorithm.
183
184    The Came algorithm is proposed in `CAME: Confidence-guided Adaptive Memory Efficient Optimization
185    <https://arxiv.org/abs/2307.02047>`.
186
187    Args:
188        params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated,
189            the element in `params` must be class `Parameter`.
190        learning_rate (Union[float, Tensor]): A value or a graph for the learning rate.
191            When the learning_rate is a Tensor in a 1D dimension.
192            If the type of `learning_rate` is int, it will be converted to float. Default: None.
193        eps (tuple): The regularization constans for square gradient, parameter scale and instability_matrix
194            respectively. default: (1e-30, 1e-3, 1e-16)
195        clip_threshold (Union[float, Tensor]): The threshold of root mean square of final gradient update. default: 1.0
196        decay_rate (Union[float, Tensor]): The coefficient used to compute running averages of square gradient.
197            default: 0.8
198        beta1 (float): The coefficient to computing running averages of gradient. Should be in range (0.0, 1.0).
199               Default: 0.9.
200        beta3 (float): The coefficient to computing running averages of gradient. Should be in range (0.0, 1.0).
201               Default: 0.99.
202        weight_decay (float): Weight decay (L2 penalty). It must be equal to or greater than 0. Default: 0.0.
203        scale_parameter (bool): If True, learning rate is scaled by root mean square of parameter. default: True
204        relative_step (bool): If True, time-dependent learning rate is computed instead of external learning rate.
205            default: True
206        warmup_init (bool): The time-dependent learning rate computation depends on whether warm-up
207            initialization is being used. default: False
208        compression (bool): If True, the data type of the running averages exponent will be compression to float16.
209            default: False
210        loss_scale (float): A floating point value for the loss scale. Should be greater than 0. In general, use the
211            default value. Only when `FixedLossScaleManager` is used for training and the `drop_overflow_update` in
212            `FixedLossScaleManager` is set to False, then this value needs to be the same as the `loss_scale` in
213            `FixedLossScaleManager`. Refer to class :class:`mindspore.amp.FixedLossScaleManager` for more details.
214            Default: 1.0.
215
216    Inputs:
217        - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
218
219    Outputs:
220        Tensor[bool], the value is True.
221
222    Raises:
223        TypeError: If `learning_rate` is not one of int, float, Tensor, Iterable, LearningRateSchedule.
224        TypeError: If element of `parameters` is neither Parameter nor dict.
225        TypeError: If `beta1`, `beta3`, `eps` or `loss_scale` is not a float.
226        TypeError: If `weight_decay` is neither float nor int.
227        TypeError: If `use_locking` or `use_nesterov` is not a bool.
228        ValueError: If `loss_scale` or `eps` is less than or equal to 0.
229        ValueError: If `beta1`, `beta3` is not in range (0.0, 1.0).
230        ValueError: If `weight_decay` is less than 0.
231
232    Supported Platforms:
233        ``Ascend``
234    """
235    _support_parallel_optimizer = True
236
237    @opt_init_args_register
238    def __init__(self,
239                 params,
240                 learning_rate=None,
241                 eps=(1e-30, 1e-3, 1e-16),
242                 clip_threshold=1.0,
243                 decay_rate=0.8,
244                 beta1=0.9,
245                 beta3=0.99,
246                 weight_decay=0.0,
247                 scale_parameter=False,
248                 relative_step=False,
249                 warmup_init=False,
250                 compression=False,
251                 loss_scale=1.0):
252
253        if learning_rate is not None and relative_step:
254            raise ValueError("Cannot combine manual lr and relative_step options", learning_rate)
255        if warmup_init and not relative_step:
256            raise ValueError("warmup_init requires relative_step=True")
257        if learning_rate is None and not relative_step:
258            raise ValueError("Cannot learning_rate is None and relative_step=False")
259        if learning_rate is None:
260            learning_rate = 0.0
261        if beta1 is None:
262            beta1 = 0.0
263
264        if not isinstance(learning_rate, (float, int)) and learning_rate is not None:
265            if relative_step or scale_parameter:
266                logging.warning("When learning_rate is learning scheduler, it not support update learning rate!")
267
268        super(Came, self).__init__(learning_rate, params, weight_decay, loss_scale)
269        validator.check_value_type("eps", eps, [list, tuple], self.cls_name)
270        if len(eps) != 3:
271            raise ValueError("eps must have 3 value: (eps1, eps2, eps3).")
272        for i, ele in enumerate(eps):
273            validator.check_value_type("eps{}".format(i), ele, [float], self.cls_name)
274            validator.check_non_negative_float(ele, "eps{}".format(i), self.cls_name)
275        validator.check_value_type("clip_threshold", clip_threshold, [float], self.cls_name)
276        validator.check_non_negative_float(clip_threshold, "clip_threshold", self.cls_name)
277        validator.check_value_type("decay_rate", decay_rate, [float], self.cls_name)
278        validator.check_float_range(decay_rate, 0, 1, Rel.INC_NEITHER, "decay_rate", self.cls_name)
279        validator.check_float_range(weight_decay, 0, 1, Rel.INC_LEFT, "weight_decay", self.cls_name)
280        validator.check_value_type("scale_parameter", scale_parameter, [bool], self.cls_name)
281        validator.check_value_type("relative_step", relative_step, [bool], self.cls_name)
282        validator.check_value_type("compression", compression, [bool], self.cls_name)
283        validator.check_value_type("beta1", beta1, [int, float], self.cls_name)
284        validator.check_non_negative_float(float(beta1), "beta1", self.cls_name)
285        validator.check_value_type("beta3", beta3, [int, float], self.cls_name)
286        validator.check_non_negative_float(float(beta3), "beta3", self.cls_name)
287        self.eps = trans_to_tensor(eps)
288        self.clip_threshold = trans_to_tensor(clip_threshold)
289        self.decay_rate = trans_to_tensor(-decay_rate)
290        self.beta1 = trans_to_tensor(beta1)
291        self.beta3 = trans_to_tensor(beta3)
292        self.weight_decay = trans_to_tensor(weight_decay)
293        self.weight_decay_flag = bool(weight_decay)
294
295        self.scale_parameter = scale_parameter
296        self.relative_step = relative_step
297        self.warmup_init = warmup_init
298        self.compression = compression
299        self.init_came_state(beta1)
300        self.step = Parameter(initializer(0, [1], mstype.float32), name='afactor_step')
301        self.fused_ada_factor = P.FusedAdaFactor(enable_scale_parameter=self.scale_parameter,
302                                                 enable_first_moment=self.use_first_moment,
303                                                 enable_weight_decay=self.weight_decay_flag)
304        if context.get_context("device_target") == "CPU":
305            self.use_fused_ada_factor = True
306        else:
307            self.use_fused_ada_factor = False
308        logging.info("Came init completed %s.", self.learning_rate)
309
310    def init_came_state(self, beta1):
311        """init came variables"""
312        if beta1 > 0:
313            self.use_first_moment = True
314            self.exp_avg = self._parameters.clone(prefix="exp_avg", init='zeros')
315        else:
316            self.use_first_moment = False
317            self.exp_avg = ParameterTuple([Parameter(Tensor(0.0))] * len(self._parameters))
318
319        self.exp_avg_sq = []
320        self.exp_avg_sq_col = []
321        self.exp_avg_sq_row = []
322        self.exp_avg_insta_col = []
323        self.exp_avg_insta_row = []
324        for param in self._parameters:
325            param_dtype = param.dtype
326            param_shape = param.shape
327            param_name = param.name
328            if len(param_shape) > 1:
329                self.exp_avg_sq_row.append(Parameter(initializer(0, shape=param_shape[:-1], dtype=param_dtype),
330                                                     name="exp_avg_sq_row_{}".format(param_name)))
331                self.exp_avg_sq_col.append(Parameter(initializer(0, shape=param_shape[:-2] + param_shape[-1:],
332                                                                 dtype=param_dtype),
333                                                     name="exp_avg_sq_col_{}".format(param_name)))
334                self.exp_avg_insta_row.append(Parameter(initializer(0, shape=param_shape[:-1], dtype=param_dtype),
335                                                        name="exp_avg_insta_row_{}".format(param_name)))
336                self.exp_avg_insta_col.append(Parameter(initializer(0, shape=param_shape[:-2] + param_shape[-1:],
337                                                                    dtype=param_dtype),
338                                                        name="exp_avg_insta_col_{}".format(param_name)))
339                self.exp_avg_sq.append(Parameter(initializer(0, shape=(1,), dtype=param_dtype),
340                                                 name="exp_avg_sq_{}".format(param_name)))
341
342            else:
343                self.exp_avg_sq_row.append(Parameter(initializer(0, shape=(1,), dtype=param_dtype),
344                                                     name="exp_avg_sq_row_{}".format(param_name)))
345                self.exp_avg_sq_col.append(Parameter(initializer(0, shape=(1,), dtype=param_dtype),
346                                                     name="exp_avg_sq_col_{}".format(param_name)))
347                self.exp_avg_insta_row.append(Parameter(initializer(0, shape=(1,), dtype=param_dtype),
348                                                        name="exp_avg_insta_row_{}".format(param_name)))
349                self.exp_avg_insta_col.append(Parameter(initializer(0, shape=(1,), dtype=param_dtype),
350                                                        name="exp_avg_insta_col_{}".format(param_name)))
351
352                if self.compression:
353                    self.exp_avg_sq.append(Parameter(initializer(0, shape=param_shape, dtype=mstype.float16),
354                                                     name="exp_avg_sq_{}".format(param_name)))
355                else:
356                    self.exp_avg_sq.append(Parameter(initializer(0, shape=param_shape, dtype=param_dtype),
357                                                     name="exp_avg_sq_{}".format(param_name)))
358
359        self.exp_avg_sq_row = ParameterTuple(self.exp_avg_sq_row)
360        self.exp_avg_sq_col = ParameterTuple(self.exp_avg_sq_col)
361        self.exp_avg_insta_row = ParameterTuple(self.exp_avg_insta_row)
362        self.exp_avg_insta_col = ParameterTuple(self.exp_avg_insta_col)
363        self.exp_avg_sq = ParameterTuple(self.exp_avg_sq)
364
365    @property
366    def supports_memory_efficient_fp16(self):
367        """
368        Support memory efficient for fp16
369        """
370        return True
371
372    @property
373    def supports_flat_params(self):
374        """
375        Support flatten params
376        """
377        return False
378
379    @jit
380    def construct(self, gradients):
381        """construct of came optimizer."""
382        gradients = self.flatten_gradients(gradients)
383        lr = self.get_lr()
384        self.assignadd(self.global_step, self.global_step_increase_tensor)
385        F.assign_add(self.step, 1)
386        step = self.step
387        beta2t = 1.0 - P.Pow()(step, self.decay_rate)
388
389        if self.use_fused_ada_factor:
390            success = self.hyper_map(F.partial(_came_opt, self.fused_ada_factor, self.eps, self.clip_threshold,
391                                               self.beta1, beta2t, self.weight_decay, lr),
392                                     gradients, self._parameters, self.exp_avg, self.exp_avg_sq_row,
393                                     self.exp_avg_sq_col, self.exp_avg_sq)
394        else:
395            success = self.hyper_map(F.partial(_came_opt, self.eps, self.clip_threshold, self.beta1, beta2t, self.beta3,
396                                               self.weight_decay, self.scale_parameter, self.compression,
397                                               self.use_first_moment, self.weight_decay_flag, lr),
398                                     gradients, self._parameters, self.exp_avg, self.exp_avg_sq_row,
399                                     self.exp_avg_sq_col, self.exp_avg_sq, self.exp_avg_insta_row,
400                                     self.exp_avg_insta_col)
401
402        return success
403
404    @Optimizer.target.setter
405    def target(self, value):
406        """
407        If the input value is set to "CPU", the parameters will be updated on the host using the Fused
408        optimizer operation.
409        """
410        self._set_base_target(value)
411        if value == 'CPU':
412            self.fused_ada_factor.set_device("CPU")
413            self.use_fused_ada_factor = True
414        else:
415            self.use_fused_ada_factor = False
416