• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Integration tests for Keras applications."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from absl import flags
22from absl.testing import parameterized
23import numpy as np
24
25from tensorflow.python.keras.applications import densenet
26from tensorflow.python.keras.applications import efficientnet
27from tensorflow.python.keras.applications import inception_resnet_v2
28from tensorflow.python.keras.applications import inception_v3
29from tensorflow.python.keras.applications import mobilenet
30from tensorflow.python.keras.applications import mobilenet_v2
31from tensorflow.python.keras.applications import nasnet
32from tensorflow.python.keras.applications import resnet
33from tensorflow.python.keras.applications import resnet_v2
34from tensorflow.python.keras.applications import vgg16
35from tensorflow.python.keras.applications import vgg19
36from tensorflow.python.keras.applications import xception
37from tensorflow.python.keras.preprocessing import image
38from tensorflow.python.keras.utils import data_utils
39from tensorflow.python.platform import test
40
41
42ARG_TO_MODEL = {
43    'resnet': (resnet, [resnet.ResNet50, resnet.ResNet101, resnet.ResNet152]),
44    'resnet_v2': (resnet_v2, [resnet_v2.ResNet50V2, resnet_v2.ResNet101V2,
45                              resnet_v2.ResNet152V2]),
46    'vgg16': (vgg16, [vgg16.VGG16]),
47    'vgg19': (vgg19, [vgg19.VGG19]),
48    'xception': (xception, [xception.Xception]),
49    'inception_v3': (inception_v3, [inception_v3.InceptionV3]),
50    'inception_resnet_v2': (inception_resnet_v2,
51                            [inception_resnet_v2.InceptionResNetV2]),
52    'mobilenet': (mobilenet, [mobilenet.MobileNet]),
53    'mobilenet_v2': (mobilenet_v2, [mobilenet_v2.MobileNetV2]),
54    'densenet': (densenet, [densenet.DenseNet121,
55                            densenet.DenseNet169, densenet.DenseNet201]),
56    'nasnet': (nasnet, [nasnet.NASNetMobile, nasnet.NASNetLarge]),
57    'efficientnet': (efficientnet,
58                     [efficientnet.EfficientNetB0, efficientnet.EfficientNetB1,
59                      efficientnet.EfficientNetB2, efficientnet.EfficientNetB3,
60                      efficientnet.EfficientNetB4, efficientnet.EfficientNetB5,
61                      efficientnet.EfficientNetB6, efficientnet.EfficientNetB7])
62}
63
64TEST_IMAGE_PATH = ('https://storage.googleapis.com/tensorflow/'
65                   'keras-applications/tests/elephant.jpg')
66_IMAGENET_CLASSES = 1000
67
68# Add a flag to define which application module file is tested.
69# This is set as an 'arg' in the build target to guarantee that
70# it only triggers the tests of the application models in the module
71# if that module file has been modified.
72FLAGS = flags.FLAGS
73flags.DEFINE_string('module', None,
74                    'Application module used in this test.')
75
76
77def _get_elephant(target_size):
78  # For models that don't include a Flatten step,
79  # the default is to accept variable-size inputs
80  # even when loading ImageNet weights (since it is possible).
81  # In this case, default to 299x299.
82  if target_size[0] is None:
83    target_size = (299, 299)
84  test_image = data_utils.get_file('elephant.jpg', TEST_IMAGE_PATH)
85  img = image.load_img(test_image, target_size=tuple(target_size))
86  x = image.img_to_array(img)
87  return np.expand_dims(x, axis=0)
88
89
90class ApplicationsLoadWeightTest(test.TestCase, parameterized.TestCase):
91
92  def assertShapeEqual(self, shape1, shape2):
93    if len(shape1) != len(shape2):
94      raise AssertionError(
95          'Shapes are different rank: %s vs %s' % (shape1, shape2))
96    if shape1 != shape2:
97      raise AssertionError('Shapes differ: %s vs %s' % (shape1, shape2))
98
99  def test_application_pretrained_weights_loading(self):
100    app_module = ARG_TO_MODEL[FLAGS.module][0]
101    apps = ARG_TO_MODEL[FLAGS.module][1]
102    for app in apps:
103      model = app(weights='imagenet')
104      self.assertShapeEqual(model.output_shape, (None, _IMAGENET_CLASSES))
105      x = _get_elephant(model.input_shape[1:3])
106      x = app_module.preprocess_input(x)
107      preds = model.predict(x)
108      names = [p[1] for p in app_module.decode_predictions(preds)[0]]
109      # Test correct label is in top 3 (weak correctness test).
110      self.assertIn('African_elephant', names[:3])
111
112
113if __name__ == '__main__':
114  test.main()
115