1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Integration tests for Keras applications.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from absl import flags 22from absl.testing import parameterized 23import numpy as np 24 25from tensorflow.python.keras.applications import densenet 26from tensorflow.python.keras.applications import efficientnet 27from tensorflow.python.keras.applications import inception_resnet_v2 28from tensorflow.python.keras.applications import inception_v3 29from tensorflow.python.keras.applications import mobilenet 30from tensorflow.python.keras.applications import mobilenet_v2 31from tensorflow.python.keras.applications import nasnet 32from tensorflow.python.keras.applications import resnet 33from tensorflow.python.keras.applications import resnet_v2 34from tensorflow.python.keras.applications import vgg16 35from tensorflow.python.keras.applications import vgg19 36from tensorflow.python.keras.applications import xception 37from tensorflow.python.keras.preprocessing import image 38from tensorflow.python.keras.utils import data_utils 39from tensorflow.python.platform import test 40 41 42ARG_TO_MODEL = { 43 'resnet': (resnet, [resnet.ResNet50, resnet.ResNet101, resnet.ResNet152]), 44 'resnet_v2': (resnet_v2, [resnet_v2.ResNet50V2, resnet_v2.ResNet101V2, 45 resnet_v2.ResNet152V2]), 46 'vgg16': (vgg16, [vgg16.VGG16]), 47 'vgg19': (vgg19, [vgg19.VGG19]), 48 'xception': (xception, [xception.Xception]), 49 'inception_v3': (inception_v3, [inception_v3.InceptionV3]), 50 'inception_resnet_v2': (inception_resnet_v2, 51 [inception_resnet_v2.InceptionResNetV2]), 52 'mobilenet': (mobilenet, [mobilenet.MobileNet]), 53 'mobilenet_v2': (mobilenet_v2, [mobilenet_v2.MobileNetV2]), 54 'densenet': (densenet, [densenet.DenseNet121, 55 densenet.DenseNet169, densenet.DenseNet201]), 56 'nasnet': (nasnet, [nasnet.NASNetMobile, nasnet.NASNetLarge]), 57 'efficientnet': (efficientnet, 58 [efficientnet.EfficientNetB0, efficientnet.EfficientNetB1, 59 efficientnet.EfficientNetB2, efficientnet.EfficientNetB3, 60 efficientnet.EfficientNetB4, efficientnet.EfficientNetB5, 61 efficientnet.EfficientNetB6, efficientnet.EfficientNetB7]) 62} 63 64TEST_IMAGE_PATH = ('https://storage.googleapis.com/tensorflow/' 65 'keras-applications/tests/elephant.jpg') 66_IMAGENET_CLASSES = 1000 67 68# Add a flag to define which application module file is tested. 69# This is set as an 'arg' in the build target to guarantee that 70# it only triggers the tests of the application models in the module 71# if that module file has been modified. 72FLAGS = flags.FLAGS 73flags.DEFINE_string('module', None, 74 'Application module used in this test.') 75 76 77def _get_elephant(target_size): 78 # For models that don't include a Flatten step, 79 # the default is to accept variable-size inputs 80 # even when loading ImageNet weights (since it is possible). 81 # In this case, default to 299x299. 82 if target_size[0] is None: 83 target_size = (299, 299) 84 test_image = data_utils.get_file('elephant.jpg', TEST_IMAGE_PATH) 85 img = image.load_img(test_image, target_size=tuple(target_size)) 86 x = image.img_to_array(img) 87 return np.expand_dims(x, axis=0) 88 89 90class ApplicationsLoadWeightTest(test.TestCase, parameterized.TestCase): 91 92 def assertShapeEqual(self, shape1, shape2): 93 if len(shape1) != len(shape2): 94 raise AssertionError( 95 'Shapes are different rank: %s vs %s' % (shape1, shape2)) 96 if shape1 != shape2: 97 raise AssertionError('Shapes differ: %s vs %s' % (shape1, shape2)) 98 99 def test_application_pretrained_weights_loading(self): 100 app_module = ARG_TO_MODEL[FLAGS.module][0] 101 apps = ARG_TO_MODEL[FLAGS.module][1] 102 for app in apps: 103 model = app(weights='imagenet') 104 self.assertShapeEqual(model.output_shape, (None, _IMAGENET_CLASSES)) 105 x = _get_elephant(model.input_shape[1:3]) 106 x = app_module.preprocess_input(x) 107 preds = model.predict(x) 108 names = [p[1] for p in app_module.decode_predictions(preds)[0]] 109 # Test correct label is in top 3 (weak correctness test). 110 self.assertIn('African_elephant', names[:3]) 111 112 113if __name__ == '__main__': 114 test.main() 115