• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""AdamWeightDecayForBert, a customized Adam for bert. Input: gradient, overflow flag."""
16import numpy as np
17
18from mindspore.common import dtype as mstype
19from mindspore.ops import operations as P
20from mindspore.ops import composite as C
21from mindspore.ops import functional as F
22from mindspore.common.tensor import Tensor
23from mindspore._checkparam import Validator as validator
24from mindspore._checkparam import Rel
25from mindspore.nn.optim.optimizer import Optimizer
26
27_adam_opt = C.MultitypeFuncGraph("adam_opt")
28_scaler_one = Tensor(1, mstype.int32)
29_scaler_ten = Tensor(10, mstype.float32)
30
31@_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor",
32                    "Tensor", "Bool", "Bool")
33def _update_run_kernel(beta1, beta2, eps, lr, weight_decay, param, m, v, gradient, decay_flags, optim_filter):
34    """
35    Update parameters by AdamWeightDecay op.
36    """
37    if optim_filter:
38        adam = P.AdamWeightDecay()
39        if decay_flags:
40            next_param = adam(param, m, v, lr, beta1, beta2, eps, weight_decay, gradient)
41        else:
42            next_param = adam(param, m, v, lr, beta1, beta2, eps, 0.0, gradient)
43        return next_param
44    return gradient
45
46@_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor",
47                    "Tensor", "Bool", "Bool")
48def _update_run_op(beta1, beta2, eps, lr, overflow, weight_decay, param, m, v, gradient, decay_flag, optim_filter):
49    """
50    Update parameters.
51
52    Args:
53        beta1 (Tensor): The exponential decay rate for the 1st moment estimations. Should be in range (0.0, 1.0).
54        beta2 (Tensor): The exponential decay rate for the 2nd moment estimations. Should be in range (0.0, 1.0).
55        eps (Tensor): Term added to the denominator to improve numerical stability. Should be greater than 0.
56        lr (Tensor): Learning rate.
57        overflow (Tensor): Whether overflow occurs.
58        weight_decay (Number): Weight decay. Should be equal to or greater than 0.
59        param (Tensor): Parameters.
60        m (Tensor): m value of parameters.
61        v (Tensor): v value of parameters.
62        gradient (Tensor): Gradient of parameters.
63        decay_flag (bool): Applies weight decay or not.
64        optim_filter (bool): Applies parameter update or not.
65
66    Returns:
67        Tensor, the new value of v after updating.
68    """
69    if optim_filter:
70        op_mul = P.Mul()
71        op_square = P.Square()
72        op_sqrt = P.Sqrt()
73        op_cast = P.Cast()
74        op_reshape = P.Reshape()
75        op_shape = P.Shape()
76        op_select = P.Select()
77
78        param_fp32 = op_cast(param, mstype.float32)
79        m_fp32 = op_cast(m, mstype.float32)
80        v_fp32 = op_cast(v, mstype.float32)
81        gradient_fp32 = op_cast(gradient, mstype.float32)
82
83        cond = op_cast(F.fill(mstype.int32, op_shape(m_fp32), 1) * op_reshape(overflow, (())), mstype.bool_)
84        next_m = op_mul(beta1, m_fp32) + op_select(cond, m_fp32,\
85                op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta1, gradient_fp32))
86
87        next_v = op_mul(beta2, v_fp32) + op_select(cond, v_fp32,\
88                op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta2, op_square(gradient_fp32)))
89
90        update = next_m / (eps + op_sqrt(next_v))
91        if decay_flag:
92            update = op_mul(weight_decay, param_fp32) + update
93
94        update_with_lr = op_mul(lr, update)
95        zeros = F.fill(mstype.float32, op_shape(param_fp32), 0)
96        next_param = param_fp32 - op_select(cond, zeros, op_reshape(update_with_lr, op_shape(param_fp32)))
97
98        next_param = F.depend(next_param, F.assign(param, op_cast(next_param, F.dtype(param))))
99        next_param = F.depend(next_param, F.assign(m, op_cast(next_m, F.dtype(m))))
100        next_param = F.depend(next_param, F.assign(v, op_cast(next_v, F.dtype(v))))
101
102        return op_cast(next_param, F.dtype(param))
103    return gradient
104
105
106@_adam_opt.register("Function", "Function", "Function", "Function", "Bool", "Bool", "Bool", "Tensor", "Tensor",
107                    "Tensor", "Tensor", "Tensor", "Tensor", "RowTensor", "Tensor", "Tensor", "Tensor", "Bool", "Bool")
108def _run_opt_with_sparse(opt, sparse_opt, push, pull, use_locking, use_nesterov, target, beta1_power,
109                         beta2_power, beta1, beta2, eps, lr, gradient, param, m, v, ps_parameter, cache_enable):
110    """Apply sparse adam optimizer to the weight parameter when the gradient is sparse."""
111    success = True
112    indices = gradient.indices
113    values = gradient.values
114    if ps_parameter and not cache_enable:
115        op_shape = P.Shape()
116        shapes = (op_shape(param), op_shape(m), op_shape(v),
117                  op_shape(beta1_power), op_shape(beta2_power), op_shape(lr), op_shape(beta1),
118                  op_shape(beta2), op_shape(eps), op_shape(values), op_shape(indices))
119        success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2,
120                                               eps, values, indices), shapes), param))
121        return success
122
123    if not target:
124        success = F.depend(success, sparse_opt(param, m, v, beta1_power, beta2_power, lr, beta1, beta2,
125                                               eps, values, indices))
126    else:
127        op_mul = P.Mul()
128        op_square = P.Square()
129        op_sqrt = P.Sqrt()
130        scatter_add = P.ScatterAdd(use_locking)
131
132        F.assign(m, op_mul(beta1, m))
133        F.assign(v, op_mul(beta2, v))
134
135        grad_indices = gradient.indices
136        grad_value = gradient.values
137
138        next_m = scatter_add(m,
139                             grad_indices,
140                             op_mul(F.tuple_to_array((1.0,)) - beta1, grad_value))
141
142        next_v = scatter_add(v,
143                             grad_indices,
144                             op_mul(F.tuple_to_array((1.0,)) - beta2, op_square(grad_value)))
145
146        if use_nesterov:
147            m_temp = next_m * _scaler_ten
148            F.assign(m, op_mul(beta1, next_m))
149            div_value = scatter_add(m,
150                                    op_mul(grad_indices, _scaler_one),
151                                    op_mul(F.tuple_to_array((1.0,)) - beta1, grad_value))
152            param_update = div_value / (op_sqrt(next_v) + eps)
153
154            F.assign(m, m_temp / _scaler_ten)
155
156
157        else:
158            param_update = next_m / (op_sqrt(next_v) + eps)
159
160        lr_t = lr * op_sqrt(1 - beta2_power) / (1 - beta1_power)
161
162        next_param = param - lr_t * param_update
163
164
165
166        success = F.depend(success, F.assign(param, next_param))
167        success = F.depend(success, F.assign(m, next_m))
168        success = F.depend(success, F.assign(v, next_v))
169
170    return success
171
172
173@_adam_opt.register("Function", "Function", "Function", "Function", "Bool", "Bool", "Bool", "Tensor", "Tensor",
174                    "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool", "Bool")
175def _run_opt_with_one_number(opt, sparse_opt, push, pull, use_locking, use_nesterov, target,
176                             beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, param,
177                             moment1, moment2, ps_parameter, cache_enable):
178    """Apply adam optimizer to the weight parameter using Tensor."""
179    success = True
180    if ps_parameter and not cache_enable:
181        op_shape = P.Shape()
182        success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2, eps, gradient),
183                                              (op_shape(param), op_shape(moment1), op_shape(moment2))), param))
184    else:
185        success = F.depend(success, opt(param, moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2,
186                                        eps, gradient))
187    return success
188
189
190@_adam_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
191                    "Tensor", "Tensor")
192def _run_off_load_opt(opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, param, moment1, moment2):
193    """Apply AdamOffload optimizer to the weight parameter using Tensor."""
194    success = True
195    delat_param = opt(moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2, eps, gradient)
196    success = F.depend(success, F.assign_add(param, delat_param))
197    return success
198
199
200def _check_param_value(beta1, beta2, eps, prim_name):
201    """Check the type of inputs."""
202    validator.check_value_type("beta1", beta1, [float], prim_name)
203    validator.check_value_type("beta2", beta2, [float], prim_name)
204    validator.check_value_type("eps", eps, [float], prim_name)
205    validator.check_float_range(beta1, 0.0, 1.0, Rel.INC_NEITHER, "beta1", prim_name)
206    validator.check_float_range(beta2, 0.0, 1.0, Rel.INC_NEITHER, "beta2", prim_name)
207    validator.check_positive_float(eps, "eps", prim_name)
208
209class AdamWeightDecayForBert(Optimizer):
210    """
211    Implements the Adam algorithm to fix the weight decay.
212
213    Note:
214        When separating parameter groups, the weight decay in each group will be applied on the parameters if the
215        weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied
216        on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive.
217
218        To improve parameter groups performance, the customized order of parameters can be supported.
219
220    Args:
221        params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated,
222            the element in `params` must be class `Parameter`. When the `params` is a list of `dict`, the "params",
223            "lr", "weight_decay" and "order_params" are the keys can be parsed.
224
225            - params: Required. The value must be a list of `Parameter`.
226
227            - lr: Optional. If "lr" is in the keys, the value of the corresponding learning rate will be used.
228              If not, the `learning_rate` in the API will be used.
229
230            - weight_decay: Optional. If "weight_decay" is in the keys, the value of the corresponding weight decay
231              will be used. If not, the `weight_decay` in the API will be used.
232
233            - order_params: Optional. If "order_params" is in the keys, the value must be the order of parameters and
234              the order will be followed in the optimizer. There are no other keys in the `dict` and the parameters
235              which in the 'order_params' must be in one of group parameters.
236
237        learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate.
238            When the learning_rate is an Iterable or a Tensor in a 1D dimension, use the dynamic learning rate, then
239            the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule,
240            use dynamic learning rate, the i-th learning rate will be calculated during the process of training
241            according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor in a zero
242            dimension, use fixed learning rate. Other cases are not supported. The float learning rate must be
243            equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float.
244            Default: 1e-3.
245        beta1 (float): The exponential decay rate for the 1st moment estimations. Default: 0.9.
246            Should be in range (0.0, 1.0).
247        beta2 (float): The exponential decay rate for the 2nd moment estimations. Default: 0.999.
248            Should be in range (0.0, 1.0).
249        eps (float): Term added to the denominator to improve numerical stability. Default: 1e-6.
250            Should be greater than 0.
251        weight_decay (float): Weight decay (L2 penalty). It must be equal to or greater than 0. Default: 0.0.
252
253    Inputs:
254        - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
255        - **overflow** (tuple[Tensor]) - The overflow flag in dynamiclossscale.
256
257    Outputs:
258        tuple[bool], all elements are True.
259
260    Supported Platforms:
261        ``Ascend`` ``GPU``
262
263    Examples:
264        >>> net = Net()
265        >>> #1) All parameters use the same learning rate and weight decay
266        >>> optim = AdamWeightDecay(params=net.trainable_params())
267        >>>
268        >>> #2) Use parameter groups and set different values
269        >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
270        >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
271        >>> group_params = [{'params': conv_params, 'weight_decay': 0.01},
272        ...                 {'params': no_conv_params, 'lr': 0.01},
273        ...                 {'order_params': net.trainable_params()}]
274        >>> optim = AdamWeightDecay(group_params, learning_rate=0.1, weight_decay=0.0)
275        >>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01.
276        >>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0.
277        >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'.
278        >>>
279        >>> loss = nn.SoftmaxCrossEntropyWithLogits()
280        >>> model = Model(net, loss_fn=loss, optimizer=optim)
281   """
282    def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0):
283        super(AdamWeightDecayForBert, self).__init__(learning_rate, params, weight_decay)
284        _check_param_value(beta1, beta2, eps, self.cls_name)
285        self.beta1 = Tensor(np.array([beta1]).astype(np.float32))
286        self.beta2 = Tensor(np.array([beta2]).astype(np.float32))
287        self.eps = Tensor(np.array([eps]).astype(np.float32))
288        self.moments1 = self.parameters.clone(prefix="adam_m", init='zeros')
289        self.moments2 = self.parameters.clone(prefix="adam_v", init='zeros')
290        self.hyper_map = C.HyperMap()
291        self.op_select = P.Select()
292        self.op_cast = P.Cast()
293        self.op_reshape = P.Reshape()
294        self.op_shape = P.Shape()
295
296    def construct(self, gradients, overflow):
297        """AdamWeightDecayForBert"""
298        lr = self.get_lr()
299        cond = self.op_cast(F.fill(mstype.int32, self.op_shape(self.beta1), 1) *\
300                            self.op_reshape(overflow, (())), mstype.bool_)
301        beta1 = self.op_select(cond, self.op_cast(F.tuple_to_array((1.0,)), mstype.float32), self.beta1)
302        beta2 = self.op_select(cond, self.op_cast(F.tuple_to_array((1.0,)), mstype.float32), self.beta2)
303        if self.is_group:
304            if self.is_group_lr:
305                optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps),
306                                              lr, self.weight_decay, self.parameters, self.moments1, self.moments2,
307                                              gradients, self.decay_flags, self.optim_filter)
308            else:
309                optim_result = self.hyper_map(F.partial(_adam_opt, beta1, beta2, self.eps, lr, overflow),
310                                              self.weight_decay, self.parameters, self.moments1, self.moments2,
311                                              gradients, self.decay_flags, self.optim_filter)
312        else:
313            optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr, self.weight_decay),
314                                          self.parameters, self.moments1, self.moments2,
315                                          gradients, self.decay_flags, self.optim_filter)
316        if self.use_parallel:
317            self.broadcast_params(optim_result)
318        return optim_result
319
320class AdamWeightDecayOp(Optimizer):
321    """
322    Implements the Adam algorithm to fix the weight decay. It is a complete operator, not a combination of other ops.
323
324    Note:
325        When separating parameter groups, the weight decay in each group will be applied on the parameters if the
326        weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied
327        on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive.
328
329        To improve parameter groups performance, the customized order of parameters can be supported.
330
331    Args:
332        params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated,
333            the element in `params` must be class `Parameter`. When the `params` is a list of `dict`, the "params",
334            "lr", "weight_decay" and "order_params" are the keys can be parsed.
335
336            - params: Required. The value must be a list of `Parameter`.
337
338            - lr: Optional. If "lr" is in the keys, the value of the corresponding learning rate will be used.
339              If not, the `learning_rate` in the API will be used.
340
341            - weight_decay: Optional. If "weight_decay" is in the keys, the value of the corresponding weight decay
342              will be used. If not, the `weight_decay` in the API will be used.
343
344            - order_params: Optional. If "order_params" is in the keys, the value must be the order of parameters and
345              the order will be followed in the optimizer. There are no other keys in the `dict` and the parameters
346              which in the 'order_params' must be in one of group parameters.
347
348        learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate.
349            When the learning_rate is an Iterable or a Tensor in a 1D dimension, use the dynamic learning rate, then
350            the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule,
351            use dynamic learning rate, the i-th learning rate will be calculated during the process of training
352            according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor in a zero
353            dimension, use fixed learning rate. Other cases are not supported. The float learning rate must be
354            equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float.
355            Default: 1e-3.
356        beta1 (float): The exponential decay rate for the 1st moment estimations. Default: 0.9.
357            Should be in range (0.0, 1.0).
358        beta2 (float): The exponential decay rate for the 2nd moment estimations. Default: 0.999.
359            Should be in range (0.0, 1.0).
360        eps (float): Term added to the denominator to improve numerical stability. Default: 1e-6.
361            Should be greater than 0.
362        weight_decay (float): Weight decay (L2 penalty). It must be equal to or greater than 0. Default: 0.0.
363
364    Inputs:
365        - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
366
367    Outputs:
368        tuple[bool], all elements are True.
369
370    Supported Platforms:
371        ``GPU``
372
373    Examples:
374        >>> net = Net()
375        >>> #1) All parameters use the same learning rate and weight decay
376        >>> optim = AdamWeightDecayOp(params=net.trainable_params())
377        >>>
378        >>> #2) Use parameter groups and set different values
379        >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
380        >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
381        >>> group_params = [{'params': conv_params, 'weight_decay': 0.01},
382        ...                 {'params': no_conv_params, 'lr': 0.01},
383        ...                 {'order_params': net.trainable_params()}]
384        >>> optim = AdamWeightDecayOp(group_params, learning_rate=0.1, weight_decay=0.0)
385        >>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01.
386        >>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0.
387        >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'.
388        >>>
389        >>> loss = nn.SoftmaxCrossEntropyWithLogits()
390        >>> model = Model(net, loss_fn=loss, optimizer=optim)
391   """
392    def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0):
393        super(AdamWeightDecayOp, self).__init__(learning_rate, params, weight_decay)
394        _check_param_value(beta1, beta2, eps, self.cls_name)
395        self.beta1 = Tensor(np.array([beta1]).astype(np.float32))
396        self.beta2 = Tensor(np.array([beta2]).astype(np.float32))
397        self.eps = Tensor(np.array([eps]).astype(np.float32))
398        self.moments1 = self.parameters.clone(prefix="adam_m", init='zeros')
399        self.moments2 = self.parameters.clone(prefix="adam_v", init='zeros')
400        self.hyper_map = C.HyperMap()
401
402    def construct(self, gradients):
403        """AdamWeightDecayOp"""
404        lr = self.get_lr()
405        if self.is_group:
406            if self.is_group_lr:
407                optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps),
408                                              lr, self.weight_decay, self.parameters, self.moments1, self.moments2,
409                                              gradients, self.decay_flags, self.optim_filter)
410            else:
411                optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr),
412                                              self.weight_decay, self.parameters, self.moments1, self.moments2,
413                                              gradients, self.decay_flags, self.optim_filter)
414        else:
415            optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr, self.weight_decay),
416                                          self.parameters, self.moments1, self.moments2,
417                                          gradients, self.decay_flags, self.optim_filter)
418        if self.use_parallel:
419            self.broadcast_params(optim_result)
420        return optim_result
421