• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Functional test for OptimizerV2."""
16
17import collections
18
19from absl.testing import parameterized
20import numpy as np
21
22from tensorflow.python import keras
23from tensorflow.python.eager import context
24from tensorflow.python.eager import def_function
25from tensorflow.python.framework import constant_op
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import ops
28from tensorflow.python.framework import test_util
29from tensorflow.python.keras import backend
30from tensorflow.python.keras import callbacks
31from tensorflow.python.keras import combinations
32from tensorflow.python.keras import keras_parameterized
33from tensorflow.python.keras import losses
34from tensorflow.python.keras import optimizer_v1
35from tensorflow.python.keras import testing_utils
36from tensorflow.python.keras.engine import input_layer
37from tensorflow.python.keras.engine import sequential
38from tensorflow.python.keras.engine import training
39from tensorflow.python.keras.layers import core
40from tensorflow.python.keras.optimizer_v2 import adadelta
41from tensorflow.python.keras.optimizer_v2 import adagrad
42from tensorflow.python.keras.optimizer_v2 import adam
43from tensorflow.python.keras.optimizer_v2 import adamax
44from tensorflow.python.keras.optimizer_v2 import ftrl
45from tensorflow.python.keras.optimizer_v2 import gradient_descent
46from tensorflow.python.keras.optimizer_v2 import learning_rate_schedule
47from tensorflow.python.keras.optimizer_v2 import nadam
48from tensorflow.python.keras.optimizer_v2 import optimizer_v2
49from tensorflow.python.keras.optimizer_v2 import rmsprop
50from tensorflow.python.keras.utils import np_utils
51from tensorflow.python.ops import array_ops
52from tensorflow.python.ops import clip_ops
53from tensorflow.python.ops import state_ops
54from tensorflow.python.ops import variables
55from tensorflow.python.platform import test
56from tensorflow.python.training import momentum
57from tensorflow.python.training import training_util
58from tensorflow.python.training.tracking import util as trackable_utils
59
60
61_DATA_TYPES = [dtypes.half, dtypes.float32, dtypes.float64]
62# TODO(b/141710709): complex support in NVCC and ROCM.
63if (not test_util.IsBuiltWithNvcc() and not test.is_built_with_rocm()):
64  _DATA_TYPES += [dtypes.complex64, dtypes.complex128]
65
66
67class OptimizerTest(test.TestCase, parameterized.TestCase):
68
69  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
70  def testBasic(self):
71    for dtype in _DATA_TYPES:
72      with testing_utils.use_gpu():
73        var0 = variables.Variable([1.0, 2.0], dtype=dtype)
74        var1 = variables.Variable([3.0, 4.0], dtype=dtype)
75        loss = lambda: 5 * var0 + 3 * var1  # pylint: disable=cell-var-from-loop
76        sgd = gradient_descent.SGD(3.0)
77
78        self.evaluate(variables.global_variables_initializer())
79        # Fetch params to validate initial values
80        self.assertAllClose([1.0, 2.0], self.evaluate(var0))
81        self.assertAllClose([3.0, 4.0], self.evaluate(var1))
82        # Run 1 step of sgd through optimizer
83        opt_op = sgd.minimize(loss, var_list=[var0, var1])
84        self.evaluate(variables.global_variables_initializer())
85        self.evaluate(opt_op)
86        # Validate updated params
87        self.assertAllClose([-14., -13.], self.evaluate(var0))
88        self.assertAllClose([-6., -5.], self.evaluate(var1))
89
90  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
91  def testAdaptiveLearningRate(self):
92    for dtype in _DATA_TYPES:
93      with self.test_session():
94        var0 = variables.Variable([1.0, 2.0], dtype=dtype)
95        var1 = variables.Variable([3.0, 4.0], dtype=dtype)
96
97        def loss():
98          return 5 * var0 + 3 * var1  # pylint: disable=cell-var-from-loop
99
100        sgd = gradient_descent.SGD(1.0)
101
102        self.evaluate(variables.global_variables_initializer())
103        # Fetch params to validate initial values
104        self.assertAllClose([1.0, 2.0], self.evaluate(var0))
105        self.assertAllClose([3.0, 4.0], self.evaluate(var1))
106        # Run 1 step of sgd through optimizer
107        opt_op = sgd.minimize(loss, [var0, var1])
108        self.evaluate(variables.global_variables_initializer())
109        self.evaluate(opt_op)
110        # Validate updated params
111        # var0 = [1., 2.] - 1.0 * [5, 5]
112        self.assertAllClose([-4., -3.], self.evaluate(var0))
113        # var1 = [3., 4.] - 1.0 * [3, 3]
114        self.assertAllClose([0., 1.], self.evaluate(var1))
115
116        sgd.learning_rate = 0.5
117        if context.executing_eagerly():
118          sgd.minimize(loss, [var0, var1])
119        else:
120          self.evaluate(opt_op)
121        # Validate updated params
122        # var0 = [-4., -3.] - 0.5 * [5, 5]
123        self.assertAllClose([-6.5, -5.5], self.evaluate(var0))
124        # var1 = [0., 1.] - 0.5 * [3, 3]
125        self.assertAllClose([-1.5, -0.5], self.evaluate(var1))
126
127        sgd.learning_rate = learning_rate_schedule.InverseTimeDecay(
128            0.5, decay_steps=1.0, decay_rate=0.5)
129        if context.executing_eagerly():
130          sgd.minimize(loss, [var0, var1])
131        else:
132          self.evaluate(opt_op)
133
134  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
135  def testPrecomputedGradient(self):
136    for dtype in _DATA_TYPES:
137      with testing_utils.use_gpu():
138        var0 = variables.Variable([1.0, 2.0], dtype=dtype)
139        var1 = variables.Variable([3.0, 4.0], dtype=dtype)
140        loss = lambda: 5 * var0 + 3 * var1  # pylint: disable=cell-var-from-loop
141        grad_loss = constant_op.constant([42, -42], dtype=dtype)
142        sgd = gradient_descent.SGD(3.0)
143
144        self.evaluate(variables.global_variables_initializer())
145        # Fetch params to validate initial values
146        self.assertAllClose([1.0, 2.0], self.evaluate(var0))
147        self.assertAllClose([3.0, 4.0], self.evaluate(var1))
148        # Run 1 step of sgd through optimizer
149        opt_op = sgd.minimize(loss, var_list=[var0, var1], grad_loss=grad_loss)
150        self.evaluate(variables.global_variables_initializer())
151        self.evaluate(opt_op)
152        # Validate updated params
153        self.assertAllClose([1.0 - 3 * 5 * 42.0, 2.0 - 3 * 5 * (-42.0)],
154                            self.evaluate(var0))
155        self.assertAllClose([3.0 - 3 * 3 * 42.0, 4.0 - 3 * 3 * (-42.0)],
156                            self.evaluate(var1))
157
158  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
159  def testNoGradients(self):
160    for dtype in _DATA_TYPES:
161      with testing_utils.use_gpu():
162        var0 = variables.Variable([1.0, 2.0], dtype=dtype)
163        var1 = variables.Variable([3.0, 4.0], dtype=dtype)
164        loss = lambda: 5 * var0  # pylint: disable=cell-var-from-loop
165        sgd_op = gradient_descent.SGD(3.0)
166        with self.assertRaisesRegex(ValueError, 'No gradients'):
167          # var1 has no gradient
168          sgd_op.minimize(loss, var_list=[var1])
169
170  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
171  def testNoGradientsForAnyVariables_Minimize(self):
172    for dtype in _DATA_TYPES:
173      with testing_utils.use_gpu():
174        var0 = variables.Variable([1.0, 2.0], dtype=dtype)
175        var1 = variables.Variable([3.0, 4.0], dtype=dtype)
176        loss = lambda: constant_op.constant(5.0)
177
178        sgd_op = gradient_descent.SGD(3.0)
179        with self.assertRaisesRegex(ValueError,
180                                    'No gradients provided for any variable'):
181          sgd_op.minimize(loss, var_list=[var0, var1])
182
183  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
184  def testNoGradientsForAnyVariables_ApplyGradients(self):
185    for dtype in _DATA_TYPES:
186      with testing_utils.use_gpu():
187        var0 = variables.Variable([1.0, 2.0], dtype=dtype)
188        var1 = variables.Variable([3.0, 4.0], dtype=dtype)
189        sgd_op = gradient_descent.SGD(3.0)
190        with self.assertRaisesRegex(ValueError,
191                                    'No gradients provided for any variable'):
192          sgd_op.apply_gradients([(None, var0), (None, var1)])
193
194  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
195  def testGradientsAsVariables(self):
196    for i, dtype in enumerate(_DATA_TYPES):
197      with testing_utils.use_gpu():
198        var0 = variables.Variable([1.0, 2.0], dtype=dtype)
199        var1 = variables.Variable([3.0, 4.0], dtype=dtype)
200        loss = lambda: 5 * var0 + 3 * var1  # pylint: disable=cell-var-from-loop
201
202        sgd = gradient_descent.SGD(3.0)
203        grads_and_vars = sgd._compute_gradients(loss, [var0, var1])
204        # Convert gradients to tf.Variables
205        converted_grads = [
206            variables.Variable(
207                array_ops.zeros([2], dtype), name='c_%d_%d' % (i, j))
208            for j, gv in enumerate(grads_and_vars)
209        ]
210        convert_ops = [
211            state_ops.assign(converted_grads[j], gv[0])
212            for j, gv in enumerate(grads_and_vars)
213        ]
214
215        # Run convert_ops to achieve the gradients converting
216        self.evaluate(variables.global_variables_initializer())
217        self.evaluate(convert_ops)
218        # Fetch params to validate initial values
219        self.assertAllClose([1.0, 2.0], self.evaluate(var0))
220        self.assertAllClose([3.0, 4.0], self.evaluate(var1))
221
222        # Run 1 step of sgd through optimizer
223        converted_grads_and_vars = list(zip(converted_grads, [var0, var1]))
224        opt_op = sgd.apply_gradients(converted_grads_and_vars)
225        self.evaluate(variables.global_variables_initializer())
226        self.evaluate(convert_ops)
227        self.evaluate(opt_op)
228
229        # Validate updated params
230        self.assertAllClose([-14., -13.], self.evaluate(var0))
231        self.assertAllClose([-6., -5.], self.evaluate(var1))
232
233  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
234  def testComputeGradientsWithTensors(self):
235    with testing_utils.use_gpu():
236      x = ops.convert_to_tensor_v2_with_dispatch(1.0)
237
238      def f():
239        return x * x
240
241      sgd = gradient_descent.SGD(3.0)
242      grads_and_vars = sgd._compute_gradients(f, [x])
243      self.assertLen(grads_and_vars, 1)
244      grad, x_as_var = grads_and_vars[0]
245      self.assertIs(x, x_as_var)
246      self.assertEqual(2.0, self.evaluate(grad))
247
248      with self.assertRaises(NotImplementedError):
249        sgd.apply_gradients(grads_and_vars)
250
251  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
252  def testConstraint(self):
253    constraint_01 = lambda x: clip_ops.clip_by_value(x, -0.1, 0.)
254    constraint_0 = lambda x: clip_ops.clip_by_value(x, 0., 1.)
255    with testing_utils.use_gpu():
256      var0 = variables.Variable([1.0, 2.0],
257                                constraint=constraint_01)
258      var1 = variables.Variable([3.0, 4.0],
259                                constraint=constraint_0)
260      loss = lambda: 5 * var0 + 3 * var1
261      sgd = gradient_descent.SGD(3.0)
262
263      self.evaluate(variables.global_variables_initializer())
264      # Fetch params to validate initial values
265      self.assertAllClose([1.0, 2.0], self.evaluate(var0))
266      self.assertAllClose([3.0, 4.0], self.evaluate(var1))
267      # Run 1 step of sgd through optimizer
268      opt_op = sgd.minimize(loss, var_list=[var0, var1])
269      self.evaluate(variables.global_variables_initializer())
270      self.evaluate(opt_op)
271      # Validate updated params
272      self.assertAllClose([-0.1, -0.1], self.evaluate(var0))
273      self.assertAllClose([0., 0.], self.evaluate(var1))
274
275  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
276  def testIterationWithoutMinimize(self):
277    with testing_utils.use_gpu():
278      sgd = gradient_descent.SGD(3.0)
279      self.evaluate(sgd.iterations.initializer)
280      self.assertEqual(0, self.evaluate(sgd.iterations))
281
282  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
283  def testConfig(self):
284    with testing_utils.use_gpu():
285      opt = gradient_descent.SGD(learning_rate=1.0)
286      config = opt.get_config()
287      opt2 = gradient_descent.SGD.from_config(config)
288      lr = opt._get_hyper('learning_rate')
289      lr2 = opt2._get_hyper('learning_rate')
290      self.evaluate(variables.global_variables_initializer())
291      # assert both are equal float values.
292      self.assertEqual(self.evaluate(lr), self.evaluate(lr2))
293      var0 = variables.Variable([[1.0], [2.0]], dtype=dtypes.float32)
294      loss = lambda: 3 * var0
295      # learning rate variable created when calling minimize.
296      opt.minimize(loss, [var0])
297      opt3 = gradient_descent.SGD.from_config(config)
298      lr3 = opt3._get_hyper('learning_rate')
299      self.evaluate(variables.global_variables_initializer())
300      self.assertEqual(self.evaluate(lr), self.evaluate(lr3))
301
302  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
303  def testConfigWithLearningRateDecay(self):
304    with testing_utils.use_gpu():
305      var0 = variables.Variable([[1.0], [2.0]], dtype=dtypes.float32)
306      for decay_schedule in [
307          learning_rate_schedule.InverseTimeDecay(
308              0.5, decay_steps=1.0, decay_rate=0.1),
309          learning_rate_schedule.PiecewiseConstantDecay(
310              [5], [1., .5])
311      ]:
312        step = 10
313        opt = gradient_descent.SGD(decay_schedule)
314        config = opt.get_config()
315        opt2 = gradient_descent.SGD.from_config(config)
316        # assert both are equal float values.
317        self.assertAllEqual(
318            decay_schedule(step),
319            opt._get_hyper('learning_rate')(step))
320        self.assertAllEqual(
321            decay_schedule(step),
322            opt2._get_hyper('learning_rate')(step))
323        loss = lambda: 3 * var0
324        # learning rate variable is created when calling minimize.
325        opt.minimize(loss, [var0])
326        self.evaluate(variables.global_variables_initializer())
327        config = opt.get_config()
328        opt3 = gradient_descent.SGD.from_config(config)
329        self.assertAllEqual(
330            self.evaluate(opt._get_hyper('learning_rate')(step)),
331            opt3._get_hyper('learning_rate')(step))
332
333  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
334  def testGradClipValue(self):
335    with testing_utils.use_gpu():
336      var = variables.Variable([1.0, 2.0])
337      loss = lambda: 3 * var
338      opt = gradient_descent.SGD(learning_rate=1.0, clipvalue=1.0)
339      opt_op = opt.minimize(loss, [var])
340      self.evaluate(variables.global_variables_initializer())
341      self.evaluate(opt_op)
342      self.assertAllClose([0., 1.], self.evaluate(var))
343
344  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
345  def testGradClipNorm(self):
346    with testing_utils.use_gpu():
347      var = variables.Variable([1.0])
348      loss = lambda: 3 * var
349      opt = gradient_descent.SGD(learning_rate=1.0, clipnorm=1.0)
350      opt_op = opt.minimize(loss, [var])
351      self.evaluate(variables.global_variables_initializer())
352      self.evaluate(opt_op)
353      self.assertAllClose([0.], self.evaluate(var))
354
355  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
356  def testGradGlobalClipNorm(self):
357    with testing_utils.use_gpu():
358      # l2 norm is 5.0
359      var1 = variables.Variable([1.0])
360      var2 = variables.Variable([2.0])
361      loss = lambda: 3 * var1 + 4 * var2
362      opt = gradient_descent.SGD(learning_rate=1.0, global_clipnorm=2.0)
363      opt_op = opt.minimize(loss, [var1, var2])
364      self.evaluate(variables.global_variables_initializer())
365      self.evaluate(opt_op)
366      # grad1 = 3.0 * 2.0 / 5.0 = 1.2
367      self.assertAllClose([-.2], self.evaluate(var1))
368      # grad2 = 4.0 * 2.0 / 5.0 = 1.6
369      self.assertAllClose([.4], self.evaluate(var2))
370
371  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
372  def testInvalidClipNorm(self):
373    with self.assertRaisesRegex(ValueError, '>= 0'):
374      gradient_descent.SGD(learning_rate=1.0, clipnorm=-1.0)
375
376  @combinations.generate(
377      combinations.combine(
378          mode=['graph', 'eager'],
379          clip_type=['clipnorm', 'global_clipnorm', 'clipvalue']))
380  def testConfigWithCliping(self, clip_type):
381    opt = gradient_descent.SGD(learning_rate=1.0, **{clip_type: 2.0})
382    config = opt.get_config()
383    opt = gradient_descent.SGD.from_config(config)
384    self.assertEqual(getattr(opt, clip_type), 2.0)
385
386  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
387  def testInvalidKwargs(self):
388    with self.assertRaisesRegex(TypeError, 'Unexpected keyword argument'):
389      gradient_descent.SGD(learning_rate=1.0, invalidkwargs=1.0)
390
391  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
392  def testWeights(self):
393    with testing_utils.use_gpu():
394      opt1 = adam.Adam(learning_rate=1.0)
395      var1 = variables.Variable([1.0, 2.0], dtype=dtypes.float32)
396      loss1 = lambda: 3 * var1
397      opt_op_1 = opt1.minimize(loss1, [var1])
398      self.evaluate(variables.global_variables_initializer())
399      config = opt1.get_config()
400      opt2 = adam.Adam.from_config(config)
401      var2 = variables.Variable([1.0, 2.0], dtype=dtypes.float32)
402      loss2 = lambda: 3 * var2
403      opt_op_2 = opt2.minimize(loss2, [var2])
404      weights = opt1.get_weights()
405
406      # Assert set_weights and both variables get updated to same value.
407      self.evaluate(variables.global_variables_initializer())
408      opt2.set_weights(weights)
409      self.evaluate([opt_op_1, opt_op_2])
410      self.assertAllClose(self.evaluate(var1), self.evaluate(var2))
411      self.assertEqual(1, self.evaluate(opt1.iterations))
412      self.assertEqual(1, self.evaluate(opt2.iterations))
413
414      var3 = variables.Variable([1.0, 2.0, 3.0], dtype=dtypes.float32)
415      var4 = variables.Variable([4.0, 5.0, 6.0], dtype=dtypes.float32)
416      loss3 = lambda: 3 * var3 + 5 * var4
417      opt_op_3 = opt1.minimize(loss3, [var3, var4])
418
419      # Assert set_weights with ValueError since weight list does not match.
420      self.evaluate(variables.global_variables_initializer())
421      weights = opt1.get_weights()
422      with self.assertRaisesRegex(ValueError, 'but the optimizer was'):
423        opt2.set_weights(weights)
424
425      # Assert set_weights and variables get updated to same value.
426      var5 = variables.Variable([1.0, 2.0, 3.0], dtype=dtypes.float32)
427      var6 = variables.Variable([4.0, 5.0, 6.0], dtype=dtypes.float32)
428      loss4 = lambda: 3 * var5 + 5 * var6
429      opt_op_4 = opt2.minimize(loss4, [var5, var6])
430      self.evaluate(variables.global_variables_initializer())
431      opt2.set_weights(weights)
432      self.evaluate([opt_op_3, opt_op_4])
433      self.assertAllClose(
434          self.evaluate([var3, var4]), self.evaluate([var5, var6]))
435
436  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
437  def testGettingHyperParameters(self):
438    with self.test_session():
439      opt = adam.Adam(learning_rate=1.0)
440      var = variables.Variable([1.0, 2.0], dtype=dtypes.float32)
441      loss = lambda: 3 * var
442      opt_op = opt.minimize(loss, [var])
443      self.evaluate(variables.global_variables_initializer())
444      self.evaluate(opt_op)
445
446      lr = self.evaluate(opt.lr)
447      self.assertEqual(1.0, lr)
448
449      opt.lr = 2.0
450      lr = self.evaluate(opt.lr)
451      self.assertEqual(2.0, lr)
452
453      self.evaluate(opt.lr.assign(3.0))
454      lr = self.evaluate(opt.lr)
455      self.assertEqual(3.0, lr)
456
457      with self.assertRaises(AttributeError):
458        opt.not_an_attr += 3
459
460  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
461  def testGettingHyperParametersWithLrInConstructor(self):
462    with self.test_session():
463      opt = gradient_descent.SGD(lr=3.0)
464      var = variables.Variable([1.0, 2.0], dtype=dtypes.float32)
465      loss = lambda: 3 * var
466      opt_op = opt.minimize(loss, [var])
467      self.evaluate(variables.global_variables_initializer())
468      self.evaluate(opt_op)
469
470      self.assertIsInstance(opt.lr, variables.Variable)
471      self.assertIsInstance(opt.learning_rate, variables.Variable)
472
473      lr = self.evaluate(opt.lr)
474      self.assertEqual(3.0, lr)
475
476      opt.lr = 2.0
477      lr = self.evaluate(opt.lr)
478      self.assertEqual(2.0, lr)
479
480      self.evaluate(opt.lr.assign(4.0))
481      lr = self.evaluate(opt.lr)
482      self.assertEqual(4.0, lr)
483
484  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
485  def testDir(self):
486    opt = gradient_descent.SGD(learning_rate=1.0, momentum=0.1)
487    dir_result = set(dir(opt))
488    self.assertIn('learning_rate', dir_result)  # Hyperparameter
489    self.assertIn('lr', dir_result)  # Hyperparameter
490    self.assertIn('momentum', dir_result)  # Hyperparameter
491    self.assertIn('nesterov', dir_result)  # Attribute
492    self.assertIn('minimize', dir_result)  # Attribute
493
494  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
495  def testOptimizerWithKerasModel(self):
496    a = input_layer.Input(shape=(3,), name='input_a')
497    b = input_layer.Input(shape=(3,), name='input_b')
498
499    dense = core.Dense(4, name='dense')
500    c = dense(a)
501    d = dense(b)
502    e = core.Dropout(0.5, name='dropout')(c)
503
504    model = training.Model([a, b], [d, e])
505
506    optimizer = gradient_descent.SGD(learning_rate=0.001)
507    loss = 'mse'
508    model.compile(optimizer, loss, metrics=['mae'])
509
510    input_a_np = np.random.random((10, 3))
511    input_b_np = np.random.random((10, 3))
512
513    output_d_np = np.random.random((10, 4))
514    output_e_np = np.random.random((10, 4))
515
516    model.fit([input_a_np, input_b_np], [output_d_np, output_e_np],
517              epochs=1,
518              batch_size=5)
519
520  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
521  def testOptimizerWithCallbacks(self):
522    np.random.seed(1331)
523    input_np = np.random.random((10, 3))
524    output_np = np.random.random((10, 4))
525    a = input_layer.Input(shape=(3,), name='input_a')
526    model = sequential.Sequential()
527    model.add(core.Dense(4, kernel_initializer='zeros', name='dense'))
528    model.add(core.Dropout(0.5, name='dropout'))
529    model(a)
530    optimizer = gradient_descent.SGD(learning_rate=0.1)
531    model.compile(optimizer, loss='mse', metrics=['mae'])
532    # This does not reduce the LR after the first epoch (due to low delta).
533    cbks = [
534        callbacks.ReduceLROnPlateau(
535            monitor='val_loss', factor=0.1, min_delta=0, patience=1, cooldown=5)
536    ]
537    model.fit(
538        input_np,
539        output_np,
540        batch_size=10,
541        validation_data=(input_np, output_np),
542        callbacks=cbks,
543        epochs=2,
544        verbose=0)
545    self.assertAllClose(
546        float(backend.get_value(model.optimizer.lr)), 0.1, atol=1e-4)
547
548    # This should reduce the LR after the first epoch (due to high delta).
549    cbks = [
550        callbacks.ReduceLROnPlateau(
551            monitor='val_loss',
552            factor=0.1,
553            min_delta=10,
554            patience=1,
555            cooldown=5)
556    ]
557    model.fit(
558        input_np,
559        output_np,
560        batch_size=10,
561        validation_data=(input_np, output_np),
562        callbacks=cbks,
563        epochs=2,
564        verbose=2)
565    self.assertAllClose(
566        float(backend.get_value(model.optimizer.lr)), 0.01, atol=1e-4)
567
568  def testOptimizerSetIterations(self):
569    global_step = training_util.get_or_create_global_step()
570    opt = adam.Adam(learning_rate=1.0)
571    opt.iterations = global_step
572    var = variables.Variable([1.0, 2.0], dtype=dtypes.float32)
573    self.evaluate(variables.global_variables_initializer())
574    init_step_value = self.evaluate(global_step)
575    loss = lambda: 3 * var
576    opt_op = opt.minimize(loss, [var])
577    self.evaluate(variables.global_variables_initializer())
578    self.evaluate(opt_op)
579    new_step_value = self.evaluate(global_step)
580    self.assertEqual(new_step_value, init_step_value + 1)
581
582  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
583  def testOptimizerWithCallableVarList(self):
584    train_samples = 20
585    input_dim = 1
586    num_classes = 2
587    (x, y), _ = testing_utils.get_test_data(
588        train_samples=train_samples,
589        test_samples=10,
590        input_shape=(input_dim,),
591        num_classes=num_classes)
592    y = np_utils.to_categorical(y)
593
594    num_hidden = 1
595    model = testing_utils.get_small_sequential_mlp(
596        num_hidden=num_hidden, num_classes=num_classes)
597    opt = adam.Adam()
598
599    loss = lambda: losses.mean_squared_error(model(x), y)
600    var_list = lambda: model.trainable_weights
601
602    with self.assertRaisesRegex(
603        ValueError, 'Weights for model .* have not yet been created'):
604      var_list()
605    train_op = opt.minimize(loss, var_list)
606    if not context.executing_eagerly():
607      self.evaluate(variables.global_variables_initializer())
608      self.assertEqual(
609          [[0.]], self.evaluate(opt.get_slot(var_list()[0], 'm')))
610      self.evaluate(train_op)
611    self.assertNotEqual(
612        [[0.]], self.evaluate(opt.get_slot(var_list()[0], 'm')))
613    self.assertLen(var_list(), 4)
614
615  def testVarKey(self):
616    with ops.get_default_graph().as_default():
617      a = variables.Variable([1., 2.], name='var')
618      b = variables.Variable([1.], name='var')
619      self.assertTrue(a._in_graph_mode)
620      self.assertTrue(b._in_graph_mode)
621      var_key = optimizer_v2._var_key(a)
622      self.assertEqual('var', var_key)
623      var_key = optimizer_v2._var_key(b)
624      self.assertEqual('var_1', var_key)
625
626  def testVarName(self):
627    with ops.get_default_graph().as_default():
628      var = variables.Variable([1., 2.], name='var')
629      loss = var + 1.
630      opt = adam.Adam()
631      opt.get_updates(loss, [var])
632      opt_vars = opt.variables()
633      self.assertLen(opt_vars, 3)
634      self.assertEqual('Adam/iter:0', opt_vars[0].name)
635      self.assertEqual('Adam/var/m:0', opt_vars[1].name)
636      var_2 = variables.Variable([1., 2.], name='var_2')
637      loss = var_2 + 1.
638      with backend.name_scope('outter'):
639        opt.get_updates(loss, [var_2])
640      opt_vars = opt.variables()
641      self.assertLen(opt_vars, 5)
642      self.assertEqual('outter/Adam/var_2/m:0', opt_vars[3].name)
643
644  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
645  def testEmptyVarList(self):
646    opt = gradient_descent.SGD(1.)
647    opt.minimize(lambda: constant_op.constant(1.), [])
648    opt.apply_gradients([])
649
650  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
651  def testAggregationTrue(self):
652    # Test that experimental_aggregate_gradients=True works without distributed
653    # strategy.
654    var = variables.Variable([1., 2.])
655    opt = gradient_descent.SGD(3.0)
656
657    self.evaluate(variables.global_variables_initializer())
658    self.assertAllClose([1., 2.], self.evaluate(var))
659    opt_op = opt.apply_gradients([([0.1, 0.1], var)],
660                                 experimental_aggregate_gradients=True)
661    self.evaluate(variables.global_variables_initializer())
662    self.evaluate(opt_op)
663    self.assertAllClose([0.7, 1.7], self.evaluate(var))
664
665  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
666  def testAggregationFalse(self):
667    # Test that experimental_aggregate_gradients=False works without distributed
668    # strategy.
669    var = variables.Variable([1., 2.])
670    opt = gradient_descent.SGD(3.0)
671
672    self.evaluate(variables.global_variables_initializer())
673    self.assertAllClose([1., 2.], self.evaluate(var))
674    opt_op = opt.apply_gradients([([0.1, 0.1], var)],
675                                 experimental_aggregate_gradients=False)
676    self.evaluate(variables.global_variables_initializer())
677    self.evaluate(opt_op)
678    self.assertAllClose([0.7, 1.7], self.evaluate(var))
679
680  @combinations.generate(combinations.combine(mode=['eager']))
681  def testRestoringIterationsWithoutAnOptimizer(self):
682    opt = gradient_descent.SGD(3.0)
683    opt.iterations.assign(5)
684    checkpoint = trackable_utils.Checkpoint(optimizer=opt)
685    path = checkpoint.save(self.get_temp_dir())
686
687    # Following verifies that the `iterations` can be restored with the absence
688    # of an `Optimizer` object (using a `Checkpoint` as a placeholder).
689    iterations_var = variables.Variable(0, dtype=dtypes.int64)
690    optimizer_checkpoint = trackable_utils.Checkpoint(iter=iterations_var)
691    checkpoint_to_restore = trackable_utils.Checkpoint(
692        optimizer=optimizer_checkpoint)
693    checkpoint_to_restore.restore(path)
694
695    self.assertEqual(5, self.evaluate(iterations_var))
696
697  @combinations.generate(combinations.combine(mode=['eager']))
698  def testSlotWithNonstandardShapeRestoresBasedOnCheckpoint(self):
699    # First create an optimizer and a slot variable with a non-standard shape.
700    x = variables.Variable([[1.0, 2.0], [3.0, 4.0]], dtype=dtypes.float32)
701    slot_shape = [2, 1]
702    optimizer_1 = optimizer_v2.OptimizerV2(name='test')
703    optimizer_1.add_slot(x, 'test_slot', 'ones', shape=slot_shape)
704
705    # Then save the variable and optimizer to a checkpoint.
706    checkpoint_1 = trackable_utils.Checkpoint(var=x, optimizer=optimizer_1)
707    checkpoint_path = checkpoint_1.save(self.get_temp_dir())
708
709    # Create a new optimizer and call restore on it (and x)
710    optimizer_2 = optimizer_v2.OptimizerV2(name='test')
711    checkpoint_2 = trackable_utils.Checkpoint(var=x, optimizer=optimizer_2)
712    checkpoint_2.restore(checkpoint_path)
713
714    self.assertEqual(slot_shape,
715                     optimizer_2.get_slot(x, 'test_slot').shape.as_list())
716
717  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
718  def test_gradient_aggregator(self):
719    def gradient_aggregator(grads_and_vars):
720      # Simulate an all-reduce where the other replica has zeros for gradients,
721      # by dividing each gradient by 2.
722      grads = [g for g, _ in grads_and_vars]
723      vars = [v for _, v in grads_and_vars]  # pylint: disable=redefined-builtin
724      all_reduced_grads = [g / 2 for g in grads]
725      return list(zip(all_reduced_grads, vars))
726
727    var = variables.Variable(2.0)
728    sgd = gradient_descent.SGD(1.0, gradient_aggregator=gradient_aggregator)
729    loss = lambda: 2 * var
730    opt_op = sgd.minimize(loss, var_list=[var])
731    self.evaluate(variables.global_variables_initializer())
732    self.evaluate(opt_op)
733    self.assertEqual(self.evaluate(var), 1.0)
734
735  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
736  def test_override_aggregate_gradients(self):
737    class MyOptimizer(gradient_descent.SGD):
738
739      def _aggregate_gradients(self, grads_and_vars):
740        # Simulate an all-reduce where the other replica has zeros for
741        # gradients, by dividing each gradient by 2.
742        grads = [g for g, _ in grads_and_vars]
743        vars = [v for _, v in grads_and_vars]  # pylint: disable=redefined-builtin
744        all_reduced_grads = [g / 2 for g in grads]
745        return list(zip(all_reduced_grads, vars))
746
747    var = variables.Variable(2.0)
748    sgd = MyOptimizer(1.0)
749    loss = lambda: 2 * var
750    opt_op = sgd.minimize(loss, var_list=[var])
751    self.evaluate(variables.global_variables_initializer())
752    self.evaluate(opt_op)
753    self.assertEqual(self.evaluate(var), 1.0)
754
755
756@keras_parameterized.run_all_keras_modes
757class OptimizersCompatibilityTest(keras_parameterized.TestCase):
758
759  def _testOptimizersCompatibility(self, opt_v1, opt_v2, test_weights=True):
760    if context.executing_eagerly():
761      self.skipTest(
762          'v1 optimizer does not run in eager mode')
763    np.random.seed(1331)
764    with testing_utils.use_gpu():
765      train_samples = 20
766      input_dim = 3
767      num_classes = 2
768      (x, y), _ = testing_utils.get_test_data(
769          train_samples=train_samples,
770          test_samples=10,
771          input_shape=(input_dim,),
772          num_classes=num_classes)
773      y = np_utils.to_categorical(y)
774
775      num_hidden = 5
776      model_v1 = testing_utils.get_small_sequential_mlp(
777          num_hidden=num_hidden, num_classes=num_classes, input_dim=input_dim)
778      model_v1.compile(
779          opt_v1,
780          loss='categorical_crossentropy',
781          metrics=[],
782          run_eagerly=testing_utils.should_run_eagerly())
783      model_v1.fit(x, y, batch_size=5, epochs=1)
784
785      model_v2 = testing_utils.get_small_sequential_mlp(
786          num_hidden=num_hidden, num_classes=num_classes, input_dim=input_dim)
787      model_v2.set_weights(model_v1.get_weights())
788      model_v2.compile(
789          opt_v2,
790          loss='categorical_crossentropy',
791          metrics=[],
792          run_eagerly=testing_utils.should_run_eagerly())
793      if not ops.executing_eagerly_outside_functions():
794        model_v2._make_train_function()
795      if test_weights:
796        opt_v2.set_weights(opt_v1.get_weights())
797
798      hist_1 = model_v1.fit(x, y, batch_size=5, epochs=1, shuffle=False)
799      hist_2 = model_v2.fit(x, y, batch_size=5, epochs=1, shuffle=False)
800      self.assertAllClose(model_v1.get_weights(), model_v2.get_weights(),
801                          rtol=1e-5, atol=1e-5)
802      self.assertAllClose(hist_1.history['loss'], hist_2.history['loss'],
803                          rtol=1e-5, atol=1e-5)
804
805  def testAdadeltaCompatibility(self):
806    opt_v1 = optimizer_v1.Adadelta(lr=0.01)
807    opt_v2 = adadelta.Adadelta(learning_rate=0.01)
808    self._testOptimizersCompatibility(opt_v1, opt_v2)
809
810  def testAdagradCompatibility(self):
811    opt_v1 = optimizer_v1.Adagrad(lr=0.01)
812    opt_v2 = adagrad.Adagrad(learning_rate=0.01)
813    self._testOptimizersCompatibility(opt_v1, opt_v2)
814
815  def testAdamCompatibility(self):
816    opt_v1 = optimizer_v1.Adam()
817    opt_v2 = adam.Adam()
818    self._testOptimizersCompatibility(opt_v1, opt_v2)
819
820  def testAdamaxCompatibility(self):
821    opt_v1 = optimizer_v1.Adamax(lr=0.01)
822    opt_v2 = adamax.Adamax(learning_rate=0.01)
823    self._testOptimizersCompatibility(opt_v1, opt_v2)
824
825  def testNadamCompatibility(self):
826    opt_v1 = optimizer_v1.Nadam(lr=0.001)
827    opt_v2 = nadam.Nadam(learning_rate=0.001)
828    self._testOptimizersCompatibility(opt_v1, opt_v2)
829
830  def testMomentumCompatibility(self):
831    opt_v1 = optimizer_v1.SGD(lr=0.01, momentum=0.9)
832    opt_v2 = gradient_descent.SGD(learning_rate=0.01, momentum=0.9)
833    self._testOptimizersCompatibility(opt_v1, opt_v2)
834
835  def testRMSpropCompatibility(self):
836    opt_v1 = optimizer_v1.RMSprop()
837    opt_v2 = rmsprop.RMSprop()
838    self._testOptimizersCompatibility(opt_v1, opt_v2)
839
840  def testSGDCompatibility(self):
841    opt_v1 = optimizer_v1.SGD(lr=0.01)
842    opt_v2 = gradient_descent.SGD(learning_rate=0.01)
843    self._testOptimizersCompatibility(opt_v1, opt_v2, False)
844
845  def testNumericEquivalenceForNesterovMomentum(self):
846    if context.executing_eagerly():
847      self.skipTest(
848          'v1 optimizer does not run in eager mode')
849    np.random.seed(1331)
850    with testing_utils.use_gpu():
851      train_samples = 20
852      input_dim = 3
853      num_classes = 2
854      (x, y), _ = testing_utils.get_test_data(
855          train_samples=train_samples,
856          test_samples=10,
857          input_shape=(input_dim,),
858          num_classes=num_classes)
859      y = np_utils.to_categorical(y)
860
861      num_hidden = 5
862      model_k_v1 = testing_utils.get_small_sequential_mlp(
863          num_hidden=num_hidden, num_classes=num_classes, input_dim=input_dim)
864      model_k_v2 = testing_utils.get_small_sequential_mlp(
865          num_hidden=num_hidden, num_classes=num_classes, input_dim=input_dim)
866      model_k_v2.set_weights(model_k_v1.get_weights())
867      model_tf = testing_utils.get_small_sequential_mlp(
868          num_hidden=num_hidden, num_classes=num_classes, input_dim=input_dim)
869      model_tf.set_weights(model_k_v2.get_weights())
870
871      opt_k_v1 = optimizer_v1.SGD(momentum=0.9, nesterov=True)
872      opt_k_v2 = gradient_descent.SGD(momentum=0.9, nesterov=True)
873      opt_tf = momentum.MomentumOptimizer(
874          learning_rate=0.01, momentum=0.9, use_nesterov=True)
875
876      model_k_v1.compile(
877          opt_k_v1,
878          loss='categorical_crossentropy',
879          metrics=[],
880          run_eagerly=testing_utils.should_run_eagerly())
881      model_k_v2.compile(
882          opt_k_v2,
883          loss='categorical_crossentropy',
884          metrics=[],
885          run_eagerly=testing_utils.should_run_eagerly())
886      model_tf.compile(
887          opt_tf,
888          loss='categorical_crossentropy',
889          metrics=[],
890          run_eagerly=testing_utils.should_run_eagerly())
891
892      hist_k_v1 = model_k_v1.fit(x, y, batch_size=5, epochs=10, shuffle=False)
893      hist_k_v2 = model_k_v2.fit(x, y, batch_size=5, epochs=10, shuffle=False)
894      hist_tf = model_tf.fit(x, y, batch_size=5, epochs=10, shuffle=False)
895
896      self.assertAllClose(model_k_v1.get_weights(), model_tf.get_weights())
897      self.assertAllClose(model_k_v1.get_weights(), model_k_v2.get_weights())
898      self.assertAllClose(opt_k_v1.get_weights(), opt_k_v2.get_weights())
899      self.assertAllClose(hist_k_v1.history['loss'], hist_tf.history['loss'])
900      self.assertAllClose(hist_k_v1.history['loss'], hist_k_v2.history['loss'])
901
902  def testNumericEquivalenceForAmsgrad(self):
903    if context.executing_eagerly():
904      self.skipTest(
905          'v1 optimizer does not run in eager mode')
906    np.random.seed(1331)
907    with testing_utils.use_gpu():
908      train_samples = 20
909      input_dim = 3
910      num_classes = 2
911      (x, y), _ = testing_utils.get_test_data(
912          train_samples=train_samples,
913          test_samples=10,
914          input_shape=(input_dim,),
915          num_classes=num_classes)
916      y = np_utils.to_categorical(y)
917
918      num_hidden = 5
919      model_k_v1 = testing_utils.get_small_sequential_mlp(
920          num_hidden=num_hidden, num_classes=num_classes, input_dim=input_dim)
921      model_k_v2 = testing_utils.get_small_sequential_mlp(
922          num_hidden=num_hidden, num_classes=num_classes, input_dim=input_dim)
923      model_k_v2.set_weights(model_k_v1.get_weights())
924
925      opt_k_v1 = optimizer_v1.Adam(amsgrad=True)
926      opt_k_v2 = adam.Adam(amsgrad=True)
927
928      model_k_v1.compile(
929          opt_k_v1,
930          loss='categorical_crossentropy',
931          metrics=[],
932          run_eagerly=testing_utils.should_run_eagerly())
933      model_k_v2.compile(
934          opt_k_v2,
935          loss='categorical_crossentropy',
936          metrics=[],
937          run_eagerly=testing_utils.should_run_eagerly())
938
939      hist_k_v1 = model_k_v1.fit(x, y, batch_size=5, epochs=10, shuffle=False)
940      hist_k_v2 = model_k_v2.fit(x, y, batch_size=5, epochs=10, shuffle=False)
941
942      self.assertAllClose(model_k_v1.get_weights(), model_k_v2.get_weights())
943      self.assertAllClose(opt_k_v1.get_weights(), opt_k_v2.get_weights())
944      self.assertAllClose(hist_k_v1.history['loss'], hist_k_v2.history['loss'])
945
946
947# Note: These tests are kept in a separate class to avoid bugs in some
948# distributions of Python that break AutoGraph which is used by tf.function.
949@combinations.generate(combinations.combine(mode=['eager']))
950class OptimizerWithFunctionTest(test.TestCase, parameterized.TestCase):
951
952  def testBasic(self):
953    var = variables.Variable([1.0, 2.0], dtype=dtypes.float32)
954    loss = lambda: 3 * var
955    opt = adam.Adam(learning_rate=1.0)
956
957    @def_function.function
958    def fn():
959      opt.minimize(loss, [var])
960      return var
961
962    self.assertAllClose([0., 1.], fn(), atol=1e-4)
963    self.assertAllClose([-1, 0.], fn(), atol=1e-4)
964
965  def testBasicWithConstantDecay(self):
966    var = variables.Variable([1.0, 2.0], dtype=dtypes.float32)
967    loss = lambda: 3 * var
968    opt = adam.Adam(learning_rate=1.0)
969
970    @def_function.function
971    def fn():
972      opt.minimize(loss, [var])
973      return var
974
975    self.assertAllClose([0., 1.], fn(), atol=1e-4)
976    self.assertAllClose([-1, 0.], fn(), atol=1e-4)
977
978  def testVarKeyWithVarCreatedInEager(self):
979    a = variables.Variable([1., 2.], name='var')
980    b = variables.Variable([1.], name='var')
981
982    @test_util.also_run_as_tf_function
983    def var_key_test():
984      self.assertFalse(a._in_graph_mode)
985      self.assertFalse(b._in_graph_mode)
986      var_key_a = optimizer_v2._var_key(a)
987      self.assertStartsWith(var_key_a, 'var_')
988      var_key_b = optimizer_v2._var_key(b)
989      self.assertStartsWith(var_key_b, 'var_')
990      self.assertNotEqual(var_key_a, var_key_b)
991
992    var_key_test()
993
994  def testLearningRateDecayUsedInTwoFunctions(self):
995    a = variables.Variable([1., 2.], name='var')
996    b = variables.Variable([1.], name='var')
997
998    learning_rate_decay = learning_rate_schedule.InverseTimeDecay(
999        0.5, decay_steps=1.0, decay_rate=0.5)
1000    opt = adam.Adam(learning_rate=learning_rate_decay)
1001    loss_a = lambda: 3 * a
1002    loss_b = lambda: 2 * b
1003
1004    @def_function.function
1005    def fn_a():
1006      opt.minimize(loss_a, [a])
1007      return a
1008
1009    @def_function.function
1010    def fn_b():
1011      opt.minimize(loss_b, [b])
1012      return b
1013
1014    fn_a()
1015    fn_b()
1016
1017
1018_NUM_LEARNERS = 50
1019APPLY_SCOPE = 'debug_apply'
1020ALLOWLIST = [
1021    # optimizer_v2._deduplicate_indexed_slices contains an indexed slice:
1022    #   array_ops.shape(unique_indices)[0]
1023    # which winds up expanding to [0:1:1] thereby creating three constants
1024    # to represent the indices.
1025    ('embeddings/strided_slice/stack', 'Const'),
1026]
1027
1028
1029def get_inputs(op):
1030  op_inputs = list(op.inputs) + op.control_inputs
1031  names = [i.name for i in op_inputs]
1032  op_inputs = [getattr(i, 'op', i) for i in op_inputs]
1033  return op_inputs, names
1034
1035
1036def strip_name(node):
1037  if 'Placeholder' in node.op:
1038    return
1039  node.name = ''
1040
1041
1042def topological_sort(graph):
1043  graph_ops = graph.get_operations()
1044
1045  sources = []
1046  result = []
1047
1048  inputs = {}
1049  outputs = collections.defaultdict(set)
1050  for op in graph_ops:
1051    op_inputs = get_inputs(op)[0]
1052    if not op_inputs:
1053      sources.append(op)
1054
1055    inputs[op] = set(op_inputs)
1056    for i in op_inputs:
1057      outputs[i].add(op)
1058
1059  while sources:
1060    op = sources.pop()
1061    for op_output in outputs[op]:
1062      inputs[op_output].remove(op)
1063      if not inputs[op_output]:
1064        sources.append(op_output)
1065
1066    result.append(op)
1067
1068  # Check correctness.
1069  if len(result) != len(graph_ops):
1070    raise ValueError('Sort result has {} ops, source graph has {}.'
1071                     .format(len(result), len(graph_ops)))
1072
1073  sort_check_seen = set()
1074  for op in result:
1075    sort_check_seen.add(op)
1076    for i in get_inputs(op)[0]:
1077      assert i in sort_check_seen
1078
1079  return result
1080
1081
1082def identify_redundant_ops(graph):
1083  """Implements basic common subexpression elimination.
1084
1085  This is not intended to replicate the graph semantics of TensorFlow Graphs
1086  (for instance it does not handle stateful op ordering), nor is it intended to
1087  replace the common subexpression elimination Grappler pass. Rather, it
1088  provides a high level sanity check that clearly redundant ops are not being
1089  created.
1090
1091  Args:
1092    graph: The graph to be analyzed.
1093
1094  Returns:
1095    A count of the duplicate ops and a description of the structure of each.
1096  """
1097  sorted_ops = topological_sort(graph)
1098  duplicates = collections.defaultdict(list)
1099  unified_node_defs = {}
1100  name_map = {}
1101
1102  for op in sorted_ops:
1103    input_names = []
1104    for op_input, name in zip(*get_inputs(op)):
1105      input_def = op_input.node_def
1106
1107      # Operations can have multiple outputs. We track which is used to prevent
1108      # overzealous elimination.
1109      input_def.name = name
1110
1111      input_def.input[:] = [name_map.get(i, i) for i in input_def.input]
1112      strip_name(input_def)
1113
1114      # NodeDef.SerializeToString() does not provide identical serialized
1115      # representations for identical NodeDefs, so we instead use string
1116      # representation as a dict key.
1117      key = repr(input_def)
1118
1119      if key in unified_node_defs:
1120        input_names.append(unified_node_defs[key])
1121
1122      else:
1123        unified_node_defs[key] = op_input.name
1124        input_names.append(name)
1125
1126    node_def = op.node_def
1127    node_def.input[:] = input_names
1128    strip_name(node_def)
1129
1130    key = repr(node_def)
1131    duplicates[key].append(op)
1132    name_map[op.name] = duplicates[key][0].name
1133
1134  num_duplicates = 0
1135  duplicate_types = []
1136  for standard_def, op_defs in duplicates.items():
1137    # We are only interested in testing the apply method of the optimizer
1138    op_defs = [i for i in op_defs if APPLY_SCOPE in i.name]
1139
1140    # We only check for per-apply redundant ops.
1141    if len(op_defs) < _NUM_LEARNERS:
1142      continue
1143
1144    # Certain ops are simply not worth eliminating, and are instead simply
1145    # ignored.
1146    name, op_type = op_defs[0].name, op_defs[0].type
1147    if any(allowlisted_scope in name and op_type == allowlisted_type
1148           for allowlisted_scope, allowlisted_type in ALLOWLIST):
1149      continue
1150
1151    num_duplicates += len(op_defs)
1152    traceback = []
1153    for level in op_defs[0].traceback:
1154      traceback.append('  {} {}:{}'.format(level[0], level[2], level[1]))
1155
1156    duplicate_types.append(
1157        '# Example name: {}\n# Op creation stack:\n{}\n{}'.format(
1158            op_defs[0].name,
1159            '\n'.join(traceback),
1160            standard_def))
1161
1162  return num_duplicates, duplicate_types
1163
1164
1165def make_model():
1166  r"""Constructs a simple ensemble of weak learners model.
1167
1168  ---------    ---------             ---------    ---------
1169  | Input |    | Input |     ...     | Input |    | Input |
1170  ---------    ---------             ---------    ---------
1171      |            |                     |            |
1172      V            V                     V            V
1173  ---------    ---------             ---------    ---------
1174  | Embed |    | Embed |     ...     | Embed |    | Embed |
1175  ---------    ---------             ---------    ---------
1176      |            |                     |            |
1177      V            V                     V            V
1178  ---------    ---------             ---------    ---------
1179  | Dense |    | Dense |     ...     | Dense |    | Dense |
1180  ---------    ---------             ---------    ---------
1181      \            |                     |            /
1182       \           |                     |           /
1183        ---------------------------------------------
1184                              |
1185                          ---------
1186                          | Dense |
1187                          ---------
1188
1189  This topology is chosen because it exercises both dense and sparse update
1190  paths.
1191
1192  Returns:
1193    A model for testing optimizer coefficient reuse.
1194  """
1195  inputs = []
1196  intermediates = []
1197  for _ in range(_NUM_LEARNERS):
1198    inp = keras.layers.Input(shape=(1,), dtype=dtypes.int32)
1199    layer = keras.layers.Embedding(1, 4)(inp)
1200    layer = keras.layers.Dense(1)(layer)
1201
1202    inputs.append(inp)
1203    intermediates.append(layer)
1204
1205  layer = keras.layers.Concatenate(axis=-1)(intermediates)
1206  layer = keras.layers.Dense(1)(layer)
1207
1208  return keras.models.Model(inputs, layer)
1209
1210
1211COEFFICIENT_PARAMS = (
1212    ('Adadelta', adadelta.Adadelta, None),
1213    ('Adagrad', adagrad.Adagrad, None),
1214    ('Adam', adam.Adam, None),
1215    ('Adam_amdgrad', adam.Adam, dict(amsgrad=True)),
1216    ('Adamax', adamax.Adamax, None),
1217    ('Ftrl', ftrl.Ftrl, None),
1218    ('Ftrl_l2_shrinkage', ftrl.Ftrl,
1219     dict(l2_shrinkage_regularization_strength=0.1)),
1220    ('SGD', gradient_descent.SGD, None),
1221    ('SGD_momentum', gradient_descent.SGD, dict(momentum=0.5)),
1222    ('Nadam', nadam.Nadam, None),
1223    ('RMSprop', rmsprop.RMSprop, None),
1224    ('RMSprop_centered', rmsprop.RMSprop, dict(centered=True)),
1225    ('RMSprop_momentum', rmsprop.RMSprop, dict(momentum=0.5)),
1226    ('RMSprop_momentum_centered', rmsprop.RMSprop,
1227     dict(momentum=0.5, centered=True)),
1228)
1229
1230
1231class OptimizerCoefficientTest(keras_parameterized.TestCase):
1232
1233  @parameterized.named_parameters(*COEFFICIENT_PARAMS)
1234  def test_duplicate_ops(self, optimizer_class, init_kwargs=None):
1235    init_kwargs = init_kwargs or {}
1236    optimizer = optimizer_class(**init_kwargs)
1237
1238    graph = ops.Graph()
1239    with graph.as_default():
1240      model = make_model()
1241      trainable_variables = model.trainable_variables
1242      grads = optimizer.get_gradients(model.outputs[0], trainable_variables)
1243
1244      with backend.name_scope(APPLY_SCOPE):
1245        optimizer.apply_gradients(zip(grads, trainable_variables))
1246
1247    num_duplicates, duplicate_types = identify_redundant_ops(graph)
1248    if num_duplicates:
1249      # Avoid spamming logs.
1250      if len(duplicate_types) > 3:
1251        duplicate_types = duplicate_types[:3] + ['...']
1252
1253      num_total = len(graph.get_operations())
1254      raise ValueError('{} of {} ({:.1f}%) ops were duplicates:\n\n{}'.format(
1255          num_duplicates, num_total, num_duplicates / num_total * 100,
1256          '\n'.join(duplicate_types)))
1257
1258  @parameterized.named_parameters(*COEFFICIENT_PARAMS)
1259  def test_subclass_compat(self, optimizer_class, init_kwargs=None):
1260    """Ensure that subclassed optimizers without apply_state still work."""
1261
1262    class SubclassedOptimizer(optimizer_class):
1263
1264      def _resource_apply_dense(self, grad, var):  # pylint: disable=useless-super-delegation
1265        return super(SubclassedOptimizer, self)._resource_apply_dense(grad, var)
1266
1267      def _resource_apply_sparse(self, grad, var, indices):  # pylint: disable=useless-super-delegation
1268        return super(SubclassedOptimizer, self)._resource_apply_sparse(
1269            grad, var, indices)
1270
1271    init_kwargs = init_kwargs or {}
1272    optimizer = SubclassedOptimizer(**init_kwargs)
1273
1274    graph = ops.Graph()
1275    with graph.as_default():
1276      model = make_model()
1277      trainable_variables = model.trainable_variables
1278      grads = optimizer.get_gradients(model.outputs[0], trainable_variables)
1279
1280      with backend.name_scope(APPLY_SCOPE):
1281        optimizer.apply_gradients(zip(grads, trainable_variables))
1282
1283
1284if __name__ == '__main__':
1285  test.main()
1286