• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities for ImageNet data preprocessing & prediction decoding."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import json
21import warnings
22
23import numpy as np
24
25from tensorflow.python.keras import backend
26from tensorflow.python.keras.utils import data_utils
27from tensorflow.python.util.tf_export import keras_export
28
29
30CLASS_INDEX = None
31CLASS_INDEX_PATH = ('https://storage.googleapis.com/download.tensorflow.org/'
32                    'data/imagenet_class_index.json')
33
34
35@keras_export('keras.applications.imagenet_utils.preprocess_input')
36def preprocess_input(x, data_format=None, mode='caffe'):
37  """Preprocesses a tensor or Numpy array encoding a batch of images.
38
39  Arguments:
40    x: Input Numpy or symbolic tensor, 3D or 4D.
41      The preprocessed data is written over the input data
42      if the data types are compatible. To avoid this
43      behaviour, `numpy.copy(x)` can be used.
44    data_format: Data format of the image tensor/array.
45    mode: One of "caffe", "tf" or "torch".
46      - caffe: will convert the images from RGB to BGR,
47          then will zero-center each color channel with
48          respect to the ImageNet dataset,
49          without scaling.
50      - tf: will scale pixels between -1 and 1,
51          sample-wise.
52      - torch: will scale pixels between 0 and 1 and then
53          will normalize each channel with respect to the
54          ImageNet dataset.
55
56  Returns:
57      Preprocessed tensor or Numpy array.
58
59  Raises:
60      ValueError: In case of unknown `data_format` argument.
61  """
62  if data_format is None:
63    data_format = backend.image_data_format()
64  if data_format not in {'channels_first', 'channels_last'}:
65    raise ValueError('Unknown data_format ' + str(data_format))
66
67  if isinstance(x, np.ndarray):
68    return _preprocess_numpy_input(
69        x, data_format=data_format, mode=mode)
70  else:
71    return _preprocess_symbolic_input(
72        x, data_format=data_format, mode=mode)
73
74
75@keras_export('keras.applications.imagenet_utils.decode_predictions')
76def decode_predictions(preds, top=5):
77  """Decodes the prediction of an ImageNet model.
78
79  Arguments:
80    preds: Numpy tensor encoding a batch of predictions.
81    top: Integer, how many top-guesses to return.
82
83  Returns:
84    A list of lists of top class prediction tuples
85    `(class_name, class_description, score)`.
86    One list of tuples per sample in batch input.
87
88  Raises:
89    ValueError: In case of invalid shape of the `pred` array
90      (must be 2D).
91  """
92  global CLASS_INDEX
93
94  if len(preds.shape) != 2 or preds.shape[1] != 1000:
95    raise ValueError('`decode_predictions` expects '
96                     'a batch of predictions '
97                     '(i.e. a 2D array of shape (samples, 1000)). '
98                     'Found array with shape: ' + str(preds.shape))
99  if CLASS_INDEX is None:
100    fpath = data_utils.get_file(
101        'imagenet_class_index.json',
102        CLASS_INDEX_PATH,
103        cache_subdir='models',
104        file_hash='c2c37ea517e94d9795004a39431a14cb')
105    with open(fpath) as f:
106      CLASS_INDEX = json.load(f)
107  results = []
108  for pred in preds:
109    top_indices = pred.argsort()[-top:][::-1]
110    result = [tuple(CLASS_INDEX[str(i)]) + (pred[i],) for i in top_indices]
111    result.sort(key=lambda x: x[2], reverse=True)
112    results.append(result)
113  return results
114
115
116def _preprocess_numpy_input(x, data_format, mode):
117  """Preprocesses a Numpy array encoding a batch of images.
118
119  Arguments:
120    x: Input array, 3D or 4D.
121    data_format: Data format of the image array.
122    mode: One of "caffe", "tf" or "torch".
123      - caffe: will convert the images from RGB to BGR,
124          then will zero-center each color channel with
125          respect to the ImageNet dataset,
126          without scaling.
127      - tf: will scale pixels between -1 and 1,
128          sample-wise.
129      - torch: will scale pixels between 0 and 1 and then
130          will normalize each channel with respect to the
131          ImageNet dataset.
132
133  Returns:
134      Preprocessed Numpy array.
135  """
136  if not issubclass(x.dtype.type, np.floating):
137    x = x.astype(backend.floatx(), copy=False)
138
139  if mode == 'tf':
140    x /= 127.5
141    x -= 1.
142    return x
143
144  if mode == 'torch':
145    x /= 255.
146    mean = [0.485, 0.456, 0.406]
147    std = [0.229, 0.224, 0.225]
148  else:
149    if data_format == 'channels_first':
150      # 'RGB'->'BGR'
151      if x.ndim == 3:
152        x = x[::-1, ...]
153      else:
154        x = x[:, ::-1, ...]
155    else:
156      # 'RGB'->'BGR'
157      x = x[..., ::-1]
158    mean = [103.939, 116.779, 123.68]
159    std = None
160
161  # Zero-center by mean pixel
162  if data_format == 'channels_first':
163    if x.ndim == 3:
164      x[0, :, :] -= mean[0]
165      x[1, :, :] -= mean[1]
166      x[2, :, :] -= mean[2]
167      if std is not None:
168        x[0, :, :] /= std[0]
169        x[1, :, :] /= std[1]
170        x[2, :, :] /= std[2]
171    else:
172      x[:, 0, :, :] -= mean[0]
173      x[:, 1, :, :] -= mean[1]
174      x[:, 2, :, :] -= mean[2]
175      if std is not None:
176        x[:, 0, :, :] /= std[0]
177        x[:, 1, :, :] /= std[1]
178        x[:, 2, :, :] /= std[2]
179  else:
180    x[..., 0] -= mean[0]
181    x[..., 1] -= mean[1]
182    x[..., 2] -= mean[2]
183    if std is not None:
184      x[..., 0] /= std[0]
185      x[..., 1] /= std[1]
186      x[..., 2] /= std[2]
187  return x
188
189
190def _preprocess_symbolic_input(x, data_format, mode):
191  """Preprocesses a tensor encoding a batch of images.
192
193  Arguments:
194    x: Input tensor, 3D or 4D.
195    data_format: Data format of the image tensor.
196    mode: One of "caffe", "tf" or "torch".
197      - caffe: will convert the images from RGB to BGR,
198          then will zero-center each color channel with
199          respect to the ImageNet dataset,
200          without scaling.
201      - tf: will scale pixels between -1 and 1,
202          sample-wise.
203      - torch: will scale pixels between 0 and 1 and then
204          will normalize each channel with respect to the
205          ImageNet dataset.
206
207  Returns:
208      Preprocessed tensor.
209  """
210  if mode == 'tf':
211    x /= 127.5
212    x -= 1.
213    return x
214
215  if mode == 'torch':
216    x /= 255.
217    mean = [0.485, 0.456, 0.406]
218    std = [0.229, 0.224, 0.225]
219  else:
220    if data_format == 'channels_first':
221      # 'RGB'->'BGR'
222      if backend.ndim(x) == 3:
223        x = x[::-1, ...]
224      else:
225        x = x[:, ::-1, ...]
226    else:
227      # 'RGB'->'BGR'
228      x = x[..., ::-1]
229    mean = [103.939, 116.779, 123.68]
230    std = None
231
232  mean_tensor = backend.constant(-np.array(mean))
233
234  # Zero-center by mean pixel
235  if backend.dtype(x) != backend.dtype(mean_tensor):
236    x = backend.bias_add(
237        x, backend.cast(mean_tensor, backend.dtype(x)), data_format=data_format)
238  else:
239    x = backend.bias_add(x, mean_tensor, data_format)
240  if std is not None:
241    x /= std
242  return x
243
244
245def obtain_input_shape(input_shape,
246                       default_size,
247                       min_size,
248                       data_format,
249                       require_flatten,
250                       weights=None):
251  """Internal utility to compute/validate a model's input shape.
252
253  Arguments:
254    input_shape: Either None (will return the default network input shape),
255      or a user-provided shape to be validated.
256    default_size: Default input width/height for the model.
257    min_size: Minimum input width/height accepted by the model.
258    data_format: Image data format to use.
259    require_flatten: Whether the model is expected to
260      be linked to a classifier via a Flatten layer.
261    weights: One of `None` (random initialization)
262      or 'imagenet' (pre-training on ImageNet).
263      If weights='imagenet' input channels must be equal to 3.
264
265  Returns:
266    An integer shape tuple (may include None entries).
267
268  Raises:
269    ValueError: In case of invalid argument values.
270  """
271  if weights != 'imagenet' and input_shape and len(input_shape) == 3:
272    if data_format == 'channels_first':
273      if input_shape[0] not in {1, 3}:
274        warnings.warn('This model usually expects 1 or 3 input channels. '
275                      'However, it was passed an input_shape with ' +
276                      str(input_shape[0]) + ' input channels.')
277      default_shape = (input_shape[0], default_size, default_size)
278    else:
279      if input_shape[-1] not in {1, 3}:
280        warnings.warn('This model usually expects 1 or 3 input channels. '
281                      'However, it was passed an input_shape with ' +
282                      str(input_shape[-1]) + ' input channels.')
283      default_shape = (default_size, default_size, input_shape[-1])
284  else:
285    if data_format == 'channels_first':
286      default_shape = (3, default_size, default_size)
287    else:
288      default_shape = (default_size, default_size, 3)
289  if weights == 'imagenet' and require_flatten:
290    if input_shape is not None:
291      if input_shape != default_shape:
292        raise ValueError('When setting `include_top=True` '
293                         'and loading `imagenet` weights, '
294                         '`input_shape` should be ' + str(default_shape) + '.')
295    return default_shape
296  if input_shape:
297    if data_format == 'channels_first':
298      if input_shape is not None:
299        if len(input_shape) != 3:
300          raise ValueError('`input_shape` must be a tuple of three integers.')
301        if input_shape[0] != 3 and weights == 'imagenet':
302          raise ValueError('The input must have 3 channels; got '
303                           '`input_shape=' + str(input_shape) + '`')
304        if ((input_shape[1] is not None and input_shape[1] < min_size) or
305            (input_shape[2] is not None and input_shape[2] < min_size)):
306          raise ValueError('Input size must be at least ' + str(min_size) +
307                           'x' + str(min_size) + '; got `input_shape=' +
308                           str(input_shape) + '`')
309    else:
310      if input_shape is not None:
311        if len(input_shape) != 3:
312          raise ValueError('`input_shape` must be a tuple of three integers.')
313        if input_shape[-1] != 3 and weights == 'imagenet':
314          raise ValueError('The input must have 3 channels; got '
315                           '`input_shape=' + str(input_shape) + '`')
316        if ((input_shape[0] is not None and input_shape[0] < min_size) or
317            (input_shape[1] is not None and input_shape[1] < min_size)):
318          raise ValueError('Input size must be at least ' + str(min_size) +
319                           'x' + str(min_size) + '; got `input_shape=' +
320                           str(input_shape) + '`')
321  else:
322    if require_flatten:
323      input_shape = default_shape
324    else:
325      if data_format == 'channels_first':
326        input_shape = (3, None, None)
327      else:
328        input_shape = (None, None, 3)
329  if require_flatten:
330    if None in input_shape:
331      raise ValueError('If `include_top` is True, '
332                       'you should specify a static `input_shape`. '
333                       'Got `input_shape=' + str(input_shape) + '`')
334  return input_shape
335
336
337def correct_pad(inputs, kernel_size):
338  """Returns a tuple for zero-padding for 2D convolution with downsampling.
339
340  Arguments:
341    inputs: Input tensor.
342    kernel_size: An integer or tuple/list of 2 integers.
343
344  Returns:
345    A tuple.
346  """
347  img_dim = 2 if backend.image_data_format() == 'channels_first' else 1
348  input_size = backend.int_shape(inputs)[img_dim:(img_dim + 2)]
349  if isinstance(kernel_size, int):
350    kernel_size = (kernel_size, kernel_size)
351  if input_size[0] is None:
352    adjust = (1, 1)
353  else:
354    adjust = (1 - input_size[0] % 2, 1 - input_size[1] % 2)
355  correct = (kernel_size[0] // 2, kernel_size[1] // 2)
356  return ((correct[0] - adjust[0], correct[0]),
357          (correct[1] - adjust[1], correct[1]))
358