1# Copyright 2022 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""adamax""" 16from __future__ import absolute_import 17 18from mindspore.common import dtype as mstype 19from mindspore.common.initializer import initializer 20from mindspore.common.api import jit 21from mindspore.ops import operations as P 22from mindspore.ops import composite as C 23from mindspore.ops import functional as F 24from mindspore.common.parameter import Parameter 25from mindspore.common.tensor import Tensor 26from mindspore import _checkparam as validator 27from mindspore.nn.optim.optimizer import Optimizer 28from mindspore.nn.optim.optimizer import opt_init_args_register 29 30_ada_max_opt = C.MultitypeFuncGraph("ada_max_opt") 31 32 33@_ada_max_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", 34 "Tensor", "Tensor") 35def _tensor_run_opt(opt, beta1, beta2, beta1_power, eps, learning_rate, weight, moment1, moment2, gradient): 36 success = True 37 success = F.depend(success, opt(weight, moment1, moment2, beta1_power, learning_rate, beta1, beta2, eps, gradient)) 38 return success 39 40 41def _check_param_value(beta1, beta2, eps, prim_name): 42 """Check the type of inputs.""" 43 validator.check_value_type("beta1", beta1, [float], prim_name) 44 validator.check_value_type("beta2", beta2, [float], prim_name) 45 validator.check_value_type("eps", eps, [float], prim_name) 46 validator.check_float_range(beta1, 0.0, 1.0, validator.INC_NEITHER, "beta1", prim_name) 47 validator.check_float_range(beta2, 0.0, 1.0, validator.INC_NEITHER, "beta2", prim_name) 48 validator.check_positive_float(eps, "eps", prim_name) 49 50 51class AdaMax(Optimizer): 52 r""" 53 Implements the AdaMax algorithm, a variant of Adaptive Movement Estimation (Adam) based on the infinity norm. 54 55 The AdaMax algorithm is proposed in `Adam: A Method for Stochastic Optimization <https://arxiv.org/abs/1412.6980>`_. 56 57 The updating formulas are as follows, 58 59 .. math:: 60 \begin{array}{ll} \\ 61 m_{t+1} = \beta_1 * m_{t} + (1 - \beta_1) * g \\ 62 v_{t+1} = \max(\beta_2 * v_{t}, \left| g \right|) \\ 63 w = w - \frac{l}{1 - \beta_1^{t+1}} * \frac{m_{t+1}}{v_{t+1} + \epsilon} 64 \end{array} 65 66 :math:`m` represents the 1st moment vector, :math:`v` represents the 2nd moment vector, 67 :math:`g` represents `gradients`, :math:`\beta_1, \beta_2` represent `beta1` and `beta2`, 68 :math:`t` represents the current step, :math:`beta_1^t` represent `beta1_power`, 69 :math:`l` represents `learning_rate`, :math:`w` represents `params`, 70 :math:`\epsilon` represents `eps`. 71 72 Note: 73 If parameters are not grouped, the `weight_decay` in optimizer will be applied on the network parameters without 74 'beta' or 'gamma' in their names. Users can group parameters to change the strategy of decaying weight. When 75 parameters are grouped, each group can set `weight_decay`. If not, the `weight_decay` in optimizer will be 76 applied. 77 78 Args: 79 params (Union[list[Parameter], list[dict]]): Must be list of `Parameter` or list of `dict`. When the 80 `params` is a list of `dict`, the string "params", "lr", "weight_decay", "grad_centralization" and 81 "order_params" are the keys can be parsed. 82 83 - params: Required. Parameters in current group. The value must be a list of `Parameter`. 84 85 - lr: Optional. If "lr" in the keys, the value of corresponding learning rate will be used. 86 If not, the `learning_rate` in optimizer will be used. Fixed and dynamic learning rate are supported. 87 88 - weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay 89 will be used. If not, the `weight_decay` in the optimizer will be used. It should be noted that weight 90 decay can be a constant value or a Cell. It is a Cell only when dynamic weight decay is applied. Dynamic 91 weight decay is similar to dynamic learning rate, users need to customize a weight decay schedule only 92 with global step as input, and during training, the optimizer calls the instance of WeightDecaySchedule 93 to get the weight decay value of current step. 94 95 - grad_centralization: Optional. Must be Boolean. If "grad_centralization" is in the keys, the set value 96 will be used. If not, the `grad_centralization` is False by default. This configuration only works on the 97 convolution layer. 98 99 - order_params: Optional. When parameters is grouped, this usually is used to maintain the order of 100 parameters that appeared in the network to improve performance. The value should be parameters whose 101 order will be followed in optimizer. 102 If `order_params` in the keys, other keys will be ignored and the element of 'order_params' must be in 103 one group of `params`. 104 105 learning_rate (Union[float, int, Tensor, Iterable, LearningRateSchedule]): Default: ``0.001`` . 106 107 - float: The fixed learning rate value. Must be equal to or greater than 0. 108 109 - int: The fixed learning rate value. Must be equal to or greater than 0. It will be converted to float. 110 111 - Tensor: Its value should be a scalar or a 1-D vector. For scalar, fixed learning rate will be applied. 112 For vector, learning rate is dynamic, then the i-th step will take the i-th value as the learning rate. 113 114 - Iterable: Learning rate is dynamic. The i-th step will take the i-th value as the learning rate. 115 116 - LearningRateSchedule: Learning rate is dynamic. During training, the optimizer calls the instance of 117 `LearningRateSchedule 118 <https://www.mindspore.cn/docs/en/master/api_python/mindspore.nn.html#learningrateschedule-class>`_ 119 with step as the input to get the learning rate of current step. 120 121 beta1 (float): The exponential decay rate for the 1st moment estimations. Should be in range (0.0, 1.0). 122 Default: ``0.9`` . 123 beta2 (float): The exponential decay rate for the 2nd moment estimations. Should be in range (0.0, 1.0). 124 Default: ``0.999`` . 125 eps (float): Term added to the denominator to improve numerical stability. Should be greater than 0. 126 Default: ``1e-08`` . 127 128 weight_decay (Union[float, int, Cell]): Weight decay (L2 penalty). Default: ``0.0`` . 129 130 - float: The fixed weight decay value. Must be equal to or greater than 0. 131 132 - int: The fixed weight decay value. Must be equal to or greater than 0. It will be converted to float. 133 134 - Cell: Weight decay is dynamic. During training, the optimizer calls the instance of 135 the Cell with step as the input to get the weight decay value of current step. 136 137 loss_scale (float): A floating point value for the loss scale. Should be greater than 0. In general, use the 138 default value. Only when `FixedLossScaleManager` is used for training and the `drop_overflow_update` in 139 `FixedLossScaleManager` is set to ``False`` , then this value needs to be the same as the `loss_scale` in 140 `FixedLossScaleManager`. Refer to class :class:`mindspore.amp.FixedLossScaleManager` for more details. 141 Default: ``1.0`` . 142 143 Inputs: 144 - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`. 145 146 Outputs: 147 Tensor[bool], the value is True. 148 149 Raises: 150 TypeError: If `learning_rate` is not one of int, float, Tensor, Iterable, LearningRateSchedule. 151 TypeError: If element of `parameters` is neither Parameter nor dict. 152 TypeError: If `beta1`, `beta2`, `eps` or `loss_scale` is not a float. 153 TypeError: If `weight_decay` is neither float nor int. 154 ValueError: If `loss_scale` or `eps` is less than or equal to 0. 155 ValueError: If `beta1`, `beta2` is not in range (0.0, 1.0). 156 ValueError: If `weight_decay` is less than 0. 157 158 Supported Platforms: 159 ``Ascend`` ``GPU`` ``CPU`` 160 161 Examples: 162 >>> import mindspore as ms 163 >>> from mindspore import nn 164 >>> 165 >>> # Define the network structure of LeNet5. Refer to 166 >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py 167 >>> net = LeNet5() 168 >>> #1) All parameters use the same learning rate and weight decay 169 >>> optim = nn.AdaMax(params=net.trainable_params()) 170 >>> 171 >>> #2) Use parameter groups and set different values 172 >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) 173 >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) 174 >>> group_params = [{'params': conv_params, 'weight_decay': 0.01, 'grad_centralization':True}, 175 ... {'params': no_conv_params, 'lr': 0.01}, 176 ... {'order_params': net.trainable_params()}] 177 >>> optim = nn.AdaMax(group_params, learning_rate=0.1, weight_decay=0.0) 178 >>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01 and grad 179 >>> # centralization of True. 180 >>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0 and grad 181 >>> # centralization of False. 182 >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'. 183 >>> 184 >>> loss = nn.SoftmaxCrossEntropyWithLogits() 185 >>> model = ms.Model(net, loss_fn=loss, optimizer=optim) 186 """ 187 @opt_init_args_register 188 def __init__(self, params, learning_rate=0.001, beta1=0.9, beta2=0.999, eps=1e-08, 189 weight_decay=0.0, loss_scale=1.0): 190 super(AdaMax, self).__init__(learning_rate, params, weight_decay, loss_scale) 191 _check_param_value(beta1, beta2, eps, self.cls_name) 192 193 self.beta1 = Tensor(beta1, mstype.float32) 194 self.beta2 = Tensor(beta2, mstype.float32) 195 self.beta1_power = Parameter(initializer(1, [1], mstype.float32), name="beta1_power") 196 self.eps = Tensor(eps, mstype.float32) 197 self.moment1 = self._parameters.clone(prefix="moment1", init='zeros') 198 self.moment2 = self._parameters.clone(prefix="moment2", init='zeros') 199 200 self.opt = P.ApplyAdaMax() 201 202 @jit 203 def construct(self, gradients): 204 gradients = self.flatten_gradients(gradients) 205 gradients = self.decay_weight(gradients) 206 gradients = self.gradients_centralization(gradients) 207 gradients = self.scale_grad(gradients) 208 lr = self.get_lr() 209 self.assignadd(self.global_step, self.global_step_increase_tensor) 210 211 self.beta1_power *= self.beta1 212 213 if self.is_group_lr: 214 success = self.map_(F.partial(_ada_max_opt, self.opt, self.beta1, self.beta2, self.beta1_power, self.eps), 215 lr, self._parameters, self.moment1, self.moment2, gradients) 216 else: 217 success = self.map_(F.partial(_ada_max_opt, self.opt, self.beta1, self.beta2, self.beta1_power, 218 self.eps, lr), self._parameters, self.moment1, self.moment2, gradients) 219 220 return success 221