• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for imagenet_utils."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from absl.testing import parameterized
21import numpy as np
22
23from tensorflow.python import keras
24from tensorflow.python.keras import keras_parameterized
25from tensorflow.python.keras.applications import imagenet_utils as utils
26from tensorflow.python.platform import test
27
28
29class TestImageNetUtils(keras_parameterized.TestCase):
30
31  def test_preprocess_input(self):
32    # Test image batch with float and int image input
33    x = np.random.uniform(0, 255, (2, 10, 10, 3))
34    xint = x.astype('int32')
35    self.assertEqual(utils.preprocess_input(x).shape, x.shape)
36    self.assertEqual(utils.preprocess_input(xint).shape, xint.shape)
37
38    out1 = utils.preprocess_input(x, 'channels_last')
39    out1int = utils.preprocess_input(xint, 'channels_last')
40    out2 = utils.preprocess_input(
41        np.transpose(x, (0, 3, 1, 2)), 'channels_first')
42    out2int = utils.preprocess_input(
43        np.transpose(xint, (0, 3, 1, 2)), 'channels_first')
44    self.assertAllClose(out1, out2.transpose(0, 2, 3, 1))
45    self.assertAllClose(out1int, out2int.transpose(0, 2, 3, 1))
46
47    # Test single image
48    x = np.random.uniform(0, 255, (10, 10, 3))
49    xint = x.astype('int32')
50    self.assertEqual(utils.preprocess_input(x).shape, x.shape)
51    self.assertEqual(utils.preprocess_input(xint).shape, xint.shape)
52
53    out1 = utils.preprocess_input(x, 'channels_last')
54    out1int = utils.preprocess_input(xint, 'channels_last')
55    out2 = utils.preprocess_input(np.transpose(x, (2, 0, 1)), 'channels_first')
56    out2int = utils.preprocess_input(
57        np.transpose(xint, (2, 0, 1)), 'channels_first')
58    self.assertAllClose(out1, out2.transpose(1, 2, 0))
59    self.assertAllClose(out1int, out2int.transpose(1, 2, 0))
60
61    # Test that writing over the input data works predictably
62    for mode in ['torch', 'tf']:
63      x = np.random.uniform(0, 255, (2, 10, 10, 3))
64      xint = x.astype('int')
65      x2 = utils.preprocess_input(x, mode=mode)
66      xint2 = utils.preprocess_input(xint)
67      self.assertAllClose(x, x2)
68      self.assertNotEqual(xint.astype('float').max(), xint2.max())
69
70    # Caffe mode works differently from the others
71    x = np.random.uniform(0, 255, (2, 10, 10, 3))
72    xint = x.astype('int')
73    x2 = utils.preprocess_input(x, data_format='channels_last', mode='caffe')
74    xint2 = utils.preprocess_input(xint)
75    self.assertAllClose(x, x2[..., ::-1])
76    self.assertNotEqual(xint.astype('float').max(), xint2.max())
77
78  def test_preprocess_input_symbolic(self):
79    # Test image batch
80    x = np.random.uniform(0, 255, (2, 10, 10, 3))
81    inputs = keras.layers.Input(shape=x.shape[1:])
82    outputs = keras.layers.Lambda(
83        utils.preprocess_input, output_shape=x.shape[1:])(
84            inputs)
85    model = keras.Model(inputs, outputs)
86    self.assertEqual(model.predict(x).shape, x.shape)
87
88    outputs1 = keras.layers.Lambda(
89        lambda x: utils.preprocess_input(x, 'channels_last'),
90        output_shape=x.shape[1:])(
91            inputs)
92    model1 = keras.Model(inputs, outputs1)
93    out1 = model1.predict(x)
94    x2 = np.transpose(x, (0, 3, 1, 2))
95    inputs2 = keras.layers.Input(shape=x2.shape[1:])
96    outputs2 = keras.layers.Lambda(
97        lambda x: utils.preprocess_input(x, 'channels_first'),
98        output_shape=x2.shape[1:])(
99            inputs2)
100    model2 = keras.Model(inputs2, outputs2)
101    out2 = model2.predict(x2)
102    self.assertAllClose(out1, out2.transpose(0, 2, 3, 1))
103
104    # Test single image
105    x = np.random.uniform(0, 255, (10, 10, 3))
106    inputs = keras.layers.Input(shape=x.shape)
107    outputs = keras.layers.Lambda(
108        utils.preprocess_input, output_shape=x.shape)(
109            inputs)
110    model = keras.Model(inputs, outputs)
111    self.assertEqual(model.predict(x[np.newaxis])[0].shape, x.shape)
112
113    outputs1 = keras.layers.Lambda(
114        lambda x: utils.preprocess_input(x, 'channels_last'),
115        output_shape=x.shape)(
116            inputs)
117    model1 = keras.Model(inputs, outputs1)
118    out1 = model1.predict(x[np.newaxis])[0]
119    x2 = np.transpose(x, (2, 0, 1))
120    inputs2 = keras.layers.Input(shape=x2.shape)
121    outputs2 = keras.layers.Lambda(
122        lambda x: utils.preprocess_input(x, 'channels_first'),
123        output_shape=x2.shape)(
124            inputs2)
125    model2 = keras.Model(inputs2, outputs2)
126    out2 = model2.predict(x2[np.newaxis])[0]
127    self.assertAllClose(out1, out2.transpose(1, 2, 0))
128
129  @parameterized.named_parameters([
130      {'testcase_name': 'channels_last_format',
131       'data_format': 'channels_last'},
132      {'testcase_name': 'channels_first_format',
133       'data_format': 'channels_first'},
134  ])
135  def test_obtain_input_shape(self, data_format):
136    # input_shape and default_size are not identical.
137    with self.assertRaises(ValueError):
138      utils.obtain_input_shape(
139          input_shape=(224, 224, 3),
140          default_size=299,
141          min_size=139,
142          data_format='channels_last',
143          require_flatten=True,
144          weights='imagenet')
145
146    # Test invalid use cases
147
148    shape = (139, 139)
149    if data_format == 'channels_last':
150      input_shape = shape + (99,)
151    else:
152      input_shape = (99,) + shape
153
154    # input_shape is smaller than min_size.
155    shape = (100, 100)
156    if data_format == 'channels_last':
157      input_shape = shape + (3,)
158    else:
159      input_shape = (3,) + shape
160    with self.assertRaises(ValueError):
161      utils.obtain_input_shape(
162          input_shape=input_shape,
163          default_size=None,
164          min_size=139,
165          data_format=data_format,
166          require_flatten=False)
167
168    # shape is 1D.
169    shape = (100,)
170    if data_format == 'channels_last':
171      input_shape = shape + (3,)
172    else:
173      input_shape = (3,) + shape
174    with self.assertRaises(ValueError):
175      utils.obtain_input_shape(
176          input_shape=input_shape,
177          default_size=None,
178          min_size=139,
179          data_format=data_format,
180          require_flatten=False)
181
182    # the number of channels is 5 not 3.
183    shape = (100, 100)
184    if data_format == 'channels_last':
185      input_shape = shape + (5,)
186    else:
187      input_shape = (5,) + shape
188    with self.assertRaises(ValueError):
189      utils.obtain_input_shape(
190          input_shape=input_shape,
191          default_size=None,
192          min_size=139,
193          data_format=data_format,
194          require_flatten=False)
195
196    # require_flatten=True with dynamic input shape.
197    with self.assertRaises(ValueError):
198      utils.obtain_input_shape(
199          input_shape=None,
200          default_size=None,
201          min_size=139,
202          data_format='channels_first',
203          require_flatten=True)
204
205    # test include top
206    self.assertEqual(utils.obtain_input_shape(
207        input_shape=(3, 200, 200),
208        default_size=None,
209        min_size=139,
210        data_format='channels_first',
211        require_flatten=True), (3, 200, 200))
212
213    self.assertEqual(utils.obtain_input_shape(
214        input_shape=None,
215        default_size=None,
216        min_size=139,
217        data_format='channels_last',
218        require_flatten=False), (None, None, 3))
219
220    self.assertEqual(utils.obtain_input_shape(
221        input_shape=None,
222        default_size=None,
223        min_size=139,
224        data_format='channels_first',
225        require_flatten=False), (3, None, None))
226
227    self.assertEqual(utils.obtain_input_shape(
228        input_shape=None,
229        default_size=None,
230        min_size=139,
231        data_format='channels_last',
232        require_flatten=False), (None, None, 3))
233
234    self.assertEqual(utils.obtain_input_shape(
235        input_shape=(150, 150, 3),
236        default_size=None,
237        min_size=139,
238        data_format='channels_last',
239        require_flatten=False), (150, 150, 3))
240
241    self.assertEqual(utils.obtain_input_shape(
242        input_shape=(3, None, None),
243        default_size=None,
244        min_size=139,
245        data_format='channels_first',
246        require_flatten=False), (3, None, None))
247
248
249if __name__ == '__main__':
250  test.main()
251