• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020-2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""lamb"""
16import numpy as np
17from mindspore import context
18from mindspore.common import dtype as mstype
19from mindspore.common.initializer import initializer
20from mindspore.ops import operations as P
21from mindspore.ops import composite as C
22from mindspore.ops import functional as F
23from mindspore.common.parameter import Parameter
24from mindspore.common.tensor import Tensor
25from mindspore._checkparam import Validator as validator
26from mindspore._checkparam import Rel
27from .optimizer import Optimizer
28from .optimizer import opt_init_args_register
29from .. import layer
30
31
32num_one = Tensor(np.ones([1]), mstype.float32)
33
34_lamb_opt = C.MultitypeFuncGraph("lamb_opt")
35
36
37@_lamb_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor",
38                    "Tensor", "Bool", "Bool")
39def _update_run_op(beta1, beta2, eps, global_step, lr, weight_decay, param, m, v, gradient, decay_flag, optim_filter):
40    """
41    Update parameters.
42
43    Args:
44        beta1 (Tensor): The exponential decay rate for the 1st moment estimations. Should be in range (0.0, 1.0).
45        beta2 (Tensor): The exponential decay rate for the 2nd moment estimations. Should be in range (0.0, 1.0).
46        eps (Tensor): Term added to the denominator to improve numerical stability. Should be greater than 0.
47        lr (Tensor): Learning rate.
48        weight_decay (numbers.Number): Weight decay. Should be equal to or greater than 0.
49        global_step (Tensor): Global step.
50        param (Tensor): Parameters.
51        m (Tensor): m value of parameters.
52        v (Tensor): v value of parameters.
53        gradient (Tensor): Gradient of parameters.
54        decay_flag (bool): Specifies whether param update with weight decay.
55        optim_filter(bool): Applies parameter update or not.
56
57    Returns:
58        Tensor, the new value of v after updating.
59    """
60    if optim_filter:
61        op_mul = P.Mul()
62        op_sqrt = P.Sqrt()
63        op_rsqrt = P.Rsqrt()
64        op_square = P.Square()
65        op_cast = P.Cast()
66        op_reshape = P.Reshape()
67        op_shape = P.Shape()
68        op_pow = P.Pow()
69        op_norm = layer.Norm()
70        op_select = P.Select()
71        op_greater = P.Greater()
72        op_fill = P.Fill()
73        op_dtype = P.DType()
74
75        param_fp32 = op_cast(param, mstype.float32)
76        m_fp32 = op_cast(m, mstype.float32)
77        v_fp32 = op_cast(v, mstype.float32)
78        gradient_fp32 = op_cast(gradient, mstype.float32)
79
80        next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(num_one, mstype.float32) - beta1, gradient_fp32)
81
82        next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(num_one, mstype.float32) - beta2, op_square(gradient_fp32))
83
84        next_mm = next_m / (op_cast(num_one, mstype.float32)
85                            - op_pow(beta1, op_cast(global_step + num_one, mstype.float32)))
86        next_vv = next_v / (op_cast(num_one, mstype.float32) -
87                            op_pow(beta2, op_cast(global_step + num_one, mstype.float32)))
88        w_norm = op_norm(param_fp32)
89        g_norm = op_norm(gradient_fp32)
90
91        g_norm_hat = op_norm(op_mul(next_mm, op_rsqrt(next_vv + eps)) + weight_decay * param_fp32)
92        zeros = F.zeros_like(w_norm)
93        ones = op_fill(op_dtype(w_norm), op_shape(w_norm), 1.0)
94        trust_ratio = op_select(
95            op_greater(w_norm, zeros),
96            op_select(op_greater(g_norm, zeros), w_norm / g_norm_hat, ones),
97            ones)
98        tens = op_fill(op_dtype(trust_ratio), op_shape(trust_ratio), 10.0)
99        trust_ratio = C.clip_by_value(trust_ratio, zeros, tens)
100        update = next_mm / (op_sqrt(next_vv) + eps)
101
102        if decay_flag:
103            update = update + op_mul(weight_decay, param_fp32)
104
105        update_with_lr = op_mul(op_mul(trust_ratio, lr), update)
106
107        next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32))
108
109        next_param = F.depend(next_param, F.assign(param, op_cast(next_param, F.dtype(param))))
110        next_param = F.depend(next_param, F.assign(m, op_cast(next_m, F.dtype(m))))
111        next_param = F.depend(next_param, F.assign(v, op_cast(next_v, F.dtype(v))))
112
113        return op_cast(next_param, F.dtype(param))
114    return gradient
115
116_lamb_opt_ascend = C.MultitypeFuncGraph("lamb_opt_ascend")
117
118
119@_lamb_opt_ascend.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor",
120                           "Tensor", "Bool", "Bool")
121def _update_run_op_ascend(beta1, beta2, eps, global_step, lr, weight_decay, param, m, v, gradient, decay_flag,
122                          optim_filter):
123    """
124    Update parameters function when device target is ascend.
125
126    Args:
127        beta1 (Tensor): The exponential decay rate for the 1st moment estimations. Should be in range (0.0, 1.0).
128        beta2 (Tensor): The exponential decay rate for the 2nd moment estimations. Should be in range (0.0, 1.0).
129        eps (Tensor): Term added to the denominator to improve numerical stability. Should be greater than 0.
130        lr (Tensor): Learning rate.
131        weight_decay (numbers.Number): Weight decay. Should be equal to or greater than 0.
132        global_step (Tensor): Global step.
133        param (Tensor): Parameters.
134        m (Tensor): m value of parameters.
135        v (Tensor): v value of parameters.
136        gradient (Tensor): Gradient of parameters.
137        decay_flag (bool): Specifies whether param update with weight decay.
138        optim_filter(bool): Applies parameter update or not.
139
140    Returns:
141        Tensor, the new value of v after updating.
142    """
143    if optim_filter:
144        op_cast = P.Cast()
145        op_norm = layer.Norm()
146        op_lamb_apply_optimizer_assign = P.LambApplyOptimizerAssign()
147        op_lamb_apply_weight_assign = P.LambApplyWeightAssign()
148
149        param_fp32 = op_cast(param, mstype.float32)
150        gradient_fp32 = op_cast(gradient, mstype.float32)
151        new_global_step = op_cast(global_step + num_one, mstype.float32)
152        weight_decay_flag = op_cast(decay_flag, mstype.float32)
153
154        update, _, _ = op_lamb_apply_optimizer_assign(gradient_fp32, v, m, param_fp32,
155                                                      beta1, 1.0 - beta1, beta2, 1.0 - beta2, eps,
156                                                      new_global_step, weight_decay_flag, weight_decay)
157        w_norm = op_norm(param_fp32)
158        g_norm = op_norm(update)
159        update = F.depend(update, op_lamb_apply_weight_assign(w_norm, g_norm, lr, update, param))
160        return update
161    return gradient
162
163
164def _check_param_value(beta1, beta2, eps, prim_name):
165    validator.check_value_type("beta1", beta1, [float], prim_name)
166    validator.check_value_type("beta2", beta2, [float], prim_name)
167    validator.check_value_type("eps", eps, [float], prim_name)
168    validator.check_float_range(beta1, 0.0, 1.0, Rel.INC_NEITHER, "beta1", prim_name)
169    validator.check_float_range(beta2, 0.0, 1.0, Rel.INC_NEITHER, "beta2", prim_name)
170    validator.check_positive_float(eps, "eps", prim_name)
171
172
173class Lamb(Optimizer):
174    r"""
175    Lamb(Layer-wise Adaptive Moments optimizer for Batching training) Dynamic Learning Rate.
176
177    LAMB is an optimization algorithm employing a layerwise adaptive large batch optimization technique.
178    Refer to the paper `LARGE BATCH OPTIMIZATION FOR DEEP LEARNING: TRAINING BERT IN 76
179    MINUTES <https://arxiv.org/abs/1904.00962>`_.
180
181    The LAMB optimizer aims to increase the training batch size without reducing the accuracy,
182    and it supports adaptive element-by-element update and accurate layered correction.
183
184    The updating of parameters follows:
185
186    ..  math::
187        \begin{gather*}
188        m_t = \beta_1 m_{t - 1}+ (1 - \beta_1)g_t\\
189        v_t = \beta_2 v_{t - 1}  + (1 - \beta_2)g_t^2\\
190        m_t = \frac{m_t}{\beta_1^t}\\
191        v_t = \frac{v_t}{\beta_2^t}\\
192        r_t = \frac{m_t}{\sqrt{v_t}+\epsilon}\\
193        w_t = w_{t-1} -\eta_t \frac{\| w_{t-1} \|}{\| r_t + \lambda w_{t-1} \|} (r_t + \lambda w_{t-1})
194        \end{gather*}
195
196    where :math:`m` is the 1st moment, and :math:`v` the 2nd moment, :math:`\eta` the
197    learning rate, :math:`\lambda` the LAMB weight decay rate.
198
199    Note:
200        When separating parameter groups, the weight decay in each group will be applied on the parameters if the
201        weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied
202        on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive.
203
204        When separating parameter groups, if you want to centralize the gradient, set grad_centralization to True,
205        but the gradient centralization can only be applied to the parameters of the convolution layer.
206        If the parameters of the non convolution layer are set to True, an error will be reported.
207
208        To improve parameter groups performance, the customized order of parameters can be supported.
209
210        There is usually no connection between a optimizer and mixed precision. But when `FixedLossScaleManager` is used
211        and `drop_overflow_update` in `FixedLossScaleManager` is set to False, optimizer needs to set the 'loss_scale'.
212        As this optimizer has no argument of `loss_scale`, so `loss_scale` needs to be processed by other means, refer
213        document `LossScale <https://www.mindspore.cn/docs/programming_guide/zh-CN/r1.5/lossscale.html>`_ to process
214        `loss_scale` correctly.
215
216    Args:
217        params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated,
218            the element in `params` must be class `Parameter`. When the `params` is a list of `dict`, the "params",
219            "lr", "weight_decay" and "order_params" are the keys can be parsed.
220
221            - params: Required. The value must be a list of `Parameter`.
222
223            - lr: Optional. If "lr" in the keys, the value of corresponding learning rate will be used.
224              If not, the `learning_rate` in the API will be used.
225
226            - weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay
227              will be used. If not, the `weight_decay` in the API will be used.
228
229            - order_params: Optional. If "order_params" in the keys, the value must be the order of parameters and
230              the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which
231              in the value of 'order_params' must be in one of group parameters.
232
233            - grad_centralization: Optional. The data type of "grad_centralization" is Bool. If "grad_centralization"
234              is in the keys, the set value will be used. If not, the `grad_centralization` is False by default.
235              This parameter only works on the convolution layer.
236
237        learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate.
238            When the learning_rate is an Iterable or a Tensor in a 1D dimension, use dynamic learning rate, then
239            the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule,
240            use dynamic learning rate, the i-th learning rate will be calculated during the process of training
241            according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor in a zero
242            dimension, use fixed learning rate. Other cases are not supported. The float learning rate must be
243            equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float.
244        beta1 (float): The exponential decay rate for the 1st moment estimations. Default: 0.9.
245            Should be in range (0.0, 1.0).
246        beta2 (float): The exponential decay rate for the 2nd moment estimations. Default: 0.999.
247            Should be in range (0.0, 1.0).
248        eps (float): Term added to the denominator to improve numerical stability. Default: 1e-6.
249            Should be greater than 0.
250        weight_decay (float): Weight decay (L2 penalty). Default: 0.0. Should be equal to or greater than 0.
251
252    Inputs:
253        - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
254
255    Outputs:
256        tuple[bool], all elements are True.
257
258    Raises:
259        TypeError: If `learning_rate` is not one of int, float, Tensor, Iterable, LearningRateSchedule.
260        TypeError: If element of `parameters` is neither Parameter nor dict.
261        TypeError: If `beta1`, `beta2` or `eps` is not a float.
262        TypeError: If `weight_decay` is neither float nor int.
263        ValueError: If `eps` is less than or equal to 0.
264        ValueError: If `beta1`, `beta2` is not in range (0.0, 1.0).
265        ValueError: If `weight_decay` is less than 0.
266
267    Supported Platforms:
268        ``Ascend`` ``GPU`` ``CPU``
269
270    Examples:
271        >>> net = Net()
272        >>> #1) All parameters use the same learning rate and weight decay
273        >>> optim = nn.Lamb(params=net.trainable_params(), learning_rate=0.1)
274        >>>
275        >>> #2) Use parameter groups and set different values
276        >>> poly_decay_lr = learning_rate_schedule.PolynomialDecayLR(learning_rate=0.1, end_learning_rate=0.01,
277        ...                                                    decay_steps=4, power = 0.5)
278        >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
279        >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
280        >>> group_params = [{'params': conv_params, 'weight_decay': 0.01, 'grad_centralization':True},
281        ...                 {'params': no_conv_params, 'lr': poly_decay_lr},
282        ...                 {'order_params': net.trainable_params(0.01)}]
283        >>> optim = nn.Lamb(group_params, learning_rate=0.1, weight_decay=0.0)
284        >>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01 and grad
285        >>> # centralization of True.
286        >>> # The no_conv_params's parameters will use dynamic learning rate of poly decay learning rate and default
287        >>> # weight decay of 0.0 and grad centralization of False.
288        >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'.
289        >>>
290        >>> loss = nn.SoftmaxCrossEntropyWithLogits()
291        >>> model = Model(net, loss_fn=loss, optimizer=optim)
292    """
293
294    @opt_init_args_register
295    def __init__(self, params, learning_rate, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0):
296        super(Lamb, self).__init__(learning_rate, params, weight_decay)
297        _check_param_value(beta1, beta2, eps, self.cls_name)
298
299        # turn them to scalar when me support scalar/tensor mix operations
300        self.beta1 = Tensor(np.array([beta1]).astype(np.float32))
301        self.beta2 = Tensor(np.array([beta2]).astype(np.float32))
302        self.eps = Tensor(np.array([eps]).astype(np.float32))
303        self.params = self.parameters
304        self.moments1 = self.params.clone(prefix="lamb_m", init='zeros')
305        self.moments2 = self.params.clone(prefix="lamb_v", init='zeros')
306
307        if not self.dynamic_lr:
308            self.global_step = Parameter(initializer(0, [1]), name='global_step')
309            self.assignadd = P.AssignAdd()
310        self.device_ascend = context.get_context("device_target") == "Ascend"
311
312    def construct(self, gradients):
313        lr = self.get_lr()
314        lamb_opt = _lamb_opt_ascend if self.device_ascend else _lamb_opt
315        gradients = self.gradients_centralization(gradients)
316        if self.is_group:
317            if self.is_group_lr:
318                optim_result = self.hyper_map(F.partial(lamb_opt, self.beta1, self.beta2, self.eps,
319                                                        self.global_step),
320                                              lr, self.weight_decay, self.params, self.moments1, self.moments2,
321                                              gradients, self.decay_flags, self.optim_filter)
322            else:
323                optim_result = self.hyper_map(F.partial(lamb_opt, self.beta1, self.beta2, self.eps,
324                                                        self.global_step, lr),
325                                              self.weight_decay, self.params, self.moments1, self.moments2,
326                                              gradients, self.decay_flags, self.optim_filter)
327        else:
328            optim_result = self.hyper_map(F.partial(lamb_opt, self.beta1, self.beta2, self.eps,
329                                                    self.global_step, lr, self.weight_decay),
330                                          self.params, self.moments1, self.moments2, gradients,
331                                          self.decay_flags, self.optim_filter)
332
333        if self.use_parallel:
334            optim_result = F.depend(optim_result, self.broadcast_params(optim_result))
335
336        if not self.dynamic_lr:
337            optim_result = F.depend(optim_result, self.assignadd(self.global_step, 1))
338
339        return optim_result
340