1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14#,============================================================================ 15"""Tests for model saving in the HDF5 format.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import os 22import shutil 23import tempfile 24from absl.testing import parameterized 25import numpy as np 26 27from tensorflow.python import keras 28from tensorflow.python.eager import context 29from tensorflow.python.framework import constant_op 30from tensorflow.python.framework import dtypes 31from tensorflow.python.framework import ops 32from tensorflow.python.framework import test_util 33from tensorflow.python.keras import optimizers 34from tensorflow.python.keras.engine import training 35from tensorflow.python.keras.saving import hdf5_format 36from tensorflow.python.lib.io import file_io 37from tensorflow.python.ops import array_ops 38from tensorflow.python.ops import random_ops 39from tensorflow.python.platform import test 40from tensorflow.python.platform import tf_logging as logging 41from tensorflow.python.training import checkpoint_management 42from tensorflow.python.training import training as training_module 43from tensorflow.python.training.tracking import util as trackable 44 45try: 46 import h5py # pylint:disable=g-import-not-at-top 47except ImportError: 48 h5py = None 49 50 51class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase): 52 53 @test_util.run_in_graph_and_eager_modes 54 def test_weight_loading(self): 55 with self.cached_session(): 56 a = keras.layers.Input(shape=(2,)) 57 x = keras.layers.Dense(3)(a) 58 b = keras.layers.Dense(1)(x) 59 model = keras.models.Model(a, b) 60 61 x = np.random.random((3, 2)) 62 ref_y = model.predict(x) 63 weights = model.get_weights() 64 model.set_weights(weights) 65 y = model.predict(x) 66 self.assertAllClose(ref_y, y) 67 68 with self.assertRaises(ValueError): 69 model.set_weights(weights[1:]) 70 with self.assertRaises(ValueError): 71 model.set_weights(weights[::-1]) 72 73 temp_dir = self.get_temp_dir() 74 self.addCleanup(shutil.rmtree, temp_dir) 75 76 no_extension_path = os.path.join(temp_dir, 'test') 77 model.save_weights(no_extension_path, save_format='tf') 78 model.load_weights(no_extension_path) 79 y = model.predict(x) 80 self.assertAllClose(ref_y, y) 81 82 if h5py is None: 83 return # Skip rest of test if H5py isn't available. 84 85 h5_path = os.path.join(temp_dir, 'test.h5') 86 model.save_weights(h5_path) 87 model.load_weights(h5_path) 88 y = model.predict(x) 89 self.assertAllClose(ref_y, y) 90 91 model.load_weights(h5_path, by_name=True) 92 y = model.predict(x) 93 self.assertAllClose(ref_y, y) 94 95 model.save_weights(no_extension_path, save_format='hdf5') 96 model.load_weights(no_extension_path) 97 y = model.predict(x) 98 self.assertAllClose(ref_y, y) 99 100 @test_util.run_in_graph_and_eager_modes 101 def test_weight_preprocessing(self): 102 input_dim = 3 103 output_dim = 3 104 size = 2 105 cases = [ 106 [ 107 (keras.layers.Bidirectional(keras.layers.SimpleRNN(2))), 108 [np.random.random((2, 1)), np.random.random((2, 1))], 109 (None, 3, 2), 110 ], 111 [ 112 (keras.layers.TimeDistributed(keras.layers.Dense(1))), 113 [np.random.random((2, 1)), np.random.random((1,))], 114 (None, 3, 2), 115 ], 116 [ 117 (keras.layers.Conv1D(output_dim, size, use_bias=False)), 118 [np.random.random((output_dim, input_dim, size, 1))], 119 (None, 4, input_dim), 120 ], 121 [ 122 (keras.layers.Conv2D(output_dim, size, 123 use_bias=False, data_format='channels_first')), 124 [np.random.random((output_dim, input_dim, size, size))], 125 (None, input_dim, 4, 4), 126 ], 127 [ 128 (keras.layers.Conv2DTranspose(output_dim, size, 129 use_bias=False, 130 data_format='channels_first')), 131 [np.random.random((output_dim, input_dim, size, size))], 132 (None, input_dim, 4, 4), 133 ], 134 [ 135 (keras.layers.Conv2DTranspose(output_dim, size, 136 use_bias=False, 137 data_format='channels_last')), 138 [np.random.random((size, size, input_dim, output_dim))], 139 (None, 4, 4, input_dim), 140 ], 141 [ 142 (keras.layers.Conv3D(output_dim, size, 143 use_bias=False, data_format='channels_first')), 144 [np.random.random((output_dim, input_dim, size, size, size))], 145 (None, input_dim, 4, 4, 4), 146 ], 147 [ 148 (keras.layers.GRU(output_dim)), 149 [np.random.random((input_dim, output_dim)), 150 np.random.random((output_dim, output_dim)), 151 np.random.random((output_dim,)), 152 np.random.random((input_dim, output_dim)), 153 np.random.random((output_dim, output_dim)), 154 np.random.random((output_dim,)), 155 np.random.random((input_dim, output_dim)), 156 np.random.random((output_dim, output_dim)), 157 np.random.random((output_dim,))], 158 (None, 4, input_dim), 159 ], 160 [ 161 (keras.layers.LSTM(output_dim)), 162 [np.random.random((input_dim, output_dim)), 163 np.random.random((output_dim, output_dim)), 164 np.random.random((output_dim,)), 165 np.random.random((input_dim, output_dim)), 166 np.random.random((output_dim, output_dim)), 167 np.random.random((output_dim,)), 168 np.random.random((input_dim, output_dim)), 169 np.random.random((output_dim, output_dim)), 170 np.random.random((output_dim,)), 171 np.random.random((input_dim, output_dim)), 172 np.random.random((output_dim, output_dim)), 173 np.random.random((output_dim,))], 174 (None, 4, input_dim), 175 ], 176 ] 177 for layer, weights, input_shape in cases: 178 layer.build(input_shape) 179 _ = hdf5_format.preprocess_weights_for_loading( 180 layer, weights, original_keras_version='1') 181 182 model = keras.models.Sequential([keras.layers.Dense(2, input_dim=2)]) 183 _ = hdf5_format.preprocess_weights_for_loading( 184 model, model.weights, original_keras_version='1') 185 186 x = keras.Input((2,)) 187 y = keras.layers.Dense(2)(x) 188 model = keras.models.Model(x, y) 189 _ = hdf5_format.preprocess_weights_for_loading( 190 model, model.weights, original_keras_version='1') 191 192 @parameterized.named_parameters( 193 ('gru', keras.layers.GRU, { 194 'units': 2, 195 'input_shape': (3, 5) 196 }), 197 ('gru_with_reset_after', keras.layers.GRU, { 198 'units': 2, 199 'input_shape': (3, 5), 200 'reset_after': True 201 }), 202 ('lstm', keras.layers.LSTM, { 203 'units': 2, 204 'input_shape': (3, 5) 205 }), 206 ('cudnngru', keras.layers.CuDNNGRU, { 207 'units': 2, 208 'input_shape': (3, 5) 209 }), 210 ('cudnnlstm', keras.layers.CuDNNLSTM, { 211 'units': 2, 212 'input_shape': (3, 5) 213 })) 214 def test_preprocess_weights_for_loading_rnn_should_be_idempotent( 215 self, layer_class, layer_args): 216 with self.cached_session(): 217 layer = layer_class(**layer_args) 218 layer.build(input_shape=layer_args.get('input_shape')) 219 weights1 = layer.get_weights() 220 weights2 = hdf5_format.preprocess_weights_for_loading( 221 layer, weights1) 222 _ = [ 223 self.assertAllClose(x, y, rtol=1e-05) 224 for (x, y) in zip(weights1, weights2) 225 ] 226 227 @test_util.run_in_graph_and_eager_modes 228 def test_sequential_weight_loading(self): 229 if h5py is None: 230 return 231 232 temp_dir = self.get_temp_dir() 233 self.addCleanup(shutil.rmtree, temp_dir) 234 h5_path = os.path.join(temp_dir, 'test.h5') 235 236 num_hidden = 5 237 input_dim = 3 238 batch_size = 5 239 num_classes = 2 240 241 with self.cached_session(): 242 model = keras.models.Sequential() 243 model.add(keras.layers.Dense(num_hidden, input_dim=input_dim)) 244 model.add(keras.layers.Dense(num_classes)) 245 246 x = np.random.random((batch_size, input_dim)) 247 ref_y = model.predict(x) 248 249 model.save_weights(h5_path) 250 251 model = keras.models.Sequential() 252 model.add(keras.layers.Dense(num_hidden, input_dim=input_dim)) 253 model.add(keras.layers.Dense(num_classes)) 254 model.load_weights(h5_path) 255 y = model.predict(x) 256 257 self.assertAllClose(y, ref_y) 258 259 @test_util.run_in_graph_and_eager_modes 260 def test_sequential_weight_loading_group_name_with_incorrect_length(self): 261 if h5py is None: 262 return 263 264 temp_dir = self.get_temp_dir() 265 self.addCleanup(shutil.rmtree, temp_dir) 266 h5_path = os.path.join(temp_dir, 'test.h5') 267 268 num_hidden = 5 269 input_dim = 3 270 num_classes = 2 271 with self.cached_session(): 272 ref_model = keras.models.Sequential() 273 ref_model.add(keras.layers.Dense(num_hidden, input_dim=input_dim, 274 name='d1')) 275 ref_model.add(keras.layers.Dense(num_classes, name='d2')) 276 ref_model.compile(loss=keras.losses.MSE, 277 optimizer=keras.optimizers.RMSprop(lr=0.0001), 278 metrics=[keras.metrics.categorical_accuracy]) 279 280 f_ref_model = h5py.File(h5_path, 'w') 281 hdf5_format.save_weights_to_hdf5_group(f_ref_model, ref_model.layers) 282 283 f_model = h5py.File(h5_path, 'r') 284 model = keras.models.Sequential() 285 model.add(keras.layers.Dense(num_hidden, use_bias=False, 286 input_dim=input_dim, name='d1')) 287 model.add(keras.layers.Dense(num_classes, name='d2')) 288 model.compile(loss=keras.losses.MSE, 289 optimizer=keras.optimizers.RMSprop(lr=0.0001), 290 metrics=[keras.metrics.categorical_accuracy]) 291 with self.assertRaisesRegexp(ValueError, 292 r'Layer #0 \(named \"d1\"\) expects 1 ' 293 r'weight\(s\), but the saved weights have 2 ' 294 r'element\(s\)\.'): 295 hdf5_format.load_weights_from_hdf5_group_by_name(f_model, model.layers) 296 297 @test_util.run_deprecated_v1 298 def test_sequential_weight_loading_group_name_with_incorrect_shape(self): 299 if h5py is None: 300 return 301 302 temp_dir = self.get_temp_dir() 303 self.addCleanup(shutil.rmtree, temp_dir) 304 h5_path = os.path.join(temp_dir, 'test.h5') 305 306 num_hidden = 5 307 input_dim = 3 308 num_classes = 2 309 with self.cached_session(): 310 ref_model = keras.models.Sequential() 311 ref_model.add(keras.layers.Dense(num_hidden, input_dim=input_dim, 312 name='d1')) 313 ref_model.add(keras.layers.Dense(num_classes, name='d2')) 314 ref_model.compile(loss=keras.losses.MSE, 315 optimizer=keras.optimizers.RMSprop(lr=0.0001), 316 metrics=[keras.metrics.categorical_accuracy]) 317 318 f_ref_model = h5py.File(h5_path, 'w') 319 hdf5_format.save_weights_to_hdf5_group(f_ref_model, ref_model.layers) 320 321 f_model = h5py.File(h5_path, 'r') 322 model = keras.models.Sequential() 323 model.add(keras.layers.Dense(num_hidden + 5, input_dim=input_dim, 324 name='d1')) 325 model.add(keras.layers.Dense(num_classes, name='d2')) 326 model.compile(loss=keras.losses.MSE, 327 optimizer=keras.optimizers.RMSprop(lr=0.0001), 328 metrics=[keras.metrics.categorical_accuracy]) 329 with self.assertRaisesRegexp(ValueError, 330 r'Layer #0 \(named "d1"\), weight ' 331 r'<tf\.Variable \'d1_1\/kernel:0\' ' 332 r'shape=\(3, 10\) dtype=float32> has ' 333 r'shape \(3, 10\), but the saved weight has ' 334 r'shape \(3, 5\)\.'): 335 hdf5_format.load_weights_from_hdf5_group_by_name(f_model, model.layers) 336 337 338class TestWholeModelSaving(test.TestCase): 339 340 @test_util.run_v1_only('b/120994067') 341 def test_sequential_model_saving(self): 342 if h5py is None: 343 self.skipTest('h5py required to run this test') 344 345 with self.cached_session(): 346 model = keras.models.Sequential() 347 model.add(keras.layers.Dense(2, input_shape=(3,))) 348 model.add(keras.layers.RepeatVector(3)) 349 model.add(keras.layers.TimeDistributed(keras.layers.Dense(3))) 350 model.compile( 351 loss=keras.losses.MSE, 352 optimizer=keras.optimizers.RMSprop(lr=0.0001), 353 metrics=[ 354 keras.metrics.categorical_accuracy, 355 keras.metrics.CategoricalCrossentropy( 356 name='cce', label_smoothing=constant_op.constant(0.2)), 357 ], 358 weighted_metrics=[ 359 keras.metrics.categorical_crossentropy, 360 keras.metrics.CategoricalCrossentropy( 361 name='cce', label_smoothing=constant_op.constant(0.2)), 362 ], 363 sample_weight_mode='temporal') 364 365 x = np.random.random((1, 3)) 366 y = np.random.random((1, 3, 3)) 367 model.train_on_batch(x, y) 368 369 out = model.predict(x) 370 fd, fname = tempfile.mkstemp('.h5') 371 keras.models.save_model(model, fname) 372 373 new_model = keras.models.load_model(fname) 374 os.close(fd) 375 os.remove(fname) 376 377 out2 = new_model.predict(x) 378 self.assertAllClose(out, out2, atol=1e-05) 379 380 # test that new updates are the same with both models 381 x = np.random.random((1, 3)) 382 y = np.random.random((1, 3, 3)) 383 model.train_on_batch(x, y) 384 new_model.train_on_batch(x, y) 385 386 x = np.random.random((1, 3)) 387 y = np.random.random((1, 3, 3)) 388 eval_out = model.evaluate(x, y) 389 eval_out2 = new_model.evaluate(x, y) 390 self.assertArrayNear(eval_out, eval_out2, 0.001) 391 392 out = model.predict(x) 393 out2 = new_model.predict(x) 394 395 # TODO(b/120930751) This tolerance should be 1e-05, 396 # very concerning that its not. 397 self.assertAllClose(out, out2, atol=1e-03) 398 399 @test_util.run_deprecated_v1 400 def test_sequential_model_saving_without_input_shape(self): 401 if h5py is None: 402 self.skipTest('h5py required to run this test') 403 404 with self.cached_session(): 405 model = keras.models.Sequential() 406 model.add(keras.layers.Dense(2)) 407 model.add(keras.layers.RepeatVector(3)) 408 model.add(keras.layers.TimeDistributed(keras.layers.Dense(3))) 409 model.compile( 410 loss=keras.losses.MSE, 411 optimizer=keras.optimizers.RMSprop(lr=0.0001), 412 metrics=[ 413 keras.metrics.categorical_accuracy, 414 keras.metrics.CategoricalAccuracy() 415 ], 416 weighted_metrics=[ 417 keras.metrics.categorical_accuracy, 418 keras.metrics.CategoricalAccuracy() 419 ], 420 sample_weight_mode='temporal') 421 x = np.random.random((1, 3)) 422 y = np.random.random((1, 3, 3)) 423 model.train_on_batch(x, y) 424 425 out = model.predict(x) 426 fd, fname = tempfile.mkstemp('.h5', dir=self.get_temp_dir()) 427 model.save(fname) 428 429 new_model = keras.models.load_model(fname) 430 os.close(fd) 431 os.remove(fname) 432 433 out2 = new_model.predict(x) 434 self.assertAllClose(out, out2, atol=1e-05) 435 436 def test_sequential_model_saving_without_compile(self): 437 if h5py is None: 438 self.skipTest('h5py required to run this test') 439 440 with self.cached_session(): 441 model = keras.models.Sequential() 442 model.add(keras.layers.Dense(2, input_shape=(3,))) 443 model.add(keras.layers.RepeatVector(3)) 444 model.add(keras.layers.TimeDistributed(keras.layers.Dense(3))) 445 446 x = np.random.random((1, 3)) 447 out = model.predict(x) 448 fd, fname = tempfile.mkstemp('.h5') 449 450 # Save the model without any compilation or training. 451 keras.models.save_model(model, fname) 452 453 new_model = keras.models.load_model(fname) 454 os.close(fd) 455 os.remove(fname) 456 457 out2 = new_model.predict(x) 458 self.assertAllClose(out, out2, atol=1e-05) 459 460 @test_util.run_deprecated_v1 461 def test_sequential_model_saving_2(self): 462 if h5py is None: 463 self.skipTest('h5py required to run this test') 464 465 with self.cached_session(): 466 # test with custom optimizer, loss 467 468 class CustomOp(keras.optimizers.RMSprop): 469 pass 470 471 def custom_loss(y_true, y_pred): 472 return keras.losses.mse(y_true, y_pred) 473 474 model = keras.models.Sequential() 475 model.add(keras.layers.Dense(2, input_shape=(3,))) 476 model.add(keras.layers.Dense(3)) 477 model.compile(loss=custom_loss, optimizer=CustomOp(), metrics=['acc']) 478 479 x = np.random.random((1, 3)) 480 y = np.random.random((1, 3)) 481 model.train_on_batch(x, y) 482 483 out = model.predict(x) 484 fd, fname = tempfile.mkstemp('.h5') 485 keras.models.save_model(model, fname) 486 487 model = keras.models.load_model( 488 fname, 489 custom_objects={'CustomOp': CustomOp, 490 'custom_loss': custom_loss}) 491 os.close(fd) 492 os.remove(fname) 493 494 out2 = model.predict(x) 495 self.assertAllClose(out, out2, atol=1e-05) 496 497 @test_util.run_deprecated_v1 498 def test_functional_model_saving(self): 499 if h5py is None: 500 self.skipTest('h5py required to run this test') 501 502 with self.cached_session(): 503 inputs = keras.layers.Input(shape=(3,)) 504 x = keras.layers.Dense(2)(inputs) 505 output = keras.layers.Dense(3)(x) 506 507 model = keras.models.Model(inputs, output) 508 model.compile( 509 loss=keras.losses.MSE, 510 optimizer=keras.optimizers.RMSprop(lr=0.0001), 511 metrics=[ 512 keras.metrics.categorical_accuracy, 513 keras.metrics.CategoricalAccuracy() 514 ], 515 weighted_metrics=[ 516 keras.metrics.categorical_accuracy, 517 keras.metrics.CategoricalAccuracy() 518 ]) 519 x = np.random.random((1, 3)) 520 y = np.random.random((1, 3)) 521 model.train_on_batch(x, y) 522 523 out = model.predict(x) 524 fd, fname = tempfile.mkstemp('.h5') 525 keras.models.save_model(model, fname) 526 527 model = keras.models.load_model(fname) 528 os.close(fd) 529 os.remove(fname) 530 531 out2 = model.predict(x) 532 self.assertAllClose(out, out2, atol=1e-05) 533 534 def test_saving_without_compilation(self): 535 if h5py is None: 536 self.skipTest('h5py required to run this test') 537 538 with self.cached_session(): 539 model = keras.models.Sequential() 540 model.add(keras.layers.Dense(2, input_shape=(3,))) 541 model.add(keras.layers.Dense(3)) 542 model.compile(loss='mse', optimizer='sgd', metrics=['acc']) 543 544 fd, fname = tempfile.mkstemp('.h5') 545 keras.models.save_model(model, fname) 546 model = keras.models.load_model(fname) 547 os.close(fd) 548 os.remove(fname) 549 550 def test_saving_with_tf_optimizer(self): 551 if h5py is None: 552 self.skipTest('h5py required to run this test') 553 554 with self.cached_session(): 555 model = keras.models.Sequential() 556 model.add(keras.layers.Dense(2, input_shape=(3,))) 557 model.add(keras.layers.Dense(3)) 558 model.compile(loss='mse', 559 optimizer=training_module.AdadeltaOptimizer(0.1), 560 metrics=['acc']) 561 562 fd, fname = tempfile.mkstemp('.h5') 563 keras.models.save_model(model, fname) 564 model = keras.models.load_model(fname) 565 os.close(fd) 566 os.remove(fname) 567 568 def test_saving_right_after_compilation(self): 569 if h5py is None: 570 self.skipTest('h5py required to run this test') 571 572 with self.cached_session(): 573 model = keras.models.Sequential() 574 model.add(keras.layers.Dense(2, input_shape=(3,))) 575 model.add(keras.layers.Dense(3)) 576 model.compile(loss='mse', optimizer='sgd', metrics=['acc']) 577 model._make_train_function() 578 579 fd, fname = tempfile.mkstemp('.h5') 580 keras.models.save_model(model, fname) 581 model = keras.models.load_model(fname) 582 os.close(fd) 583 os.remove(fname) 584 585 def test_saving_lambda_numpy_array_arguments(self): 586 with self.cached_session(): 587 if h5py is None: 588 self.skipTest('h5py required to run this test') 589 590 mean = np.random.random((4, 2, 3)) 591 std = np.abs(np.random.random((4, 2, 3))) + 1e-5 592 inputs = keras.layers.Input(shape=(4, 2, 3)) 593 output = keras.layers.Lambda(lambda image, mu, std: (image - mu) / std, 594 arguments={'mu': mean, 'std': std})(inputs) 595 model = keras.models.Model(inputs, output) 596 model.compile(loss='mse', optimizer='sgd', metrics=['acc']) 597 598 fd, fname = tempfile.mkstemp('.h5') 599 keras.models.save_model(model, fname) 600 601 model = keras.models.load_model(fname) 602 os.close(fd) 603 os.remove(fname) 604 605 self.assertAllClose(mean, model.layers[1].arguments['mu']) 606 self.assertAllClose(std, model.layers[1].arguments['std']) 607 608 def test_saving_model_with_long_layer_names(self): 609 if h5py is None: 610 self.skipTest('h5py required to run this test') 611 612 with self.cached_session(): 613 # This layer name will make the `layers_name` HDF5 attribute blow 614 # out of proportion. Note that it fits into the internal HDF5 615 # attribute memory limit on its own but because h5py converts 616 # the list of layer names into numpy array, which uses the same 617 # amout of memory for every item, it increases the memory 618 # requirements substantially. 619 x = keras.Input(shape=(2,), name='input_' + ('x' * (2**15))) 620 f = x 621 for i in range(4): 622 f = keras.layers.Dense(2, name='dense_%d' % (i,))(f) 623 model = keras.Model(inputs=[x], outputs=[f]) 624 model.compile(loss='mse', optimizer='adam', metrics=['acc']) 625 626 x = np.random.random((1, 2)) 627 y = np.random.random((1, 2)) 628 model.train_on_batch(x, y) 629 out = model.predict(x) 630 631 fd, fname = tempfile.mkstemp('.h5') 632 keras.models.save_model(model, fname) 633 model = keras.models.load_model(fname) 634 635 # Check that the HDF5 files contains chunked array 636 # of layer names. 637 with h5py.File(fname, 'r') as h5file: 638 num_names_arrays = len([attr for attr in h5file['model_weights'].attrs 639 if attr.startswith('layer_names')]) 640 # The chunking of layer names array should have happened. 641 self.assertGreater(num_names_arrays, 0) 642 out2 = model.predict(x) 643 self.assertAllClose(out, out2, atol=1e-05) 644 645 # Cleanup 646 os.close(fd) 647 os.remove(fname) 648 649 def test_saving_model_with_long_weights_names(self): 650 if h5py is None: 651 self.skipTest('h5py required to run this test') 652 653 with self.cached_session(): 654 x = keras.Input(shape=(2,), name='nested_model_input') 655 f = x 656 for i in range(4): 657 f = keras.layers.Dense(2, name='nested_model_dense_%d' % (i,))(f) 658 # This layer name will make the `weights_name` 659 # HDF5 attribute blow out of proportion. 660 f = keras.layers.Dense(2, name='nested_model_output' + ('x' * (2**14)))(f) 661 nested_model = keras.Model(inputs=[x], outputs=[f], name='nested_model') 662 663 x = keras.Input(shape=(2,), name='outer_model_input') 664 f = nested_model(x) 665 f = keras.layers.Dense(2, name='outer_model_output')(f) 666 667 model = keras.Model(inputs=[x], outputs=[f]) 668 model.compile(loss='mse', optimizer='adam', metrics=['acc']) 669 670 x = np.random.random((1, 2)) 671 y = np.random.random((1, 2)) 672 model.train_on_batch(x, y) 673 out = model.predict(x) 674 675 fd, fname = tempfile.mkstemp('.h5') 676 keras.models.save_model(model, fname) 677 model = keras.models.load_model(fname) 678 679 # Check that the HDF5 files contains chunked array 680 # of weight names. 681 with h5py.File(fname, 'r') as h5file: 682 num_weight_arrays = len( 683 [attr for attr in h5file['model_weights']['nested_model'].attrs 684 if attr.startswith('weight_names')]) 685 # The chunking of layer names array should have happened. 686 self.assertGreater(num_weight_arrays, 0) 687 out2 = model.predict(x) 688 self.assertAllClose(out, out2, atol=1e-05) 689 690 # Cleanup 691 os.close(fd) 692 os.remove(fname) 693 694 @test_util.run_deprecated_v1 695 def test_model_saving_to_pre_created_h5py_file(self): 696 if h5py is None: 697 self.skipTest('h5py required to run this test') 698 699 with self.cached_session(): 700 inputs = keras.Input(shape=(3,)) 701 x = keras.layers.Dense(2)(inputs) 702 outputs = keras.layers.Dense(3)(x) 703 704 model = keras.Model(inputs, outputs) 705 model.compile( 706 loss=keras.losses.MSE, 707 optimizer=keras.optimizers.Adam(), 708 metrics=[ 709 keras.metrics.categorical_accuracy, 710 keras.metrics.CategoricalAccuracy() 711 ]) 712 x = np.random.random((1, 3)) 713 y = np.random.random((1, 3)) 714 model.train_on_batch(x, y) 715 716 out = model.predict(x) 717 fd, fname = tempfile.mkstemp('.h5') 718 with h5py.File(fname, mode='r+') as h5file: 719 keras.models.save_model(model, h5file) 720 loaded_model = keras.models.load_model(h5file) 721 out2 = loaded_model.predict(x) 722 self.assertAllClose(out, out2, atol=1e-05) 723 724 # Test non-default options in h5 725 with h5py.File('_', driver='core', 726 backing_store=False) as h5file: 727 keras.models.save_model(model, h5file) 728 loaded_model = keras.models.load_model(h5file) 729 out2 = loaded_model.predict(x) 730 self.assertAllClose(out, out2, atol=1e-05) 731 732 # Cleanup 733 os.close(fd) 734 os.remove(fname) 735 736 def test_saving_constant_initializer_with_numpy(self): 737 if h5py is None: 738 self.skipTest('h5py required to run this test') 739 740 with self.cached_session(): 741 model = keras.models.Sequential() 742 model.add( 743 keras.layers.Dense( 744 2, 745 input_shape=(3,), 746 kernel_initializer=keras.initializers.Constant(np.ones((3, 2))))) 747 model.add(keras.layers.Dense(3)) 748 model.compile(loss='mse', optimizer='sgd', metrics=['acc']) 749 fd, fname = tempfile.mkstemp('.h5') 750 keras.models.save_model(model, fname) 751 model = keras.models.load_model(fname) 752 os.close(fd) 753 os.remove(fname) 754 755 756class SubclassedModel(training.Model): 757 758 def __init__(self): 759 super(SubclassedModel, self).__init__() 760 self.x_layer = keras.layers.Dense(3) 761 self.b_layer = keras.layers.Dense(1) 762 763 def call(self, a): 764 return self.b_layer(self.x_layer(a)) 765 766 767class TestWeightSavingAndLoadingTFFormat(test.TestCase): 768 769 def test_keras_optimizer_warning(self): 770 graph = ops.Graph() 771 with graph.as_default(), self.session(graph): 772 model = keras.models.Sequential() 773 model.add(keras.layers.Dense(2, input_shape=(3,))) 774 model.add(keras.layers.Dense(3)) 775 model.compile(loss='mse', optimizer=optimizers.Adam(), metrics=['acc']) 776 model._make_train_function() 777 temp_dir = self.get_temp_dir() 778 prefix = os.path.join(temp_dir, 'ckpt') 779 with test.mock.patch.object(logging, 'warning') as mock_log: 780 model.save_weights(prefix) 781 self.assertRegexpMatches( 782 str(mock_log.call_args), 783 'Keras optimizer') 784 785 @test_util.run_in_graph_and_eager_modes 786 def test_tensorflow_format_overwrite(self): 787 with self.cached_session() as session: 788 model = SubclassedModel() 789 temp_dir = self.get_temp_dir() 790 prefix = os.path.join(temp_dir, 'ckpt') 791 792 x = constant_op.constant(np.random.random((3, 2)), dtype=dtypes.float32) 793 executing_eagerly = context.executing_eagerly() 794 model(x) # pylint: disable=not-callable 795 if not executing_eagerly: 796 session.run([v.initializer for v in model.variables]) 797 model.save_weights(prefix, save_format='tensorflow') 798 model.save_weights(prefix, save_format='tensorflow', overwrite=True) 799 with self.assertRaises(EOFError): 800 # Indirectly tests that the user is prompted 801 model.save_weights(prefix, save_format='tensorflow', overwrite=False) 802 803 def test_no_default_session(self): 804 with ops.Graph().as_default(): 805 self.assertFalse(ops.get_default_session()) 806 data = np.random.random((1000, 32)).astype(np.float32) 807 labels = np.random.random((1000, 10)).astype(np.float32) 808 809 model = keras.models.Sequential([ 810 keras.layers.Dense(10, activation='softmax'), 811 keras.layers.Dense(10, activation='softmax')]) 812 813 model.compile(optimizer=training_module.RMSPropOptimizer(0.001), 814 loss='categorical_crossentropy', 815 metrics=['accuracy']) 816 817 model.fit(data, labels) 818 fname = os.path.join(self.get_temp_dir(), 'weights', 'ckpt') 819 model.save_weights(fname) 820 model.load_weights(fname) 821 822 def test_no_graph_pollution(self): 823 with context.graph_mode(): 824 graph = ops.Graph() 825 with graph.as_default(), self.session(graph) as session: 826 model = SubclassedModel() 827 temp_dir = self.get_temp_dir() 828 prefix = os.path.join(temp_dir, 'ckpt') 829 830 x = constant_op.constant(np.random.random((3, 2)), dtype=dtypes.float32) 831 model(x) # pylint: disable=not-callable 832 session.run([v.initializer for v in model.variables]) 833 model.save_weights(prefix, save_format='tensorflow') 834 op_count = len(graph.get_operations()) 835 model.save_weights(prefix, save_format='tensorflow') 836 self.assertEqual(len(graph.get_operations()), op_count) 837 838 model.load_weights(prefix) 839 op_count = len(graph.get_operations()) 840 model.load_weights(prefix) 841 self.assertEqual(len(graph.get_operations()), op_count) 842 843 def _weight_loading_test_template(self, make_model_fn): 844 with self.cached_session(): 845 model = make_model_fn() 846 model.compile( 847 loss='mse', 848 optimizer=training_module.RMSPropOptimizer(0.1), 849 metrics=['acc', keras.metrics.CategoricalAccuracy()]) 850 temp_dir = self.get_temp_dir() 851 prefix = os.path.join(temp_dir, 'ckpt') 852 train_x = np.random.random((3, 2)) 853 train_y = np.random.random((3,)) 854 x = constant_op.constant(train_x, dtype=dtypes.float32) 855 856 model.train_on_batch(train_x, train_y) 857 model.save_weights(prefix, save_format='tf') 858 ref_y_before_train = model.predict(train_x) 859 model.train_on_batch(train_x, train_y) 860 ref_y_after_train = model.predict(train_x) 861 for v in model.variables: 862 self.evaluate( 863 v.assign(random_ops.random_normal(shape=array_ops.shape(v)))) 864 865 self.addCleanup(shutil.rmtree, temp_dir) 866 867 model.load_weights(prefix) 868 self.assertAllClose(ref_y_before_train, self.evaluate(model(x))) 869 870 # Test restore-on-create if this is a subclassed Model (graph Networks 871 # will have already created their variables). 872 load_model = make_model_fn() 873 load_model.load_weights(prefix) 874 self.assertAllClose( 875 ref_y_before_train, 876 self.evaluate(load_model(x))) 877 load_model = make_model_fn() 878 load_model.load_weights(prefix) 879 # We need to run some of the restore ops for predict(), but not all 880 # variables have been created yet (optimizer slot variables). Tests 881 # incremental restore. 882 load_model.predict(train_x) 883 load_model.compile( 884 loss='mse', 885 optimizer=training_module.RMSPropOptimizer(0.1), 886 metrics=['acc', keras.metrics.CategoricalAccuracy()]) 887 load_model.train_on_batch(train_x, train_y) 888 self.assertAllClose(ref_y_after_train, self.evaluate(load_model(x))) 889 890 @test_util.run_in_graph_and_eager_modes 891 def test_weight_loading_graph_model(self): 892 def _make_graph_model(): 893 a = keras.layers.Input(shape=(2,)) 894 x = keras.layers.Dense(3)(a) 895 b = keras.layers.Dense(1)(x) 896 return keras.models.Model(a, b) 897 898 self._weight_loading_test_template(_make_graph_model) 899 900 @test_util.run_in_graph_and_eager_modes 901 def test_weight_loading_subclassed_model(self): 902 self._weight_loading_test_template(SubclassedModel) 903 904 def _new_layer_weight_loading_test_template( 905 self, first_model_fn, second_model_fn, restore_init_fn): 906 with self.cached_session() as session: 907 model = first_model_fn() 908 temp_dir = self.get_temp_dir() 909 prefix = os.path.join(temp_dir, 'ckpt') 910 911 x = constant_op.constant(np.random.random((3, 2)), dtype=dtypes.float32) 912 executing_eagerly = context.executing_eagerly() 913 ref_y_tensor = model(x) 914 if not executing_eagerly: 915 session.run([v.initializer for v in model.variables]) 916 ref_y = self.evaluate(ref_y_tensor) 917 model.save_weights(prefix) 918 self.assertEqual( 919 prefix, 920 checkpoint_management.latest_checkpoint(temp_dir)) 921 for v in model.variables: 922 self.evaluate( 923 v.assign(random_ops.random_normal(shape=array_ops.shape(v)))) 924 925 self.addCleanup(shutil.rmtree, temp_dir) 926 927 second_model = second_model_fn() 928 second_model.load_weights(prefix) 929 second_model(x) 930 self.evaluate(restore_init_fn(second_model)) 931 second_model.save_weights(prefix) 932 # Check that the second model's checkpoint loads into the original model 933 model.load_weights(prefix) 934 y = self.evaluate(model(x)) 935 self.assertAllClose(ref_y, y) 936 937 @test_util.run_in_graph_and_eager_modes 938 def test_weight_loading_graph_model_added_layer(self): 939 def _save_graph_model(): 940 a = keras.layers.Input(shape=(2,)) 941 x = keras.layers.Dense(3, name='first')(a) 942 b = keras.layers.Dense(1, name='second')(x) 943 return keras.models.Model(a, b) 944 def _restore_graph_model(): 945 a = keras.layers.Input(shape=(2,)) 946 x = keras.layers.Dense(3, name='first')(a) 947 y = keras.layers.Dense(1, name='second')(x) 948 b = keras.layers.Dense(3, name='secondjr')(y) 949 return keras.models.Model(a, b) 950 def _restore_init_fn(restore_model): 951 return [v.initializer for v in restore_model.layers[-1].variables] 952 953 self._new_layer_weight_loading_test_template( 954 _save_graph_model, _restore_graph_model, 955 _restore_init_fn) 956 957 @test_util.run_in_graph_and_eager_modes 958 def test_weight_loading_graph_model_added_no_weight_layer(self): 959 def _save_graph_model(): 960 a = keras.layers.Input(shape=(2,)) 961 x = keras.layers.Dense(3, name='first')(a) 962 b = keras.layers.Dense(1, name='second')(x) 963 return keras.models.Model(a, b) 964 def _restore_graph_model(): 965 a = keras.layers.Input(shape=(2,)) 966 x = keras.layers.Dense(3, name='first')(a) 967 y = keras.layers.Dropout(rate=0.1)(x) 968 b = keras.layers.Dense(1, name='second')(y) 969 return keras.models.Model(a, b) 970 def _restore_init_fn(restore_model): 971 del restore_model # unused 972 return [] 973 974 self._new_layer_weight_loading_test_template( 975 _save_graph_model, _restore_graph_model, 976 _restore_init_fn) 977 978 @test_util.run_in_graph_and_eager_modes 979 def test_weight_loading_subclassed_model_added_layer(self): 980 981 class SubclassedModelRestore(training.Model): 982 983 def __init__(self): 984 super(SubclassedModelRestore, self).__init__() 985 self.x_layer = keras.layers.Dense(3) 986 self.y_layer = keras.layers.Dense(3) 987 self.b_layer = keras.layers.Dense(1) 988 989 def call(self, a): 990 return self.b_layer(self.y_layer(self.x_layer(a))) 991 992 def _restore_init_fn(restore_model): 993 return [v.initializer for v in restore_model.y_layer.variables] 994 995 self._new_layer_weight_loading_test_template( 996 SubclassedModel, SubclassedModelRestore, 997 _restore_init_fn) 998 999 @test_util.run_in_graph_and_eager_modes 1000 def test_incompatible_checkpoint(self): 1001 save_path = trackable.Checkpoint().save( 1002 os.path.join(self.get_temp_dir(), 'ckpt')) 1003 m = keras.Model() 1004 with self.assertRaisesRegexp(AssertionError, 'Nothing to load'): 1005 m.load_weights(save_path) 1006 m.dense = keras.layers.Dense(2) 1007 m.dense(constant_op.constant([[1.]])) 1008 with self.assertRaisesRegexp( 1009 AssertionError, 'Nothing except the root object matched'): 1010 m.load_weights(save_path) 1011 1012 @test_util.run_in_graph_and_eager_modes 1013 def test_directory_passed(self): 1014 m = keras.Model() 1015 v = m.add_weight(name='v', shape=[]) 1016 self.evaluate(v.assign(42.)) 1017 prefix = os.path.join(self.get_temp_dir(), '{}'.format(ops.uid()), 'ckpt/') 1018 m.save_weights(prefix) 1019 self.evaluate(v.assign(2.)) 1020 m.load_weights(prefix) 1021 self.assertEqual(42., self.evaluate(v)) 1022 1023 @test_util.run_in_graph_and_eager_modes 1024 def test_relative_path(self): 1025 m = keras.Model() 1026 v = m.add_weight(name='v', shape=[]) 1027 os.chdir(self.get_temp_dir()) 1028 1029 prefix = 'ackpt' 1030 self.evaluate(v.assign(42.)) 1031 m.save_weights(prefix) 1032 self.assertTrue(file_io.file_exists('ackpt.index')) 1033 self.evaluate(v.assign(1.)) 1034 m.load_weights(prefix) 1035 self.assertEqual(42., self.evaluate(v)) 1036 1037 prefix = 'subdir/ackpt' 1038 self.evaluate(v.assign(43.)) 1039 m.save_weights(prefix) 1040 self.assertTrue(file_io.file_exists('subdir/ackpt.index')) 1041 self.evaluate(v.assign(2.)) 1042 m.load_weights(prefix) 1043 self.assertEqual(43., self.evaluate(v)) 1044 1045 prefix = 'ackpt/' 1046 self.evaluate(v.assign(44.)) 1047 m.save_weights(prefix) 1048 self.assertTrue(file_io.file_exists('ackpt/.index')) 1049 self.evaluate(v.assign(3.)) 1050 m.load_weights(prefix) 1051 self.assertEqual(44., self.evaluate(v)) 1052 1053 @test_util.run_in_graph_and_eager_modes 1054 def test_nonexistant_prefix_directory(self): 1055 m = keras.Model() 1056 v = m.add_weight(name='v', shape=[]) 1057 self.evaluate(v.assign(42.)) 1058 prefix = os.path.join(self.get_temp_dir(), '{}'.format(ops.uid()), 'bckpt') 1059 m.save_weights(prefix) 1060 self.evaluate(v.assign(2.)) 1061 m.load_weights(prefix) 1062 self.assertEqual(42., self.evaluate(v)) 1063 1064if __name__ == '__main__': 1065 test.main() 1066