• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Routine for decoding the CIFAR-10 binary file format."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import os
22
23from six.moves import xrange  # pylint: disable=redefined-builtin
24import tensorflow as tf
25
26# Process images of this size. Note that this differs from the original CIFAR
27# image size of 32 x 32. If one alters this number, then the entire model
28# architecture will change and any model would need to be retrained.
29IMAGE_SIZE = 24
30
31# Global constants describing the CIFAR-10 data set.
32NUM_CLASSES = 10
33NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = 50000
34NUM_EXAMPLES_PER_EPOCH_FOR_EVAL = 10000
35
36
37def read_cifar10(filename_queue):
38  """Reads and parses examples from CIFAR10 data files.
39
40  Recommendation: if you want N-way read parallelism, call this function
41  N times.  This will give you N independent Readers reading different
42  files & positions within those files, which will give better mixing of
43  examples.
44
45  Args:
46    filename_queue: A queue of strings with the filenames to read from.
47
48  Returns:
49    An object representing a single example, with the following fields:
50      height: number of rows in the result (32)
51      width: number of columns in the result (32)
52      depth: number of color channels in the result (3)
53      key: a scalar string Tensor describing the filename & record number
54        for this example.
55      label: an int32 Tensor with the label in the range 0..9.
56      uint8image: a [height, width, depth] uint8 Tensor with the image data
57  """
58
59  class CIFAR10Record(object):
60    pass
61
62  result = CIFAR10Record()
63
64  # Dimensions of the images in the CIFAR-10 dataset.
65  # See http://www.cs.toronto.edu/~kriz/cifar.html for a description of the
66  # input format.
67  label_bytes = 1  # 2 for CIFAR-100
68  result.height = 32
69  result.width = 32
70  result.depth = 3
71  image_bytes = result.height * result.width * result.depth
72  # Every record consists of a label followed by the image, with a
73  # fixed number of bytes for each.
74  record_bytes = label_bytes + image_bytes
75
76  # Read a record, getting filenames from the filename_queue.  No
77  # header or footer in the CIFAR-10 format, so we leave header_bytes
78  # and footer_bytes at their default of 0.
79  reader = tf.FixedLengthRecordReader(record_bytes=record_bytes)
80  result.key, value = reader.read(filename_queue)
81
82  # Convert from a string to a vector of uint8 that is record_bytes long.
83  record_bytes = tf.decode_raw(value, tf.uint8)
84
85  # The first bytes represent the label, which we convert from uint8->int32.
86  result.label = tf.cast(
87      tf.strided_slice(record_bytes, [0], [label_bytes]), tf.int32)
88
89  # The remaining bytes after the label represent the image, which we reshape
90  # from [depth * height * width] to [depth, height, width].
91  depth_major = tf.reshape(
92      tf.strided_slice(record_bytes, [label_bytes],
93                       [label_bytes + image_bytes]),
94      [result.depth, result.height, result.width])
95  # Convert from [depth, height, width] to [height, width, depth].
96  result.uint8image = tf.transpose(depth_major, [1, 2, 0])
97
98  return result
99
100
101def _generate_image_and_label_batch(image, label, min_queue_examples,
102                                    batch_size, shuffle):
103  """Construct a queued batch of images and labels.
104
105  Args:
106    image: 3-D Tensor of [height, width, 3] of type.float32.
107    label: 1-D Tensor of type.int32
108    min_queue_examples: int32, minimum number of samples to retain
109      in the queue that provides of batches of examples.
110    batch_size: Number of images per batch.
111    shuffle: boolean indicating whether to use a shuffling queue.
112
113  Returns:
114    images: Images. 4D tensor of [batch_size, height, width, 3] size.
115    labels: Labels. 1D tensor of [batch_size] size.
116  """
117  # Create a queue that shuffles the examples, and then
118  # read 'batch_size' images + labels from the example queue.
119  num_preprocess_threads = 16
120  if shuffle:
121    images, label_batch = tf.train.shuffle_batch(
122        [image, label],
123        batch_size=batch_size,
124        num_threads=num_preprocess_threads,
125        capacity=min_queue_examples + 3 * batch_size,
126        min_after_dequeue=min_queue_examples)
127  else:
128    images, label_batch = tf.train.batch(
129        [image, label],
130        batch_size=batch_size,
131        num_threads=num_preprocess_threads,
132        capacity=min_queue_examples + 3 * batch_size)
133
134  # Display the training images in the visualizer.
135  tf.summary.image('images', images)
136
137  return images, tf.reshape(label_batch, [batch_size])
138
139
140def distorted_inputs(data_dir, batch_size):
141  """Construct distorted input for CIFAR training using the Reader ops.
142
143  Args:
144    data_dir: Path to the CIFAR-10 data directory.
145    batch_size: Number of images per batch.
146
147  Returns:
148    images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
149    labels: Labels. 1D tensor of [batch_size] size.
150  """
151  filenames = [
152      os.path.join(data_dir, 'data_batch_%d.bin' % i) for i in xrange(1, 6)
153  ]
154  for f in filenames:
155    if not tf.gfile.Exists(f):
156      raise ValueError('Failed to find file: ' + f)
157
158  # Create a queue that produces the filenames to read.
159  filename_queue = tf.train.string_input_producer(filenames)
160
161  # Read examples from files in the filename queue.
162  read_input = read_cifar10(filename_queue)
163  reshaped_image = tf.cast(read_input.uint8image, tf.float32)
164
165  height = IMAGE_SIZE
166  width = IMAGE_SIZE
167
168  # Image processing for training the network. Note the many random
169  # distortions applied to the image.
170
171  # Randomly crop a [height, width] section of the image.
172  distorted_image = tf.random_crop(reshaped_image, [height, width, 3])
173
174  # Randomly flip the image horizontally.
175  distorted_image = tf.image.random_flip_left_right(distorted_image)
176
177  # Because these operations are not commutative, consider randomizing
178  # the order their operation.
179  distorted_image = tf.image.random_brightness(distorted_image, max_delta=63)
180  distorted_image = tf.image.random_contrast(
181      distorted_image, lower=0.2, upper=1.8)
182
183  # Subtract off the mean and divide by the variance of the pixels.
184  float_image = tf.image.per_image_standardization(distorted_image)
185
186  # Set the shapes of tensors.
187  float_image.set_shape([height, width, 3])
188  read_input.label.set_shape([1])
189
190  # Ensure that the random shuffling has good mixing properties.
191  min_fraction_of_examples_in_queue = 0.4
192  min_queue_examples = int(
193      NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN * min_fraction_of_examples_in_queue)
194  print('Filling queue with %d CIFAR images before starting to train. '
195        'This will take a few minutes.' % min_queue_examples)
196
197  # Generate a batch of images and labels by building up a queue of examples.
198  return _generate_image_and_label_batch(
199      float_image,
200      read_input.label,
201      min_queue_examples,
202      batch_size,
203      shuffle=True)
204
205
206def inputs(eval_data, data_dir, batch_size):
207  """Construct input for CIFAR evaluation using the Reader ops.
208
209  Args:
210    eval_data: bool, indicating if one should use the train or eval data set.
211    data_dir: Path to the CIFAR-10 data directory.
212    batch_size: Number of images per batch.
213
214  Returns:
215    images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
216    labels: Labels. 1D tensor of [batch_size] size.
217  """
218  if not eval_data:
219    filenames = [
220        os.path.join(data_dir, 'data_batch_%d.bin' % i) for i in xrange(1, 6)
221    ]
222    num_examples_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN
223  else:
224    filenames = [os.path.join(data_dir, 'test_batch.bin')]
225    num_examples_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_EVAL
226
227  for f in filenames:
228    if not tf.gfile.Exists(f):
229      raise ValueError('Failed to find file: ' + f)
230
231  # Create a queue that produces the filenames to read.
232  filename_queue = tf.train.string_input_producer(filenames)
233
234  # Read examples from files in the filename queue.
235  read_input = read_cifar10(filename_queue)
236  reshaped_image = tf.cast(read_input.uint8image, tf.float32)
237
238  height = IMAGE_SIZE
239  width = IMAGE_SIZE
240
241  # Image processing for evaluation.
242  # Crop the central [height, width] of the image.
243  resized_image = tf.image.resize_image_with_crop_or_pad(
244      reshaped_image, width, height)
245
246  # Subtract off the mean and divide by the variance of the pixels.
247  float_image = tf.image.per_image_standardization(resized_image)
248
249  # Set the shapes of tensors.
250  float_image.set_shape([height, width, 3])
251  read_input.label.set_shape([1])
252
253  # Ensure that the random shuffling has good mixing properties.
254  min_fraction_of_examples_in_queue = 0.4
255  min_queue_examples = int(
256      num_examples_per_epoch * min_fraction_of_examples_in_queue)
257
258  # Generate a batch of images and labels by building up a queue of examples.
259  return _generate_image_and_label_batch(
260      float_image,
261      read_input.label,
262      min_queue_examples,
263      batch_size,
264      shuffle=False)
265