1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for imagenet_utils.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from absl.testing import parameterized 21import numpy as np 22 23from tensorflow.python import keras 24from tensorflow.python.keras import keras_parameterized 25from tensorflow.python.keras.applications import imagenet_utils as utils 26from tensorflow.python.platform import test 27 28 29class TestImageNetUtils(keras_parameterized.TestCase): 30 31 def test_preprocess_input(self): 32 # Test image batch with float and int image input 33 x = np.random.uniform(0, 255, (2, 10, 10, 3)) 34 xint = x.astype('int32') 35 self.assertEqual(utils.preprocess_input(x).shape, x.shape) 36 self.assertEqual(utils.preprocess_input(xint).shape, xint.shape) 37 38 out1 = utils.preprocess_input(x, 'channels_last') 39 out1int = utils.preprocess_input(xint, 'channels_last') 40 out2 = utils.preprocess_input( 41 np.transpose(x, (0, 3, 1, 2)), 'channels_first') 42 out2int = utils.preprocess_input( 43 np.transpose(xint, (0, 3, 1, 2)), 'channels_first') 44 self.assertAllClose(out1, out2.transpose(0, 2, 3, 1)) 45 self.assertAllClose(out1int, out2int.transpose(0, 2, 3, 1)) 46 47 # Test single image 48 x = np.random.uniform(0, 255, (10, 10, 3)) 49 xint = x.astype('int32') 50 self.assertEqual(utils.preprocess_input(x).shape, x.shape) 51 self.assertEqual(utils.preprocess_input(xint).shape, xint.shape) 52 53 out1 = utils.preprocess_input(x, 'channels_last') 54 out1int = utils.preprocess_input(xint, 'channels_last') 55 out2 = utils.preprocess_input(np.transpose(x, (2, 0, 1)), 'channels_first') 56 out2int = utils.preprocess_input( 57 np.transpose(xint, (2, 0, 1)), 'channels_first') 58 self.assertAllClose(out1, out2.transpose(1, 2, 0)) 59 self.assertAllClose(out1int, out2int.transpose(1, 2, 0)) 60 61 # Test that writing over the input data works predictably 62 for mode in ['torch', 'tf']: 63 x = np.random.uniform(0, 255, (2, 10, 10, 3)) 64 xint = x.astype('int') 65 x2 = utils.preprocess_input(x, mode=mode) 66 xint2 = utils.preprocess_input(xint) 67 self.assertAllClose(x, x2) 68 self.assertNotEqual(xint.astype('float').max(), xint2.max()) 69 70 # Caffe mode works differently from the others 71 x = np.random.uniform(0, 255, (2, 10, 10, 3)) 72 xint = x.astype('int') 73 x2 = utils.preprocess_input(x, data_format='channels_last', mode='caffe') 74 xint2 = utils.preprocess_input(xint) 75 self.assertAllClose(x, x2[..., ::-1]) 76 self.assertNotEqual(xint.astype('float').max(), xint2.max()) 77 78 def test_preprocess_input_symbolic(self): 79 # Test image batch 80 x = np.random.uniform(0, 255, (2, 10, 10, 3)) 81 inputs = keras.layers.Input(shape=x.shape[1:]) 82 outputs = keras.layers.Lambda( 83 utils.preprocess_input, output_shape=x.shape[1:])( 84 inputs) 85 model = keras.Model(inputs, outputs) 86 self.assertEqual(model.predict(x).shape, x.shape) 87 88 outputs1 = keras.layers.Lambda( 89 lambda x: utils.preprocess_input(x, 'channels_last'), 90 output_shape=x.shape[1:])( 91 inputs) 92 model1 = keras.Model(inputs, outputs1) 93 out1 = model1.predict(x) 94 x2 = np.transpose(x, (0, 3, 1, 2)) 95 inputs2 = keras.layers.Input(shape=x2.shape[1:]) 96 outputs2 = keras.layers.Lambda( 97 lambda x: utils.preprocess_input(x, 'channels_first'), 98 output_shape=x2.shape[1:])( 99 inputs2) 100 model2 = keras.Model(inputs2, outputs2) 101 out2 = model2.predict(x2) 102 self.assertAllClose(out1, out2.transpose(0, 2, 3, 1)) 103 104 # Test single image 105 x = np.random.uniform(0, 255, (10, 10, 3)) 106 inputs = keras.layers.Input(shape=x.shape) 107 outputs = keras.layers.Lambda( 108 utils.preprocess_input, output_shape=x.shape)( 109 inputs) 110 model = keras.Model(inputs, outputs) 111 self.assertEqual(model.predict(x[np.newaxis])[0].shape, x.shape) 112 113 outputs1 = keras.layers.Lambda( 114 lambda x: utils.preprocess_input(x, 'channels_last'), 115 output_shape=x.shape)( 116 inputs) 117 model1 = keras.Model(inputs, outputs1) 118 out1 = model1.predict(x[np.newaxis])[0] 119 x2 = np.transpose(x, (2, 0, 1)) 120 inputs2 = keras.layers.Input(shape=x2.shape) 121 outputs2 = keras.layers.Lambda( 122 lambda x: utils.preprocess_input(x, 'channels_first'), 123 output_shape=x2.shape)( 124 inputs2) 125 model2 = keras.Model(inputs2, outputs2) 126 out2 = model2.predict(x2[np.newaxis])[0] 127 self.assertAllClose(out1, out2.transpose(1, 2, 0)) 128 129 @parameterized.named_parameters([ 130 {'testcase_name': 'channels_last_format', 131 'data_format': 'channels_last'}, 132 {'testcase_name': 'channels_first_format', 133 'data_format': 'channels_first'}, 134 ]) 135 def test_obtain_input_shape(self, data_format): 136 # input_shape and default_size are not identical. 137 with self.assertRaises(ValueError): 138 utils.obtain_input_shape( 139 input_shape=(224, 224, 3), 140 default_size=299, 141 min_size=139, 142 data_format='channels_last', 143 require_flatten=True, 144 weights='imagenet') 145 146 # Test invalid use cases 147 148 shape = (139, 139) 149 if data_format == 'channels_last': 150 input_shape = shape + (99,) 151 else: 152 input_shape = (99,) + shape 153 154 # input_shape is smaller than min_size. 155 shape = (100, 100) 156 if data_format == 'channels_last': 157 input_shape = shape + (3,) 158 else: 159 input_shape = (3,) + shape 160 with self.assertRaises(ValueError): 161 utils.obtain_input_shape( 162 input_shape=input_shape, 163 default_size=None, 164 min_size=139, 165 data_format=data_format, 166 require_flatten=False) 167 168 # shape is 1D. 169 shape = (100,) 170 if data_format == 'channels_last': 171 input_shape = shape + (3,) 172 else: 173 input_shape = (3,) + shape 174 with self.assertRaises(ValueError): 175 utils.obtain_input_shape( 176 input_shape=input_shape, 177 default_size=None, 178 min_size=139, 179 data_format=data_format, 180 require_flatten=False) 181 182 # the number of channels is 5 not 3. 183 shape = (100, 100) 184 if data_format == 'channels_last': 185 input_shape = shape + (5,) 186 else: 187 input_shape = (5,) + shape 188 with self.assertRaises(ValueError): 189 utils.obtain_input_shape( 190 input_shape=input_shape, 191 default_size=None, 192 min_size=139, 193 data_format=data_format, 194 require_flatten=False) 195 196 # require_flatten=True with dynamic input shape. 197 with self.assertRaises(ValueError): 198 utils.obtain_input_shape( 199 input_shape=None, 200 default_size=None, 201 min_size=139, 202 data_format='channels_first', 203 require_flatten=True) 204 205 # test include top 206 self.assertEqual(utils.obtain_input_shape( 207 input_shape=(3, 200, 200), 208 default_size=None, 209 min_size=139, 210 data_format='channels_first', 211 require_flatten=True), (3, 200, 200)) 212 213 self.assertEqual(utils.obtain_input_shape( 214 input_shape=None, 215 default_size=None, 216 min_size=139, 217 data_format='channels_last', 218 require_flatten=False), (None, None, 3)) 219 220 self.assertEqual(utils.obtain_input_shape( 221 input_shape=None, 222 default_size=None, 223 min_size=139, 224 data_format='channels_first', 225 require_flatten=False), (3, None, None)) 226 227 self.assertEqual(utils.obtain_input_shape( 228 input_shape=None, 229 default_size=None, 230 min_size=139, 231 data_format='channels_last', 232 require_flatten=False), (None, None, 3)) 233 234 self.assertEqual(utils.obtain_input_shape( 235 input_shape=(150, 150, 3), 236 default_size=None, 237 min_size=139, 238 data_format='channels_last', 239 require_flatten=False), (150, 150, 3)) 240 241 self.assertEqual(utils.obtain_input_shape( 242 input_shape=(3, None, None), 243 default_size=None, 244 min_size=139, 245 data_format='channels_first', 246 require_flatten=False), (3, None, None)) 247 248 249if __name__ == '__main__': 250 test.main() 251