1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Ftrl-proximal for TensorFlow.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from tensorflow.python.framework import constant_op 21from tensorflow.python.framework import ops 22from tensorflow.python.ops import math_ops 23from tensorflow.python.training import optimizer 24from tensorflow.python.training import training_ops 25from tensorflow.python.util.tf_export import tf_export 26 27 28@tf_export(v1=["train.FtrlOptimizer"]) 29class FtrlOptimizer(optimizer.Optimizer): 30 """Optimizer that implements the FTRL algorithm. 31 32 See this [paper]( 33 https://www.eecs.tufts.edu/~dsculley/papers/ad-click-prediction.pdf). 34 This version has support for both online L2 (the L2 penalty given in the paper 35 above) and shrinkage-type L2 (which is the addition of an L2 penalty to the 36 loss function). 37 """ 38 39 def __init__(self, 40 learning_rate, 41 learning_rate_power=-0.5, 42 initial_accumulator_value=0.1, 43 l1_regularization_strength=0.0, 44 l2_regularization_strength=0.0, 45 use_locking=False, 46 name="Ftrl", 47 accum_name=None, 48 linear_name=None, 49 l2_shrinkage_regularization_strength=0.0): 50 r"""Construct a new FTRL optimizer. 51 52 Args: 53 learning_rate: A float value or a constant float `Tensor`. 54 learning_rate_power: A float value, must be less or equal to zero. 55 Controls how the learning rate decreases during training. Use zero for 56 a fixed learning rate. See section 3.1 in the 57 [paper](https://www.eecs.tufts.edu/~dsculley/papers/ad-click-prediction.pdf). 58 initial_accumulator_value: The starting value for accumulators. 59 Only zero or positive values are allowed. 60 l1_regularization_strength: A float value, must be greater than or 61 equal to zero. 62 l2_regularization_strength: A float value, must be greater than or 63 equal to zero. 64 use_locking: If `True` use locks for update operations. 65 name: Optional name prefix for the operations created when applying 66 gradients. Defaults to "Ftrl". 67 accum_name: The suffix for the variable that keeps the gradient squared 68 accumulator. If not present, defaults to name. 69 linear_name: The suffix for the variable that keeps the linear gradient 70 accumulator. If not present, defaults to name + "_1". 71 l2_shrinkage_regularization_strength: A float value, must be greater than 72 or equal to zero. This differs from L2 above in that the L2 above is a 73 stabilization penalty, whereas this L2 shrinkage is a magnitude penalty. 74 The FTRL formulation can be written as: 75 w_{t+1} = argmin_w(\hat{g}_{1:t}w + L1*||w||_1 + L2*||w||_2^2), where 76 \hat{g} = g + (2*L2_shrinkage*w), and g is the gradient of the loss 77 function w.r.t. the weights w. 78 Specifically, in the absence of L1 regularization, it is equivalent to 79 the following update rule: 80 w_{t+1} = w_t - lr_t / (1 + 2*L2*lr_t) * g_t - 81 2*L2_shrinkage*lr_t / (1 + 2*L2*lr_t) * w_t 82 where lr_t is the learning rate at t. 83 When input is sparse shrinkage will only happen on the active weights. 84 85 Raises: 86 ValueError: If one of the arguments is invalid. 87 """ 88 super(FtrlOptimizer, self).__init__(use_locking, name) 89 90 if initial_accumulator_value < 0.0: 91 raise ValueError( 92 "initial_accumulator_value %f needs to be positive or zero" % 93 initial_accumulator_value) 94 if learning_rate_power > 0.0: 95 raise ValueError("learning_rate_power %f needs to be negative or zero" % 96 learning_rate_power) 97 if l1_regularization_strength < 0.0: 98 raise ValueError( 99 "l1_regularization_strength %f needs to be positive or zero" % 100 l1_regularization_strength) 101 if l2_regularization_strength < 0.0: 102 raise ValueError( 103 "l2_regularization_strength %f needs to be positive or zero" % 104 l2_regularization_strength) 105 if l2_shrinkage_regularization_strength < 0.0: 106 raise ValueError( 107 "l2_shrinkage_regularization_strength %f needs to be positive" 108 " or zero" % l2_shrinkage_regularization_strength) 109 110 self._learning_rate = learning_rate 111 self._learning_rate_power = learning_rate_power 112 self._initial_accumulator_value = initial_accumulator_value 113 self._l1_regularization_strength = l1_regularization_strength 114 self._l2_regularization_strength = l2_regularization_strength 115 self._l2_shrinkage_regularization_strength = ( 116 l2_shrinkage_regularization_strength) 117 self._learning_rate_tensor = None 118 self._learning_rate_power_tensor = None 119 self._l1_regularization_strength_tensor = None 120 self._l2_regularization_strength_tensor = None 121 self._l2_shrinkage_regularization_strength_tensor = None 122 self._accum_name = accum_name 123 self._linear_name = linear_name 124 125 def _create_slots(self, var_list): 126 # Create the "accum" and "linear" slots. 127 for v in var_list: 128 with ops.colocate_with(v): 129 val = constant_op.constant( 130 self._initial_accumulator_value, dtype=v.dtype, shape=v.get_shape()) 131 self._get_or_make_slot(v, val, "accum", self._accum_name or self._name) 132 self._zeros_slot(v, "linear", self._linear_name or self._name) 133 134 def _prepare(self): 135 self._learning_rate_tensor = ops.convert_to_tensor( 136 self._learning_rate, name="learning_rate") 137 self._l1_regularization_strength_tensor = ops.convert_to_tensor( 138 self._l1_regularization_strength, name="l1_regularization_strength") 139 self._l2_regularization_strength_tensor = ops.convert_to_tensor( 140 self._l2_regularization_strength, name="l2_regularization_strength") 141 self._l2_shrinkage_regularization_strength_tensor = ops.convert_to_tensor( 142 self._l2_shrinkage_regularization_strength, 143 name="l2_shrinkage_regularization_strength") 144 self._learning_rate_power_tensor = ops.convert_to_tensor( 145 self._learning_rate_power, name="learning_rate_power") 146 147 def _apply_dense(self, grad, var): 148 accum = self.get_slot(var, "accum") 149 linear = self.get_slot(var, "linear") 150 if self._l2_shrinkage_regularization_strength <= 0.0: 151 return training_ops.apply_ftrl( 152 var, 153 accum, 154 linear, 155 grad, 156 math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype), 157 math_ops.cast(self._l1_regularization_strength_tensor, 158 var.dtype.base_dtype), 159 math_ops.cast(self._l2_regularization_strength_tensor, 160 var.dtype.base_dtype), 161 math_ops.cast(self._learning_rate_power_tensor, var.dtype.base_dtype), 162 use_locking=self._use_locking) 163 else: 164 return training_ops.apply_ftrl_v2( 165 var, 166 accum, 167 linear, 168 grad, 169 math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype), 170 math_ops.cast(self._l1_regularization_strength_tensor, 171 var.dtype.base_dtype), 172 math_ops.cast(self._l2_regularization_strength_tensor, 173 var.dtype.base_dtype), 174 math_ops.cast(self._l2_shrinkage_regularization_strength_tensor, 175 var.dtype.base_dtype), 176 math_ops.cast(self._learning_rate_power_tensor, var.dtype.base_dtype), 177 use_locking=self._use_locking) 178 179 def _resource_apply_dense(self, grad, var): 180 accum = self.get_slot(var, "accum") 181 linear = self.get_slot(var, "linear") 182 if self._l2_shrinkage_regularization_strength <= 0.0: 183 return training_ops.resource_apply_ftrl( 184 var.handle, 185 accum.handle, 186 linear.handle, 187 grad, 188 math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype), 189 math_ops.cast(self._l1_regularization_strength_tensor, 190 var.dtype.base_dtype), 191 math_ops.cast(self._l2_regularization_strength_tensor, 192 var.dtype.base_dtype), 193 math_ops.cast(self._learning_rate_power_tensor, var.dtype.base_dtype), 194 use_locking=self._use_locking) 195 else: 196 return training_ops.resource_apply_ftrl_v2( 197 var.handle, 198 accum.handle, 199 linear.handle, 200 grad, 201 math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype), 202 math_ops.cast(self._l1_regularization_strength_tensor, 203 var.dtype.base_dtype), 204 math_ops.cast(self._l2_regularization_strength_tensor, 205 var.dtype.base_dtype), 206 math_ops.cast(self._l2_shrinkage_regularization_strength_tensor, 207 var.dtype.base_dtype), 208 math_ops.cast(self._learning_rate_power_tensor, var.dtype.base_dtype), 209 use_locking=self._use_locking) 210 211 def _apply_sparse(self, grad, var): 212 accum = self.get_slot(var, "accum") 213 linear = self.get_slot(var, "linear") 214 if self._l2_shrinkage_regularization_strength <= 0.0: 215 return training_ops.sparse_apply_ftrl( 216 var, 217 accum, 218 linear, 219 grad.values, 220 grad.indices, 221 math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype), 222 math_ops.cast(self._l1_regularization_strength_tensor, 223 var.dtype.base_dtype), 224 math_ops.cast(self._l2_regularization_strength_tensor, 225 var.dtype.base_dtype), 226 math_ops.cast(self._learning_rate_power_tensor, var.dtype.base_dtype), 227 use_locking=self._use_locking) 228 else: 229 return training_ops.sparse_apply_ftrl_v2( 230 var, 231 accum, 232 linear, 233 grad.values, 234 grad.indices, 235 math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype), 236 math_ops.cast(self._l1_regularization_strength_tensor, 237 var.dtype.base_dtype), 238 math_ops.cast(self._l2_regularization_strength_tensor, 239 var.dtype.base_dtype), 240 math_ops.cast(self._l2_shrinkage_regularization_strength_tensor, 241 grad.dtype.base_dtype), 242 math_ops.cast(self._learning_rate_power_tensor, var.dtype.base_dtype), 243 use_locking=self._use_locking) 244 245 def _resource_apply_sparse(self, grad, var, indices): 246 accum = self.get_slot(var, "accum") 247 linear = self.get_slot(var, "linear") 248 if self._l2_shrinkage_regularization_strength <= 0.0: 249 return training_ops.resource_sparse_apply_ftrl( 250 var.handle, 251 accum.handle, 252 linear.handle, 253 grad, 254 indices, 255 math_ops.cast(self._learning_rate_tensor, grad.dtype), 256 math_ops.cast(self._l1_regularization_strength_tensor, grad.dtype), 257 math_ops.cast(self._l2_regularization_strength_tensor, grad.dtype), 258 math_ops.cast(self._learning_rate_power_tensor, grad.dtype), 259 use_locking=self._use_locking) 260 else: 261 return training_ops.resource_sparse_apply_ftrl_v2( 262 var.handle, 263 accum.handle, 264 linear.handle, 265 grad, 266 indices, 267 math_ops.cast(self._learning_rate_tensor, grad.dtype), 268 math_ops.cast(self._l1_regularization_strength_tensor, grad.dtype), 269 math_ops.cast(self._l2_regularization_strength_tensor, grad.dtype), 270 math_ops.cast(self._l2_shrinkage_regularization_strength_tensor, 271 grad.dtype), 272 math_ops.cast(self._learning_rate_power_tensor, grad.dtype), 273 use_locking=self._use_locking) 274