• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020-2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""adafactor"""
16from mindspore.common import dtype as mstype
17from mindspore.log import logging
18from mindspore.common.initializer import initializer
19from mindspore.ops import operations as P
20from mindspore.ops import composite as C
21from mindspore.ops import functional as F
22from mindspore.common.parameter import Parameter, ParameterTuple
23from mindspore.common.tensor import Tensor
24from mindspore._checkparam import Validator as validator
25from mindspore._checkparam import Rel
26from mindspore.nn.optim.optimizer import opt_init_args_register
27from .optimizer import Optimizer
28
29
30def _get_lr(step, rms, learning_rate, relative_step, warmup_init, scale_parameter, eps):
31    """update optimizer learning rete"""
32    rel_step_sz = learning_rate
33    if relative_step:
34        if warmup_init:
35            min_step = 1e-6 * step * 1.0
36        else:
37            min_step = 1e-2 * 1.0
38
39        rel_step_sz = P.Minimum()(min_step, 1.0 / P.Sqrt()(step * 1.0))
40    param_scale = 1.0
41    if scale_parameter:
42        param_scale = P.Maximum()(eps[1], rms)
43    return rel_step_sz * param_scale * F.ones_like(rms)
44
45
46def _rms(update_tensor):
47    """calculate rms"""
48    return F.sqrt(P.ReduceMean(False)(F.square(update_tensor)))
49
50
51def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col):
52    """Approximation of exponential moving average of square of gradient"""
53    reduce_mean = P.ReduceMean(keep_dims=True)(exp_avg_sq_row, -1)
54    div_val = 1.0 / P.Sqrt()(P.Div()(exp_avg_sq_row, reduce_mean))
55    r_factor = (P.ExpandDims()(div_val, -1))
56
57    exp_avg_sq_col = P.ExpandDims()(exp_avg_sq_col, -2)
58    c_factor = 1.0 / P.Sqrt()(exp_avg_sq_col)
59    return P.Mul()(r_factor, c_factor)
60
61
62_adam_opt = C.MultitypeFuncGraph("adam_opt")
63
64
65@_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool", "Bool",
66                    "Bool", "Bool", "Bool", "Bool", "Bool", "Tensor", "Tensor", "Tensor",
67                    "Tensor", "Tensor", "Tensor", "Tensor", "Tensor")
68def _run_opt_with_one_number(eps, clip_threshold, decay_rate, beta1,
69                             weight_decay, scale_lr, scale_parameter, relative_step,
70                             warmup_init, compression, use_first_moment, weight_decay_flag,
71                             learning_rate, step, grad, param,
72                             exp_avg, exp_avg_sq_row,
73                             exp_avg_sq_col, exp_avg_sq):
74    """Apply ada factor optimizer to the weight parameter using Tensor."""
75    success = True
76    grad_dtype = F.dtype(grad)
77    grad_shape = F.shape(grad)
78
79    if grad_dtype == mstype.float16:
80        grad = F.cast(grad, mstype.float32)
81    p_data_fp32 = param
82    if F.dtype(p_data_fp32) == mstype.float16:
83        p_data_fp32 = F.cast(p_data_fp32, mstype.float32)
84
85    factored = len(grad_shape) >= 2
86
87    # State Initialization
88    exp_avg_update = exp_avg
89    exp_avg_sq_update = exp_avg_sq
90    exp_avg_sq_row_update = exp_avg_sq_row
91    exp_avg_sq_col_update = exp_avg_sq_col
92
93    if use_first_moment:
94        if compression:
95            exp_avg_update = F.cast(exp_avg, mstype.float16)
96
97    if factored:
98        exp_avg_sq_row_update = F.cast(exp_avg_sq_row, grad_dtype)
99        exp_avg_sq_col_update = F.cast(exp_avg_sq_col, grad_dtype)
100    else:
101        exp_avg_sq_update = F.cast(exp_avg_sq, grad_dtype)
102
103    if scale_lr:
104        rms = _rms(p_data_fp32)
105        learning_rate_update = _get_lr(step, rms, learning_rate, relative_step, warmup_init, scale_parameter, eps)
106        learning_rate_update = F.assign(learning_rate, F.cast(learning_rate_update, F.dtype(learning_rate)))
107    else:
108        learning_rate_update = learning_rate * 1.0
109
110    beta2t = 1.0 - P.Pow()(step, decay_rate)
111    update = (grad ** 2) + eps[0]
112
113    if factored:
114        exp_avg_sq_row_update = P.Mul()(exp_avg_sq_row_update, beta2t)
115        update_mean = P.ReduceMean()(update, -1) * (1.0 - beta2t)
116        exp_avg_sq_row_update = P.Add()(exp_avg_sq_row_update, update_mean)
117        exp_avg_sq_row_update = F.assign(exp_avg_sq_row, F.cast(exp_avg_sq_row_update, F.dtype(exp_avg_sq_row)))
118
119        exp_avg_sq_col_update = P.Mul()(exp_avg_sq_col_update, beta2t)
120        update_mean = P.ReduceMean()(update, -2) * (1.0 - beta2t)
121        exp_avg_sq_col_update = P.Add()(exp_avg_sq_col_update, update_mean)
122        exp_avg_sq_col_update = F.assign(exp_avg_sq_col, F.cast(exp_avg_sq_col_update, F.dtype(exp_avg_sq_col)))
123
124        update = _approx_sq_grad(exp_avg_sq_row_update, exp_avg_sq_col_update)
125        update = P.Mul()(update, grad)
126    else:
127        update = update * (1.0 - beta2t)
128        exp_avg_sq_update = P.Add()(P.Mul()(exp_avg_sq_update, beta2t), update)
129        exp_avg_sq_update = F.assign(exp_avg_sq, F.cast(exp_avg_sq_update, F.dtype(exp_avg_sq)))
130        exp_avg_sq_update = 1.0 / P.Sqrt()(exp_avg_sq_update)
131        update = P.Mul()(exp_avg_sq_update, grad)
132
133    update_rms_thres = _rms(update) / clip_threshold
134    update_coff = P.Maximum()(update_rms_thres, P.OnesLike()(update_rms_thres))
135    update = P.Mul()(P.Div()(update, update_coff), learning_rate_update)
136
137    if use_first_moment:
138        if compression:
139            exp_avg_update = F.cast(exp_avg_update, grad_dtype)
140        exp_avg_update = P.Add()(P.Mul()(exp_avg_update, beta1), update * (1 - beta1))
141        update = F.assign(exp_avg, F.cast(exp_avg_update, F.dtype(exp_avg)))
142
143    if weight_decay_flag:
144        p_data_fp32_coff = p_data_fp32 * -weight_decay * learning_rate_update
145        p_data_fp32 = P.Add()(p_data_fp32, p_data_fp32_coff)
146    p_data_fp32 = P.Sub()(p_data_fp32, update)
147    P.Assign()(param, F.cast(p_data_fp32, F.dtype(param)))
148    return success
149
150
151def trans_to_tensor(paras, is_tuple=False, fp32=True):
152    if paras is None or isinstance(paras, bool):
153        return paras
154    data_type = mstype.float32 if fp32 else mstype.float16
155    if is_tuple:
156        new_paras = [Tensor(ele, data_type) for ele in paras]
157        return tuple(new_paras)
158    return Tensor(paras, data_type)
159
160
161class AdaFactor(Optimizer):
162    r"""
163    Updates gradients by the Adaptive Learning Rates with Sublinear Memory Cost (Adafactor) algorithm.
164
165    The Adafactor algorithm is proposed in `Adafactor: Adafactor: Adaptive Learning Rates with Sublinear Memory
166    Cost <https://arxiv.org/abs/1804.04235>`_.
167
168    .. warning::
169        This is an experimental prototype that is subject to change and/or deletion.
170
171    Adafactor for weight vector are as follows,
172
173    .. math::
174        \begin{array}{l} \\
175        \alpha_{t}=\max \left(\epsilon_{2}, \operatorname{RMS}\left(X_{t-1}\right)\right) \rho_{t} \\
176        G_{t}=\nabla f_{t}\left(X_{t-1}\right) \\
177        \hat{V}_{t}=\hat{\beta}_{2} \hat{V}_{t-1}+\left(1-\hat{\beta}_{2_{t}}\right)\left(G_{t}^{2}+ \\
178        \epsilon_{1} 1_{n}\right) \\
179        U_{t}=G_{t} / \sqrt{\hat{V}_{t}} \\
180        \hat{U}_{t}=U_{t} / \max \left(1, \operatorname{RMS}\left(U_{t}\right) / d\right) \\
181        X_{t}=X_{t-1}-\alpha_{t} \hat{U}_{t}
182        \end{array}
183
184    Adafactor for weight matrices are as follows,
185
186    .. math::
187        \begin{array}{l} \\
188        \alpha_{t}=\max \left(\epsilon_{2}, \operatorname{RMS}\left(X_{t-1}\right)\right) \rho_{t} \\
189        G_{t}=\nabla f_{t}\left(X_{t-1}\right) \\
190        R_{t}=\hat{\beta}_{2 t} R_{t-1}+\left(1-\hat{\beta}_{2 t}\right)\left(G_{t}^{2}+ \\
191        \epsilon_{1} 1_{n} 1_{m}^{\top}\right) 1_{m} \\
192        C_{t}=\hat{\beta}_{2 t} C_{t-1}+\left(1-\hat{\beta}_{2 t}\right) 1_{n}^{\top}\left(G_{t}^{2}+ \\
193        \epsilon_{1} 1_{n} 1_{m}^{\top}\right) \\
194        \hat{V}_{t}=R_{t} C_{t} / 1_{n}^{\top} R_{t} \\
195        U_{t}=G_{t} / \sqrt{\hat{V}_{t}} \\
196        \hat{U}_{t}=U_{t} / \max \left(1, \operatorname{RMS}\left(U_{t}\right) / d\right) \\
197        X_{t}=X_{t-1}-\alpha_{t} U_{t}
198        \end{array}
199
200    Where RMS is:
201
202    .. math::
203        \operatorname{RMS}\left(U_{t}\right)=\operatorname{RMS}_{x \in X}\left(u_{x t}\right)= \\
204        \sqrt{\operatorname{Mean}_{x \in X}\left(\frac{\left(g_{x t}\right)^{2}}{\hat{v}_{x t}}\right)}
205
206    :math:`x` is each individual parameter,
207    :math:`t` is assumed to be the current number of steps,
208    :math:`a_{t}` is the learning rate,
209    :math:`f(X)` is the loss function,
210    :math:`\epsilon1` and :math:`\epsilon2` is a small positive number to prevent errors,
211    :math:`d` is the clipping threshold,
212    :math:`\beta_{2}` is the moment decay,
213    :math:`\rho` is the relative step size,
214    :math:`R` is the running averages of the row sums of the squared gradient,
215    :math:`C` is the running averages of the column sums of the squared gradient.
216
217    Note:
218        The learning rate depending of this optimizer will be control by the *scale_parameter*, *relative_step* and
219        *warmup_init* options. To use a manual (external) learning rate schedule, it should be
220        set `scale_parameter=False` and `relative_step=False`.
221
222        If parameters is not used in the network, please do not add it to the optimizer,
223        otherwise the calculation result will be abnormal.
224
225        To improve parameter groups performance, the customized order of parameters is supported.
226
227    Args:
228        params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated,
229            the element in `params` must be class `Parameter`.
230
231        learning_rate (Union[float, Tensor]): A value or a graph for the learning rate.
232            When the learning_rate is a Tensor in a 1D dimension.
233            If the type of `learning_rate` is int, it will be converted to float. Default: None.
234        eps (float): The regularization constans for square gradient and parameter scale respectively.
235            default: (1e-30, 1e-3)
236        clip_threshold (Union[float, Tensor]): The threshold of root mean square of final gradient update. default: 1.0
237        decay_rate (Union[float, Tensor]): The coefficient used to compute running averages of square gradient.
238            default: 0.8
239        beta1 (float): The coefficient to computing running averages of gradient. Should be in range (0.0, 1.0).
240               Default: None.
241        weight_decay (float): Weight decay (L2 penalty). It must be equal to or greater than 0. Default: 0.0.
242        scale_parameter (bool): If True, learning rate is scaled by root mean square of parameter. default: True
243        relative_step (bool): If True, time-dependent learning rate is computed instead of external learning rate.
244            default: True
245        warmup_init (bool): The time-dependent learning rate computation depends on whether warm-up
246            initialization is being used. default: False
247        compression (bool): If True, the data type of the running averages exponent will be compression to float16.
248            default: False
249        loss_scale (float): A floating point value for the loss scale. Should be greater than 0. In general, use the
250            default value. Only when `FixedLossScaleManager` is used for training and the `drop_overflow_update` in
251            `FixedLossScaleManager` is set to False, then this value needs to be the same as the `loss_scale` in
252            `FixedLossScaleManager`. Refer to class :class:`mindspore.FixedLossScaleManager` for more details.
253            Default: 1.0.
254
255    Inputs:
256        - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
257
258    Outputs:
259        Tensor[bool], the value is True.
260
261    Raises:
262        TypeError: If `learning_rate` is not one of int, float, Tensor, Iterable, LearningRateSchedule.
263        TypeError: If element of `parameters` is neither Parameter nor dict.
264        TypeError: If `beta1`, `beta2`, `eps` or `loss_scale` is not a float.
265        TypeError: If `weight_decay` is neither float nor int.
266        TypeError: If `use_locking` or `use_nesterov` is not a bool.
267        ValueError: If `loss_scale` or `eps` is less than or equal to 0.
268        ValueError: If `beta1`, `beta2` is not in range (0.0, 1.0).
269        ValueError: If `weight_decay` is less than 0.
270
271    Supported Platforms:
272        ``Ascend``
273
274    Examples:
275        >>> net = Net()
276        >>> #1) Parameters use the default learning rate with None and weight decay with 0.
277        >>> optim = nn.AdaFactor(params=net.trainable_params())
278        >>>
279        >>> #2) Use parameter groups
280        >>> all_params = net.trainable_params()
281        >>> group_params = [{'params': [all_params[0]]}, {'params': [all_params[1]]}]
282        >>> optim = nn.AdaFactor(group_params, learning_rate=0.1, weight_decay=0.0, relative_step=False)
283        >>> loss = nn.SoftmaxCrossEntropyWithLogits()
284        >>> model = Model(net, loss_fn=loss, optimizer=optim)
285    """
286
287    @opt_init_args_register
288    def __init__(self,
289                 params,
290                 learning_rate=None,
291                 eps=(1e-30, 1e-3),
292                 clip_threshold=1.0,
293                 decay_rate=0.8,
294                 beta1=0.9,
295                 weight_decay=0.0,
296                 scale_parameter=True,
297                 relative_step=True,
298                 warmup_init=False,
299                 compression=False,
300                 loss_scale=1.0):
301
302        if learning_rate is not None and relative_step:
303            raise ValueError("Cannot combine manual lr and relative_step options", learning_rate)
304        if warmup_init and not relative_step:
305            raise ValueError("warmup_init requires relative_step=True")
306        if learning_rate is None and not relative_step:
307            raise ValueError("Cannot learning_rate is None and relative_step=False")
308        if learning_rate is None:
309            learning_rate = 0.0
310        if beta1 is None:
311            beta1 = 0.0
312        self.scale_lr = True
313        if not isinstance(learning_rate, (float, int)) and learning_rate is not None:
314            self.scale_lr = False
315            if relative_step or scale_parameter:
316                logging.warning("When learning_rate is learning scheduler, it not support update learning rate!")
317
318        super(AdaFactor, self).__init__(learning_rate, params, weight_decay, loss_scale)
319        validator.check_value_type("eps", eps, [list, tuple], self.cls_name)
320        if len(eps) != 2:
321            raise ValueError("eps must have 2 value: (eps1, eps2).")
322        for i, ele in enumerate(eps):
323            validator.check_value_type("eps{}".format(i), ele, [float], self.cls_name)
324            validator.check_non_negative_float(ele, "eps{}".format(i), self.cls_name)
325        validator.check_value_type("clip_threshold", clip_threshold, [float], self.cls_name)
326        validator.check_non_negative_float(clip_threshold, "clip_threshold", self.cls_name)
327        validator.check_value_type("decay_rate", decay_rate, [float], self.cls_name)
328        validator.check_float_range(decay_rate, 0, 1, Rel.INC_NEITHER, "decay_rate", self.cls_name)
329        validator.check_float_range(weight_decay, 0, 1, Rel.INC_LEFT, "weight_decay", self.cls_name)
330        validator.check_value_type("scale_parameter", scale_parameter, [bool], self.cls_name)
331        validator.check_value_type("relative_step", relative_step, [bool], self.cls_name)
332        validator.check_value_type("compression", compression, [bool], self.cls_name)
333        validator.check_value_type("beta1", beta1, [int, float], self.cls_name)
334        validator.check_non_negative_float(float(beta1), "beta1", self.cls_name)
335        self.eps = trans_to_tensor(eps)
336        self.clip_threshold = trans_to_tensor(clip_threshold)
337        self.decay_rate = trans_to_tensor(-decay_rate)
338        self.beta1 = trans_to_tensor(beta1)
339        self.weight_decay = trans_to_tensor(weight_decay)
340        self.weight_decay_flag = bool(weight_decay)
341
342        self.step = Parameter(Tensor(0, dtype=mstype.float32), name="train_step")
343        self.scale_parameter = scale_parameter
344        self.relative_step = relative_step
345        self.warmup_init = warmup_init
346        self.compression = compression
347
348        self.init_ada_factor_state(beta1)
349        self.step = Parameter(initializer(0, [1], mstype.float32), name='afactor_step')
350        print("AdaFactor init completed", self.learning_rate)
351
352    def init_ada_factor_state(self, beta1):
353        """init adafactor variables"""
354        if beta1 > 0:
355            self.use_first_moment = True
356            self.exp_avg = self.parameters.clone(prefix="exp_avg", init='zeros')
357        else:
358            self.use_first_moment = False
359            self.exp_avg = ParameterTuple([Parameter(Tensor(0.0))] * len(self.parameters))
360
361        self.exp_avg_sq = []
362        self.exp_avg_sq_col = []
363        self.exp_avg_sq_row = []
364        for paras in self.parameters:
365            paras_dtype = paras.dtype
366            paras_shape = paras.shape
367            paras_name = paras.name
368            if len(paras_shape) > 1:
369                self.exp_avg_sq_row.append(Parameter(initializer(0, shape=paras_shape[:-1], dtype=paras_dtype),
370                                                     name="exp_avg_sq_row_{}".format(paras_name)))
371                self.exp_avg_sq_col.append(Parameter(initializer(0, shape=paras_shape[:-2] + paras_shape[-1:],
372                                                                 dtype=paras_dtype),
373                                                     name="exp_avg_sq_col_{}".format(paras_name)))
374                if self.compression:
375                    self.exp_avg_sq.append(Parameter(initializer(0, shape=(1,), dtype=mstype.float16),
376                                                     name="exp_avg_sq_{}".format(paras_name)))
377                else:
378                    self.exp_avg_sq.append(Parameter(initializer(0, shape=(1,), dtype=paras_dtype),
379                                                     name="exp_avg_sq_{}".format(paras_name)))
380
381            else:
382                self.exp_avg_sq_row.append(Parameter(initializer(0, shape=(1,), dtype=paras_dtype),
383                                                     name="exp_avg_sq_row_{}".format(paras_name)))
384                self.exp_avg_sq_col.append(Parameter(initializer(0, shape=(1,), dtype=paras_dtype),
385                                                     name="exp_avg_sq_col_{}".format(paras_name)))
386
387                if self.compression:
388                    self.exp_avg_sq.append(Parameter(initializer(0, shape=paras_shape, dtype=mstype.float16),
389                                                     name="exp_avg_sq_{}".format(paras_name)))
390                else:
391                    self.exp_avg_sq.append(Parameter(initializer(0, shape=paras_shape, dtype=paras_dtype),
392                                                     name="exp_avg_sq_{}".format(paras_name)))
393
394        self.exp_avg_sq_row = ParameterTuple(self.exp_avg_sq_row)
395        self.exp_avg_sq_col = ParameterTuple(self.exp_avg_sq_col)
396        self.exp_avg_sq = ParameterTuple(self.exp_avg_sq)
397
398    @property
399    def supports_memory_efficient_fp16(self):
400        return True
401
402    @property
403    def supports_flat_params(self):
404        return False
405
406    def construct(self, gradients):
407        lr = self.get_lr()
408        step = F.assign_add(self.step, 1)
409        success = self.hyper_map(F.partial(_adam_opt, self.eps, self.clip_threshold, self.decay_rate,
410                                           self.beta1, self.weight_decay, self.scale_lr,
411                                           self.scale_parameter, self.relative_step,
412                                           self.warmup_init, self.compression, self.use_first_moment,
413                                           self.weight_decay_flag, lr, step),
414                                 gradients, self.parameters, self.exp_avg, self.exp_avg_sq_row,
415                                 self.exp_avg_sq_col, self.exp_avg_sq)
416
417        return success
418
419    @Optimizer.target.setter
420    def target(self, value):
421        """
422        If the input value is set to "CPU", the parameters will be updated on the host using the Fused
423        optimizer operation.
424        """
425        self._set_base_target(value)
426