• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Ftrl-proximal for TensorFlow."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from tensorflow.python.framework import constant_op
21from tensorflow.python.framework import ops
22from tensorflow.python.ops import math_ops
23from tensorflow.python.training import optimizer
24from tensorflow.python.training import training_ops
25from tensorflow.python.util.tf_export import tf_export
26
27
28@tf_export(v1=["train.FtrlOptimizer"])
29class FtrlOptimizer(optimizer.Optimizer):
30  """Optimizer that implements the FTRL algorithm.
31
32  This version has support for both online L2 (McMahan et al., 2013) and
33  shrinkage-type L2, which is the addition of an L2 penalty
34  to the loss function.
35
36  References:
37    Ad-click prediction:
38      [McMahan et al., 2013](https://dl.acm.org/citation.cfm?id=2488200)
39      ([pdf](https://dl.acm.org/ft_gateway.cfm?id=2488200&ftid=1388399&dwn=1&CFID=32233078&CFTOKEN=d60fe57a294c056a-CB75C374-F915-E7A6-1573FBBC7BF7D526))
40  """
41
42  def __init__(self,
43               learning_rate,
44               learning_rate_power=-0.5,
45               initial_accumulator_value=0.1,
46               l1_regularization_strength=0.0,
47               l2_regularization_strength=0.0,
48               use_locking=False,
49               name="Ftrl",
50               accum_name=None,
51               linear_name=None,
52               l2_shrinkage_regularization_strength=0.0,
53               beta=None):
54    r"""Construct a new FTRL optimizer.
55
56    Args:
57      learning_rate: A float value or a constant float `Tensor`.
58      learning_rate_power: A float value, must be less or equal to zero.
59        Controls how the learning rate decreases during training. Use zero for
60        a fixed learning rate. See section 3.1 in (McMahan et al., 2013).
61      initial_accumulator_value: The starting value for accumulators.
62        Only zero or positive values are allowed.
63      l1_regularization_strength: A float value, must be greater than or
64        equal to zero.
65      l2_regularization_strength: A float value, must be greater than or
66        equal to zero.
67      use_locking: If `True` use locks for update operations.
68      name: Optional name prefix for the operations created when applying
69        gradients.  Defaults to "Ftrl".
70      accum_name: The suffix for the variable that keeps the gradient squared
71        accumulator.  If not present, defaults to name.
72      linear_name: The suffix for the variable that keeps the linear gradient
73        accumulator.  If not present, defaults to name + "_1".
74      l2_shrinkage_regularization_strength: A float value, must be greater than
75        or equal to zero. This differs from L2 above in that the L2 above is a
76        stabilization penalty, whereas this L2 shrinkage is a magnitude penalty.
77        The FTRL formulation can be written as:
78        w_{t+1} = argmin_w(\hat{g}_{1:t}w + L1*||w||_1 + L2*||w||_2^2), where
79        \hat{g} = g + (2*L2_shrinkage*w), and g is the gradient of the loss
80        function w.r.t. the weights w.
81        Specifically, in the absence of L1 regularization, it is equivalent to
82        the following update rule:
83        w_{t+1} = w_t - lr_t / (beta + 2*L2*lr_t) * g_t -
84                  2*L2_shrinkage*lr_t / (beta + 2*L2*lr_t) * w_t
85        where lr_t is the learning rate at t.
86        When input is sparse shrinkage will only happen on the active weights.
87      beta: A float value; corresponds to the beta parameter in the paper.
88
89    Raises:
90      ValueError: If one of the arguments is invalid.
91
92    References:
93      Ad-click prediction:
94        [McMahan et al., 2013](https://dl.acm.org/citation.cfm?id=2488200)
95        ([pdf](https://dl.acm.org/ft_gateway.cfm?id=2488200&ftid=1388399&dwn=1&CFID=32233078&CFTOKEN=d60fe57a294c056a-CB75C374-F915-E7A6-1573FBBC7BF7D526))
96    """
97    super(FtrlOptimizer, self).__init__(use_locking, name)
98
99    if initial_accumulator_value < 0.0:
100      raise ValueError(
101          "initial_accumulator_value %f needs to be positive or zero" %
102          initial_accumulator_value)
103    if learning_rate_power > 0.0:
104      raise ValueError("learning_rate_power %f needs to be negative or zero" %
105                       learning_rate_power)
106    if l1_regularization_strength < 0.0:
107      raise ValueError(
108          "l1_regularization_strength %f needs to be positive or zero" %
109          l1_regularization_strength)
110    if l2_regularization_strength < 0.0:
111      raise ValueError(
112          "l2_regularization_strength %f needs to be positive or zero" %
113          l2_regularization_strength)
114    if l2_shrinkage_regularization_strength < 0.0:
115      raise ValueError(
116          "l2_shrinkage_regularization_strength %f needs to be positive"
117          " or zero" % l2_shrinkage_regularization_strength)
118
119    self._learning_rate = learning_rate
120    self._learning_rate_power = learning_rate_power
121    self._initial_accumulator_value = initial_accumulator_value
122    self._l1_regularization_strength = l1_regularization_strength
123    self._l2_regularization_strength = l2_regularization_strength
124    self._beta = (0.0 if beta is None else beta)
125    self._l2_shrinkage_regularization_strength = (
126        l2_shrinkage_regularization_strength)
127    self._learning_rate_tensor = None
128    self._learning_rate_power_tensor = None
129    self._l1_regularization_strength_tensor = None
130    self._adjusted_l2_regularization_strength_tensor = None
131    self._l2_shrinkage_regularization_strength_tensor = None
132    self._accum_name = accum_name
133    self._linear_name = linear_name
134
135  def _create_slots(self, var_list):
136    # Create the "accum" and "linear" slots.
137    for v in var_list:
138      val = constant_op.constant(
139          self._initial_accumulator_value, dtype=v.dtype, shape=v.get_shape())
140      self._get_or_make_slot(v, val, "accum", self._accum_name or self._name)
141      self._zeros_slot(v, "linear", self._linear_name or self._name)
142
143  def _prepare(self):
144    self._learning_rate_tensor = ops.convert_to_tensor(
145        self._learning_rate, name="learning_rate")
146    self._l1_regularization_strength_tensor = ops.convert_to_tensor(
147        self._l1_regularization_strength, name="l1_regularization_strength")
148    # L2 regularization strength with beta added in so that the underlying
149    # TensorFlow ops do not need to include that parameter.
150    self._adjusted_l2_regularization_strength_tensor = ops.convert_to_tensor(
151        self._l2_regularization_strength + self._beta /
152        (2. * math_ops.maximum(self._learning_rate, 1e-36)),
153        name="adjusted_l2_regularization_strength")
154    assert self._adjusted_l2_regularization_strength_tensor is not None
155    self._beta_tensor = ops.convert_to_tensor(self._beta, name="beta")
156    self._l2_shrinkage_regularization_strength_tensor = ops.convert_to_tensor(
157        self._l2_shrinkage_regularization_strength,
158        name="l2_shrinkage_regularization_strength")
159    self._learning_rate_power_tensor = ops.convert_to_tensor(
160        self._learning_rate_power, name="learning_rate_power")
161
162  def _apply_dense(self, grad, var):
163    accum = self.get_slot(var, "accum")
164    linear = self.get_slot(var, "linear")
165    if self._l2_shrinkage_regularization_strength <= 0.0:
166      return training_ops.apply_ftrl(
167          var,
168          accum,
169          linear,
170          grad,
171          math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype),
172          math_ops.cast(self._l1_regularization_strength_tensor,
173                        var.dtype.base_dtype),
174          math_ops.cast(self._adjusted_l2_regularization_strength_tensor,
175                        var.dtype.base_dtype),
176          math_ops.cast(self._learning_rate_power_tensor, var.dtype.base_dtype),
177          use_locking=self._use_locking)
178    else:
179      return training_ops.apply_ftrl_v2(
180          var,
181          accum,
182          linear,
183          grad,
184          math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype),
185          math_ops.cast(self._l1_regularization_strength_tensor,
186                        var.dtype.base_dtype),
187          math_ops.cast(self._adjusted_l2_regularization_strength_tensor,
188                        var.dtype.base_dtype),
189          math_ops.cast(self._l2_shrinkage_regularization_strength_tensor,
190                        var.dtype.base_dtype),
191          math_ops.cast(self._learning_rate_power_tensor, var.dtype.base_dtype),
192          use_locking=self._use_locking)
193
194  def _resource_apply_dense(self, grad, var):
195    accum = self.get_slot(var, "accum")
196    linear = self.get_slot(var, "linear")
197    if self._l2_shrinkage_regularization_strength <= 0.0:
198      return training_ops.resource_apply_ftrl(
199          var.handle,
200          accum.handle,
201          linear.handle,
202          grad,
203          math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype),
204          math_ops.cast(self._l1_regularization_strength_tensor,
205                        var.dtype.base_dtype),
206          math_ops.cast(self._adjusted_l2_regularization_strength_tensor,
207                        var.dtype.base_dtype),
208          math_ops.cast(self._learning_rate_power_tensor, var.dtype.base_dtype),
209          use_locking=self._use_locking)
210    else:
211      return training_ops.resource_apply_ftrl_v2(
212          var.handle,
213          accum.handle,
214          linear.handle,
215          grad,
216          math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype),
217          math_ops.cast(self._l1_regularization_strength_tensor,
218                        var.dtype.base_dtype),
219          math_ops.cast(self._adjusted_l2_regularization_strength_tensor,
220                        var.dtype.base_dtype),
221          math_ops.cast(self._l2_shrinkage_regularization_strength_tensor,
222                        var.dtype.base_dtype),
223          math_ops.cast(self._learning_rate_power_tensor, var.dtype.base_dtype),
224          use_locking=self._use_locking)
225
226  def _apply_sparse(self, grad, var):
227    accum = self.get_slot(var, "accum")
228    linear = self.get_slot(var, "linear")
229    if self._l2_shrinkage_regularization_strength <= 0.0:
230      return training_ops.sparse_apply_ftrl(
231          var,
232          accum,
233          linear,
234          grad.values,
235          grad.indices,
236          math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype),
237          math_ops.cast(self._l1_regularization_strength_tensor,
238                        var.dtype.base_dtype),
239          math_ops.cast(self._adjusted_l2_regularization_strength_tensor,
240                        var.dtype.base_dtype),
241          math_ops.cast(self._learning_rate_power_tensor, var.dtype.base_dtype),
242          use_locking=self._use_locking)
243    else:
244      return training_ops.sparse_apply_ftrl_v2(
245          var,
246          accum,
247          linear,
248          grad.values,
249          grad.indices,
250          math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype),
251          math_ops.cast(self._l1_regularization_strength_tensor,
252                        var.dtype.base_dtype),
253          math_ops.cast(self._adjusted_l2_regularization_strength_tensor,
254                        var.dtype.base_dtype),
255          math_ops.cast(self._l2_shrinkage_regularization_strength_tensor,
256                        grad.dtype.base_dtype),
257          math_ops.cast(self._learning_rate_power_tensor, var.dtype.base_dtype),
258          use_locking=self._use_locking)
259
260  def _resource_apply_sparse(self, grad, var, indices):
261    accum = self.get_slot(var, "accum")
262    linear = self.get_slot(var, "linear")
263    if self._l2_shrinkage_regularization_strength <= 0.0:
264      return training_ops.resource_sparse_apply_ftrl(
265          var.handle,
266          accum.handle,
267          linear.handle,
268          grad,
269          indices,
270          math_ops.cast(self._learning_rate_tensor, grad.dtype),
271          math_ops.cast(self._l1_regularization_strength_tensor, grad.dtype),
272          math_ops.cast(self._adjusted_l2_regularization_strength_tensor,
273                        grad.dtype),
274          math_ops.cast(self._learning_rate_power_tensor, grad.dtype),
275          use_locking=self._use_locking)
276    else:
277      return training_ops.resource_sparse_apply_ftrl_v2(
278          var.handle,
279          accum.handle,
280          linear.handle,
281          grad,
282          indices,
283          math_ops.cast(self._learning_rate_tensor, grad.dtype),
284          math_ops.cast(self._l1_regularization_strength_tensor, grad.dtype),
285          math_ops.cast(self._adjusted_l2_regularization_strength_tensor,
286                        grad.dtype),
287          math_ops.cast(self._l2_shrinkage_regularization_strength_tensor,
288                        grad.dtype),
289          math_ops.cast(self._learning_rate_power_tensor, grad.dtype),
290          use_locking=self._use_locking)
291