• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for Adam."""
16
17import numpy as np
18
19from tensorflow.compiler.tests import xla_test
20from tensorflow.python.framework import constant_op
21from tensorflow.python.framework import dtypes
22from tensorflow.python.ops import array_ops
23from tensorflow.python.ops import resource_variable_ops
24from tensorflow.python.ops import variable_scope
25from tensorflow.python.ops import variables
26from tensorflow.python.platform import test
27from tensorflow.python.training import adam
28
29
30def adam_update_numpy(param,
31                      g_t,
32                      t,
33                      m,
34                      v,
35                      alpha=0.001,
36                      beta1=0.9,
37                      beta2=0.999,
38                      epsilon=1e-8):
39  alpha_t = alpha * np.sqrt(1 - beta2**t) / (1 - beta1**t)
40
41  m_t = beta1 * m + (1 - beta1) * g_t
42  v_t = beta2 * v + (1 - beta2) * g_t * g_t
43
44  param_t = param - alpha_t * m_t / (np.sqrt(v_t) + epsilon)
45  return param_t, m_t, v_t
46
47
48class AdamOptimizerTest(xla_test.XLATestCase):
49
50  def testBasic(self):
51    for dtype in self.float_types | self.complex_types:
52      # TODO: test fails for float16 due to excessive precision requirements.
53      if dtype in [np.float16, dtypes.bfloat16.as_numpy_dtype]:
54        continue
55      with self.session(), self.test_scope():
56        variable_scope.get_variable_scope().set_use_resource(True)
57
58        # Initialize variables for numpy implementation.
59        m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
60        var0_np = np.array([1.0, 2.0], dtype=dtype)
61        grads0_np = np.array([0.1, 0.1], dtype=dtype)
62        var1_np = np.array([3.0, 4.0], dtype=dtype)
63        grads1_np = np.array([0.01, 0.01], dtype=dtype)
64
65        var0 = resource_variable_ops.ResourceVariable(var0_np)
66        var1 = resource_variable_ops.ResourceVariable(var1_np)
67        grads0 = array_ops.placeholder(dtype)
68        grads1 = array_ops.placeholder(dtype)
69        opt = adam.AdamOptimizer()
70        update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
71        self.evaluate(variables.global_variables_initializer())
72
73        # Fetch params to validate initial values
74        self.assertAllClose([1.0, 2.0], self.evaluate(var0))
75        self.assertAllClose([3.0, 4.0], self.evaluate(var1))
76
77        beta1_power, beta2_power = opt._get_beta_accumulators()
78
79        # Run 3 steps of Adam
80        for t in range(1, 4):
81          self.assertAllCloseAccordingToType(0.9**t, self.evaluate(beta1_power))
82          self.assertAllCloseAccordingToType(0.999**t,
83                                             self.evaluate(beta2_power))
84          update.run(feed_dict={grads0: grads0_np, grads1: grads1_np})
85
86          var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0)
87          var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1)
88
89          # Validate updated params
90          self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0))
91          self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1))
92
93  def testTensorLearningRate(self):
94    for dtype in self.float_types | self.complex_types:
95      # TODO: test fails for float16 due to excessive precision requirements.
96      if dtype in [np.float16, dtypes.bfloat16.as_numpy_dtype]:
97        continue
98      with self.session(), self.test_scope():
99        variable_scope.get_variable_scope().set_use_resource(True)
100
101        # Initialize variables for numpy implementation.
102        m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
103        var0_np = np.array([1.0, 2.0], dtype=dtype)
104        grads0_np = np.array([0.1, 0.1], dtype=dtype)
105        var1_np = np.array([3.0, 4.0], dtype=dtype)
106        grads1_np = np.array([0.01, 0.01], dtype=dtype)
107
108        var0 = resource_variable_ops.ResourceVariable(var0_np)
109        var1 = resource_variable_ops.ResourceVariable(var1_np)
110        grads0 = array_ops.placeholder(dtype)
111        grads1 = array_ops.placeholder(dtype)
112        opt = adam.AdamOptimizer(constant_op.constant(0.001))
113        update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
114        self.evaluate(variables.global_variables_initializer())
115
116        # Fetch params to validate initial values
117        self.assertAllClose([1.0, 2.0], self.evaluate(var0))
118        self.assertAllClose([3.0, 4.0], self.evaluate(var1))
119
120        beta1_power, beta2_power = opt._get_beta_accumulators()
121
122        # Run 3 steps of Adam
123        for t in range(1, 4):
124          self.assertAllCloseAccordingToType(0.9**t, self.evaluate(beta1_power))
125          self.assertAllCloseAccordingToType(0.999**t,
126                                             self.evaluate(beta2_power))
127          update.run(feed_dict={grads0: grads0_np, grads1: grads1_np})
128
129          var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0)
130          var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1)
131
132          # Validate updated params
133          self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0))
134          self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1))
135
136  def testSharing(self):
137    for dtype in self.float_types | self.complex_types:
138      # TODO: test fails for float16 due to excessive precision requirements.
139      if dtype in [np.float16, dtypes.bfloat16.as_numpy_dtype]:
140        continue
141      with self.session(), self.test_scope():
142        variable_scope.get_variable_scope().set_use_resource(True)
143
144        # Initialize variables for numpy implementation.
145        m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
146        var0_np = np.array([1.0, 2.0], dtype=dtype)
147        grads0_np = np.array([0.1, 0.1], dtype=dtype)
148        var1_np = np.array([3.0, 4.0], dtype=dtype)
149        grads1_np = np.array([0.01, 0.01], dtype=dtype)
150
151        var0 = resource_variable_ops.ResourceVariable(var0_np)
152        var1 = resource_variable_ops.ResourceVariable(var1_np)
153        grads0 = array_ops.placeholder(dtype)
154        grads1 = array_ops.placeholder(dtype)
155        opt = adam.AdamOptimizer()
156        update1 = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
157        update2 = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
158        self.evaluate(variables.global_variables_initializer())
159
160        beta1_power, beta2_power = opt._get_beta_accumulators()
161
162        # Fetch params to validate initial values
163        self.assertAllClose([1.0, 2.0], self.evaluate(var0))
164        self.assertAllClose([3.0, 4.0], self.evaluate(var1))
165
166        # Run 3 steps of intertwined Adam1 and Adam2.
167        for t in range(1, 4):
168          self.assertAllCloseAccordingToType(0.9**t, self.evaluate(beta1_power))
169          self.assertAllCloseAccordingToType(0.999**t,
170                                             self.evaluate(beta2_power))
171          if t % 2 == 0:
172            update1.run(feed_dict={grads0: grads0_np, grads1: grads1_np})
173          else:
174            update2.run(feed_dict={grads0: grads0_np, grads1: grads1_np})
175
176          var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0)
177          var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1)
178
179          # Validate updated params
180          self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0))
181          self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1))
182
183
184if __name__ == "__main__":
185  test.main()
186