• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for rmsprop."""
16
17import copy
18import itertools
19import math
20
21import numpy as np
22
23from tensorflow.python.eager import context
24from tensorflow.python.framework import constant_op
25from tensorflow.python.framework import dtypes
26from tensorflow.python.framework import indexed_slices
27from tensorflow.python.framework import test_util
28from tensorflow.python.ops import embedding_ops
29from tensorflow.python.ops import math_ops
30from tensorflow.python.ops import resource_variable_ops
31from tensorflow.python.ops import variables
32from tensorflow.python.platform import test
33from tensorflow.python.training import rmsprop
34
35_DATA_TYPES = [dtypes.half, dtypes.float32]
36
37_TEST_PARAM_VALUES = [
38    # learning_rate, decay, momentum, epsilon, centered, use_resource
39    [0.5, 0.9, 0.0, 1e-3, True, False],
40    [0.5, 0.9, 0.0, 1e-3, False, False],
41    [0.5, 0.9, 0.0, 1e-3, True, True],
42    [0.5, 0.9, 0.0, 1e-3, False, True],
43    [0.1, 0.9, 0.0, 1e-3, True, False],
44    [0.5, 0.95, 0.0, 1e-3, False, False],
45    [0.5, 0.95, 0.0, 1e-5, True, False],
46    [0.5, 0.95, 0.9, 1e-5, True, False],
47]
48
49_TESTPARAMS = [
50    [data_type] + values
51    for data_type, values in itertools.product(_DATA_TYPES, _TEST_PARAM_VALUES)
52]
53
54
55class RMSPropOptimizerTest(test.TestCase):
56
57  def _rmsprop_update_numpy(self, var, g, mg, rms, mom, lr, decay, momentum,
58                            epsilon, centered):
59    rms_t = rms * decay + (1 - decay) * g * g
60    denom_t = rms_t + epsilon
61    if centered:
62      mg_t = mg * decay + (1 - decay) * g
63      denom_t -= mg_t * mg_t
64    else:
65      mg_t = mg
66    mom_t = momentum * mom + lr * g / np.sqrt(denom_t, dtype=denom_t.dtype)
67    var_t = var - mom_t
68    return var_t, mg_t, rms_t, mom_t
69
70  def _sparse_rmsprop_update_numpy(self, var, gindexs, gvalues, mg, rms, mom,
71                                   lr, decay, momentum, epsilon, centered):
72    mg_t = copy.deepcopy(mg)
73    rms_t = copy.deepcopy(rms)
74    mom_t = copy.deepcopy(mom)
75    var_t = copy.deepcopy(var)
76    for i in range(len(gindexs)):
77      gindex = gindexs[i]
78      gvalue = gvalues[i]
79      rms_t[gindex] = rms[gindex] * decay + (1 - decay) * gvalue * gvalue
80      denom_t = rms_t[gindex] + epsilon
81      if centered:
82        mg_t[gindex] = mg_t[gindex] * decay + (1 - decay) * gvalue
83        denom_t -= mg_t[gindex] * mg_t[gindex]
84      mom_t[gindex] = momentum * mom[gindex] + lr * gvalue / np.sqrt(denom_t)
85      var_t[gindex] = var[gindex] - mom_t[gindex]
86    return var_t, mg_t, rms_t, mom_t
87
88  @test_util.run_deprecated_v1
89  def testDense(self):
90    # TODO(yori): Use ParameterizedTest when available
91    for (dtype, learning_rate, decay, momentum,
92         epsilon, centered, use_resource) in _TESTPARAMS:
93      with test_util.use_gpu():
94        # Initialize variables for numpy implementation.
95        var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
96        grads0_np = np.array([0.1, 0.2], dtype=dtype.as_numpy_dtype)
97        var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
98        grads1_np = np.array([0.01, 0.2], dtype=dtype.as_numpy_dtype)
99
100        if use_resource:
101          var0 = resource_variable_ops.ResourceVariable(var0_np)
102          var1 = resource_variable_ops.ResourceVariable(var1_np)
103        else:
104          var0 = variables.Variable(var0_np)
105          var1 = variables.Variable(var1_np)
106        grads0 = constant_op.constant(grads0_np)
107        grads1 = constant_op.constant(grads1_np)
108        opt = rmsprop.RMSPropOptimizer(
109            learning_rate=learning_rate,
110            decay=decay,
111            momentum=momentum,
112            epsilon=epsilon,
113            centered=centered)
114
115        update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
116        self.evaluate(variables.global_variables_initializer())
117
118        mg0 = opt.get_slot(var0, "mg")
119        self.assertEqual(mg0 is not None, centered)
120        mg1 = opt.get_slot(var1, "mg")
121        self.assertEqual(mg1 is not None, centered)
122        rms0 = opt.get_slot(var0, "rms")
123        self.assertTrue(rms0 is not None)
124        rms1 = opt.get_slot(var1, "rms")
125        self.assertTrue(rms1 is not None)
126        mom0 = opt.get_slot(var0, "momentum")
127        self.assertTrue(mom0 is not None)
128        mom1 = opt.get_slot(var1, "momentum")
129        self.assertTrue(mom1 is not None)
130
131        mg0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
132        mg1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
133        rms0_np = np.array([1.0, 1.0], dtype=dtype.as_numpy_dtype)
134        rms1_np = np.array([1.0, 1.0], dtype=dtype.as_numpy_dtype)
135        mom0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
136        mom1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
137
138        # Fetch params to validate initial values
139        self.assertAllClose([1.0, 2.0], self.evaluate(var0))
140        self.assertAllClose([3.0, 4.0], self.evaluate(var1))
141
142        # Run 4 steps of RMSProp
143        for _ in range(1, 5):
144          self.evaluate(update)
145
146          var0_np, mg0_np, rms0_np, mom0_np = self._rmsprop_update_numpy(
147              var0_np, grads0_np, mg0_np, rms0_np, mom0_np, learning_rate,
148              decay, momentum, epsilon, centered)
149          var1_np, mg1_np, rms1_np, mom1_np = self._rmsprop_update_numpy(
150              var1_np, grads1_np, mg1_np, rms1_np, mom1_np, learning_rate,
151              decay, momentum, epsilon, centered)
152
153          # Validate updated params
154          if centered:
155            self.assertAllCloseAccordingToType(mg0_np, self.evaluate(mg0))
156            self.assertAllCloseAccordingToType(mg1_np, self.evaluate(mg1))
157          self.assertAllCloseAccordingToType(rms0_np, self.evaluate(rms0))
158          self.assertAllCloseAccordingToType(rms1_np, self.evaluate(rms1))
159          self.assertAllCloseAccordingToType(mom0_np, self.evaluate(mom0))
160          self.assertAllCloseAccordingToType(mom1_np, self.evaluate(mom1))
161          self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0))
162          self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1))
163
164  @test_util.run_deprecated_v1
165  def testMinimizeSparseResourceVariable(self):
166    for dtype in [dtypes.float32, dtypes.float64]:
167      with self.cached_session():
168        var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype)
169        x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
170        pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x)
171        loss = pred * pred
172        sgd_op = rmsprop.RMSPropOptimizer(
173            learning_rate=1.0,
174            decay=0.0,
175            momentum=0.0,
176            epsilon=0.0,
177            centered=False).minimize(loss)
178        self.evaluate(variables.global_variables_initializer())
179        # Fetch params to validate initial values
180        self.assertAllCloseAccordingToType([[1.0, 2.0]], self.evaluate(var0))
181        # Run 1 step of sgd
182        self.evaluate(sgd_op)
183        # Validate updated params
184        self.assertAllCloseAccordingToType([[0., 1.]],
185                                           self.evaluate(var0),
186                                           atol=0.01)
187
188  @test_util.run_deprecated_v1
189  def testMinimizeSparseResourceVariableCentered(self):
190    for dtype in [dtypes.float32, dtypes.float64]:
191      with self.cached_session():
192        var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype)
193        x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
194        pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x)
195        loss = pred * pred
196        sgd_op = rmsprop.RMSPropOptimizer(
197            learning_rate=1.0,
198            decay=0.0,
199            momentum=0.0,
200            epsilon=1.0,
201            centered=True).minimize(loss)
202        self.evaluate(variables.global_variables_initializer())
203        # Fetch params to validate initial values
204        self.assertAllCloseAccordingToType([[1.0, 2.0]], self.evaluate(var0))
205        # Run 1 step of sgd
206        self.evaluate(sgd_op)
207        # Validate updated params
208        self.assertAllCloseAccordingToType([[-111, -138]],
209                                           self.evaluate(var0),
210                                           atol=0.01)
211
212  @test_util.run_deprecated_v1
213  def testSparse(self):
214    # TODO(yori): Use ParameterizedTest when available
215    for (dtype, learning_rate, decay,
216         momentum, epsilon, centered, _) in _TESTPARAMS:
217      with test_util.use_gpu():
218        # Initialize variables for numpy implementation.
219        var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
220        grads0_np = np.array([0.1], dtype=dtype.as_numpy_dtype)
221        var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
222        grads1_np = np.array([0.01], dtype=dtype.as_numpy_dtype)
223
224        var0 = variables.Variable(var0_np)
225        var1 = variables.Variable(var1_np)
226        grads0_np_indices = np.array([0], dtype=np.int32)
227        grads0 = indexed_slices.IndexedSlices(
228            constant_op.constant(grads0_np),
229            constant_op.constant(grads0_np_indices), constant_op.constant([1]))
230        grads1_np_indices = np.array([1], dtype=np.int32)
231        grads1 = indexed_slices.IndexedSlices(
232            constant_op.constant(grads1_np),
233            constant_op.constant(grads1_np_indices), constant_op.constant([1]))
234        opt = rmsprop.RMSPropOptimizer(
235            learning_rate=learning_rate,
236            decay=decay,
237            momentum=momentum,
238            epsilon=epsilon,
239            centered=centered)
240        update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
241        self.evaluate(variables.global_variables_initializer())
242
243        mg0 = opt.get_slot(var0, "mg")
244        self.assertEqual(mg0 is not None, centered)
245        mg1 = opt.get_slot(var1, "mg")
246        self.assertEqual(mg1 is not None, centered)
247        rms0 = opt.get_slot(var0, "rms")
248        self.assertTrue(rms0 is not None)
249        rms1 = opt.get_slot(var1, "rms")
250        self.assertTrue(rms1 is not None)
251        mom0 = opt.get_slot(var0, "momentum")
252        self.assertTrue(mom0 is not None)
253        mom1 = opt.get_slot(var1, "momentum")
254        self.assertTrue(mom1 is not None)
255
256        mg0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
257        mg1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
258        rms0_np = np.array([1.0, 1.0], dtype=dtype.as_numpy_dtype)
259        rms1_np = np.array([1.0, 1.0], dtype=dtype.as_numpy_dtype)
260        mom0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
261        mom1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
262
263        # Fetch params to validate initial values
264        self.assertAllClose([1.0, 2.0], self.evaluate(var0))
265        self.assertAllClose([3.0, 4.0], self.evaluate(var1))
266
267        # Run 4 steps of RMSProp
268        for _ in range(1, 5):
269          self.evaluate(update)
270
271          var0_np, mg0_np, rms0_np, mom0_np = self._sparse_rmsprop_update_numpy(
272              var0_np, grads0_np_indices, grads0_np, mg0_np, rms0_np, mom0_np,
273              learning_rate, decay, momentum, epsilon, centered)
274          var1_np, mg1_np, rms1_np, mom1_np = self._sparse_rmsprop_update_numpy(
275              var1_np, grads1_np_indices, grads1_np, mg1_np, rms1_np, mom1_np,
276              learning_rate, decay, momentum, epsilon, centered)
277
278          # Validate updated params
279          if centered:
280            self.assertAllCloseAccordingToType(mg0_np, self.evaluate(mg0))
281            self.assertAllCloseAccordingToType(mg1_np, self.evaluate(mg1))
282          self.assertAllCloseAccordingToType(rms0_np, self.evaluate(rms0))
283          self.assertAllCloseAccordingToType(rms1_np, self.evaluate(rms1))
284          self.assertAllCloseAccordingToType(mom0_np, self.evaluate(mom0))
285          self.assertAllCloseAccordingToType(mom1_np, self.evaluate(mom1))
286          self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0))
287          self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1))
288
289  @test_util.run_deprecated_v1
290  def testWithoutMomentum(self):
291    for dtype in [dtypes.half, dtypes.float32]:
292      with test_util.use_gpu():
293        var0 = variables.Variable([1.0, 2.0], dtype=dtype)
294        var1 = variables.Variable([3.0, 4.0], dtype=dtype)
295        grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
296        grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
297        opt = rmsprop.RMSPropOptimizer(
298            learning_rate=2.0, decay=0.9, momentum=0.0, epsilon=1.0)
299        update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
300        self.evaluate(variables.global_variables_initializer())
301
302        rms0 = opt.get_slot(var0, "rms")
303        self.assertTrue(rms0 is not None)
304        rms1 = opt.get_slot(var1, "rms")
305        self.assertTrue(rms1 is not None)
306        mom0 = opt.get_slot(var0, "momentum")
307        self.assertTrue(mom0 is not None)
308        mom1 = opt.get_slot(var1, "momentum")
309        self.assertTrue(mom1 is not None)
310
311        # Fetch params to validate initial values
312        self.assertAllClose([1.0, 2.0], self.evaluate(var0))
313        self.assertAllClose([3.0, 4.0], self.evaluate(var1))
314        # Step 1: the rms accumulators where 1. So we should see a normal
315        # update: v -= grad * learning_rate
316        self.evaluate(update)
317        # Check the root mean square accumulators.
318        self.assertAllCloseAccordingToType(
319            np.array([0.901, 0.901]), self.evaluate(rms0))
320        self.assertAllCloseAccordingToType(
321            np.array([0.90001, 0.90001]), self.evaluate(rms1))
322        # Check the parameters.
323        self.assertAllCloseAccordingToType(
324            np.array([
325                1.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1.0)),
326                2.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1.0))
327            ]), self.evaluate(var0))
328        self.assertAllCloseAccordingToType(
329            np.array([
330                3.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1.0)),
331                4.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1.0))
332            ]), self.evaluate(var1))
333        # Step 2: the root mean square accumulators contain the previous update.
334        self.evaluate(update)
335        # Check the rms accumulators.
336        self.assertAllCloseAccordingToType(
337            np.array([0.901 * 0.9 + 0.001, 0.901 * 0.9 + 0.001]),
338            self.evaluate(rms0))
339        self.assertAllCloseAccordingToType(
340            np.array([0.90001 * 0.9 + 1e-5, 0.90001 * 0.9 + 1e-5]),
341            self.evaluate(rms1))
342        # Check the parameters.
343        self.assertAllCloseAccordingToType(
344            np.array([
345                1.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1.0)) -
346                (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001 + 1.0)),
347                2.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1.0)) -
348                (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001 + 1.0))
349            ]), self.evaluate(var0))
350        self.assertAllCloseAccordingToType(
351            np.array([
352                3.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1.0)) -
353                (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 1e-5 + 1.0)),
354                4.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1.0)) -
355                (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 1e-5 + 1.0))
356            ]), self.evaluate(var1))
357
358  @test_util.run_deprecated_v1
359  def testWithMomentum(self):
360    for dtype in [dtypes.half, dtypes.float32]:
361      with test_util.use_gpu():
362        var0 = variables.Variable([1.0, 2.0], dtype=dtype)
363        var1 = variables.Variable([3.0, 4.0], dtype=dtype)
364        grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
365        grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
366
367        opt = rmsprop.RMSPropOptimizer(
368            learning_rate=2.0, decay=0.9, momentum=0.5, epsilon=1e-5)
369        update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
370        self.evaluate(variables.global_variables_initializer())
371
372        rms0 = opt.get_slot(var0, "rms")
373        self.assertTrue(rms0 is not None)
374        rms1 = opt.get_slot(var1, "rms")
375        self.assertTrue(rms1 is not None)
376        mom0 = opt.get_slot(var0, "momentum")
377        self.assertTrue(mom0 is not None)
378        mom1 = opt.get_slot(var1, "momentum")
379        self.assertTrue(mom1 is not None)
380
381        # Fetch params to validate initial values
382        self.assertAllClose([1.0, 2.0], self.evaluate(var0))
383        self.assertAllClose([3.0, 4.0], self.evaluate(var1))
384        # Step 1: rms = 1, mom = 0. So we should see a normal
385        # update: v -= grad * learning_rate
386        self.evaluate(update)
387        # Check the root mean square accumulators.
388        self.assertAllCloseAccordingToType(
389            np.array([0.901, 0.901]), self.evaluate(rms0))
390        self.assertAllCloseAccordingToType(
391            np.array([0.90001, 0.90001]), self.evaluate(rms1))
392        # Check the momentum accumulators
393        self.assertAllCloseAccordingToType(
394            np.array([(0.1 * 2.0 / math.sqrt(0.901 + 1e-5)),
395                      (0.1 * 2.0 / math.sqrt(0.901 + 1e-5))]),
396            self.evaluate(mom0))
397        self.assertAllCloseAccordingToType(
398            np.array([(0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)),
399                      (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5))]),
400            self.evaluate(mom1))
401
402        # Check that the parameters.
403        self.assertAllCloseAccordingToType(
404            np.array([
405                1.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)),
406                2.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1e-5))
407            ]), self.evaluate(var0))
408        self.assertAllCloseAccordingToType(
409            np.array([
410                3.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)),
411                4.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5))
412            ]), self.evaluate(var1))
413
414        # Step 2: the root mean square accumulators contain the previous update.
415        self.evaluate(update)
416        # Check the rms accumulators.
417        self.assertAllCloseAccordingToType(
418            np.array([0.901 * 0.9 + 0.001, 0.901 * 0.9 + 0.001]),
419            self.evaluate(rms0))
420        self.assertAllCloseAccordingToType(
421            np.array([0.90001 * 0.9 + 1e-5, 0.90001 * 0.9 + 1e-5]),
422            self.evaluate(rms1))
423        self.assertAllCloseAccordingToType(
424            np.array([
425                0.5 * (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)) +
426                (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001 + 1e-5)),
427                0.5 * (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)) +
428                (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001 + 1e-5))
429            ]), self.evaluate(mom0))
430        self.assertAllCloseAccordingToType(
431            np.array([
432                0.5 * (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)) +
433                (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 2e-5)),
434                0.5 * (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)) +
435                (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 2e-5))
436            ]), self.evaluate(mom1))
437
438        # Check the parameters.
439        self.assertAllCloseAccordingToType(
440            np.array([
441                1.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)) -
442                (0.5 * (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)) +
443                 (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001 + 1e-5))),
444                2.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)) -
445                (0.5 * (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)) +
446                 (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001 + 1e-5)))
447            ]), self.evaluate(var0))
448
449        self.assertAllCloseAccordingToType(
450            np.array([
451                3.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)) -
452                (0.5 * (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)) +
453                 (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 2e-5))),
454                4.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)) -
455                (0.5 * (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)) +
456                 (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 2e-5)))
457            ]), self.evaluate(var1))
458
459  def testCallableParams(self):
460    with context.eager_mode():
461      for dtype in [dtypes.half, dtypes.float32]:
462        var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
463        var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
464        grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
465        grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
466
467        learning_rate = lambda: 2.0
468        decay = lambda: 0.9
469        momentum = lambda: 0.0
470        epsilon = lambda: 1.0
471        opt = rmsprop.RMSPropOptimizer(learning_rate, decay, momentum, epsilon)
472
473        # Fetch params to validate initial values
474        self.assertAllClose([1.0, 2.0], self.evaluate(var0))
475        self.assertAllClose([3.0, 4.0], self.evaluate(var1))
476        # Step 1: the rms accumulators where 1. So we should see a normal
477        # update: v -= grad * learning_rate
478        opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
479        # Check the parameters.
480        self.assertAllCloseAccordingToType(
481            np.array([
482                1.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1.0)),
483                2.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1.0))
484            ]), self.evaluate(var0))
485        self.assertAllCloseAccordingToType(
486            np.array([
487                3.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1.0)),
488                4.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1.0))
489            ]), self.evaluate(var1))
490        # Step 2: the root mean square accumulators contain the previous update.
491        opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
492        # Check the parameters.
493        self.assertAllCloseAccordingToType(
494            np.array([
495                1.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1.0)) -
496                (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001 + 1.0)),
497                2.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1.0)) -
498                (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001 + 1.0))
499            ]), self.evaluate(var0))
500        self.assertAllCloseAccordingToType(
501            np.array([
502                3.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1.0)) -
503                (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 1e-5 + 1.0)),
504                4.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1.0)) -
505                (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 1e-5 + 1.0))
506            ]), self.evaluate(var1))
507
508
509if __name__ == "__main__":
510  test.main()
511