1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Ftrl-proximal optimizer implementation.""" 16# pylint: disable=g-classes-have-attributes 17 18from tensorflow.python.keras.optimizer_v2 import optimizer_v2 19from tensorflow.python.ops import array_ops 20from tensorflow.python.ops import init_ops 21from tensorflow.python.ops import math_ops 22from tensorflow.python.training import gen_training_ops 23from tensorflow.python.util.tf_export import keras_export 24 25 26@keras_export('keras.optimizers.Ftrl') 27class Ftrl(optimizer_v2.OptimizerV2): 28 r"""Optimizer that implements the FTRL algorithm. 29 30 "Follow The Regularized Leader" (FTRL) is an optimization algorithm developed 31 at Google for click-through rate prediction in the early 2010s. It is most 32 suitable for shallow models with large and sparse feature spaces. 33 The algorithm is described by 34 [McMahan et al., 2013](https://research.google.com/pubs/archive/41159.pdf). 35 The Keras version has support for both online L2 regularization 36 (the L2 regularization described in the paper 37 above) and shrinkage-type L2 regularization 38 (which is the addition of an L2 penalty to the loss function). 39 40 Initialization: 41 42 ```python 43 n = 0 44 sigma = 0 45 z = 0 46 ``` 47 48 Update rule for one variable `w`: 49 50 ```python 51 prev_n = n 52 n = n + g ** 2 53 sigma = (sqrt(n) - sqrt(prev_n)) / lr 54 z = z + g - sigma * w 55 if abs(z) < lambda_1: 56 w = 0 57 else: 58 w = (sgn(z) * lambda_1 - z) / ((beta + sqrt(n)) / alpha + lambda_2) 59 ``` 60 61 Notation: 62 63 - `lr` is the learning rate 64 - `g` is the gradient for the variable 65 - `lambda_1` is the L1 regularization strength 66 - `lambda_2` is the L2 regularization strength 67 68 Check the documentation for the `l2_shrinkage_regularization_strength` 69 parameter for more details when shrinkage is enabled, in which case gradient 70 is replaced with a gradient with shrinkage. 71 72 Args: 73 learning_rate: A `Tensor`, floating point value, or a schedule that is a 74 `tf.keras.optimizers.schedules.LearningRateSchedule`. The learning rate. 75 learning_rate_power: A float value, must be less or equal to zero. 76 Controls how the learning rate decreases during training. Use zero for 77 a fixed learning rate. 78 initial_accumulator_value: The starting value for accumulators. 79 Only zero or positive values are allowed. 80 l1_regularization_strength: A float value, must be greater than or 81 equal to zero. Defaults to 0.0. 82 l2_regularization_strength: A float value, must be greater than or 83 equal to zero. Defaults to 0.0. 84 name: Optional name prefix for the operations created when applying 85 gradients. Defaults to `"Ftrl"`. 86 l2_shrinkage_regularization_strength: A float value, must be greater than 87 or equal to zero. This differs from L2 above in that the L2 above is a 88 stabilization penalty, whereas this L2 shrinkage is a magnitude penalty. 89 When input is sparse shrinkage will only happen on the active weights. 90 beta: A float value, representing the beta value from the paper. 91 Defaults to 0.0. 92 **kwargs: Keyword arguments. Allowed to be one of 93 `"clipnorm"` or `"clipvalue"`. 94 `"clipnorm"` (float) clips gradients by norm; `"clipvalue"` (float) clips 95 gradients by value. 96 97 Reference: 98 - [McMahan et al., 2013]( 99 https://research.google.com/pubs/archive/41159.pdf) 100 """ 101 102 def __init__(self, 103 learning_rate=0.001, 104 learning_rate_power=-0.5, 105 initial_accumulator_value=0.1, 106 l1_regularization_strength=0.0, 107 l2_regularization_strength=0.0, 108 name='Ftrl', 109 l2_shrinkage_regularization_strength=0.0, 110 beta=0.0, 111 **kwargs): 112 super(Ftrl, self).__init__(name, **kwargs) 113 114 if initial_accumulator_value < 0.0: 115 raise ValueError( 116 'initial_accumulator_value %f needs to be positive or zero' % 117 initial_accumulator_value) 118 if learning_rate_power > 0.0: 119 raise ValueError('learning_rate_power %f needs to be negative or zero' % 120 learning_rate_power) 121 if l1_regularization_strength < 0.0: 122 raise ValueError( 123 'l1_regularization_strength %f needs to be positive or zero' % 124 l1_regularization_strength) 125 if l2_regularization_strength < 0.0: 126 raise ValueError( 127 'l2_regularization_strength %f needs to be positive or zero' % 128 l2_regularization_strength) 129 if l2_shrinkage_regularization_strength < 0.0: 130 raise ValueError( 131 'l2_shrinkage_regularization_strength %f needs to be positive' 132 ' or zero' % l2_shrinkage_regularization_strength) 133 134 self._set_hyper('learning_rate', learning_rate) 135 self._set_hyper('decay', self._initial_decay) 136 self._set_hyper('learning_rate_power', learning_rate_power) 137 self._set_hyper('l1_regularization_strength', l1_regularization_strength) 138 self._set_hyper('l2_regularization_strength', l2_regularization_strength) 139 self._set_hyper('beta', beta) 140 self._initial_accumulator_value = initial_accumulator_value 141 self._l2_shrinkage_regularization_strength = ( 142 l2_shrinkage_regularization_strength) 143 144 def _create_slots(self, var_list): 145 # Create the "accum" and "linear" slots. 146 for var in var_list: 147 dtype = var.dtype.base_dtype 148 init = init_ops.constant_initializer( 149 self._initial_accumulator_value, dtype=dtype) 150 self.add_slot(var, 'accumulator', init) 151 self.add_slot(var, 'linear') 152 153 def _prepare_local(self, var_device, var_dtype, apply_state): 154 super(Ftrl, self)._prepare_local(var_device, var_dtype, apply_state) 155 apply_state[(var_device, var_dtype)].update( 156 dict( 157 learning_rate_power=array_ops.identity( 158 self._get_hyper('learning_rate_power', var_dtype)), 159 l1_regularization_strength=array_ops.identity( 160 self._get_hyper('l1_regularization_strength', var_dtype)), 161 l2_regularization_strength=array_ops.identity( 162 self._get_hyper('l2_regularization_strength', var_dtype)), 163 beta=array_ops.identity(self._get_hyper('beta', var_dtype)), 164 l2_shrinkage_regularization_strength=math_ops.cast( 165 self._l2_shrinkage_regularization_strength, var_dtype))) 166 167 def _resource_apply_dense(self, grad, var, apply_state=None): 168 var_device, var_dtype = var.device, var.dtype.base_dtype 169 coefficients = ((apply_state or {}).get((var_device, var_dtype)) 170 or self._fallback_apply_state(var_device, var_dtype)) 171 172 # Adjust L2 regularization strength to include beta to avoid the underlying 173 # TensorFlow ops needing to include it. 174 adjusted_l2_regularization_strength = ( 175 coefficients['l2_regularization_strength'] + coefficients['beta'] / 176 (2. * coefficients['lr_t'])) 177 178 accum = self.get_slot(var, 'accumulator') 179 linear = self.get_slot(var, 'linear') 180 181 if self._l2_shrinkage_regularization_strength <= 0.0: 182 return gen_training_ops.ResourceApplyFtrl( 183 var=var.handle, 184 accum=accum.handle, 185 linear=linear.handle, 186 grad=grad, 187 lr=coefficients['lr_t'], 188 l1=coefficients['l1_regularization_strength'], 189 l2=adjusted_l2_regularization_strength, 190 lr_power=coefficients['learning_rate_power'], 191 use_locking=self._use_locking) 192 else: 193 return gen_training_ops.ResourceApplyFtrlV2( 194 var=var.handle, 195 accum=accum.handle, 196 linear=linear.handle, 197 grad=grad, 198 lr=coefficients['lr_t'], 199 l1=coefficients['l1_regularization_strength'], 200 l2=adjusted_l2_regularization_strength, 201 l2_shrinkage=coefficients['l2_shrinkage_regularization_strength'], 202 lr_power=coefficients['learning_rate_power'], 203 use_locking=self._use_locking) 204 205 def _resource_apply_sparse(self, grad, var, indices, apply_state=None): 206 var_device, var_dtype = var.device, var.dtype.base_dtype 207 coefficients = ((apply_state or {}).get((var_device, var_dtype)) 208 or self._fallback_apply_state(var_device, var_dtype)) 209 210 # Adjust L2 regularization strength to include beta to avoid the underlying 211 # TensorFlow ops needing to include it. 212 adjusted_l2_regularization_strength = ( 213 coefficients['l2_regularization_strength'] + coefficients['beta'] / 214 (2. * coefficients['lr_t'])) 215 216 accum = self.get_slot(var, 'accumulator') 217 linear = self.get_slot(var, 'linear') 218 219 if self._l2_shrinkage_regularization_strength <= 0.0: 220 return gen_training_ops.ResourceSparseApplyFtrl( 221 var=var.handle, 222 accum=accum.handle, 223 linear=linear.handle, 224 grad=grad, 225 indices=indices, 226 lr=coefficients['lr_t'], 227 l1=coefficients['l1_regularization_strength'], 228 l2=adjusted_l2_regularization_strength, 229 lr_power=coefficients['learning_rate_power'], 230 use_locking=self._use_locking) 231 else: 232 return gen_training_ops.ResourceSparseApplyFtrlV2( 233 var=var.handle, 234 accum=accum.handle, 235 linear=linear.handle, 236 grad=grad, 237 indices=indices, 238 lr=coefficients['lr_t'], 239 l1=coefficients['l1_regularization_strength'], 240 l2=adjusted_l2_regularization_strength, 241 l2_shrinkage=coefficients['l2_shrinkage_regularization_strength'], 242 lr_power=coefficients['learning_rate_power'], 243 use_locking=self._use_locking) 244 245 def get_config(self): 246 config = super(Ftrl, self).get_config() 247 config.update({ 248 'learning_rate': 249 self._serialize_hyperparameter('learning_rate'), 250 'decay': 251 self._initial_decay, 252 'initial_accumulator_value': 253 self._initial_accumulator_value, 254 'learning_rate_power': 255 self._serialize_hyperparameter('learning_rate_power'), 256 'l1_regularization_strength': 257 self._serialize_hyperparameter('l1_regularization_strength'), 258 'l2_regularization_strength': 259 self._serialize_hyperparameter('l2_regularization_strength'), 260 'beta': 261 self._serialize_hyperparameter('beta'), 262 'l2_shrinkage_regularization_strength': 263 self._l2_shrinkage_regularization_strength, 264 }) 265 return config 266