• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Ftrl-proximal for TensorFlow."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from tensorflow.python.keras.optimizer_v2 import optimizer_v2
21from tensorflow.python.ops import array_ops
22from tensorflow.python.ops import init_ops
23from tensorflow.python.ops import math_ops
24from tensorflow.python.training import training_ops
25from tensorflow.python.util.tf_export import keras_export
26
27
28@keras_export('keras.optimizers.Ftrl')
29class Ftrl(optimizer_v2.OptimizerV2):
30  r"""Optimizer that implements the FTRL algorithm.
31
32  See Algorithm 1 of this [paper](
33  https://www.eecs.tufts.edu/~dsculley/papers/ad-click-prediction.pdf).
34  This version has support for both online L2 (the L2 penalty given in the paper
35  above) and shrinkage-type L2 (which is the addition of an L2 penalty to the
36  loss function).
37
38  Initialization:
39  $$t = 0$$
40  $$n_{0} = 0$$
41  $$\sigma_{0} = 0$$
42  $$z_{0} = 0$$
43
44  Update ($$i$$ is variable index):
45  $$t = t + 1$$
46  $$n_{t,i} = n_{t-1,i} + g_{t,i}^{2}$$
47  $$\sigma_{t,i} = (\sqrt{n_{t,i}} - \sqrt{n_{t-1,i}}) / \alpha$$
48  $$z_{t,i} = z_{t-1,i} + g_{t,i} - \sigma_{t,i} * w_{t,i}$$
49  $$w_{t,i} = - ((\beta+\sqrt{n+{t}}) / \alpha + \lambda_{2})^{-1} * (z_{i} -
50               sgn(z_{i}) * \lambda_{1}) if \abs{z_{i}} > \lambda_{i} else 0$$
51
52  Check the documentation for the l2_shrinkage_regularization_strength
53  parameter for more details when shrinkage is enabled, where gradient is
54  replaced with gradient_with_shrinkage.
55  """
56
57  def __init__(self,
58               learning_rate=0.001,
59               learning_rate_power=-0.5,
60               initial_accumulator_value=0.1,
61               l1_regularization_strength=0.0,
62               l2_regularization_strength=0.0,
63               name='Ftrl',
64               l2_shrinkage_regularization_strength=0.0,
65               **kwargs):
66    r"""Construct a new FTRL optimizer.
67
68    Args:
69      learning_rate: A `Tensor`, floating point value, or a schedule that is a
70        `tf.keras.optimizers.schedules.LearningRateSchedule`. The learning rate.
71      learning_rate_power: A float value, must be less or equal to zero.
72        Controls how the learning rate decreases during training. Use zero for
73        a fixed learning rate.
74      initial_accumulator_value: The starting value for accumulators.
75        Only zero or positive values are allowed.
76      l1_regularization_strength: A float value, must be greater than or
77        equal to zero.
78      l2_regularization_strength: A float value, must be greater than or
79        equal to zero.
80      name: Optional name prefix for the operations created when applying
81        gradients.  Defaults to "Ftrl".
82      l2_shrinkage_regularization_strength: A float value, must be greater than
83        or equal to zero. This differs from L2 above in that the L2 above is a
84        stabilization penalty, whereas this L2 shrinkage is a magnitude penalty.
85        The FTRL formulation can be written as:
86        w_{t+1} = argmin_w(\hat{g}_{1:t}w + L1*||w||_1 + L2*||w||_2^2), where
87        \hat{g} = g + (2*L2_shrinkage*w), and g is the gradient of the loss
88        function w.r.t. the weights w.
89        Specifically, in the absence of L1 regularization, it is equivalent to
90        the following update rule:
91        w_{t+1} = w_t - lr_t / (1 + 2*L2*lr_t) * g_t -
92                  2*L2_shrinkage*lr_t / (1 + 2*L2*lr_t) * w_t
93        where lr_t is the learning rate at t.
94        When input is sparse shrinkage will only happen on the active weights.\
95      **kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, `lr`,
96        `decay`}. `clipnorm` is clip gradients by norm; `clipvalue` is clip
97        gradients by value, `decay` is included for backward compatibility to
98        allow time inverse decay of learning rate. `lr` is included for backward
99        compatibility, recommended to use `learning_rate` instead.
100
101    Raises:
102      ValueError: If one of the arguments is invalid.
103
104    References
105      See [paper]
106        (https://www.eecs.tufts.edu/~dsculley/papers/ad-click-prediction.pdf)
107    """
108    super(Ftrl, self).__init__(name, **kwargs)
109
110    if initial_accumulator_value < 0.0:
111      raise ValueError(
112          'initial_accumulator_value %f needs to be positive or zero' %
113          initial_accumulator_value)
114    if learning_rate_power > 0.0:
115      raise ValueError('learning_rate_power %f needs to be negative or zero' %
116                       learning_rate_power)
117    if l1_regularization_strength < 0.0:
118      raise ValueError(
119          'l1_regularization_strength %f needs to be positive or zero' %
120          l1_regularization_strength)
121    if l2_regularization_strength < 0.0:
122      raise ValueError(
123          'l2_regularization_strength %f needs to be positive or zero' %
124          l2_regularization_strength)
125    if l2_shrinkage_regularization_strength < 0.0:
126      raise ValueError(
127          'l2_shrinkage_regularization_strength %f needs to be positive'
128          ' or zero' % l2_shrinkage_regularization_strength)
129
130    self._set_hyper('learning_rate', learning_rate)
131    self._set_hyper('decay', self._initial_decay)
132    self._set_hyper('learning_rate_power', learning_rate_power)
133    self._set_hyper('l1_regularization_strength', l1_regularization_strength)
134    self._set_hyper('l2_regularization_strength', l2_regularization_strength)
135    self._initial_accumulator_value = initial_accumulator_value
136    self._l2_shrinkage_regularization_strength = (
137        l2_shrinkage_regularization_strength)
138
139  def _create_slots(self, var_list):
140    # Create the "accum" and "linear" slots.
141    for var in var_list:
142      dtype = var.dtype.base_dtype
143      init = init_ops.constant_initializer(
144          self._initial_accumulator_value, dtype=dtype)
145      self.add_slot(var, 'accumulator', init)
146      self.add_slot(var, 'linear')
147
148  def _prepare_local(self, var_device, var_dtype, apply_state):
149    super(Ftrl, self)._prepare_local(var_device, var_dtype, apply_state)
150    apply_state[(var_device, var_dtype)].update(dict(
151        learning_rate_power=array_ops.identity(
152            self._get_hyper('learning_rate_power', var_dtype)),
153        l1_regularization_strength=array_ops.identity(
154            self._get_hyper('l1_regularization_strength', var_dtype)),
155        l2_regularization_strength=array_ops.identity(
156            self._get_hyper('l2_regularization_strength', var_dtype)),
157        l2_shrinkage_regularization_strength=math_ops.cast(
158            self._l2_shrinkage_regularization_strength, var_dtype)
159        ))
160
161  def _resource_apply_dense(self, grad, var, apply_state=None):
162    var_device, var_dtype = var.device, var.dtype.base_dtype
163    coefficients = ((apply_state or {}).get((var_device, var_dtype))
164                    or self._fallback_apply_state(var_device, var_dtype))
165
166    accum = self.get_slot(var, 'accumulator')
167    linear = self.get_slot(var, 'linear')
168
169    if self._l2_shrinkage_regularization_strength <= 0.0:
170      return training_ops.resource_apply_ftrl(
171          var.handle,
172          accum.handle,
173          linear.handle,
174          grad,
175          coefficients['lr_t'],
176          coefficients['l1_regularization_strength'],
177          coefficients['l2_regularization_strength'],
178          coefficients['learning_rate_power'],
179          use_locking=self._use_locking)
180    else:
181      return training_ops.resource_apply_ftrl_v2(
182          var.handle,
183          accum.handle,
184          linear.handle,
185          grad,
186          coefficients['lr_t'],
187          coefficients['l1_regularization_strength'],
188          coefficients['l2_regularization_strength'],
189          coefficients['l2_shrinkage_regularization_strength'],
190          coefficients['learning_rate_power'],
191          use_locking=self._use_locking)
192
193  def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
194    var_device, var_dtype = var.device, var.dtype.base_dtype
195    coefficients = ((apply_state or {}).get((var_device, var_dtype))
196                    or self._fallback_apply_state(var_device, var_dtype))
197
198    accum = self.get_slot(var, 'accumulator')
199    linear = self.get_slot(var, 'linear')
200
201    if self._l2_shrinkage_regularization_strength <= 0.0:
202      return training_ops.resource_sparse_apply_ftrl(
203          var.handle,
204          accum.handle,
205          linear.handle,
206          grad,
207          indices,
208          coefficients['lr_t'],
209          coefficients['l1_regularization_strength'],
210          coefficients['l2_regularization_strength'],
211          coefficients['learning_rate_power'],
212          use_locking=self._use_locking)
213    else:
214      return training_ops.resource_sparse_apply_ftrl_v2(
215          var.handle,
216          accum.handle,
217          linear.handle,
218          grad,
219          indices,
220          coefficients['lr_t'],
221          coefficients['l1_regularization_strength'],
222          coefficients['l2_regularization_strength'],
223          coefficients['l2_shrinkage_regularization_strength'],
224          coefficients['learning_rate_power'],
225          use_locking=self._use_locking)
226
227  def get_config(self):
228    config = super(Ftrl, self).get_config()
229    config.update({
230        'learning_rate':
231            self._serialize_hyperparameter('learning_rate'),
232        'decay':
233            self._serialize_hyperparameter('decay'),
234        'initial_accumulator_value':
235            self._initial_accumulator_value,
236        'learning_rate_power':
237            self._serialize_hyperparameter('learning_rate_power'),
238        'l1_regularization_strength':
239            self._serialize_hyperparameter('l1_regularization_strength'),
240        'l2_regularization_strength':
241            self._serialize_hyperparameter('l2_regularization_strength'),
242        'l2_shrinkage_regularization_strength':
243            self._l2_shrinkage_regularization_strength,
244    })
245    return config
246