• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Integration tests for Keras applications."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from absl.testing import parameterized
22
23from tensorflow.python.keras import backend
24from tensorflow.python.keras.applications import densenet
25from tensorflow.python.keras.applications import efficientnet
26from tensorflow.python.keras.applications import inception_resnet_v2
27from tensorflow.python.keras.applications import inception_v3
28from tensorflow.python.keras.applications import mobilenet
29from tensorflow.python.keras.applications import mobilenet_v2
30from tensorflow.python.keras.applications import nasnet
31from tensorflow.python.keras.applications import resnet
32from tensorflow.python.keras.applications import resnet_v2
33from tensorflow.python.keras.applications import vgg16
34from tensorflow.python.keras.applications import vgg19
35from tensorflow.python.keras.applications import xception
36from tensorflow.python.platform import test
37
38
39MODEL_LIST_NO_NASNET = [
40    (resnet.ResNet50, 2048),
41    (resnet.ResNet101, 2048),
42    (resnet.ResNet152, 2048),
43    (resnet_v2.ResNet50V2, 2048),
44    (resnet_v2.ResNet101V2, 2048),
45    (resnet_v2.ResNet152V2, 2048),
46    (vgg16.VGG16, 512),
47    (vgg19.VGG19, 512),
48    (xception.Xception, 2048),
49    (inception_v3.InceptionV3, 2048),
50    (inception_resnet_v2.InceptionResNetV2, 1536),
51    (mobilenet.MobileNet, 1024),
52    (mobilenet_v2.MobileNetV2, 1280),
53    (densenet.DenseNet121, 1024),
54    (densenet.DenseNet169, 1664),
55    (densenet.DenseNet201, 1920),
56    (efficientnet.EfficientNetB0, 1280),
57    (efficientnet.EfficientNetB1, 1280),
58    (efficientnet.EfficientNetB2, 1408),
59    (efficientnet.EfficientNetB3, 1536),
60    (efficientnet.EfficientNetB4, 1792),
61    (efficientnet.EfficientNetB5, 2048),
62    (efficientnet.EfficientNetB6, 2304),
63    (efficientnet.EfficientNetB7, 2560),
64]
65
66NASNET_LIST = [
67    (nasnet.NASNetMobile, 1056),
68    (nasnet.NASNetLarge, 4032),
69]
70
71MODEL_LIST = MODEL_LIST_NO_NASNET + NASNET_LIST
72
73
74class ApplicationsTest(test.TestCase, parameterized.TestCase):
75
76  def assertShapeEqual(self, shape1, shape2):
77    if len(shape1) != len(shape2):
78      raise AssertionError(
79          'Shapes are different rank: %s vs %s' % (shape1, shape2))
80    for v1, v2 in zip(shape1, shape2):
81      if v1 != v2:
82        raise AssertionError('Shapes differ: %s vs %s' % (shape1, shape2))
83
84  @parameterized.parameters(*MODEL_LIST)
85  def test_application_base(self, app, _):
86    # Can be instantiated with default arguments
87    model = app(weights=None)
88    # Can be serialized and deserialized
89    config = model.get_config()
90    reconstructed_model = model.__class__.from_config(config)
91    self.assertEqual(len(model.weights), len(reconstructed_model.weights))
92    backend.clear_session()
93
94  @parameterized.parameters(*MODEL_LIST)
95  def test_application_notop(self, app, last_dim):
96    if 'NASNet' in app.__name__:
97      only_check_last_dim = True
98    else:
99      only_check_last_dim = False
100    output_shape = _get_output_shape(
101        lambda: app(weights=None, include_top=False))
102    if only_check_last_dim:
103      self.assertEqual(output_shape[-1], last_dim)
104    else:
105      self.assertShapeEqual(output_shape, (None, None, None, last_dim))
106    backend.clear_session()
107
108  @parameterized.parameters(MODEL_LIST)
109  def test_application_pooling(self, app, last_dim):
110    output_shape = _get_output_shape(
111        lambda: app(weights=None, include_top=False, pooling='avg'))
112    self.assertShapeEqual(output_shape, (None, last_dim))
113
114  @parameterized.parameters(*MODEL_LIST_NO_NASNET)
115  def test_application_variable_input_channels(self, app, last_dim):
116    if backend.image_data_format() == 'channels_first':
117      input_shape = (1, None, None)
118    else:
119      input_shape = (None, None, 1)
120    output_shape = _get_output_shape(
121        lambda: app(weights=None, include_top=False, input_shape=input_shape))
122    self.assertShapeEqual(output_shape, (None, None, None, last_dim))
123    backend.clear_session()
124
125    if backend.image_data_format() == 'channels_first':
126      input_shape = (4, None, None)
127    else:
128      input_shape = (None, None, 4)
129    output_shape = _get_output_shape(
130        lambda: app(weights=None, include_top=False, input_shape=input_shape))
131    self.assertShapeEqual(output_shape, (None, None, None, last_dim))
132    backend.clear_session()
133
134
135def _get_output_shape(model_fn):
136  model = model_fn()
137  return model.output_shape
138
139
140if __name__ == '__main__':
141  test.main()
142