1# Copyright 2020-2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""lazy adam""" 16from mindspore.common import dtype as mstype 17from mindspore.common.initializer import initializer 18from mindspore.ops import operations as P 19from mindspore.ops import composite as C 20from mindspore.ops import functional as F 21from mindspore.common.parameter import Parameter 22from mindspore.common.tensor import Tensor 23from mindspore._checkparam import Validator as validator 24from mindspore._checkparam import Rel 25from .optimizer import Optimizer 26from .optimizer import opt_init_args_register 27 28_lazy_adam_opt = C.MultitypeFuncGraph("lazy_adam_opt") 29 30 31@_lazy_adam_opt.register("Function", "Function", "Function", "Function", "Bool", "Bool", "Bool", "Tensor", "Tensor", 32 "Tensor", "Tensor", "Tensor", "Tensor", "RowTensor", "Tensor", "Tensor", "Tensor", "Bool", 33 "Bool") 34def _run_opt_with_sparse(opt, sparse_opt, push, pull, use_locking, use_nesterov, target, beta1_power, beta2_power, 35 beta1, beta2, eps, lr, gradient, params, m, v, ps_parameter, cache_enable): 36 """Apply sparse lazy adam optimizer to the weight parameter when the gradient is sparse.""" 37 success = True 38 indices = gradient.indices 39 values = gradient.values 40 if ps_parameter and not cache_enable: 41 op_shape = P.Shape() 42 shapes = (op_shape(params), op_shape(m), op_shape(v), 43 op_shape(beta1_power), op_shape(beta2_power), op_shape(lr), op_shape(beta1), 44 op_shape(beta2), op_shape(eps), op_shape(values), op_shape(indices)) 45 success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2, 46 eps, values, indices), shapes), params)) 47 return success 48 49 if not target: 50 success = F.depend(success, sparse_opt(params, m, v, beta1_power, beta2_power, lr, beta1, beta2, 51 eps, values, indices)) 52 else: 53 op_gather = P.Gather() 54 op_sqrt = P.Sqrt() 55 scatter_add = P.ScatterAdd(use_locking) 56 scatter_update = P.ScatterUpdate(use_locking) 57 58 m_slice = op_gather(m, indices, 0) 59 v_slice = op_gather(v, indices, 0) 60 61 next_m = m_slice * beta1 + values * (1 - beta1) 62 next_v = v_slice * beta2 + values * values * (1 - beta2) 63 64 lr_t = lr * op_sqrt(1 - beta2_power) / (1 - beta1_power) 65 66 if use_nesterov: 67 m_temp = beta1 * next_m + values * (1 - beta1) 68 param_update = m_temp / (op_sqrt(next_v) + eps) 69 else: 70 param_update = next_m / (op_sqrt(next_v) + eps) 71 72 success = F.depend(success, scatter_add(params, indices, - lr_t * param_update)) 73 success = F.depend(success, scatter_update(m, indices, next_m)) 74 success = F.depend(success, scatter_update(v, indices, next_v)) 75 76 return success 77 78 79@_lazy_adam_opt.register("Function", "Function", "Function", "Function", "Bool", "Bool", "Bool", "Tensor", "Tensor", 80 "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool", "Bool") 81def _run_opt_with_one_number(opt, sparse_opt, push, pull, use_locking, use_nesterov, target, beta1_power, beta2_power, 82 beta1, beta2, eps, lr, gradient, params, moment1, moment2, ps_parameter, cache_enable): 83 """Apply lazy adam optimizer to the weight parameter using Tensor.""" 84 success = True 85 if ps_parameter and not cache_enable: 86 op_shape = P.Shape() 87 success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2, eps, gradient), 88 (op_shape(params), op_shape(moment1), op_shape(moment2))), params)) 89 else: 90 success = F.depend(success, opt(params, moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2, 91 eps, gradient)) 92 return success 93 94 95def _check_param_value(beta1, beta2, eps, weight_decay, prim_name): 96 """Check the type of inputs.""" 97 validator.check_value_type("beta1", beta1, [float], prim_name) 98 validator.check_value_type("beta2", beta2, [float], prim_name) 99 validator.check_value_type("eps", eps, [float], prim_name) 100 validator.check_value_type("weight_dacay", weight_decay, [float], prim_name) 101 validator.check_float_range(beta1, 0.0, 1.0, Rel.INC_NEITHER, "beta1", prim_name) 102 validator.check_float_range(beta2, 0.0, 1.0, Rel.INC_NEITHER, "beta2", prim_name) 103 validator.check_positive_float(eps, "eps", prim_name) 104 validator.check_non_negative_float(weight_decay, "weight_decay", prim_name) 105 106 107class LazyAdam(Optimizer): 108 r""" 109 This optimizer will apply a lazy adam algorithm when gradient is sparse. 110 111 The original adam algorithm is proposed in 112 `Adam: A Method for Stochastic Optimization <https://arxiv.org/abs/1412.6980>`_. 113 114 The updating formulas are as follows, 115 116 .. math:: 117 \begin{array}{ll} \\ 118 m_{t+1} = \beta_1 * m_{t} + (1 - \beta_1) * g \\ 119 v_{t+1} = \beta_2 * v_{t} + (1 - \beta_2) * g * g \\ 120 l = \alpha * \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t} \\ 121 w_{t+1} = w_{t} - l * \frac{m_{t+1}}{\sqrt{v_{t+1}} + \epsilon} 122 \end{array} 123 124 :math:`m` represents the 1st moment vector `moment1`, :math:`v` represents the 2nd moment vector `moment2`, 125 :math:`g` represents `gradients`, :math:`l` represents scaling factor, :math:`\beta_1, \beta_2` represent 126 `beta1` and `beta2`, :math:`t` represents updating step while :math:`beta_1^t` and :math:`beta_2^t` represent 127 `beta1_power` and `beta2_power`, :math:`\alpha` represents `learning_rate`, :math:`w` represents `params`, 128 :math:`\epsilon` represents `eps`. 129 130 Note: 131 When separating parameter groups, the weight decay in each group will be applied on the parameters if the 132 weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied 133 on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive. 134 135 When separating parameter groups, if you want to centralize the gradient, set grad_centralization to True, 136 but the gradient centralization can only be applied to the parameters of the convolution layer. 137 If the parameters of the non convolution layer are set to True, an error will be reported. 138 139 To improve parameter groups performance, the customized order of parameters can be supported. 140 141 The sparse strategy is applied while the SparseGatherV2 operator being used for forward network. 142 The sparse behavior, to be notice, is not equivalent to the 143 original Adam algorithm, as only the current indices parames will be updated. The sparse feature is under 144 continuous development. If the sparse strategy wants to be executed on the host, set the target to the CPU. 145 146 Args: 147 params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated, 148 the element in `params` must be class `Parameter`. When the `params` is a list of `dict`, the "params", 149 "lr" and "weight_decay" are the keys can be parsed. 150 151 - params: Required. The value must be a list of `Parameter`. 152 153 - lr: Optional. If "lr" in the keys, the value of corresponding learning rate will be used. 154 If not, the `learning_rate` in the API will be used. 155 156 - weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay 157 will be used. If not, the `weight_decay` in the API will be used. 158 159 - order_params: Optional. If "order_params" in the keys, the value must be the order of parameters and 160 the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which 161 in the value of 'order_params' must be in one of group parameters. 162 163 - grad_centralization: Optional. The data type of "grad_centralization" is Bool. If "grad_centralization" 164 is in the keys, the set value will be used. If not, the `grad_centralization` is False by default. 165 This parameter only works on the convolution layer. 166 167 learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate. 168 When the learning_rate is an Iterable or a Tensor in a 1D dimension, use dynamic learning rate, then 169 the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule, 170 use dynamic learning rate, the i-th learning rate will be calculated during the process of training 171 according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor in a zero 172 dimension, use fixed learning rate. Other cases are not supported. The float learning rate must be 173 equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float. 174 Default: 1e-3. 175 beta1 (float): The exponential decay rate for the 1st moment estimations. Should be in range (0.0, 1.0). 176 Default: 0.9. 177 beta2 (float): The exponential decay rate for the 2nd moment estimations. Should be in range (0.0, 1.0). 178 Default: 0.999. 179 eps (float): Term added to the denominator to improve numerical stability. Should be greater than 0. Default: 180 1e-8. 181 use_locking (bool): Whether to enable a lock to protect variable tensors from being updated. 182 If true, updates of the var, m, and v tensors will be protected by a lock. 183 If false, the result is unpredictable. Default: False. 184 use_nesterov (bool): Whether to use Nesterov Accelerated Gradient (NAG) algorithm to update the gradients. 185 If true, update the gradients using NAG. 186 If false, update the gradients without using NAG. Default: False. 187 weight_decay (Union[float, int]): Weight decay (L2 penalty). Default: 0.0. 188 loss_scale (float): A floating point value for the loss scale. Should be equal to or greater than 1. In general, 189 use the default value. Only when `FixedLossScaleManager` is used for training and the `drop_overflow_update` 190 in `FixedLossScaleManager` is set to False, then this value needs to be the same as the `loss_scale` in 191 `FixedLossScaleManager`. Refer to class :class:`mindspore.FixedLossScaleManager` for more details. 192 Default: 1.0. 193 194 Inputs: 195 - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`. 196 197 Outputs: 198 Tensor[bool], the value is True. 199 200 Raises: 201 TypeError: If `learning_rate` is not one of int, float, Tensor, Iterable, LearningRateSchedule. 202 TypeError: If element of `parameters` is neither Parameter nor dict. 203 TypeError: If `beta1`, `beta2`, `eps` or `loss_scale` is not a float. 204 TypeError: If `weight_decay` is neither float nor int. 205 TypeError: If `use_locking` or `use_nesterov` is not a bool. 206 ValueError: If `loss_scale` or `eps` is less than or equal to 0. 207 ValueError: If `beta1`, `beta2` is not in range (0.0, 1.0). 208 ValueError: If `weight_decay` is less than 0. 209 210 Supported Platforms: 211 ``Ascend`` ``GPU`` 212 213 Examples: 214 >>> net = Net() 215 >>> #1) All parameters use the same learning rate and weight decay 216 >>> optim = nn.LazyAdam(params=net.trainable_params()) 217 >>> 218 >>> #2) Use parameter groups and set different values 219 >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) 220 >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) 221 >>> group_params = [{'params': conv_params, 'weight_decay': 0.01, 'grad_centralization':True}, 222 ... {'params': no_conv_params, 'lr': 0.01}, 223 ... {'order_params': net.trainable_params()}] 224 >>> optim = nn.LazyAdam(group_params, learning_rate=0.1, weight_decay=0.0) 225 >>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01 and grad 226 >>> # centralization of True. 227 >>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0 and grad 228 >>> # centralization of False. 229 >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'. 230 >>> 231 >>> loss = nn.SoftmaxCrossEntropyWithLogits() 232 >>> model = Model(net, loss_fn=loss, optimizer=optim) 233 """ 234 235 @opt_init_args_register 236 def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-8, use_locking=False, 237 use_nesterov=False, weight_decay=0.0, loss_scale=1.0): 238 super(LazyAdam, self).__init__(learning_rate, params, weight_decay, loss_scale) 239 _check_param_value(beta1, beta2, eps, weight_decay, self.cls_name) 240 validator.check_value_type("use_locking", use_locking, [bool], self.cls_name) 241 validator.check_value_type("use_nesterov", use_nesterov, [bool], self.cls_name) 242 243 self.beta1 = Tensor(beta1, mstype.float32) 244 self.beta2 = Tensor(beta2, mstype.float32) 245 self.beta1_power = Parameter(initializer(1, [1], mstype.float32), name="beta1_power") 246 self.beta2_power = Parameter(initializer(1, [1], mstype.float32), name="beta2_power") 247 self.eps = Tensor(eps, mstype.float32) 248 self.use_nesterov = use_nesterov 249 self.use_locking = use_locking 250 self._is_device = True 251 self.moment1 = self.parameters.clone(prefix="moment1", init='zeros') 252 self.moment2 = self.parameters.clone(prefix="moment2", init='zeros') 253 self.opt = P.Adam(use_locking, use_nesterov) 254 self.sparse_opt = P.FusedSparseLazyAdam(use_locking, use_nesterov) 255 self.sparse_opt.add_prim_attr("primitive_target", "CPU") 256 self._ps_pull = P.Pull() 257 self._ps_push = P.Push("Adam", [0, 1, 2]) 258 self._ps_push.add_prim_attr("use_nesterov", use_nesterov) 259 260 def construct(self, gradients): 261 gradients = self.decay_weight(gradients) 262 gradients = self.gradients_centralization(gradients) 263 gradients = self.scale_grad(gradients) 264 gradients = self._grad_sparse_indices_deduplicate(gradients) 265 lr = self.get_lr() 266 267 self.beta1_power = self.beta1_power * self.beta1 268 self.beta2_power = self.beta2_power * self.beta2 269 270 if self.is_group_lr: 271 success = self.map_reverse(F.partial(_lazy_adam_opt, self.opt, self.sparse_opt, self._ps_push, 272 self._ps_pull, self.use_locking, self.use_nesterov, self._is_device, 273 self.beta1_power, self.beta2_power, self.beta1, self.beta2, self.eps), 274 lr, gradients, self.parameters, self.moment1, self.moment2, self.ps_parameters, 275 self.cache_enable) 276 else: 277 success = self.map_reverse(F.partial(_lazy_adam_opt, self.opt, self.sparse_opt, self._ps_push, 278 self._ps_pull, self.use_locking, self.use_nesterov, self._is_device, 279 self.beta1_power, self.beta2_power, self.beta1, self.beta2, self.eps, 280 lr), 281 gradients, self.parameters, self.moment1, self.moment2, self.ps_parameters, 282 self.cache_enable) 283 return success 284 285 @Optimizer.target.setter 286 def target(self, value): 287 """ 288 If the input value is set to "CPU", the parameters will be updated on the host using the Fused 289 optimizer operation. 290 """ 291 self._set_base_target(value) 292