• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Built-in py_transforms_utils functions.
16"""
17import io
18import math
19import numbers
20import random
21import colorsys
22import numpy as np
23from PIL import Image, ImageOps, ImageEnhance, __version__
24
25from .utils import Inter
26from ..core.py_util_helpers import is_numpy
27
28augment_error_message = "img should be PIL image. Got {}. Use Decode() for encoded data or ToPIL() for decoded data."
29
30
31def is_pil(img):
32    """
33    Check if the input image is PIL format.
34
35    Args:
36        img: Image to be checked.
37
38    Returns:
39        Bool, True if input is PIL image.
40    """
41    return isinstance(img, Image.Image)
42
43
44def normalize(img, mean, std, pad_channel=False, dtype="float32"):
45    """
46    Normalize the image between [0, 1] with respect to mean and standard deviation.
47
48    Args:
49        img (numpy.ndarray): Image array of shape CHW to be normalized.
50        mean (list): List of mean values for each channel, w.r.t channel order.
51        std (list): List of standard deviations for each channel, w.r.t. channel order.
52        pad_channel (bool): Whether to pad a extra channel with value zero.
53        dtype (str): Output datatype of normalize, only worked when pad_channel is True. (default is "float32")
54
55    Returns:
56        img (numpy.ndarray), Normalized image.
57    """
58    if not is_numpy(img):
59        raise TypeError("img should be NumPy image. Got {}.".format(type(img)))
60
61    if img.ndim != 3:
62        raise TypeError('img dimension should be 3. Got {}.'.format(img.ndim))
63
64    if np.issubdtype(img.dtype, np.integer):
65        raise NotImplementedError("Unsupported image datatype: [{}], pls execute [ToTensor] before [Normalize]."
66                                  .format(img.dtype))
67
68    num_channels = img.shape[0]  # shape is (C, H, W)
69
70    if len(mean) != len(std):
71        raise ValueError("Length of mean and std must be equal.")
72    # if length equal to 1, adjust the mean and std arrays to have the correct
73    # number of channels (replicate the values)
74    if len(mean) == 1:
75        mean = [mean[0]] * num_channels
76        std = [std[0]] * num_channels
77    elif len(mean) != num_channels:
78        raise ValueError("Length of mean and std must both be 1 or equal to the number of channels({0})."
79                         .format(num_channels))
80
81    mean = np.array(mean, dtype=img.dtype)
82    std = np.array(std, dtype=img.dtype)
83
84    image = (img - mean[:, None, None]) / std[:, None, None]
85    if pad_channel:
86        zeros = np.zeros([1, image.shape[1], image.shape[2]], dtype=np.float32)
87        image = np.concatenate((image, zeros), axis=0)
88        if dtype == "float16":
89            image = image.astype(np.float16)
90    return image
91
92
93def decode(img):
94    """
95    Decode the input image to PIL image format in RGB mode.
96
97    Args:
98        img: Image to be decoded.
99
100    Returns:
101        img (PIL image), Decoded image in RGB mode.
102    """
103
104    try:
105        data = io.BytesIO(img)
106        img = Image.open(data)
107        return img.convert('RGB')
108    except IOError as e:
109        raise ValueError("{0}\n: Failed to decode given image.".format(e))
110    except AttributeError as e:
111        raise ValueError("{0}\n: Failed to decode, Image might already be decoded.".format(e))
112
113
114def hwc_to_chw(img):
115    """
116    Transpose the input image; shape (H, W, C) to shape (C, H, W).
117
118    Args:
119        img (numpy.ndarray): Image to be converted.
120
121    Returns:
122        img (numpy.ndarray), Converted image.
123    """
124    if not is_numpy(img):
125        raise TypeError('img should be NumPy array. Got {}.'.format(type(img)))
126    if img.ndim != 3:
127        raise TypeError('img dimension should be 3. Got {}.'.format(img.ndim))
128    return img.transpose(2, 0, 1).copy()
129
130
131def to_tensor(img, output_type):
132    """
133    Change the input image (PIL image or NumPy image array) to NumPy format.
134
135    Args:
136        img (Union[PIL image, numpy.ndarray]): Image to be converted.
137        output_type: The datatype of the NumPy output. e.g. np.float32
138
139    Returns:
140        img (numpy.ndarray), Converted image.
141    """
142    if not (is_pil(img) or is_numpy(img)):
143        raise TypeError("img should be PIL image or NumPy array. Got {}.".format(type(img)))
144
145    img = np.asarray(img)
146    if img.ndim not in (2, 3):
147        raise TypeError("img dimension should be 2 or 3. Got {}.".format(img.ndim))
148
149    if img.ndim == 2:
150        img = img[:, :, None]
151
152    img = hwc_to_chw(img)
153
154    img = img / 255.
155    return to_type(img, output_type)
156
157
158def to_pil(img):
159    """
160    Convert the input image to PIL format.
161
162    Args:
163        img: Image to be converted.
164
165    Returns:
166        img (PIL image), Converted image.
167    """
168    if not is_pil(img):
169        if not isinstance(img, np.ndarray):
170            raise TypeError("The input of ToPIL should be ndarray. Got {}".format(type(img)))
171        return Image.fromarray(img)
172    return img
173
174
175def horizontal_flip(img):
176    """
177    Flip the input image horizontally.
178
179    Args:
180        img (PIL image): Image to be flipped horizontally.
181
182    Returns:
183        img (PIL image), Horizontally flipped image.
184    """
185    if not is_pil(img):
186        raise TypeError(augment_error_message.format(type(img)))
187
188    return img.transpose(Image.FLIP_LEFT_RIGHT)
189
190
191def vertical_flip(img):
192    """
193    Flip the input image vertically.
194
195    Args:
196        img (PIL image): Image to be flipped vertically.
197
198    Returns:
199        img (PIL image), Vertically flipped image.
200    """
201    if not is_pil(img):
202        raise TypeError(augment_error_message.format(type(img)))
203
204    return img.transpose(Image.FLIP_TOP_BOTTOM)
205
206
207def random_horizontal_flip(img, prob):
208    """
209    Randomly flip the input image horizontally.
210
211    Args:
212        img (PIL image): Image to be flipped.
213            If the given probability is above the random probability, then the image is flipped.
214        prob (float): Probability of the image being flipped.
215
216    Returns:
217        img (PIL image), Converted image.
218    """
219    if not is_pil(img):
220        raise TypeError(augment_error_message.format(type(img)))
221
222    if prob > random.random():
223        img = horizontal_flip(img)
224    return img
225
226
227def random_vertical_flip(img, prob):
228    """
229    Randomly flip the input image vertically.
230
231    Args:
232        img (PIL image): Image to be flipped.
233            If the given probability is above the random probability, then the image is flipped.
234        prob (float): Probability of the image being flipped.
235
236    Returns:
237        img (PIL image), Converted image.
238    """
239    if not is_pil(img):
240        raise TypeError(augment_error_message.format(type(img)))
241
242    if prob > random.random():
243        img = vertical_flip(img)
244    return img
245
246
247def crop(img, top, left, height, width):
248    """
249    Crop the input PIL image.
250
251    Args:
252        img (PIL image): Image to be cropped. (0,0) denotes the top left corner of the image,
253            in the directions of (width, height).
254        top (int): Vertical component of the top left corner of the crop box.
255        left (int): Horizontal component of the top left corner of the crop box.
256        height (int): Height of the crop box.
257        width (int): Width of the crop box.
258
259    Returns:
260        img (PIL image), Cropped image.
261    """
262    if not is_pil(img):
263        raise TypeError(augment_error_message.format(type(img)))
264
265    return img.crop((left, top, left + width, top + height))
266
267
268def resize(img, size, interpolation=Inter.BILINEAR):
269    """
270    Resize the input PIL image to desired size.
271
272    Args:
273        img (PIL image): Image to be resized.
274        size (Union[int, sequence]): The output size of the resized image.
275            If size is an integer, smaller edge of the image will be resized to this value with
276            the same image aspect ratio.
277            If size is a sequence of (height, width), this will be the desired output size.
278        interpolation (interpolation mode): Image interpolation mode. Default is Inter.BILINEAR = 2.
279
280    Returns:
281        img (PIL image), Resized image.
282    """
283    if not is_pil(img):
284        raise TypeError(augment_error_message.format(type(img)))
285    if not (isinstance(size, int) or (isinstance(size, (list, tuple)) and len(size) == 2)):
286        raise TypeError('Size should be a single number or a list/tuple (h, w) of length 2.'
287                        'Got {}.'.format(size))
288
289    if isinstance(size, int):
290        img_width, img_height = img.size
291        aspect_ratio = img_width / img_height  # maintain the aspect ratio
292        if (img_width <= img_height and img_width == size) or \
293                (img_height <= img_width and img_height == size):
294            return img
295        if img_width < img_height:
296            out_width = size
297            out_height = int(size / aspect_ratio)
298            return img.resize((out_width, out_height), interpolation)
299        out_height = size
300        out_width = int(size * aspect_ratio)
301        return img.resize((out_width, out_height), interpolation)
302    return img.resize(size[::-1], interpolation)
303
304
305def center_crop(img, size):
306    """
307    Crop the input PIL image at the center to the given size.
308
309    Args:
310        img (PIL image): Image to be cropped.
311        size (Union[int, tuple]): The size of the crop box.
312            If size is an integer, a square crop of size (size, size) is returned.
313            If size is a sequence of length 2, it should be (height, width).
314
315    Returns:
316        img (PIL image), Cropped image.
317    """
318    if not is_pil(img):
319        raise TypeError(augment_error_message.format(type(img)))
320
321    if isinstance(size, int):
322        size = (size, size)
323    img_width, img_height = img.size
324    crop_height, crop_width = size
325    crop_top = int(round((img_height - crop_height) / 2.))
326    crop_left = int(round((img_width - crop_width) / 2.))
327    return crop(img, crop_top, crop_left, crop_height, crop_width)
328
329
330def random_resize_crop(img, size, scale, ratio, interpolation=Inter.BILINEAR, max_attempts=10):
331    """
332    Crop the input PIL image to a random size and aspect ratio.
333
334    Args:
335        img (PIL image): Image to be randomly cropped and resized.
336        size (Union[int, sequence]): The size of the output image.
337            If size is an integer, a square crop of size (size, size) is returned.
338            If size is a sequence of length 2, it should be (height, width).
339        scale (tuple): Range (min, max) of respective size of the original size to be cropped.
340        ratio (tuple): Range (min, max) of aspect ratio to be cropped.
341        interpolation (interpolation mode): Image interpolation mode. Default is Inter.BILINEAR = 2.
342        max_attempts (int): The maximum number of attempts to propose a valid crop_area. Default 10.
343            If exceeded, fall back to use center_crop instead.
344
345    Returns:
346        img (PIL image), Randomly cropped and resized image.
347    """
348    if not is_pil(img):
349        raise TypeError(augment_error_message.format(type(img)))
350    if isinstance(size, int):
351        size = (size, size)
352    elif isinstance(size, (tuple, list)) and len(size) == 2:
353        size = size
354    else:
355        raise TypeError("Size should be a single integer or a list/tuple (h, w) of length 2.")
356
357    if scale[0] > scale[1] or ratio[0] > ratio[1]:
358        raise ValueError("Range should be in the order of (min, max).")
359
360    def _input_to_factor(img, scale, ratio):
361        img_width, img_height = img.size
362        img_area = img_width * img_height
363
364        for _ in range(max_attempts):
365            crop_area = random.uniform(scale[0], scale[1]) * img_area
366            # in case of non-symmetrical aspect ratios,
367            # use uniform distribution on a logarithmic scale.
368            log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
369            aspect_ratio = math.exp(random.uniform(*log_ratio))
370
371            width = int(round(math.sqrt(crop_area * aspect_ratio)))
372            height = int(round(width / aspect_ratio))
373
374            if 0 < width <= img_width and 0 < height <= img_height:
375                top = random.randint(0, img_height - height)
376                left = random.randint(0, img_width - width)
377                return top, left, height, width
378
379        # exceeding max_attempts, use center crop
380        img_ratio = img_width / img_height
381        if img_ratio < ratio[0]:
382            width = img_width
383            height = int(round(width / ratio[0]))
384        elif img_ratio > ratio[1]:
385            height = img_height
386            width = int(round(height * ratio[1]))
387        else:
388            width = img_width
389            height = img_height
390        top = int(round((img_height - height) / 2.))
391        left = int(round((img_width - width) / 2.))
392        return top, left, height, width
393
394    top, left, height, width = _input_to_factor(img, scale, ratio)
395    img = crop(img, top, left, height, width)
396    img = resize(img, size, interpolation)
397    return img
398
399
400def random_crop(img, size, padding, pad_if_needed, fill_value, padding_mode):
401    """
402    Crop the input PIL image at a random location.
403
404    Args:
405        img (PIL image): Image to be randomly cropped.
406        size (Union[int, sequence]): The output size of the cropped image.
407            If size is an integer, a square crop of size (size, size) is returned.
408            If size is a sequence of length 2, it should be (height, width).
409        padding (Union[int, sequence], optional): The number of pixels to pad the image.
410            If a single number is provided, it pads all borders with this value.
411            If a tuple or lists of 2 values are provided, it pads the (left and top)
412            with the first value and (right and bottom) with the second value.
413            If 4 values are provided as a list or tuple,
414            it pads the left, top, right and bottom respectively.
415            Default is None.
416        pad_if_needed (bool): Pad the image if either side is smaller than
417            the given output size. Default is False.
418        fill_value (Union[int, tuple]): The pixel intensity of the borders if
419            the padding_mode is 'constant'. If it is a 3-tuple, it is used to
420            fill R, G, B channels respectively.
421        padding_mode (str): The method of padding. Can be any of ['constant', 'edge', 'reflect', 'symmetric'].
422
423              - 'constant', means it fills the border with constant values
424              - 'edge', means it pads with the last value on the edge
425              - 'reflect', means it reflects the values on the edge omitting the last
426                value of edge
427              - 'symmetric', means it reflects the values on the edge repeating the last
428                value of edge
429
430    Returns:
431        PIL image, Cropped image.
432    """
433    if not is_pil(img):
434        raise TypeError(augment_error_message.format(type(img)))
435    if isinstance(size, int):
436        size = (size, size)
437    elif isinstance(size, (tuple, list)) and len(size) == 2:
438        size = size
439    else:
440        raise TypeError("Size should be a single integer or a list/tuple (h, w) of length 2.")
441
442    def _input_to_factor(img, size):
443        img_width, img_height = img.size
444        height, width = size
445        if height > img_height or width > img_width:
446            raise ValueError("Crop size {} is larger than input image size {}.".format(size, (img_height, img_width)))
447
448        if width == img_width and height == img_height:
449            return 0, 0, img_height, img_width
450
451        top = random.randint(0, img_height - height)
452        left = random.randint(0, img_width - width)
453        return top, left, height, width
454
455    if padding is not None:
456        img = pad(img, padding, fill_value, padding_mode)
457    # pad width when needed, img.size (width, height), crop size (height, width)
458    if pad_if_needed and img.size[0] < size[1]:
459        img = pad(img, (size[1] - img.size[0], 0), fill_value, padding_mode)
460    # pad height when needed
461    if pad_if_needed and img.size[1] < size[0]:
462        img = pad(img, (0, size[0] - img.size[1]), fill_value, padding_mode)
463
464    top, left, height, width = _input_to_factor(img, size)
465    return crop(img, top, left, height, width)
466
467
468def adjust_brightness(img, brightness_factor):
469    """
470    Adjust brightness of an image.
471
472    Args:
473        img (PIL image): Image to be adjusted.
474        brightness_factor (float): A non negative number indicated the factor by which
475            the brightness is adjusted. 0 gives a black image, 1 gives the original.
476
477    Returns:
478        img (PIL image), Brightness adjusted image.
479    """
480    if not is_pil(img):
481        raise TypeError(augment_error_message.format(type(img)))
482
483    enhancer = ImageEnhance.Brightness(img)
484    img = enhancer.enhance(brightness_factor)
485    return img
486
487
488def adjust_contrast(img, contrast_factor):
489    """
490    Adjust contrast of an image.
491
492    Args:
493        img (PIL image): PIL image to be adjusted.
494        contrast_factor (float): A non negative number indicated the factor by which
495            the contrast is adjusted. 0 gives a solid gray image, 1 gives the original.
496
497    Returns:
498        img (PIL image), Contrast adjusted image.
499    """
500    if not is_pil(img):
501        raise TypeError(augment_error_message.format(type(img)))
502
503    enhancer = ImageEnhance.Contrast(img)
504    img = enhancer.enhance(contrast_factor)
505    return img
506
507
508def adjust_saturation(img, saturation_factor):
509    """
510    Adjust saturation of an image.
511
512    Args:
513        img (PIL image): PIL image to be adjusted.
514        saturation_factor (float):  A non negative number indicated the factor by which
515            the saturation is adjusted. 0 will give a black and white image, 1 will
516            give the original.
517
518    Returns:
519        img (PIL image), Saturation adjusted image.
520    """
521    if not is_pil(img):
522        raise TypeError(augment_error_message.format(type(img)))
523
524    enhancer = ImageEnhance.Color(img)
525    img = enhancer.enhance(saturation_factor)
526    return img
527
528
529def adjust_hue(img, hue_factor):
530    """
531    Adjust hue of an image. The Hue is changed by changing the HSV values after image is converted to HSV.
532
533    Args:
534        img (PIL image): PIL image to be adjusted.
535        hue_factor (float):  Amount to shift the Hue channel. Value should be in
536            [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel. This
537            is because Hue wraps around when rotated 360 degrees.
538            0 means no shift that gives the original image while both -0.5 and 0.5
539            will give an image with complementary colors .
540
541    Returns:
542        img (PIL image), Hue adjusted image.
543    """
544    image = img
545    image_hue_factor = hue_factor
546    if not -0.5 <= image_hue_factor <= 0.5:
547        raise ValueError('image_hue_factor {} is not in [-0.5, 0.5].'.format(image_hue_factor))
548
549    if not is_pil(image):
550        raise TypeError(augment_error_message.format(type(image)))
551
552    mode = image.mode
553    if mode in {'L', '1', 'I', 'F'}:
554        return image
555
556    hue, saturation, value = img.convert('HSV').split()
557
558    np_hue = np.array(hue, dtype=np.uint8)
559
560    with np.errstate(over='ignore'):
561        np_hue += np.uint8(image_hue_factor * 255)
562    hue = Image.fromarray(np_hue, 'L')
563
564    image = Image.merge('HSV', (hue, saturation, value)).convert(mode)
565    return image
566
567
568def to_type(img, output_type):
569    """
570    Convert the NumPy image array to desired NumPy dtype.
571
572    Args:
573        img (numpy): NumPy image to cast to desired NumPy dtype.
574        output_type (Numpy datatype): NumPy dtype to cast to.
575
576    Returns:
577        img (numpy.ndarray), Converted image.
578    """
579    if not is_numpy(img):
580        raise TypeError("img should be NumPy image. Got {}.".format(type(img)))
581
582    try:
583        return img.astype(output_type)
584    except Exception:
585        raise RuntimeError("output_type: " + str(output_type) + " is not a valid datatype.")
586
587
588def rotate(img, angle, resample, expand, center, fill_value):
589    """
590    Rotate the input PIL image by angle.
591
592    Args:
593        img (PIL image): Image to be rotated.
594        angle (int or float): Rotation angle in degrees, counter-clockwise.
595        resample (Union[Inter.NEAREST, Inter.BILINEAR, Inter.BICUBIC], optional): An optional resampling filter.
596            If omitted, or if the image has mode "1" or "P", it is set to be Inter.NEAREST.
597        expand (bool, optional):  Optional expansion flag. If set to True, expand the output
598            image to make it large enough to hold the entire rotated image.
599            If set to False or omitted, make the output image the same size as the input.
600            Note that the expand flag assumes rotation around the center and no translation.
601        center (tuple, optional): Optional center of rotation (a 2-tuple).
602            Origin is the top left corner.
603        fill_value (Union[int, tuple]): Optional fill color for the area outside the rotated image.
604            If it is a 3-tuple, it is used for R, G, B channels respectively.
605            If it is an integer, it is used for all RGB channels.
606
607    Returns:
608        img (PIL image), Rotated image.
609
610    https://pillow.readthedocs.io/en/stable/reference/Image.html#PIL.Image.Image.rotate
611    """
612    if not is_pil(img):
613        raise TypeError(augment_error_message.format(type(img)))
614
615    if isinstance(fill_value, int):
616        fill_value = tuple([fill_value] * 3)
617
618    return img.rotate(angle, resample, expand, center, fillcolor=fill_value)
619
620
621def random_color_adjust(img, brightness, contrast, saturation, hue):
622    """
623    Randomly adjust the brightness, contrast, saturation, and hue of an image.
624
625    Args:
626        img (PIL image): Image to have its color adjusted randomly.
627        brightness (Union[float, tuple]): Brightness adjustment factor. Cannot be negative.
628            If it is a float, the factor is uniformly chosen from the range [max(0, 1-brightness), 1+brightness].
629            If it is a sequence, it should be [min, max] for the range.
630        contrast (Union[float, tuple]): Contrast adjustment factor. Cannot be negative.
631            If it is a float, the factor is uniformly chosen from the range [max(0, 1-contrast), 1+contrast].
632            If it is a sequence, it should be [min, max] for the range.
633        saturation (Union[float, tuple]): Saturation adjustment factor. Cannot be negative.
634            If it is a float, the factor is uniformly chosen from the range [max(0, 1-saturation), 1+saturation].
635            If it is a sequence, it should be [min, max] for the range.
636        hue (Union[float, tuple]): Hue adjustment factor.
637            If it is a float, the range will be [-hue, hue]. Value should be 0 <= hue <= 0.5.
638            If it is a sequence, it should be [min, max] where -0.5 <= min <= max <= 0.5.
639
640    Returns:
641        img (PIL image), Image after random adjustment of its color.
642    """
643    if not is_pil(img):
644        raise TypeError(augment_error_message.format(type(img)))
645
646    def _input_to_factor(value, input_name, center=1, bound=(0, float('inf')), non_negative=True):
647        if isinstance(value, numbers.Number):
648            if value < 0:
649                raise ValueError("The input value of {} cannot be negative.".format(input_name))
650            # convert value into a range
651            value = [center - value, center + value]
652            if non_negative:
653                value[0] = max(0, value[0])
654        elif isinstance(value, (list, tuple)) and len(value) == 2:
655            if not bound[0] <= value[0] <= value[1] <= bound[1]:
656                raise ValueError("Please check your value range of {} is valid and "
657                                 "within the bound {}.".format(input_name, bound))
658        else:
659            raise TypeError("Input of {} should be either a single value, or a list/tuple of "
660                            "length 2.".format(input_name))
661        factor = random.uniform(value[0], value[1])
662        return factor
663
664    brightness_factor = _input_to_factor(brightness, 'brightness')
665    contrast_factor = _input_to_factor(contrast, 'contrast')
666    saturation_factor = _input_to_factor(saturation, 'saturation')
667    hue_factor = _input_to_factor(hue, 'hue', center=0, bound=(-0.5, 0.5), non_negative=False)
668
669    transforms = []
670    transforms.append(lambda img: adjust_brightness(img, brightness_factor))
671    transforms.append(lambda img: adjust_contrast(img, contrast_factor))
672    transforms.append(lambda img: adjust_saturation(img, saturation_factor))
673    transforms.append(lambda img: adjust_hue(img, hue_factor))
674
675    # apply color adjustments in a random order
676    random.shuffle(transforms)
677    for transform in transforms:
678        img = transform(img)
679
680    return img
681
682
683def random_rotation(img, degrees, resample, expand, center, fill_value):
684    """
685    Rotate the input PIL image by a random angle.
686
687    Args:
688        img (PIL image): Image to be rotated.
689        degrees (Union[int, float, sequence]): Range of random rotation degrees.
690            If `degrees` is a number, the range will be converted to (-degrees, degrees).
691            If `degrees` is a sequence, it should be (min, max).
692        resample (Union[Inter.NEAREST, Inter.BILINEAR, Inter.BICUBIC], optional): An optional resampling filter.
693            If omitted, or if the image has mode "1" or "P", it is set to be Inter.NEAREST.
694        expand (bool, optional):  Optional expansion flag. If set to True, expand the output
695            image to make it large enough to hold the entire rotated image.
696            If set to False or omitted, make the output image the same size as the input.
697            Note that the expand flag assumes rotation around the center and no translation.
698        center (tuple, optional): Optional center of rotation (a 2-tuple).
699            Origin is the top left corner.
700        fill_value (Union[int, tuple]): Optional fill color for the area outside the rotated image.
701            If it is a 3-tuple, it is used for R, G, B channels respectively.
702            If it is an integer, it is used for all RGB channels.
703
704    Returns:
705        img (PIL image), Rotated image.
706
707    https://pillow.readthedocs.io/en/stable/reference/Image.html#PIL.Image.Image.rotate
708    """
709    if not is_pil(img):
710        raise TypeError(augment_error_message.format(type(img)))
711
712    if isinstance(degrees, numbers.Number):
713        if degrees < 0:
714            raise ValueError("If degrees is a single number, it cannot be negative.")
715        degrees = (-degrees, degrees)
716    elif isinstance(degrees, (list, tuple)):
717        if len(degrees) != 2:
718            raise ValueError("If degrees is a sequence, the length must be 2.")
719    else:
720        raise TypeError("Degrees must be a single non-negative number or a sequence.")
721
722    angle = random.uniform(degrees[0], degrees[1])
723    return rotate(img, angle, resample, expand, center, fill_value)
724
725
726def five_crop(img, size):
727    """
728    Generate 5 cropped images (one central and four corners).
729
730    Args:
731        img (PIL image): PIL image to be cropped.
732        size (Union[int, sequence]): The output size of the crop.
733            If size is an integer, a square crop of size (size, size) is returned.
734            If size is a sequence of length 2, it should be (height, width).
735
736    Returns:
737            img_tuple (tuple), a tuple of 5 PIL images
738                (top_left, top_right, bottom_left, bottom_right, center).
739    """
740    if not is_pil(img):
741        raise TypeError(augment_error_message.format(type(img)))
742
743    if isinstance(size, int):
744        size = (size, size)
745    elif isinstance(size, (tuple, list)) and len(size) == 2:
746        size = size
747    else:
748        raise TypeError("Size should be a single number or a list/tuple (h, w) of length 2.")
749
750    # PIL image.size returns in (width, height) order
751    img_width, img_height = img.size
752    crop_height, crop_width = size
753    if crop_height > img_height or crop_width > img_width:
754        raise ValueError("Crop size {} is larger than input image size {}.".format(size, (img_height, img_width)))
755    center = center_crop(img, (crop_height, crop_width))
756    top_left = img.crop((0, 0, crop_width, crop_height))
757    top_right = img.crop((img_width - crop_width, 0, img_width, crop_height))
758    bottom_left = img.crop((0, img_height - crop_height, crop_width, img_height))
759    bottom_right = img.crop((img_width - crop_width, img_height - crop_height, img_width, img_height))
760
761    return top_left, top_right, bottom_left, bottom_right, center
762
763
764def ten_crop(img, size, use_vertical_flip=False):
765    """
766    Generate 10 cropped images (first 5 from FiveCrop, second 5 from their flipped version).
767
768    The default is horizontal flipping, use_vertical_flip=False.
769
770    Args:
771        img (PIL image): PIL image to be cropped.
772        size (Union[int, sequence]): The output size of the crop.
773            If size is an integer, a square crop of size (size, size) is returned.
774            If size is a sequence of length 2, it should be (height, width).
775        use_vertical_flip (bool): Flip the image vertically instead of horizontally if set to True.
776
777    Returns:
778        img_tuple (tuple), a tuple of 10 PIL images
779            (top_left, top_right, bottom_left, bottom_right, center) of original image +
780            (top_left, top_right, bottom_left, bottom_right, center) of flipped image.
781    """
782    if not is_pil(img):
783        raise TypeError(augment_error_message.format(type(img)))
784
785    if isinstance(size, int):
786        size = (size, size)
787    elif isinstance(size, (tuple, list)) and len(size) == 2:
788        size = size
789    else:
790        raise TypeError("Size should be a single number or a list/tuple (h, w) of length 2.")
791
792    first_five_crop = five_crop(img, size)
793
794    if use_vertical_flip:
795        img = vertical_flip(img)
796    else:
797        img = horizontal_flip(img)
798
799    second_five_crop = five_crop(img, size)
800
801    return first_five_crop + second_five_crop
802
803
804def grayscale(img, num_output_channels):
805    """
806    Convert the input PIL image to grayscale image.
807
808    Args:
809        img (PIL image): PIL image to be converted to grayscale.
810        num_output_channels (int): Number of channels of the output grayscale image (1 or 3).
811
812    Returns:
813        img (PIL image), grayscaled image.
814    """
815    if not is_pil(img):
816        raise TypeError(augment_error_message.format(type(img)))
817
818    if num_output_channels == 1:
819        img = img.convert('L')
820    elif num_output_channels == 3:
821        # each channel is the same grayscale layer
822        img = img.convert('L')
823        np_gray = np.array(img, dtype=np.uint8)
824        np_img = np.dstack([np_gray, np_gray, np_gray])
825        img = Image.fromarray(np_img, 'RGB')
826    else:
827        raise ValueError('num_output_channels should be either 1 or 3. Got {}.'.format(num_output_channels))
828
829    return img
830
831
832def pad(img, padding, fill_value, padding_mode):
833    """
834    Pad the image according to padding parameters.
835
836    Args:
837        img (PIL image): Image to be padded.
838        padding (Union[int, sequence], optional): The number of pixels to pad the image.
839            If a single number is provided, it pads all borders with this value.
840            If a tuple or lists of 2 values are provided, it pads the (left and top)
841            with the first value and (right and bottom) with the second value.
842            If 4 values are provided as a list or tuple,
843            it pads the left, top, right and bottom respectively.
844            Default is None.
845        fill_value (Union[int, tuple]): The pixel intensity of the borders if
846            the padding_mode is "constant". If it is a 3-tuple, it is used to
847            fill R, G, B channels respectively.
848        padding_mode (str): The method of padding. Can be any of ['constant', 'edge', 'reflect', 'symmetric'].
849
850              - 'constant', means it fills the border with constant values
851              - 'edge', means it pads with the last value on the edge
852              - 'reflect', means it reflects the values on the edge omitting the last
853                value of edge
854              - 'symmetric', means it reflects the values on the edge repeating the last
855                value of edge
856
857    Returns:
858        img (PIL image), Padded image.
859    """
860    if not is_pil(img):
861        raise TypeError(augment_error_message.format(type(img)))
862
863    if isinstance(padding, numbers.Number):
864        top = bottom = left = right = padding
865
866    elif isinstance(padding, (tuple, list)):
867        if len(padding) == 2:
868            left = top = padding[0]
869            right = bottom = padding[1]
870        elif len(padding) == 4:
871            left = padding[0]
872            top = padding[1]
873            right = padding[2]
874            bottom = padding[3]
875        else:
876            raise ValueError("The size of the padding list or tuple should be 2 or 4.")
877    else:
878        raise TypeError("Padding can be any of: a number, a tuple or list of size 2 or 4.")
879
880    if not isinstance(fill_value, (numbers.Number, str, tuple)):
881        raise TypeError("fill_value can be any of: an integer, a string or a tuple.")
882
883    if padding_mode not in ['constant', 'edge', 'reflect', 'symmetric']:
884        raise ValueError("Padding mode should be 'constant', 'edge', 'reflect', or 'symmetric'.")
885
886    if padding_mode == 'constant':
887        if img.mode == 'P':
888            palette = img.getpalette()
889            image = ImageOps.expand(img, border=(left, top, right, bottom), fill=fill_value)
890            image.putpalette(palette)
891            return image
892        return ImageOps.expand(img, border=(left, top, right, bottom), fill=fill_value)
893
894    if img.mode == 'P':
895        palette = img.getpalette()
896        img = np.asarray(img)
897        img = np.pad(img, ((top, bottom), (left, right)), padding_mode)
898        img = Image.fromarray(img)
899        img.putpalette(palette)
900        return img
901
902    img = np.asarray(img)
903    if len(img.shape) == 3:
904        img = np.pad(img, ((top, bottom), (left, right), (0, 0)), padding_mode)
905    if len(img.shape) == 2:
906        img = np.pad(img, ((top, bottom), (left, right)), padding_mode)
907
908    return Image.fromarray(img)
909
910
911def get_perspective_params(img, distortion_scale):
912    """Helper function to get parameters for RandomPerspective.
913    """
914    img_width, img_height = img.size
915    distorted_half_width = int(img_width / 2 * distortion_scale)
916    distorted_half_height = int(img_height / 2 * distortion_scale)
917    top_left = (random.randint(0, distorted_half_width),
918                random.randint(0, distorted_half_height))
919    top_right = (random.randint(img_width - distorted_half_width - 1, img_width - 1),
920                 random.randint(0, distorted_half_height))
921    bottom_right = (random.randint(img_width - distorted_half_width - 1, img_width - 1),
922                    random.randint(img_height - distorted_half_height - 1, img_height - 1))
923    bottom_left = (random.randint(0, distorted_half_width),
924                   random.randint(img_height - distorted_half_height - 1, img_height - 1))
925
926    start_points = [(0, 0), (img_width - 1, 0), (img_width - 1, img_height - 1), (0, img_height - 1)]
927    end_points = [top_left, top_right, bottom_right, bottom_left]
928    return start_points, end_points
929
930
931def perspective(img, start_points, end_points, interpolation=Inter.BICUBIC):
932    """
933    Apply perspective transformation to the input PIL image.
934
935    Args:
936        img (PIL image): PIL image to be applied perspective transformation.
937        start_points (list): List of [top_left, top_right, bottom_right, bottom_left] of the original image.
938        end_points: List of [top_left, top_right, bottom_right, bottom_left] of the transformed image.
939        interpolation (interpolation mode): Image interpolation mode, Default is Inter.BICUBIC = 3.
940
941    Returns:
942        img (PIL image), Image after being perspectively transformed.
943    """
944
945    def _input_to_coeffs(original_points, transformed_points):
946        # Get the coefficients (a, b, c, d, e, f, g, h) for the perspective transforms.
947        # According to "Using Projective Geometry to Correct a Camera" from AMS.
948        # http://www.ams.org/publicoutreach/feature-column/fc-2013-03
949        # https://github.com/python-pillow/Pillow/blob/master/src/libImaging/Geometry.c#L377
950
951        matrix = []
952        for pt1, pt2 in zip(transformed_points, original_points):
953            matrix.append([pt1[0], pt1[1], 1, 0, 0, 0, -pt2[0] * pt1[0], -pt2[0] * pt1[1]])
954            matrix.append([0, 0, 0, pt1[0], pt1[1], 1, -pt2[1] * pt1[0], -pt2[1] * pt1[1]])
955        matrix_a = np.array(matrix, dtype=np.float)
956        matrix_b = np.array(original_points, dtype=np.float).reshape(8)
957        res = np.linalg.lstsq(matrix_a, matrix_b, rcond=None)[0]
958        return res.tolist()
959
960    if not is_pil(img):
961        raise TypeError(augment_error_message.format(type(img)))
962
963    coeffs = _input_to_coeffs(start_points, end_points)
964    return img.transform(img.size, Image.PERSPECTIVE, coeffs, interpolation)
965
966
967def get_erase_params(np_img, scale, ratio, value, bounded, max_attempts):
968    """Helper function to get parameters for RandomErasing/Cutout.
969    """
970    if not is_numpy(np_img):
971        raise TypeError('img should be NumPy array. Got {}.'.format(type(np_img)))
972
973    image_c, image_h, image_w = np_img.shape
974    area = image_h * image_w
975
976    for _ in range(max_attempts):
977        erase_area = random.uniform(scale[0], scale[1]) * area
978        aspect_ratio = random.uniform(ratio[0], ratio[1])
979        erase_w = int(round(math.sqrt(erase_area * aspect_ratio)))
980        erase_h = int(round(erase_w / aspect_ratio))
981        erase_shape = (image_c, erase_h, erase_w)
982
983        if erase_h < image_h and erase_w < image_w:
984            if bounded:
985                i = random.randint(0, image_h - erase_h)
986                j = random.randint(0, image_w - erase_w)
987            else:
988                def clip(x, lower, upper):
989                    return max(lower, min(x, upper))
990
991                x = random.randint(0, image_w)
992                y = random.randint(0, image_h)
993                x1 = clip(x - erase_w // 2, 0, image_w)
994                x2 = clip(x + erase_w // 2, 0, image_w)
995                y1 = clip(y - erase_h // 2, 0, image_h)
996                y2 = clip(y + erase_h // 2, 0, image_h)
997
998                i, j, erase_h, erase_w = y1, x1, y2 - y1, x2 - x1
999
1000            if isinstance(value, numbers.Number):
1001                erase_value = value
1002            elif isinstance(value, (str, bytes)):
1003                erase_value = np.random.normal(loc=0.0, scale=1.0, size=erase_shape)
1004            elif isinstance(value, (tuple, list)) and len(value) == 3:
1005                value = np.array(value)
1006                erase_value = np.multiply(np.ones(erase_shape), value[:, None, None])
1007            else:
1008                raise ValueError("The value for erasing should be either a single value, or a string "
1009                                 "'random', or a sequence of 3 elements for RGB respectively.")
1010
1011            return i, j, erase_h, erase_w, erase_value
1012
1013    # exceeding max_attempts, return original image
1014    return 0, 0, image_h, image_w, np_img
1015
1016
1017def erase(np_img, i, j, height, width, erase_value, inplace=False):
1018    """
1019    Erase the pixels, within a selected rectangle region, to the given value. Applied on the input NumPy image array.
1020
1021    Args:
1022        np_img (numpy.ndarray): NumPy image array of shape (C, H, W) to be erased.
1023        i (int): The height component of the top left corner (height, width).
1024        j (int): The width component of the top left corner (height, width).
1025        height (int): Height of the erased region.
1026        width (int): Width of the erased region.
1027        erase_value: Erase value return from helper function get_erase_params().
1028        inplace (bool, optional): Apply this transform inplace. Default is False.
1029
1030    Returns:
1031        np_img (numpy.ndarray), Erased NumPy image array.
1032    """
1033    if not is_numpy(np_img):
1034        raise TypeError('img should be NumPy array. Got {}.'.format(type(np_img)))
1035
1036    if not inplace:
1037        np_img = np_img.copy()
1038    # (i, j) here are the coordinates of axes (height, width) as in CHW
1039    np_img[:, i:i + height, j:j + width] = erase_value
1040    return np_img
1041
1042
1043def linear_transform(np_img, transformation_matrix, mean_vector):
1044    """
1045    Apply linear transformation to the input NumPy image array, given a square transformation matrix and a mean_vector.
1046
1047    The transformation first flattens the input array and subtract mean_vector from it, then computes the
1048    dot product with the transformation matrix, and reshapes it back to its original shape.
1049
1050    Args:
1051        np_img (numpy.ndarray): NumPy image array of shape (C, H, W) to be linear transformed.
1052        transformation_matrix (numpy.ndarray): a square transformation matrix of shape (D, D), D = C x H x W.
1053        mean_vector (numpy.ndarray): a NumPy ndarray of shape (D,) where D = C x H x W.
1054
1055    Returns:
1056        np_img (numpy.ndarray), Linear transformed image.
1057    """
1058    if not is_numpy(np_img):
1059        raise TypeError('img should be NumPy array. Got {}'.format(type(np_img)))
1060    if transformation_matrix.shape[0] != transformation_matrix.shape[1]:
1061        raise ValueError("transformation_matrix should be a square matrix. "
1062                         "Got shape {} instead".format(transformation_matrix.shape))
1063    if np.prod(np_img.shape) != transformation_matrix.shape[0]:
1064        raise ValueError("transformation_matrix shape {0} not compatible with "
1065                         "Numpy image shape {1}.".format(transformation_matrix.shape, np_img.shape))
1066    if mean_vector.shape[0] != transformation_matrix.shape[0]:
1067        raise ValueError("mean_vector length {0} should match either one dimension of the square "
1068                         "transformation_matrix {1}.".format(mean_vector.shape[0], transformation_matrix.shape))
1069    zero_centered_img = np_img.reshape(1, -1) - mean_vector
1070    transformed_img = np.dot(zero_centered_img, transformation_matrix)
1071    if transformed_img.size != np_img.size:
1072        raise ValueError("Linear transform failed, input shape should match with transformation_matrix.")
1073    transformed_img = transformed_img.reshape(np_img.shape)
1074    return transformed_img
1075
1076
1077def random_affine(img, angle, translations, scale, shear, resample, fill_value=0):
1078    """
1079    Applies a random Affine transformation on the input PIL image.
1080
1081    Args:
1082        img (PIL image): Image to be applied affine transformation.
1083        angle (Union[int, float]): Rotation angle in degrees, clockwise.
1084        translations (sequence): Translations in horizontal and vertical axis.
1085        scale (float): Scale parameter, a single number.
1086        shear (Union[float, sequence]): Shear amount parallel to X axis and Y axis.
1087        resample (Union[Inter.NEAREST, Inter.BILINEAR, Inter.BICUBIC], optional): An optional resampling filter.
1088        fill_value (Union[tuple int], optional): Optional fill_value to fill the area outside the transform
1089            in the output image. Used only in Pillow versions > 5.0.0.
1090            If None, no filling is performed.
1091
1092    Returns:
1093        img (PIL image), Randomly affine transformed image.
1094
1095    """
1096    if not is_pil(img):
1097        raise ValueError("Input image should be a Pillow image.")
1098
1099    # rotation
1100    angle = random.uniform(angle[0], angle[1])
1101
1102    # translation
1103    if translations is not None:
1104        max_dx = translations[0] * img.size[0]
1105        max_dy = translations[1] * img.size[1]
1106        translations = (np.round(random.uniform(-max_dx, max_dx)),
1107                        np.round(random.uniform(-max_dy, max_dy)))
1108    else:
1109        translations = (0, 0)
1110
1111    # scale
1112    if scale is not None:
1113        scale = random.uniform(scale[0], scale[1])
1114    else:
1115        scale = 1.0
1116
1117    # shear
1118    if shear is not None:
1119        if len(shear) == 2:
1120            shear = [random.uniform(shear[0], shear[1]), 0.]
1121        elif len(shear) == 4:
1122            shear = [random.uniform(shear[0], shear[1]),
1123                     random.uniform(shear[2], shear[3])]
1124    else:
1125        shear = 0.0
1126
1127    output_size = img.size
1128    center = (img.size[0] * 0.5 + 0.5, img.size[1] * 0.5 + 0.5)
1129
1130    angle = math.radians(angle)
1131    if isinstance(shear, (tuple, list)) and len(shear) == 2:
1132        shear = [math.radians(s) for s in shear]
1133    elif isinstance(shear, numbers.Number):
1134        shear = math.radians(shear)
1135        shear = [shear, 0]
1136    else:
1137        raise ValueError(
1138            "Shear should be a single value or a tuple/list containing " +
1139            "two values. Got {}.".format(shear))
1140
1141    scale = 1.0 / scale
1142
1143    # Inverted rotation matrix with scale and shear
1144    d = math.cos(angle + shear[0]) * math.cos(angle + shear[1]) + \
1145        math.sin(angle + shear[0]) * math.sin(angle + shear[1])
1146    matrix = [
1147        math.cos(angle + shear[0]), math.sin(angle + shear[0]), 0,
1148        -math.sin(angle + shear[1]), math.cos(angle + shear[1]), 0
1149    ]
1150    matrix = [scale / d * m for m in matrix]
1151
1152    # Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1
1153    matrix[2] += matrix[0] * (-center[0] - translations[0]) + matrix[1] * (-center[1] - translations[1])
1154    matrix[5] += matrix[3] * (-center[0] - translations[0]) + matrix[4] * (-center[1] - translations[1])
1155
1156    # Apply center translation: C * RSS^-1 * C^-1 * T^-1
1157    matrix[2] += center[0]
1158    matrix[5] += center[1]
1159
1160    if __version__ >= '5':
1161        kwargs = {"fillcolor": fill_value}
1162    else:
1163        kwargs = {}
1164    return img.transform(output_size, Image.AFFINE, matrix, resample, **kwargs)
1165
1166
1167def mix_up_single(batch_size, img, label, alpha=0.2):
1168    """
1169    Apply mix up transformation to image and label in single batch internal, One hot encoding should done before this.
1170
1171    Args:
1172        batch_size (int): The batch size of dataset.
1173        img (numpy.ndarray): NumPy image to be applied mix up transformation.
1174        label (numpy.ndarray): NumPy label to be applied mix up transformation.
1175        alpha (float): The mix up rate.
1176
1177    Returns:
1178        mix_img (numpy.ndarray): NumPy image after being applied mix up transformation.
1179        mix_label (numpy.ndarray): NumPy label after being applied mix up transformation.
1180    """
1181
1182    def cir_shift(data):
1183        index = list(range(1, batch_size)) + [0]
1184        data = data[index, ...]
1185        return data
1186
1187    lam = np.random.beta(alpha, alpha, batch_size)
1188    lam_img = lam.reshape((batch_size, 1, 1, 1))
1189    mix_img = lam_img * img + (1 - lam_img) * cir_shift(img)
1190
1191    lam_label = lam.reshape((batch_size, 1))
1192    mix_label = lam_label * label + (1 - lam_label) * cir_shift(label)
1193
1194    return mix_img, mix_label
1195
1196
1197def mix_up_muti(tmp, batch_size, img, label, alpha=0.2):
1198    """
1199    Apply mix up transformation to image and label in continuous batch, one hot encoding should done before this.
1200
1201    Args:
1202        tmp (class object): mainly for saving the tmp parameter.
1203        batch_size (int): the batch size of dataset.
1204        img (numpy.ndarray): NumPy image to be applied mix up transformation.
1205        label (numpy.ndarray): NumPy label to be applied mix up transformation.
1206        alpha (float):  refer to the mix up rate.
1207
1208    Returns:
1209        mix_img (numpy.ndarray): NumPy image after being applied mix up transformation.
1210        mix_label (numpy.ndarray): NumPy label after being applied mix up transformation.
1211    """
1212    lam = np.random.beta(alpha, alpha, batch_size)
1213    if tmp.is_first:
1214        lam = np.ones(batch_size)
1215        tmp.is_first = False
1216
1217    lam_img = lam.reshape((batch_size, 1, 1, 1))
1218    mix_img = lam_img * img + (1 - lam_img) * tmp.image
1219
1220    lam_label = lam.reshape(batch_size, 1)
1221    mix_label = lam_label * label + (1 - lam_label) * tmp.label
1222    tmp.image = mix_img
1223    tmp.label = mix_label
1224
1225    return mix_img, mix_label
1226
1227
1228def rgb_to_bgr(np_rgb_img, is_hwc):
1229    """
1230    Convert RGB img to BGR img.
1231
1232    Args:
1233        np_rgb_img (numpy.ndarray): NumPy RGB image array of shape (H, W, C) or (C, H, W) to be converted.
1234        is_hwc (Bool): If True, the shape of np_hsv_img is (H, W, C), otherwise must be (C, H, W).
1235
1236    Returns:
1237        np_bgr_img (numpy.ndarray), NumPy BGR image with same type of np_rgb_img.
1238    """
1239    if is_hwc:
1240        np_bgr_img = np_rgb_img[:, :, ::-1]
1241    else:
1242        np_bgr_img = np_rgb_img[::-1, :, :]
1243    return np_bgr_img
1244
1245
1246def rgb_to_bgrs(np_rgb_imgs, is_hwc):
1247    """
1248    Convert RGB imgs to BGR imgs.
1249
1250    Args:
1251        np_rgb_imgs (numpy.ndarray): NumPy RGB images array of shape (H, W, C) or (N, H, W, C),
1252                                      or (C, H, W) or (N, C, H, W) to be converted.
1253        is_hwc (Bool): If True, the shape of np_rgb_imgs is (H, W, C) or (N, H, W, C);
1254                       If False, the shape of np_rgb_imgs is (C, H, W) or (N, C, H, W).
1255
1256    Returns:
1257        np_bgr_imgs (numpy.ndarray), NumPy BGR images with same type of np_rgb_imgs.
1258    """
1259    if not is_numpy(np_rgb_imgs):
1260        raise TypeError("img should be NumPy image. Got {}".format(type(np_rgb_imgs)))
1261
1262    if not isinstance(is_hwc, bool):
1263        raise TypeError("is_hwc should be bool type. Got {}.".format(type(is_hwc)))
1264
1265    shape_size = len(np_rgb_imgs.shape)
1266
1267    if not shape_size in (3, 4):
1268        raise TypeError("img shape should be (H, W, C)/(N, H, W, C)/(C ,H, W)/(N, C, H, W). "
1269                        "Got {}.".format(np_rgb_imgs.shape))
1270
1271    if shape_size == 3:
1272        batch_size = 0
1273        if is_hwc:
1274            num_channels = np_rgb_imgs.shape[2]
1275        else:
1276            num_channels = np_rgb_imgs.shape[0]
1277    else:
1278        batch_size = np_rgb_imgs.shape[0]
1279        if is_hwc:
1280            num_channels = np_rgb_imgs.shape[3]
1281        else:
1282            num_channels = np_rgb_imgs.shape[1]
1283
1284    if num_channels != 3:
1285        raise TypeError("img should be 3 channels RGB img. Got {} channels.".format(num_channels))
1286    if batch_size == 0:
1287        return rgb_to_bgr(np_rgb_imgs, is_hwc)
1288    return np.array([rgb_to_bgr(img, is_hwc) for img in np_rgb_imgs])
1289
1290
1291def rgb_to_hsv(np_rgb_img, is_hwc):
1292    """
1293    Convert RGB img to HSV img.
1294
1295    Args:
1296        np_rgb_img (numpy.ndarray): NumPy RGB image array of shape (H, W, C) or (C, H, W) to be converted.
1297        is_hwc (Bool): If True, the shape of np_hsv_img is (H, W, C), otherwise must be (C, H, W).
1298
1299    Returns:
1300        np_hsv_img (numpy.ndarray), NumPy HSV image with same type of np_rgb_img.
1301    """
1302    if is_hwc:
1303        r, g, b = np_rgb_img[:, :, 0], np_rgb_img[:, :, 1], np_rgb_img[:, :, 2]
1304    else:
1305        r, g, b = np_rgb_img[0, :, :], np_rgb_img[1, :, :], np_rgb_img[2, :, :]
1306    to_hsv = np.vectorize(colorsys.rgb_to_hsv)
1307    h, s, v = to_hsv(r, g, b)
1308    if is_hwc:
1309        axis = 2
1310    else:
1311        axis = 0
1312    np_hsv_img = np.stack((h, s, v), axis=axis)
1313    return np_hsv_img
1314
1315
1316def rgb_to_hsvs(np_rgb_imgs, is_hwc):
1317    """
1318    Convert RGB imgs to HSV imgs.
1319
1320    Args:
1321        np_rgb_imgs (numpy.ndarray): NumPy RGB images array of shape (H, W, C) or (N, H, W, C),
1322                                      or (C, H, W) or (N, C, H, W) to be converted.
1323        is_hwc (Bool): If True, the shape of np_rgb_imgs is (H, W, C) or (N, H, W, C);
1324                       If False, the shape of np_rgb_imgs is (C, H, W) or (N, C, H, W).
1325
1326    Returns:
1327        np_hsv_imgs (numpy.ndarray), NumPy HSV images with same type of np_rgb_imgs.
1328    """
1329    if not is_numpy(np_rgb_imgs):
1330        raise TypeError("img should be NumPy image. Got {}".format(type(np_rgb_imgs)))
1331
1332    if not isinstance(is_hwc, bool):
1333        raise TypeError("is_hwc should be bool type. Got {}.".format(type(is_hwc)))
1334
1335    shape_size = len(np_rgb_imgs.shape)
1336
1337    if not shape_size in (3, 4):
1338        raise TypeError("img shape should be (H, W, C)/(N, H, W, C)/(C ,H, W)/(N, C, H, W). "
1339                        "Got {}.".format(np_rgb_imgs.shape))
1340
1341    if shape_size == 3:
1342        batch_size = 0
1343        if is_hwc:
1344            num_channels = np_rgb_imgs.shape[2]
1345        else:
1346            num_channels = np_rgb_imgs.shape[0]
1347    else:
1348        batch_size = np_rgb_imgs.shape[0]
1349        if is_hwc:
1350            num_channels = np_rgb_imgs.shape[3]
1351        else:
1352            num_channels = np_rgb_imgs.shape[1]
1353
1354    if num_channels != 3:
1355        raise TypeError("img should be 3 channels RGB img. Got {} channels.".format(num_channels))
1356    if batch_size == 0:
1357        return rgb_to_hsv(np_rgb_imgs, is_hwc)
1358    return np.array([rgb_to_hsv(img, is_hwc) for img in np_rgb_imgs])
1359
1360
1361def hsv_to_rgb(np_hsv_img, is_hwc):
1362    """
1363    Convert HSV img to RGB img.
1364
1365    Args:
1366        np_hsv_img (numpy.ndarray): NumPy HSV image array of shape (H, W, C) or (C, H, W) to be converted.
1367        is_hwc (Bool): If True, the shape of np_hsv_img is (H, W, C), otherwise must be (C, H, W).
1368
1369    Returns:
1370        np_rgb_img (numpy.ndarray), NumPy HSV image with same shape of np_hsv_img.
1371    """
1372    if is_hwc:
1373        h, s, v = np_hsv_img[:, :, 0], np_hsv_img[:, :, 1], np_hsv_img[:, :, 2]
1374    else:
1375        h, s, v = np_hsv_img[0, :, :], np_hsv_img[1, :, :], np_hsv_img[2, :, :]
1376    to_rgb = np.vectorize(colorsys.hsv_to_rgb)
1377    r, g, b = to_rgb(h, s, v)
1378
1379    if is_hwc:
1380        axis = 2
1381    else:
1382        axis = 0
1383    np_rgb_img = np.stack((r, g, b), axis=axis)
1384    return np_rgb_img
1385
1386
1387def hsv_to_rgbs(np_hsv_imgs, is_hwc):
1388    """
1389    Convert HSV imgs to RGB imgs.
1390
1391    Args:
1392        np_hsv_imgs (numpy.ndarray): NumPy HSV images array of shape (H, W, C) or (N, H, W, C),
1393                                      or (C, H, W) or (N, C, H, W) to be converted.
1394        is_hwc (Bool): If True, the shape of np_hsv_imgs is (H, W, C) or (N, H, W, C);
1395                       If False, the shape of np_hsv_imgs is (C, H, W) or (N, C, H, W).
1396
1397    Returns:
1398        np_rgb_imgs (numpy.ndarray), NumPy RGB images with same type of np_hsv_imgs.
1399    """
1400    if not is_numpy(np_hsv_imgs):
1401        raise TypeError("img should be NumPy image. Got {}.".format(type(np_hsv_imgs)))
1402
1403    if not isinstance(is_hwc, bool):
1404        raise TypeError("is_hwc should be bool type. Got {}.".format(type(is_hwc)))
1405
1406    shape_size = len(np_hsv_imgs.shape)
1407
1408    if not shape_size in (3, 4):
1409        raise TypeError("img shape should be (H, W, C)/(N, H, W, C)/(C, H, W)/(N, C, H, W). "
1410                        "Got {}.".format(np_hsv_imgs.shape))
1411
1412    if shape_size == 3:
1413        batch_size = 0
1414        if is_hwc:
1415            num_channels = np_hsv_imgs.shape[2]
1416        else:
1417            num_channels = np_hsv_imgs.shape[0]
1418    else:
1419        batch_size = np_hsv_imgs.shape[0]
1420        if is_hwc:
1421            num_channels = np_hsv_imgs.shape[3]
1422        else:
1423            num_channels = np_hsv_imgs.shape[1]
1424
1425    if num_channels != 3:
1426        raise TypeError("img should be 3 channels RGB img. Got {} channels.".format(num_channels))
1427    if batch_size == 0:
1428        return hsv_to_rgb(np_hsv_imgs, is_hwc)
1429    return np.array([hsv_to_rgb(img, is_hwc) for img in np_hsv_imgs])
1430
1431
1432def random_color(img, degrees):
1433    """
1434    Adjust the color of the input PIL image by a random degree.
1435
1436    Args:
1437        img (PIL image): Image to be color adjusted.
1438        degrees (sequence): Range of random color adjustment degrees.
1439            It should be in (min, max) format (default=(0.1,1.9)).
1440
1441    Returns:
1442        img (PIL image), Color adjusted image.
1443    """
1444
1445    if not is_pil(img):
1446        raise TypeError(augment_error_message.format(type(img)))
1447
1448    v = (degrees[1] - degrees[0]) * random.random() + degrees[0]
1449    return ImageEnhance.Color(img).enhance(v)
1450
1451
1452def random_sharpness(img, degrees):
1453    """
1454    Adjust the sharpness of the input PIL image by a random degree.
1455
1456    Args:
1457        img (PIL image): Image to be sharpness adjusted.
1458        degrees (sequence): Range of random sharpness adjustment degrees.
1459            It should be in (min, max) format (default=(0.1,1.9)).
1460
1461    Returns:
1462        img (PIL image), Sharpness adjusted image.
1463    """
1464
1465    if not is_pil(img):
1466        raise TypeError(augment_error_message.format(type(img)))
1467
1468    v = (degrees[1] - degrees[0]) * random.random() + degrees[0]
1469    return ImageEnhance.Sharpness(img).enhance(v)
1470
1471
1472def adjust_gamma(img, gamma, gain):
1473    """
1474    Adjust gamma of the input PIL image.
1475
1476    Args:
1477        img (PIL image): Image to be augmented with AdjustGamma.
1478        gamma (float): Non negative real number, same as gamma in the equation.
1479        gain (float, optional): The constant multiplier.
1480
1481    Returns:
1482        img (PIL image), Augmented image.
1483
1484    """
1485
1486    if not is_pil(img):
1487        raise TypeError("img should be PIL image. Got {}.".format(type(img)))
1488
1489    gamma_table = [(255 + 1 - 1e-3) * gain * pow(x / 255., gamma) for x in range(256)]
1490    if len(img.split()) == 3:
1491        gamma_table = gamma_table * 3
1492        img = img.point(gamma_table)
1493    elif len(img.split()) == 1:
1494        img = img.point(gamma_table)
1495    return img
1496
1497
1498def auto_contrast(img, cutoff, ignore):
1499    """
1500    Automatically maximize the contrast of the input PIL image.
1501
1502    Args:
1503        img (PIL image): Image to be augmented with AutoContrast.
1504        cutoff (float, optional): Percent of pixels to cut off from the histogram (default=0.0).
1505        ignore (Union[int, sequence], optional): Pixel values to ignore (default=None).
1506
1507    Returns:
1508        img (PIL image), Augmented image.
1509
1510    """
1511
1512    if not is_pil(img):
1513        raise TypeError(augment_error_message.format(type(img)))
1514
1515    return ImageOps.autocontrast(img, cutoff, ignore)
1516
1517
1518def invert_color(img):
1519    """
1520    Invert colors of input PIL image.
1521
1522    Args:
1523        img (PIL image): Image to be color inverted.
1524
1525    Returns:
1526        img (PIL image), Color inverted image.
1527
1528    """
1529
1530    if not is_pil(img):
1531        raise TypeError(augment_error_message.format(type(img)))
1532
1533    return ImageOps.invert(img)
1534
1535
1536def equalize(img):
1537    """
1538    Equalize the histogram of input PIL image.
1539
1540    Args:
1541        img (PIL image): Image to be equalized
1542
1543    Returns:
1544        img (PIL image), Equalized image.
1545
1546    """
1547
1548    if not is_pil(img):
1549        raise TypeError(augment_error_message.format(type(img)))
1550
1551    return ImageOps.equalize(img)
1552
1553
1554def uniform_augment(img, transforms, num_ops):
1555    """
1556    Uniformly select and apply a number of transforms sequentially from
1557    a list of transforms. Randomly assigns a probability to each transform for
1558    each image to decide whether apply it or not.
1559    All the transforms in transform list must have the same input/output data type.
1560
1561    Args:
1562        img: Image to be applied transformation.
1563        transforms (list): List of transformations to be chosen from to apply.
1564        num_ops (int): number of transforms to sequentially aaply.
1565
1566    Returns:
1567        img, Transformed image.
1568
1569    """
1570
1571    op_idx = np.random.choice(len(transforms), size=num_ops, replace=False)
1572    for idx in op_idx:
1573        AugmentOp = transforms[idx]
1574        pr = random.random()
1575        if random.random() < pr:
1576            img = AugmentOp(img.copy())
1577
1578    return img
1579