• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""One-line documentation for rmsprop module.
16
17rmsprop algorithm [tieleman2012rmsprop]
18
19A detailed description of rmsprop.
20
21- maintain a moving (discounted) average of the square of gradients
22- divide gradient by the root of this average
23
24mean_square = decay * mean_square{t-1} + (1-decay) * gradient ** 2
25mom = momentum * mom{t-1} + learning_rate * g_t / sqrt(mean_square + epsilon)
26delta = - mom
27
28This implementation of RMSProp uses plain momentum, not Nesterov momentum.
29
30The centered version additionally maintains a moving (discounted) average of the
31gradients, and uses that average to estimate the variance:
32
33mean_grad = decay * mean_grad{t-1} + (1-decay) * gradient
34mean_square = decay * mean_square{t-1} + (1-decay) * gradient ** 2
35mom = momentum * mom{t-1} + learning_rate * g_t /
36    sqrt(mean_square - mean_grad**2 + epsilon)
37delta = - mom
38"""
39
40from tensorflow.python.framework import ops
41from tensorflow.python.ops import array_ops
42from tensorflow.python.ops import init_ops
43from tensorflow.python.ops import math_ops
44from tensorflow.python.training import optimizer
45from tensorflow.python.training import training_ops
46from tensorflow.python.util.tf_export import tf_export
47
48
49@tf_export(v1=["train.RMSPropOptimizer"])
50class RMSPropOptimizer(optimizer.Optimizer):
51  """Optimizer that implements the RMSProp algorithm (Tielemans et al.
52
53  2012).
54
55  References:
56    Coursera slide 29:
57    Hinton, 2012
58    ([pdf](http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf))
59
60  @compatibility(TF2)
61  tf.compat.v1.train.RMSPropOptimizer is compatible with eager mode and
62  `tf.function`.
63  When eager execution is enabled, `learning_rate`, `decay`, `momentum`,
64  and `epsilon` can each be a callable that
65  takes no arguments and returns the actual value to use. This can be useful
66  for changing these values across different invocations of optimizer
67  functions.
68
69  To switch to native TF2 style, use [`tf.keras.optimizers.RMSprop`]
70  (https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/RMSprop)
71  instead. Please notice that due to the implementation differences,
72  `tf.keras.optimizers.RMSprop` and
73  `tf.compat.v1.train.RMSPropOptimizer` may have slight differences in
74  floating point numerics even though the formula used for the variable
75  updates still matches.
76
77  #### Structural mapping to native TF2
78
79  Before:
80
81  ```python
82  optimizer = tf.compat.v1.train.RMSPropOptimizer(
83    learning_rate=learning_rate,
84    decay=decay,
85    momentum=momentum,
86    epsilon=epsilon)
87  ```
88
89  After:
90
91  ```python
92  optimizer = tf.keras.optimizers.RMSprop(
93    learning_rate=learning_rate,
94    rho=decay,
95    momentum=momentum,
96    epsilon=epsilon)
97  ```
98
99  #### How to map arguments
100  | TF1 Arg Name       | TF2 Arg Name   | Note                             |
101  | ------------------ | -------------  | -------------------------------  |
102  | `learning_rate`    | `learning_rate`| Be careful of setting           |
103  : : : learning_rate tensor value computed from the global step.          :
104  : : : In TF1 this was usually meant to imply a dynamic learning rate and :
105  : : : would recompute in each step. In TF2 (eager + function) it will    :
106  : : : treat it as a scalar value that only gets computed once instead of :
107  : : : a symbolic placeholder to be computed each time.                   :
108  | `decay`            | `rho`          | -                                |
109  | `momentum`         | `momentum`     | -                                |
110  | `epsilon`          | `epsilon`      | Default value is 1e-10 in TF1,   |
111  :                    :                : but 1e-07 in TF2.                :
112  | `use_locking`      | -              | Not applicable in TF2.           |
113
114  #### Before & after usage example
115  Before:
116
117  ```python
118  x = tf.Variable([1,2,3], dtype=tf.float32)
119  grad = tf.constant([0.1, 0.2, 0.3])
120  optimizer = tf.compat.v1.train.RMSPropOptimizer(learning_rate=0.001)
121  optimizer.apply_gradients(zip([grad], [x]))
122  ```
123
124  After:
125
126  ```python
127  x = tf.Variable([1,2,3], dtype=tf.float32)
128  grad = tf.constant([0.1, 0.2, 0.3])
129  optimizer = tf.keras.optimizers.RMSprop(learning_rate=0.001)
130  optimizer.apply_gradients(zip([grad], [x]))
131  ```
132
133  @end_compatibility
134  """
135
136  def __init__(self,
137               learning_rate,
138               decay=0.9,
139               momentum=0.0,
140               epsilon=1e-10,
141               use_locking=False,
142               centered=False,
143               name="RMSProp"):
144    """Construct a new RMSProp optimizer.
145
146    Note that in the dense implementation of this algorithm, variables and their
147    corresponding accumulators (momentum, gradient moving average, square
148    gradient moving average) will be updated even if the gradient is zero
149    (i.e. accumulators will decay, momentum will be applied). The sparse
150    implementation (used when the gradient is an `IndexedSlices` object,
151    typically because of `tf.gather` or an embedding lookup in the forward pass)
152    will not update variable slices or their accumulators unless those slices
153    were used in the forward pass (nor is there an "eventual" correction to
154    account for these omitted updates). This leads to more efficient updates for
155    large embedding lookup tables (where most of the slices are not accessed in
156    a particular graph execution), but differs from the published algorithm.
157
158    Args:
159      learning_rate: A Tensor or a floating point value.  The learning rate.
160      decay: Discounting factor for the history/coming gradient
161      momentum: A scalar tensor.
162      epsilon: Small value to avoid zero denominator.
163      use_locking: If True use locks for update operation.
164      centered: If True, gradients are normalized by the estimated variance of
165        the gradient; if False, by the uncentered second moment. Setting this to
166        True may help with training, but is slightly more expensive in terms of
167        computation and memory. Defaults to False.
168      name: Optional name prefix for the operations created when applying
169        gradients. Defaults to "RMSProp".
170
171    """
172    super(RMSPropOptimizer, self).__init__(use_locking, name)
173    self._learning_rate = learning_rate
174    self._decay = decay
175    self._momentum = momentum
176    self._epsilon = epsilon
177    self._centered = centered
178
179    # Tensors for learning rate and momentum.  Created in _prepare.
180    self._learning_rate_tensor = None
181    self._decay_tensor = None
182    self._momentum_tensor = None
183    self._epsilon_tensor = None
184
185  def _create_slots(self, var_list):
186    for v in var_list:
187      if v.get_shape().is_fully_defined():
188        init_rms = init_ops.ones_initializer(dtype=v.dtype.base_dtype)
189      else:
190        init_rms = array_ops.ones_like(v)
191      self._get_or_make_slot_with_initializer(v, init_rms, v.get_shape(),
192                                              v.dtype.base_dtype, "rms",
193                                              self._name)
194      if self._centered:
195        self._zeros_slot(v, "mg", self._name)
196      self._zeros_slot(v, "momentum", self._name)
197
198  def _prepare(self):
199    lr = self._call_if_callable(self._learning_rate)
200    decay = self._call_if_callable(self._decay)
201    momentum = self._call_if_callable(self._momentum)
202    epsilon = self._call_if_callable(self._epsilon)
203
204    self._learning_rate_tensor = ops.convert_to_tensor(lr, name="learning_rate")
205    self._decay_tensor = ops.convert_to_tensor(decay, name="decay")
206    self._momentum_tensor = ops.convert_to_tensor(momentum, name="momentum")
207    self._epsilon_tensor = ops.convert_to_tensor(epsilon, name="epsilon")
208
209  def _apply_dense(self, grad, var):
210    rms = self.get_slot(var, "rms")
211    mom = self.get_slot(var, "momentum")
212    if self._centered:
213      mg = self.get_slot(var, "mg")
214      return training_ops.apply_centered_rms_prop(
215          var,
216          mg,
217          rms,
218          mom,
219          math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype),
220          math_ops.cast(self._decay_tensor, var.dtype.base_dtype),
221          math_ops.cast(self._momentum_tensor, var.dtype.base_dtype),
222          math_ops.cast(self._epsilon_tensor, var.dtype.base_dtype),
223          grad,
224          use_locking=self._use_locking).op
225    else:
226      return training_ops.apply_rms_prop(
227          var,
228          rms,
229          mom,
230          math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype),
231          math_ops.cast(self._decay_tensor, var.dtype.base_dtype),
232          math_ops.cast(self._momentum_tensor, var.dtype.base_dtype),
233          math_ops.cast(self._epsilon_tensor, var.dtype.base_dtype),
234          grad,
235          use_locking=self._use_locking).op
236
237  def _resource_apply_dense(self, grad, var):
238    rms = self.get_slot(var, "rms")
239    mom = self.get_slot(var, "momentum")
240    if self._centered:
241      mg = self.get_slot(var, "mg")
242      return training_ops.resource_apply_centered_rms_prop(
243          var.handle,
244          mg.handle,
245          rms.handle,
246          mom.handle,
247          math_ops.cast(self._learning_rate_tensor, grad.dtype.base_dtype),
248          math_ops.cast(self._decay_tensor, grad.dtype.base_dtype),
249          math_ops.cast(self._momentum_tensor, grad.dtype.base_dtype),
250          math_ops.cast(self._epsilon_tensor, grad.dtype.base_dtype),
251          grad,
252          use_locking=self._use_locking)
253    else:
254      return training_ops.resource_apply_rms_prop(
255          var.handle,
256          rms.handle,
257          mom.handle,
258          math_ops.cast(self._learning_rate_tensor, grad.dtype.base_dtype),
259          math_ops.cast(self._decay_tensor, grad.dtype.base_dtype),
260          math_ops.cast(self._momentum_tensor, grad.dtype.base_dtype),
261          math_ops.cast(self._epsilon_tensor, grad.dtype.base_dtype),
262          grad,
263          use_locking=self._use_locking)
264
265  def _apply_sparse(self, grad, var):
266    rms = self.get_slot(var, "rms")
267    mom = self.get_slot(var, "momentum")
268    if self._centered:
269      mg = self.get_slot(var, "mg")
270      return training_ops.sparse_apply_centered_rms_prop(
271          var,
272          mg,
273          rms,
274          mom,
275          math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype),
276          math_ops.cast(self._decay_tensor, var.dtype.base_dtype),
277          math_ops.cast(self._momentum_tensor, var.dtype.base_dtype),
278          math_ops.cast(self._epsilon_tensor, var.dtype.base_dtype),
279          grad.values,
280          grad.indices,
281          use_locking=self._use_locking)
282    else:
283      return training_ops.sparse_apply_rms_prop(
284          var,
285          rms,
286          mom,
287          math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype),
288          math_ops.cast(self._decay_tensor, var.dtype.base_dtype),
289          math_ops.cast(self._momentum_tensor, var.dtype.base_dtype),
290          math_ops.cast(self._epsilon_tensor, var.dtype.base_dtype),
291          grad.values,
292          grad.indices,
293          use_locking=self._use_locking)
294
295  def _resource_apply_sparse(self, grad, var, indices):
296    rms = self.get_slot(var, "rms")
297    mom = self.get_slot(var, "momentum")
298    if self._centered:
299      mg = self.get_slot(var, "mg")
300      return training_ops.resource_sparse_apply_centered_rms_prop(
301          var.handle,
302          mg.handle,
303          rms.handle,
304          mom.handle,
305          math_ops.cast(self._learning_rate_tensor, grad.dtype),
306          math_ops.cast(self._decay_tensor, grad.dtype),
307          math_ops.cast(self._momentum_tensor, grad.dtype),
308          math_ops.cast(self._epsilon_tensor, grad.dtype),
309          grad,
310          indices,
311          use_locking=self._use_locking)
312    else:
313      return training_ops.resource_sparse_apply_rms_prop(
314          var.handle,
315          rms.handle,
316          mom.handle,
317          math_ops.cast(self._learning_rate_tensor, grad.dtype),
318          math_ops.cast(self._decay_tensor, grad.dtype),
319          math_ops.cast(self._momentum_tensor, grad.dtype),
320          math_ops.cast(self._epsilon_tensor, grad.dtype),
321          grad,
322          indices,
323          use_locking=self._use_locking)
324