• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""One-line documentation for rmsprop module.
16
17rmsprop algorithm [tieleman2012rmsprop]
18
19A detailed description of rmsprop.
20
21- maintain a moving (discounted) average of the square of gradients
22- divide gradient by the root of this average
23
24mean_square = decay * mean_square{t-1} + (1-decay) * gradient ** 2
25mom = momentum * mom{t-1} + learning_rate * g_t / sqrt(mean_square + epsilon)
26delta = - mom
27
28This implementation of RMSProp uses plain momentum, not Nesterov momentum.
29
30The centered version additionally maintains a moving (discounted) average of the
31gradients, and uses that average to estimate the variance:
32
33mean_grad = decay * mean_grad{t-1} + (1-decay) * gradient
34mean_square = decay * mean_square{t-1} + (1-decay) * gradient ** 2
35mom = momentum * mom{t-1} + learning_rate * g_t /
36    sqrt(mean_square - mean_grad**2 + epsilon)
37delta = - mom
38"""
39
40from __future__ import absolute_import
41from __future__ import division
42from __future__ import print_function
43
44from tensorflow.python.framework import ops
45from tensorflow.python.ops import array_ops
46from tensorflow.python.ops import init_ops
47from tensorflow.python.ops import math_ops
48from tensorflow.python.training import optimizer
49from tensorflow.python.training import training_ops
50from tensorflow.python.util.tf_export import tf_export
51
52
53@tf_export(v1=["train.RMSPropOptimizer"])
54class RMSPropOptimizer(optimizer.Optimizer):
55  """Optimizer that implements the RMSProp algorithm (Tielemans et al.
56
57  2012).
58
59  References:
60    Coursera slide 29:
61    Hinton, 2012
62    ([pdf](http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf))
63  """
64
65  def __init__(self,
66               learning_rate,
67               decay=0.9,
68               momentum=0.0,
69               epsilon=1e-10,
70               use_locking=False,
71               centered=False,
72               name="RMSProp"):
73    """Construct a new RMSProp optimizer.
74
75    Note that in the dense implementation of this algorithm, variables and their
76    corresponding accumulators (momentum, gradient moving average, square
77    gradient moving average) will be updated even if the gradient is zero
78    (i.e. accumulators will decay, momentum will be applied). The sparse
79    implementation (used when the gradient is an `IndexedSlices` object,
80    typically because of `tf.gather` or an embedding lookup in the forward pass)
81    will not update variable slices or their accumulators unless those slices
82    were used in the forward pass (nor is there an "eventual" correction to
83    account for these omitted updates). This leads to more efficient updates for
84    large embedding lookup tables (where most of the slices are not accessed in
85    a particular graph execution), but differs from the published algorithm.
86
87    Args:
88      learning_rate: A Tensor or a floating point value.  The learning rate.
89      decay: Discounting factor for the history/coming gradient
90      momentum: A scalar tensor.
91      epsilon: Small value to avoid zero denominator.
92      use_locking: If True use locks for update operation.
93      centered: If True, gradients are normalized by the estimated variance of
94        the gradient; if False, by the uncentered second moment. Setting this to
95        True may help with training, but is slightly more expensive in terms of
96        computation and memory. Defaults to False.
97      name: Optional name prefix for the operations created when applying
98        gradients. Defaults to "RMSProp".
99
100    @compatibility(eager)
101    When eager execution is enabled, `learning_rate`, `decay`, `momentum`, and
102    `epsilon` can each be a callable that takes no arguments and returns the
103    actual value to use. This can be useful for changing these values across
104    different invocations of optimizer functions.
105    @end_compatibility
106    """
107    super(RMSPropOptimizer, self).__init__(use_locking, name)
108    self._learning_rate = learning_rate
109    self._decay = decay
110    self._momentum = momentum
111    self._epsilon = epsilon
112    self._centered = centered
113
114    # Tensors for learning rate and momentum.  Created in _prepare.
115    self._learning_rate_tensor = None
116    self._decay_tensor = None
117    self._momentum_tensor = None
118    self._epsilon_tensor = None
119
120  def _create_slots(self, var_list):
121    for v in var_list:
122      if v.get_shape().is_fully_defined():
123        init_rms = init_ops.ones_initializer(dtype=v.dtype.base_dtype)
124      else:
125        init_rms = array_ops.ones_like(v)
126      self._get_or_make_slot_with_initializer(v, init_rms, v.get_shape(),
127                                              v.dtype.base_dtype, "rms",
128                                              self._name)
129      if self._centered:
130        self._zeros_slot(v, "mg", self._name)
131      self._zeros_slot(v, "momentum", self._name)
132
133  def _prepare(self):
134    lr = self._call_if_callable(self._learning_rate)
135    decay = self._call_if_callable(self._decay)
136    momentum = self._call_if_callable(self._momentum)
137    epsilon = self._call_if_callable(self._epsilon)
138
139    self._learning_rate_tensor = ops.convert_to_tensor(lr, name="learning_rate")
140    self._decay_tensor = ops.convert_to_tensor(decay, name="decay")
141    self._momentum_tensor = ops.convert_to_tensor(momentum, name="momentum")
142    self._epsilon_tensor = ops.convert_to_tensor(epsilon, name="epsilon")
143
144  def _apply_dense(self, grad, var):
145    rms = self.get_slot(var, "rms")
146    mom = self.get_slot(var, "momentum")
147    if self._centered:
148      mg = self.get_slot(var, "mg")
149      return training_ops.apply_centered_rms_prop(
150          var,
151          mg,
152          rms,
153          mom,
154          math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype),
155          math_ops.cast(self._decay_tensor, var.dtype.base_dtype),
156          math_ops.cast(self._momentum_tensor, var.dtype.base_dtype),
157          math_ops.cast(self._epsilon_tensor, var.dtype.base_dtype),
158          grad,
159          use_locking=self._use_locking).op
160    else:
161      return training_ops.apply_rms_prop(
162          var,
163          rms,
164          mom,
165          math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype),
166          math_ops.cast(self._decay_tensor, var.dtype.base_dtype),
167          math_ops.cast(self._momentum_tensor, var.dtype.base_dtype),
168          math_ops.cast(self._epsilon_tensor, var.dtype.base_dtype),
169          grad,
170          use_locking=self._use_locking).op
171
172  def _resource_apply_dense(self, grad, var):
173    rms = self.get_slot(var, "rms")
174    mom = self.get_slot(var, "momentum")
175    if self._centered:
176      mg = self.get_slot(var, "mg")
177      return training_ops.resource_apply_centered_rms_prop(
178          var.handle,
179          mg.handle,
180          rms.handle,
181          mom.handle,
182          math_ops.cast(self._learning_rate_tensor, grad.dtype.base_dtype),
183          math_ops.cast(self._decay_tensor, grad.dtype.base_dtype),
184          math_ops.cast(self._momentum_tensor, grad.dtype.base_dtype),
185          math_ops.cast(self._epsilon_tensor, grad.dtype.base_dtype),
186          grad,
187          use_locking=self._use_locking)
188    else:
189      return training_ops.resource_apply_rms_prop(
190          var.handle,
191          rms.handle,
192          mom.handle,
193          math_ops.cast(self._learning_rate_tensor, grad.dtype.base_dtype),
194          math_ops.cast(self._decay_tensor, grad.dtype.base_dtype),
195          math_ops.cast(self._momentum_tensor, grad.dtype.base_dtype),
196          math_ops.cast(self._epsilon_tensor, grad.dtype.base_dtype),
197          grad,
198          use_locking=self._use_locking)
199
200  def _apply_sparse(self, grad, var):
201    rms = self.get_slot(var, "rms")
202    mom = self.get_slot(var, "momentum")
203    if self._centered:
204      mg = self.get_slot(var, "mg")
205      return training_ops.sparse_apply_centered_rms_prop(
206          var,
207          mg,
208          rms,
209          mom,
210          math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype),
211          math_ops.cast(self._decay_tensor, var.dtype.base_dtype),
212          math_ops.cast(self._momentum_tensor, var.dtype.base_dtype),
213          math_ops.cast(self._epsilon_tensor, var.dtype.base_dtype),
214          grad.values,
215          grad.indices,
216          use_locking=self._use_locking)
217    else:
218      return training_ops.sparse_apply_rms_prop(
219          var,
220          rms,
221          mom,
222          math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype),
223          math_ops.cast(self._decay_tensor, var.dtype.base_dtype),
224          math_ops.cast(self._momentum_tensor, var.dtype.base_dtype),
225          math_ops.cast(self._epsilon_tensor, var.dtype.base_dtype),
226          grad.values,
227          grad.indices,
228          use_locking=self._use_locking)
229
230  def _resource_apply_sparse(self, grad, var, indices):
231    rms = self.get_slot(var, "rms")
232    mom = self.get_slot(var, "momentum")
233    if self._centered:
234      mg = self.get_slot(var, "mg")
235      return training_ops.resource_sparse_apply_centered_rms_prop(
236          var.handle,
237          mg.handle,
238          rms.handle,
239          mom.handle,
240          math_ops.cast(self._learning_rate_tensor, grad.dtype),
241          math_ops.cast(self._decay_tensor, grad.dtype),
242          math_ops.cast(self._momentum_tensor, grad.dtype),
243          math_ops.cast(self._epsilon_tensor, grad.dtype),
244          grad,
245          indices,
246          use_locking=self._use_locking)
247    else:
248      return training_ops.resource_sparse_apply_rms_prop(
249          var.handle,
250          rms.handle,
251          mom.handle,
252          math_ops.cast(self._learning_rate_tensor, grad.dtype),
253          math_ops.cast(self._decay_tensor, grad.dtype),
254          math_ops.cast(self._momentum_tensor, grad.dtype),
255          math_ops.cast(self._epsilon_tensor, grad.dtype),
256          grad,
257          indices,
258          use_locking=self._use_locking)
259