• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2022 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""adamax"""
16from __future__ import absolute_import
17
18from mindspore.common import dtype as mstype
19from mindspore.common.initializer import initializer
20from mindspore.common.api import jit
21from mindspore.ops import operations as P
22from mindspore.ops import composite as C
23from mindspore.ops import functional as F
24from mindspore.common.parameter import Parameter
25from mindspore.common.tensor import Tensor
26from mindspore import _checkparam as validator
27from mindspore.nn.optim.optimizer import Optimizer
28from mindspore.nn.optim.optimizer import opt_init_args_register
29
30_ada_max_opt = C.MultitypeFuncGraph("ada_max_opt")
31
32
33@_ada_max_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
34                       "Tensor", "Tensor")
35def _tensor_run_opt(opt, beta1, beta2, beta1_power, eps, learning_rate, weight, moment1, moment2, gradient):
36    success = True
37    success = F.depend(success, opt(weight, moment1, moment2, beta1_power, learning_rate, beta1, beta2, eps, gradient))
38    return success
39
40
41def _check_param_value(beta1, beta2, eps, prim_name):
42    """Check the type of inputs."""
43    validator.check_value_type("beta1", beta1, [float], prim_name)
44    validator.check_value_type("beta2", beta2, [float], prim_name)
45    validator.check_value_type("eps", eps, [float], prim_name)
46    validator.check_float_range(beta1, 0.0, 1.0, validator.INC_NEITHER, "beta1", prim_name)
47    validator.check_float_range(beta2, 0.0, 1.0, validator.INC_NEITHER, "beta2", prim_name)
48    validator.check_positive_float(eps, "eps", prim_name)
49
50
51class AdaMax(Optimizer):
52    r"""
53    Implements the AdaMax algorithm, a variant of Adaptive Movement Estimation (Adam) based on the infinity norm.
54
55    The AdaMax algorithm is proposed in `Adam: A Method for Stochastic Optimization <https://arxiv.org/abs/1412.6980>`_.
56
57    The updating formulas are as follows,
58
59    .. math::
60        \begin{array}{ll} \\
61            m_{t+1} = \beta_1 * m_{t} + (1 - \beta_1) * g \\
62            v_{t+1} = \max(\beta_2 * v_{t}, \left| g \right|) \\
63            w = w - \frac{l}{1 - \beta_1^{t+1}} * \frac{m_{t+1}}{v_{t+1} + \epsilon}
64        \end{array}
65
66    :math:`m` represents the 1st moment vector, :math:`v` represents the 2nd moment vector,
67    :math:`g` represents `gradients`, :math:`\beta_1, \beta_2` represent `beta1` and `beta2`,
68    :math:`t` represents the current step, :math:`beta_1^t` represent `beta1_power`,
69    :math:`l` represents `learning_rate`, :math:`w` represents `params`,
70    :math:`\epsilon` represents `eps`.
71
72    Note:
73        If parameters are not grouped, the `weight_decay` in optimizer will be applied on the network parameters without
74        'beta' or 'gamma' in their names. Users can group parameters to change the strategy of decaying weight. When
75        parameters are grouped, each group can set `weight_decay`. If not, the `weight_decay` in optimizer will be
76        applied.
77
78    Args:
79        params (Union[list[Parameter], list[dict]]): Must be list of `Parameter` or list of `dict`. When the
80            `params` is a list of `dict`, the string "params", "lr", "weight_decay", "grad_centralization" and
81            "order_params" are the keys can be parsed.
82
83            - params: Required. Parameters in current group. The value must be a list of `Parameter`.
84
85            - lr: Optional. If "lr" in the keys, the value of corresponding learning rate will be used.
86              If not, the `learning_rate` in optimizer will be used. Fixed and dynamic learning rate are supported.
87
88            - weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay
89              will be used. If not, the `weight_decay` in the optimizer will be used. It should be noted that weight
90              decay can be a constant value or a Cell. It is a Cell only when dynamic weight decay is applied. Dynamic
91              weight decay is similar to dynamic learning rate, users need to customize a weight decay schedule only
92              with global step as input, and during training, the optimizer calls the instance of WeightDecaySchedule
93              to get the weight decay value of current step.
94
95            - grad_centralization: Optional. Must be Boolean. If "grad_centralization" is in the keys, the set value
96              will be used. If not, the `grad_centralization` is False by default. This configuration only works on the
97              convolution layer.
98
99            - order_params: Optional. When parameters is grouped, this usually is used to maintain the order of
100              parameters that appeared in the network to improve performance. The value should be parameters whose
101              order will be followed in optimizer.
102              If `order_params` in the keys, other keys will be ignored and the element of 'order_params' must be in
103              one group of `params`.
104
105        learning_rate (Union[float, int, Tensor, Iterable, LearningRateSchedule]): Default: ``0.001`` .
106
107            - float: The fixed learning rate value. Must be equal to or greater than 0.
108
109            - int: The fixed learning rate value. Must be equal to or greater than 0. It will be converted to float.
110
111            - Tensor: Its value should be a scalar or a 1-D vector. For scalar, fixed learning rate will be applied.
112              For vector, learning rate is dynamic, then the i-th step will take the i-th value as the learning rate.
113
114            - Iterable: Learning rate is dynamic. The i-th step will take the i-th value as the learning rate.
115
116            - LearningRateSchedule: Learning rate is dynamic. During training, the optimizer calls the instance of
117              `LearningRateSchedule
118              <https://www.mindspore.cn/docs/en/master/api_python/mindspore.nn.html#learningrateschedule-class>`_
119              with step as the input to get the learning rate of current step.
120
121        beta1 (float): The exponential decay rate for the 1st moment estimations. Should be in range (0.0, 1.0).
122                       Default: ``0.9`` .
123        beta2 (float): The exponential decay rate for the 2nd moment estimations. Should be in range (0.0, 1.0).
124                       Default: ``0.999`` .
125        eps (float): Term added to the denominator to improve numerical stability. Should be greater than 0.
126                     Default: ``1e-08`` .
127
128        weight_decay (Union[float, int, Cell]): Weight decay (L2 penalty). Default: ``0.0`` .
129
130            - float: The fixed weight decay value. Must be equal to or greater than 0.
131
132            - int: The fixed weight decay value. Must be equal to or greater than 0. It will be converted to float.
133
134            - Cell: Weight decay is dynamic. During training, the optimizer calls the instance of
135              the Cell with step as the input to get the weight decay value of current step.
136
137        loss_scale (float): A floating point value for the loss scale. Should be greater than 0. In general, use the
138            default value. Only when `FixedLossScaleManager` is used for training and the `drop_overflow_update` in
139            `FixedLossScaleManager` is set to ``False`` , then this value needs to be the same as the `loss_scale` in
140            `FixedLossScaleManager`. Refer to class :class:`mindspore.amp.FixedLossScaleManager` for more details.
141            Default: ``1.0`` .
142
143    Inputs:
144        - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
145
146    Outputs:
147        Tensor[bool], the value is True.
148
149    Raises:
150        TypeError: If `learning_rate` is not one of int, float, Tensor, Iterable, LearningRateSchedule.
151        TypeError: If element of `parameters` is neither Parameter nor dict.
152        TypeError: If `beta1`, `beta2`, `eps` or `loss_scale` is not a float.
153        TypeError: If `weight_decay` is neither float nor int.
154        ValueError: If `loss_scale` or `eps` is less than or equal to 0.
155        ValueError: If `beta1`, `beta2` is not in range (0.0, 1.0).
156        ValueError: If `weight_decay` is less than 0.
157
158    Supported Platforms:
159        ``Ascend`` ``GPU`` ``CPU``
160
161    Examples:
162        >>> import mindspore as ms
163        >>> from mindspore import nn
164        >>>
165        >>> # Define the network structure of LeNet5. Refer to
166        >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
167        >>> net = LeNet5()
168        >>> #1) All parameters use the same learning rate and weight decay
169        >>> optim = nn.AdaMax(params=net.trainable_params())
170        >>>
171        >>> #2) Use parameter groups and set different values
172        >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
173        >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
174        >>> group_params = [{'params': conv_params, 'weight_decay': 0.01, 'grad_centralization':True},
175        ...                 {'params': no_conv_params, 'lr': 0.01},
176        ...                 {'order_params': net.trainable_params()}]
177        >>> optim = nn.AdaMax(group_params, learning_rate=0.1, weight_decay=0.0)
178        >>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01 and grad
179        >>> # centralization of True.
180        >>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0 and grad
181        >>> # centralization of False.
182        >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'.
183        >>>
184        >>> loss = nn.SoftmaxCrossEntropyWithLogits()
185        >>> model = ms.Model(net, loss_fn=loss, optimizer=optim)
186    """
187    @opt_init_args_register
188    def __init__(self, params, learning_rate=0.001, beta1=0.9, beta2=0.999, eps=1e-08,
189                 weight_decay=0.0, loss_scale=1.0):
190        super(AdaMax, self).__init__(learning_rate, params, weight_decay, loss_scale)
191        _check_param_value(beta1, beta2, eps, self.cls_name)
192
193        self.beta1 = Tensor(beta1, mstype.float32)
194        self.beta2 = Tensor(beta2, mstype.float32)
195        self.beta1_power = Parameter(initializer(1, [1], mstype.float32), name="beta1_power")
196        self.eps = Tensor(eps, mstype.float32)
197        self.moment1 = self._parameters.clone(prefix="moment1", init='zeros')
198        self.moment2 = self._parameters.clone(prefix="moment2", init='zeros')
199
200        self.opt = P.ApplyAdaMax()
201
202    @jit
203    def construct(self, gradients):
204        gradients = self.flatten_gradients(gradients)
205        gradients = self.decay_weight(gradients)
206        gradients = self.gradients_centralization(gradients)
207        gradients = self.scale_grad(gradients)
208        lr = self.get_lr()
209        self.assignadd(self.global_step, self.global_step_increase_tensor)
210
211        self.beta1_power *= self.beta1
212
213        if self.is_group_lr:
214            success = self.map_(F.partial(_ada_max_opt, self.opt, self.beta1, self.beta2, self.beta1_power, self.eps),
215                                lr, self._parameters, self.moment1, self.moment2, gradients)
216        else:
217            success = self.map_(F.partial(_ada_max_opt, self.opt, self.beta1, self.beta2, self.beta1_power,
218                                          self.eps, lr), self._parameters, self.moment1, self.moment2, gradients)
219
220        return success
221