1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Ftrl-proximal for TensorFlow.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from tensorflow.python.framework import constant_op 21from tensorflow.python.framework import ops 22from tensorflow.python.ops import math_ops 23from tensorflow.python.training import optimizer 24from tensorflow.python.training import training_ops 25from tensorflow.python.util.tf_export import tf_export 26 27 28@tf_export(v1=["train.FtrlOptimizer"]) 29class FtrlOptimizer(optimizer.Optimizer): 30 """Optimizer that implements the FTRL algorithm. 31 32 This version has support for both online L2 (McMahan et al., 2013) and 33 shrinkage-type L2, which is the addition of an L2 penalty 34 to the loss function. 35 36 References: 37 Ad-click prediction: 38 [McMahan et al., 2013](https://dl.acm.org/citation.cfm?id=2488200) 39 ([pdf](https://dl.acm.org/ft_gateway.cfm?id=2488200&ftid=1388399&dwn=1&CFID=32233078&CFTOKEN=d60fe57a294c056a-CB75C374-F915-E7A6-1573FBBC7BF7D526)) 40 """ 41 42 def __init__(self, 43 learning_rate, 44 learning_rate_power=-0.5, 45 initial_accumulator_value=0.1, 46 l1_regularization_strength=0.0, 47 l2_regularization_strength=0.0, 48 use_locking=False, 49 name="Ftrl", 50 accum_name=None, 51 linear_name=None, 52 l2_shrinkage_regularization_strength=0.0, 53 beta=None): 54 r"""Construct a new FTRL optimizer. 55 56 Args: 57 learning_rate: A float value or a constant float `Tensor`. 58 learning_rate_power: A float value, must be less or equal to zero. 59 Controls how the learning rate decreases during training. Use zero for 60 a fixed learning rate. See section 3.1 in (McMahan et al., 2013). 61 initial_accumulator_value: The starting value for accumulators. 62 Only zero or positive values are allowed. 63 l1_regularization_strength: A float value, must be greater than or 64 equal to zero. 65 l2_regularization_strength: A float value, must be greater than or 66 equal to zero. 67 use_locking: If `True` use locks for update operations. 68 name: Optional name prefix for the operations created when applying 69 gradients. Defaults to "Ftrl". 70 accum_name: The suffix for the variable that keeps the gradient squared 71 accumulator. If not present, defaults to name. 72 linear_name: The suffix for the variable that keeps the linear gradient 73 accumulator. If not present, defaults to name + "_1". 74 l2_shrinkage_regularization_strength: A float value, must be greater than 75 or equal to zero. This differs from L2 above in that the L2 above is a 76 stabilization penalty, whereas this L2 shrinkage is a magnitude penalty. 77 The FTRL formulation can be written as: 78 w_{t+1} = argmin_w(\hat{g}_{1:t}w + L1*||w||_1 + L2*||w||_2^2), where 79 \hat{g} = g + (2*L2_shrinkage*w), and g is the gradient of the loss 80 function w.r.t. the weights w. 81 Specifically, in the absence of L1 regularization, it is equivalent to 82 the following update rule: 83 w_{t+1} = w_t - lr_t / (beta + 2*L2*lr_t) * g_t - 84 2*L2_shrinkage*lr_t / (beta + 2*L2*lr_t) * w_t 85 where lr_t is the learning rate at t. 86 When input is sparse shrinkage will only happen on the active weights. 87 beta: A float value; corresponds to the beta parameter in the paper. 88 89 Raises: 90 ValueError: If one of the arguments is invalid. 91 92 References: 93 Ad-click prediction: 94 [McMahan et al., 2013](https://dl.acm.org/citation.cfm?id=2488200) 95 ([pdf](https://dl.acm.org/ft_gateway.cfm?id=2488200&ftid=1388399&dwn=1&CFID=32233078&CFTOKEN=d60fe57a294c056a-CB75C374-F915-E7A6-1573FBBC7BF7D526)) 96 """ 97 super(FtrlOptimizer, self).__init__(use_locking, name) 98 99 if initial_accumulator_value < 0.0: 100 raise ValueError( 101 "initial_accumulator_value %f needs to be positive or zero" % 102 initial_accumulator_value) 103 if learning_rate_power > 0.0: 104 raise ValueError("learning_rate_power %f needs to be negative or zero" % 105 learning_rate_power) 106 if l1_regularization_strength < 0.0: 107 raise ValueError( 108 "l1_regularization_strength %f needs to be positive or zero" % 109 l1_regularization_strength) 110 if l2_regularization_strength < 0.0: 111 raise ValueError( 112 "l2_regularization_strength %f needs to be positive or zero" % 113 l2_regularization_strength) 114 if l2_shrinkage_regularization_strength < 0.0: 115 raise ValueError( 116 "l2_shrinkage_regularization_strength %f needs to be positive" 117 " or zero" % l2_shrinkage_regularization_strength) 118 119 self._learning_rate = learning_rate 120 self._learning_rate_power = learning_rate_power 121 self._initial_accumulator_value = initial_accumulator_value 122 self._l1_regularization_strength = l1_regularization_strength 123 self._l2_regularization_strength = l2_regularization_strength 124 self._beta = (0.0 if beta is None else beta) 125 self._l2_shrinkage_regularization_strength = ( 126 l2_shrinkage_regularization_strength) 127 self._learning_rate_tensor = None 128 self._learning_rate_power_tensor = None 129 self._l1_regularization_strength_tensor = None 130 self._adjusted_l2_regularization_strength_tensor = None 131 self._l2_shrinkage_regularization_strength_tensor = None 132 self._accum_name = accum_name 133 self._linear_name = linear_name 134 135 def _create_slots(self, var_list): 136 # Create the "accum" and "linear" slots. 137 for v in var_list: 138 val = constant_op.constant( 139 self._initial_accumulator_value, dtype=v.dtype, shape=v.get_shape()) 140 self._get_or_make_slot(v, val, "accum", self._accum_name or self._name) 141 self._zeros_slot(v, "linear", self._linear_name or self._name) 142 143 def _prepare(self): 144 self._learning_rate_tensor = ops.convert_to_tensor( 145 self._learning_rate, name="learning_rate") 146 self._l1_regularization_strength_tensor = ops.convert_to_tensor( 147 self._l1_regularization_strength, name="l1_regularization_strength") 148 # L2 regularization strength with beta added in so that the underlying 149 # TensorFlow ops do not need to include that parameter. 150 self._adjusted_l2_regularization_strength_tensor = ops.convert_to_tensor( 151 self._l2_regularization_strength + self._beta / 152 (2. * math_ops.maximum(self._learning_rate, 1e-36)), 153 name="adjusted_l2_regularization_strength") 154 assert self._adjusted_l2_regularization_strength_tensor is not None 155 self._beta_tensor = ops.convert_to_tensor(self._beta, name="beta") 156 self._l2_shrinkage_regularization_strength_tensor = ops.convert_to_tensor( 157 self._l2_shrinkage_regularization_strength, 158 name="l2_shrinkage_regularization_strength") 159 self._learning_rate_power_tensor = ops.convert_to_tensor( 160 self._learning_rate_power, name="learning_rate_power") 161 162 def _apply_dense(self, grad, var): 163 accum = self.get_slot(var, "accum") 164 linear = self.get_slot(var, "linear") 165 if self._l2_shrinkage_regularization_strength <= 0.0: 166 return training_ops.apply_ftrl( 167 var, 168 accum, 169 linear, 170 grad, 171 math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype), 172 math_ops.cast(self._l1_regularization_strength_tensor, 173 var.dtype.base_dtype), 174 math_ops.cast(self._adjusted_l2_regularization_strength_tensor, 175 var.dtype.base_dtype), 176 math_ops.cast(self._learning_rate_power_tensor, var.dtype.base_dtype), 177 use_locking=self._use_locking) 178 else: 179 return training_ops.apply_ftrl_v2( 180 var, 181 accum, 182 linear, 183 grad, 184 math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype), 185 math_ops.cast(self._l1_regularization_strength_tensor, 186 var.dtype.base_dtype), 187 math_ops.cast(self._adjusted_l2_regularization_strength_tensor, 188 var.dtype.base_dtype), 189 math_ops.cast(self._l2_shrinkage_regularization_strength_tensor, 190 var.dtype.base_dtype), 191 math_ops.cast(self._learning_rate_power_tensor, var.dtype.base_dtype), 192 use_locking=self._use_locking) 193 194 def _resource_apply_dense(self, grad, var): 195 accum = self.get_slot(var, "accum") 196 linear = self.get_slot(var, "linear") 197 if self._l2_shrinkage_regularization_strength <= 0.0: 198 return training_ops.resource_apply_ftrl( 199 var.handle, 200 accum.handle, 201 linear.handle, 202 grad, 203 math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype), 204 math_ops.cast(self._l1_regularization_strength_tensor, 205 var.dtype.base_dtype), 206 math_ops.cast(self._adjusted_l2_regularization_strength_tensor, 207 var.dtype.base_dtype), 208 math_ops.cast(self._learning_rate_power_tensor, var.dtype.base_dtype), 209 use_locking=self._use_locking) 210 else: 211 return training_ops.resource_apply_ftrl_v2( 212 var.handle, 213 accum.handle, 214 linear.handle, 215 grad, 216 math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype), 217 math_ops.cast(self._l1_regularization_strength_tensor, 218 var.dtype.base_dtype), 219 math_ops.cast(self._adjusted_l2_regularization_strength_tensor, 220 var.dtype.base_dtype), 221 math_ops.cast(self._l2_shrinkage_regularization_strength_tensor, 222 var.dtype.base_dtype), 223 math_ops.cast(self._learning_rate_power_tensor, var.dtype.base_dtype), 224 use_locking=self._use_locking) 225 226 def _apply_sparse(self, grad, var): 227 accum = self.get_slot(var, "accum") 228 linear = self.get_slot(var, "linear") 229 if self._l2_shrinkage_regularization_strength <= 0.0: 230 return training_ops.sparse_apply_ftrl( 231 var, 232 accum, 233 linear, 234 grad.values, 235 grad.indices, 236 math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype), 237 math_ops.cast(self._l1_regularization_strength_tensor, 238 var.dtype.base_dtype), 239 math_ops.cast(self._adjusted_l2_regularization_strength_tensor, 240 var.dtype.base_dtype), 241 math_ops.cast(self._learning_rate_power_tensor, var.dtype.base_dtype), 242 use_locking=self._use_locking) 243 else: 244 return training_ops.sparse_apply_ftrl_v2( 245 var, 246 accum, 247 linear, 248 grad.values, 249 grad.indices, 250 math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype), 251 math_ops.cast(self._l1_regularization_strength_tensor, 252 var.dtype.base_dtype), 253 math_ops.cast(self._adjusted_l2_regularization_strength_tensor, 254 var.dtype.base_dtype), 255 math_ops.cast(self._l2_shrinkage_regularization_strength_tensor, 256 grad.dtype.base_dtype), 257 math_ops.cast(self._learning_rate_power_tensor, var.dtype.base_dtype), 258 use_locking=self._use_locking) 259 260 def _resource_apply_sparse(self, grad, var, indices): 261 accum = self.get_slot(var, "accum") 262 linear = self.get_slot(var, "linear") 263 if self._l2_shrinkage_regularization_strength <= 0.0: 264 return training_ops.resource_sparse_apply_ftrl( 265 var.handle, 266 accum.handle, 267 linear.handle, 268 grad, 269 indices, 270 math_ops.cast(self._learning_rate_tensor, grad.dtype), 271 math_ops.cast(self._l1_regularization_strength_tensor, grad.dtype), 272 math_ops.cast(self._adjusted_l2_regularization_strength_tensor, 273 grad.dtype), 274 math_ops.cast(self._learning_rate_power_tensor, grad.dtype), 275 use_locking=self._use_locking) 276 else: 277 return training_ops.resource_sparse_apply_ftrl_v2( 278 var.handle, 279 accum.handle, 280 linear.handle, 281 grad, 282 indices, 283 math_ops.cast(self._learning_rate_tensor, grad.dtype), 284 math_ops.cast(self._l1_regularization_strength_tensor, grad.dtype), 285 math_ops.cast(self._adjusted_l2_regularization_strength_tensor, 286 grad.dtype), 287 math_ops.cast(self._l2_shrinkage_regularization_strength_tensor, 288 grad.dtype), 289 math_ops.cast(self._learning_rate_power_tensor, grad.dtype), 290 use_locking=self._use_locking) 291