1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for optimizers with weight decay.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as np 22 23from tensorflow.contrib.opt.python.training import weight_decay_optimizers 24from tensorflow.python.eager import context 25from tensorflow.python.framework import constant_op 26from tensorflow.python.framework import dtypes 27from tensorflow.python.framework import ops 28from tensorflow.python.framework import test_util 29from tensorflow.python.ops import resource_variable_ops 30from tensorflow.python.ops import variables 31from tensorflow.python.platform import test 32from tensorflow.python.training import adam 33 34WEIGHT_DECAY = 0.01 35 36 37def adamw_update_numpy(param, g_t, t, m, v, lr=0.001, beta1=0.9, 38 beta2=0.999, epsilon=1e-8): 39 lr_t = lr * np.sqrt(1 - beta2**t) / (1 - beta1**t) 40 41 m_t = beta1 * m + (1 - beta1) * g_t 42 v_t = beta2 * v + (1 - beta2) * g_t * g_t 43 44 param_t = (param - lr_t * m_t / (np.sqrt(v_t) + epsilon) - 45 (param * WEIGHT_DECAY)) 46 return param_t, m_t, v_t 47 48 49def momentumw_update_numpy(param, g_t, m, lr=0.001, momentum=0.9, **_): 50 # v, t are not needed for momentum optimizer 51 m = momentum * m + g_t 52 param_t = param - lr * m - param * WEIGHT_DECAY 53 return param_t, m, None 54 55 56class WeightDecayOptimizerTest(test.TestCase): 57 58 def doTest(self, optimizer, update_fn, optimizer_name, slot_name, 59 use_resource=False, do_sparse=False): 60 for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): 61 with self.session(graph=ops.Graph()): 62 # Initialize variables for numpy implementation. 63 m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 64 var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) 65 grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) 66 var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) 67 grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) 68 69 if use_resource: 70 var0 = resource_variable_ops.ResourceVariable( 71 var0_np, name="var0_%d" % i) 72 var1 = resource_variable_ops.ResourceVariable( 73 var1_np, name="var1_%d" % i) 74 else: 75 var0 = variables.Variable(var0_np) 76 var1 = variables.Variable(var1_np) 77 78 if do_sparse: 79 grads0_np_indices = np.array([0, 1], dtype=np.int32) 80 grads0 = ops.IndexedSlices(constant_op.constant(grads0_np), 81 constant_op.constant(grads0_np_indices), 82 constant_op.constant([2])) 83 grads1_np_indices = np.array([0, 1], dtype=np.int32) 84 grads1 = ops.IndexedSlices(constant_op.constant(grads1_np), 85 constant_op.constant(grads1_np_indices), 86 constant_op.constant([2])) 87 else: 88 grads0 = constant_op.constant(grads0_np) 89 grads1 = constant_op.constant(grads1_np) 90 91 opt = optimizer() 92 update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) 93 94 if not context.executing_eagerly(): 95 with ops.Graph().as_default(): 96 # Shouldn't return non-slot variables from other graphs. 97 self.assertEqual(0, len(opt.variables())) 98 self.evaluate(variables.global_variables_initializer()) 99 # Fetch params to validate initial values 100 self.assertAllClose([1.0, 2.0], self.evaluate(var0)) 101 self.assertAllClose([3.0, 4.0], self.evaluate(var1)) 102 103 # Run 3 steps of the optimizer 104 for t in range(1, 4): 105 if not context.executing_eagerly(): 106 self.evaluate(update) 107 elif t > 1: 108 opt.apply_gradients(zip([grads0, grads1], [var0, var1])) 109 110 var0_np, m0, v0 = update_fn(var0_np, grads0_np, t=t, m=m0, v=v0) 111 var1_np, m1, v1 = update_fn(var1_np, grads1_np, t=t, m=m1, v=v1) 112 113 # Validate updated params 114 self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) 115 self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) 116 if use_resource: 117 self.assertEqual("var0_%d/%s:0" % (i, optimizer_name), 118 opt.get_slot(var=var0, name=slot_name).name) 119 120 121class AdamWOptimizerTest(WeightDecayOptimizerTest): 122 123 @staticmethod 124 def get_optimizer(): 125 return weight_decay_optimizers.AdamWOptimizer(WEIGHT_DECAY) 126 127 def testSparse(self): 128 self.doTest(self.get_optimizer, adamw_update_numpy, "AdamW", "m", 129 use_resource=False, do_sparse=True) 130 131 def testResourceSparse(self): 132 self.doTest(self.get_optimizer, adamw_update_numpy, "AdamW", "m", 133 use_resource=True, do_sparse=True) 134 135 def testBasic(self): 136 self.doTest(self.get_optimizer, adamw_update_numpy, "AdamW", "m", 137 use_resource=False) 138 139 @test_util.run_in_graph_and_eager_modes(reset_test=True) 140 def testResourceBasic(self): 141 self.doTest(self.get_optimizer, adamw_update_numpy, "AdamW", "m", 142 use_resource=True) 143 144 145class MomentumWOptimizerTest(WeightDecayOptimizerTest): 146 147 @staticmethod 148 def get_optimizer(): 149 return weight_decay_optimizers.MomentumWOptimizer(WEIGHT_DECAY, 0.001, 0.9) 150 151 def testSparse(self): 152 self.doTest(self.get_optimizer, momentumw_update_numpy, "MomentumW", 153 "momentum", use_resource=False, do_sparse=True) 154 155 def testResourceSparse(self): 156 self.doTest(self.get_optimizer, momentumw_update_numpy, "MomentumW", 157 "momentum", use_resource=True, do_sparse=True) 158 159 def testBasic(self): 160 self.doTest(self.get_optimizer, momentumw_update_numpy, "MomentumW", 161 "momentum", use_resource=False) 162 163 @test_util.run_in_graph_and_eager_modes(reset_test=True) 164 def testResourceBasic(self): 165 self.doTest(self.get_optimizer, momentumw_update_numpy, "MomentumW", 166 "momentum", use_resource=True) 167 168 169class ExtendWithWeightDecayTest(WeightDecayOptimizerTest): 170 171 @staticmethod 172 def get_optimizer(): 173 adamw = weight_decay_optimizers.extend_with_decoupled_weight_decay( 174 adam.AdamOptimizer) 175 return adamw(WEIGHT_DECAY) 176 177 def testBasic(self): 178 self.doTest(self.get_optimizer, adamw_update_numpy, "Adam", "m", 179 use_resource=False) 180 181 @test_util.run_in_graph_and_eager_modes(reset_test=True) 182 def testResourceBasic(self): 183 self.doTest(self.get_optimizer, adamw_update_numpy, "Adam", "m", 184 use_resource=True) 185 186 187if __name__ == "__main__": 188 test.main() 189