• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""ImageNet preprocessing for ResNet."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import tensorflow as tf
21
22IMAGE_SIZE = 224
23CROP_PADDING = 32
24
25
26def distorted_bounding_box_crop(image_bytes,
27                                bbox,
28                                min_object_covered=0.1,
29                                aspect_ratio_range=(0.75, 1.33),
30                                area_range=(0.05, 1.0),
31                                max_attempts=100,
32                                scope=None):
33  """Generates cropped_image using one of the bboxes randomly distorted.
34
35  See `tf.image.sample_distorted_bounding_box` for more documentation.
36
37  Args:
38    image_bytes: `Tensor` of binary image data.
39    bbox: `Tensor` of bounding boxes arranged `[1, num_boxes, coords]`
40        where each coordinate is [0, 1) and the coordinates are arranged
41        as `[ymin, xmin, ymax, xmax]`. If num_boxes is 0 then use the whole
42        image.
43    min_object_covered: An optional `float`. Defaults to `0.1`. The cropped
44        area of the image must contain at least this fraction of any bounding
45        box supplied.
46    aspect_ratio_range: An optional list of `float`s. The cropped area of the
47        image must have an aspect ratio = width / height within this range.
48    area_range: An optional list of `float`s. The cropped area of the image
49        must contain a fraction of the supplied image within in this range.
50    max_attempts: An optional `int`. Number of attempts at generating a cropped
51        region of the image of the specified constraints. After `max_attempts`
52        failures, return the entire image.
53    scope: Optional `str` for name scope.
54  Returns:
55    cropped image `Tensor`
56  """
57  with tf.name_scope(scope, 'distorted_bounding_box_crop', [image_bytes, bbox]):
58    shape = tf.image.extract_jpeg_shape(image_bytes)
59    sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box(
60        shape,
61        bounding_boxes=bbox,
62        min_object_covered=min_object_covered,
63        aspect_ratio_range=aspect_ratio_range,
64        area_range=area_range,
65        max_attempts=max_attempts,
66        use_image_if_no_bounding_boxes=True)
67    bbox_begin, bbox_size, _ = sample_distorted_bounding_box
68
69    # Crop the image to the specified bounding box.
70    offset_y, offset_x, _ = tf.unstack(bbox_begin)
71    target_height, target_width, _ = tf.unstack(bbox_size)
72    crop_window = tf.stack([offset_y, offset_x, target_height, target_width])
73    image = tf.image.decode_and_crop_jpeg(image_bytes, crop_window, channels=3)
74
75    return image
76
77
78def _at_least_x_are_equal(a, b, x):
79  """At least `x` of `a` and `b` `Tensors` are equal."""
80  match = tf.equal(a, b)
81  match = tf.cast(match, tf.int32)
82  return tf.greater_equal(tf.reduce_sum(match), x)
83
84
85def _decode_and_random_crop(image_bytes, image_size):
86  """Make a random crop of image_size."""
87  bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4])
88  image = distorted_bounding_box_crop(
89      image_bytes,
90      bbox,
91      min_object_covered=0.1,
92      aspect_ratio_range=(3. / 4, 4. / 3.),
93      area_range=(0.08, 1.0),
94      max_attempts=10,
95      scope=None)
96  original_shape = tf.image.extract_jpeg_shape(image_bytes)
97  bad = _at_least_x_are_equal(original_shape, tf.shape(image), 3)
98
99  image = tf.cond(
100      bad,
101      lambda: _decode_and_center_crop(image_bytes, image_size),
102      lambda: tf.image.resize_bicubic([image],  # pylint: disable=g-long-lambda
103                                      [image_size, image_size])[0])
104
105  return image
106
107
108def _decode_and_center_crop(image_bytes, image_size):
109  """Crops to center of image with padding then scales image_size."""
110  shape = tf.image.extract_jpeg_shape(image_bytes)
111  image_height = shape[0]
112  image_width = shape[1]
113
114  padded_center_crop_size = tf.cast(
115      ((image_size / (image_size + CROP_PADDING)) *
116       tf.cast(tf.minimum(image_height, image_width), tf.float32)),
117      tf.int32)
118
119  offset_height = ((image_height - padded_center_crop_size) + 1) // 2
120  offset_width = ((image_width - padded_center_crop_size) + 1) // 2
121  crop_window = tf.stack([offset_height, offset_width,
122                          padded_center_crop_size, padded_center_crop_size])
123  image = tf.image.decode_and_crop_jpeg(image_bytes, crop_window, channels=3)
124  image = tf.image.resize_bicubic([image], [image_size, image_size])[0]
125
126  return image
127
128
129def _flip(image):
130  """Random horizontal image flip."""
131  image = tf.image.random_flip_left_right(image)
132  return image
133
134
135def preprocess_for_train(image_bytes, use_bfloat16, image_size=IMAGE_SIZE):
136  """Preprocesses the given image for evaluation.
137
138  Args:
139    image_bytes: `Tensor` representing an image binary of arbitrary size.
140    use_bfloat16: `bool` for whether to use bfloat16.
141    image_size: image size.
142
143  Returns:
144    A preprocessed image `Tensor`.
145  """
146  image = _decode_and_random_crop(image_bytes, image_size)
147  image = _flip(image)
148  image = tf.reshape(image, [image_size, image_size, 3])
149  image = tf.image.convert_image_dtype(
150      image, dtype=tf.bfloat16 if use_bfloat16 else tf.float32)
151  return image
152
153
154def preprocess_for_eval(image_bytes, use_bfloat16, image_size=IMAGE_SIZE):
155  """Preprocesses the given image for evaluation.
156
157  Args:
158    image_bytes: `Tensor` representing an image binary of arbitrary size.
159    use_bfloat16: `bool` for whether to use bfloat16.
160    image_size: image size.
161
162  Returns:
163    A preprocessed image `Tensor`.
164  """
165  image = _decode_and_center_crop(image_bytes, image_size)
166  image = tf.reshape(image, [image_size, image_size, 3])
167  image = tf.image.convert_image_dtype(
168      image, dtype=tf.bfloat16 if use_bfloat16 else tf.float32)
169  return image
170
171
172def preprocess_image(image_bytes,
173                     is_training=False,
174                     use_bfloat16=False,
175                     image_size=IMAGE_SIZE):
176  """Preprocesses the given image.
177
178  Args:
179    image_bytes: `Tensor` representing an image binary of arbitrary size.
180    is_training: `bool` for whether the preprocessing is for training.
181    use_bfloat16: `bool` for whether to use bfloat16.
182    image_size: image size.
183
184  Returns:
185    A preprocessed image `Tensor`.
186  """
187  if is_training:
188    return preprocess_for_train(image_bytes, use_bfloat16, image_size)
189  else:
190    return preprocess_for_eval(image_bytes, use_bfloat16, image_size)
191