1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Integration tests for Keras applications.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from absl.testing import parameterized 22 23from tensorflow.python.keras import backend 24from tensorflow.python.keras.applications import densenet 25from tensorflow.python.keras.applications import efficientnet 26from tensorflow.python.keras.applications import inception_resnet_v2 27from tensorflow.python.keras.applications import inception_v3 28from tensorflow.python.keras.applications import mobilenet 29from tensorflow.python.keras.applications import mobilenet_v2 30from tensorflow.python.keras.applications import nasnet 31from tensorflow.python.keras.applications import resnet 32from tensorflow.python.keras.applications import resnet_v2 33from tensorflow.python.keras.applications import vgg16 34from tensorflow.python.keras.applications import vgg19 35from tensorflow.python.keras.applications import xception 36from tensorflow.python.platform import test 37 38 39MODEL_LIST_NO_NASNET = [ 40 (resnet.ResNet50, 2048), 41 (resnet.ResNet101, 2048), 42 (resnet.ResNet152, 2048), 43 (resnet_v2.ResNet50V2, 2048), 44 (resnet_v2.ResNet101V2, 2048), 45 (resnet_v2.ResNet152V2, 2048), 46 (vgg16.VGG16, 512), 47 (vgg19.VGG19, 512), 48 (xception.Xception, 2048), 49 (inception_v3.InceptionV3, 2048), 50 (inception_resnet_v2.InceptionResNetV2, 1536), 51 (mobilenet.MobileNet, 1024), 52 (mobilenet_v2.MobileNetV2, 1280), 53 (densenet.DenseNet121, 1024), 54 (densenet.DenseNet169, 1664), 55 (densenet.DenseNet201, 1920), 56 (efficientnet.EfficientNetB0, 1280), 57 (efficientnet.EfficientNetB1, 1280), 58 (efficientnet.EfficientNetB2, 1408), 59 (efficientnet.EfficientNetB3, 1536), 60 (efficientnet.EfficientNetB4, 1792), 61 (efficientnet.EfficientNetB5, 2048), 62 (efficientnet.EfficientNetB6, 2304), 63 (efficientnet.EfficientNetB7, 2560), 64] 65 66NASNET_LIST = [ 67 (nasnet.NASNetMobile, 1056), 68 (nasnet.NASNetLarge, 4032), 69] 70 71MODEL_LIST = MODEL_LIST_NO_NASNET + NASNET_LIST 72 73 74class ApplicationsTest(test.TestCase, parameterized.TestCase): 75 76 def assertShapeEqual(self, shape1, shape2): 77 if len(shape1) != len(shape2): 78 raise AssertionError( 79 'Shapes are different rank: %s vs %s' % (shape1, shape2)) 80 for v1, v2 in zip(shape1, shape2): 81 if v1 != v2: 82 raise AssertionError('Shapes differ: %s vs %s' % (shape1, shape2)) 83 84 @parameterized.parameters(*MODEL_LIST) 85 def test_application_base(self, app, _): 86 # Can be instantiated with default arguments 87 model = app(weights=None) 88 # Can be serialized and deserialized 89 config = model.get_config() 90 reconstructed_model = model.__class__.from_config(config) 91 self.assertEqual(len(model.weights), len(reconstructed_model.weights)) 92 backend.clear_session() 93 94 @parameterized.parameters(*MODEL_LIST) 95 def test_application_notop(self, app, last_dim): 96 if 'NASNet' in app.__name__: 97 only_check_last_dim = True 98 else: 99 only_check_last_dim = False 100 output_shape = _get_output_shape( 101 lambda: app(weights=None, include_top=False)) 102 if only_check_last_dim: 103 self.assertEqual(output_shape[-1], last_dim) 104 else: 105 self.assertShapeEqual(output_shape, (None, None, None, last_dim)) 106 backend.clear_session() 107 108 @parameterized.parameters(MODEL_LIST) 109 def test_application_pooling(self, app, last_dim): 110 output_shape = _get_output_shape( 111 lambda: app(weights=None, include_top=False, pooling='avg')) 112 self.assertShapeEqual(output_shape, (None, last_dim)) 113 114 @parameterized.parameters(*MODEL_LIST_NO_NASNET) 115 def test_application_variable_input_channels(self, app, last_dim): 116 if backend.image_data_format() == 'channels_first': 117 input_shape = (1, None, None) 118 else: 119 input_shape = (None, None, 1) 120 output_shape = _get_output_shape( 121 lambda: app(weights=None, include_top=False, input_shape=input_shape)) 122 self.assertShapeEqual(output_shape, (None, None, None, last_dim)) 123 backend.clear_session() 124 125 if backend.image_data_format() == 'channels_first': 126 input_shape = (4, None, None) 127 else: 128 input_shape = (None, None, 4) 129 output_shape = _get_output_shape( 130 lambda: app(weights=None, include_top=False, input_shape=input_shape)) 131 self.assertShapeEqual(output_shape, (None, None, None, last_dim)) 132 backend.clear_session() 133 134 135def _get_output_shape(model_fn): 136 model = model_fn() 137 return model.output_shape 138 139 140if __name__ == '__main__': 141 test.main() 142