• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Functions for downloading and reading MNIST data (deprecated).
16
17This module and all its submodules are deprecated.
18"""
19
20from __future__ import absolute_import
21from __future__ import division
22from __future__ import print_function
23
24import collections
25import gzip
26import os
27
28
29import numpy
30from six.moves import urllib
31from six.moves import xrange  # pylint: disable=redefined-builtin
32
33from tensorflow.python.framework import dtypes
34from tensorflow.python.framework import random_seed
35from tensorflow.python.platform import gfile
36from tensorflow.python.util.deprecation import deprecated
37
38_Datasets = collections.namedtuple('_Datasets', ['train', 'validation', 'test'])
39
40# CVDF mirror of http://yann.lecun.com/exdb/mnist/
41DEFAULT_SOURCE_URL = 'https://storage.googleapis.com/cvdf-datasets/mnist/'
42
43
44def _read32(bytestream):
45  dt = numpy.dtype(numpy.uint32).newbyteorder('>')
46  return numpy.frombuffer(bytestream.read(4), dtype=dt)[0]
47
48
49@deprecated(None, 'Please use tf.data to implement this functionality.')
50def _extract_images(f):
51  """Extract the images into a 4D uint8 numpy array [index, y, x, depth].
52
53  Args:
54    f: A file object that can be passed into a gzip reader.
55
56  Returns:
57    data: A 4D uint8 numpy array [index, y, x, depth].
58
59  Raises:
60    ValueError: If the bytestream does not start with 2051.
61
62  """
63  print('Extracting', f.name)
64  with gzip.GzipFile(fileobj=f) as bytestream:
65    magic = _read32(bytestream)
66    if magic != 2051:
67      raise ValueError('Invalid magic number %d in MNIST image file: %s' %
68                       (magic, f.name))
69    num_images = _read32(bytestream)
70    rows = _read32(bytestream)
71    cols = _read32(bytestream)
72    buf = bytestream.read(rows * cols * num_images)
73    data = numpy.frombuffer(buf, dtype=numpy.uint8)
74    data = data.reshape(num_images, rows, cols, 1)
75    return data
76
77
78@deprecated(None, 'Please use tf.one_hot on tensors.')
79def _dense_to_one_hot(labels_dense, num_classes):
80  """Convert class labels from scalars to one-hot vectors."""
81  num_labels = labels_dense.shape[0]
82  index_offset = numpy.arange(num_labels) * num_classes
83  labels_one_hot = numpy.zeros((num_labels, num_classes))
84  labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1
85  return labels_one_hot
86
87
88@deprecated(None, 'Please use tf.data to implement this functionality.')
89def _extract_labels(f, one_hot=False, num_classes=10):
90  """Extract the labels into a 1D uint8 numpy array [index].
91
92  Args:
93    f: A file object that can be passed into a gzip reader.
94    one_hot: Does one hot encoding for the result.
95    num_classes: Number of classes for the one hot encoding.
96
97  Returns:
98    labels: a 1D uint8 numpy array.
99
100  Raises:
101    ValueError: If the bystream doesn't start with 2049.
102  """
103  print('Extracting', f.name)
104  with gzip.GzipFile(fileobj=f) as bytestream:
105    magic = _read32(bytestream)
106    if magic != 2049:
107      raise ValueError('Invalid magic number %d in MNIST label file: %s' %
108                       (magic, f.name))
109    num_items = _read32(bytestream)
110    buf = bytestream.read(num_items)
111    labels = numpy.frombuffer(buf, dtype=numpy.uint8)
112    if one_hot:
113      return _dense_to_one_hot(labels, num_classes)
114    return labels
115
116
117class _DataSet(object):
118  """Container class for a _DataSet (deprecated).
119
120  THIS CLASS IS DEPRECATED.
121  """
122
123  @deprecated(None, 'Please use alternatives such as official/mnist/_DataSet.py'
124              ' from tensorflow/models.')
125  def __init__(self,
126               images,
127               labels,
128               fake_data=False,
129               one_hot=False,
130               dtype=dtypes.float32,
131               reshape=True,
132               seed=None):
133    """Construct a _DataSet.
134
135    one_hot arg is used only if fake_data is true.  `dtype` can be either
136    `uint8` to leave the input as `[0, 255]`, or `float32` to rescale into
137    `[0, 1]`.  Seed arg provides for convenient deterministic testing.
138
139    Args:
140      images: The images
141      labels: The labels
142      fake_data: Ignore images and labels, use fake data.
143      one_hot: Bool, return the labels as one hot vectors (if True) or ints (if
144        False).
145      dtype: Output image dtype. One of [uint8, float32]. `uint8` output has
146        range [0,255]. float32 output has range [0,1].
147      reshape: Bool. If True returned images are returned flattened to vectors.
148      seed: The random seed to use.
149    """
150    seed1, seed2 = random_seed.get_seed(seed)
151    # If op level seed is not set, use whatever graph level seed is returned
152    numpy.random.seed(seed1 if seed is None else seed2)
153    dtype = dtypes.as_dtype(dtype).base_dtype
154    if dtype not in (dtypes.uint8, dtypes.float32):
155      raise TypeError('Invalid image dtype %r, expected uint8 or float32' %
156                      dtype)
157    if fake_data:
158      self._num_examples = 10000
159      self.one_hot = one_hot
160    else:
161      assert images.shape[0] == labels.shape[0], (
162          'images.shape: %s labels.shape: %s' % (images.shape, labels.shape))
163      self._num_examples = images.shape[0]
164
165      # Convert shape from [num examples, rows, columns, depth]
166      # to [num examples, rows*columns] (assuming depth == 1)
167      if reshape:
168        assert images.shape[3] == 1
169        images = images.reshape(images.shape[0],
170                                images.shape[1] * images.shape[2])
171      if dtype == dtypes.float32:
172        # Convert from [0, 255] -> [0.0, 1.0].
173        images = images.astype(numpy.float32)
174        images = numpy.multiply(images, 1.0 / 255.0)
175    self._images = images
176    self._labels = labels
177    self._epochs_completed = 0
178    self._index_in_epoch = 0
179
180  @property
181  def images(self):
182    return self._images
183
184  @property
185  def labels(self):
186    return self._labels
187
188  @property
189  def num_examples(self):
190    return self._num_examples
191
192  @property
193  def epochs_completed(self):
194    return self._epochs_completed
195
196  def next_batch(self, batch_size, fake_data=False, shuffle=True):
197    """Return the next `batch_size` examples from this data set."""
198    if fake_data:
199      fake_image = [1] * 784
200      if self.one_hot:
201        fake_label = [1] + [0] * 9
202      else:
203        fake_label = 0
204      return [fake_image for _ in xrange(batch_size)
205             ], [fake_label for _ in xrange(batch_size)]
206    start = self._index_in_epoch
207    # Shuffle for the first epoch
208    if self._epochs_completed == 0 and start == 0 and shuffle:
209      perm0 = numpy.arange(self._num_examples)
210      numpy.random.shuffle(perm0)
211      self._images = self.images[perm0]
212      self._labels = self.labels[perm0]
213    # Go to the next epoch
214    if start + batch_size > self._num_examples:
215      # Finished epoch
216      self._epochs_completed += 1
217      # Get the rest examples in this epoch
218      rest_num_examples = self._num_examples - start
219      images_rest_part = self._images[start:self._num_examples]
220      labels_rest_part = self._labels[start:self._num_examples]
221      # Shuffle the data
222      if shuffle:
223        perm = numpy.arange(self._num_examples)
224        numpy.random.shuffle(perm)
225        self._images = self.images[perm]
226        self._labels = self.labels[perm]
227      # Start next epoch
228      start = 0
229      self._index_in_epoch = batch_size - rest_num_examples
230      end = self._index_in_epoch
231      images_new_part = self._images[start:end]
232      labels_new_part = self._labels[start:end]
233      return numpy.concatenate((images_rest_part, images_new_part),
234                               axis=0), numpy.concatenate(
235                                   (labels_rest_part, labels_new_part), axis=0)
236    else:
237      self._index_in_epoch += batch_size
238      end = self._index_in_epoch
239      return self._images[start:end], self._labels[start:end]
240
241
242@deprecated(None, 'Please write your own downloading logic.')
243def _maybe_download(filename, work_directory, source_url):
244  """Download the data from source url, unless it's already here.
245
246  Args:
247      filename: string, name of the file in the directory.
248      work_directory: string, path to working directory.
249      source_url: url to download from if file doesn't exist.
250
251  Returns:
252      Path to resulting file.
253  """
254  if not gfile.Exists(work_directory):
255    gfile.MakeDirs(work_directory)
256  filepath = os.path.join(work_directory, filename)
257  if not gfile.Exists(filepath):
258    urllib.request.urlretrieve(source_url, filepath)
259    with gfile.GFile(filepath) as f:
260      size = f.size()
261    print('Successfully downloaded', filename, size, 'bytes.')
262  return filepath
263
264
265@deprecated(None, 'Please use alternatives such as:'
266            ' tensorflow_datasets.load(\'mnist\')')
267def read_data_sets(train_dir,
268                   fake_data=False,
269                   one_hot=False,
270                   dtype=dtypes.float32,
271                   reshape=True,
272                   validation_size=5000,
273                   seed=None,
274                   source_url=DEFAULT_SOURCE_URL):
275  if fake_data:
276
277    def fake():
278      return _DataSet([], [],
279                      fake_data=True,
280                      one_hot=one_hot,
281                      dtype=dtype,
282                      seed=seed)
283
284    train = fake()
285    validation = fake()
286    test = fake()
287    return _Datasets(train=train, validation=validation, test=test)
288
289  if not source_url:  # empty string check
290    source_url = DEFAULT_SOURCE_URL
291
292  train_images_file = 'train-images-idx3-ubyte.gz'
293  train_labels_file = 'train-labels-idx1-ubyte.gz'
294  test_images_file = 't10k-images-idx3-ubyte.gz'
295  test_labels_file = 't10k-labels-idx1-ubyte.gz'
296
297  local_file = _maybe_download(train_images_file, train_dir,
298                               source_url + train_images_file)
299  with gfile.Open(local_file, 'rb') as f:
300    train_images = _extract_images(f)
301
302  local_file = _maybe_download(train_labels_file, train_dir,
303                               source_url + train_labels_file)
304  with gfile.Open(local_file, 'rb') as f:
305    train_labels = _extract_labels(f, one_hot=one_hot)
306
307  local_file = _maybe_download(test_images_file, train_dir,
308                               source_url + test_images_file)
309  with gfile.Open(local_file, 'rb') as f:
310    test_images = _extract_images(f)
311
312  local_file = _maybe_download(test_labels_file, train_dir,
313                               source_url + test_labels_file)
314  with gfile.Open(local_file, 'rb') as f:
315    test_labels = _extract_labels(f, one_hot=one_hot)
316
317  if not 0 <= validation_size <= len(train_images):
318    raise ValueError(
319        'Validation size should be between 0 and {}. Received: {}.'.format(
320            len(train_images), validation_size))
321
322  validation_images = train_images[:validation_size]
323  validation_labels = train_labels[:validation_size]
324  train_images = train_images[validation_size:]
325  train_labels = train_labels[validation_size:]
326
327  options = dict(dtype=dtype, reshape=reshape, seed=seed)
328
329  train = _DataSet(train_images, train_labels, **options)
330  validation = _DataSet(validation_images, validation_labels, **options)
331  test = _DataSet(test_images, test_labels, **options)
332
333  return _Datasets(train=train, validation=validation, test=test)
334