1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""One-line documentation for rmsprop module. 16 17rmsprop algorithm [tieleman2012rmsprop] 18 19A detailed description of rmsprop. 20 21- maintain a moving (discounted) average of the square of gradients 22- divide gradient by the root of this average 23 24mean_square = decay * mean_square{t-1} + (1-decay) * gradient ** 2 25mom = momentum * mom{t-1} + learning_rate * g_t / sqrt(mean_square + epsilon) 26delta = - mom 27 28This implementation of RMSProp uses plain momentum, not Nesterov momentum. 29 30The centered version additionally maintains a moving (discounted) average of the 31gradients, and uses that average to estimate the variance: 32 33mean_grad = decay * mean_grad{t-1} + (1-decay) * gradient 34mean_square = decay * mean_square{t-1} + (1-decay) * gradient ** 2 35mom = momentum * mom{t-1} + learning_rate * g_t / 36 sqrt(mean_square - mean_grad**2 + epsilon) 37delta = - mom 38""" 39 40from __future__ import absolute_import 41from __future__ import division 42from __future__ import print_function 43 44from tensorflow.python.framework import ops 45from tensorflow.python.ops import array_ops 46from tensorflow.python.ops import init_ops 47from tensorflow.python.ops import math_ops 48from tensorflow.python.training import optimizer 49from tensorflow.python.training import training_ops 50from tensorflow.python.util.tf_export import tf_export 51 52 53@tf_export(v1=["train.RMSPropOptimizer"]) 54class RMSPropOptimizer(optimizer.Optimizer): 55 """Optimizer that implements the RMSProp algorithm (Tielemans et al. 56 57 2012). 58 59 References: 60 Coursera slide 29: 61 Hinton, 2012 62 ([pdf](http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)) 63 """ 64 65 def __init__(self, 66 learning_rate, 67 decay=0.9, 68 momentum=0.0, 69 epsilon=1e-10, 70 use_locking=False, 71 centered=False, 72 name="RMSProp"): 73 """Construct a new RMSProp optimizer. 74 75 Note that in the dense implementation of this algorithm, variables and their 76 corresponding accumulators (momentum, gradient moving average, square 77 gradient moving average) will be updated even if the gradient is zero 78 (i.e. accumulators will decay, momentum will be applied). The sparse 79 implementation (used when the gradient is an `IndexedSlices` object, 80 typically because of `tf.gather` or an embedding lookup in the forward pass) 81 will not update variable slices or their accumulators unless those slices 82 were used in the forward pass (nor is there an "eventual" correction to 83 account for these omitted updates). This leads to more efficient updates for 84 large embedding lookup tables (where most of the slices are not accessed in 85 a particular graph execution), but differs from the published algorithm. 86 87 Args: 88 learning_rate: A Tensor or a floating point value. The learning rate. 89 decay: Discounting factor for the history/coming gradient 90 momentum: A scalar tensor. 91 epsilon: Small value to avoid zero denominator. 92 use_locking: If True use locks for update operation. 93 centered: If True, gradients are normalized by the estimated variance of 94 the gradient; if False, by the uncentered second moment. Setting this to 95 True may help with training, but is slightly more expensive in terms of 96 computation and memory. Defaults to False. 97 name: Optional name prefix for the operations created when applying 98 gradients. Defaults to "RMSProp". 99 100 @compatibility(eager) 101 When eager execution is enabled, `learning_rate`, `decay`, `momentum`, and 102 `epsilon` can each be a callable that takes no arguments and returns the 103 actual value to use. This can be useful for changing these values across 104 different invocations of optimizer functions. 105 @end_compatibility 106 """ 107 super(RMSPropOptimizer, self).__init__(use_locking, name) 108 self._learning_rate = learning_rate 109 self._decay = decay 110 self._momentum = momentum 111 self._epsilon = epsilon 112 self._centered = centered 113 114 # Tensors for learning rate and momentum. Created in _prepare. 115 self._learning_rate_tensor = None 116 self._decay_tensor = None 117 self._momentum_tensor = None 118 self._epsilon_tensor = None 119 120 def _create_slots(self, var_list): 121 for v in var_list: 122 if v.get_shape().is_fully_defined(): 123 init_rms = init_ops.ones_initializer(dtype=v.dtype.base_dtype) 124 else: 125 init_rms = array_ops.ones_like(v) 126 self._get_or_make_slot_with_initializer(v, init_rms, v.get_shape(), 127 v.dtype.base_dtype, "rms", 128 self._name) 129 if self._centered: 130 self._zeros_slot(v, "mg", self._name) 131 self._zeros_slot(v, "momentum", self._name) 132 133 def _prepare(self): 134 lr = self._call_if_callable(self._learning_rate) 135 decay = self._call_if_callable(self._decay) 136 momentum = self._call_if_callable(self._momentum) 137 epsilon = self._call_if_callable(self._epsilon) 138 139 self._learning_rate_tensor = ops.convert_to_tensor(lr, name="learning_rate") 140 self._decay_tensor = ops.convert_to_tensor(decay, name="decay") 141 self._momentum_tensor = ops.convert_to_tensor(momentum, name="momentum") 142 self._epsilon_tensor = ops.convert_to_tensor(epsilon, name="epsilon") 143 144 def _apply_dense(self, grad, var): 145 rms = self.get_slot(var, "rms") 146 mom = self.get_slot(var, "momentum") 147 if self._centered: 148 mg = self.get_slot(var, "mg") 149 return training_ops.apply_centered_rms_prop( 150 var, 151 mg, 152 rms, 153 mom, 154 math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype), 155 math_ops.cast(self._decay_tensor, var.dtype.base_dtype), 156 math_ops.cast(self._momentum_tensor, var.dtype.base_dtype), 157 math_ops.cast(self._epsilon_tensor, var.dtype.base_dtype), 158 grad, 159 use_locking=self._use_locking).op 160 else: 161 return training_ops.apply_rms_prop( 162 var, 163 rms, 164 mom, 165 math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype), 166 math_ops.cast(self._decay_tensor, var.dtype.base_dtype), 167 math_ops.cast(self._momentum_tensor, var.dtype.base_dtype), 168 math_ops.cast(self._epsilon_tensor, var.dtype.base_dtype), 169 grad, 170 use_locking=self._use_locking).op 171 172 def _resource_apply_dense(self, grad, var): 173 rms = self.get_slot(var, "rms") 174 mom = self.get_slot(var, "momentum") 175 if self._centered: 176 mg = self.get_slot(var, "mg") 177 return training_ops.resource_apply_centered_rms_prop( 178 var.handle, 179 mg.handle, 180 rms.handle, 181 mom.handle, 182 math_ops.cast(self._learning_rate_tensor, grad.dtype.base_dtype), 183 math_ops.cast(self._decay_tensor, grad.dtype.base_dtype), 184 math_ops.cast(self._momentum_tensor, grad.dtype.base_dtype), 185 math_ops.cast(self._epsilon_tensor, grad.dtype.base_dtype), 186 grad, 187 use_locking=self._use_locking) 188 else: 189 return training_ops.resource_apply_rms_prop( 190 var.handle, 191 rms.handle, 192 mom.handle, 193 math_ops.cast(self._learning_rate_tensor, grad.dtype.base_dtype), 194 math_ops.cast(self._decay_tensor, grad.dtype.base_dtype), 195 math_ops.cast(self._momentum_tensor, grad.dtype.base_dtype), 196 math_ops.cast(self._epsilon_tensor, grad.dtype.base_dtype), 197 grad, 198 use_locking=self._use_locking) 199 200 def _apply_sparse(self, grad, var): 201 rms = self.get_slot(var, "rms") 202 mom = self.get_slot(var, "momentum") 203 if self._centered: 204 mg = self.get_slot(var, "mg") 205 return training_ops.sparse_apply_centered_rms_prop( 206 var, 207 mg, 208 rms, 209 mom, 210 math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype), 211 math_ops.cast(self._decay_tensor, var.dtype.base_dtype), 212 math_ops.cast(self._momentum_tensor, var.dtype.base_dtype), 213 math_ops.cast(self._epsilon_tensor, var.dtype.base_dtype), 214 grad.values, 215 grad.indices, 216 use_locking=self._use_locking) 217 else: 218 return training_ops.sparse_apply_rms_prop( 219 var, 220 rms, 221 mom, 222 math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype), 223 math_ops.cast(self._decay_tensor, var.dtype.base_dtype), 224 math_ops.cast(self._momentum_tensor, var.dtype.base_dtype), 225 math_ops.cast(self._epsilon_tensor, var.dtype.base_dtype), 226 grad.values, 227 grad.indices, 228 use_locking=self._use_locking) 229 230 def _resource_apply_sparse(self, grad, var, indices): 231 rms = self.get_slot(var, "rms") 232 mom = self.get_slot(var, "momentum") 233 if self._centered: 234 mg = self.get_slot(var, "mg") 235 return training_ops.resource_sparse_apply_centered_rms_prop( 236 var.handle, 237 mg.handle, 238 rms.handle, 239 mom.handle, 240 math_ops.cast(self._learning_rate_tensor, grad.dtype), 241 math_ops.cast(self._decay_tensor, grad.dtype), 242 math_ops.cast(self._momentum_tensor, grad.dtype), 243 math_ops.cast(self._epsilon_tensor, grad.dtype), 244 grad, 245 indices, 246 use_locking=self._use_locking) 247 else: 248 return training_ops.resource_sparse_apply_rms_prop( 249 var.handle, 250 rms.handle, 251 mom.handle, 252 math_ops.cast(self._learning_rate_tensor, grad.dtype), 253 math_ops.cast(self._decay_tensor, grad.dtype), 254 math_ops.cast(self._momentum_tensor, grad.dtype), 255 math_ops.cast(self._epsilon_tensor, grad.dtype), 256 grad, 257 indices, 258 use_locking=self._use_locking) 259