• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020-2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""FTRL"""
16from mindspore.ops import functional as F, composite as C, operations as P
17from mindspore.common import Tensor
18import mindspore.common.dtype as mstype
19from mindspore._checkparam import Validator as validator
20from mindspore._checkparam import Rel
21from .optimizer import Optimizer, _apply_decay, _grad_scale
22from .optimizer import opt_init_args_register
23
24_ftrl_opt = C.MultitypeFuncGraph("ftrl_opt")
25
26
27@_ftrl_opt.register("Function", "Function", "Function", "Function", "Number", "Number", "Number", "Tensor", "Tensor",
28                    "RowTensor", "Tensor", "Tensor", "Bool", "Bool")
29def _tensor_run_opt_with_sparse(opt, spars_opt, push, pull, l1, l2, lr_power, learning_rate, linear,
30                                gradient, weight, moment, ps_parameter, cache_enable):
31    """Apply sparse ftrl optimizer to the weight parameter when the gradient is sparse."""
32    success = True
33    indices = gradient.indices
34    values = gradient.values
35    if ps_parameter and not cache_enable:
36        op_shape = P.Shape()
37        shapes = (op_shape(weight), op_shape(moment), op_shape(linear), op_shape(values), op_shape(indices))
38        success = F.depend(success, pull(push((values, indices), shapes), weight))
39    else:
40        success = F.depend(success, spars_opt(weight, moment, linear, values, indices))
41    return success
42
43
44@_ftrl_opt.register("Function", "Function", "Function", "Function", "Number", "Number", "Number", "Tensor", "Tensor",
45                    "Tensor", "Tensor", "Tensor", "Bool", "Bool")
46def _tensor_run_opt(opt, spars_opt, push, pull, l1, l2, lr_power, learning_rate, linear,
47                    gradient, weight, moment, ps_parameter, cache_enable):
48    """Apply ftrl optimizer to the weight parameter."""
49    success = True
50    if ps_parameter and not cache_enable:
51        op_shape = P.Shape()
52        success = F.depend(success, pull(push((gradient, learning_rate, l1, l2, lr_power),
53                                              (op_shape(weight), op_shape(moment), op_shape(linear))), weight))
54    else:
55        success = F.depend(success, opt(weight, moment, linear, gradient, learning_rate, l1, l2, lr_power))
56    return success
57
58
59def _check_param(initial_accum, lr_power, l1, l2, use_locking, prim_name=None):
60    """Check param."""
61    validator.check_value_type("initial_accum", initial_accum, [float], prim_name)
62    validator.check_number("initial_accum", initial_accum, 0.0, Rel.GE, prim_name)
63
64    validator.check_value_type("lr_power", lr_power, [float], prim_name)
65    validator.check_number("lr_power", lr_power, 0.0, Rel.LE, prim_name)
66
67    validator.check_value_type("l1", l1, [float], prim_name)
68    validator.check_number("l1", l1, 0.0, Rel.GE, prim_name)
69
70    validator.check_value_type("l2", l2, [float], prim_name)
71    validator.check_number("l2", l2, 0.0, Rel.GE, prim_name)
72
73    validator.check_value_type("use_locking", use_locking, [bool], prim_name)
74
75
76class FTRL(Optimizer):
77    r"""
78    Implements the FTRL algorithm with ApplyFtrl Operator.
79
80    FTRL is an online convex optimization algorithm that adaptively chooses its regularization function
81    based on the loss functions. Refer to paper `Adaptive Bound Optimization for Online Convex Optimization
82    <https://arxiv.org/abs/1002.4908>`_. Refer to paper `Ad Click Prediction: a View from the Trenches
83    <https://www.eecs.tufts.edu/~dsculley/papers/ad-click-prediction.pdf>`_ for engineering document.
84
85    The updating formulas are as follows,
86
87    .. math::
88
89        \begin{array}{ll} \\
90            m_{t+1} = m_{t} + g^2 \\
91            u_{t+1} = u_{t} + g  - \frac{m_{t+1}^\text{-p} - m_{t}^\text{-p}}{\alpha } * \omega_{t} \\
92            \omega_{t+1} =
93            \begin{cases}
94                \frac{(sign(u_{t+1}) * l1 - u_{t+1})}{\frac{m_{t+1}^\text{-p}}{\alpha } + 2 * l2 }
95                    & \text{ if } |u_{t+1}| > l1 \\
96                0.0
97                    & \text{ otherwise }
98            \end{cases}\\
99        \end{array}
100
101    :math:`m` represents `accum`, :math:`g` represents `grads`, :math:`t` represents updating step,
102    :math:`u` represents `linear`, :math:`p` represents `lr_power`, :math:`\alpha` represents `learning_rate`,
103    :math:`\omega` represents `params`.
104
105    Note:
106        When separating parameter groups, the weight decay in each group will be applied on the parameters if the
107        weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied
108        on all of the parameters.
109
110        When separating parameter groups, if you want to centralize the gradient, set grad_centralization to True,
111        but the gradient centralization can only be applied to the parameters of the convolution layer.
112        If the parameters of the non convolution layer are set to True, an error will be reported.
113
114        To improve parameter groups performance, the customized order of parameters can be supported.
115
116        The sparse strategy is applied while the SparseGatherV2 operator being used for forward network.
117        The sparse feature is under continuous development. If the sparse strategy wants to be executed on the host,
118        set the target to the CPU.
119
120    Args:
121        params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated,
122            the element in `params` must be class `Parameter`. When the `params` is a list of `dict`, the "params",
123            "lr", "weight_decay" and "order_params" are the keys can be parsed.
124
125            - params: Required. The value must be a list of `Parameter`.
126
127            - lr: Using different learning rate by separating parameters is currently not supported.
128
129            - weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay
130              will be used. If not, the `weight_decay` in the API will be used.
131
132            - order_params: Optional. If "order_params" in the keys, the value must be the order of parameters and
133              the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which
134              in the value of 'order_params' must be in one of group parameters.
135
136            - grad_centralization: Optional. The data type of "grad_centralization" is Bool. If "grad_centralization"
137              is in the keys, the set value will be used. If not, the `grad_centralization` is False by default.
138              This parameter only works on the convolution layer.
139
140        initial_accum (float): The starting value for accumulators, must be zero or positive values. Default: 0.1.
141        learning_rate (float): The learning rate value, must be zero or positive, dynamic learning rate is currently
142            not supported. Default: 0.001.
143        lr_power (float): Learning rate power controls how the learning rate decreases during training, must be less
144            than or equal to zero. Use fixed learning rate if lr_power is zero. Default: -0.5.
145        l1 (float): l1 regularization strength, must be greater than or equal to zero. Default: 0.0.
146        l2 (float): l2 regularization strength, must be greater than or equal to zero. Default: 0.0.
147        use_locking (bool): If true, use locks for updating operation. Default: False.
148        loss_scale (float): Value for the loss scale. It must be greater than 0.0. In general, use the default value.
149            Only when `FixedLossScaleManager` is used for training and the `drop_overflow_update` in
150            `FixedLossScaleManager` is set to False, then this value needs to be the same as the `loss_scale` in
151            `FixedLossScaleManager`. Refer to class :class:`mindspore.FixedLossScaleManager` for more details.
152            Default: 1.0.
153        weight_decay (Union[float, int]): Weight decay value to multiply weight, must be zero or positive value.
154            Default: 0.0.
155
156    Inputs:
157        - **grads** (tuple[Tensor]) - The gradients of `params` in the optimizer, the shape is the same as the `params`
158          in optimizer.
159
160    Outputs:
161        tuple[Parameter], the updated parameters, the shape is the same as `params`.
162
163    Raises:
164        TypeError: If `initial_accum`, `learning_rate`, `lr_power`, `l1`, `l2` or `loss_scale` is not a float.
165        TypeError: If element of `parameters` is neither Parameter nor dict.
166        TypeError: If `weight_decay` is neither float nor int.
167        TypeError: If `use_nesterov` is not a bool.
168        ValueError: If `lr_power` is greater than 0.
169        ValueError: If `loss_scale` is less than or equal to 0.
170        ValueError: If `initial_accum`, `l1` or `l2` is less than 0.
171
172    Supported Platforms:
173        ``Ascend`` ``GPU`` ``CPU``
174
175    Examples:
176        >>> net = Net()
177        >>> #1) All parameters use the same learning rate and weight decay
178        >>> optim = nn.FTRL(params=net.trainable_params())
179        >>>
180        >>> #2) Use parameter groups and set different values
181        >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
182        >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
183        >>> group_params = [{'params': conv_params, 'weight_decay': 0.01, 'grad_centralization':True},
184        ...                 {'params': no_conv_params},
185        ...                 {'order_params': net.trainable_params()}]
186        >>> optim = nn.FTRL(group_params, learning_rate=0.1, weight_decay=0.0)
187        >>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01 and grad
188        >>> # centralization of True.
189        >>> # The no_conv_params's parameters will use default weight decay of 0.0 and grad centralization of False.
190        >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'.
191        >>>
192        >>> loss = nn.SoftmaxCrossEntropyWithLogits()
193        >>> model = Model(net, loss_fn=loss, optimizer=optim)
194    """
195
196    @opt_init_args_register
197    def __init__(self, params, initial_accum=0.1, learning_rate=0.001, lr_power=-0.5, l1=0.0, l2=0.0,
198                 use_locking=False, loss_scale=1.0, weight_decay=0.0):
199        super(FTRL, self).__init__(learning_rate, params, weight_decay, loss_scale=loss_scale)
200        if self.dynamic_lr or self.is_group_lr:
201            raise ValueError('Dynamic learning rate or group learning rate is currently not supported.')
202        _check_param(initial_accum, lr_power, l1, l2, use_locking, self.cls_name)
203        self.moments = self.parameters.clone(prefix="moments", init=initial_accum)
204        self.linear = self.parameters.clone(prefix="linear", init='zeros')
205        self.l1 = l1
206        self.l2 = l2
207        self.lr = learning_rate
208        self.lr_power = lr_power
209        if not self.is_group:
210            self.decay_flags = tuple((lambda: True)() for x in self.parameters)
211        self.opt = P.ApplyFtrl(use_locking=use_locking)
212        self.use_locking = use_locking
213        self.sparse_opt = P.SparseApplyFtrl(learning_rate, l1, l2, lr_power, use_locking=use_locking)
214        self._ps_pull = P.Pull()
215        self._ps_push = P.Push("Ftrl", [0, 1, 2])
216        self._ps_push.add_prim_attr("init_accum", initial_accum)
217        self._ps_push.add_prim_attr("lr", learning_rate)
218        self._ps_push.add_prim_attr("l1", l1)
219        self._ps_push.add_prim_attr("l2", l2)
220        self._ps_push.add_prim_attr("lr_power", lr_power)
221
222    def construct(self, grads):
223        params = self.parameters
224        moments = self.moments
225        linear = self.linear
226        grads = self.decay_weight(grads)
227        grads = self.gradients_centralization(grads)
228        grads = self.scale_grad(grads)
229        grads = self._grad_sparse_indices_deduplicate(grads)
230        lr = self.get_lr()
231
232        success = self.map_(F.partial(_ftrl_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull,
233                                      self.l1, self.l2, self.lr_power, lr),
234                            linear, grads, params, moments, self.ps_parameters, self.cache_enable)
235        return success
236
237    @Optimizer.target.setter
238    def target(self, value):
239        """
240        If the input value is set to "CPU", the parameters will be updated on the host using the Fused
241        optimizer operation.
242        """
243        if not isinstance(value, str):
244            raise TypeError("The value must be str type, but got value type is {}".format(type(value)))
245
246        if value not in ('CPU', 'Ascend', 'GPU'):
247            raise ValueError("The value must be 'CPU', 'Ascend' or 'GPU', but got value {}".format(value))
248
249        if value == 'CPU':
250            self.sparse_opt = P.FusedSparseFtrl(self.lr, self.l1, self.l2, self.lr_power, self.use_locking)
251            self.sparse_opt.add_prim_attr("primitive_target", "CPU")
252        else:
253            self.sparse_opt = P.SparseApplyFtrl(self.lr, self.l1, self.l2, self.lr_power, self.use_locking)
254
255        self._target = value
256