1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""tf.data.Dataset interface to the MNIST dataset. 16 17 This is cloned from 18 https://github.com/tensorflow/models/blob/master/official/r1/mnist/dataset.py 19""" 20 21from __future__ import absolute_import 22from __future__ import division 23from __future__ import print_function 24 25import gzip 26import os 27import shutil 28import tempfile 29 30import numpy as np 31from six.moves import urllib 32import tensorflow as tf 33 34 35def read32(bytestream): 36 """Read 4 bytes from bytestream as an unsigned 32-bit integer.""" 37 dt = np.dtype(np.uint32).newbyteorder('>') 38 return np.frombuffer(bytestream.read(4), dtype=dt)[0] 39 40 41def check_image_file_header(filename): 42 """Validate that filename corresponds to images for the MNIST dataset.""" 43 with tf.gfile.Open(filename, 'rb') as f: 44 magic = read32(f) 45 read32(f) # num_images, unused 46 rows = read32(f) 47 cols = read32(f) 48 if magic != 2051: 49 raise ValueError('Invalid magic number %d in MNIST file %s' % (magic, 50 f.name)) 51 if rows != 28 or cols != 28: 52 raise ValueError( 53 'Invalid MNIST file %s: Expected 28x28 images, found %dx%d' % 54 (f.name, rows, cols)) 55 56 57def check_labels_file_header(filename): 58 """Validate that filename corresponds to labels for the MNIST dataset.""" 59 with tf.gfile.Open(filename, 'rb') as f: 60 magic = read32(f) 61 read32(f) # num_items, unused 62 if magic != 2049: 63 raise ValueError('Invalid magic number %d in MNIST file %s' % (magic, 64 f.name)) 65 66 67def download(directory, filename): 68 """Download (and unzip) a file from the MNIST dataset if not already done.""" 69 filepath = os.path.join(directory, filename) 70 if tf.gfile.Exists(filepath): 71 return filepath 72 if not tf.gfile.Exists(directory): 73 tf.gfile.MakeDirs(directory) 74 # CVDF mirror of http://yann.lecun.com/exdb/mnist/ 75 url = 'https://storage.googleapis.com/cvdf-datasets/mnist/' + filename + '.gz' 76 _, zipped_filepath = tempfile.mkstemp(suffix='.gz') 77 print('Downloading %s to %s' % (url, zipped_filepath)) 78 urllib.request.urlretrieve(url, zipped_filepath) 79 with gzip.open(zipped_filepath, 'rb') as f_in, \ 80 tf.gfile.Open(filepath, 'wb') as f_out: 81 shutil.copyfileobj(f_in, f_out) 82 os.remove(zipped_filepath) 83 return filepath 84 85 86def dataset(directory, images_file, labels_file): 87 """Download and parse MNIST dataset.""" 88 89 images_file = download(directory, images_file) 90 labels_file = download(directory, labels_file) 91 92 check_image_file_header(images_file) 93 check_labels_file_header(labels_file) 94 95 def decode_image(image): 96 # Normalize from [0, 255] to [0.0, 1.0] 97 image = tf.decode_raw(image, tf.uint8) 98 image = tf.cast(image, tf.float32) 99 image = tf.reshape(image, [784]) 100 return image / 255.0 101 102 def decode_label(label): 103 label = tf.decode_raw(label, tf.uint8) # tf.string -> [tf.uint8] 104 label = tf.reshape(label, []) # label is a scalar 105 return tf.to_int32(label) 106 107 images = tf.data.FixedLengthRecordDataset( 108 images_file, 28 * 28, header_bytes=16).map(decode_image) 109 labels = tf.data.FixedLengthRecordDataset( 110 labels_file, 1, header_bytes=8).map(decode_label) 111 return tf.data.Dataset.zip((images, labels)) 112 113 114def train(directory): 115 """tf.data.Dataset object for MNIST training data.""" 116 return dataset(directory, 'train-images-idx3-ubyte', 117 'train-labels-idx1-ubyte') 118 119 120def test(directory): 121 """tf.data.Dataset object for MNIST test data.""" 122 return dataset(directory, 't10k-images-idx3-ubyte', 't10k-labels-idx1-ubyte') 123