1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for Keras callbacks.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import csv 23import json 24import os 25import re 26import shutil 27import sys 28import threading 29import time 30import unittest 31 32from absl.testing import parameterized 33import numpy as np 34 35from tensorflow.core.framework import summary_pb2 36from tensorflow.python import keras 37from tensorflow.python.data.ops import dataset_ops 38from tensorflow.python.framework import random_seed 39from tensorflow.python.keras import keras_parameterized 40from tensorflow.python.keras import testing_utils 41from tensorflow.python.keras.engine import sequential 42from tensorflow.python.keras.optimizer_v2 import gradient_descent 43from tensorflow.python.keras.optimizer_v2 import learning_rate_schedule 44from tensorflow.python.keras.utils import np_utils 45from tensorflow.python.ops import array_ops 46from tensorflow.python.ops import math_ops 47from tensorflow.python.ops import summary_ops_v2 48from tensorflow.python.platform import test 49from tensorflow.python.platform import tf_logging as logging 50from tensorflow.python.summary import summary_iterator 51from tensorflow.python.training import adam 52from tensorflow.python.training import checkpoint_management 53 54try: 55 import h5py # pylint:disable=g-import-not-at-top 56except ImportError: 57 h5py = None 58 59try: 60 import requests # pylint:disable=g-import-not-at-top 61except ImportError: 62 requests = None 63 64 65TRAIN_SAMPLES = 10 66TEST_SAMPLES = 10 67NUM_CLASSES = 2 68INPUT_DIM = 3 69NUM_HIDDEN = 5 70BATCH_SIZE = 5 71 72 73class Counter(keras.callbacks.Callback): 74 """Counts the number of times each callback method was run. 75 76 Attributes: 77 method_counts: dict. Contains the counts of time each callback method was 78 run. 79 """ 80 81 def __init__(self): 82 self.method_counts = collections.defaultdict(int) 83 methods_to_count = [ 84 'on_batch_begin', 'on_batch_end', 'on_epoch_begin', 'on_epoch_end', 85 'on_predict_batch_begin', 'on_predict_batch_end', 'on_predict_begin', 86 'on_predict_end', 'on_test_batch_begin', 'on_test_batch_end', 87 'on_test_begin', 'on_test_end', 'on_train_batch_begin', 88 'on_train_batch_end', 'on_train_begin', 'on_train_end' 89 ] 90 for method_name in methods_to_count: 91 setattr(self, method_name, 92 self.wrap_with_counts(method_name, getattr(self, method_name))) 93 94 def wrap_with_counts(self, method_name, method): 95 96 def _call_and_count(*args, **kwargs): 97 self.method_counts[method_name] += 1 98 return method(*args, **kwargs) 99 100 return _call_and_count 101 102 103def _get_numpy(): 104 return np.ones((10, 10)), np.ones((10, 1)) 105 106 107def _get_sequence(): 108 109 class MySequence(keras.utils.data_utils.Sequence): 110 111 def __getitem__(self, _): 112 return np.ones((2, 10)), np.ones((2, 1)) 113 114 def __len__(self): 115 return 5 116 117 return MySequence(), None 118 119 120@keras_parameterized.run_with_all_model_types 121@keras_parameterized.run_all_keras_modes 122class CallbackCountsTest(keras_parameterized.TestCase): 123 124 def _check_counts(self, counter, expected_counts): 125 """Checks that the counts registered by `counter` are those expected.""" 126 for method_name, expected_count in expected_counts.items(): 127 self.assertEqual( 128 counter.method_counts[method_name], 129 expected_count, 130 msg='For method {}: expected {}, got: {}'.format( 131 method_name, expected_count, counter.method_counts[method_name])) 132 133 def _get_model(self): 134 layers = [ 135 keras.layers.Dense(10, activation='relu'), 136 keras.layers.Dense(1, activation='sigmoid') 137 ] 138 model = testing_utils.get_model_from_layers(layers, input_shape=(10,)) 139 model.compile( 140 adam.AdamOptimizer(0.001), 141 'binary_crossentropy', 142 run_eagerly=testing_utils.should_run_eagerly(), 143 experimental_run_tf_function=testing_utils.should_run_tf_function()) 144 return model 145 146 @parameterized.named_parameters(('with_numpy', _get_numpy()), 147 ('with_sequence', _get_sequence())) 148 def test_callback_hooks_are_called_in_fit(self, data): 149 x, y = data 150 val_x, val_y = np.ones((4, 10)), np.ones((4, 1)) 151 is_sequence = isinstance(x, keras.utils.data_utils.Sequence) 152 153 model = self._get_model() 154 counter = Counter() 155 model.fit( 156 x, 157 y, 158 validation_data=(val_x, val_y), 159 batch_size=2 if not is_sequence else None, 160 steps_per_epoch=5 if is_sequence else None, 161 epochs=5, 162 callbacks=[counter]) 163 164 self._check_counts( 165 counter, { 166 'on_batch_begin': 25, 167 'on_batch_end': 25, 168 'on_epoch_begin': 5, 169 'on_epoch_end': 5, 170 'on_predict_batch_begin': 0, 171 'on_predict_batch_end': 0, 172 'on_predict_begin': 0, 173 'on_predict_end': 0, 174 'on_test_batch_begin': 10, 175 'on_test_batch_end': 10, 176 'on_test_begin': 5, 177 'on_test_end': 5, 178 'on_train_batch_begin': 25, 179 'on_train_batch_end': 25, 180 'on_train_begin': 1, 181 'on_train_end': 1 182 }) 183 184 @parameterized.named_parameters(('with_numpy', _get_numpy()), 185 ('with_sequence', _get_sequence())) 186 def test_callback_hooks_are_called_in_evaluate(self, data): 187 x, y = data 188 is_sequence = isinstance(x, keras.utils.data_utils.Sequence) 189 190 model = self._get_model() 191 counter = Counter() 192 model.evaluate( 193 x, 194 y, 195 batch_size=2 if not is_sequence else None, 196 steps=5 if is_sequence else None, 197 callbacks=[counter]) 198 self._check_counts( 199 counter, { 200 'on_test_batch_begin': 5, 201 'on_test_batch_end': 5, 202 'on_test_begin': 1, 203 'on_test_end': 1 204 }) 205 206 @parameterized.named_parameters(('with_numpy', _get_numpy()), 207 ('with_sequence', _get_sequence())) 208 def test_callback_hooks_are_called_in_predict(self, data): 209 x = data[0] 210 is_sequence = isinstance(x, keras.utils.data_utils.Sequence) 211 212 model = self._get_model() 213 counter = Counter() 214 model.predict( 215 x, 216 batch_size=2 if not is_sequence else None, 217 steps=5 if is_sequence else None, 218 callbacks=[counter]) 219 self._check_counts( 220 counter, { 221 'on_predict_batch_begin': 5, 222 'on_predict_batch_end': 5, 223 'on_predict_begin': 1, 224 'on_predict_end': 1 225 }) 226 227 def test_callback_list_methods(self): 228 counter = Counter() 229 callback_list = keras.callbacks.CallbackList([counter]) 230 231 batch = 0 232 callback_list.on_test_batch_begin(batch) 233 callback_list.on_test_batch_end(batch) 234 callback_list.on_predict_batch_begin(batch) 235 callback_list.on_predict_batch_end(batch) 236 237 self._check_counts( 238 counter, { 239 'on_test_batch_begin': 1, 240 'on_test_batch_end': 1, 241 'on_predict_batch_begin': 1, 242 'on_predict_batch_end': 1 243 }) 244 245 246class KerasCallbacksTest(keras_parameterized.TestCase): 247 248 def _get_model(self, input_shape=None): 249 layers = [ 250 keras.layers.Dense(3, activation='relu'), 251 keras.layers.Dense(2, activation='softmax') 252 ] 253 model = testing_utils.get_model_from_layers(layers, input_shape=input_shape) 254 model.compile( 255 loss='mse', 256 optimizer='rmsprop', 257 metrics=[keras.metrics.CategoricalAccuracy(name='my_acc')], 258 run_eagerly=testing_utils.should_run_eagerly(), 259 experimental_run_tf_function=testing_utils.should_run_tf_function()) 260 return model 261 262 @keras_parameterized.run_with_all_model_types 263 @keras_parameterized.run_all_keras_modes 264 def test_progbar_logging(self): 265 model = self._get_model(input_shape=(3,)) 266 267 x = array_ops.ones((50, 3)) 268 y = array_ops.zeros((50, 2)) 269 dataset = dataset_ops.Dataset.from_tensor_slices((x, y)).batch(10) 270 expected_log = r'(.*- loss:.*- my_acc:.*)+' 271 272 with self.captureWritesToStream(sys.stdout) as printed: 273 model.fit(dataset, epochs=2, steps_per_epoch=10) 274 self.assertRegexpMatches(printed.contents(), expected_log) 275 276 @keras_parameterized.run_with_all_model_types(exclude_models='functional') 277 @keras_parameterized.run_all_keras_modes 278 def test_progbar_logging_deferred_model_build(self): 279 model = self._get_model() 280 self.assertFalse(model.built) 281 282 x = array_ops.ones((50, 3)) 283 y = array_ops.zeros((50, 2)) 284 dataset = dataset_ops.Dataset.from_tensor_slices((x, y)).batch(10) 285 expected_log = r'(.*- loss:.*- my_acc:.*)+' 286 287 with self.captureWritesToStream(sys.stdout) as printed: 288 model.fit(dataset, epochs=2, steps_per_epoch=10) 289 self.assertRegexpMatches(printed.contents(), expected_log) 290 291 @keras_parameterized.run_with_all_model_types 292 @keras_parameterized.run_all_keras_modes 293 def test_progbar_logging_validation_data(self): 294 model = self._get_model(input_shape=(3,)) 295 296 x = array_ops.ones((50, 3)) 297 y = array_ops.zeros((50, 2)) 298 training_dataset = dataset_ops.Dataset.from_tensor_slices((x, y)).batch(10) 299 val_dataset = dataset_ops.Dataset.from_tensor_slices((x, y)).batch(10) 300 expected_log = r'(.*5/5.*- loss:.*- my_acc:.*- val_loss:.*- val_my_acc:.*)+' 301 302 with self.captureWritesToStream(sys.stdout) as printed: 303 model.fit(training_dataset, epochs=2, validation_data=val_dataset) 304 self.assertRegexpMatches(printed.contents(), expected_log) 305 306 @keras_parameterized.run_with_all_model_types 307 @keras_parameterized.run_all_keras_modes 308 def test_progbar_logging_validation_split(self): 309 model = self._get_model(input_shape=(3,)) 310 311 x = np.ones((100, 3)) 312 y = np.zeros((100, 2)) 313 expected_log = ( 314 r'(?s).*1/2.*80/80.*- loss:.*- my_acc:.*- val_loss:.*- val_my_acc:' 315 r'.*2/2.*80/80.*- loss:.*- my_acc:.*- val_loss:.*- val_my_acc:.*') 316 317 with self.captureWritesToStream(sys.stdout) as printed: 318 model.fit(x, y, batch_size=10, epochs=2, validation_split=0.2) 319 self.assertRegexpMatches(printed.contents(), expected_log) 320 321 @keras_parameterized.run_with_all_model_types 322 @keras_parameterized.run_all_keras_modes(always_skip_v1=True) 323 def test_progbar_logging_training_validation(self): 324 model = self._get_model(input_shape=(2,)) 325 326 def generator(): 327 for _ in range(100): 328 yield [1, 1], 1 329 330 training = dataset_ops.Dataset \ 331 .from_generator( 332 generator=generator, 333 output_types=('float64', 'float64'), 334 output_shapes=([2], [])) \ 335 .batch(2) \ 336 .repeat() 337 validation = dataset_ops.Dataset \ 338 .from_generator( 339 generator=generator, 340 output_types=('float64', 'float64'), 341 output_shapes=([2], [])) \ 342 .batch(2) 343 expected_log = ( 344 r'(?s).*1/2.*20/20.*- loss:.*- my_acc:.*- val_loss:.*- val_my_acc:' 345 r'.*2/2.*20/20.*- loss:.*- my_acc:.*- val_loss:.*- val_my_acc:.*') 346 347 with self.captureWritesToStream(sys.stdout) as printed: 348 model.fit( 349 x=training, validation_data=validation, epochs=2, steps_per_epoch=20) 350 self.assertRegexpMatches(printed.contents(), expected_log) 351 352 @keras_parameterized.run_with_all_model_types 353 @keras_parameterized.run_all_keras_modes(always_skip_v1=True) 354 def test_progbar_logging_with_dataset_and_partial_batch(self): 355 model = self._get_model(input_shape=(2,)) 356 357 def generator(): 358 # Have a partial batch at the end. 359 for _ in range(9): 360 yield np.random.random(2), 1 361 362 training = dataset_ops.Dataset \ 363 .from_generator( 364 generator=generator, 365 output_types=('float64', 'float64'), 366 output_shapes=([2], [])) \ 367 .batch(2) 368 validation = dataset_ops.Dataset \ 369 .from_generator( 370 generator=generator, 371 output_types=('float64', 'float64'), 372 output_shapes=([2], [])) \ 373 .batch(2) 374 375 with self.captureWritesToStream(sys.stdout) as printed: 376 model.fit(x=training, validation_data=validation) 377 378 # Make sure the value of val_ metrics are not zeros. 379 log_content = printed.contents() 380 val_loss = re.findall(r'val_loss: (\d\.\d+)', log_content) 381 self.assertLen(val_loss, 1) 382 self.assertGreater(float(val_loss[0]), 0.0) 383 384 @keras_parameterized.run_with_all_model_types 385 def test_ModelCheckpoint(self): 386 if h5py is None: 387 return # Skip test if models cannot be saved. 388 389 layers = [ 390 keras.layers.Dense(NUM_HIDDEN, input_dim=INPUT_DIM, activation='relu'), 391 keras.layers.Dense(NUM_CLASSES, activation='softmax') 392 ] 393 model = testing_utils.get_model_from_layers(layers, input_shape=(10,)) 394 model.compile( 395 loss='categorical_crossentropy', optimizer='rmsprop', metrics=['acc']) 396 397 temp_dir = self.get_temp_dir() 398 self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True) 399 400 filepath = os.path.join(temp_dir, 'checkpoint.h5') 401 (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data( 402 train_samples=TRAIN_SAMPLES, 403 test_samples=TEST_SAMPLES, 404 input_shape=(INPUT_DIM,), 405 num_classes=NUM_CLASSES) 406 y_test = np_utils.to_categorical(y_test) 407 y_train = np_utils.to_categorical(y_train) 408 # case 1 409 monitor = 'val_loss' 410 save_best_only = False 411 mode = 'auto' 412 413 model = keras.models.Sequential() 414 model.add( 415 keras.layers.Dense( 416 NUM_HIDDEN, input_dim=INPUT_DIM, activation='relu')) 417 model.add(keras.layers.Dense(NUM_CLASSES, activation='softmax')) 418 model.compile( 419 loss='categorical_crossentropy', optimizer='rmsprop', metrics=['acc']) 420 421 cbks = [ 422 keras.callbacks.ModelCheckpoint( 423 filepath, 424 monitor=monitor, 425 save_best_only=save_best_only, 426 mode=mode) 427 ] 428 model.fit( 429 x_train, 430 y_train, 431 batch_size=BATCH_SIZE, 432 validation_data=(x_test, y_test), 433 callbacks=cbks, 434 epochs=1, 435 verbose=0) 436 assert os.path.exists(filepath) 437 os.remove(filepath) 438 439 # case 2 440 mode = 'min' 441 cbks = [ 442 keras.callbacks.ModelCheckpoint( 443 filepath, 444 monitor=monitor, 445 save_best_only=save_best_only, 446 mode=mode) 447 ] 448 model.fit( 449 x_train, 450 y_train, 451 batch_size=BATCH_SIZE, 452 validation_data=(x_test, y_test), 453 callbacks=cbks, 454 epochs=1, 455 verbose=0) 456 assert os.path.exists(filepath) 457 os.remove(filepath) 458 459 # case 3 460 mode = 'max' 461 monitor = 'val_acc' 462 cbks = [ 463 keras.callbacks.ModelCheckpoint( 464 filepath, 465 monitor=monitor, 466 save_best_only=save_best_only, 467 mode=mode) 468 ] 469 model.fit( 470 x_train, 471 y_train, 472 batch_size=BATCH_SIZE, 473 validation_data=(x_test, y_test), 474 callbacks=cbks, 475 epochs=1, 476 verbose=0) 477 assert os.path.exists(filepath) 478 os.remove(filepath) 479 480 # case 4 481 save_best_only = True 482 cbks = [ 483 keras.callbacks.ModelCheckpoint( 484 filepath, 485 monitor=monitor, 486 save_best_only=save_best_only, 487 mode=mode) 488 ] 489 model.fit( 490 x_train, 491 y_train, 492 batch_size=BATCH_SIZE, 493 validation_data=(x_test, y_test), 494 callbacks=cbks, 495 epochs=1, 496 verbose=0) 497 assert os.path.exists(filepath) 498 os.remove(filepath) 499 500 # Case: metric not available. 501 cbks = [ 502 keras.callbacks.ModelCheckpoint( 503 filepath, 504 monitor='unknown', 505 save_best_only=True) 506 ] 507 model.fit( 508 x_train, 509 y_train, 510 batch_size=BATCH_SIZE, 511 validation_data=(x_test, y_test), 512 callbacks=cbks, 513 epochs=1, 514 verbose=0) 515 # File won't be written. 516 assert not os.path.exists(filepath) 517 518 # case 5 519 save_best_only = False 520 period = 2 521 mode = 'auto' 522 523 filepath = os.path.join(temp_dir, 'checkpoint.{epoch:02d}.h5') 524 cbks = [ 525 keras.callbacks.ModelCheckpoint( 526 filepath, 527 monitor=monitor, 528 save_best_only=save_best_only, 529 mode=mode, 530 period=period) 531 ] 532 model.fit( 533 x_train, 534 y_train, 535 batch_size=BATCH_SIZE, 536 validation_data=(x_test, y_test), 537 callbacks=cbks, 538 epochs=4, 539 verbose=1) 540 assert os.path.exists(filepath.format(epoch=2)) 541 assert os.path.exists(filepath.format(epoch=4)) 542 os.remove(filepath.format(epoch=2)) 543 os.remove(filepath.format(epoch=4)) 544 assert not os.path.exists(filepath.format(epoch=1)) 545 assert not os.path.exists(filepath.format(epoch=3)) 546 547 # Invalid use: this will raise a warning but not an Exception. 548 keras.callbacks.ModelCheckpoint( 549 filepath, 550 monitor=monitor, 551 save_best_only=save_best_only, 552 mode='unknown') 553 554 # Case 6: `ModelCheckpoint` with a combination of `save_freq` and `period`. 555 # Though `period` is deprecated, we're testing it for 556 # backward-compatibility. 557 filepath = os.path.join(temp_dir, 'checkpoint.epoch{epoch:02d}.h5') 558 cbks = [ 559 keras.callbacks.ModelCheckpoint( 560 filepath, monitor=monitor, mode=mode, save_freq='epoch', period=5) 561 ] 562 assert not os.path.exists(filepath.format(epoch=0)) 563 assert not os.path.exists(filepath.format(epoch=5)) 564 model.fit( 565 x_train, 566 y_train, 567 batch_size=2, 568 validation_data=(x_test, y_test), 569 callbacks=cbks, 570 epochs=10, 571 verbose=1) 572 assert not os.path.exists(filepath.format(epoch=1)) 573 assert not os.path.exists(filepath.format(epoch=2)) 574 assert not os.path.exists(filepath.format(epoch=3)) 575 assert not os.path.exists(filepath.format(epoch=4)) 576 assert os.path.exists(filepath.format(epoch=5)) 577 assert not os.path.exists(filepath.format(epoch=6)) 578 assert os.path.exists(filepath.format(epoch=10)) 579 os.remove(filepath.format(epoch=5)) 580 os.remove(filepath.format(epoch=10)) 581 582 # Case 7: `ModelCheckpoint` with an integer `save_freq` 583 filepath = os.path.join(temp_dir, 'checkpoint.epoch{epoch:02d}.h5') 584 cbks = [ 585 keras.callbacks.ModelCheckpoint( 586 filepath, 587 monitor=monitor, 588 save_best_only=save_best_only, 589 mode=mode, 590 save_freq=30, 591 period=100) # The period should be ignored (this test tests this). 592 ] 593 assert not os.path.exists(filepath.format(epoch=3)) 594 model.fit( 595 x_train, 596 y_train, 597 batch_size=2, 598 validation_data=(x_test, y_test), 599 callbacks=cbks, 600 epochs=10, 601 verbose=1) 602 assert not os.path.exists(filepath.format(epoch=1)) 603 assert not os.path.exists(filepath.format(epoch=2)) 604 assert os.path.exists(filepath.format(epoch=3)) 605 assert not os.path.exists(filepath.format(epoch=4)) 606 assert not os.path.exists(filepath.format(epoch=5)) 607 assert os.path.exists(filepath.format(epoch=6)) 608 assert not os.path.exists(filepath.format(epoch=7)) 609 assert not os.path.exists(filepath.format(epoch=8)) 610 assert os.path.exists(filepath.format(epoch=9)) 611 os.remove(filepath.format(epoch=3)) 612 os.remove(filepath.format(epoch=6)) 613 os.remove(filepath.format(epoch=9)) 614 615 # Case 8: `ModelCheckpoint` with valid and invalid save_freq argument. 616 with self.assertRaisesRegexp(ValueError, 'Unrecognized save_freq'): 617 keras.callbacks.ModelCheckpoint( 618 filepath, 619 monitor=monitor, 620 save_best_only=save_best_only, 621 mode=mode, 622 save_freq='invalid_save_freq') 623 # The following should not raise ValueError. 624 keras.callbacks.ModelCheckpoint( 625 filepath, 626 monitor=monitor, 627 save_best_only=save_best_only, 628 mode=mode, 629 save_freq='epoch') 630 keras.callbacks.ModelCheckpoint( 631 filepath, 632 monitor=monitor, 633 save_best_only=save_best_only, 634 mode=mode, 635 save_freq=3) 636 637 def _get_dummy_resource_for_model_checkpoint_testing(self): 638 639 def get_input_datasets(): 640 # Simple training input. 641 train_input = [[1]] * 16 642 train_label = [[0]] * 16 643 ds = dataset_ops.Dataset.from_tensor_slices((train_input, train_label)) 644 return ds.batch(8, drop_remainder=True) 645 646 # Very simple bias model to eliminate randomness. 647 optimizer = gradient_descent.SGD(0.1) 648 model = sequential.Sequential() 649 model.add(testing_utils.Bias(input_shape=(1,))) 650 model.compile(loss='mae', optimizer=optimizer, metrics=['mae']) 651 train_ds = get_input_datasets() 652 653 temp_dir = self.get_temp_dir() 654 filepath = os.path.join(temp_dir, 'checkpoint.epoch{epoch:02d}.h5') 655 656 # The filepath shouldn't exist at the beginning. 657 self.assertFalse(os.path.exists(filepath)) 658 callback = keras.callbacks.ModelCheckpoint( 659 filepath=filepath, save_weights_only=True) 660 661 return model, train_ds, callback, filepath 662 663 def _run_load_weights_on_restart_test_common_iterations(self): 664 665 (model, train_ds, callback, 666 filepath) = self._get_dummy_resource_for_model_checkpoint_testing() 667 initial_epochs = 3 668 model.fit(train_ds, epochs=initial_epochs, callbacks=[callback]) 669 670 # The files should exist after fitting with callback. 671 for epoch in range(initial_epochs): 672 self.assertTrue(os.path.exists(filepath.format(epoch=epoch + 1))) 673 self.assertFalse(os.path.exists(filepath.format(epoch=initial_epochs + 1))) 674 self.assertEqual( 675 callback._get_most_recently_modified_file_matching_pattern(filepath), 676 filepath.format(epoch=initial_epochs)) 677 678 model.fit(train_ds, epochs=1) 679 weights_after_one_more_epoch = model.get_weights() 680 681 # The filepath should continue to exist after fitting without callback. 682 for epoch in range(initial_epochs): 683 self.assertTrue(os.path.exists(filepath.format(epoch=epoch + 1))) 684 685 return model, train_ds, filepath, weights_after_one_more_epoch 686 687 @staticmethod 688 def get_ModelCheckpoint_load_weights_on_restart_true_test(save_weights_only): 689 690 def func(self): 691 (model, train_ds, filepath, weights_after_one_more_epoch 692 ) = self._run_load_weights_on_restart_test_common_iterations() 693 694 # Sleep for some short time period ensuring the files are created with 695 # a different time (in MacOS OSS the granularity is only 1 second). 696 time.sleep(2) 697 callback = keras.callbacks.ModelCheckpoint( 698 filepath=filepath, 699 save_weights_only=save_weights_only, 700 load_weights_on_restart=True) 701 model.fit(train_ds, epochs=1, callbacks=[callback]) 702 weights_after_model_restoring_and_one_more_epoch = model.get_weights() 703 704 self.assertEqual( 705 callback._get_most_recently_modified_file_matching_pattern(filepath), 706 filepath.format(epoch=1)) 707 708 model.fit( 709 train_ds, 710 epochs=1, 711 callbacks=[ 712 keras.callbacks.ModelCheckpoint( 713 filepath=filepath, 714 save_weights_only=save_weights_only, 715 load_weights_on_restart=True) 716 ]) 717 weights_with_one_final_extra_epoch = model.get_weights() 718 719 # Asserting the weights one epoch after initial fitting and another epoch 720 # after that are closed, if a ModelCheckpoint with 721 # load_weights_on_restart=True is given (so the model is restored at the 722 # beginning of training). 723 self.assertAllClose(weights_after_one_more_epoch, 724 weights_after_model_restoring_and_one_more_epoch) 725 726 self.assertNotAllClose(weights_after_one_more_epoch, 727 weights_with_one_final_extra_epoch) 728 729 return func 730 731 @staticmethod 732 def get_ModelCheckpoint_load_weights_on_restart_false_test(save_weights_only): 733 734 def func(self): 735 (model, train_ds, filepath, weights_after_one_more_epoch 736 ) = self._run_load_weights_on_restart_test_common_iterations() 737 738 model.fit( 739 train_ds, 740 epochs=1, 741 callbacks=[ 742 keras.callbacks.ModelCheckpoint( 743 filepath=filepath, save_weights_only=save_weights_only) 744 ]) 745 weights_after_model_restoring_and_one_more_epoch = model.get_weights() 746 747 # Asserting the weights one epoch after initial fitting and another epoch 748 # after that are different, if a ModelCheckpoint with 749 # load_weights_on_restart=False is given (so the model is not restored at 750 # the beginning of training). 751 self.assertNotAllClose(weights_after_one_more_epoch, 752 weights_after_model_restoring_and_one_more_epoch) 753 754 return func 755 756 test_model_checkpoint_load_weights_on_restart_true_save_weights_only_true = \ 757 get_ModelCheckpoint_load_weights_on_restart_true_test.__func__(True) 758 759 test_model_checkpoint_load_weights_on_restart_true_save_weights_only_false = \ 760 get_ModelCheckpoint_load_weights_on_restart_true_test.__func__(False) 761 762 test_model_checkpoint_load_weights_on_restart_false_save_weights_only_true = \ 763 get_ModelCheckpoint_load_weights_on_restart_false_test.__func__(True) 764 765 test_model_checkpoint_load_weights_on_restart_false_save_weights_only_false \ 766 = get_ModelCheckpoint_load_weights_on_restart_false_test.__func__(False) 767 768 def test_ModelCheckpoint_override_if_file_exist(self): 769 (model, train_ds, filepath, 770 _) = self._run_load_weights_on_restart_test_common_iterations() 771 772 # Sleep for some short time period to ensure the files are created with 773 # a different time (in MacOS OSS the granularity is only 1 second). 774 time.sleep(2) 775 callback = keras.callbacks.ModelCheckpoint( 776 filepath=filepath, save_weights_only=True) 777 model.load_weights( 778 callback._get_most_recently_modified_file_matching_pattern(filepath)) 779 weights_before_additional_fit = model.get_weights() 780 model.fit(train_ds, epochs=1, callbacks=[callback]) 781 model.load_weights( 782 callback._get_most_recently_modified_file_matching_pattern(filepath)) 783 weights_after_additional_fit = model.get_weights() 784 785 self.assertNotAllClose(weights_before_additional_fit, 786 weights_after_additional_fit) 787 788 def test_fit_with_ModelCheckpoint_with_tf_config(self): 789 (model, train_ds, callback, 790 _) = self._get_dummy_resource_for_model_checkpoint_testing() 791 792 os.environ['TF_CONFIG'] = json.dumps({ 793 'cluster': { 794 'worker': ['localhost:23333'] 795 }, 796 'task': { 797 'type': 'worker', 798 'index': 0 799 } 800 }) 801 802 # `model.fit()` should work regardless of the presence of `TF_CONFIG`. 803 model.fit(train_ds, epochs=1, callbacks=[callback]) 804 805 def test_fit_with_ModelCheckpoint_with_dir_as_h5_filepath(self): 806 (model, train_ds, callback, 807 filepath) = self._get_dummy_resource_for_model_checkpoint_testing() 808 809 temp_dir = self.get_temp_dir() 810 filepath = os.path.join(temp_dir, 'temp.h5') 811 812 self.assertFalse(os.path.exists(filepath)) 813 os.mkdir(filepath) 814 self.assertTrue(os.path.exists(filepath)) 815 816 callback = keras.callbacks.ModelCheckpoint(filepath=filepath) 817 818 with self.assertRaisesRegexp(IOError, 'Please specify a non-directory ' 819 'filepath for ModelCheckpoint.'): 820 model.fit(train_ds, epochs=1, callbacks=[callback]) 821 822 def test_ModelCheckpoint_with_bad_path_placeholders(self): 823 (model, train_ds, callback, 824 filepath) = self._get_dummy_resource_for_model_checkpoint_testing() 825 826 temp_dir = self.get_temp_dir() 827 filepath = os.path.join(temp_dir, 'chkpt_{epoch:02d}_{mape:.2f}.h5') 828 callback = keras.callbacks.ModelCheckpoint(filepath=filepath) 829 830 with self.assertRaisesRegexp(KeyError, 'Failed to format this callback ' 831 'filepath.*'): 832 model.fit(train_ds, epochs=1, callbacks=[callback]) 833 834 def test_EarlyStopping(self): 835 with self.cached_session(): 836 np.random.seed(123) 837 (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data( 838 train_samples=TRAIN_SAMPLES, 839 test_samples=TEST_SAMPLES, 840 input_shape=(INPUT_DIM,), 841 num_classes=NUM_CLASSES) 842 y_test = np_utils.to_categorical(y_test) 843 y_train = np_utils.to_categorical(y_train) 844 model = testing_utils.get_small_sequential_mlp( 845 num_hidden=NUM_HIDDEN, num_classes=NUM_CLASSES, input_dim=INPUT_DIM) 846 model.compile( 847 loss='categorical_crossentropy', optimizer='rmsprop', metrics=['acc']) 848 849 cases = [ 850 ('max', 'val_acc'), 851 ('min', 'val_loss'), 852 ('auto', 'val_acc'), 853 ('auto', 'loss'), 854 ('unknown', 'unknown') 855 ] 856 for mode, monitor in cases: 857 patience = 0 858 cbks = [ 859 keras.callbacks.EarlyStopping( 860 patience=patience, monitor=monitor, mode=mode) 861 ] 862 model.fit( 863 x_train, 864 y_train, 865 batch_size=BATCH_SIZE, 866 validation_data=(x_test, y_test), 867 callbacks=cbks, 868 epochs=5, 869 verbose=0) 870 871 def test_EarlyStopping_reuse(self): 872 with self.cached_session(): 873 np.random.seed(1337) 874 patience = 3 875 data = np.random.random((100, 1)) 876 labels = np.where(data > 0.5, 1, 0) 877 model = keras.models.Sequential((keras.layers.Dense( 878 1, input_dim=1, activation='relu'), keras.layers.Dense( 879 1, activation='sigmoid'),)) 880 model.compile( 881 optimizer='sgd', loss='binary_crossentropy', metrics=['accuracy']) 882 weights = model.get_weights() 883 884 stopper = keras.callbacks.EarlyStopping(monitor='acc', patience=patience) 885 hist = model.fit(data, labels, callbacks=[stopper], verbose=0, epochs=20) 886 assert len(hist.epoch) >= patience 887 888 # This should allow training to go for at least `patience` epochs 889 model.set_weights(weights) 890 hist = model.fit(data, labels, callbacks=[stopper], verbose=0, epochs=20) 891 assert len(hist.epoch) >= patience 892 893 def test_EarlyStopping_with_baseline(self): 894 with self.cached_session(): 895 np.random.seed(1337) 896 baseline = 0.5 897 (data, labels), _ = testing_utils.get_test_data( 898 train_samples=100, 899 test_samples=50, 900 input_shape=(1,), 901 num_classes=NUM_CLASSES) 902 model = testing_utils.get_small_sequential_mlp( 903 num_hidden=1, num_classes=1, input_dim=1) 904 model.compile( 905 optimizer='sgd', loss='binary_crossentropy', metrics=['acc']) 906 907 stopper = keras.callbacks.EarlyStopping(monitor='acc', 908 baseline=baseline) 909 hist = model.fit(data, labels, callbacks=[stopper], verbose=0, epochs=20) 910 assert len(hist.epoch) == 1 911 912 patience = 3 913 stopper = keras.callbacks.EarlyStopping(monitor='acc', 914 patience=patience, 915 baseline=baseline) 916 hist = model.fit(data, labels, callbacks=[stopper], verbose=0, epochs=20) 917 assert len(hist.epoch) >= patience 918 919 def test_EarlyStopping_final_weights_when_restoring_model_weights(self): 920 921 class DummyModel(object): 922 923 def __init__(self): 924 self.stop_training = False 925 self.weights = -1 926 927 def get_weights(self): 928 return self.weights 929 930 def set_weights(self, weights): 931 self.weights = weights 932 933 def set_weight_to_epoch(self, epoch): 934 self.weights = epoch 935 936 early_stop = keras.callbacks.EarlyStopping(monitor='val_loss', 937 patience=2, 938 restore_best_weights=True) 939 early_stop.model = DummyModel() 940 losses = [0.2, 0.15, 0.1, 0.11, 0.12] 941 # The best configuration is in the epoch 2 (loss = 0.1000). 942 epochs_trained = 0 943 early_stop.on_train_begin() 944 for epoch in range(len(losses)): 945 epochs_trained += 1 946 early_stop.model.set_weight_to_epoch(epoch=epoch) 947 early_stop.on_epoch_end(epoch, logs={'val_loss': losses[epoch]}) 948 if early_stop.model.stop_training: 949 break 950 # The best configuration is in epoch 2 (loss = 0.1000), 951 # and while patience = 2, we're restoring the best weights, 952 # so we end up at the epoch with the best weights, i.e. epoch 2 953 self.assertEqual(early_stop.model.get_weights(), 2) 954 955 def test_RemoteMonitor(self): 956 if requests is None: 957 return 958 959 monitor = keras.callbacks.RemoteMonitor() 960 # This will raise a warning since the default address in unreachable: 961 monitor.on_epoch_end(0, logs={'loss': 0.}) 962 963 def test_LearningRateScheduler(self): 964 with self.cached_session(): 965 np.random.seed(1337) 966 (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data( 967 train_samples=TRAIN_SAMPLES, 968 test_samples=TEST_SAMPLES, 969 input_shape=(INPUT_DIM,), 970 num_classes=NUM_CLASSES) 971 y_test = np_utils.to_categorical(y_test) 972 y_train = np_utils.to_categorical(y_train) 973 model = testing_utils.get_small_sequential_mlp( 974 num_hidden=NUM_HIDDEN, num_classes=NUM_CLASSES, input_dim=INPUT_DIM) 975 model.compile( 976 loss='categorical_crossentropy', 977 optimizer='sgd', 978 metrics=['accuracy']) 979 980 cbks = [keras.callbacks.LearningRateScheduler(lambda x: 1. / (1. + x))] 981 model.fit( 982 x_train, 983 y_train, 984 batch_size=BATCH_SIZE, 985 validation_data=(x_test, y_test), 986 callbacks=cbks, 987 epochs=5, 988 verbose=0) 989 assert ( 990 float(keras.backend.get_value( 991 model.optimizer.lr)) - 0.2) < keras.backend.epsilon() 992 993 cbks = [keras.callbacks.LearningRateScheduler(lambda x, lr: lr / 2)] 994 model.compile( 995 loss='categorical_crossentropy', 996 optimizer='sgd', 997 metrics=['accuracy']) 998 model.fit( 999 x_train, 1000 y_train, 1001 batch_size=BATCH_SIZE, 1002 validation_data=(x_test, y_test), 1003 callbacks=cbks, 1004 epochs=2, 1005 verbose=0) 1006 assert ( 1007 float(keras.backend.get_value( 1008 model.optimizer.lr)) - 0.01 / 4) < keras.backend.epsilon() 1009 1010 cbks = [ 1011 keras.callbacks.LearningRateScheduler( 1012 lambda epoch, _: learning_rate_schedule.CosineDecay(0.01, 2) 1013 (epoch)) 1014 ] 1015 model.compile( 1016 loss='categorical_crossentropy', 1017 optimizer='sgd', 1018 metrics=['accuracy']) 1019 model.fit( 1020 x_train, 1021 y_train, 1022 batch_size=BATCH_SIZE, 1023 validation_data=(x_test, y_test), 1024 callbacks=cbks, 1025 epochs=2, 1026 verbose=0) 1027 1028 cosine_decay_np = 0.5 * (1 + np.cos(np.pi * (1 / 2))) 1029 decayed_learning_rate = 0.01 * cosine_decay_np 1030 1031 assert (float(keras.backend.get_value(model.optimizer.lr)) - 1032 decayed_learning_rate) < keras.backend.epsilon() 1033 1034 def test_ReduceLROnPlateau(self): 1035 with self.cached_session(): 1036 np.random.seed(1337) 1037 (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data( 1038 train_samples=TRAIN_SAMPLES, 1039 test_samples=TEST_SAMPLES, 1040 input_shape=(INPUT_DIM,), 1041 num_classes=NUM_CLASSES) 1042 y_test = np_utils.to_categorical(y_test) 1043 y_train = np_utils.to_categorical(y_train) 1044 1045 def make_model(): 1046 random_seed.set_random_seed(1234) 1047 np.random.seed(1337) 1048 model = testing_utils.get_small_sequential_mlp( 1049 num_hidden=NUM_HIDDEN, num_classes=NUM_CLASSES, input_dim=INPUT_DIM) 1050 model.compile( 1051 loss='categorical_crossentropy', 1052 optimizer=gradient_descent.SGD(lr=0.1)) 1053 return model 1054 1055 # TODO(psv): Make sure the callback works correctly when min_delta is 1056 # set as 0. Test fails when the order of this callback and assertion is 1057 # interchanged. 1058 model = make_model() 1059 cbks = [ 1060 keras.callbacks.ReduceLROnPlateau( 1061 monitor='val_loss', 1062 factor=0.1, 1063 min_delta=0, 1064 patience=1, 1065 cooldown=5) 1066 ] 1067 model.fit( 1068 x_train, 1069 y_train, 1070 batch_size=BATCH_SIZE, 1071 validation_data=(x_test, y_test), 1072 callbacks=cbks, 1073 epochs=2, 1074 verbose=0) 1075 self.assertAllClose( 1076 float(keras.backend.get_value(model.optimizer.lr)), 0.1, atol=1e-4) 1077 1078 model = make_model() 1079 # This should reduce the LR after the first epoch (due to high epsilon). 1080 cbks = [ 1081 keras.callbacks.ReduceLROnPlateau( 1082 monitor='val_loss', 1083 factor=0.1, 1084 min_delta=10, 1085 patience=1, 1086 cooldown=5) 1087 ] 1088 model.fit( 1089 x_train, 1090 y_train, 1091 batch_size=BATCH_SIZE, 1092 validation_data=(x_test, y_test), 1093 callbacks=cbks, 1094 epochs=2, 1095 verbose=2) 1096 self.assertAllClose( 1097 float(keras.backend.get_value(model.optimizer.lr)), 0.01, atol=1e-4) 1098 1099 def test_ReduceLROnPlateau_patience(self): 1100 1101 class DummyOptimizer(object): 1102 1103 def __init__(self): 1104 self.lr = keras.backend.variable(1.0) 1105 1106 class DummyModel(object): 1107 1108 def __init__(self): 1109 self.optimizer = DummyOptimizer() 1110 1111 reduce_on_plateau = keras.callbacks.ReduceLROnPlateau( 1112 monitor='val_loss', patience=2) 1113 reduce_on_plateau.model = DummyModel() 1114 1115 losses = [0.0860, 0.1096, 0.1040] 1116 lrs = [] 1117 1118 for epoch in range(len(losses)): 1119 reduce_on_plateau.on_epoch_end(epoch, logs={'val_loss': losses[epoch]}) 1120 lrs.append(keras.backend.get_value(reduce_on_plateau.model.optimizer.lr)) 1121 1122 # The learning rates should be 1.0 except the last one 1123 for lr in lrs[:-1]: 1124 self.assertEqual(lr, 1.0) 1125 self.assertLess(lrs[-1], 1.0) 1126 1127 def test_ReduceLROnPlateau_backwards_compatibility(self): 1128 with test.mock.patch.object(logging, 'warning') as mock_log: 1129 reduce_on_plateau = keras.callbacks.ReduceLROnPlateau(epsilon=1e-13) 1130 self.assertRegexpMatches( 1131 str(mock_log.call_args), '`epsilon` argument is deprecated') 1132 self.assertFalse(hasattr(reduce_on_plateau, 'epsilon')) 1133 self.assertTrue(hasattr(reduce_on_plateau, 'min_delta')) 1134 self.assertEqual(reduce_on_plateau.min_delta, 1e-13) 1135 1136 def test_CSVLogger(self): 1137 with self.cached_session(): 1138 np.random.seed(1337) 1139 temp_dir = self.get_temp_dir() 1140 self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True) 1141 filepath = os.path.join(temp_dir, 'log.tsv') 1142 1143 sep = '\t' 1144 (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data( 1145 train_samples=TRAIN_SAMPLES, 1146 test_samples=TEST_SAMPLES, 1147 input_shape=(INPUT_DIM,), 1148 num_classes=NUM_CLASSES) 1149 y_test = np_utils.to_categorical(y_test) 1150 y_train = np_utils.to_categorical(y_train) 1151 1152 def make_model(): 1153 np.random.seed(1337) 1154 model = testing_utils.get_small_sequential_mlp( 1155 num_hidden=NUM_HIDDEN, num_classes=NUM_CLASSES, input_dim=INPUT_DIM) 1156 model.compile( 1157 loss='categorical_crossentropy', 1158 optimizer=gradient_descent.SGD(lr=0.1), 1159 metrics=['accuracy']) 1160 return model 1161 1162 # case 1, create new file with defined separator 1163 model = make_model() 1164 cbks = [keras.callbacks.CSVLogger(filepath, separator=sep)] 1165 model.fit( 1166 x_train, 1167 y_train, 1168 batch_size=BATCH_SIZE, 1169 validation_data=(x_test, y_test), 1170 callbacks=cbks, 1171 epochs=1, 1172 verbose=0) 1173 1174 assert os.path.exists(filepath) 1175 with open(filepath) as csvfile: 1176 dialect = csv.Sniffer().sniff(csvfile.read()) 1177 assert dialect.delimiter == sep 1178 del model 1179 del cbks 1180 1181 # case 2, append data to existing file, skip header 1182 model = make_model() 1183 cbks = [keras.callbacks.CSVLogger(filepath, separator=sep, append=True)] 1184 model.fit( 1185 x_train, 1186 y_train, 1187 batch_size=BATCH_SIZE, 1188 validation_data=(x_test, y_test), 1189 callbacks=cbks, 1190 epochs=1, 1191 verbose=0) 1192 1193 # case 3, reuse of CSVLogger object 1194 model.fit( 1195 x_train, 1196 y_train, 1197 batch_size=BATCH_SIZE, 1198 validation_data=(x_test, y_test), 1199 callbacks=cbks, 1200 epochs=2, 1201 verbose=0) 1202 1203 with open(filepath) as csvfile: 1204 list_lines = csvfile.readlines() 1205 for line in list_lines: 1206 assert line.count(sep) == 4 1207 assert len(list_lines) == 5 1208 output = ' '.join(list_lines) 1209 assert len(re.findall('epoch', output)) == 1 1210 1211 os.remove(filepath) 1212 1213 def test_stop_training_csv(self): 1214 # Test that using the CSVLogger callback with the TerminateOnNaN callback 1215 # does not result in invalid CSVs. 1216 np.random.seed(1337) 1217 tmpdir = self.get_temp_dir() 1218 self.addCleanup(shutil.rmtree, tmpdir, ignore_errors=True) 1219 1220 with self.cached_session(): 1221 fp = os.path.join(tmpdir, 'test.csv') 1222 (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data( 1223 train_samples=TRAIN_SAMPLES, 1224 test_samples=TEST_SAMPLES, 1225 input_shape=(INPUT_DIM,), 1226 num_classes=NUM_CLASSES) 1227 1228 y_test = np_utils.to_categorical(y_test) 1229 y_train = np_utils.to_categorical(y_train) 1230 cbks = [keras.callbacks.TerminateOnNaN(), keras.callbacks.CSVLogger(fp)] 1231 model = keras.models.Sequential() 1232 for _ in range(5): 1233 model.add(keras.layers.Dense(2, input_dim=INPUT_DIM, activation='relu')) 1234 model.add(keras.layers.Dense(NUM_CLASSES, activation='linear')) 1235 model.compile(loss='mean_squared_error', 1236 optimizer='rmsprop') 1237 1238 def data_generator(): 1239 i = 0 1240 max_batch_index = len(x_train) // BATCH_SIZE 1241 tot = 0 1242 while 1: 1243 if tot > 3 * len(x_train): 1244 yield (np.ones([BATCH_SIZE, INPUT_DIM]) * np.nan, 1245 np.ones([BATCH_SIZE, NUM_CLASSES]) * np.nan) 1246 else: 1247 yield (x_train[i * BATCH_SIZE: (i + 1) * BATCH_SIZE], 1248 y_train[i * BATCH_SIZE: (i + 1) * BATCH_SIZE]) 1249 i += 1 1250 tot += 1 1251 i %= max_batch_index 1252 1253 history = model.fit_generator(data_generator(), 1254 len(x_train) // BATCH_SIZE, 1255 validation_data=(x_test, y_test), 1256 callbacks=cbks, 1257 epochs=20) 1258 loss = history.history['loss'] 1259 assert len(loss) > 1 1260 assert loss[-1] == np.inf or np.isnan(loss[-1]) 1261 1262 values = [] 1263 with open(fp) as f: 1264 for x in csv.reader(f): 1265 # In windows, due to \r\n line ends we may end up reading empty lines 1266 # after each line. Skip empty lines. 1267 if x: 1268 values.append(x) 1269 assert 'nan' in values[-1], 'The last epoch was not logged.' 1270 1271 def test_TerminateOnNaN(self): 1272 with self.cached_session(): 1273 np.random.seed(1337) 1274 (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data( 1275 train_samples=TRAIN_SAMPLES, 1276 test_samples=TEST_SAMPLES, 1277 input_shape=(INPUT_DIM,), 1278 num_classes=NUM_CLASSES) 1279 1280 y_test = np_utils.to_categorical(y_test) 1281 y_train = np_utils.to_categorical(y_train) 1282 cbks = [keras.callbacks.TerminateOnNaN()] 1283 model = keras.models.Sequential() 1284 initializer = keras.initializers.Constant(value=1e5) 1285 for _ in range(5): 1286 model.add( 1287 keras.layers.Dense( 1288 2, 1289 input_dim=INPUT_DIM, 1290 activation='relu', 1291 kernel_initializer=initializer)) 1292 model.add(keras.layers.Dense(NUM_CLASSES)) 1293 model.compile(loss='mean_squared_error', optimizer='rmsprop') 1294 1295 history = model.fit( 1296 x_train, 1297 y_train, 1298 batch_size=BATCH_SIZE, 1299 validation_data=(x_test, y_test), 1300 callbacks=cbks, 1301 epochs=20) 1302 loss = history.history['loss'] 1303 self.assertEqual(len(loss), 1) 1304 self.assertEqual(loss[0], np.inf) 1305 1306 @unittest.skipIf( 1307 os.name == 'nt', 1308 'use_multiprocessing=True does not work on windows properly.') 1309 def test_LambdaCallback(self): 1310 with self.cached_session(): 1311 np.random.seed(1337) 1312 (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data( 1313 train_samples=TRAIN_SAMPLES, 1314 test_samples=TEST_SAMPLES, 1315 input_shape=(INPUT_DIM,), 1316 num_classes=NUM_CLASSES) 1317 y_test = np_utils.to_categorical(y_test) 1318 y_train = np_utils.to_categorical(y_train) 1319 model = keras.models.Sequential() 1320 model.add( 1321 keras.layers.Dense( 1322 NUM_HIDDEN, input_dim=INPUT_DIM, activation='relu')) 1323 model.add(keras.layers.Dense(NUM_CLASSES, activation='softmax')) 1324 model.compile( 1325 loss='categorical_crossentropy', 1326 optimizer='sgd', 1327 metrics=['accuracy']) 1328 1329 # Start an arbitrary process that should run during model 1330 # training and be terminated after training has completed. 1331 e = threading.Event() 1332 1333 def target(): 1334 e.wait() 1335 1336 t = threading.Thread(target=target) 1337 t.start() 1338 cleanup_callback = keras.callbacks.LambdaCallback( 1339 on_train_end=lambda logs: e.set()) 1340 1341 cbks = [cleanup_callback] 1342 model.fit( 1343 x_train, 1344 y_train, 1345 batch_size=BATCH_SIZE, 1346 validation_data=(x_test, y_test), 1347 callbacks=cbks, 1348 epochs=5, 1349 verbose=0) 1350 t.join() 1351 assert not t.is_alive() 1352 1353 def test_RemoteMonitor_np_array(self): 1354 if requests is None: 1355 self.skipTest('`requests` required to run this test') 1356 with test.mock.patch.object(requests, 'post') as requests_post: 1357 monitor = keras.callbacks.RemoteMonitor(send_as_json=True) 1358 a = np.arange(1) # a 1 by 1 array 1359 logs = {'loss': 0., 'val': a} 1360 monitor.on_epoch_end(0, logs=logs) 1361 send = {'loss': 0., 'epoch': 0, 'val': 0} 1362 requests_post.assert_called_once_with( 1363 monitor.root + monitor.path, json=send, headers=monitor.headers) 1364 1365 def test_RemoteMonitor_np_float32(self): 1366 if requests is None: 1367 self.skipTest('`requests` required to run this test') 1368 1369 with test.mock.patch.object(requests, 'post') as requests_post: 1370 monitor = keras.callbacks.RemoteMonitor(send_as_json=True) 1371 a = np.float32(1.0) # a float32 generic type 1372 logs = {'loss': 0., 'val': a} 1373 monitor.on_epoch_end(0, logs=logs) 1374 send = {'loss': 0., 'epoch': 0, 'val': 1.0} 1375 requests_post.assert_called_once_with( 1376 monitor.root + monitor.path, json=send, headers=monitor.headers) 1377 1378 def test_RemoteMonitorWithJsonPayload(self): 1379 if requests is None: 1380 self.skipTest('`requests` required to run this test') 1381 with self.cached_session(): 1382 (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data( 1383 train_samples=TRAIN_SAMPLES, 1384 test_samples=TEST_SAMPLES, 1385 input_shape=(INPUT_DIM,), 1386 num_classes=NUM_CLASSES) 1387 y_test = keras.utils.np_utils.to_categorical(y_test) 1388 y_train = keras.utils.np_utils.to_categorical(y_train) 1389 model = keras.models.Sequential() 1390 model.add( 1391 keras.layers.Dense( 1392 NUM_HIDDEN, input_dim=INPUT_DIM, activation='relu')) 1393 model.add(keras.layers.Dense(NUM_CLASSES, activation='softmax')) 1394 model.compile( 1395 loss='categorical_crossentropy', 1396 optimizer='rmsprop', 1397 metrics=['accuracy']) 1398 cbks = [keras.callbacks.RemoteMonitor(send_as_json=True)] 1399 1400 with test.mock.patch.object(requests, 'post'): 1401 model.fit( 1402 x_train, 1403 y_train, 1404 batch_size=BATCH_SIZE, 1405 validation_data=(x_test, y_test), 1406 callbacks=cbks, 1407 epochs=1) 1408 1409 def test_callback_params_samples(self): 1410 x, y = np.ones((64, 3)), np.ones((64, 2)) 1411 model = testing_utils.get_small_sequential_mlp( 1412 num_hidden=10, num_classes=2, input_dim=3) 1413 model.compile('sgd', 'mse') 1414 callback = keras.callbacks.Callback() 1415 model.evaluate(x, y, callbacks=[callback]) 1416 self.assertEqual(callback.params['samples'], 64) 1417 1418 1419# A summary that was emitted during a test. Fields: 1420# logdir: str. The logdir of the FileWriter to which the summary was 1421# written. 1422# tag: str. The name of the summary. 1423_ObservedSummary = collections.namedtuple('_ObservedSummary', ('logdir', 'tag')) 1424 1425 1426class _SummaryFile(object): 1427 """A record of summary tags and the files to which they were written. 1428 1429 Fields `scalars`, `images`, `histograms`, and `tensors` are sets 1430 containing `_ObservedSummary` values. 1431 """ 1432 1433 def __init__(self): 1434 self.scalars = set() 1435 self.images = set() 1436 self.histograms = set() 1437 self.tensors = set() 1438 1439 1440def list_summaries(logdir): 1441 """Read all summaries under the logdir into a `_SummaryFile`. 1442 1443 Args: 1444 logdir: A path to a directory that contains zero or more event 1445 files, either as direct children or in transitive subdirectories. 1446 Summaries in these events must only contain old-style scalars, 1447 images, and histograms. Non-summary events, like `graph_def`s, are 1448 ignored. 1449 1450 Returns: 1451 A `_SummaryFile` object reflecting all summaries written to any 1452 event files in the logdir or any of its descendant directories. 1453 1454 Raises: 1455 ValueError: If an event file contains an summary of unexpected kind. 1456 """ 1457 result = _SummaryFile() 1458 for (dirpath, dirnames, filenames) in os.walk(logdir): 1459 del dirnames # unused 1460 for filename in filenames: 1461 if not filename.startswith('events.out.'): 1462 continue 1463 path = os.path.join(dirpath, filename) 1464 for event in summary_iterator.summary_iterator(path): 1465 if not event.summary: # (e.g., it's a `graph_def` event) 1466 continue 1467 for value in event.summary.value: 1468 tag = value.tag 1469 # Case on the `value` rather than the summary metadata because 1470 # the Keras callback uses `summary_ops_v2` to emit old-style 1471 # summaries. See b/124535134. 1472 kind = value.WhichOneof('value') 1473 container = { 1474 'simple_value': result.scalars, 1475 'image': result.images, 1476 'histo': result.histograms, 1477 'tensor': result.tensors, 1478 }.get(kind) 1479 if container is None: 1480 raise ValueError( 1481 'Unexpected summary kind %r in event file %s:\n%r' 1482 % (kind, path, event)) 1483 elif kind == 'tensor' and tag != 'keras': 1484 # Check for V2 scalar summaries, which have a different PB 1485 # structure. 1486 if event.summary.value[ 1487 0].metadata.plugin_data.plugin_name == 'scalars': 1488 container = result.scalars 1489 container.add(_ObservedSummary(logdir=dirpath, tag=tag)) 1490 return result 1491 1492 1493@keras_parameterized.run_with_all_model_types 1494@keras_parameterized.run_all_keras_modes(always_skip_v1=True) 1495class TestTensorBoardV2(keras_parameterized.TestCase): 1496 1497 def setUp(self): 1498 super(TestTensorBoardV2, self).setUp() 1499 self.logdir = os.path.join(self.get_temp_dir(), 'tb') 1500 self.train_dir = os.path.join(self.logdir, 'train') 1501 self.validation_dir = os.path.join(self.logdir, 'validation') 1502 1503 def _get_model(self): 1504 layers = [ 1505 keras.layers.Conv2D(8, (3, 3)), 1506 keras.layers.Flatten(), 1507 keras.layers.Dense(1) 1508 ] 1509 model = testing_utils.get_model_from_layers(layers, input_shape=(10, 10, 1)) 1510 opt = gradient_descent.SGD(learning_rate=0.001) 1511 model.compile( 1512 opt, 1513 'mse', 1514 run_eagerly=testing_utils.should_run_eagerly(), 1515 experimental_run_tf_function=testing_utils.should_run_tf_function()) 1516 return model 1517 1518 def test_TensorBoard_default_logdir(self): 1519 """Regression test for cross-platform pathsep in default logdir.""" 1520 os.chdir(self.get_temp_dir()) 1521 1522 model = self._get_model() 1523 x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1)) 1524 tb_cbk = keras.callbacks.TensorBoard() # no logdir specified 1525 1526 model.fit( 1527 x, 1528 y, 1529 batch_size=2, 1530 epochs=2, 1531 validation_data=(x, y), 1532 callbacks=[tb_cbk]) 1533 1534 summary_file = list_summaries(logdir='.') 1535 train_dir = os.path.join('.', 'logs', 'train') 1536 validation_dir = os.path.join('.', 'logs', 'validation') 1537 self.assertEqual( 1538 summary_file.scalars, { 1539 _ObservedSummary(logdir=train_dir, tag='epoch_loss'), 1540 _ObservedSummary(logdir=validation_dir, tag='epoch_loss'), 1541 }) 1542 1543 def test_TensorBoard_basic(self): 1544 model = self._get_model() 1545 x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1)) 1546 tb_cbk = keras.callbacks.TensorBoard(self.logdir) 1547 1548 model.fit( 1549 x, 1550 y, 1551 batch_size=2, 1552 epochs=2, 1553 validation_data=(x, y), 1554 callbacks=[tb_cbk]) 1555 1556 summary_file = list_summaries(self.logdir) 1557 self.assertEqual( 1558 summary_file.scalars, { 1559 _ObservedSummary(logdir=self.train_dir, tag='epoch_loss'), 1560 _ObservedSummary(logdir=self.validation_dir, tag='epoch_loss'), 1561 }) 1562 1563 def test_TensorBoard_across_invocations(self): 1564 """Regression test for summary writer resource use-after-free. 1565 1566 See: <https://github.com/tensorflow/tensorflow/issues/25707> 1567 """ 1568 model = self._get_model() 1569 x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1)) 1570 tb_cbk = keras.callbacks.TensorBoard(self.logdir) 1571 1572 for _ in (1, 2): 1573 model.fit( 1574 x, 1575 y, 1576 batch_size=2, 1577 epochs=2, 1578 validation_data=(x, y), 1579 callbacks=[tb_cbk]) 1580 1581 summary_file = list_summaries(self.logdir) 1582 self.assertEqual( 1583 summary_file.scalars, { 1584 _ObservedSummary(logdir=self.train_dir, tag='epoch_loss'), 1585 _ObservedSummary(logdir=self.validation_dir, tag='epoch_loss'), 1586 }) 1587 1588 def test_TensorBoard_no_spurious_event_files(self): 1589 model = self._get_model() 1590 x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1)) 1591 tb_cbk = keras.callbacks.TensorBoard(self.logdir) 1592 1593 model.fit( 1594 x, 1595 y, 1596 batch_size=2, 1597 epochs=2, 1598 callbacks=[tb_cbk]) 1599 1600 events_file_run_basenames = set() 1601 for (dirpath, dirnames, filenames) in os.walk(self.logdir): 1602 del dirnames # unused 1603 if any(fn.startswith('events.out.') for fn in filenames): 1604 events_file_run_basenames.add(os.path.basename(dirpath)) 1605 self.assertEqual(events_file_run_basenames, {'train'}) 1606 1607 def test_TensorBoard_batch_metrics(self): 1608 model = self._get_model() 1609 x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1)) 1610 tb_cbk = keras.callbacks.TensorBoard(self.logdir, update_freq=1) 1611 1612 model.fit( 1613 x, 1614 y, 1615 batch_size=2, 1616 epochs=2, 1617 validation_data=(x, y), 1618 callbacks=[tb_cbk]) 1619 1620 summary_file = list_summaries(self.logdir) 1621 self.assertEqual( 1622 summary_file.scalars, 1623 { 1624 _ObservedSummary(logdir=self.train_dir, tag='batch_loss'), 1625 _ObservedSummary(logdir=self.train_dir, tag='epoch_loss'), 1626 _ObservedSummary(logdir=self.validation_dir, tag='epoch_loss'), 1627 }, 1628 ) 1629 1630 def test_TensorBoard_weight_histograms(self): 1631 model = self._get_model() 1632 x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1)) 1633 tb_cbk = keras.callbacks.TensorBoard(self.logdir, histogram_freq=1) 1634 model_type = testing_utils.get_model_type() 1635 1636 model.fit( 1637 x, 1638 y, 1639 batch_size=2, 1640 epochs=2, 1641 validation_data=(x, y), 1642 callbacks=[tb_cbk]) 1643 summary_file = list_summaries(self.logdir) 1644 1645 self.assertEqual( 1646 summary_file.scalars, 1647 { 1648 _ObservedSummary(logdir=self.train_dir, tag='epoch_loss'), 1649 _ObservedSummary(logdir=self.validation_dir, tag='epoch_loss'), 1650 }, 1651 ) 1652 self.assertEqual( 1653 self._strip_layer_names(summary_file.histograms, model_type), 1654 { 1655 _ObservedSummary(logdir=self.train_dir, tag='bias_0'), 1656 _ObservedSummary(logdir=self.train_dir, tag='kernel_0'), 1657 }, 1658 ) 1659 1660 def test_TensorBoard_weight_images(self): 1661 model = self._get_model() 1662 x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1)) 1663 tb_cbk = keras.callbacks.TensorBoard( 1664 self.logdir, histogram_freq=1, write_images=True) 1665 model_type = testing_utils.get_model_type() 1666 1667 model.fit( 1668 x, 1669 y, 1670 batch_size=2, 1671 epochs=2, 1672 validation_data=(x, y), 1673 callbacks=[tb_cbk]) 1674 summary_file = list_summaries(self.logdir) 1675 1676 self.assertEqual( 1677 summary_file.scalars, 1678 { 1679 _ObservedSummary(logdir=self.train_dir, tag='epoch_loss'), 1680 _ObservedSummary(logdir=self.validation_dir, tag='epoch_loss'), 1681 }, 1682 ) 1683 self.assertEqual( 1684 self._strip_layer_names(summary_file.histograms, model_type), 1685 { 1686 _ObservedSummary(logdir=self.train_dir, tag='bias_0'), 1687 _ObservedSummary(logdir=self.train_dir, tag='kernel_0'), 1688 }, 1689 ) 1690 self.assertEqual( 1691 self._strip_layer_names(summary_file.images, model_type), 1692 { 1693 _ObservedSummary(logdir=self.train_dir, tag='bias_0/image/0'), 1694 _ObservedSummary(logdir=self.train_dir, tag='kernel_0/image/0'), 1695 _ObservedSummary(logdir=self.train_dir, tag='kernel_0/image/1'), 1696 _ObservedSummary(logdir=self.train_dir, tag='kernel_0/image/2'), 1697 }, 1698 ) 1699 1700 def test_custom_summary(self): 1701 if not testing_utils.should_run_tf_function(): 1702 self.skipTest('Custom summaries only supported in V2 code path.') 1703 1704 def scalar_v2_mock(name, data, step=None): 1705 """A reimplementation of the scalar plugin to avoid circular deps.""" 1706 metadata = summary_pb2.SummaryMetadata() 1707 # Should match value in tensorboard/plugins/scalar/metadata.py. 1708 metadata.plugin_data.plugin_name = 'scalars' 1709 with summary_ops_v2.summary_scope( 1710 name, 'scalar_summary', values=[data, step]) as (tag, _): 1711 return summary_ops_v2.write( 1712 tag=tag, 1713 tensor=math_ops.cast(data, 'float32'), 1714 step=step, 1715 metadata=metadata) 1716 1717 class LayerWithSummary(keras.layers.Layer): 1718 1719 def call(self, x): 1720 scalar_v2_mock('custom_summary', math_ops.reduce_sum(x)) 1721 return x 1722 1723 model = testing_utils.get_model_from_layers([LayerWithSummary()], 1724 input_shape=(5,), 1725 name='model') 1726 1727 model.compile( 1728 'sgd', 1729 'mse', 1730 run_eagerly=testing_utils.should_run_eagerly(), 1731 experimental_run_tf_function=testing_utils.should_run_tf_function()) 1732 tb_cbk = keras.callbacks.TensorBoard(self.logdir, update_freq=1) 1733 x, y = np.ones((10, 5)), np.ones((10, 5)) 1734 model.fit(x, y, batch_size=2, validation_data=(x, y), callbacks=[tb_cbk]) 1735 summary_file = list_summaries(self.logdir) 1736 self.assertEqual( 1737 summary_file.scalars, 1738 { 1739 _ObservedSummary(logdir=self.train_dir, tag='epoch_loss'), 1740 _ObservedSummary(logdir=self.validation_dir, tag='epoch_loss'), 1741 _ObservedSummary(logdir=self.train_dir, tag='batch_loss'), 1742 _ObservedSummary( 1743 logdir=self.train_dir, 1744 tag='model/layer_with_summary/custom_summary'), 1745 _ObservedSummary( 1746 logdir=self.validation_dir, 1747 tag='model/layer_with_summary/custom_summary') 1748 }, 1749 ) 1750 1751 def _strip_layer_names(self, summaries, model_type): 1752 """Deduplicate summary names modulo layer prefix. 1753 1754 This removes the first slash-component of each tag name: for 1755 instance, "foo/bar/baz" becomes "bar/baz". 1756 1757 Args: 1758 summaries: A `set` of `_ObservedSummary` values. 1759 model_type: The model type currently being tested. 1760 1761 Returns: 1762 A new `set` of `_ObservedSummary` values with layer prefixes 1763 removed. 1764 """ 1765 result = set() 1766 for summary in summaries: 1767 if '/' not in summary.tag: 1768 raise ValueError('tag has no layer name: %r' % summary.tag) 1769 start_from = 2 if 'subclass' in model_type else 1 1770 new_tag = '/'.join(summary.tag.split('/')[start_from:]) 1771 result.add(summary._replace(tag=new_tag)) 1772 return result 1773 1774 def test_TensorBoard_invalid_argument(self): 1775 with self.assertRaisesRegexp(ValueError, 'Unrecognized arguments'): 1776 keras.callbacks.TensorBoard(wwrite_images=True) 1777 1778 1779# Note that this test specifies model_type explicitly. 1780@keras_parameterized.run_all_keras_modes(always_skip_v1=True) 1781class TestTensorBoardV2NonParameterizedTest(keras_parameterized.TestCase): 1782 1783 def setUp(self): 1784 super(TestTensorBoardV2NonParameterizedTest, self).setUp() 1785 self.logdir = os.path.join(self.get_temp_dir(), 'tb') 1786 self.train_dir = os.path.join(self.logdir, 'train') 1787 self.validation_dir = os.path.join(self.logdir, 'validation') 1788 1789 def _get_seq_model(self): 1790 model = keras.models.Sequential([ 1791 keras.layers.Conv2D(8, (3, 3), input_shape=(10, 10, 1)), 1792 keras.layers.Flatten(), 1793 keras.layers.Dense(1), 1794 ]) 1795 opt = gradient_descent.SGD(learning_rate=0.001) 1796 model.compile( 1797 opt, 1798 'mse', 1799 run_eagerly=testing_utils.should_run_eagerly(), 1800 experimental_run_tf_function=testing_utils.should_run_tf_function()) 1801 return model 1802 1803 def fitModelAndAssertKerasModelWritten(self, model): 1804 x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1)) 1805 tb_cbk = keras.callbacks.TensorBoard(self.logdir, 1806 write_graph=True, 1807 profile_batch=0) 1808 model.fit( 1809 x, 1810 y, 1811 batch_size=2, 1812 epochs=2, 1813 validation_data=(x, y), 1814 callbacks=[tb_cbk]) 1815 summary_file = list_summaries(self.logdir) 1816 self.assertEqual( 1817 summary_file.tensors, 1818 { 1819 _ObservedSummary(logdir=self.train_dir, tag='keras'), 1820 }, 1821 ) 1822 1823 def test_TensorBoard_writeSequentialModel_noInputShape(self): 1824 model = keras.models.Sequential([ 1825 keras.layers.Conv2D(8, (3, 3)), 1826 keras.layers.Flatten(), 1827 keras.layers.Dense(1), 1828 ]) 1829 model.compile('sgd', 'mse', run_eagerly=False) 1830 self.fitModelAndAssertKerasModelWritten(model) 1831 1832 def test_TensorBoard_writeSequentialModel_withInputShape(self): 1833 model = keras.models.Sequential([ 1834 keras.layers.Conv2D(8, (3, 3), input_shape=(10, 10, 1)), 1835 keras.layers.Flatten(), 1836 keras.layers.Dense(1), 1837 ]) 1838 model.compile('sgd', 'mse', run_eagerly=False) 1839 self.fitModelAndAssertKerasModelWritten(model) 1840 1841 def test_TensoriBoard_writeModel(self): 1842 inputs = keras.layers.Input([10, 10, 1]) 1843 x = keras.layers.Conv2D(8, (3, 3), activation='relu')(inputs) 1844 x = keras.layers.Flatten()(x) 1845 x = keras.layers.Dense(1)(x) 1846 model = keras.models.Model(inputs=inputs, outputs=[x]) 1847 model.compile('sgd', 'mse', run_eagerly=False) 1848 self.fitModelAndAssertKerasModelWritten(model) 1849 1850 def test_TensorBoard_autoTrace(self): 1851 model = self._get_seq_model() 1852 x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1)) 1853 tb_cbk = keras.callbacks.TensorBoard( 1854 self.logdir, histogram_freq=1, profile_batch=1, write_graph=False) 1855 1856 model.fit( 1857 x, 1858 y, 1859 batch_size=2, 1860 epochs=2, 1861 validation_data=(x, y), 1862 callbacks=[tb_cbk]) 1863 summary_file = list_summaries(self.logdir) 1864 1865 self.assertEqual( 1866 summary_file.tensors, 1867 { 1868 _ObservedSummary(logdir=self.train_dir, tag=u'batch_1'), 1869 }, 1870 ) 1871 1872 def test_TensorBoard_autoTrace_tagNameWithBatchNum(self): 1873 model = self._get_seq_model() 1874 x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1)) 1875 tb_cbk = keras.callbacks.TensorBoard( 1876 self.logdir, histogram_freq=1, profile_batch=2, write_graph=False) 1877 1878 model.fit( 1879 x, 1880 y, 1881 batch_size=2, 1882 epochs=2, 1883 validation_data=(x, y), 1884 callbacks=[tb_cbk]) 1885 summary_file = list_summaries(self.logdir) 1886 1887 self.assertEqual( 1888 summary_file.tensors, 1889 { 1890 _ObservedSummary(logdir=self.train_dir, tag=u'batch_2'), 1891 }, 1892 ) 1893 1894 def test_TensorBoard_autoTrace_profile_batch_largerThanBatchCount(self): 1895 model = self._get_seq_model() 1896 x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1)) 1897 tb_cbk = keras.callbacks.TensorBoard( 1898 self.logdir, histogram_freq=1, profile_batch=10000, write_graph=False) 1899 1900 model.fit( 1901 x, 1902 y, 1903 batch_size=2, 1904 epochs=2, 1905 validation_data=(x, y), 1906 callbacks=[tb_cbk]) 1907 summary_file = list_summaries(self.logdir) 1908 1909 # Enabled trace only on the 10000th batch, thus it should be empty. 1910 self.assertEmpty(summary_file.tensors) 1911 1912 1913class MostRecentlyModifiedFileMatchingPatternTest(test.TestCase): 1914 1915 def test_get_most_recently_modified_file_matching_pattern(self): 1916 file_pattern = 'f.batch{batch:02d}epoch{epoch:02d}.h5' 1917 test_dir = self.get_temp_dir() 1918 path_pattern = os.path.join(test_dir, file_pattern) 1919 file_paths = [ 1920 os.path.join(test_dir, file_name) for file_name in 1921 ['f.batch03epoch02.h5', 'f.batch02epoch02.h5', 'f.batch01epoch01.h5'] 1922 ] 1923 for file_path in file_paths: 1924 with open(file_path, 'w') as f: 1925 # Ensure there are some intervals between file creation. 1926 time.sleep(2) 1927 f.write('foo bar') 1928 # Ensure the files have been actually written. 1929 self.assertEqual( 1930 set([ 1931 os.path.join(test_dir, file_name) 1932 for file_name in os.listdir(test_dir) 1933 ]), set(file_paths)) 1934 self.assertEqual( 1935 keras.callbacks.ModelCheckpoint(None) 1936 ._get_most_recently_modified_file_matching_pattern(path_pattern), 1937 file_paths[-1]) 1938 1939 def test_some_file_not_matching_pattern(self): 1940 file_pattern = 'f.batch{batch:02d}epoch{epoch:02d}.h5' 1941 test_dir = self.get_temp_dir() 1942 path_pattern = os.path.join(test_dir, file_pattern) 1943 file_paths = [ 1944 os.path.join(test_dir, file_name) for file_name in 1945 ['f.batch03epoch02.h5', 'f.batch02epoch02.h5', 'f.baatch01epoch01.h5'] 1946 ] 1947 for file_path in file_paths: 1948 with open(file_path, 'w') as f: 1949 # Ensure there are some intervals between file creation. 1950 time.sleep(2) 1951 f.write('foo bar') 1952 self.assertEqual( 1953 keras.callbacks.ModelCheckpoint(None) 1954 ._get_most_recently_modified_file_matching_pattern(path_pattern), 1955 file_paths[-2]) 1956 1957 def test_get_same_file_if_file_name_equals_pattern(self): 1958 file_name = 'f.batch02.h5' 1959 test_dir = self.get_temp_dir() 1960 file_path = os.path.join(test_dir, file_name) 1961 with open(file_path, 'w') as f: 1962 f.write('foo bar') 1963 self.assertEqual(os.path.join(test_dir, os.listdir(test_dir)[0]), file_path) 1964 self.assertEqual( 1965 keras.callbacks.ModelCheckpoint( 1966 None)._get_most_recently_modified_file_matching_pattern(file_path), 1967 file_path) 1968 1969 def test_get_none_if_file_does_not_exist(self): 1970 file_name = 'f.batch02.h5' 1971 test_dir = self.get_temp_dir() 1972 file_path = os.path.join(test_dir, file_name) 1973 self.assertLen(os.listdir(test_dir), 0) 1974 self.assertEqual( 1975 keras.callbacks.ModelCheckpoint( 1976 None)._get_most_recently_modified_file_matching_pattern(file_path), 1977 None) 1978 1979 def test_using_checkpoint_management_latest_checkpoint(self): 1980 file_pattern = 'f.batch{batch:02d}epoch{epoch:02d}' 1981 ckpt_file_name = 'f.batchXepochY' 1982 test_dir = self.get_temp_dir() 1983 path_pattern = os.path.join(test_dir, file_pattern) 1984 ckpt_file_path = os.path.join(test_dir, ckpt_file_name) 1985 with open(ckpt_file_path, 'w') as f: 1986 f.write('dummy ckpt') 1987 checkpoint_management.update_checkpoint_state_internal( 1988 test_dir, ckpt_file_path) 1989 1990 file_paths = [ 1991 os.path.join(test_dir, file_name) 1992 for file_name in ['f.batch03epoch02', 'f.batch02epoch02'] 1993 ] 1994 for file_path in file_paths: 1995 with open(file_path, 'w') as f: 1996 f.write('foo bar') 1997 1998 # The result returned from checkpoint_management.latest_checkpoint takes 1999 # priority, so even if it was written earlier, we should still return that. 2000 self.assertEqual( 2001 keras.callbacks.ModelCheckpoint(None) 2002 ._get_most_recently_modified_file_matching_pattern(path_pattern), 2003 ckpt_file_path) 2004 2005 2006if __name__ == '__main__': 2007 test.main() 2008