• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests specific to deferred-build `Sequential` models."""
16
17import os
18import unittest
19import numpy as np
20
21from tensorflow.python import keras
22from tensorflow.python.compat import v2_compat
23from tensorflow.python.keras import keras_parameterized
24from tensorflow.python.keras import testing_utils
25from tensorflow.python.ops import math_ops
26from tensorflow.python.platform import test
27
28try:
29  import h5py  # pylint:disable=g-import-not-at-top
30except ImportError:
31  h5py = None
32
33
34class TestDeferredSequential(keras_parameterized.TestCase):
35
36  @keras_parameterized.run_all_keras_modes(always_skip_v1=True)
37  def test_build_behavior(self):
38    # Test graph network creation after __call__
39    model = get_model()
40    model(np.random.random((2, 6)))
41    self.assertLen(model.weights, 4)
42    self.assertTrue(model._is_graph_network)
43    self.assertLen(model.inputs, 1)
44    self.assertLen(model.outputs, 1)
45    self.assertEqual(model.inputs[0].shape.as_list(), [2, 6])
46    self.assertEqual(model.outputs[0].shape.as_list(), [2, 2])
47
48    # Test effect of new __call__ with a different shape
49    model(np.random.random((3, 6)))
50    self.assertLen(model.inputs, 1)
51    self.assertLen(model.outputs, 1)
52    self.assertEqual(model.inputs[0].shape.as_list(), [None, 6])
53    self.assertEqual(model.outputs[0].shape.as_list(), [None, 2])
54    model(np.random.random((4, 6)))
55    self.assertLen(model.inputs, 1)
56    self.assertLen(model.outputs, 1)
57    self.assertEqual(model.inputs[0].shape.as_list(), [None, 6])
58    self.assertEqual(model.outputs[0].shape.as_list(), [None, 2])
59
60    # Test graph network creation after build
61    model = get_model()
62    model.build((None, 6))
63    self.assertLen(model.weights, 4)
64    self.assertTrue(model._is_graph_network)
65    self.assertLen(model.inputs, 1)
66    self.assertLen(model.outputs, 1)
67    self.assertEqual(model.inputs[0].shape.as_list(), [None, 6])
68    self.assertEqual(model.outputs[0].shape.as_list(), [None, 2])
69
70    # Test graph network creation after compile/fit
71    model = get_model()
72    model.compile(
73        loss='mse',
74        optimizer='rmsprop',
75        metrics=[keras.metrics.CategoricalAccuracy()],
76        run_eagerly=testing_utils.should_run_eagerly())
77    model.fit(np.zeros((2, 6)), np.zeros((2, 2)))
78    self.assertLen(model.weights, 4)
79    self.assertTrue(model._is_graph_network)
80    self.assertLen(model.inputs, 1)
81    self.assertLen(model.outputs, 1)
82    # Inconsistency here: with eager `fit`, the model is built with shape
83    # (2, 6), but with graph function `fit`, it is built with shape `(None, 6)`.
84    # This is likely due to our assumption "the batch size should be dynamic"
85    # at the level of `Model`. TODO(fchollet): investigate and resolve.
86    self.assertEqual(model.inputs[0].shape.as_list()[-1], 6)
87    self.assertEqual(model.outputs[0].shape.as_list()[-1], 2)
88
89  @keras_parameterized.run_all_keras_modes(always_skip_v1=True)
90  def test_add_and_pop(self):
91    model = get_model()
92    model.build((None, 6))
93    self.assertTrue(model.built)
94    self.assertTrue(model._is_graph_network)
95    self.assertLen(model.layers, 3)
96    self.assertLen(model.weights, 4)
97    model.pop()
98    self.assertTrue(model.built)
99    self.assertTrue(model._is_graph_network)
100    self.assertLen(model.layers, 2)
101    self.assertLen(model.weights, 2)
102    model.add(keras.layers.Dense(2))
103    self.assertTrue(model.built)
104    self.assertTrue(model._is_graph_network)
105    self.assertLen(model.layers, 3)
106    self.assertLen(model.weights, 4)
107
108  @keras_parameterized.run_all_keras_modes(always_skip_v1=True)
109  def test_feature_extraction(self):
110    # This tests layer connectivity reset when rebuilding
111    model = get_model()
112    model(np.random.random((3, 6)))  # First build
113    model(np.random.random((4, 6)))  # Triggers a rebuild
114    # Classic feature extractor pattern
115    extractor = keras.Model(inputs=model.inputs,
116                            outputs=[layer.output for layer in model.layers])
117    # Check that inputs and outputs are connected
118    _ = extractor(np.random.random((4, 6)))
119
120  @keras_parameterized.run_all_keras_modes(always_skip_v1=True)
121  def test_saving_savedmodel(self):
122    model = get_model()
123    model(np.random.random((3, 6)))  # Build model
124
125    path = os.path.join(self.get_temp_dir(), 'model_path')
126    model.save(path)
127    new_model = keras.models.load_model(path)
128    model_layers = model._flatten_layers(include_self=True, recursive=False)
129    new_model_layers = new_model._flatten_layers(
130        include_self=True, recursive=False)
131    for layer1, layer2 in zip(model_layers, new_model_layers):
132      self.assertEqual(layer1.name, layer2.name)
133      for w1, w2 in zip(layer1.weights, layer2.weights):
134        self.assertAllClose(w1, w2)
135
136  @unittest.skipIf(h5py is None, 'Test requires h5py')
137  @keras_parameterized.run_all_keras_modes(always_skip_v1=True)
138  def test_saving_h5(self):
139    path = os.path.join(self.get_temp_dir(), 'model_path.h5')
140    model = get_model()
141    model(np.random.random((3, 6)))  # Build model
142
143    path = os.path.join(self.get_temp_dir(), 'model_path.h5')
144    model.save(path)
145    new_model = keras.models.load_model(path)
146    model_layers = model._flatten_layers(include_self=True, recursive=False)
147    new_model_layers = new_model._flatten_layers(
148        include_self=True, recursive=False)
149    for layer1, layer2 in zip(model_layers, new_model_layers):
150      self.assertEqual(layer1.name, layer2.name)
151      for w1, w2 in zip(layer1.weights, layer2.weights):
152        self.assertAllClose(w1, w2)
153
154  @keras_parameterized.run_all_keras_modes
155  def test_shared_layer(self):
156    # This tests that preexisting layer connectivity is preserved
157    # when auto-building graph networks
158    shared_layer = keras.layers.Dense(2)
159    m1 = keras.Sequential([shared_layer])
160    m1(np.random.random((3, 6)))
161    m2 = keras.Sequential([shared_layer])
162    m2(np.random.random((3, 6)))
163    # Nesting case
164    shared_layer = keras.layers.Dense(2)
165    m1 = keras.Sequential([shared_layer])
166    m2 = keras.Sequential([shared_layer, m1])
167    m2(np.random.random((3, 2)))
168
169  @keras_parameterized.run_all_keras_modes
170  def test_loss_layer(self):
171    class LossLayer(keras.layers.Layer):
172
173      def call(self, inputs):
174        self.add_loss(math_ops.reduce_sum(inputs))
175        return inputs
176
177    # Test loss layer alone
178    model = keras.Sequential([LossLayer()])
179    model.compile('rmsprop', run_eagerly=testing_utils.should_run_eagerly())
180    loss = model.train_on_batch(np.ones((2, 2)))
181    self.assertAllClose(loss, 4.)
182    model(np.random.random((4, 2)))  # Triggers a rebuild
183    loss = model.train_on_batch(np.ones((1, 2)))
184    self.assertAllClose(loss, 2.)
185
186    # Test loss layer combined with another layer
187    model = keras.Sequential([
188        keras.layers.Dense(1, kernel_initializer='ones'),
189        LossLayer()])
190    model.compile('rmsprop', run_eagerly=testing_utils.should_run_eagerly())
191    loss = model.train_on_batch(np.ones((2, 2)))
192    self.assertAllClose(loss, 4.)
193    model(np.random.random((4, 2)))  # Triggers a rebuild
194    loss = model.train_on_batch(np.ones((1, 2)))
195    self.assertLess(loss, 2.)
196
197    # Test loss layer combined with external loss
198    model = keras.Sequential([
199        keras.layers.Dense(1, kernel_initializer='ones'),
200        LossLayer()])
201    model.compile('rmsprop', 'mse',
202                  run_eagerly=testing_utils.should_run_eagerly())
203    loss = model.train_on_batch(np.ones((2, 2)), np.ones((2, 2)))
204    model(np.random.random((4, 2)))  # Triggers a rebuild
205    loss = model.train_on_batch(np.ones((1, 2)), np.ones((1, 2)))
206
207
208def get_model():
209  model = keras.models.Sequential()
210  model.add(keras.layers.Dense(2, name='first_layer'))
211  model.add(keras.layers.Dropout(0.3, name='dp'))
212  model.add(keras.layers.Dense(2, name='last_layer'))
213  return model
214
215
216if __name__ == '__main__':
217  v2_compat.enable_v2_behavior()
218  test.main()
219