• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020-2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""lazy adam"""
16from mindspore.common import dtype as mstype
17from mindspore.common.initializer import initializer
18from mindspore.ops import operations as P
19from mindspore.ops import composite as C
20from mindspore.ops import functional as F
21from mindspore.common.parameter import Parameter
22from mindspore.common.tensor import Tensor
23from mindspore._checkparam import Validator as validator
24from mindspore._checkparam import Rel
25from .optimizer import Optimizer
26from .optimizer import opt_init_args_register
27
28_lazy_adam_opt = C.MultitypeFuncGraph("lazy_adam_opt")
29
30
31@_lazy_adam_opt.register("Function", "Function", "Function", "Function", "Bool", "Bool", "Bool", "Tensor", "Tensor",
32                         "Tensor", "Tensor", "Tensor", "Tensor", "RowTensor", "Tensor", "Tensor", "Tensor", "Bool",
33                         "Bool")
34def _run_opt_with_sparse(opt, sparse_opt, push, pull, use_locking, use_nesterov, target, beta1_power, beta2_power,
35                         beta1, beta2, eps, lr, gradient, params, m, v, ps_parameter, cache_enable):
36    """Apply sparse lazy adam optimizer to the weight parameter when the gradient is sparse."""
37    success = True
38    indices = gradient.indices
39    values = gradient.values
40    if ps_parameter and not cache_enable:
41        op_shape = P.Shape()
42        shapes = (op_shape(params), op_shape(m), op_shape(v),
43                  op_shape(beta1_power), op_shape(beta2_power), op_shape(lr), op_shape(beta1),
44                  op_shape(beta2), op_shape(eps), op_shape(values), op_shape(indices))
45        success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2,
46                                               eps, values, indices), shapes), params))
47        return success
48
49    if not target:
50        success = F.depend(success, sparse_opt(params, m, v, beta1_power, beta2_power, lr, beta1, beta2,
51                                               eps, values, indices))
52    else:
53        op_gather = P.Gather()
54        op_sqrt = P.Sqrt()
55        scatter_add = P.ScatterAdd(use_locking)
56        scatter_update = P.ScatterUpdate(use_locking)
57
58        m_slice = op_gather(m, indices, 0)
59        v_slice = op_gather(v, indices, 0)
60
61        next_m = m_slice * beta1 + values * (1 - beta1)
62        next_v = v_slice * beta2 + values * values * (1 - beta2)
63
64        lr_t = lr * op_sqrt(1 - beta2_power) / (1 - beta1_power)
65
66        if use_nesterov:
67            m_temp = beta1 * next_m + values * (1 - beta1)
68            param_update = m_temp / (op_sqrt(next_v) + eps)
69        else:
70            param_update = next_m / (op_sqrt(next_v) + eps)
71
72        success = F.depend(success, scatter_add(params, indices, - lr_t * param_update))
73        success = F.depend(success, scatter_update(m, indices, next_m))
74        success = F.depend(success, scatter_update(v, indices, next_v))
75
76    return success
77
78
79@_lazy_adam_opt.register("Function", "Function", "Function", "Function", "Bool", "Bool", "Bool", "Tensor", "Tensor",
80                         "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool", "Bool")
81def _run_opt_with_one_number(opt, sparse_opt, push, pull, use_locking, use_nesterov, target, beta1_power, beta2_power,
82                             beta1, beta2, eps, lr, gradient, params, moment1, moment2, ps_parameter, cache_enable):
83    """Apply lazy adam optimizer to the weight parameter using Tensor."""
84    success = True
85    if ps_parameter and not cache_enable:
86        op_shape = P.Shape()
87        success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2, eps, gradient),
88                                              (op_shape(params), op_shape(moment1), op_shape(moment2))), params))
89    else:
90        success = F.depend(success, opt(params, moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2,
91                                        eps, gradient))
92    return success
93
94
95def _check_param_value(beta1, beta2, eps, weight_decay, prim_name):
96    """Check the type of inputs."""
97    validator.check_value_type("beta1", beta1, [float], prim_name)
98    validator.check_value_type("beta2", beta2, [float], prim_name)
99    validator.check_value_type("eps", eps, [float], prim_name)
100    validator.check_value_type("weight_dacay", weight_decay, [float], prim_name)
101    validator.check_float_range(beta1, 0.0, 1.0, Rel.INC_NEITHER, "beta1", prim_name)
102    validator.check_float_range(beta2, 0.0, 1.0, Rel.INC_NEITHER, "beta2", prim_name)
103    validator.check_positive_float(eps, "eps", prim_name)
104    validator.check_non_negative_float(weight_decay, "weight_decay", prim_name)
105
106
107class LazyAdam(Optimizer):
108    r"""
109    This optimizer will apply a lazy adam algorithm when gradient is sparse.
110
111    The original adam algorithm is proposed in
112    `Adam: A Method for Stochastic Optimization <https://arxiv.org/abs/1412.6980>`_.
113
114    The updating formulas are as follows,
115
116    .. math::
117        \begin{array}{ll} \\
118            m_{t+1} = \beta_1 * m_{t} + (1 - \beta_1) * g \\
119            v_{t+1} = \beta_2 * v_{t} + (1 - \beta_2) * g * g \\
120            l = \alpha * \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t} \\
121            w_{t+1} = w_{t} - l * \frac{m_{t+1}}{\sqrt{v_{t+1}} + \epsilon}
122        \end{array}
123
124    :math:`m` represents the 1st moment vector `moment1`, :math:`v` represents the 2nd moment vector `moment2`,
125    :math:`g` represents `gradients`, :math:`l` represents scaling factor, :math:`\beta_1, \beta_2` represent
126    `beta1` and `beta2`, :math:`t` represents updating step while :math:`beta_1^t` and :math:`beta_2^t` represent
127    `beta1_power` and `beta2_power`, :math:`\alpha` represents `learning_rate`, :math:`w` represents `params`,
128    :math:`\epsilon` represents `eps`.
129
130    Note:
131        When separating parameter groups, the weight decay in each group will be applied on the parameters if the
132        weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied
133        on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive.
134
135        When separating parameter groups, if you want to centralize the gradient, set grad_centralization to True,
136        but the gradient centralization can only be applied to the parameters of the convolution layer.
137        If the parameters of the non convolution layer are set to True, an error will be reported.
138
139        To improve parameter groups performance, the customized order of parameters can be supported.
140
141        The sparse strategy is applied while the SparseGatherV2 operator being used for forward network.
142        The sparse behavior, to be notice, is not equivalent to the
143        original Adam algorithm, as only the current indices parames will be updated. The sparse feature is under
144        continuous development. If the sparse strategy wants to be executed on the host, set the target to the CPU.
145
146    Args:
147        params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated,
148            the element in `params` must be class `Parameter`. When the `params` is a list of `dict`, the "params",
149            "lr" and "weight_decay" are the keys can be parsed.
150
151            - params: Required. The value must be a list of `Parameter`.
152
153            - lr: Optional. If "lr" in the keys, the value of corresponding learning rate will be used.
154              If not, the `learning_rate` in the API will be used.
155
156            - weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay
157              will be used. If not, the `weight_decay` in the API will be used.
158
159            - order_params: Optional. If "order_params" in the keys, the value must be the order of parameters and
160              the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which
161              in the value of 'order_params' must be in one of group parameters.
162
163            - grad_centralization: Optional. The data type of "grad_centralization" is Bool. If "grad_centralization"
164              is in the keys, the set value will be used. If not, the `grad_centralization` is False by default.
165              This parameter only works on the convolution layer.
166
167        learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate.
168            When the learning_rate is an Iterable or a Tensor in a 1D dimension, use dynamic learning rate, then
169            the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule,
170            use dynamic learning rate, the i-th learning rate will be calculated during the process of training
171            according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor in a zero
172            dimension, use fixed learning rate. Other cases are not supported. The float learning rate must be
173            equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float.
174            Default: 1e-3.
175        beta1 (float): The exponential decay rate for the 1st moment estimations. Should be in range (0.0, 1.0).
176                       Default: 0.9.
177        beta2 (float): The exponential decay rate for the 2nd moment estimations. Should be in range (0.0, 1.0).
178                       Default: 0.999.
179        eps (float): Term added to the denominator to improve numerical stability. Should be greater than 0. Default:
180                     1e-8.
181        use_locking (bool): Whether to enable a lock to protect variable tensors from being updated.
182            If true, updates of the var, m, and v tensors will be protected by a lock.
183            If false, the result is unpredictable. Default: False.
184        use_nesterov (bool): Whether to use Nesterov Accelerated Gradient (NAG) algorithm to update the gradients.
185            If true, update the gradients using NAG.
186            If false, update the gradients without using NAG. Default: False.
187        weight_decay (Union[float, int]): Weight decay (L2 penalty). Default: 0.0.
188        loss_scale (float): A floating point value for the loss scale. Should be equal to or greater than 1. In general,
189            use the default value. Only when `FixedLossScaleManager` is used for training and the `drop_overflow_update`
190            in `FixedLossScaleManager` is set to False, then this value needs to be the same as the `loss_scale` in
191            `FixedLossScaleManager`. Refer to class :class:`mindspore.FixedLossScaleManager` for more details.
192            Default: 1.0.
193
194    Inputs:
195        - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
196
197    Outputs:
198        Tensor[bool], the value is True.
199
200    Raises:
201        TypeError: If `learning_rate` is not one of int, float, Tensor, Iterable, LearningRateSchedule.
202        TypeError: If element of `parameters` is neither Parameter nor dict.
203        TypeError: If `beta1`, `beta2`, `eps` or `loss_scale` is not a float.
204        TypeError: If `weight_decay` is neither float nor int.
205        TypeError: If `use_locking` or `use_nesterov` is not a bool.
206        ValueError: If `loss_scale` or `eps` is less than or equal to 0.
207        ValueError: If `beta1`, `beta2` is not in range (0.0, 1.0).
208        ValueError: If `weight_decay` is less than 0.
209
210    Supported Platforms:
211        ``Ascend`` ``GPU``
212
213    Examples:
214        >>> net = Net()
215        >>> #1) All parameters use the same learning rate and weight decay
216        >>> optim = nn.LazyAdam(params=net.trainable_params())
217        >>>
218        >>> #2) Use parameter groups and set different values
219        >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
220        >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
221        >>> group_params = [{'params': conv_params, 'weight_decay': 0.01, 'grad_centralization':True},
222        ...                 {'params': no_conv_params, 'lr': 0.01},
223        ...                 {'order_params': net.trainable_params()}]
224        >>> optim = nn.LazyAdam(group_params, learning_rate=0.1, weight_decay=0.0)
225        >>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01 and grad
226        >>> # centralization of True.
227        >>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0 and grad
228        >>> # centralization of False.
229        >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'.
230        >>>
231        >>> loss = nn.SoftmaxCrossEntropyWithLogits()
232        >>> model = Model(net, loss_fn=loss, optimizer=optim)
233    """
234
235    @opt_init_args_register
236    def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-8, use_locking=False,
237                 use_nesterov=False, weight_decay=0.0, loss_scale=1.0):
238        super(LazyAdam, self).__init__(learning_rate, params, weight_decay, loss_scale)
239        _check_param_value(beta1, beta2, eps, weight_decay, self.cls_name)
240        validator.check_value_type("use_locking", use_locking, [bool], self.cls_name)
241        validator.check_value_type("use_nesterov", use_nesterov, [bool], self.cls_name)
242
243        self.beta1 = Tensor(beta1, mstype.float32)
244        self.beta2 = Tensor(beta2, mstype.float32)
245        self.beta1_power = Parameter(initializer(1, [1], mstype.float32), name="beta1_power")
246        self.beta2_power = Parameter(initializer(1, [1], mstype.float32), name="beta2_power")
247        self.eps = Tensor(eps, mstype.float32)
248        self.use_nesterov = use_nesterov
249        self.use_locking = use_locking
250        self._is_device = True
251        self.moment1 = self.parameters.clone(prefix="moment1", init='zeros')
252        self.moment2 = self.parameters.clone(prefix="moment2", init='zeros')
253        self.opt = P.Adam(use_locking, use_nesterov)
254        self.sparse_opt = P.FusedSparseLazyAdam(use_locking, use_nesterov)
255        self.sparse_opt.add_prim_attr("primitive_target", "CPU")
256        self._ps_pull = P.Pull()
257        self._ps_push = P.Push("Adam", [0, 1, 2])
258        self._ps_push.add_prim_attr("use_nesterov", use_nesterov)
259
260    def construct(self, gradients):
261        gradients = self.decay_weight(gradients)
262        gradients = self.gradients_centralization(gradients)
263        gradients = self.scale_grad(gradients)
264        gradients = self._grad_sparse_indices_deduplicate(gradients)
265        lr = self.get_lr()
266
267        self.beta1_power = self.beta1_power * self.beta1
268        self.beta2_power = self.beta2_power * self.beta2
269
270        if self.is_group_lr:
271            success = self.map_reverse(F.partial(_lazy_adam_opt, self.opt, self.sparse_opt, self._ps_push,
272                                                 self._ps_pull, self.use_locking, self.use_nesterov, self._is_device,
273                                                 self.beta1_power, self.beta2_power, self.beta1, self.beta2, self.eps),
274                                       lr, gradients, self.parameters, self.moment1, self.moment2, self.ps_parameters,
275                                       self.cache_enable)
276        else:
277            success = self.map_reverse(F.partial(_lazy_adam_opt, self.opt, self.sparse_opt, self._ps_push,
278                                                 self._ps_pull, self.use_locking, self.use_nesterov, self._is_device,
279                                                 self.beta1_power, self.beta2_power, self.beta1, self.beta2, self.eps,
280                                                 lr),
281                                       gradients, self.parameters, self.moment1, self.moment2, self.ps_parameters,
282                                       self.cache_enable)
283        return success
284
285    @Optimizer.target.setter
286    def target(self, value):
287        """
288        If the input value is set to "CPU", the parameters will be updated on the host using the Fused
289        optimizer operation.
290        """
291        self._set_base_target(value)
292