1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Keras image dataset loading utilities.""" 16# pylint: disable=g-classes-have-attributes 17 18import numpy as np 19 20from tensorflow.python.data.ops import dataset_ops 21from tensorflow.python.keras.layers.preprocessing import image_preprocessing 22from tensorflow.python.keras.preprocessing import dataset_utils 23from tensorflow.python.keras.preprocessing import image as keras_image_ops 24from tensorflow.python.ops import image_ops 25from tensorflow.python.ops import io_ops 26from tensorflow.python.util.tf_export import keras_export 27 28 29ALLOWLIST_FORMATS = ('.bmp', '.gif', '.jpeg', '.jpg', '.png') 30 31 32@keras_export('keras.utils.image_dataset_from_directory', 33 'keras.preprocessing.image_dataset_from_directory', 34 v1=[]) 35def image_dataset_from_directory(directory, 36 labels='inferred', 37 label_mode='int', 38 class_names=None, 39 color_mode='rgb', 40 batch_size=32, 41 image_size=(256, 256), 42 shuffle=True, 43 seed=None, 44 validation_split=None, 45 subset=None, 46 interpolation='bilinear', 47 follow_links=False, 48 crop_to_aspect_ratio=False, 49 **kwargs): 50 """Generates a `tf.data.Dataset` from image files in a directory. 51 52 If your directory structure is: 53 54 ``` 55 main_directory/ 56 ...class_a/ 57 ......a_image_1.jpg 58 ......a_image_2.jpg 59 ...class_b/ 60 ......b_image_1.jpg 61 ......b_image_2.jpg 62 ``` 63 64 Then calling `image_dataset_from_directory(main_directory, labels='inferred')` 65 will return a `tf.data.Dataset` that yields batches of images from 66 the subdirectories `class_a` and `class_b`, together with labels 67 0 and 1 (0 corresponding to `class_a` and 1 corresponding to `class_b`). 68 69 Supported image formats: jpeg, png, bmp, gif. 70 Animated gifs are truncated to the first frame. 71 72 Args: 73 directory: Directory where the data is located. 74 If `labels` is "inferred", it should contain 75 subdirectories, each containing images for a class. 76 Otherwise, the directory structure is ignored. 77 labels: Either "inferred" 78 (labels are generated from the directory structure), 79 None (no labels), 80 or a list/tuple of integer labels of the same size as the number of 81 image files found in the directory. Labels should be sorted according 82 to the alphanumeric order of the image file paths 83 (obtained via `os.walk(directory)` in Python). 84 label_mode: 85 - 'int': means that the labels are encoded as integers 86 (e.g. for `sparse_categorical_crossentropy` loss). 87 - 'categorical' means that the labels are 88 encoded as a categorical vector 89 (e.g. for `categorical_crossentropy` loss). 90 - 'binary' means that the labels (there can be only 2) 91 are encoded as `float32` scalars with values 0 or 1 92 (e.g. for `binary_crossentropy`). 93 - None (no labels). 94 class_names: Only valid if "labels" is "inferred". This is the explict 95 list of class names (must match names of subdirectories). Used 96 to control the order of the classes 97 (otherwise alphanumerical order is used). 98 color_mode: One of "grayscale", "rgb", "rgba". Default: "rgb". 99 Whether the images will be converted to 100 have 1, 3, or 4 channels. 101 batch_size: Size of the batches of data. Default: 32. 102 image_size: Size to resize images to after they are read from disk. 103 Defaults to `(256, 256)`. 104 Since the pipeline processes batches of images that must all have 105 the same size, this must be provided. 106 shuffle: Whether to shuffle the data. Default: True. 107 If set to False, sorts the data in alphanumeric order. 108 seed: Optional random seed for shuffling and transformations. 109 validation_split: Optional float between 0 and 1, 110 fraction of data to reserve for validation. 111 subset: One of "training" or "validation". 112 Only used if `validation_split` is set. 113 interpolation: String, the interpolation method used when resizing images. 114 Defaults to `bilinear`. Supports `bilinear`, `nearest`, `bicubic`, 115 `area`, `lanczos3`, `lanczos5`, `gaussian`, `mitchellcubic`. 116 follow_links: Whether to visits subdirectories pointed to by symlinks. 117 Defaults to False. 118 crop_to_aspect_ratio: If True, resize the images without aspect 119 ratio distortion. When the original aspect ratio differs from the target 120 aspect ratio, the output image will be cropped so as to return the largest 121 possible window in the image (of size `image_size`) that matches 122 the target aspect ratio. By default (`crop_to_aspect_ratio=False`), 123 aspect ratio may not be preserved. 124 **kwargs: Legacy keyword arguments. 125 126 Returns: 127 A `tf.data.Dataset` object. 128 - If `label_mode` is None, it yields `float32` tensors of shape 129 `(batch_size, image_size[0], image_size[1], num_channels)`, 130 encoding images (see below for rules regarding `num_channels`). 131 - Otherwise, it yields a tuple `(images, labels)`, where `images` 132 has shape `(batch_size, image_size[0], image_size[1], num_channels)`, 133 and `labels` follows the format described below. 134 135 Rules regarding labels format: 136 - if `label_mode` is `int`, the labels are an `int32` tensor of shape 137 `(batch_size,)`. 138 - if `label_mode` is `binary`, the labels are a `float32` tensor of 139 1s and 0s of shape `(batch_size, 1)`. 140 - if `label_mode` is `categorial`, the labels are a `float32` tensor 141 of shape `(batch_size, num_classes)`, representing a one-hot 142 encoding of the class index. 143 144 Rules regarding number of channels in the yielded images: 145 - if `color_mode` is `grayscale`, 146 there's 1 channel in the image tensors. 147 - if `color_mode` is `rgb`, 148 there are 3 channel in the image tensors. 149 - if `color_mode` is `rgba`, 150 there are 4 channel in the image tensors. 151 """ 152 if 'smart_resize' in kwargs: 153 crop_to_aspect_ratio = kwargs.pop('smart_resize') 154 if kwargs: 155 raise TypeError(f'Unknown keywords argument(s): {tuple(kwargs.keys())}') 156 if labels not in ('inferred', None): 157 if not isinstance(labels, (list, tuple)): 158 raise ValueError( 159 '`labels` argument should be a list/tuple of integer labels, of ' 160 'the same size as the number of image files in the target ' 161 'directory. If you wish to infer the labels from the subdirectory ' 162 'names in the target directory, pass `labels="inferred"`. ' 163 'If you wish to get a dataset that only contains images ' 164 '(no labels), pass `label_mode=None`.') 165 if class_names: 166 raise ValueError('You can only pass `class_names` if the labels are ' 167 'inferred from the subdirectory names in the target ' 168 'directory (`labels="inferred"`).') 169 if label_mode not in {'int', 'categorical', 'binary', None}: 170 raise ValueError( 171 '`label_mode` argument must be one of "int", "categorical", "binary", ' 172 'or None. Received: %s' % (label_mode,)) 173 if labels is None or label_mode is None: 174 labels = None 175 label_mode = None 176 if color_mode == 'rgb': 177 num_channels = 3 178 elif color_mode == 'rgba': 179 num_channels = 4 180 elif color_mode == 'grayscale': 181 num_channels = 1 182 else: 183 raise ValueError( 184 '`color_mode` must be one of {"rbg", "rgba", "grayscale"}. ' 185 'Received: %s' % (color_mode,)) 186 interpolation = image_preprocessing.get_interpolation(interpolation) 187 dataset_utils.check_validation_split_arg( 188 validation_split, subset, shuffle, seed) 189 190 if seed is None: 191 seed = np.random.randint(1e6) 192 image_paths, labels, class_names = dataset_utils.index_directory( 193 directory, 194 labels, 195 formats=ALLOWLIST_FORMATS, 196 class_names=class_names, 197 shuffle=shuffle, 198 seed=seed, 199 follow_links=follow_links) 200 201 if label_mode == 'binary' and len(class_names) != 2: 202 raise ValueError( 203 'When passing `label_mode="binary", there must exactly 2 classes. ' 204 'Found the following classes: %s' % (class_names,)) 205 206 image_paths, labels = dataset_utils.get_training_or_validation_split( 207 image_paths, labels, validation_split, subset) 208 if not image_paths: 209 raise ValueError('No images found.') 210 211 dataset = paths_and_labels_to_dataset( 212 image_paths=image_paths, 213 image_size=image_size, 214 num_channels=num_channels, 215 labels=labels, 216 label_mode=label_mode, 217 num_classes=len(class_names), 218 interpolation=interpolation, 219 crop_to_aspect_ratio=crop_to_aspect_ratio) 220 if shuffle: 221 # Shuffle locally at each iteration 222 dataset = dataset.shuffle(buffer_size=batch_size * 8, seed=seed) 223 dataset = dataset.batch(batch_size) 224 # Users may need to reference `class_names`. 225 dataset.class_names = class_names 226 # Include file paths for images as attribute. 227 dataset.file_paths = image_paths 228 return dataset 229 230 231def paths_and_labels_to_dataset(image_paths, 232 image_size, 233 num_channels, 234 labels, 235 label_mode, 236 num_classes, 237 interpolation, 238 crop_to_aspect_ratio=False): 239 """Constructs a dataset of images and labels.""" 240 # TODO(fchollet): consider making num_parallel_calls settable 241 path_ds = dataset_ops.Dataset.from_tensor_slices(image_paths) 242 args = (image_size, num_channels, interpolation, crop_to_aspect_ratio) 243 img_ds = path_ds.map( 244 lambda x: load_image(x, *args)) 245 if label_mode: 246 label_ds = dataset_utils.labels_to_dataset(labels, label_mode, num_classes) 247 img_ds = dataset_ops.Dataset.zip((img_ds, label_ds)) 248 return img_ds 249 250 251def load_image(path, image_size, num_channels, interpolation, 252 crop_to_aspect_ratio=False): 253 """Load an image from a path and resize it.""" 254 img = io_ops.read_file(path) 255 img = image_ops.decode_image( 256 img, channels=num_channels, expand_animations=False) 257 if crop_to_aspect_ratio: 258 img = keras_image_ops.smart_resize(img, image_size, 259 interpolation=interpolation) 260 else: 261 img = image_ops.resize_images_v2(img, image_size, method=interpolation) 262 img.set_shape((image_size[0], image_size[1], num_channels)) 263 return img 264