1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for rmsprop.""" 16 17import copy 18import itertools 19import math 20 21import numpy as np 22 23from tensorflow.python.eager import context 24from tensorflow.python.framework import constant_op 25from tensorflow.python.framework import dtypes 26from tensorflow.python.framework import indexed_slices 27from tensorflow.python.framework import test_util 28from tensorflow.python.ops import embedding_ops 29from tensorflow.python.ops import math_ops 30from tensorflow.python.ops import resource_variable_ops 31from tensorflow.python.ops import variables 32from tensorflow.python.platform import test 33from tensorflow.python.training import rmsprop 34 35_DATA_TYPES = [dtypes.half, dtypes.float32] 36 37_TEST_PARAM_VALUES = [ 38 # learning_rate, decay, momentum, epsilon, centered, use_resource 39 [0.5, 0.9, 0.0, 1e-3, True, False], 40 [0.5, 0.9, 0.0, 1e-3, False, False], 41 [0.5, 0.9, 0.0, 1e-3, True, True], 42 [0.5, 0.9, 0.0, 1e-3, False, True], 43 [0.1, 0.9, 0.0, 1e-3, True, False], 44 [0.5, 0.95, 0.0, 1e-3, False, False], 45 [0.5, 0.95, 0.0, 1e-5, True, False], 46 [0.5, 0.95, 0.9, 1e-5, True, False], 47] 48 49_TESTPARAMS = [ 50 [data_type] + values 51 for data_type, values in itertools.product(_DATA_TYPES, _TEST_PARAM_VALUES) 52] 53 54 55class RMSPropOptimizerTest(test.TestCase): 56 57 def _rmsprop_update_numpy(self, var, g, mg, rms, mom, lr, decay, momentum, 58 epsilon, centered): 59 rms_t = rms * decay + (1 - decay) * g * g 60 denom_t = rms_t + epsilon 61 if centered: 62 mg_t = mg * decay + (1 - decay) * g 63 denom_t -= mg_t * mg_t 64 else: 65 mg_t = mg 66 mom_t = momentum * mom + lr * g / np.sqrt(denom_t, dtype=denom_t.dtype) 67 var_t = var - mom_t 68 return var_t, mg_t, rms_t, mom_t 69 70 def _sparse_rmsprop_update_numpy(self, var, gindexs, gvalues, mg, rms, mom, 71 lr, decay, momentum, epsilon, centered): 72 mg_t = copy.deepcopy(mg) 73 rms_t = copy.deepcopy(rms) 74 mom_t = copy.deepcopy(mom) 75 var_t = copy.deepcopy(var) 76 for i in range(len(gindexs)): 77 gindex = gindexs[i] 78 gvalue = gvalues[i] 79 rms_t[gindex] = rms[gindex] * decay + (1 - decay) * gvalue * gvalue 80 denom_t = rms_t[gindex] + epsilon 81 if centered: 82 mg_t[gindex] = mg_t[gindex] * decay + (1 - decay) * gvalue 83 denom_t -= mg_t[gindex] * mg_t[gindex] 84 mom_t[gindex] = momentum * mom[gindex] + lr * gvalue / np.sqrt(denom_t) 85 var_t[gindex] = var[gindex] - mom_t[gindex] 86 return var_t, mg_t, rms_t, mom_t 87 88 @test_util.run_deprecated_v1 89 def testDense(self): 90 # TODO(yori): Use ParameterizedTest when available 91 for (dtype, learning_rate, decay, momentum, 92 epsilon, centered, use_resource) in _TESTPARAMS: 93 with test_util.use_gpu(): 94 # Initialize variables for numpy implementation. 95 var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) 96 grads0_np = np.array([0.1, 0.2], dtype=dtype.as_numpy_dtype) 97 var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) 98 grads1_np = np.array([0.01, 0.2], dtype=dtype.as_numpy_dtype) 99 100 if use_resource: 101 var0 = resource_variable_ops.ResourceVariable(var0_np) 102 var1 = resource_variable_ops.ResourceVariable(var1_np) 103 else: 104 var0 = variables.Variable(var0_np) 105 var1 = variables.Variable(var1_np) 106 grads0 = constant_op.constant(grads0_np) 107 grads1 = constant_op.constant(grads1_np) 108 opt = rmsprop.RMSPropOptimizer( 109 learning_rate=learning_rate, 110 decay=decay, 111 momentum=momentum, 112 epsilon=epsilon, 113 centered=centered) 114 115 update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) 116 self.evaluate(variables.global_variables_initializer()) 117 118 mg0 = opt.get_slot(var0, "mg") 119 self.assertEqual(mg0 is not None, centered) 120 mg1 = opt.get_slot(var1, "mg") 121 self.assertEqual(mg1 is not None, centered) 122 rms0 = opt.get_slot(var0, "rms") 123 self.assertTrue(rms0 is not None) 124 rms1 = opt.get_slot(var1, "rms") 125 self.assertTrue(rms1 is not None) 126 mom0 = opt.get_slot(var0, "momentum") 127 self.assertTrue(mom0 is not None) 128 mom1 = opt.get_slot(var1, "momentum") 129 self.assertTrue(mom1 is not None) 130 131 mg0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) 132 mg1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) 133 rms0_np = np.array([1.0, 1.0], dtype=dtype.as_numpy_dtype) 134 rms1_np = np.array([1.0, 1.0], dtype=dtype.as_numpy_dtype) 135 mom0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) 136 mom1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) 137 138 # Fetch params to validate initial values 139 self.assertAllClose([1.0, 2.0], self.evaluate(var0)) 140 self.assertAllClose([3.0, 4.0], self.evaluate(var1)) 141 142 # Run 4 steps of RMSProp 143 for _ in range(1, 5): 144 self.evaluate(update) 145 146 var0_np, mg0_np, rms0_np, mom0_np = self._rmsprop_update_numpy( 147 var0_np, grads0_np, mg0_np, rms0_np, mom0_np, learning_rate, 148 decay, momentum, epsilon, centered) 149 var1_np, mg1_np, rms1_np, mom1_np = self._rmsprop_update_numpy( 150 var1_np, grads1_np, mg1_np, rms1_np, mom1_np, learning_rate, 151 decay, momentum, epsilon, centered) 152 153 # Validate updated params 154 if centered: 155 self.assertAllCloseAccordingToType(mg0_np, self.evaluate(mg0)) 156 self.assertAllCloseAccordingToType(mg1_np, self.evaluate(mg1)) 157 self.assertAllCloseAccordingToType(rms0_np, self.evaluate(rms0)) 158 self.assertAllCloseAccordingToType(rms1_np, self.evaluate(rms1)) 159 self.assertAllCloseAccordingToType(mom0_np, self.evaluate(mom0)) 160 self.assertAllCloseAccordingToType(mom1_np, self.evaluate(mom1)) 161 self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) 162 self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) 163 164 @test_util.run_deprecated_v1 165 def testMinimizeSparseResourceVariable(self): 166 for dtype in [dtypes.float32, dtypes.float64]: 167 with self.cached_session(): 168 var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype) 169 x = constant_op.constant([[4.0], [5.0]], dtype=dtype) 170 pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x) 171 loss = pred * pred 172 sgd_op = rmsprop.RMSPropOptimizer( 173 learning_rate=1.0, 174 decay=0.0, 175 momentum=0.0, 176 epsilon=0.0, 177 centered=False).minimize(loss) 178 self.evaluate(variables.global_variables_initializer()) 179 # Fetch params to validate initial values 180 self.assertAllCloseAccordingToType([[1.0, 2.0]], self.evaluate(var0)) 181 # Run 1 step of sgd 182 self.evaluate(sgd_op) 183 # Validate updated params 184 self.assertAllCloseAccordingToType([[0., 1.]], 185 self.evaluate(var0), 186 atol=0.01) 187 188 @test_util.run_deprecated_v1 189 def testMinimizeSparseResourceVariableCentered(self): 190 for dtype in [dtypes.float32, dtypes.float64]: 191 with self.cached_session(): 192 var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype) 193 x = constant_op.constant([[4.0], [5.0]], dtype=dtype) 194 pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x) 195 loss = pred * pred 196 sgd_op = rmsprop.RMSPropOptimizer( 197 learning_rate=1.0, 198 decay=0.0, 199 momentum=0.0, 200 epsilon=1.0, 201 centered=True).minimize(loss) 202 self.evaluate(variables.global_variables_initializer()) 203 # Fetch params to validate initial values 204 self.assertAllCloseAccordingToType([[1.0, 2.0]], self.evaluate(var0)) 205 # Run 1 step of sgd 206 self.evaluate(sgd_op) 207 # Validate updated params 208 self.assertAllCloseAccordingToType([[-111, -138]], 209 self.evaluate(var0), 210 atol=0.01) 211 212 @test_util.run_deprecated_v1 213 def testSparse(self): 214 # TODO(yori): Use ParameterizedTest when available 215 for (dtype, learning_rate, decay, 216 momentum, epsilon, centered, _) in _TESTPARAMS: 217 with test_util.use_gpu(): 218 # Initialize variables for numpy implementation. 219 var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) 220 grads0_np = np.array([0.1], dtype=dtype.as_numpy_dtype) 221 var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) 222 grads1_np = np.array([0.01], dtype=dtype.as_numpy_dtype) 223 224 var0 = variables.Variable(var0_np) 225 var1 = variables.Variable(var1_np) 226 grads0_np_indices = np.array([0], dtype=np.int32) 227 grads0 = indexed_slices.IndexedSlices( 228 constant_op.constant(grads0_np), 229 constant_op.constant(grads0_np_indices), constant_op.constant([1])) 230 grads1_np_indices = np.array([1], dtype=np.int32) 231 grads1 = indexed_slices.IndexedSlices( 232 constant_op.constant(grads1_np), 233 constant_op.constant(grads1_np_indices), constant_op.constant([1])) 234 opt = rmsprop.RMSPropOptimizer( 235 learning_rate=learning_rate, 236 decay=decay, 237 momentum=momentum, 238 epsilon=epsilon, 239 centered=centered) 240 update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) 241 self.evaluate(variables.global_variables_initializer()) 242 243 mg0 = opt.get_slot(var0, "mg") 244 self.assertEqual(mg0 is not None, centered) 245 mg1 = opt.get_slot(var1, "mg") 246 self.assertEqual(mg1 is not None, centered) 247 rms0 = opt.get_slot(var0, "rms") 248 self.assertTrue(rms0 is not None) 249 rms1 = opt.get_slot(var1, "rms") 250 self.assertTrue(rms1 is not None) 251 mom0 = opt.get_slot(var0, "momentum") 252 self.assertTrue(mom0 is not None) 253 mom1 = opt.get_slot(var1, "momentum") 254 self.assertTrue(mom1 is not None) 255 256 mg0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) 257 mg1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) 258 rms0_np = np.array([1.0, 1.0], dtype=dtype.as_numpy_dtype) 259 rms1_np = np.array([1.0, 1.0], dtype=dtype.as_numpy_dtype) 260 mom0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) 261 mom1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) 262 263 # Fetch params to validate initial values 264 self.assertAllClose([1.0, 2.0], self.evaluate(var0)) 265 self.assertAllClose([3.0, 4.0], self.evaluate(var1)) 266 267 # Run 4 steps of RMSProp 268 for _ in range(1, 5): 269 self.evaluate(update) 270 271 var0_np, mg0_np, rms0_np, mom0_np = self._sparse_rmsprop_update_numpy( 272 var0_np, grads0_np_indices, grads0_np, mg0_np, rms0_np, mom0_np, 273 learning_rate, decay, momentum, epsilon, centered) 274 var1_np, mg1_np, rms1_np, mom1_np = self._sparse_rmsprop_update_numpy( 275 var1_np, grads1_np_indices, grads1_np, mg1_np, rms1_np, mom1_np, 276 learning_rate, decay, momentum, epsilon, centered) 277 278 # Validate updated params 279 if centered: 280 self.assertAllCloseAccordingToType(mg0_np, self.evaluate(mg0)) 281 self.assertAllCloseAccordingToType(mg1_np, self.evaluate(mg1)) 282 self.assertAllCloseAccordingToType(rms0_np, self.evaluate(rms0)) 283 self.assertAllCloseAccordingToType(rms1_np, self.evaluate(rms1)) 284 self.assertAllCloseAccordingToType(mom0_np, self.evaluate(mom0)) 285 self.assertAllCloseAccordingToType(mom1_np, self.evaluate(mom1)) 286 self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) 287 self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) 288 289 @test_util.run_deprecated_v1 290 def testWithoutMomentum(self): 291 for dtype in [dtypes.half, dtypes.float32]: 292 with test_util.use_gpu(): 293 var0 = variables.Variable([1.0, 2.0], dtype=dtype) 294 var1 = variables.Variable([3.0, 4.0], dtype=dtype) 295 grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) 296 grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) 297 opt = rmsprop.RMSPropOptimizer( 298 learning_rate=2.0, decay=0.9, momentum=0.0, epsilon=1.0) 299 update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) 300 self.evaluate(variables.global_variables_initializer()) 301 302 rms0 = opt.get_slot(var0, "rms") 303 self.assertTrue(rms0 is not None) 304 rms1 = opt.get_slot(var1, "rms") 305 self.assertTrue(rms1 is not None) 306 mom0 = opt.get_slot(var0, "momentum") 307 self.assertTrue(mom0 is not None) 308 mom1 = opt.get_slot(var1, "momentum") 309 self.assertTrue(mom1 is not None) 310 311 # Fetch params to validate initial values 312 self.assertAllClose([1.0, 2.0], self.evaluate(var0)) 313 self.assertAllClose([3.0, 4.0], self.evaluate(var1)) 314 # Step 1: the rms accumulators where 1. So we should see a normal 315 # update: v -= grad * learning_rate 316 self.evaluate(update) 317 # Check the root mean square accumulators. 318 self.assertAllCloseAccordingToType( 319 np.array([0.901, 0.901]), self.evaluate(rms0)) 320 self.assertAllCloseAccordingToType( 321 np.array([0.90001, 0.90001]), self.evaluate(rms1)) 322 # Check the parameters. 323 self.assertAllCloseAccordingToType( 324 np.array([ 325 1.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1.0)), 326 2.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1.0)) 327 ]), self.evaluate(var0)) 328 self.assertAllCloseAccordingToType( 329 np.array([ 330 3.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1.0)), 331 4.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1.0)) 332 ]), self.evaluate(var1)) 333 # Step 2: the root mean square accumulators contain the previous update. 334 self.evaluate(update) 335 # Check the rms accumulators. 336 self.assertAllCloseAccordingToType( 337 np.array([0.901 * 0.9 + 0.001, 0.901 * 0.9 + 0.001]), 338 self.evaluate(rms0)) 339 self.assertAllCloseAccordingToType( 340 np.array([0.90001 * 0.9 + 1e-5, 0.90001 * 0.9 + 1e-5]), 341 self.evaluate(rms1)) 342 # Check the parameters. 343 self.assertAllCloseAccordingToType( 344 np.array([ 345 1.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1.0)) - 346 (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001 + 1.0)), 347 2.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1.0)) - 348 (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001 + 1.0)) 349 ]), self.evaluate(var0)) 350 self.assertAllCloseAccordingToType( 351 np.array([ 352 3.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1.0)) - 353 (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 1e-5 + 1.0)), 354 4.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1.0)) - 355 (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 1e-5 + 1.0)) 356 ]), self.evaluate(var1)) 357 358 @test_util.run_deprecated_v1 359 def testWithMomentum(self): 360 for dtype in [dtypes.half, dtypes.float32]: 361 with test_util.use_gpu(): 362 var0 = variables.Variable([1.0, 2.0], dtype=dtype) 363 var1 = variables.Variable([3.0, 4.0], dtype=dtype) 364 grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) 365 grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) 366 367 opt = rmsprop.RMSPropOptimizer( 368 learning_rate=2.0, decay=0.9, momentum=0.5, epsilon=1e-5) 369 update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) 370 self.evaluate(variables.global_variables_initializer()) 371 372 rms0 = opt.get_slot(var0, "rms") 373 self.assertTrue(rms0 is not None) 374 rms1 = opt.get_slot(var1, "rms") 375 self.assertTrue(rms1 is not None) 376 mom0 = opt.get_slot(var0, "momentum") 377 self.assertTrue(mom0 is not None) 378 mom1 = opt.get_slot(var1, "momentum") 379 self.assertTrue(mom1 is not None) 380 381 # Fetch params to validate initial values 382 self.assertAllClose([1.0, 2.0], self.evaluate(var0)) 383 self.assertAllClose([3.0, 4.0], self.evaluate(var1)) 384 # Step 1: rms = 1, mom = 0. So we should see a normal 385 # update: v -= grad * learning_rate 386 self.evaluate(update) 387 # Check the root mean square accumulators. 388 self.assertAllCloseAccordingToType( 389 np.array([0.901, 0.901]), self.evaluate(rms0)) 390 self.assertAllCloseAccordingToType( 391 np.array([0.90001, 0.90001]), self.evaluate(rms1)) 392 # Check the momentum accumulators 393 self.assertAllCloseAccordingToType( 394 np.array([(0.1 * 2.0 / math.sqrt(0.901 + 1e-5)), 395 (0.1 * 2.0 / math.sqrt(0.901 + 1e-5))]), 396 self.evaluate(mom0)) 397 self.assertAllCloseAccordingToType( 398 np.array([(0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)), 399 (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5))]), 400 self.evaluate(mom1)) 401 402 # Check that the parameters. 403 self.assertAllCloseAccordingToType( 404 np.array([ 405 1.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)), 406 2.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)) 407 ]), self.evaluate(var0)) 408 self.assertAllCloseAccordingToType( 409 np.array([ 410 3.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)), 411 4.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)) 412 ]), self.evaluate(var1)) 413 414 # Step 2: the root mean square accumulators contain the previous update. 415 self.evaluate(update) 416 # Check the rms accumulators. 417 self.assertAllCloseAccordingToType( 418 np.array([0.901 * 0.9 + 0.001, 0.901 * 0.9 + 0.001]), 419 self.evaluate(rms0)) 420 self.assertAllCloseAccordingToType( 421 np.array([0.90001 * 0.9 + 1e-5, 0.90001 * 0.9 + 1e-5]), 422 self.evaluate(rms1)) 423 self.assertAllCloseAccordingToType( 424 np.array([ 425 0.5 * (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)) + 426 (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001 + 1e-5)), 427 0.5 * (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)) + 428 (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001 + 1e-5)) 429 ]), self.evaluate(mom0)) 430 self.assertAllCloseAccordingToType( 431 np.array([ 432 0.5 * (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)) + 433 (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 2e-5)), 434 0.5 * (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)) + 435 (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 2e-5)) 436 ]), self.evaluate(mom1)) 437 438 # Check the parameters. 439 self.assertAllCloseAccordingToType( 440 np.array([ 441 1.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)) - 442 (0.5 * (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)) + 443 (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001 + 1e-5))), 444 2.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)) - 445 (0.5 * (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)) + 446 (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001 + 1e-5))) 447 ]), self.evaluate(var0)) 448 449 self.assertAllCloseAccordingToType( 450 np.array([ 451 3.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)) - 452 (0.5 * (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)) + 453 (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 2e-5))), 454 4.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)) - 455 (0.5 * (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)) + 456 (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 2e-5))) 457 ]), self.evaluate(var1)) 458 459 def testCallableParams(self): 460 with context.eager_mode(): 461 for dtype in [dtypes.half, dtypes.float32]: 462 var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) 463 var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) 464 grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) 465 grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) 466 467 learning_rate = lambda: 2.0 468 decay = lambda: 0.9 469 momentum = lambda: 0.0 470 epsilon = lambda: 1.0 471 opt = rmsprop.RMSPropOptimizer(learning_rate, decay, momentum, epsilon) 472 473 # Fetch params to validate initial values 474 self.assertAllClose([1.0, 2.0], self.evaluate(var0)) 475 self.assertAllClose([3.0, 4.0], self.evaluate(var1)) 476 # Step 1: the rms accumulators where 1. So we should see a normal 477 # update: v -= grad * learning_rate 478 opt.apply_gradients(zip([grads0, grads1], [var0, var1])) 479 # Check the parameters. 480 self.assertAllCloseAccordingToType( 481 np.array([ 482 1.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1.0)), 483 2.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1.0)) 484 ]), self.evaluate(var0)) 485 self.assertAllCloseAccordingToType( 486 np.array([ 487 3.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1.0)), 488 4.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1.0)) 489 ]), self.evaluate(var1)) 490 # Step 2: the root mean square accumulators contain the previous update. 491 opt.apply_gradients(zip([grads0, grads1], [var0, var1])) 492 # Check the parameters. 493 self.assertAllCloseAccordingToType( 494 np.array([ 495 1.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1.0)) - 496 (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001 + 1.0)), 497 2.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1.0)) - 498 (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001 + 1.0)) 499 ]), self.evaluate(var0)) 500 self.assertAllCloseAccordingToType( 501 np.array([ 502 3.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1.0)) - 503 (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 1e-5 + 1.0)), 504 4.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1.0)) - 505 (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 1e-5 + 1.0)) 506 ]), self.evaluate(var1)) 507 508 509if __name__ == "__main__": 510 test.main() 511