• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""AdamWeightDecayForBert, a customized Adam for bert. Input: gradient, overflow flag."""
16import numpy as np
17
18from mindspore.common import dtype as mstype
19from mindspore.ops import operations as P
20from mindspore.ops import composite as C
21from mindspore.ops import functional as F
22from mindspore.common.tensor import Tensor
23from mindspore._checkparam import Validator as validator
24from mindspore._checkparam import Rel
25from mindspore.nn.optim.optimizer import Optimizer
26
27_adam_opt = C.MultitypeFuncGraph("adam_opt")
28_scaler_one = Tensor(1, mstype.int32)
29_scaler_ten = Tensor(10, mstype.float32)
30
31@_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor",
32                    "Tensor", "Bool", "Bool")
33def _update_run_kernel(beta1, beta2, eps, lr, weight_decay, param, m, v, gradient, decay_flags, optim_filter):
34    """
35    Update parameters by AdamWeightDecay op.
36    """
37    if optim_filter:
38        adam = P.AdamWeightDecay()
39        if decay_flags:
40            next_param = adam(param, m, v, lr, beta1, beta2, eps, weight_decay, gradient)
41        else:
42            next_param = adam(param, m, v, lr, beta1, beta2, eps, 0.0, gradient)
43        return next_param
44    return gradient
45
46@_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor",
47                    "Tensor", "Bool", "Bool")
48def _update_run_op(beta1, beta2, eps, lr, overflow, weight_decay, param, m, v, gradient, decay_flag, optim_filter):
49    """
50    Update parameters.
51
52    Args:
53        beta1 (Tensor): The exponential decay rate for the 1st moment estimations. Should be in range (0.0, 1.0).
54        beta2 (Tensor): The exponential decay rate for the 2nd moment estimations. Should be in range (0.0, 1.0).
55        eps (Tensor): Term added to the denominator to improve numerical stability. Should be greater than 0.
56        lr (Tensor): Learning rate.
57        overflow (Tensor): Whether overflow occurs.
58        weight_decay (Number): Weight decay. Should be equal to or greater than 0.
59        param (Tensor): Parameters.
60        m (Tensor): m value of parameters.
61        v (Tensor): v value of parameters.
62        gradient (Tensor): Gradient of parameters.
63        decay_flag (bool): Applies weight decay or not.
64        optim_filter (bool): Applies parameter update or not.
65
66    Returns:
67        Tensor, the new value of v after updating.
68    """
69    if optim_filter:
70        op_mul = P.Mul()
71        op_square = P.Square()
72        op_sqrt = P.Sqrt()
73        op_cast = P.Cast()
74        op_reshape = P.Reshape()
75        op_shape = P.Shape()
76        op_select = P.Select()
77
78        param_fp32 = op_cast(param, mstype.float32)
79        m_fp32 = op_cast(m, mstype.float32)
80        v_fp32 = op_cast(v, mstype.float32)
81        gradient_fp32 = op_cast(gradient, mstype.float32)
82
83        cond = op_cast(F.fill(mstype.int32, op_shape(m_fp32), 1) * op_reshape(overflow, (())), mstype.bool_)
84        next_m = op_mul(beta1, m_fp32) + op_select(cond, m_fp32,\
85                op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta1, gradient_fp32))
86
87        next_v = op_mul(beta2, v_fp32) + op_select(cond, v_fp32,\
88                op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta2, op_square(gradient_fp32)))
89
90        update = next_m / (eps + op_sqrt(next_v))
91        if decay_flag:
92            update = op_mul(weight_decay, param_fp32) + update
93
94        update_with_lr = op_mul(lr, update)
95        zeros = F.fill(mstype.float32, op_shape(param_fp32), 0)
96        next_param = param_fp32 - op_select(cond, zeros, op_reshape(update_with_lr, op_shape(param_fp32)))
97
98        next_param = F.depend(next_param, F.assign(param, op_cast(next_param, F.dtype(param))))
99        next_param = F.depend(next_param, F.assign(m, op_cast(next_m, F.dtype(m))))
100        next_param = F.depend(next_param, F.assign(v, op_cast(next_v, F.dtype(v))))
101
102        return op_cast(next_param, F.dtype(param))
103    return gradient
104
105
106@_adam_opt.register("Function", "Function", "Function", "Function", "Bool", "Bool", "Bool", "Tensor", "Tensor",
107                    "Tensor", "Tensor", "Tensor", "Tensor", "RowTensor", "Tensor", "Tensor", "Tensor", "Bool", "Bool")
108def _run_opt_with_sparse(opt, sparse_opt, push, pull, use_locking, use_nesterov, target, beta1_power,
109                         beta2_power, beta1, beta2, eps, lr, gradient, param, m, v, ps_parameter, cache_enable):
110    """Apply sparse adam optimizer to the weight parameter when the gradient is sparse."""
111    success = True
112    indices = gradient.indices
113    values = gradient.values
114    if ps_parameter and not cache_enable:
115        op_shape = P.Shape()
116        shapes = (op_shape(param), op_shape(m), op_shape(v),
117                  op_shape(beta1_power), op_shape(beta2_power), op_shape(lr), op_shape(beta1),
118                  op_shape(beta2), op_shape(eps), op_shape(values), op_shape(indices))
119        success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2,
120                                               eps, values, indices), shapes), param))
121        return success
122
123    if not target:
124        success = F.depend(success, sparse_opt(param, m, v, beta1_power, beta2_power, lr, beta1, beta2,
125                                               eps, values, indices))
126    else:
127        op_mul = P.Mul()
128        op_square = P.Square()
129        op_sqrt = P.Sqrt()
130        scatter_add = P.ScatterAdd(use_locking)
131
132        F.assign(m, op_mul(beta1, m))
133        F.assign(v, op_mul(beta2, v))
134
135        grad_indices = gradient.indices
136        grad_value = gradient.values
137
138        next_m = scatter_add(m,
139                             grad_indices,
140                             op_mul(F.tuple_to_array((1.0,)) - beta1, grad_value))
141
142        next_v = scatter_add(v,
143                             grad_indices,
144                             op_mul(F.tuple_to_array((1.0,)) - beta2, op_square(grad_value)))
145
146        if use_nesterov:
147            m_temp = next_m * _scaler_ten
148            F.assign(m, op_mul(beta1, next_m))
149            div_value = scatter_add(m,
150                                    op_mul(grad_indices, _scaler_one),
151                                    op_mul(F.tuple_to_array((1.0,)) - beta1, grad_value))
152            param_update = div_value / (op_sqrt(next_v) + eps)
153
154            F.assign(m, m_temp / _scaler_ten)
155
156        else:
157            param_update = next_m / (op_sqrt(next_v) + eps)
158
159        lr_t = lr * op_sqrt(1 - beta2_power) / (1 - beta1_power)
160
161        next_param = param - lr_t * param_update
162
163
164
165        success = F.depend(success, F.assign(param, next_param))
166        success = F.depend(success, F.assign(m, next_m))
167        success = F.depend(success, F.assign(v, next_v))
168
169    return success
170
171
172@_adam_opt.register("Function", "Function", "Function", "Function", "Bool", "Bool", "Bool", "Tensor", "Tensor",
173                    "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool", "Bool")
174def _run_opt_with_one_number(opt, sparse_opt, push, pull, use_locking, use_nesterov, target,
175                             beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, param,
176                             moment1, moment2, ps_parameter, cache_enable):
177    """Apply adam optimizer to the weight parameter using Tensor."""
178    success = True
179    if ps_parameter and not cache_enable:
180        op_shape = P.Shape()
181        success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2, eps, gradient),
182                                              (op_shape(param), op_shape(moment1), op_shape(moment2))), param))
183    else:
184        success = F.depend(success, opt(param, moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2,
185                                        eps, gradient))
186    return success
187
188
189@_adam_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
190                    "Tensor", "Tensor")
191def _run_off_load_opt(opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, param, moment1, moment2):
192    """Apply AdamOffload optimizer to the weight parameter using Tensor."""
193    success = True
194    delat_param = opt(moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2, eps, gradient)
195    success = F.depend(success, F.assign_add(param, delat_param))
196    return success
197
198
199def _check_param_value(beta1, beta2, eps, prim_name):
200    """Check the type of inputs."""
201    validator.check_value_type("beta1", beta1, [float], prim_name)
202    validator.check_value_type("beta2", beta2, [float], prim_name)
203    validator.check_value_type("eps", eps, [float], prim_name)
204    validator.check_float_range(beta1, 0.0, 1.0, Rel.INC_NEITHER, "beta1", prim_name)
205    validator.check_float_range(beta2, 0.0, 1.0, Rel.INC_NEITHER, "beta2", prim_name)
206    validator.check_positive_float(eps, "eps", prim_name)
207
208class AdamWeightDecayForBert(Optimizer):
209    """
210    Implements the Adam algorithm to fix the weight decay.
211
212    Note:
213        When separating parameter groups, the weight decay in each group will be applied on the parameters if the
214        weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied
215        on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive.
216
217        To improve parameter groups performance, the customized order of parameters can be supported.
218
219    Args:
220        params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated,
221            the element in `params` must be class `Parameter`. When the `params` is a list of `dict`, the "params",
222            "lr", "weight_decay" and "order_params" are the keys can be parsed.
223
224            - params: Required. The value must be a list of `Parameter`.
225
226            - lr: Optional. If "lr" is in the keys, the value of the corresponding learning rate will be used.
227              If not, the `learning_rate` in the API will be used.
228
229            - weight_decay: Optional. If "weight_decay" is in the keys, the value of the corresponding weight decay
230              will be used. If not, the `weight_decay` in the API will be used.
231
232            - order_params: Optional. If "order_params" is in the keys, the value must be the order of parameters and
233              the order will be followed in the optimizer. There are no other keys in the `dict` and the parameters
234              which in the 'order_params' must be in one of group parameters.
235
236        learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate.
237            When the learning_rate is an Iterable or a Tensor in a 1D dimension, use the dynamic learning rate, then
238            the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule,
239            use dynamic learning rate, the i-th learning rate will be calculated during the process of training
240            according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor in a zero
241            dimension, use fixed learning rate. Other cases are not supported. The float learning rate must be
242            equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float.
243            Default: 1e-3.
244        beta1 (float): The exponential decay rate for the 1st moment estimations. Default: 0.9.
245            Should be in range (0.0, 1.0).
246        beta2 (float): The exponential decay rate for the 2nd moment estimations. Default: 0.999.
247            Should be in range (0.0, 1.0).
248        eps (float): Term added to the denominator to improve numerical stability. Default: 1e-6.
249            Should be greater than 0.
250        weight_decay (float): Weight decay (L2 penalty). It must be equal to or greater than 0. Default: 0.0.
251
252    Inputs:
253        - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
254        - **overflow** (tuple[Tensor]) - The overflow flag in dynamiclossscale.
255
256    Outputs:
257        tuple[bool], all elements are True.
258
259    Supported Platforms:
260        ``Ascend`` ``GPU``
261
262    Examples:
263        >>> net = Net()
264        >>> #1) All parameters use the same learning rate and weight decay
265        >>> optim = AdamWeightDecay(params=net.trainable_params())
266        >>>
267        >>> #2) Use parameter groups and set different values
268        >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
269        >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
270        >>> group_params = [{'params': conv_params, 'weight_decay': 0.01},
271        ...                 {'params': no_conv_params, 'lr': 0.01},
272        ...                 {'order_params': net.trainable_params()}]
273        >>> optim = AdamWeightDecay(group_params, learning_rate=0.1, weight_decay=0.0)
274        >>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01.
275        >>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0.
276        >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'.
277        >>>
278        >>> loss = nn.SoftmaxCrossEntropyWithLogits()
279        >>> model = Model(net, loss_fn=loss, optimizer=optim)
280   """
281    def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0):
282        super(AdamWeightDecayForBert, self).__init__(learning_rate, params, weight_decay)
283        _check_param_value(beta1, beta2, eps, self.cls_name)
284        self.beta1 = Tensor(np.array([beta1]).astype(np.float32))
285        self.beta2 = Tensor(np.array([beta2]).astype(np.float32))
286        self.eps = Tensor(np.array([eps]).astype(np.float32))
287        self.moments1 = self.parameters.clone(prefix="adam_m", init='zeros')
288        self.moments2 = self.parameters.clone(prefix="adam_v", init='zeros')
289        self.hyper_map = C.HyperMap()
290        self.op_select = P.Select()
291        self.op_cast = P.Cast()
292        self.op_reshape = P.Reshape()
293        self.op_shape = P.Shape()
294
295    def construct(self, gradients, overflow):
296        """AdamWeightDecayForBert"""
297        lr = self.get_lr()
298        cond = self.op_cast(F.fill(mstype.int32, self.op_shape(self.beta1), 1) *\
299                            self.op_reshape(overflow, (())), mstype.bool_)
300        beta1 = self.op_select(cond, self.op_cast(F.tuple_to_array((1.0,)), mstype.float32), self.beta1)
301        beta2 = self.op_select(cond, self.op_cast(F.tuple_to_array((1.0,)), mstype.float32), self.beta2)
302        if self.is_group:
303            if self.is_group_lr:
304                optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps),
305                                              lr, self.weight_decay, self.parameters, self.moments1, self.moments2,
306                                              gradients, self.decay_flags, self.optim_filter)
307            else:
308                optim_result = self.hyper_map(F.partial(_adam_opt, beta1, beta2, self.eps, lr, overflow),
309                                              self.weight_decay, self.parameters, self.moments1, self.moments2,
310                                              gradients, self.decay_flags, self.optim_filter)
311        else:
312            optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr, self.weight_decay),
313                                          self.parameters, self.moments1, self.moments2,
314                                          gradients, self.decay_flags, self.optim_filter)
315        if self.use_parallel:
316            self.broadcast_params(optim_result)
317        return optim_result
318
319class AdamWeightDecayOp(Optimizer):
320    """
321    Implements the Adam algorithm to fix the weight decay. It is a complete operator, not a combination of other ops.
322
323    Note:
324        When separating parameter groups, the weight decay in each group will be applied on the parameters if the
325        weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied
326        on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive.
327
328        To improve parameter groups performance, the customized order of parameters can be supported.
329
330    Args:
331        params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated,
332            the element in `params` must be class `Parameter`. When the `params` is a list of `dict`, the "params",
333            "lr", "weight_decay" and "order_params" are the keys can be parsed.
334
335            - params: Required. The value must be a list of `Parameter`.
336
337            - lr: Optional. If "lr" is in the keys, the value of the corresponding learning rate will be used.
338              If not, the `learning_rate` in the API will be used.
339
340            - weight_decay: Optional. If "weight_decay" is in the keys, the value of the corresponding weight decay
341              will be used. If not, the `weight_decay` in the API will be used.
342
343            - order_params: Optional. If "order_params" is in the keys, the value must be the order of parameters and
344              the order will be followed in the optimizer. There are no other keys in the `dict` and the parameters
345              which in the 'order_params' must be in one of group parameters.
346
347        learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate.
348            When the learning_rate is an Iterable or a Tensor in a 1D dimension, use the dynamic learning rate, then
349            the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule,
350            use dynamic learning rate, the i-th learning rate will be calculated during the process of training
351            according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor in a zero
352            dimension, use fixed learning rate. Other cases are not supported. The float learning rate must be
353            equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float.
354            Default: 1e-3.
355        beta1 (float): The exponential decay rate for the 1st moment estimations. Default: 0.9.
356            Should be in range (0.0, 1.0).
357        beta2 (float): The exponential decay rate for the 2nd moment estimations. Default: 0.999.
358            Should be in range (0.0, 1.0).
359        eps (float): Term added to the denominator to improve numerical stability. Default: 1e-6.
360            Should be greater than 0.
361        weight_decay (float): Weight decay (L2 penalty). It must be equal to or greater than 0. Default: 0.0.
362
363    Inputs:
364        - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
365
366    Outputs:
367        tuple[bool], all elements are True.
368
369    Supported Platforms:
370        ``GPU``
371
372    Examples:
373        >>> net = Net()
374        >>> #1) All parameters use the same learning rate and weight decay
375        >>> optim = AdamWeightDecayOp(params=net.trainable_params())
376        >>>
377        >>> #2) Use parameter groups and set different values
378        >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
379        >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
380        >>> group_params = [{'params': conv_params, 'weight_decay': 0.01},
381        ...                 {'params': no_conv_params, 'lr': 0.01},
382        ...                 {'order_params': net.trainable_params()}]
383        >>> optim = AdamWeightDecayOp(group_params, learning_rate=0.1, weight_decay=0.0)
384        >>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01.
385        >>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0.
386        >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'.
387        >>>
388        >>> loss = nn.SoftmaxCrossEntropyWithLogits()
389        >>> model = Model(net, loss_fn=loss, optimizer=optim)
390   """
391    def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0):
392        super(AdamWeightDecayOp, self).__init__(learning_rate, params, weight_decay)
393        _check_param_value(beta1, beta2, eps, self.cls_name)
394        self.beta1 = Tensor(np.array([beta1]).astype(np.float32))
395        self.beta2 = Tensor(np.array([beta2]).astype(np.float32))
396        self.eps = Tensor(np.array([eps]).astype(np.float32))
397        self.moments1 = self.parameters.clone(prefix="adam_m", init='zeros')
398        self.moments2 = self.parameters.clone(prefix="adam_v", init='zeros')
399        self.hyper_map = C.HyperMap()
400
401    def construct(self, gradients):
402        """AdamWeightDecayOp"""
403        lr = self.get_lr()
404        if self.is_group:
405            if self.is_group_lr:
406                optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps),
407                                              lr, self.weight_decay, self.parameters, self.moments1, self.moments2,
408                                              gradients, self.decay_flags, self.optim_filter)
409            else:
410                optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr),
411                                              self.weight_decay, self.parameters, self.moments1, self.moments2,
412                                              gradients, self.decay_flags, self.optim_filter)
413        else:
414            optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr, self.weight_decay),
415                                          self.parameters, self.moments1, self.moments2,
416                                          gradients, self.decay_flags, self.optim_filter)
417        if self.use_parallel:
418            self.broadcast_params(optim_result)
419        return optim_result
420