• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Ftrl-proximal optimizer implementation."""
16# pylint: disable=g-classes-have-attributes
17
18from tensorflow.python.keras.optimizer_v2 import optimizer_v2
19from tensorflow.python.ops import array_ops
20from tensorflow.python.ops import init_ops
21from tensorflow.python.ops import math_ops
22from tensorflow.python.training import gen_training_ops
23from tensorflow.python.util.tf_export import keras_export
24
25
26@keras_export('keras.optimizers.Ftrl')
27class Ftrl(optimizer_v2.OptimizerV2):
28  r"""Optimizer that implements the FTRL algorithm.
29
30  "Follow The Regularized Leader" (FTRL) is an optimization algorithm developed
31  at Google for click-through rate prediction in the early 2010s. It is most
32  suitable for shallow models with large and sparse feature spaces.
33  The algorithm is described by
34  [McMahan et al., 2013](https://research.google.com/pubs/archive/41159.pdf).
35  The Keras version has support for both online L2 regularization
36  (the L2 regularization described in the paper
37  above) and shrinkage-type L2 regularization
38  (which is the addition of an L2 penalty to the loss function).
39
40  Initialization:
41
42  ```python
43  n = 0
44  sigma = 0
45  z = 0
46  ```
47
48  Update rule for one variable `w`:
49
50  ```python
51  prev_n = n
52  n = n + g ** 2
53  sigma = (sqrt(n) - sqrt(prev_n)) / lr
54  z = z + g - sigma * w
55  if abs(z) < lambda_1:
56    w = 0
57  else:
58    w = (sgn(z) * lambda_1 - z) / ((beta + sqrt(n)) / alpha + lambda_2)
59  ```
60
61  Notation:
62
63  - `lr` is the learning rate
64  - `g` is the gradient for the variable
65  - `lambda_1` is the L1 regularization strength
66  - `lambda_2` is the L2 regularization strength
67
68  Check the documentation for the `l2_shrinkage_regularization_strength`
69  parameter for more details when shrinkage is enabled, in which case gradient
70  is replaced with a gradient with shrinkage.
71
72  Args:
73    learning_rate: A `Tensor`, floating point value, or a schedule that is a
74      `tf.keras.optimizers.schedules.LearningRateSchedule`. The learning rate.
75    learning_rate_power: A float value, must be less or equal to zero.
76      Controls how the learning rate decreases during training. Use zero for
77      a fixed learning rate.
78    initial_accumulator_value: The starting value for accumulators.
79      Only zero or positive values are allowed.
80    l1_regularization_strength: A float value, must be greater than or
81      equal to zero. Defaults to 0.0.
82    l2_regularization_strength: A float value, must be greater than or
83      equal to zero. Defaults to 0.0.
84    name: Optional name prefix for the operations created when applying
85      gradients.  Defaults to `"Ftrl"`.
86    l2_shrinkage_regularization_strength: A float value, must be greater than
87      or equal to zero. This differs from L2 above in that the L2 above is a
88      stabilization penalty, whereas this L2 shrinkage is a magnitude penalty.
89      When input is sparse shrinkage will only happen on the active weights.
90    beta: A float value, representing the beta value from the paper.
91      Defaults to 0.0.
92    **kwargs: Keyword arguments. Allowed to be one of
93      `"clipnorm"` or `"clipvalue"`.
94      `"clipnorm"` (float) clips gradients by norm; `"clipvalue"` (float) clips
95      gradients by value.
96
97  Reference:
98    - [McMahan et al., 2013](
99      https://research.google.com/pubs/archive/41159.pdf)
100  """
101
102  def __init__(self,
103               learning_rate=0.001,
104               learning_rate_power=-0.5,
105               initial_accumulator_value=0.1,
106               l1_regularization_strength=0.0,
107               l2_regularization_strength=0.0,
108               name='Ftrl',
109               l2_shrinkage_regularization_strength=0.0,
110               beta=0.0,
111               **kwargs):
112    super(Ftrl, self).__init__(name, **kwargs)
113
114    if initial_accumulator_value < 0.0:
115      raise ValueError(
116          'initial_accumulator_value %f needs to be positive or zero' %
117          initial_accumulator_value)
118    if learning_rate_power > 0.0:
119      raise ValueError('learning_rate_power %f needs to be negative or zero' %
120                       learning_rate_power)
121    if l1_regularization_strength < 0.0:
122      raise ValueError(
123          'l1_regularization_strength %f needs to be positive or zero' %
124          l1_regularization_strength)
125    if l2_regularization_strength < 0.0:
126      raise ValueError(
127          'l2_regularization_strength %f needs to be positive or zero' %
128          l2_regularization_strength)
129    if l2_shrinkage_regularization_strength < 0.0:
130      raise ValueError(
131          'l2_shrinkage_regularization_strength %f needs to be positive'
132          ' or zero' % l2_shrinkage_regularization_strength)
133
134    self._set_hyper('learning_rate', learning_rate)
135    self._set_hyper('decay', self._initial_decay)
136    self._set_hyper('learning_rate_power', learning_rate_power)
137    self._set_hyper('l1_regularization_strength', l1_regularization_strength)
138    self._set_hyper('l2_regularization_strength', l2_regularization_strength)
139    self._set_hyper('beta', beta)
140    self._initial_accumulator_value = initial_accumulator_value
141    self._l2_shrinkage_regularization_strength = (
142        l2_shrinkage_regularization_strength)
143
144  def _create_slots(self, var_list):
145    # Create the "accum" and "linear" slots.
146    for var in var_list:
147      dtype = var.dtype.base_dtype
148      init = init_ops.constant_initializer(
149          self._initial_accumulator_value, dtype=dtype)
150      self.add_slot(var, 'accumulator', init)
151      self.add_slot(var, 'linear')
152
153  def _prepare_local(self, var_device, var_dtype, apply_state):
154    super(Ftrl, self)._prepare_local(var_device, var_dtype, apply_state)
155    apply_state[(var_device, var_dtype)].update(
156        dict(
157            learning_rate_power=array_ops.identity(
158                self._get_hyper('learning_rate_power', var_dtype)),
159            l1_regularization_strength=array_ops.identity(
160                self._get_hyper('l1_regularization_strength', var_dtype)),
161            l2_regularization_strength=array_ops.identity(
162                self._get_hyper('l2_regularization_strength', var_dtype)),
163            beta=array_ops.identity(self._get_hyper('beta', var_dtype)),
164            l2_shrinkage_regularization_strength=math_ops.cast(
165                self._l2_shrinkage_regularization_strength, var_dtype)))
166
167  def _resource_apply_dense(self, grad, var, apply_state=None):
168    var_device, var_dtype = var.device, var.dtype.base_dtype
169    coefficients = ((apply_state or {}).get((var_device, var_dtype))
170                    or self._fallback_apply_state(var_device, var_dtype))
171
172    # Adjust L2 regularization strength to include beta to avoid the underlying
173    # TensorFlow ops needing to include it.
174    adjusted_l2_regularization_strength = (
175        coefficients['l2_regularization_strength'] + coefficients['beta'] /
176        (2. * coefficients['lr_t']))
177
178    accum = self.get_slot(var, 'accumulator')
179    linear = self.get_slot(var, 'linear')
180
181    if self._l2_shrinkage_regularization_strength <= 0.0:
182      return gen_training_ops.ResourceApplyFtrl(
183          var=var.handle,
184          accum=accum.handle,
185          linear=linear.handle,
186          grad=grad,
187          lr=coefficients['lr_t'],
188          l1=coefficients['l1_regularization_strength'],
189          l2=adjusted_l2_regularization_strength,
190          lr_power=coefficients['learning_rate_power'],
191          use_locking=self._use_locking)
192    else:
193      return gen_training_ops.ResourceApplyFtrlV2(
194          var=var.handle,
195          accum=accum.handle,
196          linear=linear.handle,
197          grad=grad,
198          lr=coefficients['lr_t'],
199          l1=coefficients['l1_regularization_strength'],
200          l2=adjusted_l2_regularization_strength,
201          l2_shrinkage=coefficients['l2_shrinkage_regularization_strength'],
202          lr_power=coefficients['learning_rate_power'],
203          use_locking=self._use_locking)
204
205  def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
206    var_device, var_dtype = var.device, var.dtype.base_dtype
207    coefficients = ((apply_state or {}).get((var_device, var_dtype))
208                    or self._fallback_apply_state(var_device, var_dtype))
209
210    # Adjust L2 regularization strength to include beta to avoid the underlying
211    # TensorFlow ops needing to include it.
212    adjusted_l2_regularization_strength = (
213        coefficients['l2_regularization_strength'] + coefficients['beta'] /
214        (2. * coefficients['lr_t']))
215
216    accum = self.get_slot(var, 'accumulator')
217    linear = self.get_slot(var, 'linear')
218
219    if self._l2_shrinkage_regularization_strength <= 0.0:
220      return gen_training_ops.ResourceSparseApplyFtrl(
221          var=var.handle,
222          accum=accum.handle,
223          linear=linear.handle,
224          grad=grad,
225          indices=indices,
226          lr=coefficients['lr_t'],
227          l1=coefficients['l1_regularization_strength'],
228          l2=adjusted_l2_regularization_strength,
229          lr_power=coefficients['learning_rate_power'],
230          use_locking=self._use_locking)
231    else:
232      return gen_training_ops.ResourceSparseApplyFtrlV2(
233          var=var.handle,
234          accum=accum.handle,
235          linear=linear.handle,
236          grad=grad,
237          indices=indices,
238          lr=coefficients['lr_t'],
239          l1=coefficients['l1_regularization_strength'],
240          l2=adjusted_l2_regularization_strength,
241          l2_shrinkage=coefficients['l2_shrinkage_regularization_strength'],
242          lr_power=coefficients['learning_rate_power'],
243          use_locking=self._use_locking)
244
245  def get_config(self):
246    config = super(Ftrl, self).get_config()
247    config.update({
248        'learning_rate':
249            self._serialize_hyperparameter('learning_rate'),
250        'decay':
251            self._initial_decay,
252        'initial_accumulator_value':
253            self._initial_accumulator_value,
254        'learning_rate_power':
255            self._serialize_hyperparameter('learning_rate_power'),
256        'l1_regularization_strength':
257            self._serialize_hyperparameter('l1_regularization_strength'),
258        'l2_regularization_strength':
259            self._serialize_hyperparameter('l2_regularization_strength'),
260        'beta':
261            self._serialize_hyperparameter('beta'),
262        'l2_shrinkage_regularization_strength':
263            self._l2_shrinkage_regularization_strength,
264    })
265    return config
266