• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=invalid-name
16# pylint: disable=g-import-not-at-top
17"""Set of tools for real-time data augmentation on image data.
18"""
19from __future__ import absolute_import
20from __future__ import division
21from __future__ import print_function
22
23from keras_preprocessing import image
24try:
25  from scipy import linalg  # pylint: disable=unused-import
26  from scipy import ndimage  # pylint: disable=unused-import
27except ImportError:
28  pass
29
30from tensorflow.python.keras import backend
31from tensorflow.python.keras.utils import data_utils
32from tensorflow.python.util import tf_inspect
33from tensorflow.python.util.tf_export import keras_export
34
35random_rotation = image.random_rotation
36random_shift = image.random_shift
37random_shear = image.random_shear
38random_zoom = image.random_zoom
39apply_channel_shift = image.apply_channel_shift
40random_channel_shift = image.random_channel_shift
41apply_brightness_shift = image.apply_brightness_shift
42random_brightness = image.random_brightness
43apply_affine_transform = image.apply_affine_transform
44load_img = image.load_img
45
46
47@keras_export('keras.preprocessing.image.array_to_img')
48def array_to_img(x, data_format=None, scale=True, dtype=None):
49  """Converts a 3D Numpy array to a PIL Image instance.
50
51  Usage:
52
53  >>> img = np.random.random(size=(100, 100, 3))
54  >>> try:
55  ...   from PIL import Image
56  ...   pil_img = tf.keras.preprocessing.image.array_to_img(img)
57  ... except ImportError:
58  ...   pass
59
60  Arguments:
61      x: Input Numpy array.
62      data_format: Image data format, can be either "channels_first" or
63        "channels_last". Defaults to `None`, which gets data format from Keras
64        backend.
65      scale: Whether to rescale image values to be within `[0, 255]`. Defaults
66        to `True`.
67      dtype: Dtype to use. Default to `None`, which gets float type from Keras
68        backend.
69
70  Returns:
71      A PIL Image instance.
72
73  Raises:
74      ImportError: if PIL is not available.
75      ValueError: if invalid `x` or `data_format` is passed.
76  """
77
78  if data_format is None:
79    data_format = backend.image_data_format()
80  kwargs = {}
81  if 'dtype' in tf_inspect.getfullargspec(image.array_to_img)[0]:
82    if dtype is None:
83      dtype = backend.floatx()
84    kwargs['dtype'] = dtype
85  return image.array_to_img(x, data_format=data_format, scale=scale, **kwargs)
86
87
88@keras_export('keras.preprocessing.image.img_to_array')
89def img_to_array(img, data_format=None, dtype=None):
90  """Converts a PIL Image instance to a Numpy array.
91
92  Arguments:
93      img: PIL Image instance.
94      data_format: Image data format,
95          either "channels_first" or "channels_last".
96      dtype: Dtype to use for the returned array.
97
98  Returns:
99      A 3D Numpy array.
100
101  Raises:
102      ValueError: if invalid `img` or `data_format` is passed.
103  """
104
105  if data_format is None:
106    data_format = backend.image_data_format()
107  kwargs = {}
108  if 'dtype' in tf_inspect.getfullargspec(image.img_to_array)[0]:
109    if dtype is None:
110      dtype = backend.floatx()
111    kwargs['dtype'] = dtype
112  return image.img_to_array(img, data_format=data_format, **kwargs)
113
114
115@keras_export('keras.preprocessing.image.save_img')
116def save_img(path,
117             x,
118             data_format=None,
119             file_format=None,
120             scale=True,
121             **kwargs):
122  """Saves an image stored as a Numpy array to a path or file object.
123
124  Arguments:
125      path: Path or file object.
126      x: Numpy array.
127      data_format: Image data format,
128          either "channels_first" or "channels_last".
129      file_format: Optional file format override. If omitted, the
130          format to use is determined from the filename extension.
131          If a file object was used instead of a filename, this
132          parameter should always be used.
133      scale: Whether to rescale image values to be within `[0, 255]`.
134      **kwargs: Additional keyword arguments passed to `PIL.Image.save()`.
135  """
136  if data_format is None:
137    data_format = backend.image_data_format()
138  image.save_img(path,
139                 x,
140                 data_format=data_format,
141                 file_format=file_format,
142                 scale=scale, **kwargs)
143
144
145@keras_export('keras.preprocessing.image.Iterator')
146class Iterator(image.Iterator, data_utils.Sequence):
147  pass
148
149
150@keras_export('keras.preprocessing.image.DirectoryIterator')
151class DirectoryIterator(image.DirectoryIterator, Iterator):
152  """Iterator capable of reading images from a directory on disk.
153
154  Arguments:
155      directory: Path to the directory to read images from.
156          Each subdirectory in this directory will be
157          considered to contain images from one class,
158          or alternatively you could specify class subdirectories
159          via the `classes` argument.
160      image_data_generator: Instance of `ImageDataGenerator`
161          to use for random transformations and normalization.
162      target_size: tuple of integers, dimensions to resize input images to.
163      color_mode: One of `"rgb"`, `"rgba"`, `"grayscale"`.
164          Color mode to read images.
165      classes: Optional list of strings, names of subdirectories
166          containing images from each class (e.g. `["dogs", "cats"]`).
167          It will be computed automatically if not set.
168      class_mode: Mode for yielding the targets:
169          `"binary"`: binary targets (if there are only two classes),
170          `"categorical"`: categorical targets,
171          `"sparse"`: integer targets,
172          `"input"`: targets are images identical to input images (mainly
173              used to work with autoencoders),
174          `None`: no targets get yielded (only input images are yielded).
175      batch_size: Integer, size of a batch.
176      shuffle: Boolean, whether to shuffle the data between epochs.
177      seed: Random seed for data shuffling.
178      data_format: String, one of `channels_first`, `channels_last`.
179      save_to_dir: Optional directory where to save the pictures
180          being yielded, in a viewable format. This is useful
181          for visualizing the random transformations being
182          applied, for debugging purposes.
183      save_prefix: String prefix to use for saving sample
184          images (if `save_to_dir` is set).
185      save_format: Format to use for saving sample images
186          (if `save_to_dir` is set).
187      subset: Subset of data (`"training"` or `"validation"`) if
188          validation_split is set in ImageDataGenerator.
189      interpolation: Interpolation method used to resample the image if the
190          target size is different from that of the loaded image.
191          Supported methods are "nearest", "bilinear", and "bicubic".
192          If PIL version 1.1.3 or newer is installed, "lanczos" is also
193          supported. If PIL version 3.4.0 or newer is installed, "box" and
194          "hamming" are also supported. By default, "nearest" is used.
195      dtype: Dtype to use for generated arrays.
196  """
197
198  def __init__(self, directory, image_data_generator,
199               target_size=(256, 256),
200               color_mode='rgb',
201               classes=None,
202               class_mode='categorical',
203               batch_size=32,
204               shuffle=True,
205               seed=None,
206               data_format=None,
207               save_to_dir=None,
208               save_prefix='',
209               save_format='png',
210               follow_links=False,
211               subset=None,
212               interpolation='nearest',
213               dtype=None):
214    if data_format is None:
215      data_format = backend.image_data_format()
216    kwargs = {}
217    if 'dtype' in tf_inspect.getfullargspec(
218        image.ImageDataGenerator.__init__)[0]:
219      if dtype is None:
220        dtype = backend.floatx()
221      kwargs['dtype'] = dtype
222    super(DirectoryIterator, self).__init__(
223        directory, image_data_generator,
224        target_size=target_size,
225        color_mode=color_mode,
226        classes=classes,
227        class_mode=class_mode,
228        batch_size=batch_size,
229        shuffle=shuffle,
230        seed=seed,
231        data_format=data_format,
232        save_to_dir=save_to_dir,
233        save_prefix=save_prefix,
234        save_format=save_format,
235        follow_links=follow_links,
236        subset=subset,
237        interpolation=interpolation,
238        **kwargs)
239
240
241@keras_export('keras.preprocessing.image.NumpyArrayIterator')
242class NumpyArrayIterator(image.NumpyArrayIterator, Iterator):
243  """Iterator yielding data from a Numpy array.
244
245  Arguments:
246      x: Numpy array of input data or tuple.
247          If tuple, the second elements is either
248          another numpy array or a list of numpy arrays,
249          each of which gets passed
250          through as an output without any modifications.
251      y: Numpy array of targets data.
252      image_data_generator: Instance of `ImageDataGenerator`
253          to use for random transformations and normalization.
254      batch_size: Integer, size of a batch.
255      shuffle: Boolean, whether to shuffle the data between epochs.
256      sample_weight: Numpy array of sample weights.
257      seed: Random seed for data shuffling.
258      data_format: String, one of `channels_first`, `channels_last`.
259      save_to_dir: Optional directory where to save the pictures
260          being yielded, in a viewable format. This is useful
261          for visualizing the random transformations being
262          applied, for debugging purposes.
263      save_prefix: String prefix to use for saving sample
264          images (if `save_to_dir` is set).
265      save_format: Format to use for saving sample images
266          (if `save_to_dir` is set).
267      subset: Subset of data (`"training"` or `"validation"`) if
268          validation_split is set in ImageDataGenerator.
269      dtype: Dtype to use for the generated arrays.
270  """
271
272  def __init__(self, x, y, image_data_generator,
273               batch_size=32,
274               shuffle=False,
275               sample_weight=None,
276               seed=None,
277               data_format=None,
278               save_to_dir=None,
279               save_prefix='',
280               save_format='png',
281               subset=None,
282               dtype=None):
283    if data_format is None:
284      data_format = backend.image_data_format()
285    kwargs = {}
286    if 'dtype' in tf_inspect.getfullargspec(
287        image.NumpyArrayIterator.__init__)[0]:
288      if dtype is None:
289        dtype = backend.floatx()
290      kwargs['dtype'] = dtype
291    super(NumpyArrayIterator, self).__init__(
292        x, y, image_data_generator,
293        batch_size=batch_size,
294        shuffle=shuffle,
295        sample_weight=sample_weight,
296        seed=seed,
297        data_format=data_format,
298        save_to_dir=save_to_dir,
299        save_prefix=save_prefix,
300        save_format=save_format,
301        subset=subset,
302        **kwargs)
303
304
305@keras_export('keras.preprocessing.image.ImageDataGenerator')
306class ImageDataGenerator(image.ImageDataGenerator):
307  """Generate batches of tensor image data with real-time data augmentation.
308
309   The data will be looped over (in batches).
310
311  Arguments:
312      featurewise_center: Boolean.
313          Set input mean to 0 over the dataset, feature-wise.
314      samplewise_center: Boolean. Set each sample mean to 0.
315      featurewise_std_normalization: Boolean.
316          Divide inputs by std of the dataset, feature-wise.
317      samplewise_std_normalization: Boolean. Divide each input by its std.
318      zca_epsilon: epsilon for ZCA whitening. Default is 1e-6.
319      zca_whitening: Boolean. Apply ZCA whitening.
320      rotation_range: Int. Degree range for random rotations.
321      width_shift_range: Float, 1-D array-like or int
322          - float: fraction of total width, if < 1, or pixels if >= 1.
323          - 1-D array-like: random elements from the array.
324          - int: integer number of pixels from interval
325              `(-width_shift_range, +width_shift_range)`
326          - With `width_shift_range=2` possible values
327              are integers `[-1, 0, +1]`,
328              same as with `width_shift_range=[-1, 0, +1]`,
329              while with `width_shift_range=1.0` possible values are floats
330              in the interval [-1.0, +1.0).
331      height_shift_range: Float, 1-D array-like or int
332          - float: fraction of total height, if < 1, or pixels if >= 1.
333          - 1-D array-like: random elements from the array.
334          - int: integer number of pixels from interval
335              `(-height_shift_range, +height_shift_range)`
336          - With `height_shift_range=2` possible values
337              are integers `[-1, 0, +1]`,
338              same as with `height_shift_range=[-1, 0, +1]`,
339              while with `height_shift_range=1.0` possible values are floats
340              in the interval [-1.0, +1.0).
341      brightness_range: Tuple or list of two floats. Range for picking
342          a brightness shift value from.
343      shear_range: Float. Shear Intensity
344          (Shear angle in counter-clockwise direction in degrees)
345      zoom_range: Float or [lower, upper]. Range for random zoom.
346          If a float, `[lower, upper] = [1-zoom_range, 1+zoom_range]`.
347      channel_shift_range: Float. Range for random channel shifts.
348      fill_mode: One of {"constant", "nearest", "reflect" or "wrap"}.
349          Default is 'nearest'.
350          Points outside the boundaries of the input are filled
351          according to the given mode:
352          - 'constant': kkkkkkkk|abcd|kkkkkkkk (cval=k)
353          - 'nearest':  aaaaaaaa|abcd|dddddddd
354          - 'reflect':  abcddcba|abcd|dcbaabcd
355          - 'wrap':  abcdabcd|abcd|abcdabcd
356      cval: Float or Int.
357          Value used for points outside the boundaries
358          when `fill_mode = "constant"`.
359      horizontal_flip: Boolean. Randomly flip inputs horizontally.
360      vertical_flip: Boolean. Randomly flip inputs vertically.
361      rescale: rescaling factor. Defaults to None.
362          If None or 0, no rescaling is applied,
363          otherwise we multiply the data by the value provided
364          (after applying all other transformations).
365      preprocessing_function: function that will be applied on each input.
366          The function will run after the image is resized and augmented.
367          The function should take one argument:
368          one image (Numpy tensor with rank 3),
369          and should output a Numpy tensor with the same shape.
370      data_format: Image data format,
371          either "channels_first" or "channels_last".
372          "channels_last" mode means that the images should have shape
373          `(samples, height, width, channels)`,
374          "channels_first" mode means that the images should have shape
375          `(samples, channels, height, width)`.
376          It defaults to the `image_data_format` value found in your
377          Keras config file at `~/.keras/keras.json`.
378          If you never set it, then it will be "channels_last".
379      validation_split: Float. Fraction of images reserved for validation
380          (strictly between 0 and 1).
381      dtype: Dtype to use for the generated arrays.
382
383  Examples:
384
385  Example of using `.flow(x, y)`:
386
387  ```python
388  (x_train, y_train), (x_test, y_test) = cifar10.load_data()
389  y_train = np_utils.to_categorical(y_train, num_classes)
390  y_test = np_utils.to_categorical(y_test, num_classes)
391  datagen = ImageDataGenerator(
392      featurewise_center=True,
393      featurewise_std_normalization=True,
394      rotation_range=20,
395      width_shift_range=0.2,
396      height_shift_range=0.2,
397      horizontal_flip=True)
398  # compute quantities required for featurewise normalization
399  # (std, mean, and principal components if ZCA whitening is applied)
400  datagen.fit(x_train)
401  # fits the model on batches with real-time data augmentation:
402  model.fit_generator(datagen.flow(x_train, y_train, batch_size=32),
403                      steps_per_epoch=len(x_train) / 32, epochs=epochs)
404  # here's a more "manual" example
405  for e in range(epochs):
406      print('Epoch', e)
407      batches = 0
408      for x_batch, y_batch in datagen.flow(x_train, y_train, batch_size=32):
409          model.fit(x_batch, y_batch)
410          batches += 1
411          if batches >= len(x_train) / 32:
412              # we need to break the loop by hand because
413              # the generator loops indefinitely
414              break
415  ```
416
417  Example of using `.flow_from_directory(directory)`:
418
419  ```python
420  train_datagen = ImageDataGenerator(
421          rescale=1./255,
422          shear_range=0.2,
423          zoom_range=0.2,
424          horizontal_flip=True)
425  test_datagen = ImageDataGenerator(rescale=1./255)
426  train_generator = train_datagen.flow_from_directory(
427          'data/train',
428          target_size=(150, 150),
429          batch_size=32,
430          class_mode='binary')
431  validation_generator = test_datagen.flow_from_directory(
432          'data/validation',
433          target_size=(150, 150),
434          batch_size=32,
435          class_mode='binary')
436  model.fit_generator(
437          train_generator,
438          steps_per_epoch=2000,
439          epochs=50,
440          validation_data=validation_generator,
441          validation_steps=800)
442  ```
443
444  Example of transforming images and masks together.
445
446  ```python
447  # we create two instances with the same arguments
448  data_gen_args = dict(featurewise_center=True,
449                       featurewise_std_normalization=True,
450                       rotation_range=90,
451                       width_shift_range=0.1,
452                       height_shift_range=0.1,
453                       zoom_range=0.2)
454  image_datagen = ImageDataGenerator(**data_gen_args)
455  mask_datagen = ImageDataGenerator(**data_gen_args)
456  # Provide the same seed and keyword arguments to the fit and flow methods
457  seed = 1
458  image_datagen.fit(images, augment=True, seed=seed)
459  mask_datagen.fit(masks, augment=True, seed=seed)
460  image_generator = image_datagen.flow_from_directory(
461      'data/images',
462      class_mode=None,
463      seed=seed)
464  mask_generator = mask_datagen.flow_from_directory(
465      'data/masks',
466      class_mode=None,
467      seed=seed)
468  # combine generators into one which yields image and masks
469  train_generator = zip(image_generator, mask_generator)
470  model.fit_generator(
471      train_generator,
472      steps_per_epoch=2000,
473      epochs=50)
474  ```
475  """
476
477  def __init__(self,
478               featurewise_center=False,
479               samplewise_center=False,
480               featurewise_std_normalization=False,
481               samplewise_std_normalization=False,
482               zca_whitening=False,
483               zca_epsilon=1e-6,
484               rotation_range=0,
485               width_shift_range=0.,
486               height_shift_range=0.,
487               brightness_range=None,
488               shear_range=0.,
489               zoom_range=0.,
490               channel_shift_range=0.,
491               fill_mode='nearest',
492               cval=0.,
493               horizontal_flip=False,
494               vertical_flip=False,
495               rescale=None,
496               preprocessing_function=None,
497               data_format=None,
498               validation_split=0.0,
499               dtype=None):
500    if data_format is None:
501      data_format = backend.image_data_format()
502    kwargs = {}
503    if 'dtype' in tf_inspect.getfullargspec(
504        image.ImageDataGenerator.__init__)[0]:
505      if dtype is None:
506        dtype = backend.floatx()
507      kwargs['dtype'] = dtype
508    super(ImageDataGenerator, self).__init__(
509        featurewise_center=featurewise_center,
510        samplewise_center=samplewise_center,
511        featurewise_std_normalization=featurewise_std_normalization,
512        samplewise_std_normalization=samplewise_std_normalization,
513        zca_whitening=zca_whitening,
514        zca_epsilon=zca_epsilon,
515        rotation_range=rotation_range,
516        width_shift_range=width_shift_range,
517        height_shift_range=height_shift_range,
518        brightness_range=brightness_range,
519        shear_range=shear_range,
520        zoom_range=zoom_range,
521        channel_shift_range=channel_shift_range,
522        fill_mode=fill_mode,
523        cval=cval,
524        horizontal_flip=horizontal_flip,
525        vertical_flip=vertical_flip,
526        rescale=rescale,
527        preprocessing_function=preprocessing_function,
528        data_format=data_format,
529        validation_split=validation_split,
530        **kwargs)
531
532keras_export('keras.preprocessing.image.random_rotation')(random_rotation)
533keras_export('keras.preprocessing.image.random_shift')(random_shift)
534keras_export('keras.preprocessing.image.random_shear')(random_shear)
535keras_export('keras.preprocessing.image.random_zoom')(random_zoom)
536keras_export(
537    'keras.preprocessing.image.apply_channel_shift')(apply_channel_shift)
538keras_export(
539    'keras.preprocessing.image.random_channel_shift')(random_channel_shift)
540keras_export(
541    'keras.preprocessing.image.apply_brightness_shift')(apply_brightness_shift)
542keras_export('keras.preprocessing.image.random_brightness')(random_brightness)
543keras_export(
544    'keras.preprocessing.image.apply_affine_transform')(apply_affine_transform)
545keras_export('keras.preprocessing.image.load_img')(load_img)
546