• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#  Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3#  Licensed under the Apache License, Version 2.0 (the "License");
4#  you may not use this file except in compliance with the License.
5#  You may obtain a copy of the License at
6#
7#   http://www.apache.org/licenses/LICENSE-2.0
8#
9#  Unless required by applicable law or agreed to in writing, software
10#  distributed under the License is distributed on an "AS IS" BASIS,
11#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12#  See the License for the specific language governing permissions and
13#  limitations under the License.
14# ==============================================================================
15"""tf.data.Dataset interface to the MNIST dataset.
16
17 This is cloned from
18 https://github.com/tensorflow/models/blob/master/official/r1/mnist/dataset.py
19"""
20
21from __future__ import absolute_import
22from __future__ import division
23from __future__ import print_function
24
25import gzip
26import os
27import shutil
28import tempfile
29
30import numpy as np
31from six.moves import urllib
32import tensorflow as tf
33
34
35def read32(bytestream):
36  """Read 4 bytes from bytestream as an unsigned 32-bit integer."""
37  dt = np.dtype(np.uint32).newbyteorder('>')
38  return np.frombuffer(bytestream.read(4), dtype=dt)[0]
39
40
41def check_image_file_header(filename):
42  """Validate that filename corresponds to images for the MNIST dataset."""
43  with tf.gfile.Open(filename, 'rb') as f:
44    magic = read32(f)
45    read32(f)  # num_images, unused
46    rows = read32(f)
47    cols = read32(f)
48    if magic != 2051:
49      raise ValueError('Invalid magic number %d in MNIST file %s' % (magic,
50                                                                     f.name))
51    if rows != 28 or cols != 28:
52      raise ValueError(
53          'Invalid MNIST file %s: Expected 28x28 images, found %dx%d' %
54          (f.name, rows, cols))
55
56
57def check_labels_file_header(filename):
58  """Validate that filename corresponds to labels for the MNIST dataset."""
59  with tf.gfile.Open(filename, 'rb') as f:
60    magic = read32(f)
61    read32(f)  # num_items, unused
62    if magic != 2049:
63      raise ValueError('Invalid magic number %d in MNIST file %s' % (magic,
64                                                                     f.name))
65
66
67def download(directory, filename):
68  """Download (and unzip) a file from the MNIST dataset if not already done."""
69  filepath = os.path.join(directory, filename)
70  if tf.gfile.Exists(filepath):
71    return filepath
72  if not tf.gfile.Exists(directory):
73    tf.gfile.MakeDirs(directory)
74  # CVDF mirror of http://yann.lecun.com/exdb/mnist/
75  url = 'https://storage.googleapis.com/cvdf-datasets/mnist/' + filename + '.gz'
76  _, zipped_filepath = tempfile.mkstemp(suffix='.gz')
77  print('Downloading %s to %s' % (url, zipped_filepath))
78  urllib.request.urlretrieve(url, zipped_filepath)
79  with gzip.open(zipped_filepath, 'rb') as f_in, \
80      tf.gfile.Open(filepath, 'wb') as f_out:
81    shutil.copyfileobj(f_in, f_out)
82  os.remove(zipped_filepath)
83  return filepath
84
85
86def dataset(directory, images_file, labels_file):
87  """Download and parse MNIST dataset."""
88
89  images_file = download(directory, images_file)
90  labels_file = download(directory, labels_file)
91
92  check_image_file_header(images_file)
93  check_labels_file_header(labels_file)
94
95  def decode_image(image):
96    # Normalize from [0, 255] to [0.0, 1.0]
97    image = tf.decode_raw(image, tf.uint8)
98    image = tf.cast(image, tf.float32)
99    image = tf.reshape(image, [784])
100    return image / 255.0
101
102  def decode_label(label):
103    label = tf.decode_raw(label, tf.uint8)  # tf.string -> [tf.uint8]
104    label = tf.reshape(label, [])  # label is a scalar
105    return tf.to_int32(label)
106
107  images = tf.data.FixedLengthRecordDataset(
108      images_file, 28 * 28, header_bytes=16).map(decode_image)
109  labels = tf.data.FixedLengthRecordDataset(
110      labels_file, 1, header_bytes=8).map(decode_label)
111  return tf.data.Dataset.zip((images, labels))
112
113
114def train(directory):
115  """tf.data.Dataset object for MNIST training data."""
116  return dataset(directory, 'train-images-idx3-ubyte',
117                 'train-labels-idx1-ubyte')
118
119
120def test(directory):
121  """tf.data.Dataset object for MNIST test data."""
122  return dataset(directory, 't10k-images-idx3-ubyte', 't10k-labels-idx1-ubyte')
123