• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Keras image dataset loading utilities."""
16# pylint: disable=g-classes-have-attributes
17
18import numpy as np
19
20from tensorflow.python.data.ops import dataset_ops
21from tensorflow.python.keras.layers.preprocessing import image_preprocessing
22from tensorflow.python.keras.preprocessing import dataset_utils
23from tensorflow.python.keras.preprocessing import image as keras_image_ops
24from tensorflow.python.ops import image_ops
25from tensorflow.python.ops import io_ops
26from tensorflow.python.util.tf_export import keras_export
27
28
29ALLOWLIST_FORMATS = ('.bmp', '.gif', '.jpeg', '.jpg', '.png')
30
31
32@keras_export('keras.utils.image_dataset_from_directory',
33              'keras.preprocessing.image_dataset_from_directory',
34              v1=[])
35def image_dataset_from_directory(directory,
36                                 labels='inferred',
37                                 label_mode='int',
38                                 class_names=None,
39                                 color_mode='rgb',
40                                 batch_size=32,
41                                 image_size=(256, 256),
42                                 shuffle=True,
43                                 seed=None,
44                                 validation_split=None,
45                                 subset=None,
46                                 interpolation='bilinear',
47                                 follow_links=False,
48                                 crop_to_aspect_ratio=False,
49                                 **kwargs):
50  """Generates a `tf.data.Dataset` from image files in a directory.
51
52  If your directory structure is:
53
54  ```
55  main_directory/
56  ...class_a/
57  ......a_image_1.jpg
58  ......a_image_2.jpg
59  ...class_b/
60  ......b_image_1.jpg
61  ......b_image_2.jpg
62  ```
63
64  Then calling `image_dataset_from_directory(main_directory, labels='inferred')`
65  will return a `tf.data.Dataset` that yields batches of images from
66  the subdirectories `class_a` and `class_b`, together with labels
67  0 and 1 (0 corresponding to `class_a` and 1 corresponding to `class_b`).
68
69  Supported image formats: jpeg, png, bmp, gif.
70  Animated gifs are truncated to the first frame.
71
72  Args:
73    directory: Directory where the data is located.
74        If `labels` is "inferred", it should contain
75        subdirectories, each containing images for a class.
76        Otherwise, the directory structure is ignored.
77    labels: Either "inferred"
78        (labels are generated from the directory structure),
79        None (no labels),
80        or a list/tuple of integer labels of the same size as the number of
81        image files found in the directory. Labels should be sorted according
82        to the alphanumeric order of the image file paths
83        (obtained via `os.walk(directory)` in Python).
84    label_mode:
85        - 'int': means that the labels are encoded as integers
86            (e.g. for `sparse_categorical_crossentropy` loss).
87        - 'categorical' means that the labels are
88            encoded as a categorical vector
89            (e.g. for `categorical_crossentropy` loss).
90        - 'binary' means that the labels (there can be only 2)
91            are encoded as `float32` scalars with values 0 or 1
92            (e.g. for `binary_crossentropy`).
93        - None (no labels).
94    class_names: Only valid if "labels" is "inferred". This is the explict
95        list of class names (must match names of subdirectories). Used
96        to control the order of the classes
97        (otherwise alphanumerical order is used).
98    color_mode: One of "grayscale", "rgb", "rgba". Default: "rgb".
99        Whether the images will be converted to
100        have 1, 3, or 4 channels.
101    batch_size: Size of the batches of data. Default: 32.
102    image_size: Size to resize images to after they are read from disk.
103        Defaults to `(256, 256)`.
104        Since the pipeline processes batches of images that must all have
105        the same size, this must be provided.
106    shuffle: Whether to shuffle the data. Default: True.
107        If set to False, sorts the data in alphanumeric order.
108    seed: Optional random seed for shuffling and transformations.
109    validation_split: Optional float between 0 and 1,
110        fraction of data to reserve for validation.
111    subset: One of "training" or "validation".
112        Only used if `validation_split` is set.
113    interpolation: String, the interpolation method used when resizing images.
114      Defaults to `bilinear`. Supports `bilinear`, `nearest`, `bicubic`,
115      `area`, `lanczos3`, `lanczos5`, `gaussian`, `mitchellcubic`.
116    follow_links: Whether to visits subdirectories pointed to by symlinks.
117        Defaults to False.
118    crop_to_aspect_ratio: If True, resize the images without aspect
119      ratio distortion. When the original aspect ratio differs from the target
120      aspect ratio, the output image will be cropped so as to return the largest
121      possible window in the image (of size `image_size`) that matches
122      the target aspect ratio. By default (`crop_to_aspect_ratio=False`),
123      aspect ratio may not be preserved.
124    **kwargs: Legacy keyword arguments.
125
126  Returns:
127    A `tf.data.Dataset` object.
128      - If `label_mode` is None, it yields `float32` tensors of shape
129        `(batch_size, image_size[0], image_size[1], num_channels)`,
130        encoding images (see below for rules regarding `num_channels`).
131      - Otherwise, it yields a tuple `(images, labels)`, where `images`
132        has shape `(batch_size, image_size[0], image_size[1], num_channels)`,
133        and `labels` follows the format described below.
134
135  Rules regarding labels format:
136    - if `label_mode` is `int`, the labels are an `int32` tensor of shape
137      `(batch_size,)`.
138    - if `label_mode` is `binary`, the labels are a `float32` tensor of
139      1s and 0s of shape `(batch_size, 1)`.
140    - if `label_mode` is `categorial`, the labels are a `float32` tensor
141      of shape `(batch_size, num_classes)`, representing a one-hot
142      encoding of the class index.
143
144  Rules regarding number of channels in the yielded images:
145    - if `color_mode` is `grayscale`,
146      there's 1 channel in the image tensors.
147    - if `color_mode` is `rgb`,
148      there are 3 channel in the image tensors.
149    - if `color_mode` is `rgba`,
150      there are 4 channel in the image tensors.
151  """
152  if 'smart_resize' in kwargs:
153    crop_to_aspect_ratio = kwargs.pop('smart_resize')
154  if kwargs:
155    raise TypeError(f'Unknown keywords argument(s): {tuple(kwargs.keys())}')
156  if labels not in ('inferred', None):
157    if not isinstance(labels, (list, tuple)):
158      raise ValueError(
159          '`labels` argument should be a list/tuple of integer labels, of '
160          'the same size as the number of image files in the target '
161          'directory. If you wish to infer the labels from the subdirectory '
162          'names in the target directory, pass `labels="inferred"`. '
163          'If you wish to get a dataset that only contains images '
164          '(no labels), pass `label_mode=None`.')
165    if class_names:
166      raise ValueError('You can only pass `class_names` if the labels are '
167                       'inferred from the subdirectory names in the target '
168                       'directory (`labels="inferred"`).')
169  if label_mode not in {'int', 'categorical', 'binary', None}:
170    raise ValueError(
171        '`label_mode` argument must be one of "int", "categorical", "binary", '
172        'or None. Received: %s' % (label_mode,))
173  if labels is None or label_mode is None:
174    labels = None
175    label_mode = None
176  if color_mode == 'rgb':
177    num_channels = 3
178  elif color_mode == 'rgba':
179    num_channels = 4
180  elif color_mode == 'grayscale':
181    num_channels = 1
182  else:
183    raise ValueError(
184        '`color_mode` must be one of {"rbg", "rgba", "grayscale"}. '
185        'Received: %s' % (color_mode,))
186  interpolation = image_preprocessing.get_interpolation(interpolation)
187  dataset_utils.check_validation_split_arg(
188      validation_split, subset, shuffle, seed)
189
190  if seed is None:
191    seed = np.random.randint(1e6)
192  image_paths, labels, class_names = dataset_utils.index_directory(
193      directory,
194      labels,
195      formats=ALLOWLIST_FORMATS,
196      class_names=class_names,
197      shuffle=shuffle,
198      seed=seed,
199      follow_links=follow_links)
200
201  if label_mode == 'binary' and len(class_names) != 2:
202    raise ValueError(
203        'When passing `label_mode="binary", there must exactly 2 classes. '
204        'Found the following classes: %s' % (class_names,))
205
206  image_paths, labels = dataset_utils.get_training_or_validation_split(
207      image_paths, labels, validation_split, subset)
208  if not image_paths:
209    raise ValueError('No images found.')
210
211  dataset = paths_and_labels_to_dataset(
212      image_paths=image_paths,
213      image_size=image_size,
214      num_channels=num_channels,
215      labels=labels,
216      label_mode=label_mode,
217      num_classes=len(class_names),
218      interpolation=interpolation,
219      crop_to_aspect_ratio=crop_to_aspect_ratio)
220  if shuffle:
221    # Shuffle locally at each iteration
222    dataset = dataset.shuffle(buffer_size=batch_size * 8, seed=seed)
223  dataset = dataset.batch(batch_size)
224  # Users may need to reference `class_names`.
225  dataset.class_names = class_names
226  # Include file paths for images as attribute.
227  dataset.file_paths = image_paths
228  return dataset
229
230
231def paths_and_labels_to_dataset(image_paths,
232                                image_size,
233                                num_channels,
234                                labels,
235                                label_mode,
236                                num_classes,
237                                interpolation,
238                                crop_to_aspect_ratio=False):
239  """Constructs a dataset of images and labels."""
240  # TODO(fchollet): consider making num_parallel_calls settable
241  path_ds = dataset_ops.Dataset.from_tensor_slices(image_paths)
242  args = (image_size, num_channels, interpolation, crop_to_aspect_ratio)
243  img_ds = path_ds.map(
244      lambda x: load_image(x, *args))
245  if label_mode:
246    label_ds = dataset_utils.labels_to_dataset(labels, label_mode, num_classes)
247    img_ds = dataset_ops.Dataset.zip((img_ds, label_ds))
248  return img_ds
249
250
251def load_image(path, image_size, num_channels, interpolation,
252               crop_to_aspect_ratio=False):
253  """Load an image from a path and resize it."""
254  img = io_ops.read_file(path)
255  img = image_ops.decode_image(
256      img, channels=num_channels, expand_animations=False)
257  if crop_to_aspect_ratio:
258    img = keras_image_ops.smart_resize(img, image_size,
259                                       interpolation=interpolation)
260  else:
261    img = image_ops.resize_images_v2(img, image_size, method=interpolation)
262  img.set_shape((image_size[0], image_size[1], num_channels))
263  return img
264