• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16"""AdamWeightDecayForBert, a customized Adam for bert. Input: gradient, overflow flag."""
17import numpy as np
18
19from mindspore.common import dtype as mstype
20from mindspore.ops import operations as P
21from mindspore.ops import composite as C
22from mindspore.ops import functional as F
23from mindspore.common.tensor import Tensor
24from mindspore._checkparam import Validator as validator
25from mindspore._checkparam import Rel
26from mindspore.nn.optim.optimizer import Optimizer
27
28_adam_opt = C.MultitypeFuncGraph("adam_opt")
29_scaler_one = Tensor(1, mstype.int32)
30_scaler_ten = Tensor(10, mstype.float32)
31
32@_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor",
33                    "Tensor", "Bool", "Bool")
34def _update_run_kernel(beta1, beta2, eps, lr, weight_decay, param, m, v, gradient, decay_flags, optim_filter):
35    """
36    Update parameters by AdamWeightDecay op.
37    """
38    if optim_filter:
39        adam = P.AdamWeightDecay()
40        if decay_flags:
41            next_param = adam(param, m, v, lr, beta1, beta2, eps, weight_decay, gradient)
42        else:
43            next_param = adam(param, m, v, lr, beta1, beta2, eps, 0.0, gradient)
44        return next_param
45    return gradient
46
47@_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor",
48                    "Tensor", "Bool", "Bool")
49def _update_run_op(beta1, beta2, eps, lr, overflow, weight_decay, param, m, v, gradient, decay_flag, optim_filter):
50    """
51    Update parameters.
52
53    Args:
54        beta1 (Tensor): The exponential decay rate for the 1st moment estimations. Should be in range (0.0, 1.0).
55        beta2 (Tensor): The exponential decay rate for the 2nd moment estimations. Should be in range (0.0, 1.0).
56        eps (Tensor): Term added to the denominator to improve numerical stability. Should be greater than 0.
57        lr (Tensor): Learning rate.
58        overflow (Tensor): Whether overflow occurs.
59        weight_decay (Number): Weight decay. Should be equal to or greater than 0.
60        param (Tensor): Parameters.
61        m (Tensor): m value of parameters.
62        v (Tensor): v value of parameters.
63        gradient (Tensor): Gradient of parameters.
64        decay_flag (bool): Applies weight decay or not.
65        optim_filter (bool): Applies parameter update or not.
66
67    Returns:
68        Tensor, the new value of v after updating.
69    """
70    if optim_filter:
71        op_mul = P.Mul()
72        op_square = P.Square()
73        op_sqrt = P.Sqrt()
74        op_cast = P.Cast()
75        op_reshape = P.Reshape()
76        op_shape = P.Shape()
77        op_select = P.Select()
78
79        param_fp32 = op_cast(param, mstype.float32)
80        m_fp32 = op_cast(m, mstype.float32)
81        v_fp32 = op_cast(v, mstype.float32)
82        gradient_fp32 = op_cast(gradient, mstype.float32)
83
84        cond = op_cast(F.fill(mstype.int32, op_shape(m_fp32), 1) * op_reshape(overflow, (())), mstype.bool_)
85        next_m = op_mul(beta1, m_fp32) + op_select(cond, m_fp32,\
86                op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta1, gradient_fp32))
87
88        next_v = op_mul(beta2, v_fp32) + op_select(cond, v_fp32,\
89                op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta2, op_square(gradient_fp32)))
90
91        update = next_m / (eps + op_sqrt(next_v))
92        if decay_flag:
93            update = op_mul(weight_decay, param_fp32) + update
94
95        update_with_lr = op_mul(lr, update)
96        zeros = F.fill(mstype.float32, op_shape(param_fp32), 0)
97        next_param = param_fp32 - op_select(cond, zeros, op_reshape(update_with_lr, op_shape(param_fp32)))
98
99        next_param = F.depend(next_param, F.assign(param, op_cast(next_param, F.dtype(param))))
100        next_param = F.depend(next_param, F.assign(m, op_cast(next_m, F.dtype(m))))
101        next_param = F.depend(next_param, F.assign(v, op_cast(next_v, F.dtype(v))))
102
103        return op_cast(next_param, F.dtype(param))
104    return gradient
105
106
107@_adam_opt.register("Function", "Function", "Function", "Function", "Bool", "Bool", "Bool", "Tensor", "Tensor",
108                    "Tensor", "Tensor", "Tensor", "Tensor", "RowTensor", "Tensor", "Tensor", "Tensor", "Bool", "Bool")
109def _run_opt_with_sparse(opt, sparse_opt, push, pull, use_locking, use_nesterov, target, beta1_power,
110                         beta2_power, beta1, beta2, eps, lr, gradient, param, m, v, ps_parameter, cache_enable):
111    """Apply sparse adam optimizer to the weight parameter when the gradient is sparse."""
112    success = True
113    indices = gradient.indices
114    values = gradient.values
115    if ps_parameter and not cache_enable:
116        op_shape = P.Shape()
117        shapes = (op_shape(param), op_shape(m), op_shape(v),
118                  op_shape(beta1_power), op_shape(beta2_power), op_shape(lr), op_shape(beta1),
119                  op_shape(beta2), op_shape(eps), op_shape(values), op_shape(indices))
120        success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2,
121                                               eps, values, indices), shapes), param))
122        return success
123
124    if not target:
125        success = F.depend(success, sparse_opt(param, m, v, beta1_power, beta2_power, lr, beta1, beta2,
126                                               eps, values, indices))
127    else:
128        op_mul = P.Mul()
129        op_square = P.Square()
130        op_sqrt = P.Sqrt()
131        scatter_add = P.ScatterAdd(use_locking)
132
133        F.assign(m, op_mul(beta1, m))
134        F.assign(v, op_mul(beta2, v))
135
136        grad_indices = gradient.indices
137        grad_value = gradient.values
138
139        next_m = scatter_add(m,
140                             grad_indices,
141                             op_mul(F.tuple_to_array((1.0,)) - beta1, grad_value))
142
143        next_v = scatter_add(v,
144                             grad_indices,
145                             op_mul(F.tuple_to_array((1.0,)) - beta2, op_square(grad_value)))
146
147        if use_nesterov:
148            m_temp = next_m * _scaler_ten
149            F.assign(m, op_mul(beta1, next_m))
150            div_value = scatter_add(m,
151                                    op_mul(grad_indices, _scaler_one),
152                                    op_mul(F.tuple_to_array((1.0,)) - beta1, grad_value))
153            param_update = div_value / (op_sqrt(next_v) + eps)
154
155            F.assign(m, m_temp / _scaler_ten)
156
157
158        else:
159            param_update = next_m / (op_sqrt(next_v) + eps)
160
161        lr_t = lr * op_sqrt(1 - beta2_power) / (1 - beta1_power)
162
163        next_param = param - lr_t * param_update
164
165
166        success = F.depend(success, F.assign(param, next_param))
167        success = F.depend(success, F.assign(m, next_m))
168        success = F.depend(success, F.assign(v, next_v))
169
170    return success
171
172
173@_adam_opt.register("Function", "Function", "Function", "Function", "Bool", "Bool", "Bool", "Tensor", "Tensor",
174                    "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool", "Bool")
175def _run_opt_with_one_number(opt, sparse_opt, push, pull, use_locking, use_nesterov, target,
176                             beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, param,
177                             moment1, moment2, ps_parameter, cache_enable):
178    """Apply adam optimizer to the weight parameter using Tensor."""
179    success = True
180    if ps_parameter and not cache_enable:
181        op_shape = P.Shape()
182        success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2, eps, gradient),
183                                              (op_shape(param), op_shape(moment1), op_shape(moment2))), param))
184    else:
185        success = F.depend(success, opt(param, moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2,
186                                        eps, gradient))
187    return success
188
189
190@_adam_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
191                    "Tensor", "Tensor")
192def _run_off_load_opt(opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, param, moment1, moment2):
193    """Apply AdamOffload optimizer to the weight parameter using Tensor."""
194    success = True
195    delat_param = opt(moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2, eps, gradient)
196    success = F.depend(success, F.assign_add(param, delat_param))
197    return success
198
199
200def _check_param_value(beta1, beta2, eps, prim_name):
201    """Check the type of inputs."""
202    validator.check_value_type("beta1", beta1, [float], prim_name)
203    validator.check_value_type("beta2", beta2, [float], prim_name)
204    validator.check_value_type("eps", eps, [float], prim_name)
205    validator.check_float_range(beta1, 0.0, 1.0, Rel.INC_NEITHER, "beta1", prim_name)
206    validator.check_float_range(beta2, 0.0, 1.0, Rel.INC_NEITHER, "beta2", prim_name)
207    validator.check_positive_float(eps, "eps", prim_name)
208
209class AdamWeightDecayForBert(Optimizer):
210    """
211    Implements the Adam algorithm to fix the weight decay.
212
213    Note:
214        When separating parameter groups, the weight decay in each group will be applied on the parameters if the
215        weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied
216        on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive.
217
218        To improve parameter groups performance, the customized order of parameters can be supported.
219
220    Args:
221        params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated,
222            the element in `params` must be class `Parameter`. When the `params` is a list of `dict`, the "params",
223            "lr", "weight_decay" and "order_params" are the keys can be parsed.
224
225            - params: Required. The value must be a list of `Parameter`.
226
227            - lr: Optional. If "lr" is in the keys, the value of the corresponding learning rate will be used.
228              If not, the `learning_rate` in the API will be used.
229
230            - weight_decay: Optional. If "weight_decay" is in the keys, the value of the corresponding weight decay
231              will be used. If not, the `weight_decay` in the API will be used.
232
233            - order_params: Optional. If "order_params" is in the keys, the value must be the order of parameters and
234              the order will be followed in the optimizer. There are no other keys in the `dict` and the parameters
235              which in the 'order_params' must be in one of group parameters.
236
237        learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate.
238            When the learning_rate is an Iterable or a Tensor in a 1D dimension, use the dynamic learning rate, then
239            the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule,
240            use dynamic learning rate, the i-th learning rate will be calculated during the process of training
241            according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor in a zero
242            dimension, use fixed learning rate. Other cases are not supported. The float learning rate must be
243            equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float.
244            Default: 1e-3.
245        beta1 (float): The exponential decay rate for the 1st moment estimations. Default: 0.9.
246            Should be in range (0.0, 1.0).
247        beta2 (float): The exponential decay rate for the 2nd moment estimations. Default: 0.999.
248            Should be in range (0.0, 1.0).
249        eps (float): Term added to the denominator to improve numerical stability. Default: 1e-6.
250            Should be greater than 0.
251        weight_decay (float): Weight decay (L2 penalty). It must be equal to or greater than 0. Default: 0.0.
252
253    Inputs:
254        - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
255        - **overflow** (tuple[Tensor]) - The overflow flag in dynamiclossscale.
256
257    Outputs:
258        tuple[bool], all elements are True.
259
260    Supported Platforms:
261        ``Ascend`` ``GPU``
262
263    Examples:
264        >>> net = Net()
265        >>> #1) All parameters use the same learning rate and weight decay
266        >>> optim = AdamWeightDecay(params=net.trainable_params())
267        >>>
268        >>> #2) Use parameter groups and set different values
269        >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
270        >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
271        >>> group_params = [{'params': conv_params, 'weight_decay': 0.01},
272        ...                 {'params': no_conv_params, 'lr': 0.01},
273        ...                 {'order_params': net.trainable_params()}]
274        >>> optim = AdamWeightDecay(group_params, learning_rate=0.1, weight_decay=0.0)
275        >>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01.
276        >>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0.
277        >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'.
278        >>>
279        >>> loss = nn.SoftmaxCrossEntropyWithLogits()
280        >>> model = Model(net, loss_fn=loss, optimizer=optim)
281   """
282    def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0):
283        super(AdamWeightDecayForBert, self).__init__(learning_rate, params, weight_decay)
284        _check_param_value(beta1, beta2, eps, self.cls_name)
285        self.beta1 = Tensor(np.array([beta1]).astype(np.float32))
286        self.beta2 = Tensor(np.array([beta2]).astype(np.float32))
287        self.eps = Tensor(np.array([eps]).astype(np.float32))
288        self.moments1 = self.parameters.clone(prefix="adam_m", init='zeros')
289        self.moments2 = self.parameters.clone(prefix="adam_v", init='zeros')
290        self.hyper_map = C.HyperMap()
291        self.op_select = P.Select()
292        self.op_cast = P.Cast()
293        self.op_reshape = P.Reshape()
294        self.op_shape = P.Shape()
295
296    def construct(self, gradients, overflow):
297        """AdamWeightDecayForBert"""
298        lr = self.get_lr()
299        cond = self.op_cast(F.fill(mstype.int32, self.op_shape(self.beta1), 1) *\
300                            self.op_reshape(overflow, (())), mstype.bool_)
301        beta1 = self.op_select(cond, self.op_cast(F.tuple_to_array((1.0,)), mstype.float32), self.beta1)
302        beta2 = self.op_select(cond, self.op_cast(F.tuple_to_array((1.0,)), mstype.float32), self.beta2)
303        if self.is_group:
304            if self.is_group_lr:
305                optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps),
306                                              lr, self.weight_decay, self.parameters, self.moments1, self.moments2,
307                                              gradients, self.decay_flags, self.optim_filter)
308            else:
309                optim_result = self.hyper_map(F.partial(_adam_opt, beta1, beta2, self.eps, lr, overflow),
310                                              self.weight_decay, self.parameters, self.moments1, self.moments2,
311                                              gradients, self.decay_flags, self.optim_filter)
312        else:
313            optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr, self.weight_decay),
314                                          self.parameters, self.moments1, self.moments2,
315                                          gradients, self.decay_flags, self.optim_filter)
316        if self.use_parallel:
317            self.broadcast_params(optim_result)
318        return optim_result
319
320class AdamWeightDecayOp(Optimizer):
321    """
322    Implements the Adam algorithm to fix the weight decay. It is a complete operator, not a combination of other ops.
323
324    Note:
325        When separating parameter groups, the weight decay in each group will be applied on the parameters if the
326        weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied
327        on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive.
328
329        To improve parameter groups performance, the customized order of parameters can be supported.
330
331    Args:
332        params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated,
333            the element in `params` must be class `Parameter`. When the `params` is a list of `dict`, the "params",
334            "lr", "weight_decay" and "order_params" are the keys can be parsed.
335
336            - params: Required. The value must be a list of `Parameter`.
337
338            - lr: Optional. If "lr" is in the keys, the value of the corresponding learning rate will be used.
339              If not, the `learning_rate` in the API will be used.
340
341            - weight_decay: Optional. If "weight_decay" is in the keys, the value of the corresponding weight decay
342              will be used. If not, the `weight_decay` in the API will be used.
343
344            - order_params: Optional. If "order_params" is in the keys, the value must be the order of parameters and
345              the order will be followed in the optimizer. There are no other keys in the `dict` and the parameters
346              which in the 'order_params' must be in one of group parameters.
347
348        learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate.
349            When the learning_rate is an Iterable or a Tensor in a 1D dimension, use the dynamic learning rate, then
350            the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule,
351            use dynamic learning rate, the i-th learning rate will be calculated during the process of training
352            according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor in a zero
353            dimension, use fixed learning rate. Other cases are not supported. The float learning rate must be
354            equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float.
355            Default: 1e-3.
356        beta1 (float): The exponential decay rate for the 1st moment estimations. Default: 0.9.
357            Should be in range (0.0, 1.0).
358        beta2 (float): The exponential decay rate for the 2nd moment estimations. Default: 0.999.
359            Should be in range (0.0, 1.0).
360        eps (float): Term added to the denominator to improve numerical stability. Default: 1e-6.
361            Should be greater than 0.
362        weight_decay (float): Weight decay (L2 penalty). It must be equal to or greater than 0. Default: 0.0.
363
364    Inputs:
365        - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
366
367    Outputs:
368        tuple[bool], all elements are True.
369
370    Supported Platforms:
371        ``GPU``
372
373    Examples:
374        >>> net = Net()
375        >>> #1) All parameters use the same learning rate and weight decay
376        >>> optim = AdamWeightDecayOp(params=net.trainable_params())
377        >>>
378        >>> #2) Use parameter groups and set different values
379        >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
380        >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
381        >>> group_params = [{'params': conv_params, 'weight_decay': 0.01},
382        ...                 {'params': no_conv_params, 'lr': 0.01},
383        ...                 {'order_params': net.trainable_params()}]
384        >>> optim = AdamWeightDecayOp(group_params, learning_rate=0.1, weight_decay=0.0)
385        >>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01.
386        >>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0.
387        >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'.
388        >>>
389        >>> loss = nn.SoftmaxCrossEntropyWithLogits()
390        >>> model = Model(net, loss_fn=loss, optimizer=optim)
391   """
392    def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0):
393        super(AdamWeightDecayOp, self).__init__(learning_rate, params, weight_decay)
394        _check_param_value(beta1, beta2, eps, self.cls_name)
395        self.beta1 = Tensor(np.array([beta1]).astype(np.float32))
396        self.beta2 = Tensor(np.array([beta2]).astype(np.float32))
397        self.eps = Tensor(np.array([eps]).astype(np.float32))
398        self.moments1 = self.parameters.clone(prefix="adam_m", init='zeros')
399        self.moments2 = self.parameters.clone(prefix="adam_v", init='zeros')
400        self.hyper_map = C.HyperMap()
401
402    def construct(self, gradients):
403        """AdamWeightDecayOp"""
404        lr = self.get_lr()
405        if self.is_group:
406            if self.is_group_lr:
407                optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps),
408                                              lr, self.weight_decay, self.parameters, self.moments1, self.moments2,
409                                              gradients, self.decay_flags, self.optim_filter)
410            else:
411                optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr),
412                                              self.weight_decay, self.parameters, self.moments1, self.moments2,
413                                              gradients, self.decay_flags, self.optim_filter)
414        else:
415            optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr, self.weight_decay),
416                                          self.parameters, self.moments1, self.moments2,
417                                          gradients, self.decay_flags, self.optim_filter)
418        if self.use_parallel:
419            self.broadcast_params(optim_result)
420        return optim_result
421