• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for Keras model saving code."""
16
17import collections
18import os
19import pathlib
20import shutil
21import tempfile
22import warnings
23
24from absl.testing import parameterized
25import numpy as np
26
27from tensorflow.python import keras
28from tensorflow.python import tf2
29from tensorflow.python.eager import context
30from tensorflow.python.feature_column import feature_column_lib
31from tensorflow.python.framework import constant_op
32from tensorflow.python.framework import dtypes
33from tensorflow.python.framework import ops
34from tensorflow.python.framework import sparse_tensor
35from tensorflow.python.keras import combinations
36from tensorflow.python.keras import keras_parameterized
37from tensorflow.python.keras import losses
38from tensorflow.python.keras import optimizer_v1
39from tensorflow.python.keras import optimizers
40from tensorflow.python.keras import testing_utils
41from tensorflow.python.keras.engine import functional
42from tensorflow.python.keras.engine import sequential
43from tensorflow.python.keras.feature_column import dense_features
44from tensorflow.python.keras.feature_column import sequence_feature_column as ksfc
45from tensorflow.python.keras.layers import core
46from tensorflow.python.keras.saving import model_config
47from tensorflow.python.keras.saving import save
48from tensorflow.python.keras.utils import generic_utils
49from tensorflow.python.ops import array_ops
50from tensorflow.python.ops import lookup_ops
51from tensorflow.python.ops import math_ops
52from tensorflow.python.platform import test
53from tensorflow.python.saved_model import loader_impl
54from tensorflow.python.training import training as training_module
55
56
57try:
58  import h5py  # pylint:disable=g-import-not-at-top
59except ImportError:
60  h5py = None
61
62
63class TestSaveModel(test.TestCase, parameterized.TestCase):
64
65  def setUp(self):
66    super(TestSaveModel, self).setUp()
67    self.model = testing_utils.get_small_sequential_mlp(1, 2, 3)
68    self.subclassed_model = testing_utils.get_small_subclass_mlp(1, 2)
69
70  def assert_h5_format(self, path):
71    if h5py is not None:
72      self.assertTrue(h5py.is_hdf5(path),
73                      'Model saved at path {} is not a valid hdf5 file.'
74                      .format(path))
75
76  def assert_saved_model(self, path):
77    loader_impl.parse_saved_model(path)
78
79  @testing_utils.run_v2_only
80  def test_save_format_defaults(self):
81    path = os.path.join(self.get_temp_dir(), 'model_path')
82    save.save_model(self.model, path)
83    self.assert_saved_model(path)
84
85  @testing_utils.run_v2_only
86  def test_save_format_defaults_pathlib(self):
87    path = pathlib.Path(self.get_temp_dir()) / 'model_path'
88    save.save_model(self.model, path)
89    self.assert_saved_model(path)
90
91  @testing_utils.run_v2_only
92  def test_save_hdf5(self):
93    path = os.path.join(self.get_temp_dir(), 'model')
94    save.save_model(self.model, path, save_format='h5')
95    self.assert_h5_format(path)
96    with self.assertRaisesRegex(
97        NotImplementedError,
98        'requires the model to be a Functional model or a Sequential model.'):
99      save.save_model(self.subclassed_model, path, save_format='h5')
100
101  @testing_utils.run_v2_only
102  def test_save_load_hdf5_pathlib(self):
103    path = pathlib.Path(self.get_temp_dir()) / 'model'
104    save.save_model(self.model, path, save_format='h5')
105    save.load_model(path)
106
107  @testing_utils.run_v2_only
108  def test_save_tf(self):
109    path = os.path.join(self.get_temp_dir(), 'model')
110    save.save_model(self.model, path, save_format='tf')
111    self.assert_saved_model(path)
112    with self.assertRaisesRegex(ValueError, 'input shapes have not been set'):
113      save.save_model(self.subclassed_model, path, save_format='tf')
114    self.subclassed_model.predict(np.random.random((3, 5)))
115    save.save_model(self.subclassed_model, path, save_format='tf')
116    self.assert_saved_model(path)
117
118  @testing_utils.run_v2_only
119  def test_save_load_tf_string(self):
120    path = os.path.join(self.get_temp_dir(), 'model')
121    save.save_model(self.model, path, save_format='tf')
122    save.load_model(path)
123
124  @testing_utils.run_v2_only
125  def test_save_load_tf_pathlib(self):
126    path = pathlib.Path(self.get_temp_dir()) / 'model'
127    save.save_model(self.model, path, save_format='tf')
128    save.load_model(path)
129
130  @testing_utils.run_v2_only
131  def test_save_load_weights_tf_pathlib(self):
132    path = pathlib.Path(self.get_temp_dir()) / 'model'
133    self.model.save_weights(path, save_format='tf')
134    self.model.load_weights(path)
135
136  @testing_utils.run_v2_only
137  def test_save_load_weights_hdf5_pathlib(self):
138    path = pathlib.Path(self.get_temp_dir()) / 'model'
139    self.model.save_weights(path, save_format='h5')
140    self.model.load_weights(path)
141
142  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
143  def test_saving_with_dense_features(self):
144    cols = [
145        feature_column_lib.numeric_column('a'),
146        feature_column_lib.indicator_column(
147            feature_column_lib.categorical_column_with_vocabulary_list(
148                'b', ['one', 'two']))
149    ]
150    input_layers = {
151        'a': keras.layers.Input(shape=(1,), name='a'),
152        'b': keras.layers.Input(shape=(1,), name='b', dtype='string')
153    }
154
155    fc_layer = dense_features.DenseFeatures(cols)(input_layers)
156    output = keras.layers.Dense(10)(fc_layer)
157
158    model = keras.models.Model(input_layers, output)
159
160    model.compile(
161        loss=keras.losses.MSE,
162        optimizer='rmsprop',
163        metrics=[keras.metrics.categorical_accuracy])
164
165    config = model.to_json()
166    loaded_model = model_config.model_from_json(config)
167
168    inputs_a = np.arange(10).reshape(10, 1)
169    inputs_b = np.arange(10).reshape(10, 1).astype('str')
170
171    with self.cached_session():
172      # Initialize tables for V1 lookup.
173      if not context.executing_eagerly():
174        self.evaluate(lookup_ops.tables_initializer())
175
176      self.assertLen(loaded_model.predict({'a': inputs_a, 'b': inputs_b}), 10)
177
178  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
179  def test_saving_with_sequence_features(self):
180    cols = [
181        feature_column_lib.sequence_numeric_column('a'),
182        feature_column_lib.indicator_column(
183            feature_column_lib.sequence_categorical_column_with_vocabulary_list(
184                'b', ['one', 'two']))
185    ]
186    input_layers = {
187        'a':
188            keras.layers.Input(shape=(None, 1), sparse=True, name='a'),
189        'b':
190            keras.layers.Input(
191                shape=(None, 1), sparse=True, name='b', dtype='string')
192    }
193
194    fc_layer, _ = ksfc.SequenceFeatures(cols)(input_layers)
195    # TODO(tibell): Figure out the right dtype and apply masking.
196    # sequence_length_mask = array_ops.sequence_mask(sequence_length)
197    # x = keras.layers.GRU(32)(fc_layer, mask=sequence_length_mask)
198    x = keras.layers.GRU(32)(fc_layer)
199    output = keras.layers.Dense(10)(x)
200
201    model = keras.models.Model(input_layers, output)
202
203    model.compile(
204        loss=keras.losses.MSE,
205        optimizer='rmsprop',
206        metrics=[keras.metrics.categorical_accuracy])
207
208    config = model.to_json()
209    loaded_model = model_config.model_from_json(config)
210
211    batch_size = 10
212    timesteps = 1
213
214    values_a = np.arange(10, dtype=np.float32)
215    indices_a = np.zeros((10, 3), dtype=np.int64)
216    indices_a[:, 0] = np.arange(10)
217    inputs_a = sparse_tensor.SparseTensor(indices_a, values_a,
218                                          (batch_size, timesteps, 1))
219
220    values_b = np.zeros(10, dtype=np.str_)
221    indices_b = np.zeros((10, 3), dtype=np.int64)
222    indices_b[:, 0] = np.arange(10)
223    inputs_b = sparse_tensor.SparseTensor(indices_b, values_b,
224                                          (batch_size, timesteps, 1))
225
226    with self.cached_session():
227      # Initialize tables for V1 lookup.
228      if not context.executing_eagerly():
229        self.evaluate(lookup_ops.tables_initializer())
230
231      self.assertLen(
232          loaded_model.predict({
233              'a': inputs_a,
234              'b': inputs_b
235          }, steps=1), batch_size)
236
237  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
238  def test_saving_h5_for_rnn_layers(self):
239    # See https://github.com/tensorflow/tensorflow/issues/35731 for details.
240    inputs = keras.Input([10, 91], name='train_input')
241    rnn_layers = [
242        keras.layers.LSTMCell(size, recurrent_dropout=0, name='rnn_cell%d' % i)
243        for i, size in enumerate([512, 512])
244    ]
245    rnn_output = keras.layers.RNN(
246        rnn_layers, return_sequences=True, name='rnn_layer')(inputs)
247    pred_feat = keras.layers.Dense(91, name='prediction_features')(rnn_output)
248    pred = keras.layers.Softmax()(pred_feat)
249    model = keras.Model(inputs=[inputs], outputs=[pred, pred_feat])
250    path = os.path.join(self.get_temp_dir(), 'model_path.h5')
251    model.save(path)
252
253    # Make sure the variable name is unique.
254    self.assertNotEqual(rnn_layers[0].kernel.name,
255                        rnn_layers[1].kernel.name)
256    self.assertIn('rnn_cell1', rnn_layers[1].kernel.name)
257
258  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
259  def test_saving_optimizer_weights(self):
260
261    class MyModel(keras.Model):
262
263      def __init__(self):
264        super(MyModel, self).__init__()
265        self.layer = keras.layers.Dense(1)
266
267      def call(self, x):
268        return self.layer(x)
269
270    path = os.path.join(self.get_temp_dir(), 'weights_path')
271    x, y = np.ones((10, 10)), np.ones((10, 1))
272
273    model = MyModel()
274    model.compile('rmsprop', loss='bce')
275    model.train_on_batch(x, y)
276    model.reset_metrics()
277    model.save_weights(path, save_format='tf')
278
279    batch_loss = model.train_on_batch(x, y)
280
281    new_model = MyModel()
282    new_model.compile('rmsprop', loss='bce')
283    new_model.train_on_batch(x, y)
284    new_model.reset_metrics()
285
286    new_model.load_weights(path)
287    new_batch_loss = new_model.train_on_batch(x, y)
288
289    self.assertAllClose(batch_loss, new_batch_loss)
290
291  @combinations.generate(combinations.combine(mode=['eager', 'graph']))
292  def test_save_include_optimizer_false(self):
293
294    def get_variables(file_name):
295      reader = training_module.load_checkpoint(
296          os.path.join(file_name, 'variables/variables'))
297      shape_from_key = reader.get_variable_to_shape_map()
298      return sorted(shape_from_key.keys())
299
300    path = os.path.join(self.get_temp_dir(), 'no_optimizer')
301    x, y = np.ones((10, 10)), np.ones((10, 1))
302
303    model = keras.models.Sequential()
304    model.add(keras.layers.Dense(1))
305    model.compile('adam', loss='mse')
306    model.train_on_batch(x, y)
307    model.save(path, save_format='tf', include_optimizer=False)
308    variables = get_variables(path)
309
310    for v in variables:
311      self.assertNotIn('optimizer', v)
312
313  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
314  def test_saving_model_with_custom_object(self):
315    with generic_utils.custom_object_scope(), self.cached_session():
316
317      @generic_utils.register_keras_serializable()
318      class CustomLoss(losses.MeanSquaredError):
319        pass
320
321      model = sequential.Sequential(
322          [core.Dense(units=1, input_shape=(1,))])
323      model.compile(optimizer='sgd', loss=CustomLoss())
324      model.fit(np.zeros([10, 1]), np.zeros([10, 1]))
325
326      temp_dir = self.get_temp_dir()
327      filepath = os.path.join(temp_dir, 'saving')
328      model.save(filepath)
329
330      # Make sure the model can be correctly load back.
331      _ = save.load_model(filepath, compile=True)
332
333
334@keras_parameterized.run_with_all_saved_model_formats
335class TestWholeModelSaving(keras_parameterized.TestCase):
336
337  def _save_model_dir(self, dirname='saved_model'):
338    temp_dir = self.get_temp_dir()
339    self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True)
340    return os.path.join(temp_dir, dirname)
341
342  def _assert_same_weights_and_metrics(self, model, loaded_model):
343    """Checks that the loaded weights and metrics are the same as the original.
344
345    Args:
346      model: original model
347      loaded_model: loaded model
348    """
349    self.assertAllClose(model.weights, loaded_model.weights)
350
351    if loaded_model.optimizer:
352      if testing_utils.get_save_format() == 'tf':
353        # TODO(b/153110928): Keras TF format doesn't restore optimizer weights
354        # currently.
355        return
356      self.assertAllClose(model.optimizer.weights,
357                          loaded_model.optimizer.weights)
358
359    # In V1/Graph mode, the model isn't built, so the metrics are not loaded
360    # immediately (requires model to be called on some data before building
361    # metrics).
362    check_metrics = tf2.enabled() and context.executing_eagerly()
363
364    if check_metrics:
365      self.assertAllEqual([m.name for m in model.metrics],
366                          [m.name for m in loaded_model.metrics])
367
368  @keras_parameterized.run_with_all_model_types
369  @keras_parameterized.run_all_keras_modes
370  def test_save_and_load(self):
371    saved_model_dir = self._save_model_dir()
372    save_format = testing_utils.get_save_format()
373    save_kwargs = testing_utils.get_save_kwargs()
374
375    if ((save_format == 'h5' or not save_kwargs.get('save_traces', True)) and
376        testing_utils.get_model_type() == 'subclass'):
377      # HDF5 format currently does not allow saving subclassed models.
378      # When saving with `save_traces=False`, the subclassed model must have a
379      # get_config/from_config, which the autogenerated model does not have.
380      return
381
382    with self.cached_session():
383      model = testing_utils.get_model_from_layers(
384          [keras.layers.Dense(2),
385           keras.layers.RepeatVector(3),
386           keras.layers.TimeDistributed(keras.layers.Dense(3))],
387          input_shape=(3,))
388      model.compile(
389          loss=keras.losses.MSE,
390          optimizer=keras.optimizer_v2.rmsprop.RMSprop(lr=0.0001),
391          metrics=[
392              keras.metrics.categorical_accuracy,
393              keras.metrics.CategoricalCrossentropy(
394                  name='cce', label_smoothing=constant_op.constant(0.2)),
395          ],
396          weighted_metrics=[
397              keras.metrics.categorical_crossentropy,
398              keras.metrics.CategoricalCrossentropy(
399                  name='cce', label_smoothing=constant_op.constant(0.2)),
400          ],
401          sample_weight_mode='temporal')
402
403      x = np.random.random((1, 3))
404      y = np.random.random((1, 3, 3))
405      model.train_on_batch(x, y)
406
407      out = model.predict(x)
408      keras.models.save_model(
409          model, saved_model_dir, save_format=save_format,
410          **save_kwargs)
411
412      loaded_model = keras.models.load_model(saved_model_dir)
413      self._assert_same_weights_and_metrics(model, loaded_model)
414
415      out2 = loaded_model.predict(x)
416      self.assertAllClose(out, out2, atol=1e-05)
417
418      eval_out = model.evaluate(x, y)
419      eval_out2 = loaded_model.evaluate(x, y)
420      self.assertArrayNear(eval_out, eval_out2, 0.001)
421
422  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
423  def test_sequential_model_saving_without_input_shape(self):
424    saved_model_dir = self._save_model_dir()
425    save_format = testing_utils.get_save_format()
426    with self.cached_session():
427      model = keras.models.Sequential()
428      model.add(keras.layers.Dense(2))
429      model.add(keras.layers.RepeatVector(3))
430      model.add(keras.layers.TimeDistributed(keras.layers.Dense(3)))
431      model.compile(
432          loss=keras.losses.MSE,
433          optimizer='rmsprop',
434          metrics=[
435              keras.metrics.categorical_accuracy,
436              keras.metrics.CategoricalAccuracy(name='cat_acc')
437          ],
438          weighted_metrics=[
439              keras.metrics.categorical_accuracy,
440              keras.metrics.CategoricalAccuracy(name='cat_acc2')
441          ],
442          sample_weight_mode='temporal')
443      x = np.random.random((1, 3))
444      y = np.random.random((1, 3, 3))
445      model.train_on_batch(x, y)
446
447      out = model.predict(x)
448      model.save(saved_model_dir, save_format=save_format)
449
450      new_model = keras.models.load_model(saved_model_dir)
451
452      self._assert_same_weights_and_metrics(model, new_model)
453
454      out2 = new_model.predict(x)
455      self.assertAllClose(out, out2, atol=1e-05)
456
457  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
458  def test_sequential_model_saving_without_compile(self):
459    saved_model_dir = self._save_model_dir()
460    save_format = testing_utils.get_save_format()
461    with self.cached_session():
462      model = keras.models.Sequential()
463      model.add(keras.layers.Dense(2, input_shape=(3,)))
464      model.add(keras.layers.RepeatVector(3))
465      model.add(keras.layers.TimeDistributed(keras.layers.Dense(3)))
466
467      x = np.random.random((1, 3))
468      out = model.predict(x)
469
470      # Save the model without any compilation or training.
471      keras.models.save_model(model, saved_model_dir, save_format=save_format)
472
473      new_model = keras.models.load_model(saved_model_dir)
474      self._assert_same_weights_and_metrics(model, new_model)
475
476      out2 = new_model.predict(x)
477      self.assertAllClose(out, out2, atol=1e-05)
478
479  def test_sequential_model_saving_2(self):
480    saved_model_dir = self._save_model_dir()
481    save_format = testing_utils.get_save_format()
482
483    with ops.Graph().as_default(), self.cached_session():
484      # test with custom optimizer, loss
485
486      class CustomOp(optimizer_v1.RMSprop):
487        pass
488
489      def custom_loss(y_true, y_pred):
490        return keras.losses.mse(y_true, y_pred)
491
492      model = keras.models.Sequential()
493      model.add(keras.layers.Dense(2, input_shape=(3,)))
494      model.add(keras.layers.Dense(3))
495      model.compile(loss=custom_loss, optimizer=CustomOp(), metrics=['acc'])
496
497      x = np.random.random((1, 3))
498      y = np.random.random((1, 3))
499      model.train_on_batch(x, y)
500
501      out = model.predict(x)
502      keras.models.save_model(model, saved_model_dir, save_format=save_format)
503
504      new_model = keras.models.load_model(
505          saved_model_dir,
506          custom_objects={'CustomOp': CustomOp,
507                          'custom_loss': custom_loss})
508      self._assert_same_weights_and_metrics(model, new_model)
509
510      out2 = new_model.predict(x)
511      self.assertAllClose(out, out2, atol=1e-05)
512
513  def test_saving_without_compilation(self):
514    saved_model_dir = self._save_model_dir()
515    save_format = testing_utils.get_save_format()
516    model = keras.models.Sequential()
517    model.add(keras.layers.Dense(2, input_shape=(3,)))
518    model.add(keras.layers.Dense(3))
519    model.compile(loss='mse', optimizer='sgd', metrics=['acc'])
520
521    keras.models.save_model(model, saved_model_dir, save_format=save_format)
522    model = keras.models.load_model(saved_model_dir)
523
524  def test_saving_with_tf_optimizer(self):
525    saved_model_dir = self._save_model_dir()
526    save_format = testing_utils.get_save_format()
527
528    model = keras.models.Sequential()
529    model.add(keras.layers.Dense(2, input_shape=(3,)))
530    model.add(keras.layers.Dense(3))
531    model.compile(loss='mse',
532                  optimizer=training_module.AdadeltaOptimizer(0.1),
533                  metrics=['acc'])
534
535    keras.models.save_model(model, saved_model_dir, save_format=save_format)
536    model = keras.models.load_model(saved_model_dir)
537
538  def test_saving_right_after_compilation(self):
539    saved_model_dir = self._save_model_dir()
540    save_format = testing_utils.get_save_format()
541    with self.cached_session():
542      model = keras.models.Sequential()
543      model.add(keras.layers.Dense(2, input_shape=(3,)))
544      model.add(keras.layers.Dense(3))
545      model.compile(loss='mse', optimizer='sgd', metrics=['acc'])
546      if not ops.executing_eagerly_outside_functions():
547        model._make_train_function()
548      keras.models.save_model(model, saved_model_dir, save_format=save_format)
549      model = keras.models.load_model(saved_model_dir)
550
551  def test_saving_lambda_numpy_array_arguments(self):
552    saved_model_dir = self._save_model_dir()
553    save_format = testing_utils.get_save_format()
554
555    if h5py is None:
556      self.skipTest('h5py required to run this test')
557
558    mean = np.random.random((4, 2, 3))
559    std = np.abs(np.random.random((4, 2, 3))) + 1e-5
560    inputs = keras.layers.Input(shape=(4, 2, 3))
561    output = keras.layers.Lambda(lambda image, mu, std: (image - mu) / std,
562                                 arguments={'mu': mean, 'std': std})(inputs)
563    model = keras.models.Model(inputs, output)
564    model.compile(loss='mse', optimizer='sgd', metrics=['acc'])
565
566    keras.models.save_model(model, saved_model_dir, save_format=save_format)
567
568    model = keras.models.load_model(saved_model_dir)
569
570    self.assertAllClose(mean, model.layers[1].arguments['mu'])
571    self.assertAllClose(std, model.layers[1].arguments['std'])
572
573  def test_saving_model_with_long_layer_names(self):
574    saved_model_dir = self._save_model_dir()
575    save_format = testing_utils.get_save_format()
576    with self.cached_session():
577      # This layer name will make the `layers_name` HDF5 attribute blow
578      # out of proportion. Note that it fits into the internal HDF5
579      # attribute memory limit on its own but because h5py converts
580      # the list of layer names into numpy array, which uses the same
581      # amount of memory for every item, it increases the memory
582      # requirements substantially.
583      x = keras.Input(shape=(2,), name='input_' + ('x' * (2**15)))
584      f = x
585      for i in range(4):
586        f = keras.layers.Dense(2, name='dense_%d' % (i,))(f)
587      model = keras.Model(inputs=[x], outputs=[f])
588      model.compile(
589          'adam', loss=keras.losses.MeanSquaredError(), metrics=['acc'])
590
591      x = np.random.random((1, 2))
592      y = np.random.random((1, 2))
593      model.train_on_batch(x, y)
594      out = model.predict(x)
595
596      keras.models.save_model(model, saved_model_dir, save_format=save_format)
597      model = keras.models.load_model(saved_model_dir)
598
599      if save_format in ['tf', 'tensorflow']:
600        return
601      # Check that the HDF5 files contains chunked array
602      # of layer names.
603      with h5py.File(saved_model_dir, 'r') as h5file:
604        num_names_arrays = len([attr for attr in h5file['model_weights'].attrs
605                                if attr.startswith('layer_names')])
606      # The chunking of layer names array should have happened.
607      self.assertGreater(num_names_arrays, 0)
608      out2 = model.predict(x)
609      self.assertAllClose(out, out2, atol=1e-05)
610
611  def test_saving_model_with_long_weights_names(self):
612    saved_model_dir = self._save_model_dir()
613    save_format = testing_utils.get_save_format()
614
615    with self.cached_session():
616      x = keras.Input(shape=(2,), name='nested_model_input')
617      f = x
618      for i in range(4):
619        f = keras.layers.Dense(2, name='nested_model_dense_%d' % (i,))(f)
620      # This layer name will make the `weights_name`
621      # HDF5 attribute blow out of proportion.
622      f = keras.layers.Dense(2, name='nested_model_output' + ('x' * (2**14)))(f)
623      nested_model = keras.Model(inputs=[x], outputs=[f], name='nested_model')
624
625      x = keras.Input(shape=(2,), name='outer_model_input')
626      f = nested_model(x)
627      f = keras.layers.Dense(2, name='outer_model_output')(f)
628
629      model = keras.Model(inputs=[x], outputs=[f])
630      model.compile(loss='mse', optimizer='adam', metrics=['acc'])
631
632      x = np.random.random((1, 2))
633      y = np.random.random((1, 2))
634      model.train_on_batch(x, y)
635      out = model.predict(x)
636
637      keras.models.save_model(model, saved_model_dir, save_format=save_format)
638      model = keras.models.load_model(saved_model_dir)
639
640      if save_format in ['h5', 'hdf5', 'keras']:
641        # Check that the HDF5 files contains chunked array
642        # of weight names.
643        with h5py.File(saved_model_dir, 'r') as h5file:
644          num_weight_arrays = len(
645              [attr for attr in h5file['model_weights']['nested_model'].attrs
646               if attr.startswith('weight_names')])
647        # The chunking of layer names array should have happened.
648        self.assertGreater(num_weight_arrays, 0)
649      out2 = model.predict(x)
650      self.assertAllClose(out, out2, atol=1e-05)
651
652  def test_model_saving_to_pre_created_h5py_file(self):
653    saved_model_dir = self._save_model_dir()
654    save_format = testing_utils.get_save_format()
655    with ops.Graph().as_default(), self.cached_session():
656      inputs = keras.Input(shape=(3,))
657      x = keras.layers.Dense(2)(inputs)
658      outputs = keras.layers.Dense(3)(x)
659
660      model = keras.Model(inputs, outputs)
661      model.compile(
662          loss=keras.losses.MSE,
663          optimizer=optimizer_v1.Adam(),
664          metrics=[
665              keras.metrics.categorical_accuracy,
666              keras.metrics.CategoricalAccuracy()
667          ])
668      x = np.random.random((1, 3))
669      y = np.random.random((1, 3))
670      model.train_on_batch(x, y)
671
672      out = model.predict(x)
673
674      keras.models.save_model(model, saved_model_dir, save_format=save_format)
675      loaded_model = keras.models.load_model(saved_model_dir)
676      out1 = loaded_model.predict(x)
677      self.assertAllClose(out, out1, atol=1e-05)
678      if save_format in ['tf', 'tensorflow']:
679        return
680
681      # Test h5 format specifically
682      fd, fname = tempfile.mkstemp('.h5')
683      with h5py.File(fname, mode='r+') as h5file:
684        keras.models.save_model(model, h5file)
685        loaded_model = keras.models.load_model(h5file)
686        out2 = loaded_model.predict(x)
687      self.assertAllClose(out, out2, atol=1e-05)
688
689      # Test non-default options in h5
690      with h5py.File(
691          '_', driver='core', mode='w', backing_store=False) as h5file:
692        keras.models.save_model(model, h5file)
693        loaded_model = keras.models.load_model(h5file)
694        out2 = loaded_model.predict(x)
695      self.assertAllClose(out, out2, atol=1e-05)
696
697      # Cleanup
698      os.close(fd)
699      os.remove(fname)
700
701  def test_model_saving_to_new_dir_path(self):
702    saved_model_dir = os.path.join(self._save_model_dir(), 'newdir',
703                                   'saved_model')
704    save_format = testing_utils.get_save_format()
705
706    with self.cached_session():
707      model = keras.models.Sequential()
708      model.add(keras.layers.Dense(2, input_shape=(3,)))
709      model.add(keras.layers.RepeatVector(3))
710      model.add(keras.layers.TimeDistributed(keras.layers.Dense(3)))
711
712      x = np.random.random((1, 3))
713      out = model.predict(x)
714
715      keras.models.save_model(model, saved_model_dir, save_format=save_format)
716
717      new_model = keras.models.load_model(saved_model_dir)
718      self._assert_same_weights_and_metrics(model, new_model)
719
720      out2 = new_model.predict(x)
721      self.assertAllClose(out, out2, atol=1e-05)
722
723  def test_model_raise_exception_with_failed_saving(self):
724    if h5py is None:
725      self.skipTest('h5py required to run this test')
726
727    saved_model_dir = self._save_model_dir()
728    saved_model_path = os.path.join(saved_model_dir, 'saved_model.h5')
729
730    with self.cached_session():
731      model = keras.models.Sequential()
732      model.add(keras.layers.Dense(2, input_shape=(3,)))
733      model.add(keras.layers.RepeatVector(3))
734      model.add(keras.layers.TimeDistributed(keras.layers.Dense(3)))
735
736      with self.assertRaisesRegex(OSError, 'Unable to create file'):
737        with h5py.File(saved_model_path, 'w'):
738          keras.models.save_model(model, saved_model_path)
739
740  def test_saving_constant_initializer_with_numpy(self):
741    saved_model_dir = self._save_model_dir()
742    save_format = testing_utils.get_save_format()
743
744    model = keras.models.Sequential()
745    model.add(
746        keras.layers.Dense(
747            2,
748            input_shape=(3,),
749            kernel_initializer=keras.initializers.Constant(np.ones((3, 2)))))
750    model.add(keras.layers.Dense(3))
751    model.compile(loss='mse', optimizer='sgd', metrics=['acc'])
752    keras.models.save_model(model, saved_model_dir, save_format=save_format)
753    model = keras.models.load_model(saved_model_dir)
754
755  def test_saving_group_naming_h5py(self):
756    # Test saving model with layer which name is prefix to a previous layer
757    # name.
758
759    temp_dir = self.get_temp_dir()
760    self.addCleanup(shutil.rmtree, temp_dir)
761    h5_path = os.path.join(temp_dir, 'test.h5')
762
763    input_layer = keras.layers.Input((None, None, 3), name='test_input')
764    x = keras.layers.Conv2D(1, 1, name='conv1/conv')(input_layer)
765    x = keras.layers.Activation('relu', name='conv1')(x)
766    model = keras.models.Model(inputs=input_layer, outputs=x)
767
768    model.save_weights(h5_path)
769    model.load_weights(h5_path)
770
771  def test_primitive_attrs_contain_no_extraneous_strings(self):
772    if h5py is None:
773      self.skipTest('h5py required to run this test')
774
775    saved_model_dir = self._save_model_dir()
776    save_format = testing_utils.get_save_format()
777    model = keras.models.Sequential()
778    model.add(keras.layers.Dense(1, input_shape=[2]))
779    model.save(saved_model_dir, save_format=save_format)
780    if save_format in ['tf', 'tensorflow']:
781      return
782
783    h5file = h5py.File(saved_model_dir, 'r')
784    self.assertRegex(h5file.attrs['keras_version'], r'^[\d]+\.[\d]+\.[\S]+$')
785
786  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
787  def test_functional_model_with_custom_loss_and_metric(self):
788    def _make_model():
789      inputs = keras.Input(shape=(4,))
790      x = keras.layers.Dense(8, activation='relu')(inputs)
791      outputs = keras.layers.Dense(3, activation='softmax')(x)
792      model = keras.Model(inputs=inputs, outputs=outputs)
793      custom_loss = keras.layers.Lambda(lambda x: keras.backend.sum(x * x))(x)
794      model.add_loss(custom_loss)
795      model.add_metric(custom_loss, aggregation='mean', name='custom_loss')
796      return model
797
798    saved_model_dir = self._save_model_dir()
799    save_format = testing_utils.get_save_format()
800
801    with self.cached_session():
802      model = _make_model()
803      model.compile(
804          loss=keras.losses.SparseCategoricalCrossentropy(),
805          optimizer=optimizers.gradient_descent_v2.SGD(),
806          metrics=[keras.metrics.SparseCategoricalCrossentropy()])
807      x = np.random.normal(size=(32, 4))
808      y = np.random.randint(0, 3, size=32)
809      model.train_on_batch(x, y)
810      evaluation_results = model.evaluate(x, y)
811      # Save and reload model.
812      model.save(saved_model_dir, save_format=save_format)
813      del model  # Prevent misuse.
814      loaded_model = keras.models.load_model(saved_model_dir)
815      loaded_model_eval_results = loaded_model.evaluate(x, y)
816      # Assert all evaluation results are the same.
817      self.assertAllClose(evaluation_results, loaded_model_eval_results, 1e-9)
818      # Check correctness of the loss calculation.
819      self.assertAllGreater(evaluation_results, 0.)
820      evaluation_results = dict(
821          zip(loaded_model.metrics_names, evaluation_results))
822      self.assertNear(
823          evaluation_results['sparse_categorical_crossentropy'] +
824          evaluation_results['custom_loss'], evaluation_results['loss'], 1e-6)
825
826  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
827  def test_save_uncompiled_model_with_optimizer(self):
828    with self.cached_session() as session:
829      saved_model_dir = self._save_model_dir()
830      save_format = testing_utils.get_save_format()
831      model = keras.models.Sequential([keras.layers.Dense(1, input_shape=(3,))])
832      # Set the model's optimizer but don't compile. This can happen if the
833      # model is trained with a custom training loop.
834      model.optimizer = keras.optimizer_v2.rmsprop.RMSprop(lr=0.0001)
835      if not context.executing_eagerly():
836        session.run([v.initializer for v in model.variables])
837      model.save(saved_model_dir, save_format=save_format)
838
839      if save_format in ['tf', 'tensorflow']:
840        loaded = keras.models.load_model(saved_model_dir)
841        self.assertIsInstance(loaded.optimizer,
842                              keras.optimizer_v2.optimizer_v2.OptimizerV2)
843
844  @combinations.generate(combinations.combine(mode=['eager']))
845  def test_functional_model_with_getitem_op_layer(self):
846    inp = keras.Input(shape=(8))
847
848    out = inp[:]
849    model = keras.Model(
850        inputs=[inp],
851        outputs=out)
852    batch_size = 7
853    x = array_ops.stack([
854        math_ops.range(8) for _ in range(batch_size)])
855    args = [x]
856    expected = x[:]
857
858    self.assertAllEqual(model(args), expected)
859    self.assertAllEqual(model.predict(args, batch_size=batch_size), expected)
860
861    # Make sure it can be successfully saved and loaded.
862    save_format = testing_utils.get_save_format()
863    saved_model_dir = self._save_model_dir()
864    keras.models.save_model(model, saved_model_dir, save_format=save_format)
865
866    loaded_model = keras.models.load_model(saved_model_dir)
867
868    self.assertAllEqual(loaded_model(args), expected)
869    self.assertAllEqual(loaded_model.predict(args, batch_size=batch_size),
870                        expected)
871
872  @combinations.generate(combinations.combine(mode=['eager', 'graph']))
873  def test_custom_functional_registered(self):
874
875    def _get_cls_definition():
876      class CustomModel(keras.Model):
877
878        def c(self):
879          return 'c'
880
881      return CustomModel
882
883    cls = _get_cls_definition()
884    self.assertEqual(cls.__bases__[0], keras.Model)
885
886    with self.cached_session() as sess:
887      input_ = keras.layers.Input(shape=(1,))
888      output = keras.layers.Dense(1)(input_)
889      model = cls(input_, output)
890      # `cls` now inherits from `Functional` class.
891      self.assertEqual(cls.__bases__[0], functional.Functional)
892
893      if not context.executing_eagerly():
894        sess.run([v.initializer for v in model.variables])
895
896      save_format = testing_utils.get_save_format()
897      saved_model_dir = self._save_model_dir()
898      keras.models.save_model(model, saved_model_dir, save_format=save_format)
899
900    loaded_model = keras.models.load_model(
901        saved_model_dir, custom_objects={'CustomModel': cls})
902    self.assertIsInstance(loaded_model, cls)
903
904    # Check with "new" `CustomModel` class definition.
905    new_cls = _get_cls_definition()
906    # The new `CustomModel` class is *not* derived from `Functional`.
907    self.assertEqual(new_cls.__bases__[0], keras.Model)
908    reloaded_model = keras.models.load_model(
909        saved_model_dir, custom_objects={'CustomModel': new_cls})
910    self.assertIsInstance(reloaded_model, new_cls)
911
912  @combinations.generate(combinations.combine(mode=['eager']))
913  def test_shared_objects(self):
914    class OuterLayer(keras.layers.Layer):
915
916      def __init__(self, inner_layer):
917        super(OuterLayer, self).__init__()
918        self.inner_layer = inner_layer
919
920      def call(self, inputs):
921        return self.inner_layer(inputs)
922
923      def get_config(self):
924        return {
925            'inner_layer': generic_utils.serialize_keras_object(
926                self.inner_layer)
927        }
928
929      @classmethod
930      def from_config(cls, config):
931        return cls(generic_utils.deserialize_keras_object(
932            config['inner_layer']))
933
934    class InnerLayer(keras.layers.Layer):
935
936      def __init__(self):
937        super(InnerLayer, self).__init__()
938        self.v = self.add_weight(name='v', shape=[], dtype=dtypes.float32)
939
940      def call(self, inputs):
941        return self.v + inputs
942
943      @classmethod
944      def from_config(cls, config):
945        return cls()
946
947    # Create a model with 2 output layers that share the same inner layer.
948    inner_layer = InnerLayer()
949    outer_layer_1 = OuterLayer(inner_layer)
950    outer_layer_2 = OuterLayer(inner_layer)
951    input_ = keras.Input(shape=(1,))
952    model = keras.Model(
953        inputs=input_, outputs=[outer_layer_1(input_), outer_layer_2(input_)])
954
955    # Changes to the shared layer should affect both outputs.
956    model.layers[1].inner_layer.v.assign(5)
957    self.assertAllEqual(model(1), [6.0, 6.0])
958    model.layers[1].inner_layer.v.assign(3)
959    self.assertAllEqual(model(1), [4.0, 4.0])
960
961    # After loading, changes to the shared layer should still affect both
962    # outputs.
963    def _do_assertions(loaded):
964      loaded.layers[1].inner_layer.v.assign(5)
965      self.assertAllEqual(loaded(1), [6.0, 6.0])
966      loaded.layers[1].inner_layer.v.assign(3)
967      self.assertAllEqual(loaded(1), [4.0, 4.0])
968      loaded.layers[2].inner_layer.v.assign(5)
969      self.assertAllEqual(loaded(1), [6.0, 6.0])
970      loaded.layers[2].inner_layer.v.assign(3)
971      self.assertAllEqual(loaded(1), [4.0, 4.0])
972
973    # We'd like to make sure we only attach shared object IDs when strictly
974    # necessary, so we'll recursively traverse the generated config to count
975    # whether we have the exact number we expect.
976    def _get_all_keys_recursive(dict_or_iterable):
977      if isinstance(dict_or_iterable, dict):
978        for key in dict_or_iterable.keys():
979          yield key
980        for key in _get_all_keys_recursive(dict_or_iterable.values()):
981          yield key
982      elif isinstance(dict_or_iterable, str):
983        return
984      else:
985        try:
986          for item in dict_or_iterable:
987            for key in _get_all_keys_recursive(item):
988              yield key
989        # Not an iterable or dictionary
990        except TypeError:
991          return
992
993    with generic_utils.CustomObjectScope({
994        'OuterLayer': OuterLayer, 'InnerLayer': InnerLayer}):
995
996      # Test saving and loading to disk
997      save_format = testing_utils.get_save_format()
998      saved_model_dir = self._save_model_dir()
999      keras.models.save_model(model, saved_model_dir, save_format=save_format)
1000      loaded = keras.models.load_model(saved_model_dir)
1001      _do_assertions(loaded)
1002
1003      # Test recreating directly from config
1004      config = model.get_config()
1005      key_count = collections.Counter(_get_all_keys_recursive(config))
1006      self.assertEqual(key_count[generic_utils.SHARED_OBJECT_KEY], 2)
1007      loaded = keras.Model.from_config(config)
1008      _do_assertions(loaded)
1009
1010  @combinations.generate(combinations.combine(mode=['eager']))
1011  def test_shared_objects_wrapper(self):
1012    """Tests that shared layers wrapped with `Wrapper` restore correctly."""
1013    input_ = keras.Input(shape=(1,))
1014    unwrapped = keras.layers.Layer(name='unwrapped')
1015    wrapped = keras.layers.Wrapper(unwrapped, name='wrapped')
1016    model = keras.Model(inputs=input_,
1017                        outputs=[unwrapped(input_), wrapped(input_)])
1018
1019    # Test recreating directly from config
1020    config = model.get_config()
1021    loaded = keras.Model.from_config(config)
1022    self.assertIs(loaded.layers[1], loaded.layers[2].layer)
1023
1024    # Test saving and loading to disk
1025    save_format = testing_utils.get_save_format()
1026    saved_model_dir = self._save_model_dir()
1027    keras.models.save_model(model, saved_model_dir, save_format=save_format)
1028    loaded = keras.models.load_model(saved_model_dir)
1029    self.assertIs(loaded.layers[1], loaded.layers[2].layer)
1030
1031  @combinations.generate(
1032      combinations.combine(mode=['graph', 'eager'], fit=[True, False]))
1033  def test_multi_output_metrics_name_stay_same(self, fit):
1034    """Tests that metric names don't change with each save/load cycle.
1035
1036    e.g. "head_0_accuracy" should not become "head_0_head_0_accuracy" after
1037    saving and loading a model.
1038
1039    Arguments:
1040      fit: Whether the model should be fit before saving.
1041    """
1042    # This doesn't work at all, so we can't check whether metric names are
1043    # correct.
1044    if not context.executing_eagerly() and not fit:
1045      self.skipTest('b/181767784')
1046
1047    input_ = keras.Input((4,))
1048    model = keras.Model(
1049        input_,
1050        [keras.layers.Softmax(name='head_0')(keras.layers.Dense(3)(input_)),
1051         keras.layers.Softmax(name='head_1')(keras.layers.Dense(5)(input_))])
1052    metric = keras.metrics.BinaryAccuracy()
1053    model.compile(optimizer='rmsprop',
1054                  loss='mse',
1055                  metrics={'head_0': [metric, 'accuracy']})
1056
1057    x = np.random.rand(2, 4)
1058    y = {'head_0': np.random.randint(2, size=(2, 3)),
1059         'head_1': np.random.randint(2, size=(2, 5))}
1060
1061    # Make sure metrix prefixing works the same regardless of whether the user
1062    # has fit the model before saving.
1063    if fit:
1064      model.fit(x, y, verbose=0)
1065
1066    # Save and reload.
1067    save_format = testing_utils.get_save_format()
1068    saved_model_dir = self._save_model_dir()
1069    keras.models.save_model(model, saved_model_dir, save_format=save_format)
1070    loaded = keras.models.load_model(saved_model_dir)
1071
1072    # Make sure the metrics names from the model before saving match the loaded
1073    # model.
1074    self.assertSequenceEqual(model.metrics_names, loaded.metrics_names)
1075
1076  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
1077  def test_warning_when_saving_invalid_custom_mask_layer(self):
1078
1079    class MyMasking(keras.layers.Layer):
1080
1081      def call(self, inputs):
1082        return inputs
1083
1084      def compute_mask(self, inputs, mask=None):
1085        mask = math_ops.not_equal(inputs, 0)
1086        return mask
1087
1088    class MyLayer(keras.layers.Layer):
1089
1090      def call(self, inputs, mask=None):
1091        return array_ops.identity(inputs)
1092
1093    samples = np.random.random((2, 2))
1094    model = keras.Sequential([MyMasking(), MyLayer()])
1095    model.predict(samples)
1096    with warnings.catch_warnings(record=True) as w:
1097      model.save(self._save_model_dir(), testing_utils.get_save_format())
1098    self.assertIn(generic_utils.CustomMaskWarning,
1099                  {warning.category for warning in w})
1100
1101    # Test that setting up a custom mask correctly does not issue a warning.
1102    class MyCorrectMasking(keras.layers.Layer):
1103
1104      def call(self, inputs):
1105        return inputs
1106
1107      def compute_mask(self, inputs, mask=None):
1108        mask = math_ops.not_equal(inputs, 0)
1109        return mask
1110
1111      # This get_config doesn't actually do anything because our mask is
1112      # static and doesn't need any external information to work. We do need a
1113      # dummy get_config method to prevent the warning from appearing, however.
1114      def get_config(self, *args, **kwargs):
1115        return {}
1116
1117    model = keras.Sequential([MyCorrectMasking(), MyLayer()])
1118    model.predict(samples)
1119    with warnings.catch_warnings(record=True) as w:
1120      model.save(self._save_model_dir(), testing_utils.get_save_format())
1121    self.assertNotIn(generic_utils.CustomMaskWarning,
1122                     {warning.category for warning in w})
1123
1124
1125# Factory functions to create models that will be serialized inside a Network.
1126def _make_graph_network(input_size, output_size):
1127  inputs = keras.Input(input_size)
1128  x = keras.layers.Dense(8, activation='relu')(inputs)
1129  y = keras.layers.Dense(output_size)(x)
1130  return keras.Model(inputs=inputs, outputs=y)
1131
1132
1133def _make_sequential(input_size, output_size):
1134  del input_size
1135  return keras.Sequential([
1136      keras.layers.Dense(8, activation='relu'),
1137      keras.layers.Dense(output_size),
1138  ])
1139
1140
1141def _make_sequential_built(input_size, output_size):
1142  model = _make_sequential(input_size, output_size)
1143  model.build((None, input_size))
1144  return model
1145
1146
1147def _make_sequential_graph_network(input_size, output_size):
1148  return keras.Sequential([
1149      keras.layers.InputLayer(input_size),
1150      keras.layers.Dense(8, activation='relu'),
1151      keras.layers.Dense(output_size),
1152  ])
1153
1154
1155def _make_sequential_input_shape(input_size, output_size):
1156  return keras.Sequential([
1157      keras.layers.Dense(8, activation='relu', input_shape=(input_size,)),
1158      keras.layers.Dense(output_size),
1159  ])
1160
1161
1162class _make_subclassed(keras.Model):  # pylint: disable=invalid-name
1163
1164  def __init__(self, input_size, output_size):
1165    super(_make_subclassed, self).__init__()
1166    self._config = {'input_size': input_size, 'output_size': output_size}
1167    self._hidden_layer = keras.layers.Dense(8, activation='relu', name='hidden')
1168    self._logits_layer = keras.layers.Dense(output_size, name='logits')
1169
1170  def call(self, inputs):
1171    x = self._hidden_layer(inputs)
1172    return self._logits_layer(x)
1173
1174  def get_config(self):
1175    return self._config
1176
1177  @classmethod
1178  def from_config(cls, config):
1179    return cls(**config)
1180
1181
1182class _make_subclassed_built(_make_subclassed):  # pylint: disable=invalid-name
1183
1184  def __init__(self, input_size, output_size):
1185    super(_make_subclassed_built, self).__init__(input_size, output_size)
1186    self.build((None, input_size))
1187
1188
1189@combinations.generate(combinations.combine(mode=['graph', 'eager']))
1190class TestWholeModelSavingWithNesting(test.TestCase, parameterized.TestCase):
1191  """Tests saving a whole model that contains other models."""
1192
1193  @parameterized.named_parameters([
1194      ('graph_network', _make_graph_network),
1195      ('sequential', _make_sequential),
1196      ('sequential_built', _make_sequential_built),
1197      ('sequential_graph_network', _make_sequential_graph_network),
1198      ('sequential_input_shape', _make_sequential_input_shape),
1199      ('subclassed', _make_subclassed),
1200      ('subclassed_built', _make_subclassed_built),
1201  ])
1202  def test_functional(self, model_fn):
1203    """Tests serializing a model that uses a nested model to share weights."""
1204    if h5py is None:
1205      self.skipTest('h5py required to run this test')
1206
1207    def _make_model():
1208      inputs = (keras.Input(shape=(4,), name='examples'),
1209                keras.Input(shape=(4,), name='neighbors'))
1210      base_model = model_fn(inputs[0].shape.as_list()[-1], 2)
1211      outputs = keras.layers.add([base_model(inputs[0]), base_model(inputs[1])])
1212      return keras.Model(inputs=inputs, outputs=outputs)
1213
1214    with self.cached_session():
1215      x = (np.random.normal(size=(16, 4)).astype(np.float32),
1216           np.random.normal(size=(16, 4)).astype(np.float32))
1217      model = _make_model()
1218      predictions = model(x)
1219      # Save and reload.
1220      model_path = os.path.join(self.get_temp_dir(), 'model.h5')
1221      model.save(model_path)
1222      del model
1223      loaded_model = keras.models.load_model(
1224          model_path,
1225          custom_objects={
1226              '_make_subclassed': _make_subclassed,
1227              '_make_subclassed_built': _make_subclassed_built,
1228          },
1229          compile=False)
1230      self.assertAllClose(loaded_model(x), predictions, 1e-9)
1231
1232
1233if __name__ == '__main__':
1234  test.main()
1235