1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Keras image dataset loading utilities.""" 16# pylint: disable=g-classes-have-attributes 17 18import multiprocessing 19import os 20 21import numpy as np 22 23from tensorflow.python.data.ops import dataset_ops 24from tensorflow.python.ops import array_ops 25from tensorflow.python.ops import math_ops 26 27 28def index_directory(directory, 29 labels, 30 formats, 31 class_names=None, 32 shuffle=True, 33 seed=None, 34 follow_links=False): 35 """Make list of all files in the subdirs of `directory`, with their labels. 36 37 Args: 38 directory: The target directory (string). 39 labels: Either "inferred" 40 (labels are generated from the directory structure), 41 None (no labels), 42 or a list/tuple of integer labels of the same size as the number of 43 valid files found in the directory. Labels should be sorted according 44 to the alphanumeric order of the image file paths 45 (obtained via `os.walk(directory)` in Python). 46 formats: Allowlist of file extensions to index (e.g. ".jpg", ".txt"). 47 class_names: Only valid if "labels" is "inferred". This is the explict 48 list of class names (must match names of subdirectories). Used 49 to control the order of the classes 50 (otherwise alphanumerical order is used). 51 shuffle: Whether to shuffle the data. Default: True. 52 If set to False, sorts the data in alphanumeric order. 53 seed: Optional random seed for shuffling. 54 follow_links: Whether to visits subdirectories pointed to by symlinks. 55 56 Returns: 57 tuple (file_paths, labels, class_names). 58 file_paths: list of file paths (strings). 59 labels: list of matching integer labels (same length as file_paths) 60 class_names: names of the classes corresponding to these labels, in order. 61 """ 62 if labels is None: 63 # in the no-label case, index from the parent directory down. 64 subdirs = [''] 65 class_names = subdirs 66 else: 67 subdirs = [] 68 for subdir in sorted(os.listdir(directory)): 69 if os.path.isdir(os.path.join(directory, subdir)): 70 subdirs.append(subdir) 71 if not class_names: 72 class_names = subdirs 73 else: 74 if set(class_names) != set(subdirs): 75 raise ValueError( 76 'The `class_names` passed did not match the ' 77 'names of the subdirectories of the target directory. ' 78 'Expected: %s, but received: %s' % 79 (subdirs, class_names)) 80 class_indices = dict(zip(class_names, range(len(class_names)))) 81 82 # Build an index of the files 83 # in the different class subfolders. 84 pool = multiprocessing.pool.ThreadPool() 85 results = [] 86 filenames = [] 87 88 for dirpath in (os.path.join(directory, subdir) for subdir in subdirs): 89 results.append( 90 pool.apply_async(index_subdirectory, 91 (dirpath, class_indices, follow_links, formats))) 92 labels_list = [] 93 for res in results: 94 partial_filenames, partial_labels = res.get() 95 labels_list.append(partial_labels) 96 filenames += partial_filenames 97 if labels not in ('inferred', None): 98 if len(labels) != len(filenames): 99 raise ValueError('Expected the lengths of `labels` to match the number ' 100 'of files in the target directory. len(labels) is %s ' 101 'while we found %s files in %s.' % ( 102 len(labels), len(filenames), directory)) 103 else: 104 i = 0 105 labels = np.zeros((len(filenames),), dtype='int32') 106 for partial_labels in labels_list: 107 labels[i:i + len(partial_labels)] = partial_labels 108 i += len(partial_labels) 109 110 if labels is None: 111 print('Found %d files.' % (len(filenames),)) 112 else: 113 print('Found %d files belonging to %d classes.' % 114 (len(filenames), len(class_names))) 115 pool.close() 116 pool.join() 117 file_paths = [os.path.join(directory, fname) for fname in filenames] 118 119 if shuffle: 120 # Shuffle globally to erase macro-structure 121 if seed is None: 122 seed = np.random.randint(1e6) 123 rng = np.random.RandomState(seed) 124 rng.shuffle(file_paths) 125 rng = np.random.RandomState(seed) 126 rng.shuffle(labels) 127 return file_paths, labels, class_names 128 129 130def iter_valid_files(directory, follow_links, formats): 131 walk = os.walk(directory, followlinks=follow_links) 132 for root, _, files in sorted(walk, key=lambda x: x[0]): 133 for fname in sorted(files): 134 if fname.lower().endswith(formats): 135 yield root, fname 136 137 138def index_subdirectory(directory, class_indices, follow_links, formats): 139 """Recursively walks directory and list image paths and their class index. 140 141 Args: 142 directory: string, target directory. 143 class_indices: dict mapping class names to their index. 144 follow_links: boolean, whether to recursively follow subdirectories 145 (if False, we only list top-level images in `directory`). 146 formats: Allowlist of file extensions to index (e.g. ".jpg", ".txt"). 147 148 Returns: 149 tuple `(filenames, labels)`. `filenames` is a list of relative file 150 paths, and `labels` is a list of integer labels corresponding to these 151 files. 152 """ 153 dirname = os.path.basename(directory) 154 valid_files = iter_valid_files(directory, follow_links, formats) 155 labels = [] 156 filenames = [] 157 for root, fname in valid_files: 158 labels.append(class_indices[dirname]) 159 absolute_path = os.path.join(root, fname) 160 relative_path = os.path.join( 161 dirname, os.path.relpath(absolute_path, directory)) 162 filenames.append(relative_path) 163 return filenames, labels 164 165 166def get_training_or_validation_split(samples, labels, validation_split, subset): 167 """Potentially restict samples & labels to a training or validation split. 168 169 Args: 170 samples: List of elements. 171 labels: List of corresponding labels. 172 validation_split: Float, fraction of data to reserve for validation. 173 subset: Subset of the data to return. 174 Either "training", "validation", or None. If None, we return all of the 175 data. 176 177 Returns: 178 tuple (samples, labels), potentially restricted to the specified subset. 179 """ 180 if not validation_split: 181 return samples, labels 182 183 num_val_samples = int(validation_split * len(samples)) 184 if subset == 'training': 185 print('Using %d files for training.' % (len(samples) - num_val_samples,)) 186 samples = samples[:-num_val_samples] 187 labels = labels[:-num_val_samples] 188 elif subset == 'validation': 189 print('Using %d files for validation.' % (num_val_samples,)) 190 samples = samples[-num_val_samples:] 191 labels = labels[-num_val_samples:] 192 else: 193 raise ValueError('`subset` must be either "training" ' 194 'or "validation", received: %s' % (subset,)) 195 return samples, labels 196 197 198def labels_to_dataset(labels, label_mode, num_classes): 199 """Create a tf.data.Dataset from the list/tuple of labels. 200 201 Args: 202 labels: list/tuple of labels to be converted into a tf.data.Dataset. 203 label_mode: 204 - 'binary' indicates that the labels (there can be only 2) are encoded as 205 `float32` scalars with values 0 or 1 (e.g. for `binary_crossentropy`). 206 - 'categorical' means that the labels are mapped into a categorical vector. 207 (e.g. for `categorical_crossentropy` loss). 208 num_classes: number of classes of labels. 209 """ 210 label_ds = dataset_ops.Dataset.from_tensor_slices(labels) 211 if label_mode == 'binary': 212 label_ds = label_ds.map( 213 lambda x: array_ops.expand_dims(math_ops.cast(x, 'float32'), axis=-1)) 214 elif label_mode == 'categorical': 215 label_ds = label_ds.map(lambda x: array_ops.one_hot(x, num_classes)) 216 return label_ds 217 218 219def check_validation_split_arg(validation_split, subset, shuffle, seed): 220 """Raise errors in case of invalid argument values. 221 222 Args: 223 shuffle: Whether to shuffle the data. Either True or False. 224 seed: random seed for shuffling and transformations. 225 validation_split: float between 0 and 1, fraction of data to reserve for 226 validation. 227 subset: One of "training" or "validation". Only used if `validation_split` 228 is set. 229 """ 230 if validation_split and not 0 < validation_split < 1: 231 raise ValueError( 232 '`validation_split` must be between 0 and 1, received: %s' % 233 (validation_split,)) 234 if (validation_split or subset) and not (validation_split and subset): 235 raise ValueError( 236 'If `subset` is set, `validation_split` must be set, and inversely.') 237 if subset not in ('training', 'validation', None): 238 raise ValueError('`subset` must be either "training" ' 239 'or "validation", received: %s' % (subset,)) 240 if validation_split and shuffle and seed is None: 241 raise ValueError( 242 'If using `validation_split` and shuffling the data, you must provide ' 243 'a `seed` argument, to make sure that there is no overlap between the ' 244 'training and validation subset.') 245