1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Ftrl-proximal optimizer implementation.""" 16# pylint: disable=g-classes-have-attributes 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.keras.optimizer_v2 import optimizer_v2 22from tensorflow.python.ops import array_ops 23from tensorflow.python.ops import init_ops 24from tensorflow.python.ops import math_ops 25from tensorflow.python.training import gen_training_ops 26from tensorflow.python.util.tf_export import keras_export 27 28 29@keras_export('keras.optimizers.Ftrl') 30class Ftrl(optimizer_v2.OptimizerV2): 31 r"""Optimizer that implements the FTRL algorithm. 32 33 See Algorithm 1 of this 34 [paper](https://research.google.com/pubs/archive/41159.pdf). 35 This version has support for both online L2 (the L2 penalty given in the paper 36 above) and shrinkage-type L2 (which is the addition of an L2 penalty to the 37 loss function). 38 39 Initialization: 40 $$t = 0$$ 41 $$n_{0} = 0$$ 42 $$\sigma_{0} = 0$$ 43 $$z_{0} = 0$$ 44 45 Update ($$i$$ is variable index, $$\alpha$$ is the learning rate): 46 $$t = t + 1$$ 47 $$n_{t,i} = n_{t-1,i} + g_{t,i}^{2}$$ 48 $$\sigma_{t,i} = (\sqrt{n_{t,i}} - \sqrt{n_{t-1,i}}) / \alpha$$ 49 $$z_{t,i} = z_{t-1,i} + g_{t,i} - \sigma_{t,i} * w_{t,i}$$ 50 $$w_{t,i} = - ((\beta+\sqrt{n_{t,i}}) / \alpha + 2 * \lambda_{2})^{-1} * 51 (z_{i} - sgn(z_{i}) * \lambda_{1}) if \abs{z_{i}} > \lambda_{i} 52 else 0$$ 53 54 Check the documentation for the l2_shrinkage_regularization_strength 55 parameter for more details when shrinkage is enabled, in which case gradient 56 is replaced with gradient_with_shrinkage. 57 58 Args: 59 learning_rate: A `Tensor`, floating point value, or a schedule that is a 60 `tf.keras.optimizers.schedules.LearningRateSchedule`. The learning rate. 61 learning_rate_power: A float value, must be less or equal to zero. 62 Controls how the learning rate decreases during training. Use zero for 63 a fixed learning rate. 64 initial_accumulator_value: The starting value for accumulators. 65 Only zero or positive values are allowed. 66 l1_regularization_strength: A float value, must be greater than or 67 equal to zero. Defaults to 0.0. 68 l2_regularization_strength: A float value, must be greater than or 69 equal to zero. Defaults to 0.0. 70 name: Optional name prefix for the operations created when applying 71 gradients. Defaults to `"Ftrl"`. 72 l2_shrinkage_regularization_strength: A float value, must be greater than 73 or equal to zero. This differs from L2 above in that the L2 above is a 74 stabilization penalty, whereas this L2 shrinkage is a magnitude penalty. 75 When input is sparse shrinkage will only happen on the active weights. 76 beta: A float value, representing the beta value from the paper. 77 Defaults to 0.0. 78 **kwargs: Keyword arguments. Allowed to be one of 79 `"clipnorm"` or `"clipvalue"`. 80 `"clipnorm"` (float) clips gradients by norm; `"clipvalue"` (float) clips 81 gradients by value. 82 83 Reference: 84 - [paper]( 85 https://www.eecs.tufts.edu/~dsculley/papers/ad-click-prediction.pdf) 86 """ 87 88 def __init__(self, 89 learning_rate=0.001, 90 learning_rate_power=-0.5, 91 initial_accumulator_value=0.1, 92 l1_regularization_strength=0.0, 93 l2_regularization_strength=0.0, 94 name='Ftrl', 95 l2_shrinkage_regularization_strength=0.0, 96 beta=0.0, 97 **kwargs): 98 super(Ftrl, self).__init__(name, **kwargs) 99 100 if initial_accumulator_value < 0.0: 101 raise ValueError( 102 'initial_accumulator_value %f needs to be positive or zero' % 103 initial_accumulator_value) 104 if learning_rate_power > 0.0: 105 raise ValueError('learning_rate_power %f needs to be negative or zero' % 106 learning_rate_power) 107 if l1_regularization_strength < 0.0: 108 raise ValueError( 109 'l1_regularization_strength %f needs to be positive or zero' % 110 l1_regularization_strength) 111 if l2_regularization_strength < 0.0: 112 raise ValueError( 113 'l2_regularization_strength %f needs to be positive or zero' % 114 l2_regularization_strength) 115 if l2_shrinkage_regularization_strength < 0.0: 116 raise ValueError( 117 'l2_shrinkage_regularization_strength %f needs to be positive' 118 ' or zero' % l2_shrinkage_regularization_strength) 119 120 self._set_hyper('learning_rate', learning_rate) 121 self._set_hyper('decay', self._initial_decay) 122 self._set_hyper('learning_rate_power', learning_rate_power) 123 self._set_hyper('l1_regularization_strength', l1_regularization_strength) 124 self._set_hyper('l2_regularization_strength', l2_regularization_strength) 125 self._set_hyper('beta', beta) 126 self._initial_accumulator_value = initial_accumulator_value 127 self._l2_shrinkage_regularization_strength = ( 128 l2_shrinkage_regularization_strength) 129 130 def _create_slots(self, var_list): 131 # Create the "accum" and "linear" slots. 132 for var in var_list: 133 dtype = var.dtype.base_dtype 134 init = init_ops.constant_initializer( 135 self._initial_accumulator_value, dtype=dtype) 136 self.add_slot(var, 'accumulator', init) 137 self.add_slot(var, 'linear') 138 139 def _prepare_local(self, var_device, var_dtype, apply_state): 140 super(Ftrl, self)._prepare_local(var_device, var_dtype, apply_state) 141 apply_state[(var_device, var_dtype)].update( 142 dict( 143 learning_rate_power=array_ops.identity( 144 self._get_hyper('learning_rate_power', var_dtype)), 145 l1_regularization_strength=array_ops.identity( 146 self._get_hyper('l1_regularization_strength', var_dtype)), 147 l2_regularization_strength=array_ops.identity( 148 self._get_hyper('l2_regularization_strength', var_dtype)), 149 beta=array_ops.identity(self._get_hyper('beta', var_dtype)), 150 l2_shrinkage_regularization_strength=math_ops.cast( 151 self._l2_shrinkage_regularization_strength, var_dtype))) 152 153 def _resource_apply_dense(self, grad, var, apply_state=None): 154 var_device, var_dtype = var.device, var.dtype.base_dtype 155 coefficients = ((apply_state or {}).get((var_device, var_dtype)) 156 or self._fallback_apply_state(var_device, var_dtype)) 157 158 # Adjust L2 regularization strength to include beta to avoid the underlying 159 # TensorFlow ops needing to include it. 160 adjusted_l2_regularization_strength = ( 161 coefficients['l2_regularization_strength'] + coefficients['beta'] / 162 (2. * coefficients['lr_t'])) 163 164 accum = self.get_slot(var, 'accumulator') 165 linear = self.get_slot(var, 'linear') 166 167 if self._l2_shrinkage_regularization_strength <= 0.0: 168 return gen_training_ops.ResourceApplyFtrl( 169 var=var.handle, 170 accum=accum.handle, 171 linear=linear.handle, 172 grad=grad, 173 lr=coefficients['lr_t'], 174 l1=coefficients['l1_regularization_strength'], 175 l2=adjusted_l2_regularization_strength, 176 lr_power=coefficients['learning_rate_power'], 177 use_locking=self._use_locking) 178 else: 179 return gen_training_ops.ResourceApplyFtrlV2( 180 var=var.handle, 181 accum=accum.handle, 182 linear=linear.handle, 183 grad=grad, 184 lr=coefficients['lr_t'], 185 l1=coefficients['l1_regularization_strength'], 186 l2=adjusted_l2_regularization_strength, 187 l2_shrinkage=coefficients['l2_shrinkage_regularization_strength'], 188 lr_power=coefficients['learning_rate_power'], 189 use_locking=self._use_locking) 190 191 def _resource_apply_sparse(self, grad, var, indices, apply_state=None): 192 var_device, var_dtype = var.device, var.dtype.base_dtype 193 coefficients = ((apply_state or {}).get((var_device, var_dtype)) 194 or self._fallback_apply_state(var_device, var_dtype)) 195 196 # Adjust L2 regularization strength to include beta to avoid the underlying 197 # TensorFlow ops needing to include it. 198 adjusted_l2_regularization_strength = ( 199 coefficients['l2_regularization_strength'] + coefficients['beta'] / 200 (2. * coefficients['lr_t'])) 201 202 accum = self.get_slot(var, 'accumulator') 203 linear = self.get_slot(var, 'linear') 204 205 if self._l2_shrinkage_regularization_strength <= 0.0: 206 return gen_training_ops.ResourceSparseApplyFtrl( 207 var=var.handle, 208 accum=accum.handle, 209 linear=linear.handle, 210 grad=grad, 211 indices=indices, 212 lr=coefficients['lr_t'], 213 l1=coefficients['l1_regularization_strength'], 214 l2=adjusted_l2_regularization_strength, 215 lr_power=coefficients['learning_rate_power'], 216 use_locking=self._use_locking) 217 else: 218 return gen_training_ops.ResourceSparseApplyFtrlV2( 219 var=var.handle, 220 accum=accum.handle, 221 linear=linear.handle, 222 grad=grad, 223 indices=indices, 224 lr=coefficients['lr_t'], 225 l1=coefficients['l1_regularization_strength'], 226 l2=adjusted_l2_regularization_strength, 227 l2_shrinkage=coefficients['l2_shrinkage_regularization_strength'], 228 lr_power=coefficients['learning_rate_power'], 229 use_locking=self._use_locking) 230 231 def get_config(self): 232 config = super(Ftrl, self).get_config() 233 config.update({ 234 'learning_rate': 235 self._serialize_hyperparameter('learning_rate'), 236 'decay': 237 self._initial_decay, 238 'initial_accumulator_value': 239 self._initial_accumulator_value, 240 'learning_rate_power': 241 self._serialize_hyperparameter('learning_rate_power'), 242 'l1_regularization_strength': 243 self._serialize_hyperparameter('l1_regularization_strength'), 244 'l2_regularization_strength': 245 self._serialize_hyperparameter('l2_regularization_strength'), 246 'beta': 247 self._serialize_hyperparameter('beta'), 248 'l2_shrinkage_regularization_strength': 249 self._l2_shrinkage_regularization_strength, 250 }) 251 return config 252