• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for optimizers with weight decay."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22
23from tensorflow.contrib.opt.python.training import weight_decay_optimizers
24from tensorflow.python.eager import context
25from tensorflow.python.framework import constant_op
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import ops
28from tensorflow.python.framework import test_util
29from tensorflow.python.ops import resource_variable_ops
30from tensorflow.python.ops import variables
31from tensorflow.python.platform import test
32from tensorflow.python.training import adam
33
34WEIGHT_DECAY = 0.01
35
36
37def adamw_update_numpy(param, g_t, t, m, v, lr=0.001, beta1=0.9,
38                       beta2=0.999, epsilon=1e-8):
39  lr_t = lr * np.sqrt(1 - beta2**t) / (1 - beta1**t)
40
41  m_t = beta1 * m + (1 - beta1) * g_t
42  v_t = beta2 * v + (1 - beta2) * g_t * g_t
43
44  param_t = (param - lr_t * m_t / (np.sqrt(v_t) + epsilon) -
45             (param * WEIGHT_DECAY))
46  return param_t, m_t, v_t
47
48
49def momentumw_update_numpy(param, g_t, m, lr=0.001, momentum=0.9, **_):
50  # v, t are not needed for momentum optimizer
51  m = momentum * m + g_t
52  param_t = param - lr * m - param * WEIGHT_DECAY
53  return param_t, m, None
54
55
56class WeightDecayOptimizerTest(test.TestCase):
57
58  def doTest(self, optimizer, update_fn, optimizer_name, slot_name,
59             use_resource=False, do_sparse=False):
60    for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
61      with self.session(graph=ops.Graph()):
62        # Initialize variables for numpy implementation.
63        m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
64        var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
65        grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
66        var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
67        grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
68
69        if use_resource:
70          var0 = resource_variable_ops.ResourceVariable(
71              var0_np, name="var0_%d" % i)
72          var1 = resource_variable_ops.ResourceVariable(
73              var1_np, name="var1_%d" % i)
74        else:
75          var0 = variables.Variable(var0_np)
76          var1 = variables.Variable(var1_np)
77
78        if do_sparse:
79          grads0_np_indices = np.array([0, 1], dtype=np.int32)
80          grads0 = ops.IndexedSlices(constant_op.constant(grads0_np),
81                                     constant_op.constant(grads0_np_indices),
82                                     constant_op.constant([2]))
83          grads1_np_indices = np.array([0, 1], dtype=np.int32)
84          grads1 = ops.IndexedSlices(constant_op.constant(grads1_np),
85                                     constant_op.constant(grads1_np_indices),
86                                     constant_op.constant([2]))
87        else:
88          grads0 = constant_op.constant(grads0_np)
89          grads1 = constant_op.constant(grads1_np)
90
91        opt = optimizer()
92        update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
93
94        if not context.executing_eagerly():
95          with ops.Graph().as_default():
96            # Shouldn't return non-slot variables from other graphs.
97            self.assertEqual(0, len(opt.variables()))
98          self.evaluate(variables.global_variables_initializer())
99          # Fetch params to validate initial values
100          self.assertAllClose([1.0, 2.0], self.evaluate(var0))
101          self.assertAllClose([3.0, 4.0], self.evaluate(var1))
102
103        # Run 3 steps of the optimizer
104        for t in range(1, 4):
105          if not context.executing_eagerly():
106            self.evaluate(update)
107          elif t > 1:
108            opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
109
110          var0_np, m0, v0 = update_fn(var0_np, grads0_np, t=t, m=m0, v=v0)
111          var1_np, m1, v1 = update_fn(var1_np, grads1_np, t=t, m=m1, v=v1)
112
113          # Validate updated params
114          self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0))
115          self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1))
116          if use_resource:
117            self.assertEqual("var0_%d/%s:0" % (i, optimizer_name),
118                             opt.get_slot(var=var0, name=slot_name).name)
119
120
121class AdamWOptimizerTest(WeightDecayOptimizerTest):
122
123  @staticmethod
124  def get_optimizer():
125    return weight_decay_optimizers.AdamWOptimizer(WEIGHT_DECAY)
126
127  def testSparse(self):
128    self.doTest(self.get_optimizer, adamw_update_numpy, "AdamW", "m",
129                use_resource=False, do_sparse=True)
130
131  def testResourceSparse(self):
132    self.doTest(self.get_optimizer, adamw_update_numpy, "AdamW", "m",
133                use_resource=True, do_sparse=True)
134
135  def testBasic(self):
136    self.doTest(self.get_optimizer, adamw_update_numpy, "AdamW", "m",
137                use_resource=False)
138
139  @test_util.run_in_graph_and_eager_modes(reset_test=True)
140  def testResourceBasic(self):
141    self.doTest(self.get_optimizer, adamw_update_numpy, "AdamW", "m",
142                use_resource=True)
143
144
145class MomentumWOptimizerTest(WeightDecayOptimizerTest):
146
147  @staticmethod
148  def get_optimizer():
149    return weight_decay_optimizers.MomentumWOptimizer(WEIGHT_DECAY, 0.001, 0.9)
150
151  def testSparse(self):
152    self.doTest(self.get_optimizer, momentumw_update_numpy, "MomentumW",
153                "momentum", use_resource=False, do_sparse=True)
154
155  def testResourceSparse(self):
156    self.doTest(self.get_optimizer, momentumw_update_numpy, "MomentumW",
157                "momentum", use_resource=True, do_sparse=True)
158
159  def testBasic(self):
160    self.doTest(self.get_optimizer, momentumw_update_numpy, "MomentumW",
161                "momentum", use_resource=False)
162
163  @test_util.run_in_graph_and_eager_modes(reset_test=True)
164  def testResourceBasic(self):
165    self.doTest(self.get_optimizer, momentumw_update_numpy, "MomentumW",
166                "momentum", use_resource=True)
167
168
169class ExtendWithWeightDecayTest(WeightDecayOptimizerTest):
170
171  @staticmethod
172  def get_optimizer():
173    adamw = weight_decay_optimizers.extend_with_decoupled_weight_decay(
174        adam.AdamOptimizer)
175    return adamw(WEIGHT_DECAY)
176
177  def testBasic(self):
178    self.doTest(self.get_optimizer, adamw_update_numpy, "Adam", "m",
179                use_resource=False)
180
181  @test_util.run_in_graph_and_eager_modes(reset_test=True)
182  def testResourceBasic(self):
183    self.doTest(self.get_optimizer, adamw_update_numpy, "Adam", "m",
184                use_resource=True)
185
186
187if __name__ == "__main__":
188  test.main()
189