• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""RMSprop for TensorFlow."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import numpy as np
21
22from tensorflow.python.framework import ops
23from tensorflow.python.keras import backend_config
24from tensorflow.python.keras.optimizer_v2 import optimizer_v2
25from tensorflow.python.ops import array_ops
26from tensorflow.python.ops import control_flow_ops
27from tensorflow.python.ops import math_ops
28from tensorflow.python.ops import state_ops
29from tensorflow.python.training import training_ops
30from tensorflow.python.util.tf_export import keras_export
31
32
33@keras_export("keras.optimizers.RMSprop")
34class RMSprop(optimizer_v2.OptimizerV2):
35  r"""Optimizer that implements the RMSprop algorithm.
36
37  A detailed description of rmsprop.
38
39    - maintain a moving (discounted) average of the square of gradients
40    - divide gradient by the root of this average
41
42  $$mean_square_t = rho * mean_square{t-1} + (1-rho) * gradient ** 2$$
43  $$mom_t = momentum * mom_{t-1} + learning_rate * gradient / \sqrt{ /
44      mean_square_t + \epsilon}$$
45  $$variable_t := variable_{t-1} - mom_t
46
47  This implementation of RMSprop uses plain momentum, not Nesterov momentum.
48
49  The centered version additionally maintains a moving average of the
50  gradients, and uses that average to estimate the variance:
51
52  $$mean_grad_t = rho * mean_grad_{t-1} + (1-rho) * gradient$$
53  $$mean_square_t = rho * mean_square_{t-1} + (1-rho) * gradient ** 2$$
54  $$mom_t = momentum * mom_{t-1} + learning_rate * gradient /
55      sqrt(mean_square_t - mean_grad_t**2 + epsilon)$$
56  $$variable_t := variable_{t-1} - mom_t
57
58  References
59    See ([pdf]
60      http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf).
61  """
62
63  def __init__(self,
64               learning_rate=0.001,
65               rho=0.9,
66               momentum=0.0,
67               epsilon=1e-7,
68               centered=False,
69               name="RMSprop",
70               **kwargs):
71    """Construct a new RMSprop optimizer.
72
73    Note that in the dense implementation of this algorithm, variables and their
74    corresponding accumulators (momentum, gradient moving average, square
75    gradient moving average) will be updated even if the gradient is zero
76    (i.e. accumulators will decay, momentum will be applied). The sparse
77    implementation (used when the gradient is an `IndexedSlices` object,
78    typically because of `tf.gather` or an embedding lookup in the forward pass)
79    will not update variable slices or their accumulators unless those slices
80    were used in the forward pass (nor is there an "eventual" correction to
81    account for these omitted updates). This leads to more efficient updates for
82    large embedding lookup tables (where most of the slices are not accessed in
83    a particular graph execution), but differs from the published algorithm.
84
85    Args:
86      learning_rate: A Tensor or a floating point value.  The learning rate.
87      rho: Discounting factor for the history/coming gradient
88      momentum: A scalar tensor.
89      epsilon: Small value to avoid zero denominator.
90      centered: If True, gradients are normalized by the estimated variance of
91        the gradient; if False, by the uncentered second moment. Setting this to
92        True may help with training, but is slightly more expensive in terms of
93        computation and memory. Defaults to False.
94      name: Optional name prefix for the operations created when applying
95        gradients. Defaults to "RMSprop".  @compatibility(eager) When eager
96        execution is enabled, `learning_rate`, `decay`, `momentum`, and
97        `epsilon` can each be a callable that takes no arguments and returns the
98        actual value to use. This can be useful for changing these values across
99        different invocations of optimizer functions. @end_compatibility
100      **kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, `lr`,
101        `decay`}. `clipnorm` is clip gradients by norm; `clipvalue` is clip
102        gradients by value, `decay` is included for backward compatibility to
103        allow time inverse decay of learning rate. `lr` is included for backward
104        compatibility, recommended to use `learning_rate` instead.
105    """
106    if epsilon is None:
107      epsilon = backend_config.epsilon()
108    super(RMSprop, self).__init__(name, **kwargs)
109    self._set_hyper("learning_rate", kwargs.get("lr", learning_rate))
110    self._set_hyper("decay", self._initial_decay)
111    self._set_hyper("rho", rho)
112
113    self._momentum = False
114    if isinstance(momentum, ops.Tensor) or callable(momentum) or momentum > 0:
115      self._momentum = True
116    if isinstance(momentum, (int, float)) and (momentum < 0 or momentum > 1):
117      raise ValueError("`momentum` must be between [0, 1].")
118    self._set_hyper("momentum", momentum)
119
120    self._set_hyper("epsilon", epsilon)
121    self.centered = centered
122
123  def _create_slots(self, var_list):
124    for var in var_list:
125      self.add_slot(var, "rms")
126    if self._momentum:
127      for var in var_list:
128        self.add_slot(var, "momentum")
129    if self.centered:
130      for var in var_list:
131        self.add_slot(var, "mg")
132
133  def _resource_apply_dense(self, grad, var):
134    var_dtype = var.dtype.base_dtype
135    lr_t = self._decayed_lr(var_dtype)
136    rms = self.get_slot(var, "rms")
137    rho = self._get_hyper("rho", var_dtype)
138    momentum = self._get_hyper("momentum", var_dtype)
139    epsilon = self._get_hyper("epsilon", var_dtype)
140    if self._momentum:
141      mom = self.get_slot(var, "momentum")
142      if self.centered:
143        mg = self.get_slot(var, "mg")
144        return training_ops.resource_apply_centered_rms_prop(
145            var.handle,
146            mg.handle,
147            rms.handle,
148            mom.handle,
149            lr_t,
150            rho,
151            momentum,
152            epsilon,
153            grad,
154            use_locking=self._use_locking)
155      else:
156        return training_ops.resource_apply_rms_prop(
157            var.handle,
158            rms.handle,
159            mom.handle,
160            lr_t,
161            rho,
162            momentum,
163            epsilon,
164            grad,
165            use_locking=self._use_locking)
166    else:
167      rms_t = rho * rms + (1. - rho) * math_ops.square(grad)
168      rms_t = state_ops.assign(rms, rms_t, use_locking=self._use_locking)
169      denom_t = rms_t
170      if self.centered:
171        mg = self.get_slot(var, "mg")
172        mg_t = rho * mg + (1. - rho) * grad
173        mg_t = state_ops.assign(mg, mg_t, use_locking=self._use_locking)
174        denom_t = rms_t - math_ops.square(mg_t)
175      var_t = var - lr_t * grad / (math_ops.sqrt(denom_t) + epsilon)
176      return state_ops.assign(var, var_t, use_locking=self._use_locking).op
177
178  def _resource_apply_sparse(self, grad, var, indices):
179    var_dtype = var.dtype.base_dtype
180    lr_t = self._decayed_lr(var_dtype)
181    rms = self.get_slot(var, "rms")
182    rho = self._get_hyper("rho", var_dtype)
183    momentum = self._get_hyper("momentum", var_dtype)
184    epsilon = self._get_hyper("epsilon", var_dtype)
185    if self._momentum:
186      mom = self.get_slot(var, "momentum")
187      if self.centered:
188        mg = self.get_slot(var, "mg")
189        return training_ops.resource_sparse_apply_centered_rms_prop(
190            var.handle,
191            mg.handle,
192            rms.handle,
193            mom.handle,
194            lr_t,
195            rho,
196            momentum,
197            epsilon,
198            grad,
199            indices,
200            use_locking=self._use_locking)
201      else:
202        return training_ops.resource_sparse_apply_rms_prop(
203            var.handle,
204            rms.handle,
205            mom.handle,
206            lr_t,
207            rho,
208            momentum,
209            epsilon,
210            grad,
211            indices,
212            use_locking=self._use_locking)
213    else:
214      rms_scaled_g_values = (grad * grad) * (1. - rho)
215      rms_t = state_ops.assign(rms, rms * rho, use_locking=self._use_locking)
216      with ops.control_dependencies([rms_t]):
217        rms_t = self._resource_scatter_add(rms, indices, rms_scaled_g_values)
218        rms_slice = array_ops.gather(rms_t, indices)
219      denom_slice = rms_slice
220      if self.centered:
221        mg = self.get_slot(var, "mg")
222        mg_scaled_g_values = grad * (1. - rho)
223        mg_t = state_ops.assign(mg, mg * rho, use_locking=self._use_locking)
224        with ops.control_dependencies([mg_t]):
225          mg_t = self._resource_scatter_add(mg, indices, mg_scaled_g_values)
226          mg_slice = array_ops.gather(mg_t, indices)
227          denom_slice = rms_slice - math_ops.square(mg_slice)
228      var_update = self._resource_scatter_add(
229          var, indices, -lr_t * grad / (math_ops.sqrt(denom_slice) + epsilon))
230      if self.centered:
231        return control_flow_ops.group(*[var_update, rms_t, mg_t])
232      return control_flow_ops.group(*[var_update, rms_t])
233
234  def set_weights(self, weights):
235    params = self.weights
236    # Override set_weights for backward compatibility of Keras V1 optimizer
237    # since it does not include iteration at head of the weight list. Set
238    # iteration to 0.
239    if len(params) == len(weights) + 1:
240      weights = [np.array(0)] + weights
241    super(RMSprop, self).set_weights(weights)
242
243  def get_config(self):
244    config = super(RMSprop, self).get_config()
245    config.update({
246        "learning_rate": self._serialize_hyperparameter("learning_rate"),
247        "decay": self._serialize_hyperparameter("decay"),
248        "rho": self._serialize_hyperparameter("rho"),
249        "momentum": self._serialize_hyperparameter("momentum"),
250        "epsilon": self._serialize_hyperparameter("epsilon"),
251        "centered": self.centered,
252    })
253    return config
254
255
256RMSProp = RMSprop
257