• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Keras image dataset loading utilities."""
16# pylint: disable=g-classes-have-attributes
17
18import multiprocessing
19import os
20
21import numpy as np
22
23from tensorflow.python.data.ops import dataset_ops
24from tensorflow.python.ops import array_ops
25from tensorflow.python.ops import math_ops
26
27
28def index_directory(directory,
29                    labels,
30                    formats,
31                    class_names=None,
32                    shuffle=True,
33                    seed=None,
34                    follow_links=False):
35  """Make list of all files in the subdirs of `directory`, with their labels.
36
37  Args:
38    directory: The target directory (string).
39    labels: Either "inferred"
40        (labels are generated from the directory structure),
41        None (no labels),
42        or a list/tuple of integer labels of the same size as the number of
43        valid files found in the directory. Labels should be sorted according
44        to the alphanumeric order of the image file paths
45        (obtained via `os.walk(directory)` in Python).
46    formats: Allowlist of file extensions to index (e.g. ".jpg", ".txt").
47    class_names: Only valid if "labels" is "inferred". This is the explict
48        list of class names (must match names of subdirectories). Used
49        to control the order of the classes
50        (otherwise alphanumerical order is used).
51    shuffle: Whether to shuffle the data. Default: True.
52        If set to False, sorts the data in alphanumeric order.
53    seed: Optional random seed for shuffling.
54    follow_links: Whether to visits subdirectories pointed to by symlinks.
55
56  Returns:
57    tuple (file_paths, labels, class_names).
58      file_paths: list of file paths (strings).
59      labels: list of matching integer labels (same length as file_paths)
60      class_names: names of the classes corresponding to these labels, in order.
61  """
62  if labels is None:
63    # in the no-label case, index from the parent directory down.
64    subdirs = ['']
65    class_names = subdirs
66  else:
67    subdirs = []
68    for subdir in sorted(os.listdir(directory)):
69      if os.path.isdir(os.path.join(directory, subdir)):
70        subdirs.append(subdir)
71    if not class_names:
72      class_names = subdirs
73    else:
74      if set(class_names) != set(subdirs):
75        raise ValueError(
76            'The `class_names` passed did not match the '
77            'names of the subdirectories of the target directory. '
78            'Expected: %s, but received: %s' %
79            (subdirs, class_names))
80  class_indices = dict(zip(class_names, range(len(class_names))))
81
82  # Build an index of the files
83  # in the different class subfolders.
84  pool = multiprocessing.pool.ThreadPool()
85  results = []
86  filenames = []
87
88  for dirpath in (os.path.join(directory, subdir) for subdir in subdirs):
89    results.append(
90        pool.apply_async(index_subdirectory,
91                         (dirpath, class_indices, follow_links, formats)))
92  labels_list = []
93  for res in results:
94    partial_filenames, partial_labels = res.get()
95    labels_list.append(partial_labels)
96    filenames += partial_filenames
97  if labels not in ('inferred', None):
98    if len(labels) != len(filenames):
99      raise ValueError('Expected the lengths of `labels` to match the number '
100                       'of files in the target directory. len(labels) is %s '
101                       'while we found %s files in %s.' % (
102                           len(labels), len(filenames), directory))
103  else:
104    i = 0
105    labels = np.zeros((len(filenames),), dtype='int32')
106    for partial_labels in labels_list:
107      labels[i:i + len(partial_labels)] = partial_labels
108      i += len(partial_labels)
109
110  if labels is None:
111    print('Found %d files.' % (len(filenames),))
112  else:
113    print('Found %d files belonging to %d classes.' %
114          (len(filenames), len(class_names)))
115  pool.close()
116  pool.join()
117  file_paths = [os.path.join(directory, fname) for fname in filenames]
118
119  if shuffle:
120    # Shuffle globally to erase macro-structure
121    if seed is None:
122      seed = np.random.randint(1e6)
123    rng = np.random.RandomState(seed)
124    rng.shuffle(file_paths)
125    rng = np.random.RandomState(seed)
126    rng.shuffle(labels)
127  return file_paths, labels, class_names
128
129
130def iter_valid_files(directory, follow_links, formats):
131  walk = os.walk(directory, followlinks=follow_links)
132  for root, _, files in sorted(walk, key=lambda x: x[0]):
133    for fname in sorted(files):
134      if fname.lower().endswith(formats):
135        yield root, fname
136
137
138def index_subdirectory(directory, class_indices, follow_links, formats):
139  """Recursively walks directory and list image paths and their class index.
140
141  Args:
142    directory: string, target directory.
143    class_indices: dict mapping class names to their index.
144    follow_links: boolean, whether to recursively follow subdirectories
145      (if False, we only list top-level images in `directory`).
146    formats: Allowlist of file extensions to index (e.g. ".jpg", ".txt").
147
148  Returns:
149    tuple `(filenames, labels)`. `filenames` is a list of relative file
150      paths, and `labels` is a list of integer labels corresponding to these
151      files.
152  """
153  dirname = os.path.basename(directory)
154  valid_files = iter_valid_files(directory, follow_links, formats)
155  labels = []
156  filenames = []
157  for root, fname in valid_files:
158    labels.append(class_indices[dirname])
159    absolute_path = os.path.join(root, fname)
160    relative_path = os.path.join(
161        dirname, os.path.relpath(absolute_path, directory))
162    filenames.append(relative_path)
163  return filenames, labels
164
165
166def get_training_or_validation_split(samples, labels, validation_split, subset):
167  """Potentially restict samples & labels to a training or validation split.
168
169  Args:
170    samples: List of elements.
171    labels: List of corresponding labels.
172    validation_split: Float, fraction of data to reserve for validation.
173    subset: Subset of the data to return.
174      Either "training", "validation", or None. If None, we return all of the
175      data.
176
177  Returns:
178    tuple (samples, labels), potentially restricted to the specified subset.
179  """
180  if not validation_split:
181    return samples, labels
182
183  num_val_samples = int(validation_split * len(samples))
184  if subset == 'training':
185    print('Using %d files for training.' % (len(samples) - num_val_samples,))
186    samples = samples[:-num_val_samples]
187    labels = labels[:-num_val_samples]
188  elif subset == 'validation':
189    print('Using %d files for validation.' % (num_val_samples,))
190    samples = samples[-num_val_samples:]
191    labels = labels[-num_val_samples:]
192  else:
193    raise ValueError('`subset` must be either "training" '
194                     'or "validation", received: %s' % (subset,))
195  return samples, labels
196
197
198def labels_to_dataset(labels, label_mode, num_classes):
199  """Create a tf.data.Dataset from the list/tuple of labels.
200
201  Args:
202    labels: list/tuple of labels to be converted into a tf.data.Dataset.
203    label_mode:
204    - 'binary' indicates that the labels (there can be only 2) are encoded as
205      `float32` scalars with values 0 or 1 (e.g. for `binary_crossentropy`).
206    - 'categorical' means that the labels are mapped into a categorical vector.
207      (e.g. for `categorical_crossentropy` loss).
208    num_classes: number of classes of labels.
209  """
210  label_ds = dataset_ops.Dataset.from_tensor_slices(labels)
211  if label_mode == 'binary':
212    label_ds = label_ds.map(
213        lambda x: array_ops.expand_dims(math_ops.cast(x, 'float32'), axis=-1))
214  elif label_mode == 'categorical':
215    label_ds = label_ds.map(lambda x: array_ops.one_hot(x, num_classes))
216  return label_ds
217
218
219def check_validation_split_arg(validation_split, subset, shuffle, seed):
220  """Raise errors in case of invalid argument values.
221
222  Args:
223    shuffle: Whether to shuffle the data. Either True or False.
224    seed: random seed for shuffling and transformations.
225    validation_split: float between 0 and 1, fraction of data to reserve for
226      validation.
227    subset: One of "training" or "validation". Only used if `validation_split`
228      is set.
229  """
230  if validation_split and not 0 < validation_split < 1:
231    raise ValueError(
232        '`validation_split` must be between 0 and 1, received: %s' %
233        (validation_split,))
234  if (validation_split or subset) and not (validation_split and subset):
235    raise ValueError(
236        'If `subset` is set, `validation_split` must be set, and inversely.')
237  if subset not in ('training', 'validation', None):
238    raise ValueError('`subset` must be either "training" '
239                     'or "validation", received: %s' % (subset,))
240  if validation_split and shuffle and seed is None:
241    raise ValueError(
242        'If using `validation_split` and shuffling the data, you must provide '
243        'a `seed` argument, to make sure that there is no overlap between the '
244        'training and validation subset.')
245