1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Ftrl-proximal for TensorFlow.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from tensorflow.python.keras.optimizer_v2 import optimizer_v2 21from tensorflow.python.ops import array_ops 22from tensorflow.python.ops import init_ops 23from tensorflow.python.ops import math_ops 24from tensorflow.python.training import training_ops 25from tensorflow.python.util.tf_export import keras_export 26 27 28@keras_export('keras.optimizers.Ftrl') 29class Ftrl(optimizer_v2.OptimizerV2): 30 r"""Optimizer that implements the FTRL algorithm. 31 32 See Algorithm 1 of this [paper]( 33 https://www.eecs.tufts.edu/~dsculley/papers/ad-click-prediction.pdf). 34 This version has support for both online L2 (the L2 penalty given in the paper 35 above) and shrinkage-type L2 (which is the addition of an L2 penalty to the 36 loss function). 37 38 Initialization: 39 $$t = 0$$ 40 $$n_{0} = 0$$ 41 $$\sigma_{0} = 0$$ 42 $$z_{0} = 0$$ 43 44 Update ($$i$$ is variable index): 45 $$t = t + 1$$ 46 $$n_{t,i} = n_{t-1,i} + g_{t,i}^{2}$$ 47 $$\sigma_{t,i} = (\sqrt{n_{t,i}} - \sqrt{n_{t-1,i}}) / \alpha$$ 48 $$z_{t,i} = z_{t-1,i} + g_{t,i} - \sigma_{t,i} * w_{t,i}$$ 49 $$w_{t,i} = - ((\beta+\sqrt{n+{t}}) / \alpha + \lambda_{2})^{-1} * (z_{i} - 50 sgn(z_{i}) * \lambda_{1}) if \abs{z_{i}} > \lambda_{i} else 0$$ 51 52 Check the documentation for the l2_shrinkage_regularization_strength 53 parameter for more details when shrinkage is enabled, where gradient is 54 replaced with gradient_with_shrinkage. 55 """ 56 57 def __init__(self, 58 learning_rate=0.001, 59 learning_rate_power=-0.5, 60 initial_accumulator_value=0.1, 61 l1_regularization_strength=0.0, 62 l2_regularization_strength=0.0, 63 name='Ftrl', 64 l2_shrinkage_regularization_strength=0.0, 65 **kwargs): 66 r"""Construct a new FTRL optimizer. 67 68 Args: 69 learning_rate: A `Tensor`, floating point value, or a schedule that is a 70 `tf.keras.optimizers.schedules.LearningRateSchedule`. The learning rate. 71 learning_rate_power: A float value, must be less or equal to zero. 72 Controls how the learning rate decreases during training. Use zero for 73 a fixed learning rate. 74 initial_accumulator_value: The starting value for accumulators. 75 Only zero or positive values are allowed. 76 l1_regularization_strength: A float value, must be greater than or 77 equal to zero. 78 l2_regularization_strength: A float value, must be greater than or 79 equal to zero. 80 name: Optional name prefix for the operations created when applying 81 gradients. Defaults to "Ftrl". 82 l2_shrinkage_regularization_strength: A float value, must be greater than 83 or equal to zero. This differs from L2 above in that the L2 above is a 84 stabilization penalty, whereas this L2 shrinkage is a magnitude penalty. 85 The FTRL formulation can be written as: 86 w_{t+1} = argmin_w(\hat{g}_{1:t}w + L1*||w||_1 + L2*||w||_2^2), where 87 \hat{g} = g + (2*L2_shrinkage*w), and g is the gradient of the loss 88 function w.r.t. the weights w. 89 Specifically, in the absence of L1 regularization, it is equivalent to 90 the following update rule: 91 w_{t+1} = w_t - lr_t / (1 + 2*L2*lr_t) * g_t - 92 2*L2_shrinkage*lr_t / (1 + 2*L2*lr_t) * w_t 93 where lr_t is the learning rate at t. 94 When input is sparse shrinkage will only happen on the active weights.\ 95 **kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, `lr`, 96 `decay`}. `clipnorm` is clip gradients by norm; `clipvalue` is clip 97 gradients by value, `decay` is included for backward compatibility to 98 allow time inverse decay of learning rate. `lr` is included for backward 99 compatibility, recommended to use `learning_rate` instead. 100 101 Raises: 102 ValueError: If one of the arguments is invalid. 103 104 References 105 See [paper] 106 (https://www.eecs.tufts.edu/~dsculley/papers/ad-click-prediction.pdf) 107 """ 108 super(Ftrl, self).__init__(name, **kwargs) 109 110 if initial_accumulator_value < 0.0: 111 raise ValueError( 112 'initial_accumulator_value %f needs to be positive or zero' % 113 initial_accumulator_value) 114 if learning_rate_power > 0.0: 115 raise ValueError('learning_rate_power %f needs to be negative or zero' % 116 learning_rate_power) 117 if l1_regularization_strength < 0.0: 118 raise ValueError( 119 'l1_regularization_strength %f needs to be positive or zero' % 120 l1_regularization_strength) 121 if l2_regularization_strength < 0.0: 122 raise ValueError( 123 'l2_regularization_strength %f needs to be positive or zero' % 124 l2_regularization_strength) 125 if l2_shrinkage_regularization_strength < 0.0: 126 raise ValueError( 127 'l2_shrinkage_regularization_strength %f needs to be positive' 128 ' or zero' % l2_shrinkage_regularization_strength) 129 130 self._set_hyper('learning_rate', learning_rate) 131 self._set_hyper('decay', self._initial_decay) 132 self._set_hyper('learning_rate_power', learning_rate_power) 133 self._set_hyper('l1_regularization_strength', l1_regularization_strength) 134 self._set_hyper('l2_regularization_strength', l2_regularization_strength) 135 self._initial_accumulator_value = initial_accumulator_value 136 self._l2_shrinkage_regularization_strength = ( 137 l2_shrinkage_regularization_strength) 138 139 def _create_slots(self, var_list): 140 # Create the "accum" and "linear" slots. 141 for var in var_list: 142 dtype = var.dtype.base_dtype 143 init = init_ops.constant_initializer( 144 self._initial_accumulator_value, dtype=dtype) 145 self.add_slot(var, 'accumulator', init) 146 self.add_slot(var, 'linear') 147 148 def _prepare_local(self, var_device, var_dtype, apply_state): 149 super(Ftrl, self)._prepare_local(var_device, var_dtype, apply_state) 150 apply_state[(var_device, var_dtype)].update(dict( 151 learning_rate_power=array_ops.identity( 152 self._get_hyper('learning_rate_power', var_dtype)), 153 l1_regularization_strength=array_ops.identity( 154 self._get_hyper('l1_regularization_strength', var_dtype)), 155 l2_regularization_strength=array_ops.identity( 156 self._get_hyper('l2_regularization_strength', var_dtype)), 157 l2_shrinkage_regularization_strength=math_ops.cast( 158 self._l2_shrinkage_regularization_strength, var_dtype) 159 )) 160 161 def _resource_apply_dense(self, grad, var, apply_state=None): 162 var_device, var_dtype = var.device, var.dtype.base_dtype 163 coefficients = ((apply_state or {}).get((var_device, var_dtype)) 164 or self._fallback_apply_state(var_device, var_dtype)) 165 166 accum = self.get_slot(var, 'accumulator') 167 linear = self.get_slot(var, 'linear') 168 169 if self._l2_shrinkage_regularization_strength <= 0.0: 170 return training_ops.resource_apply_ftrl( 171 var.handle, 172 accum.handle, 173 linear.handle, 174 grad, 175 coefficients['lr_t'], 176 coefficients['l1_regularization_strength'], 177 coefficients['l2_regularization_strength'], 178 coefficients['learning_rate_power'], 179 use_locking=self._use_locking) 180 else: 181 return training_ops.resource_apply_ftrl_v2( 182 var.handle, 183 accum.handle, 184 linear.handle, 185 grad, 186 coefficients['lr_t'], 187 coefficients['l1_regularization_strength'], 188 coefficients['l2_regularization_strength'], 189 coefficients['l2_shrinkage_regularization_strength'], 190 coefficients['learning_rate_power'], 191 use_locking=self._use_locking) 192 193 def _resource_apply_sparse(self, grad, var, indices, apply_state=None): 194 var_device, var_dtype = var.device, var.dtype.base_dtype 195 coefficients = ((apply_state or {}).get((var_device, var_dtype)) 196 or self._fallback_apply_state(var_device, var_dtype)) 197 198 accum = self.get_slot(var, 'accumulator') 199 linear = self.get_slot(var, 'linear') 200 201 if self._l2_shrinkage_regularization_strength <= 0.0: 202 return training_ops.resource_sparse_apply_ftrl( 203 var.handle, 204 accum.handle, 205 linear.handle, 206 grad, 207 indices, 208 coefficients['lr_t'], 209 coefficients['l1_regularization_strength'], 210 coefficients['l2_regularization_strength'], 211 coefficients['learning_rate_power'], 212 use_locking=self._use_locking) 213 else: 214 return training_ops.resource_sparse_apply_ftrl_v2( 215 var.handle, 216 accum.handle, 217 linear.handle, 218 grad, 219 indices, 220 coefficients['lr_t'], 221 coefficients['l1_regularization_strength'], 222 coefficients['l2_regularization_strength'], 223 coefficients['l2_shrinkage_regularization_strength'], 224 coefficients['learning_rate_power'], 225 use_locking=self._use_locking) 226 227 def get_config(self): 228 config = super(Ftrl, self).get_config() 229 config.update({ 230 'learning_rate': 231 self._serialize_hyperparameter('learning_rate'), 232 'decay': 233 self._serialize_hyperparameter('decay'), 234 'initial_accumulator_value': 235 self._initial_accumulator_value, 236 'learning_rate_power': 237 self._serialize_hyperparameter('learning_rate_power'), 238 'l1_regularization_strength': 239 self._serialize_hyperparameter('l1_regularization_strength'), 240 'l2_regularization_strength': 241 self._serialize_hyperparameter('l2_regularization_strength'), 242 'l2_shrinkage_regularization_strength': 243 self._l2_shrinkage_regularization_strength, 244 }) 245 return config 246