• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for Keras callbacks."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import csv
23import json
24import os
25import re
26import shutil
27import sys
28import threading
29import time
30import unittest
31
32from absl.testing import parameterized
33import numpy as np
34
35from tensorflow.core.framework import summary_pb2
36from tensorflow.python import keras
37from tensorflow.python.data.ops import dataset_ops
38from tensorflow.python.framework import random_seed
39from tensorflow.python.keras import keras_parameterized
40from tensorflow.python.keras import testing_utils
41from tensorflow.python.keras.engine import sequential
42from tensorflow.python.keras.optimizer_v2 import gradient_descent
43from tensorflow.python.keras.optimizer_v2 import learning_rate_schedule
44from tensorflow.python.keras.utils import np_utils
45from tensorflow.python.ops import array_ops
46from tensorflow.python.ops import math_ops
47from tensorflow.python.ops import summary_ops_v2
48from tensorflow.python.platform import test
49from tensorflow.python.platform import tf_logging as logging
50from tensorflow.python.summary import summary_iterator
51from tensorflow.python.training import adam
52from tensorflow.python.training import checkpoint_management
53
54try:
55  import h5py  # pylint:disable=g-import-not-at-top
56except ImportError:
57  h5py = None
58
59try:
60  import requests  # pylint:disable=g-import-not-at-top
61except ImportError:
62  requests = None
63
64
65TRAIN_SAMPLES = 10
66TEST_SAMPLES = 10
67NUM_CLASSES = 2
68INPUT_DIM = 3
69NUM_HIDDEN = 5
70BATCH_SIZE = 5
71
72
73class Counter(keras.callbacks.Callback):
74  """Counts the number of times each callback method was run.
75
76  Attributes:
77    method_counts: dict. Contains the counts of time  each callback method was
78      run.
79  """
80
81  def __init__(self):
82    self.method_counts = collections.defaultdict(int)
83    methods_to_count = [
84        'on_batch_begin', 'on_batch_end', 'on_epoch_begin', 'on_epoch_end',
85        'on_predict_batch_begin', 'on_predict_batch_end', 'on_predict_begin',
86        'on_predict_end', 'on_test_batch_begin', 'on_test_batch_end',
87        'on_test_begin', 'on_test_end', 'on_train_batch_begin',
88        'on_train_batch_end', 'on_train_begin', 'on_train_end'
89    ]
90    for method_name in methods_to_count:
91      setattr(self, method_name,
92              self.wrap_with_counts(method_name, getattr(self, method_name)))
93
94  def wrap_with_counts(self, method_name, method):
95
96    def _call_and_count(*args, **kwargs):
97      self.method_counts[method_name] += 1
98      return method(*args, **kwargs)
99
100    return _call_and_count
101
102
103def _get_numpy():
104  return np.ones((10, 10)), np.ones((10, 1))
105
106
107def _get_sequence():
108
109  class MySequence(keras.utils.data_utils.Sequence):
110
111    def __getitem__(self, _):
112      return np.ones((2, 10)), np.ones((2, 1))
113
114    def __len__(self):
115      return 5
116
117  return MySequence(), None
118
119
120@keras_parameterized.run_with_all_model_types
121@keras_parameterized.run_all_keras_modes
122class CallbackCountsTest(keras_parameterized.TestCase):
123
124  def _check_counts(self, counter, expected_counts):
125    """Checks that the counts registered by `counter` are those expected."""
126    for method_name, expected_count in expected_counts.items():
127      self.assertEqual(
128          counter.method_counts[method_name],
129          expected_count,
130          msg='For method {}: expected {}, got: {}'.format(
131              method_name, expected_count, counter.method_counts[method_name]))
132
133  def _get_model(self):
134    layers = [
135        keras.layers.Dense(10, activation='relu'),
136        keras.layers.Dense(1, activation='sigmoid')
137    ]
138    model = testing_utils.get_model_from_layers(layers, input_shape=(10,))
139    model.compile(
140        adam.AdamOptimizer(0.001),
141        'binary_crossentropy',
142        run_eagerly=testing_utils.should_run_eagerly(),
143        experimental_run_tf_function=testing_utils.should_run_tf_function())
144    return model
145
146  @parameterized.named_parameters(('with_numpy', _get_numpy()),
147                                  ('with_sequence', _get_sequence()))
148  def test_callback_hooks_are_called_in_fit(self, data):
149    x, y = data
150    val_x, val_y = np.ones((4, 10)), np.ones((4, 1))
151    is_sequence = isinstance(x, keras.utils.data_utils.Sequence)
152
153    model = self._get_model()
154    counter = Counter()
155    model.fit(
156        x,
157        y,
158        validation_data=(val_x, val_y),
159        batch_size=2 if not is_sequence else None,
160        steps_per_epoch=5 if is_sequence else None,
161        epochs=5,
162        callbacks=[counter])
163
164    self._check_counts(
165        counter, {
166            'on_batch_begin': 25,
167            'on_batch_end': 25,
168            'on_epoch_begin': 5,
169            'on_epoch_end': 5,
170            'on_predict_batch_begin': 0,
171            'on_predict_batch_end': 0,
172            'on_predict_begin': 0,
173            'on_predict_end': 0,
174            'on_test_batch_begin': 10,
175            'on_test_batch_end': 10,
176            'on_test_begin': 5,
177            'on_test_end': 5,
178            'on_train_batch_begin': 25,
179            'on_train_batch_end': 25,
180            'on_train_begin': 1,
181            'on_train_end': 1
182        })
183
184  @parameterized.named_parameters(('with_numpy', _get_numpy()),
185                                  ('with_sequence', _get_sequence()))
186  def test_callback_hooks_are_called_in_evaluate(self, data):
187    x, y = data
188    is_sequence = isinstance(x, keras.utils.data_utils.Sequence)
189
190    model = self._get_model()
191    counter = Counter()
192    model.evaluate(
193        x,
194        y,
195        batch_size=2 if not is_sequence else None,
196        steps=5 if is_sequence else None,
197        callbacks=[counter])
198    self._check_counts(
199        counter, {
200            'on_test_batch_begin': 5,
201            'on_test_batch_end': 5,
202            'on_test_begin': 1,
203            'on_test_end': 1
204        })
205
206  @parameterized.named_parameters(('with_numpy', _get_numpy()),
207                                  ('with_sequence', _get_sequence()))
208  def test_callback_hooks_are_called_in_predict(self, data):
209    x = data[0]
210    is_sequence = isinstance(x, keras.utils.data_utils.Sequence)
211
212    model = self._get_model()
213    counter = Counter()
214    model.predict(
215        x,
216        batch_size=2 if not is_sequence else None,
217        steps=5 if is_sequence else None,
218        callbacks=[counter])
219    self._check_counts(
220        counter, {
221            'on_predict_batch_begin': 5,
222            'on_predict_batch_end': 5,
223            'on_predict_begin': 1,
224            'on_predict_end': 1
225        })
226
227  def test_callback_list_methods(self):
228    counter = Counter()
229    callback_list = keras.callbacks.CallbackList([counter])
230
231    batch = 0
232    callback_list.on_test_batch_begin(batch)
233    callback_list.on_test_batch_end(batch)
234    callback_list.on_predict_batch_begin(batch)
235    callback_list.on_predict_batch_end(batch)
236
237    self._check_counts(
238        counter, {
239            'on_test_batch_begin': 1,
240            'on_test_batch_end': 1,
241            'on_predict_batch_begin': 1,
242            'on_predict_batch_end': 1
243        })
244
245
246class KerasCallbacksTest(keras_parameterized.TestCase):
247
248  def _get_model(self, input_shape=None):
249    layers = [
250        keras.layers.Dense(3, activation='relu'),
251        keras.layers.Dense(2, activation='softmax')
252    ]
253    model = testing_utils.get_model_from_layers(layers, input_shape=input_shape)
254    model.compile(
255        loss='mse',
256        optimizer='rmsprop',
257        metrics=[keras.metrics.CategoricalAccuracy(name='my_acc')],
258        run_eagerly=testing_utils.should_run_eagerly(),
259        experimental_run_tf_function=testing_utils.should_run_tf_function())
260    return model
261
262  @keras_parameterized.run_with_all_model_types
263  @keras_parameterized.run_all_keras_modes
264  def test_progbar_logging(self):
265    model = self._get_model(input_shape=(3,))
266
267    x = array_ops.ones((50, 3))
268    y = array_ops.zeros((50, 2))
269    dataset = dataset_ops.Dataset.from_tensor_slices((x, y)).batch(10)
270    expected_log = r'(.*- loss:.*- my_acc:.*)+'
271
272    with self.captureWritesToStream(sys.stdout) as printed:
273      model.fit(dataset, epochs=2, steps_per_epoch=10)
274      self.assertRegexpMatches(printed.contents(), expected_log)
275
276  @keras_parameterized.run_with_all_model_types(exclude_models='functional')
277  @keras_parameterized.run_all_keras_modes
278  def test_progbar_logging_deferred_model_build(self):
279    model = self._get_model()
280    self.assertFalse(model.built)
281
282    x = array_ops.ones((50, 3))
283    y = array_ops.zeros((50, 2))
284    dataset = dataset_ops.Dataset.from_tensor_slices((x, y)).batch(10)
285    expected_log = r'(.*- loss:.*- my_acc:.*)+'
286
287    with self.captureWritesToStream(sys.stdout) as printed:
288      model.fit(dataset, epochs=2, steps_per_epoch=10)
289      self.assertRegexpMatches(printed.contents(), expected_log)
290
291  @keras_parameterized.run_with_all_model_types
292  @keras_parameterized.run_all_keras_modes
293  def test_progbar_logging_validation_data(self):
294    model = self._get_model(input_shape=(3,))
295
296    x = array_ops.ones((50, 3))
297    y = array_ops.zeros((50, 2))
298    training_dataset = dataset_ops.Dataset.from_tensor_slices((x, y)).batch(10)
299    val_dataset = dataset_ops.Dataset.from_tensor_slices((x, y)).batch(10)
300    expected_log = r'(.*5/5.*- loss:.*- my_acc:.*- val_loss:.*- val_my_acc:.*)+'
301
302    with self.captureWritesToStream(sys.stdout) as printed:
303      model.fit(training_dataset, epochs=2, validation_data=val_dataset)
304      self.assertRegexpMatches(printed.contents(), expected_log)
305
306  @keras_parameterized.run_with_all_model_types
307  @keras_parameterized.run_all_keras_modes
308  def test_progbar_logging_validation_split(self):
309    model = self._get_model(input_shape=(3,))
310
311    x = np.ones((100, 3))
312    y = np.zeros((100, 2))
313    expected_log = (
314        r'(?s).*1/2.*80/80.*- loss:.*- my_acc:.*- val_loss:.*- val_my_acc:'
315        r'.*2/2.*80/80.*- loss:.*- my_acc:.*- val_loss:.*- val_my_acc:.*')
316
317    with self.captureWritesToStream(sys.stdout) as printed:
318      model.fit(x, y, batch_size=10, epochs=2, validation_split=0.2)
319      self.assertRegexpMatches(printed.contents(), expected_log)
320
321  @keras_parameterized.run_with_all_model_types
322  @keras_parameterized.run_all_keras_modes(always_skip_v1=True)
323  def test_progbar_logging_training_validation(self):
324    model = self._get_model(input_shape=(2,))
325
326    def generator():
327      for _ in range(100):
328        yield [1, 1], 1
329
330    training = dataset_ops.Dataset \
331        .from_generator(
332            generator=generator,
333            output_types=('float64', 'float64'),
334            output_shapes=([2], [])) \
335        .batch(2) \
336        .repeat()
337    validation = dataset_ops.Dataset \
338        .from_generator(
339            generator=generator,
340            output_types=('float64', 'float64'),
341            output_shapes=([2], [])) \
342        .batch(2)
343    expected_log = (
344        r'(?s).*1/2.*20/20.*- loss:.*- my_acc:.*- val_loss:.*- val_my_acc:'
345        r'.*2/2.*20/20.*- loss:.*- my_acc:.*- val_loss:.*- val_my_acc:.*')
346
347    with self.captureWritesToStream(sys.stdout) as printed:
348      model.fit(
349          x=training, validation_data=validation, epochs=2, steps_per_epoch=20)
350      self.assertRegexpMatches(printed.contents(), expected_log)
351
352  @keras_parameterized.run_with_all_model_types
353  @keras_parameterized.run_all_keras_modes(always_skip_v1=True)
354  def test_progbar_logging_with_dataset_and_partial_batch(self):
355    model = self._get_model(input_shape=(2,))
356
357    def generator():
358      # Have a partial batch at the end.
359      for _ in range(9):
360        yield np.random.random(2), 1
361
362    training = dataset_ops.Dataset \
363      .from_generator(
364          generator=generator,
365          output_types=('float64', 'float64'),
366          output_shapes=([2], [])) \
367      .batch(2)
368    validation = dataset_ops.Dataset \
369      .from_generator(
370          generator=generator,
371          output_types=('float64', 'float64'),
372          output_shapes=([2], [])) \
373      .batch(2)
374
375    with self.captureWritesToStream(sys.stdout) as printed:
376      model.fit(x=training, validation_data=validation)
377
378      # Make sure the value of val_ metrics are not zeros.
379      log_content = printed.contents()
380      val_loss = re.findall(r'val_loss: (\d\.\d+)', log_content)
381      self.assertLen(val_loss, 1)
382      self.assertGreater(float(val_loss[0]), 0.0)
383
384  @keras_parameterized.run_with_all_model_types
385  def test_ModelCheckpoint(self):
386    if h5py is None:
387      return  # Skip test if models cannot be saved.
388
389    layers = [
390        keras.layers.Dense(NUM_HIDDEN, input_dim=INPUT_DIM, activation='relu'),
391        keras.layers.Dense(NUM_CLASSES, activation='softmax')
392    ]
393    model = testing_utils.get_model_from_layers(layers, input_shape=(10,))
394    model.compile(
395        loss='categorical_crossentropy', optimizer='rmsprop', metrics=['acc'])
396
397    temp_dir = self.get_temp_dir()
398    self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True)
399
400    filepath = os.path.join(temp_dir, 'checkpoint.h5')
401    (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
402        train_samples=TRAIN_SAMPLES,
403        test_samples=TEST_SAMPLES,
404        input_shape=(INPUT_DIM,),
405        num_classes=NUM_CLASSES)
406    y_test = np_utils.to_categorical(y_test)
407    y_train = np_utils.to_categorical(y_train)
408    # case 1
409    monitor = 'val_loss'
410    save_best_only = False
411    mode = 'auto'
412
413    model = keras.models.Sequential()
414    model.add(
415        keras.layers.Dense(
416            NUM_HIDDEN, input_dim=INPUT_DIM, activation='relu'))
417    model.add(keras.layers.Dense(NUM_CLASSES, activation='softmax'))
418    model.compile(
419        loss='categorical_crossentropy', optimizer='rmsprop', metrics=['acc'])
420
421    cbks = [
422        keras.callbacks.ModelCheckpoint(
423            filepath,
424            monitor=monitor,
425            save_best_only=save_best_only,
426            mode=mode)
427    ]
428    model.fit(
429        x_train,
430        y_train,
431        batch_size=BATCH_SIZE,
432        validation_data=(x_test, y_test),
433        callbacks=cbks,
434        epochs=1,
435        verbose=0)
436    assert os.path.exists(filepath)
437    os.remove(filepath)
438
439    # case 2
440    mode = 'min'
441    cbks = [
442        keras.callbacks.ModelCheckpoint(
443            filepath,
444            monitor=monitor,
445            save_best_only=save_best_only,
446            mode=mode)
447    ]
448    model.fit(
449        x_train,
450        y_train,
451        batch_size=BATCH_SIZE,
452        validation_data=(x_test, y_test),
453        callbacks=cbks,
454        epochs=1,
455        verbose=0)
456    assert os.path.exists(filepath)
457    os.remove(filepath)
458
459    # case 3
460    mode = 'max'
461    monitor = 'val_acc'
462    cbks = [
463        keras.callbacks.ModelCheckpoint(
464            filepath,
465            monitor=monitor,
466            save_best_only=save_best_only,
467            mode=mode)
468    ]
469    model.fit(
470        x_train,
471        y_train,
472        batch_size=BATCH_SIZE,
473        validation_data=(x_test, y_test),
474        callbacks=cbks,
475        epochs=1,
476        verbose=0)
477    assert os.path.exists(filepath)
478    os.remove(filepath)
479
480    # case 4
481    save_best_only = True
482    cbks = [
483        keras.callbacks.ModelCheckpoint(
484            filepath,
485            monitor=monitor,
486            save_best_only=save_best_only,
487            mode=mode)
488    ]
489    model.fit(
490        x_train,
491        y_train,
492        batch_size=BATCH_SIZE,
493        validation_data=(x_test, y_test),
494        callbacks=cbks,
495        epochs=1,
496        verbose=0)
497    assert os.path.exists(filepath)
498    os.remove(filepath)
499
500    # Case: metric not available.
501    cbks = [
502        keras.callbacks.ModelCheckpoint(
503            filepath,
504            monitor='unknown',
505            save_best_only=True)
506    ]
507    model.fit(
508        x_train,
509        y_train,
510        batch_size=BATCH_SIZE,
511        validation_data=(x_test, y_test),
512        callbacks=cbks,
513        epochs=1,
514        verbose=0)
515    # File won't be written.
516    assert not os.path.exists(filepath)
517
518    # case 5
519    save_best_only = False
520    period = 2
521    mode = 'auto'
522
523    filepath = os.path.join(temp_dir, 'checkpoint.{epoch:02d}.h5')
524    cbks = [
525        keras.callbacks.ModelCheckpoint(
526            filepath,
527            monitor=monitor,
528            save_best_only=save_best_only,
529            mode=mode,
530            period=period)
531    ]
532    model.fit(
533        x_train,
534        y_train,
535        batch_size=BATCH_SIZE,
536        validation_data=(x_test, y_test),
537        callbacks=cbks,
538        epochs=4,
539        verbose=1)
540    assert os.path.exists(filepath.format(epoch=2))
541    assert os.path.exists(filepath.format(epoch=4))
542    os.remove(filepath.format(epoch=2))
543    os.remove(filepath.format(epoch=4))
544    assert not os.path.exists(filepath.format(epoch=1))
545    assert not os.path.exists(filepath.format(epoch=3))
546
547    # Invalid use: this will raise a warning but not an Exception.
548    keras.callbacks.ModelCheckpoint(
549        filepath,
550        monitor=monitor,
551        save_best_only=save_best_only,
552        mode='unknown')
553
554    # Case 6: `ModelCheckpoint` with a combination of `save_freq` and `period`.
555    # Though `period` is deprecated, we're testing it for
556    # backward-compatibility.
557    filepath = os.path.join(temp_dir, 'checkpoint.epoch{epoch:02d}.h5')
558    cbks = [
559        keras.callbacks.ModelCheckpoint(
560            filepath, monitor=monitor, mode=mode, save_freq='epoch', period=5)
561    ]
562    assert not os.path.exists(filepath.format(epoch=0))
563    assert not os.path.exists(filepath.format(epoch=5))
564    model.fit(
565        x_train,
566        y_train,
567        batch_size=2,
568        validation_data=(x_test, y_test),
569        callbacks=cbks,
570        epochs=10,
571        verbose=1)
572    assert not os.path.exists(filepath.format(epoch=1))
573    assert not os.path.exists(filepath.format(epoch=2))
574    assert not os.path.exists(filepath.format(epoch=3))
575    assert not os.path.exists(filepath.format(epoch=4))
576    assert os.path.exists(filepath.format(epoch=5))
577    assert not os.path.exists(filepath.format(epoch=6))
578    assert os.path.exists(filepath.format(epoch=10))
579    os.remove(filepath.format(epoch=5))
580    os.remove(filepath.format(epoch=10))
581
582    # Case 7: `ModelCheckpoint` with an integer `save_freq`
583    filepath = os.path.join(temp_dir, 'checkpoint.epoch{epoch:02d}.h5')
584    cbks = [
585        keras.callbacks.ModelCheckpoint(
586            filepath,
587            monitor=monitor,
588            save_best_only=save_best_only,
589            mode=mode,
590            save_freq=30,
591            period=100)  # The period should be ignored (this test tests this).
592    ]
593    assert not os.path.exists(filepath.format(epoch=3))
594    model.fit(
595        x_train,
596        y_train,
597        batch_size=2,
598        validation_data=(x_test, y_test),
599        callbacks=cbks,
600        epochs=10,
601        verbose=1)
602    assert not os.path.exists(filepath.format(epoch=1))
603    assert not os.path.exists(filepath.format(epoch=2))
604    assert os.path.exists(filepath.format(epoch=3))
605    assert not os.path.exists(filepath.format(epoch=4))
606    assert not os.path.exists(filepath.format(epoch=5))
607    assert os.path.exists(filepath.format(epoch=6))
608    assert not os.path.exists(filepath.format(epoch=7))
609    assert not os.path.exists(filepath.format(epoch=8))
610    assert os.path.exists(filepath.format(epoch=9))
611    os.remove(filepath.format(epoch=3))
612    os.remove(filepath.format(epoch=6))
613    os.remove(filepath.format(epoch=9))
614
615    # Case 8: `ModelCheckpoint` with valid and invalid save_freq argument.
616    with self.assertRaisesRegexp(ValueError, 'Unrecognized save_freq'):
617      keras.callbacks.ModelCheckpoint(
618          filepath,
619          monitor=monitor,
620          save_best_only=save_best_only,
621          mode=mode,
622          save_freq='invalid_save_freq')
623    # The following should not raise ValueError.
624    keras.callbacks.ModelCheckpoint(
625        filepath,
626        monitor=monitor,
627        save_best_only=save_best_only,
628        mode=mode,
629        save_freq='epoch')
630    keras.callbacks.ModelCheckpoint(
631        filepath,
632        monitor=monitor,
633        save_best_only=save_best_only,
634        mode=mode,
635        save_freq=3)
636
637  def _get_dummy_resource_for_model_checkpoint_testing(self):
638
639    def get_input_datasets():
640      # Simple training input.
641      train_input = [[1]] * 16
642      train_label = [[0]] * 16
643      ds = dataset_ops.Dataset.from_tensor_slices((train_input, train_label))
644      return ds.batch(8, drop_remainder=True)
645
646    # Very simple bias model to eliminate randomness.
647    optimizer = gradient_descent.SGD(0.1)
648    model = sequential.Sequential()
649    model.add(testing_utils.Bias(input_shape=(1,)))
650    model.compile(loss='mae', optimizer=optimizer, metrics=['mae'])
651    train_ds = get_input_datasets()
652
653    temp_dir = self.get_temp_dir()
654    filepath = os.path.join(temp_dir, 'checkpoint.epoch{epoch:02d}.h5')
655
656    # The filepath shouldn't exist at the beginning.
657    self.assertFalse(os.path.exists(filepath))
658    callback = keras.callbacks.ModelCheckpoint(
659        filepath=filepath, save_weights_only=True)
660
661    return model, train_ds, callback, filepath
662
663  def _run_load_weights_on_restart_test_common_iterations(self):
664
665    (model, train_ds, callback,
666     filepath) = self._get_dummy_resource_for_model_checkpoint_testing()
667    initial_epochs = 3
668    model.fit(train_ds, epochs=initial_epochs, callbacks=[callback])
669
670    # The files should exist after fitting with callback.
671    for epoch in range(initial_epochs):
672      self.assertTrue(os.path.exists(filepath.format(epoch=epoch + 1)))
673    self.assertFalse(os.path.exists(filepath.format(epoch=initial_epochs + 1)))
674    self.assertEqual(
675        callback._get_most_recently_modified_file_matching_pattern(filepath),
676        filepath.format(epoch=initial_epochs))
677
678    model.fit(train_ds, epochs=1)
679    weights_after_one_more_epoch = model.get_weights()
680
681    # The filepath should continue to exist after fitting without callback.
682    for epoch in range(initial_epochs):
683      self.assertTrue(os.path.exists(filepath.format(epoch=epoch + 1)))
684
685    return model, train_ds, filepath, weights_after_one_more_epoch
686
687  @staticmethod
688  def get_ModelCheckpoint_load_weights_on_restart_true_test(save_weights_only):
689
690    def func(self):
691      (model, train_ds, filepath, weights_after_one_more_epoch
692      ) = self._run_load_weights_on_restart_test_common_iterations()
693
694      # Sleep for some short time period ensuring the files are created with
695      # a different time (in MacOS OSS the granularity is only 1 second).
696      time.sleep(2)
697      callback = keras.callbacks.ModelCheckpoint(
698          filepath=filepath,
699          save_weights_only=save_weights_only,
700          load_weights_on_restart=True)
701      model.fit(train_ds, epochs=1, callbacks=[callback])
702      weights_after_model_restoring_and_one_more_epoch = model.get_weights()
703
704      self.assertEqual(
705          callback._get_most_recently_modified_file_matching_pattern(filepath),
706          filepath.format(epoch=1))
707
708      model.fit(
709          train_ds,
710          epochs=1,
711          callbacks=[
712              keras.callbacks.ModelCheckpoint(
713                  filepath=filepath,
714                  save_weights_only=save_weights_only,
715                  load_weights_on_restart=True)
716          ])
717      weights_with_one_final_extra_epoch = model.get_weights()
718
719      # Asserting the weights one epoch after initial fitting and another epoch
720      # after that are closed, if a ModelCheckpoint with
721      # load_weights_on_restart=True is given (so the model is restored at the
722      # beginning of training).
723      self.assertAllClose(weights_after_one_more_epoch,
724                          weights_after_model_restoring_and_one_more_epoch)
725
726      self.assertNotAllClose(weights_after_one_more_epoch,
727                             weights_with_one_final_extra_epoch)
728
729    return func
730
731  @staticmethod
732  def get_ModelCheckpoint_load_weights_on_restart_false_test(save_weights_only):
733
734    def func(self):
735      (model, train_ds, filepath, weights_after_one_more_epoch
736      ) = self._run_load_weights_on_restart_test_common_iterations()
737
738      model.fit(
739          train_ds,
740          epochs=1,
741          callbacks=[
742              keras.callbacks.ModelCheckpoint(
743                  filepath=filepath, save_weights_only=save_weights_only)
744          ])
745      weights_after_model_restoring_and_one_more_epoch = model.get_weights()
746
747      # Asserting the weights one epoch after initial fitting and another epoch
748      # after that are different, if a ModelCheckpoint with
749      # load_weights_on_restart=False is given (so the model is not restored at
750      # the beginning of training).
751      self.assertNotAllClose(weights_after_one_more_epoch,
752                             weights_after_model_restoring_and_one_more_epoch)
753
754    return func
755
756  test_model_checkpoint_load_weights_on_restart_true_save_weights_only_true = \
757        get_ModelCheckpoint_load_weights_on_restart_true_test.__func__(True)
758
759  test_model_checkpoint_load_weights_on_restart_true_save_weights_only_false = \
760        get_ModelCheckpoint_load_weights_on_restart_true_test.__func__(False)
761
762  test_model_checkpoint_load_weights_on_restart_false_save_weights_only_true = \
763        get_ModelCheckpoint_load_weights_on_restart_false_test.__func__(True)
764
765  test_model_checkpoint_load_weights_on_restart_false_save_weights_only_false \
766        = get_ModelCheckpoint_load_weights_on_restart_false_test.__func__(False)
767
768  def test_ModelCheckpoint_override_if_file_exist(self):
769    (model, train_ds, filepath,
770     _) = self._run_load_weights_on_restart_test_common_iterations()
771
772    # Sleep for some short time period to ensure the files are created with
773    # a different time (in MacOS OSS the granularity is only 1 second).
774    time.sleep(2)
775    callback = keras.callbacks.ModelCheckpoint(
776        filepath=filepath, save_weights_only=True)
777    model.load_weights(
778        callback._get_most_recently_modified_file_matching_pattern(filepath))
779    weights_before_additional_fit = model.get_weights()
780    model.fit(train_ds, epochs=1, callbacks=[callback])
781    model.load_weights(
782        callback._get_most_recently_modified_file_matching_pattern(filepath))
783    weights_after_additional_fit = model.get_weights()
784
785    self.assertNotAllClose(weights_before_additional_fit,
786                           weights_after_additional_fit)
787
788  def test_fit_with_ModelCheckpoint_with_tf_config(self):
789    (model, train_ds, callback,
790     _) = self._get_dummy_resource_for_model_checkpoint_testing()
791
792    os.environ['TF_CONFIG'] = json.dumps({
793        'cluster': {
794            'worker': ['localhost:23333']
795        },
796        'task': {
797            'type': 'worker',
798            'index': 0
799        }
800    })
801
802    # `model.fit()` should work regardless of the presence of `TF_CONFIG`.
803    model.fit(train_ds, epochs=1, callbacks=[callback])
804
805  def test_fit_with_ModelCheckpoint_with_dir_as_h5_filepath(self):
806    (model, train_ds, callback,
807     filepath) = self._get_dummy_resource_for_model_checkpoint_testing()
808
809    temp_dir = self.get_temp_dir()
810    filepath = os.path.join(temp_dir, 'temp.h5')
811
812    self.assertFalse(os.path.exists(filepath))
813    os.mkdir(filepath)
814    self.assertTrue(os.path.exists(filepath))
815
816    callback = keras.callbacks.ModelCheckpoint(filepath=filepath)
817
818    with self.assertRaisesRegexp(IOError, 'Please specify a non-directory '
819                                          'filepath for ModelCheckpoint.'):
820      model.fit(train_ds, epochs=1, callbacks=[callback])
821
822  def test_ModelCheckpoint_with_bad_path_placeholders(self):
823    (model, train_ds, callback,
824     filepath) = self._get_dummy_resource_for_model_checkpoint_testing()
825
826    temp_dir = self.get_temp_dir()
827    filepath = os.path.join(temp_dir, 'chkpt_{epoch:02d}_{mape:.2f}.h5')
828    callback = keras.callbacks.ModelCheckpoint(filepath=filepath)
829
830    with self.assertRaisesRegexp(KeyError, 'Failed to format this callback '
831                                           'filepath.*'):
832      model.fit(train_ds, epochs=1, callbacks=[callback])
833
834  def test_EarlyStopping(self):
835    with self.cached_session():
836      np.random.seed(123)
837      (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
838          train_samples=TRAIN_SAMPLES,
839          test_samples=TEST_SAMPLES,
840          input_shape=(INPUT_DIM,),
841          num_classes=NUM_CLASSES)
842      y_test = np_utils.to_categorical(y_test)
843      y_train = np_utils.to_categorical(y_train)
844      model = testing_utils.get_small_sequential_mlp(
845          num_hidden=NUM_HIDDEN, num_classes=NUM_CLASSES, input_dim=INPUT_DIM)
846      model.compile(
847          loss='categorical_crossentropy', optimizer='rmsprop', metrics=['acc'])
848
849      cases = [
850          ('max', 'val_acc'),
851          ('min', 'val_loss'),
852          ('auto', 'val_acc'),
853          ('auto', 'loss'),
854          ('unknown', 'unknown')
855      ]
856      for mode, monitor in cases:
857        patience = 0
858        cbks = [
859            keras.callbacks.EarlyStopping(
860                patience=patience, monitor=monitor, mode=mode)
861        ]
862        model.fit(
863            x_train,
864            y_train,
865            batch_size=BATCH_SIZE,
866            validation_data=(x_test, y_test),
867            callbacks=cbks,
868            epochs=5,
869            verbose=0)
870
871  def test_EarlyStopping_reuse(self):
872    with self.cached_session():
873      np.random.seed(1337)
874      patience = 3
875      data = np.random.random((100, 1))
876      labels = np.where(data > 0.5, 1, 0)
877      model = keras.models.Sequential((keras.layers.Dense(
878          1, input_dim=1, activation='relu'), keras.layers.Dense(
879              1, activation='sigmoid'),))
880      model.compile(
881          optimizer='sgd', loss='binary_crossentropy', metrics=['accuracy'])
882      weights = model.get_weights()
883
884      stopper = keras.callbacks.EarlyStopping(monitor='acc', patience=patience)
885      hist = model.fit(data, labels, callbacks=[stopper], verbose=0, epochs=20)
886      assert len(hist.epoch) >= patience
887
888      # This should allow training to go for at least `patience` epochs
889      model.set_weights(weights)
890      hist = model.fit(data, labels, callbacks=[stopper], verbose=0, epochs=20)
891      assert len(hist.epoch) >= patience
892
893  def test_EarlyStopping_with_baseline(self):
894    with self.cached_session():
895      np.random.seed(1337)
896      baseline = 0.5
897      (data, labels), _ = testing_utils.get_test_data(
898          train_samples=100,
899          test_samples=50,
900          input_shape=(1,),
901          num_classes=NUM_CLASSES)
902      model = testing_utils.get_small_sequential_mlp(
903          num_hidden=1, num_classes=1, input_dim=1)
904      model.compile(
905          optimizer='sgd', loss='binary_crossentropy', metrics=['acc'])
906
907      stopper = keras.callbacks.EarlyStopping(monitor='acc',
908                                              baseline=baseline)
909      hist = model.fit(data, labels, callbacks=[stopper], verbose=0, epochs=20)
910      assert len(hist.epoch) == 1
911
912      patience = 3
913      stopper = keras.callbacks.EarlyStopping(monitor='acc',
914                                              patience=patience,
915                                              baseline=baseline)
916      hist = model.fit(data, labels, callbacks=[stopper], verbose=0, epochs=20)
917      assert len(hist.epoch) >= patience
918
919  def test_EarlyStopping_final_weights_when_restoring_model_weights(self):
920
921    class DummyModel(object):
922
923      def __init__(self):
924        self.stop_training = False
925        self.weights = -1
926
927      def get_weights(self):
928        return self.weights
929
930      def set_weights(self, weights):
931        self.weights = weights
932
933      def set_weight_to_epoch(self, epoch):
934        self.weights = epoch
935
936    early_stop = keras.callbacks.EarlyStopping(monitor='val_loss',
937                                               patience=2,
938                                               restore_best_weights=True)
939    early_stop.model = DummyModel()
940    losses = [0.2, 0.15, 0.1, 0.11, 0.12]
941    # The best configuration is in the epoch 2 (loss = 0.1000).
942    epochs_trained = 0
943    early_stop.on_train_begin()
944    for epoch in range(len(losses)):
945      epochs_trained += 1
946      early_stop.model.set_weight_to_epoch(epoch=epoch)
947      early_stop.on_epoch_end(epoch, logs={'val_loss': losses[epoch]})
948      if early_stop.model.stop_training:
949        break
950    # The best configuration is in epoch 2 (loss = 0.1000),
951    # and while patience = 2, we're restoring the best weights,
952    # so we end up at the epoch with the best weights, i.e. epoch 2
953    self.assertEqual(early_stop.model.get_weights(), 2)
954
955  def test_RemoteMonitor(self):
956    if requests is None:
957      return
958
959    monitor = keras.callbacks.RemoteMonitor()
960    # This will raise a warning since the default address in unreachable:
961    monitor.on_epoch_end(0, logs={'loss': 0.})
962
963  def test_LearningRateScheduler(self):
964    with self.cached_session():
965      np.random.seed(1337)
966      (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
967          train_samples=TRAIN_SAMPLES,
968          test_samples=TEST_SAMPLES,
969          input_shape=(INPUT_DIM,),
970          num_classes=NUM_CLASSES)
971      y_test = np_utils.to_categorical(y_test)
972      y_train = np_utils.to_categorical(y_train)
973      model = testing_utils.get_small_sequential_mlp(
974          num_hidden=NUM_HIDDEN, num_classes=NUM_CLASSES, input_dim=INPUT_DIM)
975      model.compile(
976          loss='categorical_crossentropy',
977          optimizer='sgd',
978          metrics=['accuracy'])
979
980      cbks = [keras.callbacks.LearningRateScheduler(lambda x: 1. / (1. + x))]
981      model.fit(
982          x_train,
983          y_train,
984          batch_size=BATCH_SIZE,
985          validation_data=(x_test, y_test),
986          callbacks=cbks,
987          epochs=5,
988          verbose=0)
989      assert (
990          float(keras.backend.get_value(
991              model.optimizer.lr)) - 0.2) < keras.backend.epsilon()
992
993      cbks = [keras.callbacks.LearningRateScheduler(lambda x, lr: lr / 2)]
994      model.compile(
995          loss='categorical_crossentropy',
996          optimizer='sgd',
997          metrics=['accuracy'])
998      model.fit(
999          x_train,
1000          y_train,
1001          batch_size=BATCH_SIZE,
1002          validation_data=(x_test, y_test),
1003          callbacks=cbks,
1004          epochs=2,
1005          verbose=0)
1006      assert (
1007          float(keras.backend.get_value(
1008              model.optimizer.lr)) - 0.01 / 4) < keras.backend.epsilon()
1009
1010      cbks = [
1011          keras.callbacks.LearningRateScheduler(
1012              lambda epoch, _: learning_rate_schedule.CosineDecay(0.01, 2)
1013              (epoch))
1014      ]
1015      model.compile(
1016          loss='categorical_crossentropy',
1017          optimizer='sgd',
1018          metrics=['accuracy'])
1019      model.fit(
1020          x_train,
1021          y_train,
1022          batch_size=BATCH_SIZE,
1023          validation_data=(x_test, y_test),
1024          callbacks=cbks,
1025          epochs=2,
1026          verbose=0)
1027
1028      cosine_decay_np = 0.5 * (1 + np.cos(np.pi * (1 / 2)))
1029      decayed_learning_rate = 0.01 * cosine_decay_np
1030
1031      assert (float(keras.backend.get_value(model.optimizer.lr)) -
1032              decayed_learning_rate) < keras.backend.epsilon()
1033
1034  def test_ReduceLROnPlateau(self):
1035    with self.cached_session():
1036      np.random.seed(1337)
1037      (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
1038          train_samples=TRAIN_SAMPLES,
1039          test_samples=TEST_SAMPLES,
1040          input_shape=(INPUT_DIM,),
1041          num_classes=NUM_CLASSES)
1042      y_test = np_utils.to_categorical(y_test)
1043      y_train = np_utils.to_categorical(y_train)
1044
1045      def make_model():
1046        random_seed.set_random_seed(1234)
1047        np.random.seed(1337)
1048        model = testing_utils.get_small_sequential_mlp(
1049            num_hidden=NUM_HIDDEN, num_classes=NUM_CLASSES, input_dim=INPUT_DIM)
1050        model.compile(
1051            loss='categorical_crossentropy',
1052            optimizer=gradient_descent.SGD(lr=0.1))
1053        return model
1054
1055      # TODO(psv): Make sure the callback works correctly when min_delta is
1056      # set as 0. Test fails when the order of this callback and assertion is
1057      # interchanged.
1058      model = make_model()
1059      cbks = [
1060          keras.callbacks.ReduceLROnPlateau(
1061              monitor='val_loss',
1062              factor=0.1,
1063              min_delta=0,
1064              patience=1,
1065              cooldown=5)
1066      ]
1067      model.fit(
1068          x_train,
1069          y_train,
1070          batch_size=BATCH_SIZE,
1071          validation_data=(x_test, y_test),
1072          callbacks=cbks,
1073          epochs=2,
1074          verbose=0)
1075      self.assertAllClose(
1076          float(keras.backend.get_value(model.optimizer.lr)), 0.1, atol=1e-4)
1077
1078      model = make_model()
1079      # This should reduce the LR after the first epoch (due to high epsilon).
1080      cbks = [
1081          keras.callbacks.ReduceLROnPlateau(
1082              monitor='val_loss',
1083              factor=0.1,
1084              min_delta=10,
1085              patience=1,
1086              cooldown=5)
1087      ]
1088      model.fit(
1089          x_train,
1090          y_train,
1091          batch_size=BATCH_SIZE,
1092          validation_data=(x_test, y_test),
1093          callbacks=cbks,
1094          epochs=2,
1095          verbose=2)
1096      self.assertAllClose(
1097          float(keras.backend.get_value(model.optimizer.lr)), 0.01, atol=1e-4)
1098
1099  def test_ReduceLROnPlateau_patience(self):
1100
1101    class DummyOptimizer(object):
1102
1103      def __init__(self):
1104        self.lr = keras.backend.variable(1.0)
1105
1106    class DummyModel(object):
1107
1108      def __init__(self):
1109        self.optimizer = DummyOptimizer()
1110
1111    reduce_on_plateau = keras.callbacks.ReduceLROnPlateau(
1112        monitor='val_loss', patience=2)
1113    reduce_on_plateau.model = DummyModel()
1114
1115    losses = [0.0860, 0.1096, 0.1040]
1116    lrs = []
1117
1118    for epoch in range(len(losses)):
1119      reduce_on_plateau.on_epoch_end(epoch, logs={'val_loss': losses[epoch]})
1120      lrs.append(keras.backend.get_value(reduce_on_plateau.model.optimizer.lr))
1121
1122    # The learning rates should be 1.0 except the last one
1123    for lr in lrs[:-1]:
1124      self.assertEqual(lr, 1.0)
1125    self.assertLess(lrs[-1], 1.0)
1126
1127  def test_ReduceLROnPlateau_backwards_compatibility(self):
1128    with test.mock.patch.object(logging, 'warning') as mock_log:
1129      reduce_on_plateau = keras.callbacks.ReduceLROnPlateau(epsilon=1e-13)
1130      self.assertRegexpMatches(
1131          str(mock_log.call_args), '`epsilon` argument is deprecated')
1132    self.assertFalse(hasattr(reduce_on_plateau, 'epsilon'))
1133    self.assertTrue(hasattr(reduce_on_plateau, 'min_delta'))
1134    self.assertEqual(reduce_on_plateau.min_delta, 1e-13)
1135
1136  def test_CSVLogger(self):
1137    with self.cached_session():
1138      np.random.seed(1337)
1139      temp_dir = self.get_temp_dir()
1140      self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True)
1141      filepath = os.path.join(temp_dir, 'log.tsv')
1142
1143      sep = '\t'
1144      (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
1145          train_samples=TRAIN_SAMPLES,
1146          test_samples=TEST_SAMPLES,
1147          input_shape=(INPUT_DIM,),
1148          num_classes=NUM_CLASSES)
1149      y_test = np_utils.to_categorical(y_test)
1150      y_train = np_utils.to_categorical(y_train)
1151
1152      def make_model():
1153        np.random.seed(1337)
1154        model = testing_utils.get_small_sequential_mlp(
1155            num_hidden=NUM_HIDDEN, num_classes=NUM_CLASSES, input_dim=INPUT_DIM)
1156        model.compile(
1157            loss='categorical_crossentropy',
1158            optimizer=gradient_descent.SGD(lr=0.1),
1159            metrics=['accuracy'])
1160        return model
1161
1162      # case 1, create new file with defined separator
1163      model = make_model()
1164      cbks = [keras.callbacks.CSVLogger(filepath, separator=sep)]
1165      model.fit(
1166          x_train,
1167          y_train,
1168          batch_size=BATCH_SIZE,
1169          validation_data=(x_test, y_test),
1170          callbacks=cbks,
1171          epochs=1,
1172          verbose=0)
1173
1174      assert os.path.exists(filepath)
1175      with open(filepath) as csvfile:
1176        dialect = csv.Sniffer().sniff(csvfile.read())
1177      assert dialect.delimiter == sep
1178      del model
1179      del cbks
1180
1181      # case 2, append data to existing file, skip header
1182      model = make_model()
1183      cbks = [keras.callbacks.CSVLogger(filepath, separator=sep, append=True)]
1184      model.fit(
1185          x_train,
1186          y_train,
1187          batch_size=BATCH_SIZE,
1188          validation_data=(x_test, y_test),
1189          callbacks=cbks,
1190          epochs=1,
1191          verbose=0)
1192
1193      # case 3, reuse of CSVLogger object
1194      model.fit(
1195          x_train,
1196          y_train,
1197          batch_size=BATCH_SIZE,
1198          validation_data=(x_test, y_test),
1199          callbacks=cbks,
1200          epochs=2,
1201          verbose=0)
1202
1203      with open(filepath) as csvfile:
1204        list_lines = csvfile.readlines()
1205        for line in list_lines:
1206          assert line.count(sep) == 4
1207        assert len(list_lines) == 5
1208        output = ' '.join(list_lines)
1209        assert len(re.findall('epoch', output)) == 1
1210
1211      os.remove(filepath)
1212
1213  def test_stop_training_csv(self):
1214    # Test that using the CSVLogger callback with the TerminateOnNaN callback
1215    # does not result in invalid CSVs.
1216    np.random.seed(1337)
1217    tmpdir = self.get_temp_dir()
1218    self.addCleanup(shutil.rmtree, tmpdir, ignore_errors=True)
1219
1220    with self.cached_session():
1221      fp = os.path.join(tmpdir, 'test.csv')
1222      (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
1223          train_samples=TRAIN_SAMPLES,
1224          test_samples=TEST_SAMPLES,
1225          input_shape=(INPUT_DIM,),
1226          num_classes=NUM_CLASSES)
1227
1228      y_test = np_utils.to_categorical(y_test)
1229      y_train = np_utils.to_categorical(y_train)
1230      cbks = [keras.callbacks.TerminateOnNaN(), keras.callbacks.CSVLogger(fp)]
1231      model = keras.models.Sequential()
1232      for _ in range(5):
1233        model.add(keras.layers.Dense(2, input_dim=INPUT_DIM, activation='relu'))
1234      model.add(keras.layers.Dense(NUM_CLASSES, activation='linear'))
1235      model.compile(loss='mean_squared_error',
1236                    optimizer='rmsprop')
1237
1238      def data_generator():
1239        i = 0
1240        max_batch_index = len(x_train) // BATCH_SIZE
1241        tot = 0
1242        while 1:
1243          if tot > 3 * len(x_train):
1244            yield (np.ones([BATCH_SIZE, INPUT_DIM]) * np.nan,
1245                   np.ones([BATCH_SIZE, NUM_CLASSES]) * np.nan)
1246          else:
1247            yield (x_train[i * BATCH_SIZE: (i + 1) * BATCH_SIZE],
1248                   y_train[i * BATCH_SIZE: (i + 1) * BATCH_SIZE])
1249          i += 1
1250          tot += 1
1251          i %= max_batch_index
1252
1253      history = model.fit_generator(data_generator(),
1254                                    len(x_train) // BATCH_SIZE,
1255                                    validation_data=(x_test, y_test),
1256                                    callbacks=cbks,
1257                                    epochs=20)
1258      loss = history.history['loss']
1259      assert len(loss) > 1
1260      assert loss[-1] == np.inf or np.isnan(loss[-1])
1261
1262      values = []
1263      with open(fp) as f:
1264        for x in csv.reader(f):
1265          # In windows, due to \r\n line ends we may end up reading empty lines
1266          # after each line. Skip empty lines.
1267          if x:
1268            values.append(x)
1269      assert 'nan' in values[-1], 'The last epoch was not logged.'
1270
1271  def test_TerminateOnNaN(self):
1272    with self.cached_session():
1273      np.random.seed(1337)
1274      (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
1275          train_samples=TRAIN_SAMPLES,
1276          test_samples=TEST_SAMPLES,
1277          input_shape=(INPUT_DIM,),
1278          num_classes=NUM_CLASSES)
1279
1280      y_test = np_utils.to_categorical(y_test)
1281      y_train = np_utils.to_categorical(y_train)
1282      cbks = [keras.callbacks.TerminateOnNaN()]
1283      model = keras.models.Sequential()
1284      initializer = keras.initializers.Constant(value=1e5)
1285      for _ in range(5):
1286        model.add(
1287            keras.layers.Dense(
1288                2,
1289                input_dim=INPUT_DIM,
1290                activation='relu',
1291                kernel_initializer=initializer))
1292      model.add(keras.layers.Dense(NUM_CLASSES))
1293      model.compile(loss='mean_squared_error', optimizer='rmsprop')
1294
1295      history = model.fit(
1296          x_train,
1297          y_train,
1298          batch_size=BATCH_SIZE,
1299          validation_data=(x_test, y_test),
1300          callbacks=cbks,
1301          epochs=20)
1302      loss = history.history['loss']
1303      self.assertEqual(len(loss), 1)
1304      self.assertEqual(loss[0], np.inf)
1305
1306  @unittest.skipIf(
1307      os.name == 'nt',
1308      'use_multiprocessing=True does not work on windows properly.')
1309  def test_LambdaCallback(self):
1310    with self.cached_session():
1311      np.random.seed(1337)
1312      (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
1313          train_samples=TRAIN_SAMPLES,
1314          test_samples=TEST_SAMPLES,
1315          input_shape=(INPUT_DIM,),
1316          num_classes=NUM_CLASSES)
1317      y_test = np_utils.to_categorical(y_test)
1318      y_train = np_utils.to_categorical(y_train)
1319      model = keras.models.Sequential()
1320      model.add(
1321          keras.layers.Dense(
1322              NUM_HIDDEN, input_dim=INPUT_DIM, activation='relu'))
1323      model.add(keras.layers.Dense(NUM_CLASSES, activation='softmax'))
1324      model.compile(
1325          loss='categorical_crossentropy',
1326          optimizer='sgd',
1327          metrics=['accuracy'])
1328
1329      # Start an arbitrary process that should run during model
1330      # training and be terminated after training has completed.
1331      e = threading.Event()
1332
1333      def target():
1334        e.wait()
1335
1336      t = threading.Thread(target=target)
1337      t.start()
1338      cleanup_callback = keras.callbacks.LambdaCallback(
1339          on_train_end=lambda logs: e.set())
1340
1341      cbks = [cleanup_callback]
1342      model.fit(
1343          x_train,
1344          y_train,
1345          batch_size=BATCH_SIZE,
1346          validation_data=(x_test, y_test),
1347          callbacks=cbks,
1348          epochs=5,
1349          verbose=0)
1350      t.join()
1351      assert not t.is_alive()
1352
1353  def test_RemoteMonitor_np_array(self):
1354    if requests is None:
1355      self.skipTest('`requests` required to run this test')
1356    with test.mock.patch.object(requests, 'post') as requests_post:
1357      monitor = keras.callbacks.RemoteMonitor(send_as_json=True)
1358      a = np.arange(1)  # a 1 by 1 array
1359      logs = {'loss': 0., 'val': a}
1360      monitor.on_epoch_end(0, logs=logs)
1361      send = {'loss': 0., 'epoch': 0, 'val': 0}
1362      requests_post.assert_called_once_with(
1363          monitor.root + monitor.path, json=send, headers=monitor.headers)
1364
1365  def test_RemoteMonitor_np_float32(self):
1366    if requests is None:
1367      self.skipTest('`requests` required to run this test')
1368
1369    with test.mock.patch.object(requests, 'post') as requests_post:
1370      monitor = keras.callbacks.RemoteMonitor(send_as_json=True)
1371      a = np.float32(1.0)  # a float32 generic type
1372      logs = {'loss': 0., 'val': a}
1373      monitor.on_epoch_end(0, logs=logs)
1374      send = {'loss': 0., 'epoch': 0, 'val': 1.0}
1375      requests_post.assert_called_once_with(
1376          monitor.root + monitor.path, json=send, headers=monitor.headers)
1377
1378  def test_RemoteMonitorWithJsonPayload(self):
1379    if requests is None:
1380      self.skipTest('`requests` required to run this test')
1381    with self.cached_session():
1382      (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
1383          train_samples=TRAIN_SAMPLES,
1384          test_samples=TEST_SAMPLES,
1385          input_shape=(INPUT_DIM,),
1386          num_classes=NUM_CLASSES)
1387      y_test = keras.utils.np_utils.to_categorical(y_test)
1388      y_train = keras.utils.np_utils.to_categorical(y_train)
1389      model = keras.models.Sequential()
1390      model.add(
1391          keras.layers.Dense(
1392              NUM_HIDDEN, input_dim=INPUT_DIM, activation='relu'))
1393      model.add(keras.layers.Dense(NUM_CLASSES, activation='softmax'))
1394      model.compile(
1395          loss='categorical_crossentropy',
1396          optimizer='rmsprop',
1397          metrics=['accuracy'])
1398      cbks = [keras.callbacks.RemoteMonitor(send_as_json=True)]
1399
1400      with test.mock.patch.object(requests, 'post'):
1401        model.fit(
1402            x_train,
1403            y_train,
1404            batch_size=BATCH_SIZE,
1405            validation_data=(x_test, y_test),
1406            callbacks=cbks,
1407            epochs=1)
1408
1409  def test_callback_params_samples(self):
1410    x, y = np.ones((64, 3)), np.ones((64, 2))
1411    model = testing_utils.get_small_sequential_mlp(
1412        num_hidden=10, num_classes=2, input_dim=3)
1413    model.compile('sgd', 'mse')
1414    callback = keras.callbacks.Callback()
1415    model.evaluate(x, y, callbacks=[callback])
1416    self.assertEqual(callback.params['samples'], 64)
1417
1418
1419# A summary that was emitted during a test. Fields:
1420#   logdir: str. The logdir of the FileWriter to which the summary was
1421#     written.
1422#   tag: str. The name of the summary.
1423_ObservedSummary = collections.namedtuple('_ObservedSummary', ('logdir', 'tag'))
1424
1425
1426class _SummaryFile(object):
1427  """A record of summary tags and the files to which they were written.
1428
1429  Fields `scalars`, `images`, `histograms`, and `tensors` are sets
1430  containing `_ObservedSummary` values.
1431  """
1432
1433  def __init__(self):
1434    self.scalars = set()
1435    self.images = set()
1436    self.histograms = set()
1437    self.tensors = set()
1438
1439
1440def list_summaries(logdir):
1441  """Read all summaries under the logdir into a `_SummaryFile`.
1442
1443  Args:
1444    logdir: A path to a directory that contains zero or more event
1445      files, either as direct children or in transitive subdirectories.
1446      Summaries in these events must only contain old-style scalars,
1447      images, and histograms. Non-summary events, like `graph_def`s, are
1448      ignored.
1449
1450  Returns:
1451    A `_SummaryFile` object reflecting all summaries written to any
1452    event files in the logdir or any of its descendant directories.
1453
1454  Raises:
1455    ValueError: If an event file contains an summary of unexpected kind.
1456  """
1457  result = _SummaryFile()
1458  for (dirpath, dirnames, filenames) in os.walk(logdir):
1459    del dirnames  # unused
1460    for filename in filenames:
1461      if not filename.startswith('events.out.'):
1462        continue
1463      path = os.path.join(dirpath, filename)
1464      for event in summary_iterator.summary_iterator(path):
1465        if not event.summary:  # (e.g., it's a `graph_def` event)
1466          continue
1467        for value in event.summary.value:
1468          tag = value.tag
1469          # Case on the `value` rather than the summary metadata because
1470          # the Keras callback uses `summary_ops_v2` to emit old-style
1471          # summaries. See b/124535134.
1472          kind = value.WhichOneof('value')
1473          container = {
1474              'simple_value': result.scalars,
1475              'image': result.images,
1476              'histo': result.histograms,
1477              'tensor': result.tensors,
1478          }.get(kind)
1479          if container is None:
1480            raise ValueError(
1481                'Unexpected summary kind %r in event file %s:\n%r'
1482                % (kind, path, event))
1483          elif kind == 'tensor' and tag != 'keras':
1484            # Check for V2 scalar summaries, which have a different PB
1485            # structure.
1486            if event.summary.value[
1487                0].metadata.plugin_data.plugin_name == 'scalars':
1488              container = result.scalars
1489          container.add(_ObservedSummary(logdir=dirpath, tag=tag))
1490  return result
1491
1492
1493@keras_parameterized.run_with_all_model_types
1494@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
1495class TestTensorBoardV2(keras_parameterized.TestCase):
1496
1497  def setUp(self):
1498    super(TestTensorBoardV2, self).setUp()
1499    self.logdir = os.path.join(self.get_temp_dir(), 'tb')
1500    self.train_dir = os.path.join(self.logdir, 'train')
1501    self.validation_dir = os.path.join(self.logdir, 'validation')
1502
1503  def _get_model(self):
1504    layers = [
1505        keras.layers.Conv2D(8, (3, 3)),
1506        keras.layers.Flatten(),
1507        keras.layers.Dense(1)
1508    ]
1509    model = testing_utils.get_model_from_layers(layers, input_shape=(10, 10, 1))
1510    opt = gradient_descent.SGD(learning_rate=0.001)
1511    model.compile(
1512        opt,
1513        'mse',
1514        run_eagerly=testing_utils.should_run_eagerly(),
1515        experimental_run_tf_function=testing_utils.should_run_tf_function())
1516    return model
1517
1518  def test_TensorBoard_default_logdir(self):
1519    """Regression test for cross-platform pathsep in default logdir."""
1520    os.chdir(self.get_temp_dir())
1521
1522    model = self._get_model()
1523    x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
1524    tb_cbk = keras.callbacks.TensorBoard()  # no logdir specified
1525
1526    model.fit(
1527        x,
1528        y,
1529        batch_size=2,
1530        epochs=2,
1531        validation_data=(x, y),
1532        callbacks=[tb_cbk])
1533
1534    summary_file = list_summaries(logdir='.')
1535    train_dir = os.path.join('.', 'logs', 'train')
1536    validation_dir = os.path.join('.', 'logs', 'validation')
1537    self.assertEqual(
1538        summary_file.scalars, {
1539            _ObservedSummary(logdir=train_dir, tag='epoch_loss'),
1540            _ObservedSummary(logdir=validation_dir, tag='epoch_loss'),
1541        })
1542
1543  def test_TensorBoard_basic(self):
1544    model = self._get_model()
1545    x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
1546    tb_cbk = keras.callbacks.TensorBoard(self.logdir)
1547
1548    model.fit(
1549        x,
1550        y,
1551        batch_size=2,
1552        epochs=2,
1553        validation_data=(x, y),
1554        callbacks=[tb_cbk])
1555
1556    summary_file = list_summaries(self.logdir)
1557    self.assertEqual(
1558        summary_file.scalars, {
1559            _ObservedSummary(logdir=self.train_dir, tag='epoch_loss'),
1560            _ObservedSummary(logdir=self.validation_dir, tag='epoch_loss'),
1561        })
1562
1563  def test_TensorBoard_across_invocations(self):
1564    """Regression test for summary writer resource use-after-free.
1565
1566    See: <https://github.com/tensorflow/tensorflow/issues/25707>
1567    """
1568    model = self._get_model()
1569    x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
1570    tb_cbk = keras.callbacks.TensorBoard(self.logdir)
1571
1572    for _ in (1, 2):
1573      model.fit(
1574          x,
1575          y,
1576          batch_size=2,
1577          epochs=2,
1578          validation_data=(x, y),
1579          callbacks=[tb_cbk])
1580
1581    summary_file = list_summaries(self.logdir)
1582    self.assertEqual(
1583        summary_file.scalars, {
1584            _ObservedSummary(logdir=self.train_dir, tag='epoch_loss'),
1585            _ObservedSummary(logdir=self.validation_dir, tag='epoch_loss'),
1586        })
1587
1588  def test_TensorBoard_no_spurious_event_files(self):
1589    model = self._get_model()
1590    x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
1591    tb_cbk = keras.callbacks.TensorBoard(self.logdir)
1592
1593    model.fit(
1594        x,
1595        y,
1596        batch_size=2,
1597        epochs=2,
1598        callbacks=[tb_cbk])
1599
1600    events_file_run_basenames = set()
1601    for (dirpath, dirnames, filenames) in os.walk(self.logdir):
1602      del dirnames  # unused
1603      if any(fn.startswith('events.out.') for fn in filenames):
1604        events_file_run_basenames.add(os.path.basename(dirpath))
1605    self.assertEqual(events_file_run_basenames, {'train'})
1606
1607  def test_TensorBoard_batch_metrics(self):
1608    model = self._get_model()
1609    x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
1610    tb_cbk = keras.callbacks.TensorBoard(self.logdir, update_freq=1)
1611
1612    model.fit(
1613        x,
1614        y,
1615        batch_size=2,
1616        epochs=2,
1617        validation_data=(x, y),
1618        callbacks=[tb_cbk])
1619
1620    summary_file = list_summaries(self.logdir)
1621    self.assertEqual(
1622        summary_file.scalars,
1623        {
1624            _ObservedSummary(logdir=self.train_dir, tag='batch_loss'),
1625            _ObservedSummary(logdir=self.train_dir, tag='epoch_loss'),
1626            _ObservedSummary(logdir=self.validation_dir, tag='epoch_loss'),
1627        },
1628    )
1629
1630  def test_TensorBoard_weight_histograms(self):
1631    model = self._get_model()
1632    x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
1633    tb_cbk = keras.callbacks.TensorBoard(self.logdir, histogram_freq=1)
1634    model_type = testing_utils.get_model_type()
1635
1636    model.fit(
1637        x,
1638        y,
1639        batch_size=2,
1640        epochs=2,
1641        validation_data=(x, y),
1642        callbacks=[tb_cbk])
1643    summary_file = list_summaries(self.logdir)
1644
1645    self.assertEqual(
1646        summary_file.scalars,
1647        {
1648            _ObservedSummary(logdir=self.train_dir, tag='epoch_loss'),
1649            _ObservedSummary(logdir=self.validation_dir, tag='epoch_loss'),
1650        },
1651    )
1652    self.assertEqual(
1653        self._strip_layer_names(summary_file.histograms, model_type),
1654        {
1655            _ObservedSummary(logdir=self.train_dir, tag='bias_0'),
1656            _ObservedSummary(logdir=self.train_dir, tag='kernel_0'),
1657        },
1658    )
1659
1660  def test_TensorBoard_weight_images(self):
1661    model = self._get_model()
1662    x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
1663    tb_cbk = keras.callbacks.TensorBoard(
1664        self.logdir, histogram_freq=1, write_images=True)
1665    model_type = testing_utils.get_model_type()
1666
1667    model.fit(
1668        x,
1669        y,
1670        batch_size=2,
1671        epochs=2,
1672        validation_data=(x, y),
1673        callbacks=[tb_cbk])
1674    summary_file = list_summaries(self.logdir)
1675
1676    self.assertEqual(
1677        summary_file.scalars,
1678        {
1679            _ObservedSummary(logdir=self.train_dir, tag='epoch_loss'),
1680            _ObservedSummary(logdir=self.validation_dir, tag='epoch_loss'),
1681        },
1682    )
1683    self.assertEqual(
1684        self._strip_layer_names(summary_file.histograms, model_type),
1685        {
1686            _ObservedSummary(logdir=self.train_dir, tag='bias_0'),
1687            _ObservedSummary(logdir=self.train_dir, tag='kernel_0'),
1688        },
1689    )
1690    self.assertEqual(
1691        self._strip_layer_names(summary_file.images, model_type),
1692        {
1693            _ObservedSummary(logdir=self.train_dir, tag='bias_0/image/0'),
1694            _ObservedSummary(logdir=self.train_dir, tag='kernel_0/image/0'),
1695            _ObservedSummary(logdir=self.train_dir, tag='kernel_0/image/1'),
1696            _ObservedSummary(logdir=self.train_dir, tag='kernel_0/image/2'),
1697        },
1698    )
1699
1700  def test_custom_summary(self):
1701    if not testing_utils.should_run_tf_function():
1702      self.skipTest('Custom summaries only supported in V2 code path.')
1703
1704    def scalar_v2_mock(name, data, step=None):
1705      """A reimplementation of the scalar plugin to avoid circular deps."""
1706      metadata = summary_pb2.SummaryMetadata()
1707      # Should match value in tensorboard/plugins/scalar/metadata.py.
1708      metadata.plugin_data.plugin_name = 'scalars'
1709      with summary_ops_v2.summary_scope(
1710          name, 'scalar_summary', values=[data, step]) as (tag, _):
1711        return summary_ops_v2.write(
1712            tag=tag,
1713            tensor=math_ops.cast(data, 'float32'),
1714            step=step,
1715            metadata=metadata)
1716
1717    class LayerWithSummary(keras.layers.Layer):
1718
1719      def call(self, x):
1720        scalar_v2_mock('custom_summary', math_ops.reduce_sum(x))
1721        return x
1722
1723    model = testing_utils.get_model_from_layers([LayerWithSummary()],
1724                                                input_shape=(5,),
1725                                                name='model')
1726
1727    model.compile(
1728        'sgd',
1729        'mse',
1730        run_eagerly=testing_utils.should_run_eagerly(),
1731        experimental_run_tf_function=testing_utils.should_run_tf_function())
1732    tb_cbk = keras.callbacks.TensorBoard(self.logdir, update_freq=1)
1733    x, y = np.ones((10, 5)), np.ones((10, 5))
1734    model.fit(x, y, batch_size=2, validation_data=(x, y), callbacks=[tb_cbk])
1735    summary_file = list_summaries(self.logdir)
1736    self.assertEqual(
1737        summary_file.scalars,
1738        {
1739            _ObservedSummary(logdir=self.train_dir, tag='epoch_loss'),
1740            _ObservedSummary(logdir=self.validation_dir, tag='epoch_loss'),
1741            _ObservedSummary(logdir=self.train_dir, tag='batch_loss'),
1742            _ObservedSummary(
1743                logdir=self.train_dir,
1744                tag='model/layer_with_summary/custom_summary'),
1745            _ObservedSummary(
1746                logdir=self.validation_dir,
1747                tag='model/layer_with_summary/custom_summary')
1748        },
1749    )
1750
1751  def _strip_layer_names(self, summaries, model_type):
1752    """Deduplicate summary names modulo layer prefix.
1753
1754    This removes the first slash-component of each tag name: for
1755    instance, "foo/bar/baz" becomes "bar/baz".
1756
1757    Args:
1758      summaries: A `set` of `_ObservedSummary` values.
1759      model_type: The model type currently being tested.
1760
1761    Returns:
1762      A new `set` of `_ObservedSummary` values with layer prefixes
1763      removed.
1764    """
1765    result = set()
1766    for summary in summaries:
1767      if '/' not in summary.tag:
1768        raise ValueError('tag has no layer name: %r' % summary.tag)
1769      start_from = 2 if 'subclass' in model_type else 1
1770      new_tag = '/'.join(summary.tag.split('/')[start_from:])
1771      result.add(summary._replace(tag=new_tag))
1772    return result
1773
1774  def test_TensorBoard_invalid_argument(self):
1775    with self.assertRaisesRegexp(ValueError, 'Unrecognized arguments'):
1776      keras.callbacks.TensorBoard(wwrite_images=True)
1777
1778
1779# Note that this test specifies model_type explicitly.
1780@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
1781class TestTensorBoardV2NonParameterizedTest(keras_parameterized.TestCase):
1782
1783  def setUp(self):
1784    super(TestTensorBoardV2NonParameterizedTest, self).setUp()
1785    self.logdir = os.path.join(self.get_temp_dir(), 'tb')
1786    self.train_dir = os.path.join(self.logdir, 'train')
1787    self.validation_dir = os.path.join(self.logdir, 'validation')
1788
1789  def _get_seq_model(self):
1790    model = keras.models.Sequential([
1791        keras.layers.Conv2D(8, (3, 3), input_shape=(10, 10, 1)),
1792        keras.layers.Flatten(),
1793        keras.layers.Dense(1),
1794    ])
1795    opt = gradient_descent.SGD(learning_rate=0.001)
1796    model.compile(
1797        opt,
1798        'mse',
1799        run_eagerly=testing_utils.should_run_eagerly(),
1800        experimental_run_tf_function=testing_utils.should_run_tf_function())
1801    return model
1802
1803  def fitModelAndAssertKerasModelWritten(self, model):
1804    x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
1805    tb_cbk = keras.callbacks.TensorBoard(self.logdir,
1806                                         write_graph=True,
1807                                         profile_batch=0)
1808    model.fit(
1809        x,
1810        y,
1811        batch_size=2,
1812        epochs=2,
1813        validation_data=(x, y),
1814        callbacks=[tb_cbk])
1815    summary_file = list_summaries(self.logdir)
1816    self.assertEqual(
1817        summary_file.tensors,
1818        {
1819            _ObservedSummary(logdir=self.train_dir, tag='keras'),
1820        },
1821    )
1822
1823  def test_TensorBoard_writeSequentialModel_noInputShape(self):
1824    model = keras.models.Sequential([
1825        keras.layers.Conv2D(8, (3, 3)),
1826        keras.layers.Flatten(),
1827        keras.layers.Dense(1),
1828    ])
1829    model.compile('sgd', 'mse', run_eagerly=False)
1830    self.fitModelAndAssertKerasModelWritten(model)
1831
1832  def test_TensorBoard_writeSequentialModel_withInputShape(self):
1833    model = keras.models.Sequential([
1834        keras.layers.Conv2D(8, (3, 3), input_shape=(10, 10, 1)),
1835        keras.layers.Flatten(),
1836        keras.layers.Dense(1),
1837    ])
1838    model.compile('sgd', 'mse', run_eagerly=False)
1839    self.fitModelAndAssertKerasModelWritten(model)
1840
1841  def test_TensoriBoard_writeModel(self):
1842    inputs = keras.layers.Input([10, 10, 1])
1843    x = keras.layers.Conv2D(8, (3, 3), activation='relu')(inputs)
1844    x = keras.layers.Flatten()(x)
1845    x = keras.layers.Dense(1)(x)
1846    model = keras.models.Model(inputs=inputs, outputs=[x])
1847    model.compile('sgd', 'mse', run_eagerly=False)
1848    self.fitModelAndAssertKerasModelWritten(model)
1849
1850  def test_TensorBoard_autoTrace(self):
1851    model = self._get_seq_model()
1852    x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
1853    tb_cbk = keras.callbacks.TensorBoard(
1854        self.logdir, histogram_freq=1, profile_batch=1, write_graph=False)
1855
1856    model.fit(
1857        x,
1858        y,
1859        batch_size=2,
1860        epochs=2,
1861        validation_data=(x, y),
1862        callbacks=[tb_cbk])
1863    summary_file = list_summaries(self.logdir)
1864
1865    self.assertEqual(
1866        summary_file.tensors,
1867        {
1868            _ObservedSummary(logdir=self.train_dir, tag=u'batch_1'),
1869        },
1870    )
1871
1872  def test_TensorBoard_autoTrace_tagNameWithBatchNum(self):
1873    model = self._get_seq_model()
1874    x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
1875    tb_cbk = keras.callbacks.TensorBoard(
1876        self.logdir, histogram_freq=1, profile_batch=2, write_graph=False)
1877
1878    model.fit(
1879        x,
1880        y,
1881        batch_size=2,
1882        epochs=2,
1883        validation_data=(x, y),
1884        callbacks=[tb_cbk])
1885    summary_file = list_summaries(self.logdir)
1886
1887    self.assertEqual(
1888        summary_file.tensors,
1889        {
1890            _ObservedSummary(logdir=self.train_dir, tag=u'batch_2'),
1891        },
1892    )
1893
1894  def test_TensorBoard_autoTrace_profile_batch_largerThanBatchCount(self):
1895    model = self._get_seq_model()
1896    x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
1897    tb_cbk = keras.callbacks.TensorBoard(
1898        self.logdir, histogram_freq=1, profile_batch=10000, write_graph=False)
1899
1900    model.fit(
1901        x,
1902        y,
1903        batch_size=2,
1904        epochs=2,
1905        validation_data=(x, y),
1906        callbacks=[tb_cbk])
1907    summary_file = list_summaries(self.logdir)
1908
1909    # Enabled trace only on the 10000th batch, thus it should be empty.
1910    self.assertEmpty(summary_file.tensors)
1911
1912
1913class MostRecentlyModifiedFileMatchingPatternTest(test.TestCase):
1914
1915  def test_get_most_recently_modified_file_matching_pattern(self):
1916    file_pattern = 'f.batch{batch:02d}epoch{epoch:02d}.h5'
1917    test_dir = self.get_temp_dir()
1918    path_pattern = os.path.join(test_dir, file_pattern)
1919    file_paths = [
1920        os.path.join(test_dir, file_name) for file_name in
1921        ['f.batch03epoch02.h5', 'f.batch02epoch02.h5', 'f.batch01epoch01.h5']
1922    ]
1923    for file_path in file_paths:
1924      with open(file_path, 'w') as f:
1925        # Ensure there are some intervals between file creation.
1926        time.sleep(2)
1927        f.write('foo bar')
1928    # Ensure the files have been actually written.
1929    self.assertEqual(
1930        set([
1931            os.path.join(test_dir, file_name)
1932            for file_name in os.listdir(test_dir)
1933        ]), set(file_paths))
1934    self.assertEqual(
1935        keras.callbacks.ModelCheckpoint(None)
1936        ._get_most_recently_modified_file_matching_pattern(path_pattern),
1937        file_paths[-1])
1938
1939  def test_some_file_not_matching_pattern(self):
1940    file_pattern = 'f.batch{batch:02d}epoch{epoch:02d}.h5'
1941    test_dir = self.get_temp_dir()
1942    path_pattern = os.path.join(test_dir, file_pattern)
1943    file_paths = [
1944        os.path.join(test_dir, file_name) for file_name in
1945        ['f.batch03epoch02.h5', 'f.batch02epoch02.h5', 'f.baatch01epoch01.h5']
1946    ]
1947    for file_path in file_paths:
1948      with open(file_path, 'w') as f:
1949        # Ensure there are some intervals between file creation.
1950        time.sleep(2)
1951        f.write('foo bar')
1952    self.assertEqual(
1953        keras.callbacks.ModelCheckpoint(None)
1954        ._get_most_recently_modified_file_matching_pattern(path_pattern),
1955        file_paths[-2])
1956
1957  def test_get_same_file_if_file_name_equals_pattern(self):
1958    file_name = 'f.batch02.h5'
1959    test_dir = self.get_temp_dir()
1960    file_path = os.path.join(test_dir, file_name)
1961    with open(file_path, 'w') as f:
1962      f.write('foo bar')
1963    self.assertEqual(os.path.join(test_dir, os.listdir(test_dir)[0]), file_path)
1964    self.assertEqual(
1965        keras.callbacks.ModelCheckpoint(
1966            None)._get_most_recently_modified_file_matching_pattern(file_path),
1967        file_path)
1968
1969  def test_get_none_if_file_does_not_exist(self):
1970    file_name = 'f.batch02.h5'
1971    test_dir = self.get_temp_dir()
1972    file_path = os.path.join(test_dir, file_name)
1973    self.assertLen(os.listdir(test_dir), 0)
1974    self.assertEqual(
1975        keras.callbacks.ModelCheckpoint(
1976            None)._get_most_recently_modified_file_matching_pattern(file_path),
1977        None)
1978
1979  def test_using_checkpoint_management_latest_checkpoint(self):
1980    file_pattern = 'f.batch{batch:02d}epoch{epoch:02d}'
1981    ckpt_file_name = 'f.batchXepochY'
1982    test_dir = self.get_temp_dir()
1983    path_pattern = os.path.join(test_dir, file_pattern)
1984    ckpt_file_path = os.path.join(test_dir, ckpt_file_name)
1985    with open(ckpt_file_path, 'w') as f:
1986      f.write('dummy ckpt')
1987    checkpoint_management.update_checkpoint_state_internal(
1988        test_dir, ckpt_file_path)
1989
1990    file_paths = [
1991        os.path.join(test_dir, file_name)
1992        for file_name in ['f.batch03epoch02', 'f.batch02epoch02']
1993    ]
1994    for file_path in file_paths:
1995      with open(file_path, 'w') as f:
1996        f.write('foo bar')
1997
1998    # The result returned from checkpoint_management.latest_checkpoint takes
1999    # priority, so even if it was written earlier, we should still return that.
2000    self.assertEqual(
2001        keras.callbacks.ModelCheckpoint(None)
2002        ._get_most_recently_modified_file_matching_pattern(path_pattern),
2003        ckpt_file_path)
2004
2005
2006if __name__ == '__main__':
2007  test.main()
2008