• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Ftrl-proximal optimizer implementation."""
16# pylint: disable=g-classes-have-attributes
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.keras.optimizer_v2 import optimizer_v2
22from tensorflow.python.ops import array_ops
23from tensorflow.python.ops import init_ops
24from tensorflow.python.ops import math_ops
25from tensorflow.python.training import gen_training_ops
26from tensorflow.python.util.tf_export import keras_export
27
28
29@keras_export('keras.optimizers.Ftrl')
30class Ftrl(optimizer_v2.OptimizerV2):
31  r"""Optimizer that implements the FTRL algorithm.
32
33  See Algorithm 1 of this
34  [paper](https://research.google.com/pubs/archive/41159.pdf).
35  This version has support for both online L2 (the L2 penalty given in the paper
36  above) and shrinkage-type L2 (which is the addition of an L2 penalty to the
37  loss function).
38
39  Initialization:
40  $$t = 0$$
41  $$n_{0} = 0$$
42  $$\sigma_{0} = 0$$
43  $$z_{0} = 0$$
44
45  Update ($$i$$ is variable index, $$\alpha$$ is the learning rate):
46  $$t = t + 1$$
47  $$n_{t,i} = n_{t-1,i} + g_{t,i}^{2}$$
48  $$\sigma_{t,i} = (\sqrt{n_{t,i}} - \sqrt{n_{t-1,i}}) / \alpha$$
49  $$z_{t,i} = z_{t-1,i} + g_{t,i} - \sigma_{t,i} * w_{t,i}$$
50  $$w_{t,i} = - ((\beta+\sqrt{n_{t,i}}) / \alpha + 2 * \lambda_{2})^{-1} *
51              (z_{i} - sgn(z_{i}) * \lambda_{1}) if \abs{z_{i}} > \lambda_{i}
52                                                 else 0$$
53
54  Check the documentation for the l2_shrinkage_regularization_strength
55  parameter for more details when shrinkage is enabled, in which case gradient
56  is replaced with gradient_with_shrinkage.
57
58  Args:
59    learning_rate: A `Tensor`, floating point value, or a schedule that is a
60      `tf.keras.optimizers.schedules.LearningRateSchedule`. The learning rate.
61    learning_rate_power: A float value, must be less or equal to zero.
62      Controls how the learning rate decreases during training. Use zero for
63      a fixed learning rate.
64    initial_accumulator_value: The starting value for accumulators.
65      Only zero or positive values are allowed.
66    l1_regularization_strength: A float value, must be greater than or
67      equal to zero. Defaults to 0.0.
68    l2_regularization_strength: A float value, must be greater than or
69      equal to zero. Defaults to 0.0.
70    name: Optional name prefix for the operations created when applying
71      gradients.  Defaults to `"Ftrl"`.
72    l2_shrinkage_regularization_strength: A float value, must be greater than
73      or equal to zero. This differs from L2 above in that the L2 above is a
74      stabilization penalty, whereas this L2 shrinkage is a magnitude penalty.
75      When input is sparse shrinkage will only happen on the active weights.
76    beta: A float value, representing the beta value from the paper.
77      Defaults to 0.0.
78    **kwargs: Keyword arguments. Allowed to be one of
79      `"clipnorm"` or `"clipvalue"`.
80      `"clipnorm"` (float) clips gradients by norm; `"clipvalue"` (float) clips
81      gradients by value.
82
83  Reference:
84    - [paper](
85      https://www.eecs.tufts.edu/~dsculley/papers/ad-click-prediction.pdf)
86  """
87
88  def __init__(self,
89               learning_rate=0.001,
90               learning_rate_power=-0.5,
91               initial_accumulator_value=0.1,
92               l1_regularization_strength=0.0,
93               l2_regularization_strength=0.0,
94               name='Ftrl',
95               l2_shrinkage_regularization_strength=0.0,
96               beta=0.0,
97               **kwargs):
98    super(Ftrl, self).__init__(name, **kwargs)
99
100    if initial_accumulator_value < 0.0:
101      raise ValueError(
102          'initial_accumulator_value %f needs to be positive or zero' %
103          initial_accumulator_value)
104    if learning_rate_power > 0.0:
105      raise ValueError('learning_rate_power %f needs to be negative or zero' %
106                       learning_rate_power)
107    if l1_regularization_strength < 0.0:
108      raise ValueError(
109          'l1_regularization_strength %f needs to be positive or zero' %
110          l1_regularization_strength)
111    if l2_regularization_strength < 0.0:
112      raise ValueError(
113          'l2_regularization_strength %f needs to be positive or zero' %
114          l2_regularization_strength)
115    if l2_shrinkage_regularization_strength < 0.0:
116      raise ValueError(
117          'l2_shrinkage_regularization_strength %f needs to be positive'
118          ' or zero' % l2_shrinkage_regularization_strength)
119
120    self._set_hyper('learning_rate', learning_rate)
121    self._set_hyper('decay', self._initial_decay)
122    self._set_hyper('learning_rate_power', learning_rate_power)
123    self._set_hyper('l1_regularization_strength', l1_regularization_strength)
124    self._set_hyper('l2_regularization_strength', l2_regularization_strength)
125    self._set_hyper('beta', beta)
126    self._initial_accumulator_value = initial_accumulator_value
127    self._l2_shrinkage_regularization_strength = (
128        l2_shrinkage_regularization_strength)
129
130  def _create_slots(self, var_list):
131    # Create the "accum" and "linear" slots.
132    for var in var_list:
133      dtype = var.dtype.base_dtype
134      init = init_ops.constant_initializer(
135          self._initial_accumulator_value, dtype=dtype)
136      self.add_slot(var, 'accumulator', init)
137      self.add_slot(var, 'linear')
138
139  def _prepare_local(self, var_device, var_dtype, apply_state):
140    super(Ftrl, self)._prepare_local(var_device, var_dtype, apply_state)
141    apply_state[(var_device, var_dtype)].update(
142        dict(
143            learning_rate_power=array_ops.identity(
144                self._get_hyper('learning_rate_power', var_dtype)),
145            l1_regularization_strength=array_ops.identity(
146                self._get_hyper('l1_regularization_strength', var_dtype)),
147            l2_regularization_strength=array_ops.identity(
148                self._get_hyper('l2_regularization_strength', var_dtype)),
149            beta=array_ops.identity(self._get_hyper('beta', var_dtype)),
150            l2_shrinkage_regularization_strength=math_ops.cast(
151                self._l2_shrinkage_regularization_strength, var_dtype)))
152
153  def _resource_apply_dense(self, grad, var, apply_state=None):
154    var_device, var_dtype = var.device, var.dtype.base_dtype
155    coefficients = ((apply_state or {}).get((var_device, var_dtype))
156                    or self._fallback_apply_state(var_device, var_dtype))
157
158    # Adjust L2 regularization strength to include beta to avoid the underlying
159    # TensorFlow ops needing to include it.
160    adjusted_l2_regularization_strength = (
161        coefficients['l2_regularization_strength'] + coefficients['beta'] /
162        (2. * coefficients['lr_t']))
163
164    accum = self.get_slot(var, 'accumulator')
165    linear = self.get_slot(var, 'linear')
166
167    if self._l2_shrinkage_regularization_strength <= 0.0:
168      return gen_training_ops.ResourceApplyFtrl(
169          var=var.handle,
170          accum=accum.handle,
171          linear=linear.handle,
172          grad=grad,
173          lr=coefficients['lr_t'],
174          l1=coefficients['l1_regularization_strength'],
175          l2=adjusted_l2_regularization_strength,
176          lr_power=coefficients['learning_rate_power'],
177          use_locking=self._use_locking)
178    else:
179      return gen_training_ops.ResourceApplyFtrlV2(
180          var=var.handle,
181          accum=accum.handle,
182          linear=linear.handle,
183          grad=grad,
184          lr=coefficients['lr_t'],
185          l1=coefficients['l1_regularization_strength'],
186          l2=adjusted_l2_regularization_strength,
187          l2_shrinkage=coefficients['l2_shrinkage_regularization_strength'],
188          lr_power=coefficients['learning_rate_power'],
189          use_locking=self._use_locking)
190
191  def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
192    var_device, var_dtype = var.device, var.dtype.base_dtype
193    coefficients = ((apply_state or {}).get((var_device, var_dtype))
194                    or self._fallback_apply_state(var_device, var_dtype))
195
196    # Adjust L2 regularization strength to include beta to avoid the underlying
197    # TensorFlow ops needing to include it.
198    adjusted_l2_regularization_strength = (
199        coefficients['l2_regularization_strength'] + coefficients['beta'] /
200        (2. * coefficients['lr_t']))
201
202    accum = self.get_slot(var, 'accumulator')
203    linear = self.get_slot(var, 'linear')
204
205    if self._l2_shrinkage_regularization_strength <= 0.0:
206      return gen_training_ops.ResourceSparseApplyFtrl(
207          var=var.handle,
208          accum=accum.handle,
209          linear=linear.handle,
210          grad=grad,
211          indices=indices,
212          lr=coefficients['lr_t'],
213          l1=coefficients['l1_regularization_strength'],
214          l2=adjusted_l2_regularization_strength,
215          lr_power=coefficients['learning_rate_power'],
216          use_locking=self._use_locking)
217    else:
218      return gen_training_ops.ResourceSparseApplyFtrlV2(
219          var=var.handle,
220          accum=accum.handle,
221          linear=linear.handle,
222          grad=grad,
223          indices=indices,
224          lr=coefficients['lr_t'],
225          l1=coefficients['l1_regularization_strength'],
226          l2=adjusted_l2_regularization_strength,
227          l2_shrinkage=coefficients['l2_shrinkage_regularization_strength'],
228          lr_power=coefficients['learning_rate_power'],
229          use_locking=self._use_locking)
230
231  def get_config(self):
232    config = super(Ftrl, self).get_config()
233    config.update({
234        'learning_rate':
235            self._serialize_hyperparameter('learning_rate'),
236        'decay':
237            self._initial_decay,
238        'initial_accumulator_value':
239            self._initial_accumulator_value,
240        'learning_rate_power':
241            self._serialize_hyperparameter('learning_rate_power'),
242        'l1_regularization_strength':
243            self._serialize_hyperparameter('l1_regularization_strength'),
244        'l2_regularization_strength':
245            self._serialize_hyperparameter('l2_regularization_strength'),
246        'beta':
247            self._serialize_hyperparameter('beta'),
248        'l2_shrinkage_regularization_strength':
249            self._l2_shrinkage_regularization_strength,
250    })
251    return config
252