1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Ftrl-proximal for TensorFlow.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from tensorflow.python.framework import constant_op 21from tensorflow.python.framework import ops 22from tensorflow.python.ops import math_ops 23from tensorflow.python.training import optimizer 24from tensorflow.python.training import training_ops 25from tensorflow.python.util.tf_export import tf_export 26 27 28@tf_export("train.FtrlOptimizer") 29class FtrlOptimizer(optimizer.Optimizer): 30 """Optimizer that implements the FTRL algorithm. 31 32 See this [paper]( 33 https://www.eecs.tufts.edu/~dsculley/papers/ad-click-prediction.pdf). 34 This version has support for both online L2 (the L2 penalty given in the paper 35 above) and shrinkage-type L2 (which is the addition of an L2 penalty to the 36 loss function). 37 """ 38 39 def __init__(self, 40 learning_rate, 41 learning_rate_power=-0.5, 42 initial_accumulator_value=0.1, 43 l1_regularization_strength=0.0, 44 l2_regularization_strength=0.0, 45 use_locking=False, 46 name="Ftrl", 47 accum_name=None, 48 linear_name=None, 49 l2_shrinkage_regularization_strength=0.0): 50 r"""Construct a new FTRL optimizer. 51 52 Args: 53 learning_rate: A float value or a constant float `Tensor`. 54 learning_rate_power: A float value, must be less or equal to zero. 55 initial_accumulator_value: The starting value for accumulators. 56 Only positive values are allowed. 57 l1_regularization_strength: A float value, must be greater than or 58 equal to zero. 59 l2_regularization_strength: A float value, must be greater than or 60 equal to zero. 61 use_locking: If `True` use locks for update operations. 62 name: Optional name prefix for the operations created when applying 63 gradients. Defaults to "Ftrl". 64 accum_name: The suffix for the variable that keeps the gradient squared 65 accumulator. If not present, defaults to name. 66 linear_name: The suffix for the variable that keeps the linear gradient 67 accumulator. If not present, defaults to name + "_1". 68 l2_shrinkage_regularization_strength: A float value, must be greater than 69 or equal to zero. This differs from L2 above in that the L2 above is a 70 stabilization penalty, whereas this L2 shrinkage is a magnitude penalty. 71 The FTRL formulation can be written as: 72 w_{t+1} = argmin_w(\hat{g}_{1:t}w + L1*||w||_1 + L2*||w||_2^2), where 73 \hat{g} = g + (2*L2_shrinkage*w), and g is the gradient of the loss 74 function w.r.t. the weights w. 75 Specifically, in the absence of L1 regularization, it is equivalent to 76 the following update rule: 77 w_{t+1} = w_t - lr_t / (1 + 2*L2*lr_t) * g_t - 78 2*L2_shrinkage*lr_t / (1 + 2*L2*lr_t) * w_t 79 where lr_t is the learning rate at t. 80 When input is sparse shrinkage will only happen on the active weights. 81 82 Raises: 83 ValueError: If one of the arguments is invalid. 84 """ 85 super(FtrlOptimizer, self).__init__(use_locking, name) 86 87 if initial_accumulator_value <= 0.0: 88 raise ValueError("initial_accumulator_value %f needs to be positive" % 89 initial_accumulator_value) 90 if learning_rate_power > 0.0: 91 raise ValueError("learning_rate_power %f needs to be negative or zero" % 92 learning_rate_power) 93 if l1_regularization_strength < 0.0: 94 raise ValueError( 95 "l1_regularization_strength %f needs to be positive or zero" % 96 l1_regularization_strength) 97 if l2_regularization_strength < 0.0: 98 raise ValueError( 99 "l2_regularization_strength %f needs to be positive or zero" % 100 l2_regularization_strength) 101 if l2_shrinkage_regularization_strength < 0.0: 102 raise ValueError( 103 "l2_shrinkage_regularization_strength %f needs to be positive" 104 " or zero" % l2_shrinkage_regularization_strength) 105 106 self._learning_rate = learning_rate 107 self._learning_rate_power = learning_rate_power 108 self._initial_accumulator_value = initial_accumulator_value 109 self._l1_regularization_strength = l1_regularization_strength 110 self._l2_regularization_strength = l2_regularization_strength 111 self._l2_shrinkage_regularization_strength = ( 112 l2_shrinkage_regularization_strength) 113 self._learning_rate_tensor = None 114 self._learning_rate_power_tensor = None 115 self._l1_regularization_strength_tensor = None 116 self._l2_regularization_strength_tensor = None 117 self._l2_shrinkage_regularization_strength_tensor = None 118 self._accum_name = accum_name 119 self._linear_name = linear_name 120 121 def _create_slots(self, var_list): 122 # Create the "accum" and "linear" slots. 123 for v in var_list: 124 with ops.colocate_with(v): 125 val = constant_op.constant( 126 self._initial_accumulator_value, dtype=v.dtype, shape=v.get_shape()) 127 self._get_or_make_slot(v, val, "accum", self._accum_name or self._name) 128 self._zeros_slot(v, "linear", self._linear_name or self._name) 129 130 def _prepare(self): 131 self._learning_rate_tensor = ops.convert_to_tensor( 132 self._learning_rate, name="learning_rate") 133 self._l1_regularization_strength_tensor = ops.convert_to_tensor( 134 self._l1_regularization_strength, name="l1_regularization_strength") 135 self._l2_regularization_strength_tensor = ops.convert_to_tensor( 136 self._l2_regularization_strength, name="l2_regularization_strength") 137 self._l2_shrinkage_regularization_strength_tensor = ops.convert_to_tensor( 138 self._l2_shrinkage_regularization_strength, 139 name="l2_shrinkage_regularization_strength") 140 self._learning_rate_power_tensor = ops.convert_to_tensor( 141 self._learning_rate_power, name="learning_rate_power") 142 143 def _apply_dense(self, grad, var): 144 accum = self.get_slot(var, "accum") 145 linear = self.get_slot(var, "linear") 146 if self._l2_shrinkage_regularization_strength <= 0.0: 147 return training_ops.apply_ftrl( 148 var, 149 accum, 150 linear, 151 grad, 152 math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype), 153 math_ops.cast(self._l1_regularization_strength_tensor, 154 var.dtype.base_dtype), 155 math_ops.cast(self._l2_regularization_strength_tensor, 156 var.dtype.base_dtype), 157 math_ops.cast(self._learning_rate_power_tensor, var.dtype.base_dtype), 158 use_locking=self._use_locking) 159 else: 160 return training_ops.apply_ftrl_v2( 161 var, 162 accum, 163 linear, 164 grad, 165 math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype), 166 math_ops.cast(self._l1_regularization_strength_tensor, 167 var.dtype.base_dtype), 168 math_ops.cast(self._l2_regularization_strength_tensor, 169 var.dtype.base_dtype), 170 math_ops.cast(self._l2_shrinkage_regularization_strength_tensor, 171 var.dtype.base_dtype), 172 math_ops.cast(self._learning_rate_power_tensor, var.dtype.base_dtype), 173 use_locking=self._use_locking) 174 175 def _resource_apply_dense(self, grad, var): 176 accum = self.get_slot(var, "accum") 177 linear = self.get_slot(var, "linear") 178 if self._l2_shrinkage_regularization_strength <= 0.0: 179 return training_ops.resource_apply_ftrl( 180 var.handle, 181 accum.handle, 182 linear.handle, 183 grad, 184 math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype), 185 math_ops.cast(self._l1_regularization_strength_tensor, 186 var.dtype.base_dtype), 187 math_ops.cast(self._l2_regularization_strength_tensor, 188 var.dtype.base_dtype), 189 math_ops.cast(self._learning_rate_power_tensor, var.dtype.base_dtype), 190 use_locking=self._use_locking) 191 else: 192 return training_ops.resource_apply_ftrl_v2( 193 var.handle, 194 accum.handle, 195 linear.handle, 196 grad, 197 math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype), 198 math_ops.cast(self._l1_regularization_strength_tensor, 199 var.dtype.base_dtype), 200 math_ops.cast(self._l2_regularization_strength_tensor, 201 var.dtype.base_dtype), 202 math_ops.cast(self._l2_shrinkage_regularization_strength_tensor, 203 var.dtype.base_dtype), 204 math_ops.cast(self._learning_rate_power_tensor, var.dtype.base_dtype), 205 use_locking=self._use_locking) 206 207 def _apply_sparse(self, grad, var): 208 accum = self.get_slot(var, "accum") 209 linear = self.get_slot(var, "linear") 210 if self._l2_shrinkage_regularization_strength <= 0.0: 211 return training_ops.sparse_apply_ftrl( 212 var, 213 accum, 214 linear, 215 grad.values, 216 grad.indices, 217 math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype), 218 math_ops.cast(self._l1_regularization_strength_tensor, 219 var.dtype.base_dtype), 220 math_ops.cast(self._l2_regularization_strength_tensor, 221 var.dtype.base_dtype), 222 math_ops.cast(self._learning_rate_power_tensor, var.dtype.base_dtype), 223 use_locking=self._use_locking) 224 else: 225 return training_ops.sparse_apply_ftrl_v2( 226 var, 227 accum, 228 linear, 229 grad.values, 230 grad.indices, 231 math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype), 232 math_ops.cast(self._l1_regularization_strength_tensor, 233 var.dtype.base_dtype), 234 math_ops.cast(self._l2_regularization_strength_tensor, 235 var.dtype.base_dtype), 236 math_ops.cast(self._l2_shrinkage_regularization_strength_tensor, 237 grad.dtype.base_dtype), 238 math_ops.cast(self._learning_rate_power_tensor, var.dtype.base_dtype), 239 use_locking=self._use_locking) 240 241 def _resource_apply_sparse(self, grad, var, indices): 242 accum = self.get_slot(var, "accum") 243 linear = self.get_slot(var, "linear") 244 if self._l2_shrinkage_regularization_strength <= 0.0: 245 return training_ops.resource_sparse_apply_ftrl( 246 var.handle, 247 accum.handle, 248 linear.handle, 249 grad, 250 indices, 251 math_ops.cast(self._learning_rate_tensor, grad.dtype), 252 math_ops.cast(self._l1_regularization_strength_tensor, grad.dtype), 253 math_ops.cast(self._l2_regularization_strength_tensor, grad.dtype), 254 math_ops.cast(self._learning_rate_power_tensor, grad.dtype), 255 use_locking=self._use_locking) 256 else: 257 return training_ops.resource_sparse_apply_ftrl_v2( 258 var.handle, 259 accum.handle, 260 linear.handle, 261 grad, 262 indices, 263 math_ops.cast(self._learning_rate_tensor, grad.dtype), 264 math_ops.cast(self._l1_regularization_strength_tensor, grad.dtype), 265 math_ops.cast(self._l2_regularization_strength_tensor, grad.dtype), 266 math_ops.cast(self._l2_shrinkage_regularization_strength_tensor, 267 grad.dtype), 268 math_ops.cast(self._learning_rate_power_tensor, grad.dtype), 269 use_locking=self._use_locking) 270