• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for Keras model saving code."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import os
23import shutil
24import sys
25import tempfile
26
27from absl.testing import parameterized
28import numpy as np
29from six import string_types
30
31from tensorflow.python import keras
32from tensorflow.python import tf2
33from tensorflow.python.eager import context
34from tensorflow.python.feature_column import feature_column_lib
35from tensorflow.python.framework import constant_op
36from tensorflow.python.framework import dtypes
37from tensorflow.python.framework import ops
38from tensorflow.python.framework import sparse_tensor
39from tensorflow.python.keras import combinations
40from tensorflow.python.keras import keras_parameterized
41from tensorflow.python.keras import losses
42from tensorflow.python.keras import optimizer_v1
43from tensorflow.python.keras import optimizers
44from tensorflow.python.keras import testing_utils
45from tensorflow.python.keras.engine import sequential
46from tensorflow.python.keras.feature_column import dense_features
47from tensorflow.python.keras.feature_column import sequence_feature_column as ksfc
48from tensorflow.python.keras.layers import core
49from tensorflow.python.keras.saving import model_config
50from tensorflow.python.keras.saving import save
51from tensorflow.python.keras.utils import generic_utils
52from tensorflow.python.ops import array_ops
53from tensorflow.python.ops import lookup_ops
54from tensorflow.python.ops import math_ops
55from tensorflow.python.platform import test
56from tensorflow.python.saved_model import loader_impl
57from tensorflow.python.training import training as training_module
58
59
60if sys.version_info >= (3, 6):
61  import pathlib  # pylint:disable=g-import-not-at-top
62try:
63  import h5py  # pylint:disable=g-import-not-at-top
64except ImportError:
65  h5py = None
66
67
68class TestSaveModel(test.TestCase, parameterized.TestCase):
69
70  def setUp(self):
71    super(TestSaveModel, self).setUp()
72    self.model = testing_utils.get_small_sequential_mlp(1, 2, 3)
73    self.subclassed_model = testing_utils.get_small_subclass_mlp(1, 2)
74
75  def assert_h5_format(self, path):
76    if h5py is not None:
77      self.assertTrue(h5py.is_hdf5(path),
78                      'Model saved at path {} is not a valid hdf5 file.'
79                      .format(path))
80
81  def assert_saved_model(self, path):
82    loader_impl.parse_saved_model(path)
83
84  @testing_utils.run_v2_only
85  def test_save_format_defaults(self):
86    path = os.path.join(self.get_temp_dir(), 'model_path')
87    save.save_model(self.model, path)
88    self.assert_saved_model(path)
89
90  @testing_utils.run_v2_only
91  def test_save_format_defaults_pathlib(self):
92    if sys.version_info < (3, 6):
93      self.skipTest('pathlib is only available for python version >= 3.6')
94    path = pathlib.Path(self.get_temp_dir()) / 'model_path'
95    save.save_model(self.model, path)
96    self.assert_saved_model(path)
97
98  @testing_utils.run_v2_only
99  def test_save_hdf5(self):
100    path = os.path.join(self.get_temp_dir(), 'model')
101    save.save_model(self.model, path, save_format='h5')
102    self.assert_h5_format(path)
103    with self.assertRaisesRegex(
104        NotImplementedError,
105        'requires the model to be a Functional model or a Sequential model.'):
106      save.save_model(self.subclassed_model, path, save_format='h5')
107
108  @testing_utils.run_v2_only
109  def test_save_load_hdf5_pathlib(self):
110    if sys.version_info < (3, 6):
111      self.skipTest('pathlib is only available for python version >= 3.6')
112    path = pathlib.Path(self.get_temp_dir()) / 'model'
113    save.save_model(self.model, path, save_format='h5')
114    save.load_model(path)
115
116  @testing_utils.run_v2_only
117  def test_save_tf(self):
118    path = os.path.join(self.get_temp_dir(), 'model')
119    save.save_model(self.model, path, save_format='tf')
120    self.assert_saved_model(path)
121    with self.assertRaisesRegex(ValueError, 'input shapes have not been set'):
122      save.save_model(self.subclassed_model, path, save_format='tf')
123    self.subclassed_model.predict(np.random.random((3, 5)))
124    save.save_model(self.subclassed_model, path, save_format='tf')
125    self.assert_saved_model(path)
126
127  @testing_utils.run_v2_only
128  def test_save_load_tf_string(self):
129    path = os.path.join(self.get_temp_dir(), 'model')
130    save.save_model(self.model, path, save_format='tf')
131    save.load_model(path)
132
133  @testing_utils.run_v2_only
134  def test_save_load_tf_pathlib(self):
135    if sys.version_info < (3, 6):
136      self.skipTest('pathlib is only available for python version >= 3.6')
137    path = pathlib.Path(self.get_temp_dir()) / 'model'
138    save.save_model(self.model, path, save_format='tf')
139    save.load_model(path)
140
141  @testing_utils.run_v2_only
142  def test_save_load_weights_tf_pathlib(self):
143    if sys.version_info < (3, 6):
144      self.skipTest('pathlib is only available for python version >= 3.6')
145    path = pathlib.Path(self.get_temp_dir()) / 'model'
146    self.model.save_weights(path, save_format='tf')
147    self.model.load_weights(path)
148
149  @testing_utils.run_v2_only
150  def test_save_load_weights_hdf5_pathlib(self):
151    if sys.version_info < (3, 6):
152      self.skipTest('pathlib is only available for python version >= 3.6')
153    path = pathlib.Path(self.get_temp_dir()) / 'model'
154    self.model.save_weights(path, save_format='h5')
155    self.model.load_weights(path)
156
157  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
158  def test_saving_with_dense_features(self):
159    cols = [
160        feature_column_lib.numeric_column('a'),
161        feature_column_lib.indicator_column(
162            feature_column_lib.categorical_column_with_vocabulary_list(
163                'b', ['one', 'two']))
164    ]
165    input_layers = {
166        'a': keras.layers.Input(shape=(1,), name='a'),
167        'b': keras.layers.Input(shape=(1,), name='b', dtype='string')
168    }
169
170    fc_layer = dense_features.DenseFeatures(cols)(input_layers)
171    output = keras.layers.Dense(10)(fc_layer)
172
173    model = keras.models.Model(input_layers, output)
174
175    model.compile(
176        loss=keras.losses.MSE,
177        optimizer='rmsprop',
178        metrics=[keras.metrics.categorical_accuracy])
179
180    config = model.to_json()
181    loaded_model = model_config.model_from_json(config)
182
183    inputs_a = np.arange(10).reshape(10, 1)
184    inputs_b = np.arange(10).reshape(10, 1).astype('str')
185
186    with self.cached_session():
187      # Initialize tables for V1 lookup.
188      if not context.executing_eagerly():
189        self.evaluate(lookup_ops.tables_initializer())
190
191      self.assertLen(loaded_model.predict({'a': inputs_a, 'b': inputs_b}), 10)
192
193  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
194  def test_saving_with_sequence_features(self):
195    cols = [
196        feature_column_lib.sequence_numeric_column('a'),
197        feature_column_lib.indicator_column(
198            feature_column_lib.sequence_categorical_column_with_vocabulary_list(
199                'b', ['one', 'two']))
200    ]
201    input_layers = {
202        'a':
203            keras.layers.Input(shape=(None, 1), sparse=True, name='a'),
204        'b':
205            keras.layers.Input(
206                shape=(None, 1), sparse=True, name='b', dtype='string')
207    }
208
209    fc_layer, _ = ksfc.SequenceFeatures(cols)(input_layers)
210    # TODO(tibell): Figure out the right dtype and apply masking.
211    # sequence_length_mask = array_ops.sequence_mask(sequence_length)
212    # x = keras.layers.GRU(32)(fc_layer, mask=sequence_length_mask)
213    x = keras.layers.GRU(32)(fc_layer)
214    output = keras.layers.Dense(10)(x)
215
216    model = keras.models.Model(input_layers, output)
217
218    model.compile(
219        loss=keras.losses.MSE,
220        optimizer='rmsprop',
221        metrics=[keras.metrics.categorical_accuracy])
222
223    config = model.to_json()
224    loaded_model = model_config.model_from_json(config)
225
226    batch_size = 10
227    timesteps = 1
228
229    values_a = np.arange(10, dtype=np.float32)
230    indices_a = np.zeros((10, 3), dtype=np.int64)
231    indices_a[:, 0] = np.arange(10)
232    inputs_a = sparse_tensor.SparseTensor(indices_a, values_a,
233                                          (batch_size, timesteps, 1))
234
235    values_b = np.zeros(10, dtype=np.str)
236    indices_b = np.zeros((10, 3), dtype=np.int64)
237    indices_b[:, 0] = np.arange(10)
238    inputs_b = sparse_tensor.SparseTensor(indices_b, values_b,
239                                          (batch_size, timesteps, 1))
240
241    with self.cached_session():
242      # Initialize tables for V1 lookup.
243      if not context.executing_eagerly():
244        self.evaluate(lookup_ops.tables_initializer())
245
246      self.assertLen(
247          loaded_model.predict({
248              'a': inputs_a,
249              'b': inputs_b
250          }, steps=1), batch_size)
251
252  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
253  def test_saving_h5_for_rnn_layers(self):
254    # See https://github.com/tensorflow/tensorflow/issues/35731 for details.
255    inputs = keras.Input([10, 91], name='train_input')
256    rnn_layers = [
257        keras.layers.LSTMCell(size, recurrent_dropout=0, name='rnn_cell%d' % i)
258        for i, size in enumerate([512, 512])
259    ]
260    rnn_output = keras.layers.RNN(
261        rnn_layers, return_sequences=True, name='rnn_layer')(inputs)
262    pred_feat = keras.layers.Dense(91, name='prediction_features')(rnn_output)
263    pred = keras.layers.Softmax()(pred_feat)
264    model = keras.Model(inputs=[inputs], outputs=[pred, pred_feat])
265    path = os.path.join(self.get_temp_dir(), 'model_path.h5')
266    model.save(path)
267
268    # Make sure the variable name is unique.
269    self.assertNotEqual(rnn_layers[0].kernel.name,
270                        rnn_layers[1].kernel.name)
271    self.assertIn('rnn_cell1', rnn_layers[1].kernel.name)
272
273  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
274  def test_saving_optimizer_weights(self):
275
276    class MyModel(keras.Model):
277
278      def __init__(self):
279        super(MyModel, self).__init__()
280        self.layer = keras.layers.Dense(1)
281
282      def call(self, x):
283        return self.layer(x)
284
285    path = os.path.join(self.get_temp_dir(), 'weights_path')
286    x, y = np.ones((10, 10)), np.ones((10, 1))
287
288    model = MyModel()
289    model.compile('rmsprop', loss='bce')
290    model.train_on_batch(x, y)
291    model.reset_metrics()
292    model.save_weights(path, save_format='tf')
293
294    batch_loss = model.train_on_batch(x, y)
295
296    new_model = MyModel()
297    new_model.compile('rmsprop', loss='bce')
298    new_model.train_on_batch(x, y)
299    new_model.reset_metrics()
300
301    new_model.load_weights(path)
302    new_batch_loss = new_model.train_on_batch(x, y)
303
304    self.assertAllClose(batch_loss, new_batch_loss)
305
306  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
307  def test_saving_model_with_custom_object(self):
308    with generic_utils.custom_object_scope(), self.cached_session():
309
310      @generic_utils.register_keras_serializable()
311      class CustomLoss(losses.MeanSquaredError):
312        pass
313
314      model = sequential.Sequential(
315          [core.Dense(units=1, input_shape=(1,))])
316      model.compile(optimizer='sgd', loss=CustomLoss())
317      model.fit(np.zeros([10, 1]), np.zeros([10, 1]))
318
319      temp_dir = self.get_temp_dir()
320      filepath = os.path.join(temp_dir, 'saving')
321      model.save(filepath)
322
323      # Make sure the model can be correctly load back.
324      _ = save.load_model(filepath, compile=True)
325
326
327@keras_parameterized.run_with_all_saved_model_formats
328class TestWholeModelSaving(keras_parameterized.TestCase):
329
330  def _save_model_dir(self, dirname='saved_model'):
331    temp_dir = self.get_temp_dir()
332    self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True)
333    return os.path.join(temp_dir, dirname)
334
335  def _assert_same_weights_and_metrics(self, model, loaded_model):
336    """Checks that the loaded weights and metrics are the same as the original.
337
338    Args:
339      model: original model
340      loaded_model: loaded model
341    """
342    self.assertAllClose(model.weights, loaded_model.weights)
343
344    if loaded_model.optimizer:
345      if testing_utils.get_save_format() == 'tf':
346        # TODO(b/153110928): Keras TF format doesn't restore optimizer weights
347        # currently.
348        return
349      self.assertAllClose(model.optimizer.weights,
350                          loaded_model.optimizer.weights)
351
352    # In V1/Graph mode, the model isn't built, so the metrics are not loaded
353    # immediately (requires model to be called on some data before building
354    # metrics).
355    check_metrics = tf2.enabled() and context.executing_eagerly()
356
357    if check_metrics:
358      self.assertAllEqual([m.name for m in model.metrics],
359                          [m.name for m in loaded_model.metrics])
360
361  @keras_parameterized.run_with_all_model_types
362  @keras_parameterized.run_all_keras_modes
363  def test_save_and_load(self):
364    saved_model_dir = self._save_model_dir()
365    save_format = testing_utils.get_save_format()
366    save_kwargs = testing_utils.get_save_kwargs()
367
368    if ((save_format == 'h5' or not save_kwargs.get('save_traces', True)) and
369        testing_utils.get_model_type() == 'subclass'):
370      # HDF5 format currently does not allow saving subclassed models.
371      # When saving with `save_traces=False`, the subclassed model must have a
372      # get_config/from_config, which the autogenerated model does not have.
373      return
374
375    with self.cached_session():
376      model = testing_utils.get_model_from_layers(
377          [keras.layers.Dense(2),
378           keras.layers.RepeatVector(3),
379           keras.layers.TimeDistributed(keras.layers.Dense(3))],
380          input_shape=(3,))
381      model.compile(
382          loss=keras.losses.MSE,
383          optimizer=keras.optimizer_v2.rmsprop.RMSprop(lr=0.0001),
384          metrics=[
385              keras.metrics.categorical_accuracy,
386              keras.metrics.CategoricalCrossentropy(
387                  name='cce', label_smoothing=constant_op.constant(0.2)),
388          ],
389          weighted_metrics=[
390              keras.metrics.categorical_crossentropy,
391              keras.metrics.CategoricalCrossentropy(
392                  name='cce', label_smoothing=constant_op.constant(0.2)),
393          ],
394          sample_weight_mode='temporal')
395
396      x = np.random.random((1, 3))
397      y = np.random.random((1, 3, 3))
398      model.train_on_batch(x, y)
399
400      out = model.predict(x)
401      keras.models.save_model(
402          model, saved_model_dir, save_format=save_format,
403          **save_kwargs)
404
405      loaded_model = keras.models.load_model(saved_model_dir)
406      self._assert_same_weights_and_metrics(model, loaded_model)
407
408      out2 = loaded_model.predict(x)
409      self.assertAllClose(out, out2, atol=1e-05)
410
411      eval_out = model.evaluate(x, y)
412      eval_out2 = loaded_model.evaluate(x, y)
413      self.assertArrayNear(eval_out, eval_out2, 0.001)
414
415  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
416  def test_sequential_model_saving_without_input_shape(self):
417    saved_model_dir = self._save_model_dir()
418    save_format = testing_utils.get_save_format()
419    with self.cached_session():
420      model = keras.models.Sequential()
421      model.add(keras.layers.Dense(2))
422      model.add(keras.layers.RepeatVector(3))
423      model.add(keras.layers.TimeDistributed(keras.layers.Dense(3)))
424      model.compile(
425          loss=keras.losses.MSE,
426          optimizer='rmsprop',
427          metrics=[
428              keras.metrics.categorical_accuracy,
429              keras.metrics.CategoricalAccuracy(name='cat_acc')
430          ],
431          weighted_metrics=[
432              keras.metrics.categorical_accuracy,
433              keras.metrics.CategoricalAccuracy(name='cat_acc2')
434          ],
435          sample_weight_mode='temporal')
436      x = np.random.random((1, 3))
437      y = np.random.random((1, 3, 3))
438      model.train_on_batch(x, y)
439
440      out = model.predict(x)
441      model.save(saved_model_dir, save_format=save_format)
442
443      new_model = keras.models.load_model(saved_model_dir)
444
445      self._assert_same_weights_and_metrics(model, new_model)
446
447      out2 = new_model.predict(x)
448      self.assertAllClose(out, out2, atol=1e-05)
449
450  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
451  def test_sequential_model_saving_without_compile(self):
452    saved_model_dir = self._save_model_dir()
453    save_format = testing_utils.get_save_format()
454    with self.cached_session():
455      model = keras.models.Sequential()
456      model.add(keras.layers.Dense(2, input_shape=(3,)))
457      model.add(keras.layers.RepeatVector(3))
458      model.add(keras.layers.TimeDistributed(keras.layers.Dense(3)))
459
460      x = np.random.random((1, 3))
461      out = model.predict(x)
462
463      # Save the model without any compilation or training.
464      keras.models.save_model(model, saved_model_dir, save_format=save_format)
465
466      new_model = keras.models.load_model(saved_model_dir)
467      self._assert_same_weights_and_metrics(model, new_model)
468
469      out2 = new_model.predict(x)
470      self.assertAllClose(out, out2, atol=1e-05)
471
472  def test_sequential_model_saving_2(self):
473    saved_model_dir = self._save_model_dir()
474    save_format = testing_utils.get_save_format()
475
476    with ops.Graph().as_default(), self.cached_session():
477      # test with custom optimizer, loss
478
479      class CustomOp(optimizer_v1.RMSprop):
480        pass
481
482      def custom_loss(y_true, y_pred):
483        return keras.losses.mse(y_true, y_pred)
484
485      model = keras.models.Sequential()
486      model.add(keras.layers.Dense(2, input_shape=(3,)))
487      model.add(keras.layers.Dense(3))
488      model.compile(loss=custom_loss, optimizer=CustomOp(), metrics=['acc'])
489
490      x = np.random.random((1, 3))
491      y = np.random.random((1, 3))
492      model.train_on_batch(x, y)
493
494      out = model.predict(x)
495      keras.models.save_model(model, saved_model_dir, save_format=save_format)
496
497      new_model = keras.models.load_model(
498          saved_model_dir,
499          custom_objects={'CustomOp': CustomOp,
500                          'custom_loss': custom_loss})
501      self._assert_same_weights_and_metrics(model, new_model)
502
503      out2 = new_model.predict(x)
504      self.assertAllClose(out, out2, atol=1e-05)
505
506  def test_saving_without_compilation(self):
507    saved_model_dir = self._save_model_dir()
508    save_format = testing_utils.get_save_format()
509    model = keras.models.Sequential()
510    model.add(keras.layers.Dense(2, input_shape=(3,)))
511    model.add(keras.layers.Dense(3))
512    model.compile(loss='mse', optimizer='sgd', metrics=['acc'])
513
514    keras.models.save_model(model, saved_model_dir, save_format=save_format)
515    model = keras.models.load_model(saved_model_dir)
516
517  def test_saving_with_tf_optimizer(self):
518    saved_model_dir = self._save_model_dir()
519    save_format = testing_utils.get_save_format()
520
521    model = keras.models.Sequential()
522    model.add(keras.layers.Dense(2, input_shape=(3,)))
523    model.add(keras.layers.Dense(3))
524    model.compile(loss='mse',
525                  optimizer=training_module.AdadeltaOptimizer(0.1),
526                  metrics=['acc'])
527
528    keras.models.save_model(model, saved_model_dir, save_format=save_format)
529    model = keras.models.load_model(saved_model_dir)
530
531  def test_saving_right_after_compilation(self):
532    saved_model_dir = self._save_model_dir()
533    save_format = testing_utils.get_save_format()
534    with self.cached_session():
535      model = keras.models.Sequential()
536      model.add(keras.layers.Dense(2, input_shape=(3,)))
537      model.add(keras.layers.Dense(3))
538      model.compile(loss='mse', optimizer='sgd', metrics=['acc'])
539      if not ops.executing_eagerly_outside_functions():
540        model._make_train_function()
541      keras.models.save_model(model, saved_model_dir, save_format=save_format)
542      model = keras.models.load_model(saved_model_dir)
543
544  def test_saving_lambda_numpy_array_arguments(self):
545    saved_model_dir = self._save_model_dir()
546    save_format = testing_utils.get_save_format()
547
548    if h5py is None:
549      self.skipTest('h5py required to run this test')
550
551    mean = np.random.random((4, 2, 3))
552    std = np.abs(np.random.random((4, 2, 3))) + 1e-5
553    inputs = keras.layers.Input(shape=(4, 2, 3))
554    output = keras.layers.Lambda(lambda image, mu, std: (image - mu) / std,
555                                 arguments={'mu': mean, 'std': std})(inputs)
556    model = keras.models.Model(inputs, output)
557    model.compile(loss='mse', optimizer='sgd', metrics=['acc'])
558
559    keras.models.save_model(model, saved_model_dir, save_format=save_format)
560
561    model = keras.models.load_model(saved_model_dir)
562
563    self.assertAllClose(mean, model.layers[1].arguments['mu'])
564    self.assertAllClose(std, model.layers[1].arguments['std'])
565
566  def test_saving_model_with_long_layer_names(self):
567    saved_model_dir = self._save_model_dir()
568    save_format = testing_utils.get_save_format()
569    with self.cached_session():
570      # This layer name will make the `layers_name` HDF5 attribute blow
571      # out of proportion. Note that it fits into the internal HDF5
572      # attribute memory limit on its own but because h5py converts
573      # the list of layer names into numpy array, which uses the same
574      # amount of memory for every item, it increases the memory
575      # requirements substantially.
576      x = keras.Input(shape=(2,), name='input_' + ('x' * (2**15)))
577      f = x
578      for i in range(4):
579        f = keras.layers.Dense(2, name='dense_%d' % (i,))(f)
580      model = keras.Model(inputs=[x], outputs=[f])
581      model.compile(
582          'adam', loss=keras.losses.MeanSquaredError(), metrics=['acc'])
583
584      x = np.random.random((1, 2))
585      y = np.random.random((1, 2))
586      model.train_on_batch(x, y)
587      out = model.predict(x)
588
589      keras.models.save_model(model, saved_model_dir, save_format=save_format)
590      model = keras.models.load_model(saved_model_dir)
591
592      if save_format in ['tf', 'tensorflow']:
593        return
594      # Check that the HDF5 files contains chunked array
595      # of layer names.
596      with h5py.File(saved_model_dir, 'r') as h5file:
597        num_names_arrays = len([attr for attr in h5file['model_weights'].attrs
598                                if attr.startswith('layer_names')])
599      # The chunking of layer names array should have happened.
600      self.assertGreater(num_names_arrays, 0)
601      out2 = model.predict(x)
602      self.assertAllClose(out, out2, atol=1e-05)
603
604  def test_saving_model_with_long_weights_names(self):
605    saved_model_dir = self._save_model_dir()
606    save_format = testing_utils.get_save_format()
607
608    with self.cached_session():
609      x = keras.Input(shape=(2,), name='nested_model_input')
610      f = x
611      for i in range(4):
612        f = keras.layers.Dense(2, name='nested_model_dense_%d' % (i,))(f)
613      # This layer name will make the `weights_name`
614      # HDF5 attribute blow out of proportion.
615      f = keras.layers.Dense(2, name='nested_model_output' + ('x' * (2**14)))(f)
616      nested_model = keras.Model(inputs=[x], outputs=[f], name='nested_model')
617
618      x = keras.Input(shape=(2,), name='outer_model_input')
619      f = nested_model(x)
620      f = keras.layers.Dense(2, name='outer_model_output')(f)
621
622      model = keras.Model(inputs=[x], outputs=[f])
623      model.compile(loss='mse', optimizer='adam', metrics=['acc'])
624
625      x = np.random.random((1, 2))
626      y = np.random.random((1, 2))
627      model.train_on_batch(x, y)
628      out = model.predict(x)
629
630      keras.models.save_model(model, saved_model_dir, save_format=save_format)
631      model = keras.models.load_model(saved_model_dir)
632
633      if save_format in ['h5', 'hdf5', 'keras']:
634        # Check that the HDF5 files contains chunked array
635        # of weight names.
636        with h5py.File(saved_model_dir, 'r') as h5file:
637          num_weight_arrays = len(
638              [attr for attr in h5file['model_weights']['nested_model'].attrs
639               if attr.startswith('weight_names')])
640        # The chunking of layer names array should have happened.
641        self.assertGreater(num_weight_arrays, 0)
642      out2 = model.predict(x)
643      self.assertAllClose(out, out2, atol=1e-05)
644
645  def test_model_saving_to_pre_created_h5py_file(self):
646    saved_model_dir = self._save_model_dir()
647    save_format = testing_utils.get_save_format()
648    with ops.Graph().as_default(), self.cached_session():
649      inputs = keras.Input(shape=(3,))
650      x = keras.layers.Dense(2)(inputs)
651      outputs = keras.layers.Dense(3)(x)
652
653      model = keras.Model(inputs, outputs)
654      model.compile(
655          loss=keras.losses.MSE,
656          optimizer=optimizer_v1.Adam(),
657          metrics=[
658              keras.metrics.categorical_accuracy,
659              keras.metrics.CategoricalAccuracy()
660          ])
661      x = np.random.random((1, 3))
662      y = np.random.random((1, 3))
663      model.train_on_batch(x, y)
664
665      out = model.predict(x)
666
667      keras.models.save_model(model, saved_model_dir, save_format=save_format)
668      loaded_model = keras.models.load_model(saved_model_dir)
669      out1 = loaded_model.predict(x)
670      self.assertAllClose(out, out1, atol=1e-05)
671      if save_format in ['tf', 'tensorflow']:
672        return
673
674      # Test h5 format specifically
675      fd, fname = tempfile.mkstemp('.h5')
676      with h5py.File(fname, mode='r+') as h5file:
677        keras.models.save_model(model, h5file)
678        loaded_model = keras.models.load_model(h5file)
679        out2 = loaded_model.predict(x)
680      self.assertAllClose(out, out2, atol=1e-05)
681
682      # Test non-default options in h5
683      with h5py.File(
684          '_', driver='core', mode='w', backing_store=False) as h5file:
685        keras.models.save_model(model, h5file)
686        loaded_model = keras.models.load_model(h5file)
687        out2 = loaded_model.predict(x)
688      self.assertAllClose(out, out2, atol=1e-05)
689
690      # Cleanup
691      os.close(fd)
692      os.remove(fname)
693
694  def test_model_saving_to_new_dir_path(self):
695    saved_model_dir = os.path.join(self._save_model_dir(), 'newdir',
696                                   'saved_model')
697    save_format = testing_utils.get_save_format()
698
699    with self.cached_session():
700      model = keras.models.Sequential()
701      model.add(keras.layers.Dense(2, input_shape=(3,)))
702      model.add(keras.layers.RepeatVector(3))
703      model.add(keras.layers.TimeDistributed(keras.layers.Dense(3)))
704
705      x = np.random.random((1, 3))
706      out = model.predict(x)
707
708      keras.models.save_model(model, saved_model_dir, save_format=save_format)
709
710      new_model = keras.models.load_model(saved_model_dir)
711      self._assert_same_weights_and_metrics(model, new_model)
712
713      out2 = new_model.predict(x)
714      self.assertAllClose(out, out2, atol=1e-05)
715
716  def test_model_raise_exception_with_failed_saving(self):
717    if h5py is None:
718      self.skipTest('h5py required to run this test')
719
720    saved_model_dir = self._save_model_dir()
721    saved_model_path = os.path.join(saved_model_dir, 'saved_model.h5')
722
723    with self.cached_session():
724      model = keras.models.Sequential()
725      model.add(keras.layers.Dense(2, input_shape=(3,)))
726      model.add(keras.layers.RepeatVector(3))
727      model.add(keras.layers.TimeDistributed(keras.layers.Dense(3)))
728
729      with self.assertRaisesRegex(OSError, 'Unable to create file'):
730        with h5py.File(saved_model_path, 'w'):
731          keras.models.save_model(model, saved_model_path)
732
733  def test_saving_constant_initializer_with_numpy(self):
734    saved_model_dir = self._save_model_dir()
735    save_format = testing_utils.get_save_format()
736
737    model = keras.models.Sequential()
738    model.add(
739        keras.layers.Dense(
740            2,
741            input_shape=(3,),
742            kernel_initializer=keras.initializers.Constant(np.ones((3, 2)))))
743    model.add(keras.layers.Dense(3))
744    model.compile(loss='mse', optimizer='sgd', metrics=['acc'])
745    keras.models.save_model(model, saved_model_dir, save_format=save_format)
746    model = keras.models.load_model(saved_model_dir)
747
748  def test_saving_group_naming_h5py(self):
749    # Test saving model with layer which name is prefix to a previous layer
750    # name.
751
752    temp_dir = self.get_temp_dir()
753    self.addCleanup(shutil.rmtree, temp_dir)
754    h5_path = os.path.join(temp_dir, 'test.h5')
755
756    input_layer = keras.layers.Input((None, None, 3), name='test_input')
757    x = keras.layers.Conv2D(1, 1, name='conv1/conv')(input_layer)
758    x = keras.layers.Activation('relu', name='conv1')(x)
759    model = keras.models.Model(inputs=input_layer, outputs=x)
760
761    model.save_weights(h5_path)
762    model.load_weights(h5_path)
763
764  def test_primitive_attrs_contain_no_extraneous_strings(self):
765    if h5py is None:
766      self.skipTest('h5py required to run this test')
767
768    saved_model_dir = self._save_model_dir()
769    save_format = testing_utils.get_save_format()
770    model = keras.models.Sequential()
771    model.add(keras.layers.Dense(1, input_shape=[2]))
772    model.save(saved_model_dir, save_format=save_format)
773    if save_format in ['tf', 'tensorflow']:
774      return
775
776    h5file = h5py.File(saved_model_dir, 'r')
777    self.assertRegex(h5file.attrs['keras_version'], r'^[\d]+\.[\d]+\.[\S]+$')
778
779  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
780  def test_functional_model_with_custom_loss_and_metric(self):
781    def _make_model():
782      inputs = keras.Input(shape=(4,))
783      x = keras.layers.Dense(8, activation='relu')(inputs)
784      outputs = keras.layers.Dense(3, activation='softmax')(x)
785      model = keras.Model(inputs=inputs, outputs=outputs)
786      custom_loss = keras.layers.Lambda(lambda x: keras.backend.sum(x * x))(x)
787      model.add_loss(custom_loss)
788      model.add_metric(custom_loss, aggregation='mean', name='custom_loss')
789      return model
790
791    saved_model_dir = self._save_model_dir()
792    save_format = testing_utils.get_save_format()
793
794    with self.cached_session():
795      model = _make_model()
796      model.compile(
797          loss=keras.losses.SparseCategoricalCrossentropy(),
798          optimizer=optimizers.gradient_descent_v2.SGD(),
799          metrics=[keras.metrics.SparseCategoricalCrossentropy()])
800      x = np.random.normal(size=(32, 4))
801      y = np.random.randint(0, 3, size=32)
802      model.train_on_batch(x, y)
803      evaluation_results = model.evaluate(x, y)
804      # Save and reload model.
805      model.save(saved_model_dir, save_format=save_format)
806      del model  # Prevent misuse.
807      loaded_model = keras.models.load_model(saved_model_dir)
808      loaded_model_eval_results = loaded_model.evaluate(x, y)
809      # Assert all evaluation results are the same.
810      self.assertAllClose(evaluation_results, loaded_model_eval_results, 1e-9)
811      # Check correctness of the loss calculation.
812      self.assertAllGreater(evaluation_results, 0.)
813      evaluation_results = dict(
814          zip(loaded_model.metrics_names, evaluation_results))
815      self.assertNear(
816          evaluation_results['sparse_categorical_crossentropy'] +
817          evaluation_results['custom_loss'], evaluation_results['loss'], 1e-6)
818
819  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
820  def test_save_uncompiled_model_with_optimizer(self):
821    with self.cached_session() as session:
822      saved_model_dir = self._save_model_dir()
823      save_format = testing_utils.get_save_format()
824      model = keras.models.Sequential([keras.layers.Dense(1, input_shape=(3,))])
825      # Set the model's optimizer but don't compile. This can happen if the
826      # model is trained with a custom training loop.
827      model.optimizer = keras.optimizer_v2.rmsprop.RMSprop(lr=0.0001)
828      if not context.executing_eagerly():
829        session.run([v.initializer for v in model.variables])
830      model.save(saved_model_dir, save_format=save_format)
831
832      if save_format in ['tf', 'tensorflow']:
833        loaded = keras.models.load_model(saved_model_dir)
834        self.assertIsInstance(loaded.optimizer,
835                              keras.optimizer_v2.optimizer_v2.OptimizerV2)
836
837  @combinations.generate(combinations.combine(mode=['eager']))
838  def test_functional_model_with_getitem_op_layer(self):
839    inp = keras.Input(shape=(8))
840
841    out = inp[:]
842    model = keras.Model(
843        inputs=[inp],
844        outputs=out)
845    batch_size = 7
846    x = array_ops.stack([
847        math_ops.range(8) for _ in range(batch_size)])
848    args = [x]
849    expected = x[:]
850
851    self.assertAllEqual(model(args), expected)
852    self.assertAllEqual(model.predict(args, batch_size=batch_size), expected)
853
854    # Make sure it can be successfully saved and loaded
855    save_format = testing_utils.get_save_format()
856    saved_model_dir = self._save_model_dir()
857    keras.models.save_model(model, saved_model_dir, save_format=save_format)
858
859    loaded_model = keras.models.load_model(saved_model_dir)
860
861    self.assertAllEqual(loaded_model(args), expected)
862    self.assertAllEqual(loaded_model.predict(args, batch_size=batch_size),
863                        expected)
864
865  @combinations.generate(combinations.combine(mode=['eager']))
866  def test_shared_objects(self):
867    class OuterLayer(keras.layers.Layer):
868
869      def __init__(self, inner_layer):
870        super(OuterLayer, self).__init__()
871        self.inner_layer = inner_layer
872
873      def call(self, inputs):
874        return self.inner_layer(inputs)
875
876      def get_config(self):
877        return {
878            'inner_layer': generic_utils.serialize_keras_object(
879                self.inner_layer)
880        }
881
882      @classmethod
883      def from_config(cls, config):
884        return cls(generic_utils.deserialize_keras_object(
885            config['inner_layer']))
886
887    class InnerLayer(keras.layers.Layer):
888
889      def __init__(self):
890        super(InnerLayer, self).__init__()
891        self.v = self.add_weight(name='v', shape=[], dtype=dtypes.float32)
892
893      def call(self, inputs):
894        return self.v + inputs
895
896      @classmethod
897      def from_config(cls, config):
898        return cls()
899
900    # Create a model with 2 output layers that share the same inner layer.
901    inner_layer = InnerLayer()
902    outer_layer_1 = OuterLayer(inner_layer)
903    outer_layer_2 = OuterLayer(inner_layer)
904    input_ = keras.Input(shape=(1,))
905    model = keras.Model(
906        inputs=input_, outputs=[outer_layer_1(input_), outer_layer_2(input_)])
907
908    # Changes to the shared layer should affect both outputs.
909    model.layers[1].inner_layer.v.assign(5)
910    self.assertAllEqual(model(1), [6.0, 6.0])
911    model.layers[1].inner_layer.v.assign(3)
912    self.assertAllEqual(model(1), [4.0, 4.0])
913
914    # After loading, changes to the shared layer should still affect both
915    # outputs.
916    def _do_assertions(loaded):
917      loaded.layers[1].inner_layer.v.assign(5)
918      self.assertAllEqual(loaded(1), [6.0, 6.0])
919      loaded.layers[1].inner_layer.v.assign(3)
920      self.assertAllEqual(loaded(1), [4.0, 4.0])
921      loaded.layers[2].inner_layer.v.assign(5)
922      self.assertAllEqual(loaded(1), [6.0, 6.0])
923      loaded.layers[2].inner_layer.v.assign(3)
924      self.assertAllEqual(loaded(1), [4.0, 4.0])
925
926    # We'd like to make sure we only attach shared object IDs when strictly
927    # necessary, so we'll recursively traverse the generated config to count
928    # whether we have the exact number we expect.
929    def _get_all_keys_recursive(dict_or_iterable):
930      if isinstance(dict_or_iterable, dict):
931        for key in dict_or_iterable.keys():
932          yield key
933        for key in _get_all_keys_recursive(dict_or_iterable.values()):
934          yield key
935      elif isinstance(dict_or_iterable, string_types):
936        return
937      else:
938        try:
939          for item in dict_or_iterable:
940            for key in _get_all_keys_recursive(item):
941              yield key
942        # Not an iterable or dictionary
943        except TypeError:
944          return
945
946    with generic_utils.CustomObjectScope({
947        'OuterLayer': OuterLayer, 'InnerLayer': InnerLayer}):
948
949      # Test saving and loading to disk
950      save_format = testing_utils.get_save_format()
951      saved_model_dir = self._save_model_dir()
952      keras.models.save_model(model, saved_model_dir, save_format=save_format)
953      loaded = keras.models.load_model(saved_model_dir)
954      _do_assertions(loaded)
955
956      # Test recreating directly from config
957      config = model.get_config()
958      key_count = collections.Counter(_get_all_keys_recursive(config))
959      self.assertEqual(key_count[generic_utils.SHARED_OBJECT_KEY], 2)
960      loaded = keras.Model.from_config(config)
961      _do_assertions(loaded)
962
963  @combinations.generate(combinations.combine(mode=['eager']))
964  def test_shared_objects_wrapper(self):
965    """Tests that shared layers wrapped with `Wrapper` restore correctly."""
966    input_ = keras.Input(shape=(1,))
967    unwrapped = keras.layers.Layer(name='unwrapped')
968    wrapped = keras.layers.Wrapper(unwrapped, name='wrapped')
969    model = keras.Model(inputs=input_,
970                        outputs=[unwrapped(input_), wrapped(input_)])
971
972    # Test recreating directly from config
973    config = model.get_config()
974    loaded = keras.Model.from_config(config)
975    self.assertIs(loaded.layers[1], loaded.layers[2].layer)
976
977    # Test saving and loading to disk
978    save_format = testing_utils.get_save_format()
979    saved_model_dir = self._save_model_dir()
980    keras.models.save_model(model, saved_model_dir, save_format=save_format)
981    loaded = keras.models.load_model(saved_model_dir)
982    self.assertIs(loaded.layers[1], loaded.layers[2].layer)
983
984
985# Factory functions to create models that will be serialized inside a Network.
986def _make_graph_network(input_size, output_size):
987  inputs = keras.Input(input_size)
988  x = keras.layers.Dense(8, activation='relu')(inputs)
989  y = keras.layers.Dense(output_size)(x)
990  return keras.Model(inputs=inputs, outputs=y)
991
992
993def _make_sequential(input_size, output_size):
994  del input_size
995  return keras.Sequential([
996      keras.layers.Dense(8, activation='relu'),
997      keras.layers.Dense(output_size),
998  ])
999
1000
1001def _make_sequential_built(input_size, output_size):
1002  model = _make_sequential(input_size, output_size)
1003  model.build((None, input_size))
1004  return model
1005
1006
1007def _make_sequential_graph_network(input_size, output_size):
1008  return keras.Sequential([
1009      keras.layers.InputLayer(input_size),
1010      keras.layers.Dense(8, activation='relu'),
1011      keras.layers.Dense(output_size),
1012  ])
1013
1014
1015def _make_sequential_input_shape(input_size, output_size):
1016  return keras.Sequential([
1017      keras.layers.Dense(8, activation='relu', input_shape=(input_size,)),
1018      keras.layers.Dense(output_size),
1019  ])
1020
1021
1022class _make_subclassed(keras.Model):  # pylint: disable=invalid-name
1023
1024  def __init__(self, input_size, output_size):
1025    super(_make_subclassed, self).__init__()
1026    self._config = {'input_size': input_size, 'output_size': output_size}
1027    self._hidden_layer = keras.layers.Dense(8, activation='relu', name='hidden')
1028    self._logits_layer = keras.layers.Dense(output_size, name='logits')
1029
1030  def call(self, inputs):
1031    x = self._hidden_layer(inputs)
1032    return self._logits_layer(x)
1033
1034  def get_config(self):
1035    return self._config
1036
1037  @classmethod
1038  def from_config(cls, config):
1039    return cls(**config)
1040
1041
1042class _make_subclassed_built(_make_subclassed):  # pylint: disable=invalid-name
1043
1044  def __init__(self, input_size, output_size):
1045    super(_make_subclassed_built, self).__init__(input_size, output_size)
1046    self.build((None, input_size))
1047
1048
1049@combinations.generate(combinations.combine(mode=['graph', 'eager']))
1050class TestWholeModelSavingWithNesting(test.TestCase, parameterized.TestCase):
1051  """Tests saving a whole model that contains other models."""
1052
1053  @parameterized.named_parameters([
1054      ('graph_network', _make_graph_network),
1055      ('sequential', _make_sequential),
1056      ('sequential_built', _make_sequential_built),
1057      ('sequential_graph_network', _make_sequential_graph_network),
1058      ('sequential_input_shape', _make_sequential_input_shape),
1059      ('subclassed', _make_subclassed),
1060      ('subclassed_built', _make_subclassed_built),
1061  ])
1062  def test_functional(self, model_fn):
1063    """Tests serializing a model that uses a nested model to share weights."""
1064    if h5py is None:
1065      self.skipTest('h5py required to run this test')
1066
1067    def _make_model():
1068      inputs = (keras.Input(shape=(4,), name='examples'),
1069                keras.Input(shape=(4,), name='neighbors'))
1070      base_model = model_fn(inputs[0].shape.as_list()[-1], 2)
1071      outputs = keras.layers.add([base_model(inputs[0]), base_model(inputs[1])])
1072      return keras.Model(inputs=inputs, outputs=outputs)
1073
1074    with self.cached_session():
1075      x = (np.random.normal(size=(16, 4)).astype(np.float32),
1076           np.random.normal(size=(16, 4)).astype(np.float32))
1077      model = _make_model()
1078      predictions = model(x)
1079      # Save and reload.
1080      model_path = os.path.join(self.get_temp_dir(), 'model.h5')
1081      model.save(model_path)
1082      del model
1083      loaded_model = keras.models.load_model(
1084          model_path,
1085          custom_objects={
1086              '_make_subclassed': _make_subclassed,
1087              '_make_subclassed_built': _make_subclassed_built,
1088          },
1089          compile=False)
1090      self.assertAllClose(loaded_model(x), predictions, 1e-9)
1091
1092
1093if __name__ == '__main__':
1094  test.main()
1095