1# Copyright 2020-2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""momentum""" 16from mindspore.ops import functional as F, composite as C, operations as P 17from mindspore.common.parameter import Parameter 18from mindspore.common.tensor import Tensor 19import mindspore.common.dtype as mstype 20from mindspore._checkparam import Validator 21from .optimizer import Optimizer 22from .optimizer import opt_init_args_register 23 24_momentum_opt = C.MultitypeFuncGraph("momentum_opt") 25 26 27@_momentum_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool", "Bool") 28def _tensor_run_opt_ext(opt, momentum, learning_rate, gradient, weight, moment, ps_parameter, cache_enable): 29 """Apply momentum optimizer to the weight parameter using Tensor.""" 30 if ps_parameter and not cache_enable: 31 op_shape = P.Shape() 32 _ps_pull = P.Pull() 33 _ps_push = P.Push("ApplyMomentum", []) 34 shapes = (op_shape(learning_rate), op_shape(gradient), op_shape(momentum)) 35 success = F.depend(True, _ps_pull(_ps_push((learning_rate, gradient, momentum), shapes), weight)) 36 else: 37 success = F.depend(True, opt(weight, moment, learning_rate, gradient, momentum)) 38 return success 39 40 41class Momentum(Optimizer): 42 r""" 43 Implements the Momentum algorithm. 44 45 Refer to the paper on the importance of initialization and momentum in deep learning for more details. 46 47 .. math:: 48 v_{t+1} = v_{t} \ast u + gradients 49 50 If use_nesterov is True: 51 52 .. math:: 53 p_{t+1} = p_{t} - (grad \ast lr + v_{t+1} \ast u \ast lr) 54 55 If use_nesterov is False: 56 57 .. math:: 58 p_{t+1} = p_{t} - lr \ast v_{t+1} 59 60 Here: where grad, lr, p, v and u denote the gradients, learning_rate, params, moments, and momentum respectively. 61 62 Note: 63 When separating parameter groups, the weight decay in each group will be applied on the parameters if the 64 weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied 65 on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive. 66 67 When separating parameter groups, if you want to centralize the gradient, set grad_centralization to True, 68 but the gradient centralization can only be applied to the parameters of the convolution layer. 69 If the parameters of the non convolution layer are set to True, an error will be reported. 70 71 To improve parameter groups performance, the customized order of parameters can be supported. 72 73 Args: 74 params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated, 75 the element in `params` must be class `Parameter`. When the `params` is a list of `dict`, the "params", 76 "lr", "weight_decay" and "order_params" are the keys can be parsed. 77 78 - params: Required. The value must be a list of `Parameter`. 79 80 - lr: Optional. If "lr" in the keys, the value of corresponding learning rate will be used. 81 If not, the `learning_rate` in the API will be used. 82 83 - weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay 84 will be used. If not, the `weight_decay` in the API will be used. 85 86 - order_params: Optional. If "order_params" in the keys, the value must be the order of parameters and 87 the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which 88 in the value of 'order_params' must be in one of group parameters. 89 90 - grad_centralization: Optional. The data type of "grad_centralization" is Bool. If "grad_centralization" 91 is in the keys, the set value will be used. If not, the `grad_centralization` is False by default. 92 This parameter only works on the convolution layer. 93 94 learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate. 95 When the learning_rate is an Iterable or a Tensor in a 1D dimension, use dynamic learning rate, then 96 the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule, 97 use dynamic learning rate, the i-th learning rate will be calculated during the process of training 98 according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor in a zero 99 dimension, use fixed learning rate. Other cases are not supported. The float learning rate must be 100 equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float. 101 momentum (float): Hyperparameter of type float, means momentum for the moving average. 102 It must be at least 0.0. 103 weight_decay (int, float): Weight decay (L2 penalty). It must be equal to or greater than 0.0. Default: 0.0. 104 loss_scale (float): A floating point value for the loss scale. It must be greater than 0.0. In general, use the 105 default value. Only when `FixedLossScaleManager` is used for training and the `drop_overflow_update` in 106 `FixedLossScaleManager` is set to False, then this value needs to be the same as the `loss_scale` in 107 `FixedLossScaleManager`. Refer to class :class:`mindspore.FixedLossScaleManager` for more details. 108 Default: 1.0. 109 use_nesterov (bool): Enable Nesterov momentum. Default: False. 110 111 Inputs: 112 - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`. 113 114 Outputs: 115 tuple[bool]. All elements are True. 116 117 Raises: 118 TypeError: If `learning_rate` is not one of int, float, Tensor, Iterable, LearningRateSchedule. 119 TypeError: If element of `parameters` is neither Parameter nor dict. 120 TypeError: If `loss_scale` or `momentum` is not a float. 121 TypeError: If `weight_decay` is neither float nor int. 122 TypeError: If `use_nesterov` is not a bool. 123 ValueError: If `loss_scale` is less than or equal to 0. 124 ValueError: If `weight_decay` or `momentum` is less than 0. 125 126 Supported Platforms: 127 ``Ascend`` ``GPU`` ``CPU`` 128 129 Examples: 130 >>> net = Net() 131 >>> #1) All parameters use the same learning rate and weight decay 132 >>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) 133 >>> 134 >>> #2) Use parameter groups and set different values 135 >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) 136 >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) 137 >>> group_params = [{'params': conv_params, 'weight_decay': 0.01, 'grad_centralization':True}, 138 ... {'params': no_conv_params, 'lr': 0.01}, 139 ... {'order_params': net.trainable_params()}] 140 >>> optim = nn.Momentum(group_params, learning_rate=0.1, momentum=0.9, weight_decay=0.0) 141 >>> # The conv_params's parameters will use a learning rate of default value 0.1 and a weight decay of 0.01 and 142 >>> # grad centralization of True. 143 >>> # The no_conv_params's parameters will use a learning rate of 0.01 and a weight decay of default value 0.0 144 >>> # and grad centralization of False.. 145 >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'. 146 >>> 147 >>> loss = nn.SoftmaxCrossEntropyWithLogits() 148 >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None) 149 """ 150 @opt_init_args_register 151 def __init__(self, params, learning_rate, momentum, weight_decay=0.0, loss_scale=1.0, use_nesterov=False): 152 super(Momentum, self).__init__(learning_rate, params, weight_decay, loss_scale) 153 Validator.check_value_type("momentum", momentum, [float], self.cls_name) 154 if isinstance(momentum, float) and momentum < 0.0: 155 raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum)) 156 self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum") 157 self.params = self.parameters 158 self.use_nesterov = Validator.check_bool(use_nesterov) 159 self.moments = self.params.clone(prefix="moments", init='zeros') 160 self.opt = P.ApplyMomentum(use_nesterov=self.use_nesterov) 161 162 def construct(self, gradients): 163 params = self.params 164 moments = self.moments 165 gradients = self.decay_weight(gradients) 166 gradients = self.gradients_centralization(gradients) 167 gradients = self.scale_grad(gradients) 168 lr = self.get_lr() 169 if self.is_group_lr: 170 success = self.hyper_map_reverse(F.partial(_momentum_opt, self.opt, self.momentum), 171 lr, gradients, params, moments, self.ps_parameters, self.cache_enable) 172 else: 173 success = self.hyper_map_reverse(F.partial(_momentum_opt, self.opt, self.momentum, lr), 174 gradients, params, moments, self.ps_parameters, self.cache_enable) 175 return success 176