• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Functional test for OptimizerV2."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22
23from absl.testing import parameterized
24import numpy as np
25
26from tensorflow.python import keras
27from tensorflow.python.eager import context
28from tensorflow.python.eager import def_function
29from tensorflow.python.framework import constant_op
30from tensorflow.python.framework import dtypes
31from tensorflow.python.framework import ops
32from tensorflow.python.framework import test_util
33from tensorflow.python.keras import backend
34from tensorflow.python.keras import callbacks
35from tensorflow.python.keras import combinations
36from tensorflow.python.keras import keras_parameterized
37from tensorflow.python.keras import losses
38from tensorflow.python.keras import optimizer_v1
39from tensorflow.python.keras import testing_utils
40from tensorflow.python.keras.engine import input_layer
41from tensorflow.python.keras.engine import sequential
42from tensorflow.python.keras.engine import training
43from tensorflow.python.keras.layers import core
44from tensorflow.python.keras.optimizer_v2 import adadelta
45from tensorflow.python.keras.optimizer_v2 import adagrad
46from tensorflow.python.keras.optimizer_v2 import adam
47from tensorflow.python.keras.optimizer_v2 import adamax
48from tensorflow.python.keras.optimizer_v2 import ftrl
49from tensorflow.python.keras.optimizer_v2 import gradient_descent
50from tensorflow.python.keras.optimizer_v2 import learning_rate_schedule
51from tensorflow.python.keras.optimizer_v2 import nadam
52from tensorflow.python.keras.optimizer_v2 import optimizer_v2
53from tensorflow.python.keras.optimizer_v2 import rmsprop
54from tensorflow.python.keras.utils import np_utils
55from tensorflow.python.ops import array_ops
56from tensorflow.python.ops import clip_ops
57from tensorflow.python.ops import state_ops
58from tensorflow.python.ops import variables
59from tensorflow.python.platform import test
60from tensorflow.python.training import momentum
61from tensorflow.python.training import training_util
62from tensorflow.python.training.tracking import util as trackable_utils
63
64
65_DATA_TYPES = [dtypes.half, dtypes.float32, dtypes.float64]
66# TODO(b/141710709): complex support in NVCC and ROCM.
67if (not test_util.IsBuiltWithNvcc() and not test.is_built_with_rocm()):
68  _DATA_TYPES += [dtypes.complex64, dtypes.complex128]
69
70
71class OptimizerTest(test.TestCase, parameterized.TestCase):
72
73  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
74  def testBasic(self):
75    for dtype in _DATA_TYPES:
76      with testing_utils.use_gpu():
77        var0 = variables.Variable([1.0, 2.0], dtype=dtype)
78        var1 = variables.Variable([3.0, 4.0], dtype=dtype)
79        loss = lambda: 5 * var0 + 3 * var1  # pylint: disable=cell-var-from-loop
80        sgd = gradient_descent.SGD(3.0)
81
82        self.evaluate(variables.global_variables_initializer())
83        # Fetch params to validate initial values
84        self.assertAllClose([1.0, 2.0], self.evaluate(var0))
85        self.assertAllClose([3.0, 4.0], self.evaluate(var1))
86        # Run 1 step of sgd through optimizer
87        opt_op = sgd.minimize(loss, var_list=[var0, var1])
88        self.evaluate(variables.global_variables_initializer())
89        self.evaluate(opt_op)
90        # Validate updated params
91        self.assertAllClose([-14., -13.], self.evaluate(var0))
92        self.assertAllClose([-6., -5.], self.evaluate(var1))
93
94  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
95  def testAdaptiveLearningRate(self):
96    for dtype in _DATA_TYPES:
97      with self.test_session():
98        var0 = variables.Variable([1.0, 2.0], dtype=dtype)
99        var1 = variables.Variable([3.0, 4.0], dtype=dtype)
100
101        def loss():
102          return 5 * var0 + 3 * var1  # pylint: disable=cell-var-from-loop
103
104        sgd = gradient_descent.SGD(1.0)
105
106        self.evaluate(variables.global_variables_initializer())
107        # Fetch params to validate initial values
108        self.assertAllClose([1.0, 2.0], self.evaluate(var0))
109        self.assertAllClose([3.0, 4.0], self.evaluate(var1))
110        # Run 1 step of sgd through optimizer
111        opt_op = sgd.minimize(loss, [var0, var1])
112        self.evaluate(variables.global_variables_initializer())
113        self.evaluate(opt_op)
114        # Validate updated params
115        # var0 = [1., 2.] - 1.0 * [5, 5]
116        self.assertAllClose([-4., -3.], self.evaluate(var0))
117        # var1 = [3., 4.] - 1.0 * [3, 3]
118        self.assertAllClose([0., 1.], self.evaluate(var1))
119
120        sgd.learning_rate = 0.5
121        if context.executing_eagerly():
122          sgd.minimize(loss, [var0, var1])
123        else:
124          self.evaluate(opt_op)
125        # Validate updated params
126        # var0 = [-4., -3.] - 0.5 * [5, 5]
127        self.assertAllClose([-6.5, -5.5], self.evaluate(var0))
128        # var1 = [0., 1.] - 0.5 * [3, 3]
129        self.assertAllClose([-1.5, -0.5], self.evaluate(var1))
130
131        sgd.learning_rate = learning_rate_schedule.InverseTimeDecay(
132            0.5, decay_steps=1.0, decay_rate=0.5)
133        if context.executing_eagerly():
134          sgd.minimize(loss, [var0, var1])
135        else:
136          self.evaluate(opt_op)
137
138  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
139  def testPrecomputedGradient(self):
140    for dtype in _DATA_TYPES:
141      with testing_utils.use_gpu():
142        var0 = variables.Variable([1.0, 2.0], dtype=dtype)
143        var1 = variables.Variable([3.0, 4.0], dtype=dtype)
144        loss = lambda: 5 * var0 + 3 * var1  # pylint: disable=cell-var-from-loop
145        grad_loss = constant_op.constant([42, -42], dtype=dtype)
146        sgd = gradient_descent.SGD(3.0)
147
148        self.evaluate(variables.global_variables_initializer())
149        # Fetch params to validate initial values
150        self.assertAllClose([1.0, 2.0], self.evaluate(var0))
151        self.assertAllClose([3.0, 4.0], self.evaluate(var1))
152        # Run 1 step of sgd through optimizer
153        opt_op = sgd.minimize(loss, var_list=[var0, var1], grad_loss=grad_loss)
154        self.evaluate(variables.global_variables_initializer())
155        self.evaluate(opt_op)
156        # Validate updated params
157        self.assertAllClose([1.0 - 3 * 5 * 42.0, 2.0 - 3 * 5 * (-42.0)],
158                            self.evaluate(var0))
159        self.assertAllClose([3.0 - 3 * 3 * 42.0, 4.0 - 3 * 3 * (-42.0)],
160                            self.evaluate(var1))
161
162  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
163  def testNoGradients(self):
164    for dtype in _DATA_TYPES:
165      with testing_utils.use_gpu():
166        var0 = variables.Variable([1.0, 2.0], dtype=dtype)
167        var1 = variables.Variable([3.0, 4.0], dtype=dtype)
168        loss = lambda: 5 * var0  # pylint: disable=cell-var-from-loop
169        sgd_op = gradient_descent.SGD(3.0)
170        with self.assertRaisesRegex(ValueError, 'No gradients'):
171          # var1 has no gradient
172          sgd_op.minimize(loss, var_list=[var1])
173
174  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
175  def testNoGradientsForAnyVariables_Minimize(self):
176    for dtype in _DATA_TYPES:
177      with testing_utils.use_gpu():
178        var0 = variables.Variable([1.0, 2.0], dtype=dtype)
179        var1 = variables.Variable([3.0, 4.0], dtype=dtype)
180        loss = lambda: constant_op.constant(5.0)
181
182        sgd_op = gradient_descent.SGD(3.0)
183        with self.assertRaisesRegex(ValueError,
184                                    'No gradients provided for any variable'):
185          sgd_op.minimize(loss, var_list=[var0, var1])
186
187  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
188  def testNoGradientsForAnyVariables_ApplyGradients(self):
189    for dtype in _DATA_TYPES:
190      with testing_utils.use_gpu():
191        var0 = variables.Variable([1.0, 2.0], dtype=dtype)
192        var1 = variables.Variable([3.0, 4.0], dtype=dtype)
193        sgd_op = gradient_descent.SGD(3.0)
194        with self.assertRaisesRegex(ValueError,
195                                    'No gradients provided for any variable'):
196          sgd_op.apply_gradients([(None, var0), (None, var1)])
197
198  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
199  def testGradientsAsVariables(self):
200    for i, dtype in enumerate(_DATA_TYPES):
201      with testing_utils.use_gpu():
202        var0 = variables.Variable([1.0, 2.0], dtype=dtype)
203        var1 = variables.Variable([3.0, 4.0], dtype=dtype)
204        loss = lambda: 5 * var0 + 3 * var1  # pylint: disable=cell-var-from-loop
205
206        sgd = gradient_descent.SGD(3.0)
207        grads_and_vars = sgd._compute_gradients(loss, [var0, var1])
208        # Convert gradients to tf.Variables
209        converted_grads = [
210            variables.Variable(
211                array_ops.zeros([2], dtype), name='c_%d_%d' % (i, j))
212            for j, gv in enumerate(grads_and_vars)
213        ]
214        convert_ops = [
215            state_ops.assign(converted_grads[j], gv[0])
216            for j, gv in enumerate(grads_and_vars)
217        ]
218
219        # Run convert_ops to achieve the gradients converting
220        self.evaluate(variables.global_variables_initializer())
221        self.evaluate(convert_ops)
222        # Fetch params to validate initial values
223        self.assertAllClose([1.0, 2.0], self.evaluate(var0))
224        self.assertAllClose([3.0, 4.0], self.evaluate(var1))
225
226        # Run 1 step of sgd through optimizer
227        converted_grads_and_vars = list(zip(converted_grads, [var0, var1]))
228        opt_op = sgd.apply_gradients(converted_grads_and_vars)
229        self.evaluate(variables.global_variables_initializer())
230        self.evaluate(convert_ops)
231        self.evaluate(opt_op)
232
233        # Validate updated params
234        self.assertAllClose([-14., -13.], self.evaluate(var0))
235        self.assertAllClose([-6., -5.], self.evaluate(var1))
236
237  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
238  def testComputeGradientsWithTensors(self):
239    with testing_utils.use_gpu():
240      x = ops.convert_to_tensor_v2_with_dispatch(1.0)
241
242      def f():
243        return x * x
244
245      sgd = gradient_descent.SGD(3.0)
246      grads_and_vars = sgd._compute_gradients(f, [x])
247      self.assertLen(grads_and_vars, 1)
248      grad, x_as_var = grads_and_vars[0]
249      self.assertIs(x, x_as_var)
250      self.assertEqual(2.0, self.evaluate(grad))
251
252      with self.assertRaises(NotImplementedError):
253        sgd.apply_gradients(grads_and_vars)
254
255  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
256  def testConstraint(self):
257    constraint_01 = lambda x: clip_ops.clip_by_value(x, -0.1, 0.)
258    constraint_0 = lambda x: clip_ops.clip_by_value(x, 0., 1.)
259    with testing_utils.use_gpu():
260      var0 = variables.Variable([1.0, 2.0],
261                                constraint=constraint_01)
262      var1 = variables.Variable([3.0, 4.0],
263                                constraint=constraint_0)
264      loss = lambda: 5 * var0 + 3 * var1
265      sgd = gradient_descent.SGD(3.0)
266
267      self.evaluate(variables.global_variables_initializer())
268      # Fetch params to validate initial values
269      self.assertAllClose([1.0, 2.0], self.evaluate(var0))
270      self.assertAllClose([3.0, 4.0], self.evaluate(var1))
271      # Run 1 step of sgd through optimizer
272      opt_op = sgd.minimize(loss, var_list=[var0, var1])
273      self.evaluate(variables.global_variables_initializer())
274      self.evaluate(opt_op)
275      # Validate updated params
276      self.assertAllClose([-0.1, -0.1], self.evaluate(var0))
277      self.assertAllClose([0., 0.], self.evaluate(var1))
278
279  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
280  def testIterationWithoutMinimize(self):
281    with testing_utils.use_gpu():
282      sgd = gradient_descent.SGD(3.0)
283      self.evaluate(sgd.iterations.initializer)
284      self.assertEqual(0, self.evaluate(sgd.iterations))
285
286  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
287  def testConfig(self):
288    with testing_utils.use_gpu():
289      opt = gradient_descent.SGD(learning_rate=1.0)
290      config = opt.get_config()
291      opt2 = gradient_descent.SGD.from_config(config)
292      lr = opt._get_hyper('learning_rate')
293      lr2 = opt2._get_hyper('learning_rate')
294      self.evaluate(variables.global_variables_initializer())
295      # assert both are equal float values.
296      self.assertEqual(self.evaluate(lr), self.evaluate(lr2))
297      var0 = variables.Variable([[1.0], [2.0]], dtype=dtypes.float32)
298      loss = lambda: 3 * var0
299      # learning rate variable created when calling minimize.
300      opt.minimize(loss, [var0])
301      opt3 = gradient_descent.SGD.from_config(config)
302      lr3 = opt3._get_hyper('learning_rate')
303      self.evaluate(variables.global_variables_initializer())
304      self.assertEqual(self.evaluate(lr), self.evaluate(lr3))
305
306  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
307  def testConfigWithLearningRateDecay(self):
308    with testing_utils.use_gpu():
309      var0 = variables.Variable([[1.0], [2.0]], dtype=dtypes.float32)
310      for decay_schedule in [
311          learning_rate_schedule.InverseTimeDecay(
312              0.5, decay_steps=1.0, decay_rate=0.1),
313          learning_rate_schedule.PiecewiseConstantDecay(
314              [5], [1., .5])
315      ]:
316        step = 10
317        opt = gradient_descent.SGD(decay_schedule)
318        config = opt.get_config()
319        opt2 = gradient_descent.SGD.from_config(config)
320        # assert both are equal float values.
321        self.assertAllEqual(
322            decay_schedule(step),
323            opt._get_hyper('learning_rate')(step))
324        self.assertAllEqual(
325            decay_schedule(step),
326            opt2._get_hyper('learning_rate')(step))
327        loss = lambda: 3 * var0
328        # learning rate variable is created when calling minimize.
329        opt.minimize(loss, [var0])
330        self.evaluate(variables.global_variables_initializer())
331        config = opt.get_config()
332        opt3 = gradient_descent.SGD.from_config(config)
333        self.assertAllEqual(
334            self.evaluate(opt._get_hyper('learning_rate')(step)),
335            opt3._get_hyper('learning_rate')(step))
336
337  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
338  def testGradClipValue(self):
339    with testing_utils.use_gpu():
340      var = variables.Variable([1.0, 2.0])
341      loss = lambda: 3 * var
342      opt = gradient_descent.SGD(learning_rate=1.0, clipvalue=1.0)
343      opt_op = opt.minimize(loss, [var])
344      self.evaluate(variables.global_variables_initializer())
345      self.evaluate(opt_op)
346      self.assertAllClose([0., 1.], self.evaluate(var))
347
348  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
349  def testGradClipNorm(self):
350    with testing_utils.use_gpu():
351      var = variables.Variable([1.0])
352      loss = lambda: 3 * var
353      opt = gradient_descent.SGD(learning_rate=1.0, clipnorm=1.0)
354      opt_op = opt.minimize(loss, [var])
355      self.evaluate(variables.global_variables_initializer())
356      self.evaluate(opt_op)
357      self.assertAllClose([0.], self.evaluate(var))
358
359  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
360  def testGradGlobalClipNorm(self):
361    with testing_utils.use_gpu():
362      # l2 norm is 5.0
363      var1 = variables.Variable([1.0])
364      var2 = variables.Variable([2.0])
365      loss = lambda: 3 * var1 + 4 * var2
366      opt = gradient_descent.SGD(learning_rate=1.0, global_clipnorm=2.0)
367      opt_op = opt.minimize(loss, [var1, var2])
368      self.evaluate(variables.global_variables_initializer())
369      self.evaluate(opt_op)
370      # grad1 = 3.0 * 2.0 / 5.0 = 1.2
371      self.assertAllClose([-.2], self.evaluate(var1))
372      # grad2 = 4.0 * 2.0 / 5.0 = 1.6
373      self.assertAllClose([.4], self.evaluate(var2))
374
375  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
376  def testInvalidClipNorm(self):
377    with self.assertRaisesRegex(ValueError, '>= 0'):
378      gradient_descent.SGD(learning_rate=1.0, clipnorm=-1.0)
379
380  @combinations.generate(
381      combinations.combine(
382          mode=['graph', 'eager'],
383          clip_type=['clipnorm', 'global_clipnorm', 'clipvalue']))
384  def testConfigWithCliping(self, clip_type):
385    opt = gradient_descent.SGD(learning_rate=1.0, **{clip_type: 2.0})
386    config = opt.get_config()
387    opt = gradient_descent.SGD.from_config(config)
388    self.assertEqual(getattr(opt, clip_type), 2.0)
389
390  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
391  def testInvalidKwargs(self):
392    with self.assertRaisesRegex(TypeError, 'Unexpected keyword argument'):
393      gradient_descent.SGD(learning_rate=1.0, invalidkwargs=1.0)
394
395  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
396  def testWeights(self):
397    with testing_utils.use_gpu():
398      opt1 = adam.Adam(learning_rate=1.0)
399      var1 = variables.Variable([1.0, 2.0], dtype=dtypes.float32)
400      loss1 = lambda: 3 * var1
401      opt_op_1 = opt1.minimize(loss1, [var1])
402      self.evaluate(variables.global_variables_initializer())
403      config = opt1.get_config()
404      opt2 = adam.Adam.from_config(config)
405      var2 = variables.Variable([1.0, 2.0], dtype=dtypes.float32)
406      loss2 = lambda: 3 * var2
407      opt_op_2 = opt2.minimize(loss2, [var2])
408      weights = opt1.get_weights()
409
410      # Assert set_weights and both variables get updated to same value.
411      self.evaluate(variables.global_variables_initializer())
412      opt2.set_weights(weights)
413      self.evaluate([opt_op_1, opt_op_2])
414      self.assertAllClose(self.evaluate(var1), self.evaluate(var2))
415      self.assertEqual(1, self.evaluate(opt1.iterations))
416      self.assertEqual(1, self.evaluate(opt2.iterations))
417
418      var3 = variables.Variable([1.0, 2.0, 3.0], dtype=dtypes.float32)
419      var4 = variables.Variable([4.0, 5.0, 6.0], dtype=dtypes.float32)
420      loss3 = lambda: 3 * var3 + 5 * var4
421      opt_op_3 = opt1.minimize(loss3, [var3, var4])
422
423      # Assert set_weights with ValueError since weight list does not match.
424      self.evaluate(variables.global_variables_initializer())
425      weights = opt1.get_weights()
426      with self.assertRaisesRegex(ValueError, 'but the optimizer was'):
427        opt2.set_weights(weights)
428
429      # Assert set_weights and variables get updated to same value.
430      var5 = variables.Variable([1.0, 2.0, 3.0], dtype=dtypes.float32)
431      var6 = variables.Variable([4.0, 5.0, 6.0], dtype=dtypes.float32)
432      loss4 = lambda: 3 * var5 + 5 * var6
433      opt_op_4 = opt2.minimize(loss4, [var5, var6])
434      self.evaluate(variables.global_variables_initializer())
435      opt2.set_weights(weights)
436      self.evaluate([opt_op_3, opt_op_4])
437      self.assertAllClose(
438          self.evaluate([var3, var4]), self.evaluate([var5, var6]))
439
440  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
441  def testGettingHyperParameters(self):
442    with self.test_session():
443      opt = adam.Adam(learning_rate=1.0)
444      var = variables.Variable([1.0, 2.0], dtype=dtypes.float32)
445      loss = lambda: 3 * var
446      opt_op = opt.minimize(loss, [var])
447      self.evaluate(variables.global_variables_initializer())
448      self.evaluate(opt_op)
449
450      lr = self.evaluate(opt.lr)
451      self.assertEqual(1.0, lr)
452
453      opt.lr = 2.0
454      lr = self.evaluate(opt.lr)
455      self.assertEqual(2.0, lr)
456
457      self.evaluate(opt.lr.assign(3.0))
458      lr = self.evaluate(opt.lr)
459      self.assertEqual(3.0, lr)
460
461      with self.assertRaises(AttributeError):
462        opt.not_an_attr += 3
463
464  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
465  def testGettingHyperParametersWithLrInConstructor(self):
466    with self.test_session():
467      opt = gradient_descent.SGD(lr=3.0)
468      var = variables.Variable([1.0, 2.0], dtype=dtypes.float32)
469      loss = lambda: 3 * var
470      opt_op = opt.minimize(loss, [var])
471      self.evaluate(variables.global_variables_initializer())
472      self.evaluate(opt_op)
473
474      self.assertIsInstance(opt.lr, variables.Variable)
475      self.assertIsInstance(opt.learning_rate, variables.Variable)
476
477      lr = self.evaluate(opt.lr)
478      self.assertEqual(3.0, lr)
479
480      opt.lr = 2.0
481      lr = self.evaluate(opt.lr)
482      self.assertEqual(2.0, lr)
483
484      self.evaluate(opt.lr.assign(4.0))
485      lr = self.evaluate(opt.lr)
486      self.assertEqual(4.0, lr)
487
488  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
489  def testDir(self):
490    opt = gradient_descent.SGD(learning_rate=1.0, momentum=0.1)
491    dir_result = set(dir(opt))
492    self.assertIn('learning_rate', dir_result)  # Hyperparameter
493    self.assertIn('lr', dir_result)  # Hyperparameter
494    self.assertIn('momentum', dir_result)  # Hyperparameter
495    self.assertIn('nesterov', dir_result)  # Attribute
496    self.assertIn('minimize', dir_result)  # Attribute
497
498  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
499  def testOptimizerWithKerasModel(self):
500    a = input_layer.Input(shape=(3,), name='input_a')
501    b = input_layer.Input(shape=(3,), name='input_b')
502
503    dense = core.Dense(4, name='dense')
504    c = dense(a)
505    d = dense(b)
506    e = core.Dropout(0.5, name='dropout')(c)
507
508    model = training.Model([a, b], [d, e])
509
510    optimizer = gradient_descent.SGD(learning_rate=0.001)
511    loss = 'mse'
512    model.compile(optimizer, loss, metrics=['mae'])
513
514    input_a_np = np.random.random((10, 3))
515    input_b_np = np.random.random((10, 3))
516
517    output_d_np = np.random.random((10, 4))
518    output_e_np = np.random.random((10, 4))
519
520    model.fit([input_a_np, input_b_np], [output_d_np, output_e_np],
521              epochs=1,
522              batch_size=5)
523
524  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
525  def testOptimizerWithCallbacks(self):
526    np.random.seed(1331)
527    input_np = np.random.random((10, 3))
528    output_np = np.random.random((10, 4))
529    a = input_layer.Input(shape=(3,), name='input_a')
530    model = sequential.Sequential()
531    model.add(core.Dense(4, kernel_initializer='zeros', name='dense'))
532    model.add(core.Dropout(0.5, name='dropout'))
533    model(a)
534    optimizer = gradient_descent.SGD(learning_rate=0.1)
535    model.compile(optimizer, loss='mse', metrics=['mae'])
536    # This does not reduce the LR after the first epoch (due to low delta).
537    cbks = [
538        callbacks.ReduceLROnPlateau(
539            monitor='val_loss', factor=0.1, min_delta=0, patience=1, cooldown=5)
540    ]
541    model.fit(
542        input_np,
543        output_np,
544        batch_size=10,
545        validation_data=(input_np, output_np),
546        callbacks=cbks,
547        epochs=2,
548        verbose=0)
549    self.assertAllClose(
550        float(backend.get_value(model.optimizer.lr)), 0.1, atol=1e-4)
551
552    # This should reduce the LR after the first epoch (due to high delta).
553    cbks = [
554        callbacks.ReduceLROnPlateau(
555            monitor='val_loss',
556            factor=0.1,
557            min_delta=10,
558            patience=1,
559            cooldown=5)
560    ]
561    model.fit(
562        input_np,
563        output_np,
564        batch_size=10,
565        validation_data=(input_np, output_np),
566        callbacks=cbks,
567        epochs=2,
568        verbose=2)
569    self.assertAllClose(
570        float(backend.get_value(model.optimizer.lr)), 0.01, atol=1e-4)
571
572  def testOptimizerSetIterations(self):
573    global_step = training_util.get_or_create_global_step()
574    opt = adam.Adam(learning_rate=1.0)
575    opt.iterations = global_step
576    var = variables.Variable([1.0, 2.0], dtype=dtypes.float32)
577    self.evaluate(variables.global_variables_initializer())
578    init_step_value = self.evaluate(global_step)
579    loss = lambda: 3 * var
580    opt_op = opt.minimize(loss, [var])
581    self.evaluate(variables.global_variables_initializer())
582    self.evaluate(opt_op)
583    new_step_value = self.evaluate(global_step)
584    self.assertEqual(new_step_value, init_step_value + 1)
585
586  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
587  def testOptimizerWithCallableVarList(self):
588    train_samples = 20
589    input_dim = 1
590    num_classes = 2
591    (x, y), _ = testing_utils.get_test_data(
592        train_samples=train_samples,
593        test_samples=10,
594        input_shape=(input_dim,),
595        num_classes=num_classes)
596    y = np_utils.to_categorical(y)
597
598    num_hidden = 1
599    model = testing_utils.get_small_sequential_mlp(
600        num_hidden=num_hidden, num_classes=num_classes)
601    opt = adam.Adam()
602
603    loss = lambda: losses.mean_squared_error(model(x), y)
604    var_list = lambda: model.trainable_weights
605
606    with self.assertRaisesRegex(
607        ValueError, 'Weights for model .* have not yet been created'):
608      var_list()
609    train_op = opt.minimize(loss, var_list)
610    if not context.executing_eagerly():
611      self.evaluate(variables.global_variables_initializer())
612      self.assertEqual(
613          [[0.]], self.evaluate(opt.get_slot(var_list()[0], 'm')))
614      self.evaluate(train_op)
615    self.assertNotEqual(
616        [[0.]], self.evaluate(opt.get_slot(var_list()[0], 'm')))
617    self.assertLen(var_list(), 4)
618
619  def testVarKey(self):
620    with ops.get_default_graph().as_default():
621      a = variables.Variable([1., 2.], name='var')
622      b = variables.Variable([1.], name='var')
623      self.assertTrue(a._in_graph_mode)
624      self.assertTrue(b._in_graph_mode)
625      var_key = optimizer_v2._var_key(a)
626      self.assertEqual('var', var_key)
627      var_key = optimizer_v2._var_key(b)
628      self.assertEqual('var_1', var_key)
629
630  def testVarName(self):
631    with ops.get_default_graph().as_default():
632      var = variables.Variable([1., 2.], name='var')
633      loss = var + 1.
634      opt = adam.Adam()
635      opt.get_updates(loss, [var])
636      opt_vars = opt.variables()
637      self.assertLen(opt_vars, 3)
638      self.assertEqual('Adam/iter:0', opt_vars[0].name)
639      self.assertEqual('Adam/var/m:0', opt_vars[1].name)
640      var_2 = variables.Variable([1., 2.], name='var_2')
641      loss = var_2 + 1.
642      with backend.name_scope('outter'):
643        opt.get_updates(loss, [var_2])
644      opt_vars = opt.variables()
645      self.assertLen(opt_vars, 5)
646      self.assertEqual('outter/Adam/var_2/m:0', opt_vars[3].name)
647
648  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
649  def testEmptyVarList(self):
650    opt = gradient_descent.SGD(1.)
651    opt.minimize(lambda: constant_op.constant(1.), [])
652    opt.apply_gradients([])
653
654  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
655  def testAggregationTrue(self):
656    # Test that experimental_aggregate_gradients=True works without distributed
657    # strategy.
658    var = variables.Variable([1., 2.])
659    opt = gradient_descent.SGD(3.0)
660
661    self.evaluate(variables.global_variables_initializer())
662    self.assertAllClose([1., 2.], self.evaluate(var))
663    opt_op = opt.apply_gradients([([0.1, 0.1], var)],
664                                 experimental_aggregate_gradients=True)
665    self.evaluate(variables.global_variables_initializer())
666    self.evaluate(opt_op)
667    self.assertAllClose([0.7, 1.7], self.evaluate(var))
668
669  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
670  def testAggregationFalse(self):
671    # Test that experimental_aggregate_gradients=False works without distributed
672    # strategy.
673    var = variables.Variable([1., 2.])
674    opt = gradient_descent.SGD(3.0)
675
676    self.evaluate(variables.global_variables_initializer())
677    self.assertAllClose([1., 2.], self.evaluate(var))
678    opt_op = opt.apply_gradients([([0.1, 0.1], var)],
679                                 experimental_aggregate_gradients=False)
680    self.evaluate(variables.global_variables_initializer())
681    self.evaluate(opt_op)
682    self.assertAllClose([0.7, 1.7], self.evaluate(var))
683
684  @combinations.generate(combinations.combine(mode=['eager']))
685  def testRestoringIterationsWithoutAnOptimizer(self):
686    opt = gradient_descent.SGD(3.0)
687    opt.iterations.assign(5)
688    checkpoint = trackable_utils.Checkpoint(optimizer=opt)
689    path = checkpoint.save(self.get_temp_dir())
690
691    # Following verifies that the `iterations` can be restored with the absence
692    # of an `Optimizer` object (using a `Checkpoint` as a placeholder).
693    iterations_var = variables.Variable(0, dtype=dtypes.int64)
694    optimizer_checkpoint = trackable_utils.Checkpoint(iter=iterations_var)
695    checkpoint_to_restore = trackable_utils.Checkpoint(
696        optimizer=optimizer_checkpoint)
697    checkpoint_to_restore.restore(path)
698
699    self.assertEqual(5, self.evaluate(iterations_var))
700
701
702@keras_parameterized.run_all_keras_modes
703class OptimizersCompatibilityTest(keras_parameterized.TestCase):
704
705  def _testOptimizersCompatibility(self, opt_v1, opt_v2, test_weights=True):
706    if context.executing_eagerly():
707      self.skipTest(
708          'v1 optimizer does not run in eager mode')
709    np.random.seed(1331)
710    with testing_utils.use_gpu():
711      train_samples = 20
712      input_dim = 3
713      num_classes = 2
714      (x, y), _ = testing_utils.get_test_data(
715          train_samples=train_samples,
716          test_samples=10,
717          input_shape=(input_dim,),
718          num_classes=num_classes)
719      y = np_utils.to_categorical(y)
720
721      num_hidden = 5
722      model_v1 = testing_utils.get_small_sequential_mlp(
723          num_hidden=num_hidden, num_classes=num_classes, input_dim=input_dim)
724      model_v1.compile(
725          opt_v1,
726          loss='categorical_crossentropy',
727          metrics=[],
728          run_eagerly=testing_utils.should_run_eagerly())
729      model_v1.fit(x, y, batch_size=5, epochs=1)
730
731      model_v2 = testing_utils.get_small_sequential_mlp(
732          num_hidden=num_hidden, num_classes=num_classes, input_dim=input_dim)
733      model_v2.set_weights(model_v1.get_weights())
734      model_v2.compile(
735          opt_v2,
736          loss='categorical_crossentropy',
737          metrics=[],
738          run_eagerly=testing_utils.should_run_eagerly())
739      if not ops.executing_eagerly_outside_functions():
740        model_v2._make_train_function()
741      if test_weights:
742        opt_v2.set_weights(opt_v1.get_weights())
743
744      hist_1 = model_v1.fit(x, y, batch_size=5, epochs=1, shuffle=False)
745      hist_2 = model_v2.fit(x, y, batch_size=5, epochs=1, shuffle=False)
746      self.assertAllClose(model_v1.get_weights(), model_v2.get_weights(),
747                          rtol=1e-5, atol=1e-5)
748      self.assertAllClose(hist_1.history['loss'], hist_2.history['loss'],
749                          rtol=1e-5, atol=1e-5)
750
751  def testAdadeltaCompatibility(self):
752    opt_v1 = optimizer_v1.Adadelta(lr=0.01)
753    opt_v2 = adadelta.Adadelta(learning_rate=0.01)
754    self._testOptimizersCompatibility(opt_v1, opt_v2)
755
756  def testAdagradCompatibility(self):
757    opt_v1 = optimizer_v1.Adagrad(lr=0.01)
758    opt_v2 = adagrad.Adagrad(learning_rate=0.01)
759    self._testOptimizersCompatibility(opt_v1, opt_v2)
760
761  def testAdamCompatibility(self):
762    opt_v1 = optimizer_v1.Adam()
763    opt_v2 = adam.Adam()
764    self._testOptimizersCompatibility(opt_v1, opt_v2)
765
766  def testAdamaxCompatibility(self):
767    opt_v1 = optimizer_v1.Adamax(lr=0.01)
768    opt_v2 = adamax.Adamax(learning_rate=0.01)
769    self._testOptimizersCompatibility(opt_v1, opt_v2)
770
771  def testNadamCompatibility(self):
772    opt_v1 = optimizer_v1.Nadam(lr=0.001)
773    opt_v2 = nadam.Nadam(learning_rate=0.001)
774    self._testOptimizersCompatibility(opt_v1, opt_v2)
775
776  def testMomentumCompatibility(self):
777    opt_v1 = optimizer_v1.SGD(lr=0.01, momentum=0.9)
778    opt_v2 = gradient_descent.SGD(learning_rate=0.01, momentum=0.9)
779    self._testOptimizersCompatibility(opt_v1, opt_v2)
780
781  def testRMSpropCompatibility(self):
782    opt_v1 = optimizer_v1.RMSprop()
783    opt_v2 = rmsprop.RMSprop()
784    self._testOptimizersCompatibility(opt_v1, opt_v2)
785
786  def testSGDCompatibility(self):
787    opt_v1 = optimizer_v1.SGD(lr=0.01)
788    opt_v2 = gradient_descent.SGD(learning_rate=0.01)
789    self._testOptimizersCompatibility(opt_v1, opt_v2, False)
790
791  def testNumericEquivalenceForNesterovMomentum(self):
792    if context.executing_eagerly():
793      self.skipTest(
794          'v1 optimizer does not run in eager mode')
795    np.random.seed(1331)
796    with testing_utils.use_gpu():
797      train_samples = 20
798      input_dim = 3
799      num_classes = 2
800      (x, y), _ = testing_utils.get_test_data(
801          train_samples=train_samples,
802          test_samples=10,
803          input_shape=(input_dim,),
804          num_classes=num_classes)
805      y = np_utils.to_categorical(y)
806
807      num_hidden = 5
808      model_k_v1 = testing_utils.get_small_sequential_mlp(
809          num_hidden=num_hidden, num_classes=num_classes, input_dim=input_dim)
810      model_k_v2 = testing_utils.get_small_sequential_mlp(
811          num_hidden=num_hidden, num_classes=num_classes, input_dim=input_dim)
812      model_k_v2.set_weights(model_k_v1.get_weights())
813      model_tf = testing_utils.get_small_sequential_mlp(
814          num_hidden=num_hidden, num_classes=num_classes, input_dim=input_dim)
815      model_tf.set_weights(model_k_v2.get_weights())
816
817      opt_k_v1 = optimizer_v1.SGD(momentum=0.9, nesterov=True)
818      opt_k_v2 = gradient_descent.SGD(momentum=0.9, nesterov=True)
819      opt_tf = momentum.MomentumOptimizer(
820          learning_rate=0.01, momentum=0.9, use_nesterov=True)
821
822      model_k_v1.compile(
823          opt_k_v1,
824          loss='categorical_crossentropy',
825          metrics=[],
826          run_eagerly=testing_utils.should_run_eagerly())
827      model_k_v2.compile(
828          opt_k_v2,
829          loss='categorical_crossentropy',
830          metrics=[],
831          run_eagerly=testing_utils.should_run_eagerly())
832      model_tf.compile(
833          opt_tf,
834          loss='categorical_crossentropy',
835          metrics=[],
836          run_eagerly=testing_utils.should_run_eagerly())
837
838      hist_k_v1 = model_k_v1.fit(x, y, batch_size=5, epochs=10, shuffle=False)
839      hist_k_v2 = model_k_v2.fit(x, y, batch_size=5, epochs=10, shuffle=False)
840      hist_tf = model_tf.fit(x, y, batch_size=5, epochs=10, shuffle=False)
841
842      self.assertAllClose(model_k_v1.get_weights(), model_tf.get_weights())
843      self.assertAllClose(model_k_v1.get_weights(), model_k_v2.get_weights())
844      self.assertAllClose(opt_k_v1.get_weights(), opt_k_v2.get_weights())
845      self.assertAllClose(hist_k_v1.history['loss'], hist_tf.history['loss'])
846      self.assertAllClose(hist_k_v1.history['loss'], hist_k_v2.history['loss'])
847
848  def testNumericEquivalenceForAmsgrad(self):
849    if context.executing_eagerly():
850      self.skipTest(
851          'v1 optimizer does not run in eager mode')
852    np.random.seed(1331)
853    with testing_utils.use_gpu():
854      train_samples = 20
855      input_dim = 3
856      num_classes = 2
857      (x, y), _ = testing_utils.get_test_data(
858          train_samples=train_samples,
859          test_samples=10,
860          input_shape=(input_dim,),
861          num_classes=num_classes)
862      y = np_utils.to_categorical(y)
863
864      num_hidden = 5
865      model_k_v1 = testing_utils.get_small_sequential_mlp(
866          num_hidden=num_hidden, num_classes=num_classes, input_dim=input_dim)
867      model_k_v2 = testing_utils.get_small_sequential_mlp(
868          num_hidden=num_hidden, num_classes=num_classes, input_dim=input_dim)
869      model_k_v2.set_weights(model_k_v1.get_weights())
870
871      opt_k_v1 = optimizer_v1.Adam(amsgrad=True)
872      opt_k_v2 = adam.Adam(amsgrad=True)
873
874      model_k_v1.compile(
875          opt_k_v1,
876          loss='categorical_crossentropy',
877          metrics=[],
878          run_eagerly=testing_utils.should_run_eagerly())
879      model_k_v2.compile(
880          opt_k_v2,
881          loss='categorical_crossentropy',
882          metrics=[],
883          run_eagerly=testing_utils.should_run_eagerly())
884
885      hist_k_v1 = model_k_v1.fit(x, y, batch_size=5, epochs=10, shuffle=False)
886      hist_k_v2 = model_k_v2.fit(x, y, batch_size=5, epochs=10, shuffle=False)
887
888      self.assertAllClose(model_k_v1.get_weights(), model_k_v2.get_weights())
889      self.assertAllClose(opt_k_v1.get_weights(), opt_k_v2.get_weights())
890      self.assertAllClose(hist_k_v1.history['loss'], hist_k_v2.history['loss'])
891
892
893# Note: These tests are kept in a separate class to avoid bugs in some
894# distributions of Python that break AutoGraph which is used by tf.function.
895@combinations.generate(combinations.combine(mode=['eager']))
896class OptimizerWithFunctionTest(test.TestCase, parameterized.TestCase):
897
898  def testBasic(self):
899    var = variables.Variable([1.0, 2.0], dtype=dtypes.float32)
900    loss = lambda: 3 * var
901    opt = adam.Adam(learning_rate=1.0)
902
903    @def_function.function
904    def fn():
905      opt.minimize(loss, [var])
906      return var
907
908    self.assertAllClose([0., 1.], fn(), atol=1e-4)
909    self.assertAllClose([-1, 0.], fn(), atol=1e-4)
910
911  def testBasicWithConstantDecay(self):
912    var = variables.Variable([1.0, 2.0], dtype=dtypes.float32)
913    loss = lambda: 3 * var
914    opt = adam.Adam(learning_rate=1.0)
915
916    @def_function.function
917    def fn():
918      opt.minimize(loss, [var])
919      return var
920
921    self.assertAllClose([0., 1.], fn(), atol=1e-4)
922    self.assertAllClose([-1, 0.], fn(), atol=1e-4)
923
924  def testVarKeyWithVarCreatedInEager(self):
925    a = variables.Variable([1., 2.], name='var')
926    b = variables.Variable([1.], name='var')
927
928    @test_util.also_run_as_tf_function
929    def var_key_test():
930      self.assertFalse(a._in_graph_mode)
931      self.assertFalse(b._in_graph_mode)
932      var_key_a = optimizer_v2._var_key(a)
933      self.assertStartsWith(var_key_a, 'var_')
934      var_key_b = optimizer_v2._var_key(b)
935      self.assertStartsWith(var_key_b, 'var_')
936      self.assertNotEqual(var_key_a, var_key_b)
937
938    var_key_test()
939
940  def testLearningRateDecayUsedInTwoFunctions(self):
941    a = variables.Variable([1., 2.], name='var')
942    b = variables.Variable([1.], name='var')
943
944    learning_rate_decay = learning_rate_schedule.InverseTimeDecay(
945        0.5, decay_steps=1.0, decay_rate=0.5)
946    opt = adam.Adam(learning_rate=learning_rate_decay)
947    loss_a = lambda: 3 * a
948    loss_b = lambda: 2 * b
949
950    @def_function.function
951    def fn_a():
952      opt.minimize(loss_a, [a])
953      return a
954
955    @def_function.function
956    def fn_b():
957      opt.minimize(loss_b, [b])
958      return b
959
960    fn_a()
961    fn_b()
962
963
964_NUM_LEARNERS = 50
965APPLY_SCOPE = 'debug_apply'
966ALLOWLIST = [
967    # optimizer_v2._deduplicate_indexed_slices contains an indexed slice:
968    #   array_ops.shape(unique_indices)[0]
969    # which winds up expanding to [0:1:1] thereby creating three constants
970    # to represent the indices.
971    ('embeddings/strided_slice/stack', 'Const'),
972]
973
974
975def get_inputs(op):
976  op_inputs = list(op.inputs) + op.control_inputs
977  names = [i.name for i in op_inputs]
978  op_inputs = [getattr(i, 'op', i) for i in op_inputs]
979  return op_inputs, names
980
981
982def strip_name(node):
983  if 'Placeholder' in node.op:
984    return
985  node.name = ''
986
987
988def topological_sort(graph):
989  graph_ops = graph.get_operations()
990
991  sources = []
992  result = []
993
994  inputs = {}
995  outputs = collections.defaultdict(set)
996  for op in graph_ops:
997    op_inputs = get_inputs(op)[0]
998    if not op_inputs:
999      sources.append(op)
1000
1001    inputs[op] = set(op_inputs)
1002    for i in op_inputs:
1003      outputs[i].add(op)
1004
1005  while sources:
1006    op = sources.pop()
1007    for op_output in outputs[op]:
1008      inputs[op_output].remove(op)
1009      if not inputs[op_output]:
1010        sources.append(op_output)
1011
1012    result.append(op)
1013
1014  # Check correctness.
1015  if len(result) != len(graph_ops):
1016    raise ValueError('Sort result has {} ops, source graph has {}.'
1017                     .format(len(result), len(graph_ops)))
1018
1019  sort_check_seen = set()
1020  for op in result:
1021    sort_check_seen.add(op)
1022    for i in get_inputs(op)[0]:
1023      assert i in sort_check_seen
1024
1025  return result
1026
1027
1028def identify_redundant_ops(graph):
1029  """Implements basic common subexpression elimination.
1030
1031  This is not intended to replicate the graph semantics of TensorFlow Graphs
1032  (for instance it does not handle stateful op ordering), nor is it intended to
1033  replace the common subexpression elimination Grappler pass. Rather, it
1034  provides a high level sanity check that clearly redundant ops are not being
1035  created.
1036
1037  Args:
1038    graph: The graph to be analyzed.
1039
1040  Returns:
1041    A count of the duplicate ops and a description of the structure of each.
1042  """
1043  sorted_ops = topological_sort(graph)
1044  duplicates = collections.defaultdict(list)
1045  unified_node_defs = {}
1046  name_map = {}
1047
1048  for op in sorted_ops:
1049    input_names = []
1050    for op_input, name in zip(*get_inputs(op)):
1051      input_def = op_input.node_def
1052
1053      # Operations can have multiple outputs. We track which is used to prevent
1054      # overzealous elimination.
1055      input_def.name = name
1056
1057      input_def.input[:] = [name_map.get(i, i) for i in input_def.input]
1058      strip_name(input_def)
1059
1060      # NodeDef.SerializeToString() does not provide identical serialized
1061      # representations for identical NodeDefs, so we instead use string
1062      # representation as a dict key.
1063      key = repr(input_def)
1064
1065      if key in unified_node_defs:
1066        input_names.append(unified_node_defs[key])
1067
1068      else:
1069        unified_node_defs[key] = op_input.name
1070        input_names.append(name)
1071
1072    node_def = op.node_def
1073    node_def.input[:] = input_names
1074    strip_name(node_def)
1075
1076    key = repr(node_def)
1077    duplicates[key].append(op)
1078    name_map[op.name] = duplicates[key][0].name
1079
1080  num_duplicates = 0
1081  duplicate_types = []
1082  for standard_def, op_defs in duplicates.items():
1083    # We are only interested in testing the apply method of the optimizer
1084    op_defs = [i for i in op_defs if APPLY_SCOPE in i.name]
1085
1086    # We only check for per-apply redundant ops.
1087    if len(op_defs) < _NUM_LEARNERS:
1088      continue
1089
1090    # Certain ops are simply not worth eliminating, and are instead simply
1091    # ignored.
1092    name, op_type = op_defs[0].name, op_defs[0].type
1093    if any(allowlisted_scope in name and op_type == allowlisted_type
1094           for allowlisted_scope, allowlisted_type in ALLOWLIST):
1095      continue
1096
1097    num_duplicates += len(op_defs)
1098    traceback = []
1099    for level in op_defs[0].traceback:
1100      traceback.append('  {} {}:{}'.format(level[0], level[2], level[1]))
1101
1102    duplicate_types.append(
1103        '# Example name: {}\n# Op creation stack:\n{}\n{}'.format(
1104            op_defs[0].name,
1105            '\n'.join(traceback),
1106            standard_def))
1107
1108  return num_duplicates, duplicate_types
1109
1110
1111def make_model():
1112  r"""Constructs a simple ensemble of weak learners model.
1113
1114  ---------    ---------             ---------    ---------
1115  | Input |    | Input |     ...     | Input |    | Input |
1116  ---------    ---------             ---------    ---------
1117      |            |                     |            |
1118      V            V                     V            V
1119  ---------    ---------             ---------    ---------
1120  | Embed |    | Embed |     ...     | Embed |    | Embed |
1121  ---------    ---------             ---------    ---------
1122      |            |                     |            |
1123      V            V                     V            V
1124  ---------    ---------             ---------    ---------
1125  | Dense |    | Dense |     ...     | Dense |    | Dense |
1126  ---------    ---------             ---------    ---------
1127      \            |                     |            /
1128       \           |                     |           /
1129        ---------------------------------------------
1130                              |
1131                          ---------
1132                          | Dense |
1133                          ---------
1134
1135  This topology is chosen because it exercises both dense and sparse update
1136  paths.
1137
1138  Returns:
1139    A model for testing optimizer coefficient reuse.
1140  """
1141  inputs = []
1142  intermediates = []
1143  for _ in range(_NUM_LEARNERS):
1144    inp = keras.layers.Input(shape=(1,), dtype=dtypes.int32)
1145    layer = keras.layers.Embedding(1, 4)(inp)
1146    layer = keras.layers.Dense(1)(layer)
1147
1148    inputs.append(inp)
1149    intermediates.append(layer)
1150
1151  layer = keras.layers.Concatenate(axis=-1)(intermediates)
1152  layer = keras.layers.Dense(1)(layer)
1153
1154  return keras.models.Model(inputs, layer)
1155
1156
1157COEFFICIENT_PARAMS = (
1158    ('Adadelta', adadelta.Adadelta, None),
1159    ('Adagrad', adagrad.Adagrad, None),
1160    ('Adam', adam.Adam, None),
1161    ('Adam_amdgrad', adam.Adam, dict(amsgrad=True)),
1162    ('Adamax', adamax.Adamax, None),
1163    ('Ftrl', ftrl.Ftrl, None),
1164    ('Ftrl_l2_shrinkage', ftrl.Ftrl,
1165     dict(l2_shrinkage_regularization_strength=0.1)),
1166    ('SGD', gradient_descent.SGD, None),
1167    ('SGD_momentum', gradient_descent.SGD, dict(momentum=0.5)),
1168    ('Nadam', nadam.Nadam, None),
1169    ('RMSprop', rmsprop.RMSprop, None),
1170    ('RMSprop_centered', rmsprop.RMSprop, dict(centered=True)),
1171    ('RMSprop_momentum', rmsprop.RMSprop, dict(momentum=0.5)),
1172    ('RMSprop_momentum_centered', rmsprop.RMSprop,
1173     dict(momentum=0.5, centered=True)),
1174)
1175
1176
1177class OptimizerCoefficientTest(keras_parameterized.TestCase):
1178
1179  @parameterized.named_parameters(*COEFFICIENT_PARAMS)
1180  def test_duplicate_ops(self, optimizer_class, init_kwargs=None):
1181    init_kwargs = init_kwargs or {}
1182    optimizer = optimizer_class(**init_kwargs)
1183
1184    graph = ops.Graph()
1185    with graph.as_default():
1186      model = make_model()
1187      trainable_variables = model.trainable_variables
1188      grads = optimizer.get_gradients(model.outputs[0], trainable_variables)
1189
1190      with backend.name_scope(APPLY_SCOPE):
1191        optimizer.apply_gradients(zip(grads, trainable_variables))
1192
1193    num_duplicates, duplicate_types = identify_redundant_ops(graph)
1194    if num_duplicates:
1195      # Avoid spamming logs.
1196      if len(duplicate_types) > 3:
1197        duplicate_types = duplicate_types[:3] + ['...']
1198
1199      num_total = len(graph.get_operations())
1200      raise ValueError('{} of {} ({:.1f}%) ops were duplicates:\n\n{}'.format(
1201          num_duplicates, num_total, num_duplicates / num_total * 100,
1202          '\n'.join(duplicate_types)))
1203
1204  @parameterized.named_parameters(*COEFFICIENT_PARAMS)
1205  def test_subclass_compat(self, optimizer_class, init_kwargs=None):
1206    """Ensure that subclassed optimizers without apply_state still work."""
1207
1208    class SubclassedOptimizer(optimizer_class):
1209
1210      def _resource_apply_dense(self, grad, var):  # pylint: disable=useless-super-delegation
1211        return super(SubclassedOptimizer, self)._resource_apply_dense(grad, var)
1212
1213      def _resource_apply_sparse(self, grad, var, indices):  # pylint: disable=useless-super-delegation
1214        return super(SubclassedOptimizer, self)._resource_apply_sparse(
1215            grad, var, indices)
1216
1217    init_kwargs = init_kwargs or {}
1218    optimizer = SubclassedOptimizer(**init_kwargs)
1219
1220    graph = ops.Graph()
1221    with graph.as_default():
1222      model = make_model()
1223      trainable_variables = model.trainable_variables
1224      grads = optimizer.get_gradients(model.outputs[0], trainable_variables)
1225
1226      with backend.name_scope(APPLY_SCOPE):
1227        optimizer.apply_gradients(zip(grads, trainable_variables))
1228
1229
1230if __name__ == '__main__':
1231  test.main()
1232