1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""RMSprop for TensorFlow.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import numpy as np 21 22from tensorflow.python.framework import ops 23from tensorflow.python.keras import backend_config 24from tensorflow.python.keras.optimizer_v2 import optimizer_v2 25from tensorflow.python.ops import array_ops 26from tensorflow.python.ops import control_flow_ops 27from tensorflow.python.ops import math_ops 28from tensorflow.python.ops import state_ops 29from tensorflow.python.training import training_ops 30from tensorflow.python.util.tf_export import keras_export 31 32 33@keras_export("keras.optimizers.RMSprop") 34class RMSprop(optimizer_v2.OptimizerV2): 35 r"""Optimizer that implements the RMSprop algorithm. 36 37 A detailed description of rmsprop. 38 39 - maintain a moving (discounted) average of the square of gradients 40 - divide gradient by the root of this average 41 42 $$mean_square_t = rho * mean_square{t-1} + (1-rho) * gradient ** 2$$ 43 $$mom_t = momentum * mom_{t-1} + learning_rate * gradient / \sqrt{ / 44 mean_square_t + \epsilon}$$ 45 $$variable_t := variable_{t-1} - mom_t 46 47 This implementation of RMSprop uses plain momentum, not Nesterov momentum. 48 49 The centered version additionally maintains a moving average of the 50 gradients, and uses that average to estimate the variance: 51 52 $$mean_grad_t = rho * mean_grad_{t-1} + (1-rho) * gradient$$ 53 $$mean_square_t = rho * mean_square_{t-1} + (1-rho) * gradient ** 2$$ 54 $$mom_t = momentum * mom_{t-1} + learning_rate * gradient / 55 sqrt(mean_square_t - mean_grad_t**2 + epsilon)$$ 56 $$variable_t := variable_{t-1} - mom_t 57 58 References 59 See ([pdf] 60 http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf). 61 """ 62 63 def __init__(self, 64 learning_rate=0.001, 65 rho=0.9, 66 momentum=0.0, 67 epsilon=1e-7, 68 centered=False, 69 name="RMSprop", 70 **kwargs): 71 """Construct a new RMSprop optimizer. 72 73 Note that in the dense implementation of this algorithm, variables and their 74 corresponding accumulators (momentum, gradient moving average, square 75 gradient moving average) will be updated even if the gradient is zero 76 (i.e. accumulators will decay, momentum will be applied). The sparse 77 implementation (used when the gradient is an `IndexedSlices` object, 78 typically because of `tf.gather` or an embedding lookup in the forward pass) 79 will not update variable slices or their accumulators unless those slices 80 were used in the forward pass (nor is there an "eventual" correction to 81 account for these omitted updates). This leads to more efficient updates for 82 large embedding lookup tables (where most of the slices are not accessed in 83 a particular graph execution), but differs from the published algorithm. 84 85 Args: 86 learning_rate: A Tensor or a floating point value. The learning rate. 87 rho: Discounting factor for the history/coming gradient 88 momentum: A scalar tensor. 89 epsilon: Small value to avoid zero denominator. 90 centered: If True, gradients are normalized by the estimated variance of 91 the gradient; if False, by the uncentered second moment. Setting this to 92 True may help with training, but is slightly more expensive in terms of 93 computation and memory. Defaults to False. 94 name: Optional name prefix for the operations created when applying 95 gradients. Defaults to "RMSprop". @compatibility(eager) When eager 96 execution is enabled, `learning_rate`, `decay`, `momentum`, and 97 `epsilon` can each be a callable that takes no arguments and returns the 98 actual value to use. This can be useful for changing these values across 99 different invocations of optimizer functions. @end_compatibility 100 **kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, `lr`, 101 `decay`}. `clipnorm` is clip gradients by norm; `clipvalue` is clip 102 gradients by value, `decay` is included for backward compatibility to 103 allow time inverse decay of learning rate. `lr` is included for backward 104 compatibility, recommended to use `learning_rate` instead. 105 """ 106 if epsilon is None: 107 epsilon = backend_config.epsilon() 108 super(RMSprop, self).__init__(name, **kwargs) 109 self._set_hyper("learning_rate", kwargs.get("lr", learning_rate)) 110 self._set_hyper("decay", self._initial_decay) 111 self._set_hyper("rho", rho) 112 113 self._momentum = False 114 if isinstance(momentum, ops.Tensor) or callable(momentum) or momentum > 0: 115 self._momentum = True 116 if isinstance(momentum, (int, float)) and (momentum < 0 or momentum > 1): 117 raise ValueError("`momentum` must be between [0, 1].") 118 self._set_hyper("momentum", momentum) 119 120 self._set_hyper("epsilon", epsilon) 121 self.centered = centered 122 123 def _create_slots(self, var_list): 124 for var in var_list: 125 self.add_slot(var, "rms") 126 if self._momentum: 127 for var in var_list: 128 self.add_slot(var, "momentum") 129 if self.centered: 130 for var in var_list: 131 self.add_slot(var, "mg") 132 133 def _resource_apply_dense(self, grad, var): 134 var_dtype = var.dtype.base_dtype 135 lr_t = self._decayed_lr(var_dtype) 136 rms = self.get_slot(var, "rms") 137 rho = self._get_hyper("rho", var_dtype) 138 momentum = self._get_hyper("momentum", var_dtype) 139 epsilon = self._get_hyper("epsilon", var_dtype) 140 if self._momentum: 141 mom = self.get_slot(var, "momentum") 142 if self.centered: 143 mg = self.get_slot(var, "mg") 144 return training_ops.resource_apply_centered_rms_prop( 145 var.handle, 146 mg.handle, 147 rms.handle, 148 mom.handle, 149 lr_t, 150 rho, 151 momentum, 152 epsilon, 153 grad, 154 use_locking=self._use_locking) 155 else: 156 return training_ops.resource_apply_rms_prop( 157 var.handle, 158 rms.handle, 159 mom.handle, 160 lr_t, 161 rho, 162 momentum, 163 epsilon, 164 grad, 165 use_locking=self._use_locking) 166 else: 167 rms_t = rho * rms + (1. - rho) * math_ops.square(grad) 168 rms_t = state_ops.assign(rms, rms_t, use_locking=self._use_locking) 169 denom_t = rms_t 170 if self.centered: 171 mg = self.get_slot(var, "mg") 172 mg_t = rho * mg + (1. - rho) * grad 173 mg_t = state_ops.assign(mg, mg_t, use_locking=self._use_locking) 174 denom_t = rms_t - math_ops.square(mg_t) 175 var_t = var - lr_t * grad / (math_ops.sqrt(denom_t) + epsilon) 176 return state_ops.assign(var, var_t, use_locking=self._use_locking).op 177 178 def _resource_apply_sparse(self, grad, var, indices): 179 var_dtype = var.dtype.base_dtype 180 lr_t = self._decayed_lr(var_dtype) 181 rms = self.get_slot(var, "rms") 182 rho = self._get_hyper("rho", var_dtype) 183 momentum = self._get_hyper("momentum", var_dtype) 184 epsilon = self._get_hyper("epsilon", var_dtype) 185 if self._momentum: 186 mom = self.get_slot(var, "momentum") 187 if self.centered: 188 mg = self.get_slot(var, "mg") 189 return training_ops.resource_sparse_apply_centered_rms_prop( 190 var.handle, 191 mg.handle, 192 rms.handle, 193 mom.handle, 194 lr_t, 195 rho, 196 momentum, 197 epsilon, 198 grad, 199 indices, 200 use_locking=self._use_locking) 201 else: 202 return training_ops.resource_sparse_apply_rms_prop( 203 var.handle, 204 rms.handle, 205 mom.handle, 206 lr_t, 207 rho, 208 momentum, 209 epsilon, 210 grad, 211 indices, 212 use_locking=self._use_locking) 213 else: 214 rms_scaled_g_values = (grad * grad) * (1. - rho) 215 rms_t = state_ops.assign(rms, rms * rho, use_locking=self._use_locking) 216 with ops.control_dependencies([rms_t]): 217 rms_t = self._resource_scatter_add(rms, indices, rms_scaled_g_values) 218 rms_slice = array_ops.gather(rms_t, indices) 219 denom_slice = rms_slice 220 if self.centered: 221 mg = self.get_slot(var, "mg") 222 mg_scaled_g_values = grad * (1. - rho) 223 mg_t = state_ops.assign(mg, mg * rho, use_locking=self._use_locking) 224 with ops.control_dependencies([mg_t]): 225 mg_t = self._resource_scatter_add(mg, indices, mg_scaled_g_values) 226 mg_slice = array_ops.gather(mg_t, indices) 227 denom_slice = rms_slice - math_ops.square(mg_slice) 228 var_update = self._resource_scatter_add( 229 var, indices, -lr_t * grad / (math_ops.sqrt(denom_slice) + epsilon)) 230 if self.centered: 231 return control_flow_ops.group(*[var_update, rms_t, mg_t]) 232 return control_flow_ops.group(*[var_update, rms_t]) 233 234 def set_weights(self, weights): 235 params = self.weights 236 # Override set_weights for backward compatibility of Keras V1 optimizer 237 # since it does not include iteration at head of the weight list. Set 238 # iteration to 0. 239 if len(params) == len(weights) + 1: 240 weights = [np.array(0)] + weights 241 super(RMSprop, self).set_weights(weights) 242 243 def get_config(self): 244 config = super(RMSprop, self).get_config() 245 config.update({ 246 "learning_rate": self._serialize_hyperparameter("learning_rate"), 247 "decay": self._serialize_hyperparameter("decay"), 248 "rho": self._serialize_hyperparameter("rho"), 249 "momentum": self._serialize_hyperparameter("momentum"), 250 "epsilon": self._serialize_hyperparameter("epsilon"), 251 "centered": self.centered, 252 }) 253 return config 254 255 256RMSProp = RMSprop 257