• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Validators for TensorOps.
16"""
17import numbers
18from functools import wraps
19import numpy as np
20from mindspore._c_dataengine import TensorOp, TensorOperation
21
22from mindspore.dataset.core.validator_helpers import check_value, check_uint8, FLOAT_MIN_INTEGER, FLOAT_MAX_INTEGER, \
23    check_pos_float32, check_float32, check_2tuple, check_range, check_positive, INT32_MAX, INT32_MIN, \
24    parse_user_args, type_check, type_check_list, check_c_tensor_op, UINT8_MAX, check_value_normalize_std, \
25    check_value_cutoff, check_value_ratio, check_odd, check_non_negative_float32
26from .utils import Inter, Border, ImageBatchFormat, ConvertMode, SliceMode
27
28
29def check_crop_size(size):
30    """Wrapper method to check the parameters of crop size."""
31    type_check(size, (int, list, tuple), "size")
32    if isinstance(size, int):
33        check_value(size, (1, FLOAT_MAX_INTEGER))
34    elif isinstance(size, (tuple, list)) and len(size) == 2:
35        for index, value in enumerate(size):
36            type_check(value, (int,), "size[{}]".format(index))
37            check_value(value, (1, FLOAT_MAX_INTEGER))
38    else:
39        raise TypeError("Size should be a single integer or a list/tuple (h, w) of length 2.")
40
41
42def check_crop_coordinates(coordinates):
43    """Wrapper method to check the parameters of crop size."""
44    type_check(coordinates, (list, tuple), "coordinates")
45    if isinstance(coordinates, (tuple, list)) and len(coordinates) == 2:
46        for index, value in enumerate(coordinates):
47            type_check(value, (int,), "coordinates[{}]".format(index))
48            check_value(value, (0, INT32_MAX), "coordinates[{}]".format(index))
49    else:
50        raise TypeError("Coordinates should be a list/tuple (y, x) of length 2.")
51
52
53def check_cut_mix_batch_c(method):
54    """Wrapper method to check the parameters of CutMixBatch."""
55
56    @wraps(method)
57    def new_method(self, *args, **kwargs):
58        [image_batch_format, alpha, prob], _ = parse_user_args(method, *args, **kwargs)
59        type_check(image_batch_format, (ImageBatchFormat,), "image_batch_format")
60        type_check(alpha, (int, float), "alpha")
61        type_check(prob, (int, float), "prob")
62        check_pos_float32(alpha)
63        check_positive(alpha, "alpha")
64        check_value(prob, [0, 1], "prob")
65        return method(self, *args, **kwargs)
66
67    return new_method
68
69
70def check_resize_size(size):
71    """Wrapper method to check the parameters of resize."""
72    if isinstance(size, int):
73        check_value(size, (1, FLOAT_MAX_INTEGER))
74    elif isinstance(size, (tuple, list)) and len(size) == 2:
75        for i, value in enumerate(size):
76            type_check(value, (int,), "size at dim {0}".format(i))
77            check_value(value, (1, INT32_MAX), "size at dim {0}".format(i))
78    else:
79        raise TypeError("Size should be a single integer or a list/tuple (h, w) of length 2.")
80
81
82def check_mix_up_batch_c(method):
83    """Wrapper method to check the parameters of MixUpBatch."""
84
85    @wraps(method)
86    def new_method(self, *args, **kwargs):
87        [alpha], _ = parse_user_args(method, *args, **kwargs)
88        type_check(alpha, (int, float), "alpha")
89        check_positive(alpha, "alpha")
90        check_pos_float32(alpha)
91
92        return method(self, *args, **kwargs)
93
94    return new_method
95
96
97def check_normalize_c_param(mean, std):
98    type_check(mean, (list, tuple), "mean")
99    type_check(std, (list, tuple), "std")
100    if len(mean) != len(std):
101        raise ValueError("Length of mean and std must be equal.")
102    for mean_value in mean:
103        check_value(mean_value, [0, 255], "mean_value")
104    for std_value in std:
105        check_value_normalize_std(std_value, [0, 255], "std_value")
106
107
108def check_normalize_py_param(mean, std):
109    type_check(mean, (list, tuple), "mean")
110    type_check(std, (list, tuple), "std")
111    if len(mean) != len(std):
112        raise ValueError("Length of mean and std must be equal.")
113    for mean_value in mean:
114        check_value(mean_value, [0., 1.], "mean_value")
115    for std_value in std:
116        check_value_normalize_std(std_value, [0., 1.], "std_value")
117
118
119def check_fill_value(fill_value):
120    if isinstance(fill_value, int):
121        check_uint8(fill_value)
122    elif isinstance(fill_value, tuple) and len(fill_value) == 3:
123        for value in fill_value:
124            check_uint8(value)
125    else:
126        raise TypeError("fill_value should be a single integer or a 3-tuple.")
127
128
129def check_padding(padding):
130    """Parsing the padding arguments and check if it is legal."""
131    type_check(padding, (tuple, list, numbers.Number), "padding")
132    if isinstance(padding, numbers.Number):
133        check_value(padding, (0, INT32_MAX), "padding")
134    if isinstance(padding, (tuple, list)):
135        if len(padding) not in (2, 4):
136            raise ValueError("The size of the padding list or tuple should be 2 or 4.")
137        for i, pad_value in enumerate(padding):
138            type_check(pad_value, (int,), "padding[{}]".format(i))
139            check_value(pad_value, (0, INT32_MAX), "pad_value")
140
141
142def check_degrees(degrees):
143    """Check if the `degrees` is legal."""
144    type_check(degrees, (int, float, list, tuple), "degrees")
145    if isinstance(degrees, (int, float)):
146        check_non_negative_float32(degrees, "degrees")
147    elif isinstance(degrees, (list, tuple)):
148        if len(degrees) == 2:
149            type_check_list(degrees, (int, float), "degrees")
150            for value in degrees:
151                check_float32(value, "degrees")
152            if degrees[0] > degrees[1]:
153                raise ValueError("degrees should be in (min,max) format. Got (max,min).")
154        else:
155            raise TypeError("If degrees is a sequence, the length must be 2.")
156
157
158def check_random_color_adjust_param(value, input_name, center=1, bound=(0, FLOAT_MAX_INTEGER), non_negative=True):
159    """Check the parameters in random color adjust operation."""
160    type_check(value, (numbers.Number, list, tuple), input_name)
161    if isinstance(value, numbers.Number):
162        if value < 0:
163            raise ValueError("The input value of {} cannot be negative.".format(input_name))
164    elif isinstance(value, (list, tuple)):
165        if len(value) != 2:
166            raise TypeError("If {0} is a sequence, the length must be 2.".format(input_name))
167        if value[0] > value[1]:
168            raise ValueError("{0} value should be in (min,max) format. Got ({1}, {2}).".format(input_name,
169                                                                                               value[0], value[1]))
170        check_range(value, bound)
171
172
173def check_erasing_value(value):
174    if not (isinstance(value, (numbers.Number,)) or
175            (isinstance(value, (str,)) and value == 'random') or
176            (isinstance(value, (tuple, list)) and len(value) == 3)):
177        raise ValueError("The value for erasing should be either a single value, "
178                         "or a string 'random', or a sequence of 3 elements for RGB respectively.")
179
180
181def check_crop(method):
182    """A wrapper that wraps a parameter checker around the original function(crop operation)."""
183
184    @wraps(method)
185    def new_method(self, *args, **kwargs):
186        [coordinates, size], _ = parse_user_args(method, *args, **kwargs)
187        check_crop_coordinates(coordinates)
188        check_crop_size(size)
189
190        return method(self, *args, **kwargs)
191
192    return new_method
193
194
195def check_center_crop(method):
196    """A wrapper that wraps a parameter checker around the original function(center crop operation)."""
197
198    @wraps(method)
199    def new_method(self, *args, **kwargs):
200        [size], _ = parse_user_args(method, *args, **kwargs)
201        check_crop_size(size)
202
203        return method(self, *args, **kwargs)
204
205    return new_method
206
207
208def check_five_crop(method):
209    """A wrapper that wraps a parameter checker around the original function(five crop operation)."""
210
211    @wraps(method)
212    def new_method(self, *args, **kwargs):
213        [size], _ = parse_user_args(method, *args, **kwargs)
214        check_crop_size(size)
215
216        return method(self, *args, **kwargs)
217
218    return new_method
219
220
221def check_posterize(method):
222    """A wrapper that wraps a parameter checker around the original function(posterize operation)."""
223
224    @wraps(method)
225    def new_method(self, *args, **kwargs):
226        [bits], _ = parse_user_args(method, *args, **kwargs)
227        if bits is not None:
228            type_check(bits, (list, tuple, int), "bits")
229        if isinstance(bits, int):
230            check_value(bits, [1, 8])
231        if isinstance(bits, (list, tuple)):
232            if len(bits) != 2:
233                raise TypeError("Size of bits should be a single integer or a list/tuple (min, max) of length 2.")
234            for item in bits:
235                check_uint8(item, "bits")
236            # also checks if min <= max
237            check_range(bits, [1, 8])
238        return method(self, *args, **kwargs)
239
240    return new_method
241
242
243def check_resize_interpolation(method):
244    """A wrapper that wraps a parameter checker around the original function(resize interpolation operation)."""
245
246    @wraps(method)
247    def new_method(self, *args, **kwargs):
248        [size, interpolation], _ = parse_user_args(method, *args, **kwargs)
249        if interpolation is None:
250            raise KeyError("Interpolation should not be None")
251        check_resize_size(size)
252        type_check(interpolation, (Inter,), "interpolation")
253
254        return method(self, *args, **kwargs)
255
256    return new_method
257
258
259def check_resize(method):
260    """A wrapper that wraps a parameter checker around the original function(resize operation)."""
261
262    @wraps(method)
263    def new_method(self, *args, **kwargs):
264        [size], _ = parse_user_args(method, *args, **kwargs)
265        check_resize_size(size)
266
267        return method(self, *args, **kwargs)
268
269    return new_method
270
271
272def check_size_scale_ration_max_attempts_paras(size, scale, ratio, max_attempts):
273    """Wrapper method to check the parameters of RandomCropDecodeResize and SoftDvppDecodeRandomCropResizeJpeg."""
274
275    check_crop_size(size)
276    if scale is not None:
277        type_check(scale, (tuple, list), "scale")
278        if len(scale) != 2:
279            raise TypeError("scale should be a list/tuple of length 2.")
280        type_check_list(scale, (float, int), "scale")
281        if scale[0] > scale[1]:
282            raise ValueError("scale should be in (min,max) format. Got (max,min).")
283        check_range(scale, [0, FLOAT_MAX_INTEGER])
284        check_positive(scale[1], "scale[1]")
285    if ratio is not None:
286        type_check(ratio, (tuple, list), "ratio")
287        if len(ratio) != 2:
288            raise TypeError("ratio should be a list/tuple of length 2.")
289        type_check_list(ratio, (float, int), "ratio")
290        if ratio[0] > ratio[1]:
291            raise ValueError("ratio should be in (min,max) format. Got (max,min).")
292        check_range(ratio, [0, FLOAT_MAX_INTEGER])
293        check_positive(ratio[0], "ratio[0]")
294        check_positive(ratio[1], "ratio[1]")
295    if max_attempts is not None:
296        check_value(max_attempts, (1, FLOAT_MAX_INTEGER))
297
298
299def check_random_resize_crop(method):
300    """A wrapper that wraps a parameter checker around the original function(random resize crop operation)."""
301
302    @wraps(method)
303    def new_method(self, *args, **kwargs):
304        [size, scale, ratio, interpolation, max_attempts], _ = parse_user_args(method, *args, **kwargs)
305        if interpolation is not None:
306            type_check(interpolation, (Inter,), "interpolation")
307        check_size_scale_ration_max_attempts_paras(size, scale, ratio, max_attempts)
308
309        return method(self, *args, **kwargs)
310
311    return new_method
312
313
314def check_prob(method):
315    """A wrapper that wraps a parameter checker (to confirm probability) around the original function."""
316
317    @wraps(method)
318    def new_method(self, *args, **kwargs):
319        [prob], _ = parse_user_args(method, *args, **kwargs)
320        type_check(prob, (float, int,), "prob")
321        check_value(prob, [0., 1.], "prob")
322
323        return method(self, *args, **kwargs)
324
325    return new_method
326
327
328def check_normalize_c(method):
329    """A wrapper that wraps a parameter checker around the original function(normalize operation written in C++)."""
330
331    @wraps(method)
332    def new_method(self, *args, **kwargs):
333        [mean, std], _ = parse_user_args(method, *args, **kwargs)
334        check_normalize_c_param(mean, std)
335
336        return method(self, *args, **kwargs)
337
338    return new_method
339
340
341def check_normalize_py(method):
342    """A wrapper that wraps a parameter checker around the original function(normalize operation written in Python)."""
343
344    @wraps(method)
345    def new_method(self, *args, **kwargs):
346        [mean, std], _ = parse_user_args(method, *args, **kwargs)
347        check_normalize_py_param(mean, std)
348
349        return method(self, *args, **kwargs)
350
351    return new_method
352
353
354def check_normalizepad_c(method):
355    """A wrapper that wraps a parameter checker around the original function(normalizepad written in C++)."""
356
357    @wraps(method)
358    def new_method(self, *args, **kwargs):
359        [mean, std, dtype], _ = parse_user_args(method, *args, **kwargs)
360        check_normalize_c_param(mean, std)
361        if not isinstance(dtype, str):
362            raise TypeError("dtype should be string.")
363        if dtype not in ["float32", "float16"]:
364            raise ValueError("dtype only support float32 or float16.")
365
366        return method(self, *args, **kwargs)
367
368    return new_method
369
370
371def check_normalizepad_py(method):
372    """A wrapper that wraps a parameter checker around the original function(normalizepad written in Python)."""
373
374    @wraps(method)
375    def new_method(self, *args, **kwargs):
376        [mean, std, dtype], _ = parse_user_args(method, *args, **kwargs)
377        check_normalize_py_param(mean, std)
378        if not isinstance(dtype, str):
379            raise TypeError("dtype should be string.")
380        if dtype not in ["float32", "float16"]:
381            raise ValueError("dtype only support float32 or float16.")
382
383        return method(self, *args, **kwargs)
384
385    return new_method
386
387
388def check_random_crop(method):
389    """Wrapper method to check the parameters of random crop."""
390
391    @wraps(method)
392    def new_method(self, *args, **kwargs):
393        [size, padding, pad_if_needed, fill_value, padding_mode], _ = parse_user_args(method, *args, **kwargs)
394        check_crop_size(size)
395        type_check(pad_if_needed, (bool,), "pad_if_needed")
396        if padding is not None:
397            check_padding(padding)
398        if fill_value is not None:
399            check_fill_value(fill_value)
400        if padding_mode is not None:
401            type_check(padding_mode, (Border,), "padding_mode")
402
403        return method(self, *args, **kwargs)
404
405    return new_method
406
407
408def check_random_color_adjust(method):
409    """Wrapper method to check the parameters of random color adjust."""
410
411    @wraps(method)
412    def new_method(self, *args, **kwargs):
413        [brightness, contrast, saturation, hue], _ = parse_user_args(method, *args, **kwargs)
414        check_random_color_adjust_param(brightness, "brightness")
415        check_random_color_adjust_param(contrast, "contrast")
416        check_random_color_adjust_param(saturation, "saturation")
417        check_random_color_adjust_param(hue, 'hue', center=0, bound=(-0.5, 0.5), non_negative=False)
418
419        return method(self, *args, **kwargs)
420
421    return new_method
422
423
424def check_resample_expand_center_fill_value_params(resample, expand, center, fill_value):
425    type_check(resample, (Inter,), "resample")
426    type_check(expand, (bool,), "expand")
427    if center is not None:
428        check_2tuple(center, "center")
429        for value in center:
430            type_check(value, (int, float), "center")
431            check_value(value, [INT32_MIN, INT32_MAX], "center")
432    check_fill_value(fill_value)
433
434
435def check_random_rotation(method):
436    """Wrapper method to check the parameters of random rotation."""
437
438    @wraps(method)
439    def new_method(self, *args, **kwargs):
440        [degrees, resample, expand, center, fill_value], _ = parse_user_args(method, *args, **kwargs)
441        check_degrees(degrees)
442        check_resample_expand_center_fill_value_params(resample, expand, center, fill_value)
443
444        return method(self, *args, **kwargs)
445
446    return new_method
447
448
449def check_rotate(method):
450    """Wrapper method to check the parameters of rotate."""
451
452    @wraps(method)
453    def new_method(self, *args, **kwargs):
454        [degrees, resample, expand, center, fill_value], _ = parse_user_args(method, *args, **kwargs)
455        type_check(degrees, (float, int), "degrees")
456        check_float32(degrees, "degrees")
457        check_resample_expand_center_fill_value_params(resample, expand, center, fill_value)
458
459        return method(self, *args, **kwargs)
460
461    return new_method
462
463
464def check_ten_crop(method):
465    """Wrapper method to check the parameters of crop."""
466
467    @wraps(method)
468    def new_method(self, *args, **kwargs):
469        [size, use_vertical_flip], _ = parse_user_args(method, *args, **kwargs)
470        check_crop_size(size)
471
472        if use_vertical_flip is not None:
473            type_check(use_vertical_flip, (bool,), "use_vertical_flip")
474
475        return method(self, *args, **kwargs)
476
477    return new_method
478
479
480def check_num_channels(method):
481    """Wrapper method to check the parameters of number of channels."""
482
483    @wraps(method)
484    def new_method(self, *args, **kwargs):
485        [num_output_channels], _ = parse_user_args(method, *args, **kwargs)
486        if num_output_channels is not None:
487            if num_output_channels not in (1, 3):
488                raise ValueError("Number of channels of the output grayscale image"
489                                 "should be either 1 or 3. Got {0}.".format(num_output_channels))
490
491        return method(self, *args, **kwargs)
492
493    return new_method
494
495
496def check_pad(method):
497    """Wrapper method to check the parameters of random pad."""
498
499    @wraps(method)
500    def new_method(self, *args, **kwargs):
501        [padding, fill_value, padding_mode], _ = parse_user_args(method, *args, **kwargs)
502        check_padding(padding)
503        check_fill_value(fill_value)
504        type_check(padding_mode, (Border,), "padding_mode")
505
506        return method(self, *args, **kwargs)
507
508    return new_method
509
510
511def check_slice_patches(method):
512    """Wrapper method to check the parameters of slice patches."""
513
514    @wraps(method)
515    def new_method(self, *args, **kwargs):
516        [num_height, num_width, slice_mode, fill_value], _ = parse_user_args(method, *args, **kwargs)
517        if num_height is not None:
518            type_check(num_height, (int,), "num_height")
519            check_value(num_height, (1, INT32_MAX), "num_height")
520        if num_width is not None:
521            type_check(num_width, (int,), "num_width")
522            check_value(num_width, (1, INT32_MAX), "num_width")
523        if slice_mode is not None:
524            type_check(slice_mode, (SliceMode,), "slice_mode")
525        if fill_value is not None:
526            type_check(fill_value, (int,), "fill_value")
527            check_value(fill_value, [0, 255], "fill_value")
528        return method(self, *args, **kwargs)
529
530    return new_method
531
532
533def check_random_perspective(method):
534    """Wrapper method to check the parameters of random perspective."""
535
536    @wraps(method)
537    def new_method(self, *args, **kwargs):
538        [distortion_scale, prob, interpolation], _ = parse_user_args(method, *args, **kwargs)
539
540        type_check(distortion_scale, (float,), "distortion_scale")
541        type_check(prob, (float,), "prob")
542        check_value(distortion_scale, [0., 1.], "distortion_scale")
543        check_value(prob, [0., 1.], "prob")
544        type_check(interpolation, (Inter,), "interpolation")
545
546        return method(self, *args, **kwargs)
547
548    return new_method
549
550
551def check_mix_up(method):
552    """Wrapper method to check the parameters of mix up."""
553
554    @wraps(method)
555    def new_method(self, *args, **kwargs):
556        [batch_size, alpha, is_single], _ = parse_user_args(method, *args, **kwargs)
557        type_check(is_single, (bool,), "is_single")
558        type_check(batch_size, (int,), "batch_size")
559        type_check(alpha, (int, float), "alpha")
560        check_value(batch_size, (1, FLOAT_MAX_INTEGER))
561        check_positive(alpha, "alpha")
562        return method(self, *args, **kwargs)
563
564    return new_method
565
566
567def check_rgb_to_bgr(method):
568    """Wrapper method to check the parameters of rgb_to_bgr."""
569
570    @wraps(method)
571    def new_method(self, *args, **kwargs):
572        [is_hwc], _ = parse_user_args(method, *args, **kwargs)
573        type_check(is_hwc, (bool,), "is_hwc")
574        return method(self, *args, **kwargs)
575
576    return new_method
577
578
579def check_rgb_to_hsv(method):
580    """Wrapper method to check the parameters of rgb_to_hsv."""
581
582    @wraps(method)
583    def new_method(self, *args, **kwargs):
584        [is_hwc], _ = parse_user_args(method, *args, **kwargs)
585        type_check(is_hwc, (bool,), "is_hwc")
586        return method(self, *args, **kwargs)
587
588    return new_method
589
590
591def check_hsv_to_rgb(method):
592    """Wrapper method to check the parameters of hsv_to_rgb."""
593
594    @wraps(method)
595    def new_method(self, *args, **kwargs):
596        [is_hwc], _ = parse_user_args(method, *args, **kwargs)
597        type_check(is_hwc, (bool,), "is_hwc")
598        return method(self, *args, **kwargs)
599
600    return new_method
601
602
603def check_random_erasing(method):
604    """Wrapper method to check the parameters of random erasing."""
605
606    @wraps(method)
607    def new_method(self, *args, **kwargs):
608        [prob, scale, ratio, value, inplace, max_attempts], _ = parse_user_args(method, *args, **kwargs)
609
610        type_check(prob, (float, int,), "prob")
611        type_check_list(scale, (float, int,), "scale")
612        if len(scale) != 2:
613            raise TypeError("scale should be a list or tuple of length 2.")
614        type_check_list(ratio, (float, int,), "ratio")
615        if len(ratio) != 2:
616            raise TypeError("ratio should be a list or tuple of length 2.")
617        type_check(value, (int, list, tuple, str), "value")
618        type_check(inplace, (bool,), "inplace")
619        type_check(max_attempts, (int,), "max_attempts")
620        check_erasing_value(value)
621
622        check_value(prob, [0., 1.], "prob")
623        if scale[0] > scale[1]:
624            raise ValueError("scale should be in (min,max) format. Got (max,min).")
625        check_range(scale, [0, FLOAT_MAX_INTEGER])
626        check_positive(scale[1], "scale[1]")
627        if ratio[0] > ratio[1]:
628            raise ValueError("ratio should be in (min,max) format. Got (max,min).")
629        check_value_ratio(ratio[0], [0, FLOAT_MAX_INTEGER])
630        check_value_ratio(ratio[1], [0, FLOAT_MAX_INTEGER])
631        if isinstance(value, int):
632            check_value(value, (0, 255))
633        if isinstance(value, (list, tuple)):
634            for item in value:
635                type_check(item, (int,), "value")
636                check_value(item, [0, 255], "value")
637        check_value(max_attempts, (1, FLOAT_MAX_INTEGER))
638
639        return method(self, *args, **kwargs)
640
641    return new_method
642
643
644def check_cutout(method):
645    """Wrapper method to check the parameters of cutout operation."""
646
647    @wraps(method)
648    def new_method(self, *args, **kwargs):
649        [length, num_patches], _ = parse_user_args(method, *args, **kwargs)
650        type_check(length, (int,), "length")
651        type_check(num_patches, (int,), "num_patches")
652        check_value(length, (1, FLOAT_MAX_INTEGER))
653        check_value(num_patches, (1, FLOAT_MAX_INTEGER))
654
655        return method(self, *args, **kwargs)
656
657    return new_method
658
659
660def check_linear_transform(method):
661    """Wrapper method to check the parameters of linear transform."""
662
663    @wraps(method)
664    def new_method(self, *args, **kwargs):
665        [transformation_matrix, mean_vector], _ = parse_user_args(method, *args, **kwargs)
666        type_check(transformation_matrix, (np.ndarray,), "transformation_matrix")
667        type_check(mean_vector, (np.ndarray,), "mean_vector")
668
669        if transformation_matrix.shape[0] != transformation_matrix.shape[1]:
670            raise ValueError("transformation_matrix should be a square matrix. "
671                             "Got shape {} instead.".format(transformation_matrix.shape))
672        if mean_vector.shape[0] != transformation_matrix.shape[0]:
673            raise ValueError("mean_vector length {0} should match either one dimension of the square"
674                             "transformation_matrix {1}.".format(mean_vector.shape[0], transformation_matrix.shape))
675
676        return method(self, *args, **kwargs)
677
678    return new_method
679
680
681def check_random_affine(method):
682    """Wrapper method to check the parameters of random affine."""
683
684    @wraps(method)
685    def new_method(self, *args, **kwargs):
686        [degrees, translate, scale, shear, resample, fill_value], _ = parse_user_args(method, *args, **kwargs)
687        check_degrees(degrees)
688
689        if translate is not None:
690            type_check(translate, (list, tuple), "translate")
691            type_check_list(translate, (int, float), "translate")
692            if len(translate) != 2 and len(translate) != 4:
693                raise TypeError("translate should be a list or tuple of length 2 or 4.")
694            for i, t in enumerate(translate):
695                check_value(t, [-1.0, 1.0], "translate at {0}".format(i))
696
697        if scale is not None:
698            type_check(scale, (tuple, list), "scale")
699            type_check_list(scale, (int, float), "scale")
700            if len(scale) == 2:
701                if scale[0] > scale[1]:
702                    raise ValueError("Input scale[1] must be equal to or greater than scale[0].")
703                check_range(scale, [0, FLOAT_MAX_INTEGER])
704                check_positive(scale[1], "scale[1]")
705            else:
706                raise TypeError("scale should be a list or tuple of length 2.")
707
708        if shear is not None:
709            type_check(shear, (numbers.Number, tuple, list), "shear")
710            if isinstance(shear, numbers.Number):
711                check_positive(shear, "shear")
712            else:
713                type_check_list(shear, (int, float), "shear")
714                if len(shear) not in (2, 4):
715                    raise TypeError("shear must be of length 2 or 4.")
716                if len(shear) == 2 and shear[0] > shear[1]:
717                    raise ValueError("Input shear[1] must be equal to or greater than shear[0]")
718                if len(shear) == 4 and (shear[0] > shear[1] or shear[2] > shear[3]):
719                    raise ValueError("Input shear[1] must be equal to or greater than shear[0] and "
720                                     "shear[3] must be equal to or greater than shear[2].")
721
722        type_check(resample, (Inter,), "resample")
723
724        if fill_value is not None:
725            check_fill_value(fill_value)
726
727        return method(self, *args, **kwargs)
728
729    return new_method
730
731
732def check_rescale(method):
733    """Wrapper method to check the parameters of rescale."""
734
735    @wraps(method)
736    def new_method(self, *args, **kwargs):
737        [rescale, shift], _ = parse_user_args(method, *args, **kwargs)
738        type_check(rescale, (numbers.Number,), "rescale")
739        type_check(shift, (numbers.Number,), "shift")
740        check_float32(rescale)
741        check_float32(shift)
742
743        return method(self, *args, **kwargs)
744
745    return new_method
746
747
748def check_uniform_augment_cpp(method):
749    """Wrapper method to check the parameters of UniformAugment C++ op."""
750
751    @wraps(method)
752    def new_method(self, *args, **kwargs):
753        [transforms, num_ops], _ = parse_user_args(method, *args, **kwargs)
754        type_check(num_ops, (int,), "num_ops")
755        check_positive(num_ops, "num_ops")
756
757        if num_ops > len(transforms):
758            raise ValueError("num_ops is greater than transforms list size.")
759        parsed_transforms = []
760        for op in transforms:
761            if op and getattr(op, 'parse', None):
762                parsed_transforms.append(op.parse())
763            else:
764                parsed_transforms.append(op)
765        type_check(parsed_transforms, (list, tuple,), "transforms")
766        for index, arg in enumerate(parsed_transforms):
767            if not isinstance(arg, (TensorOp, TensorOperation)):
768                raise TypeError("Type of Transforms[{0}] must be c_transform, but got {1}".format(index, type(arg)))
769
770        return method(self, *args, **kwargs)
771
772    return new_method
773
774
775def check_bounding_box_augment_cpp(method):
776    """Wrapper method to check the parameters of BoundingBoxAugment C++ op."""
777
778    @wraps(method)
779    def new_method(self, *args, **kwargs):
780        [transform, ratio], _ = parse_user_args(method, *args, **kwargs)
781        type_check(ratio, (float, int), "ratio")
782        check_value(ratio, [0., 1.], "ratio")
783        if transform and getattr(transform, 'parse', None):
784            transform = transform.parse()
785        type_check(transform, (TensorOp, TensorOperation), "transform")
786        return method(self, *args, **kwargs)
787
788    return new_method
789
790
791def check_adjust_gamma(method):
792    """Wrapper method to check the parameters of AdjustGamma ops (Python and C++)."""
793
794    @wraps(method)
795    def new_method(self, *args, **kwargs):
796        [gamma, gain], _ = parse_user_args(method, *args, **kwargs)
797        type_check(gamma, (float, int), "gamma")
798        check_value(gamma, (0, FLOAT_MAX_INTEGER))
799        if gain is not None:
800            type_check(gain, (float, int), "gain")
801            check_value(gain, (FLOAT_MIN_INTEGER, FLOAT_MAX_INTEGER))
802        return method(self, *args, **kwargs)
803
804    return new_method
805
806
807def check_auto_contrast(method):
808    """Wrapper method to check the parameters of AutoContrast ops (Python and C++)."""
809
810    @wraps(method)
811    def new_method(self, *args, **kwargs):
812        [cutoff, ignore], _ = parse_user_args(method, *args, **kwargs)
813        type_check(cutoff, (int, float), "cutoff")
814        check_value_cutoff(cutoff, [0, 50], "cutoff")
815        if ignore is not None:
816            type_check(ignore, (list, tuple, int), "ignore")
817        if isinstance(ignore, int):
818            check_value(ignore, [0, 255], "ignore")
819        if isinstance(ignore, (list, tuple)):
820            for item in ignore:
821                type_check(item, (int,), "item")
822                check_value(item, [0, 255], "ignore")
823        return method(self, *args, **kwargs)
824
825    return new_method
826
827
828def check_uniform_augment_py(method):
829    """Wrapper method to check the parameters of Python UniformAugment op."""
830
831    @wraps(method)
832    def new_method(self, *args, **kwargs):
833        [transforms, num_ops], _ = parse_user_args(method, *args, **kwargs)
834        type_check(transforms, (list,), "transforms")
835
836        if not transforms:
837            raise ValueError("transforms list is empty.")
838
839        for transform in transforms:
840            if isinstance(transform, TensorOp):
841                raise ValueError("transform list only accepts Python operations.")
842
843        type_check(num_ops, (int,), "num_ops")
844        check_positive(num_ops, "num_ops")
845        if num_ops > len(transforms):
846            raise ValueError("num_ops cannot be greater than the length of transforms list.")
847
848        return method(self, *args, **kwargs)
849
850    return new_method
851
852
853def check_positive_degrees(method):
854    """A wrapper method to check degrees parameter in RandomSharpness and RandomColor ops (Python and C++)"""
855
856    @wraps(method)
857    def new_method(self, *args, **kwargs):
858        [degrees], _ = parse_user_args(method, *args, **kwargs)
859
860        if degrees is not None:
861            if not isinstance(degrees, (list, tuple)):
862                raise TypeError("degrees must be either a tuple or a list.")
863            type_check_list(degrees, (int, float), "degrees")
864            if len(degrees) != 2:
865                raise ValueError("degrees must be a sequence with length 2.")
866            for degree in degrees:
867                check_value(degree, (0, FLOAT_MAX_INTEGER))
868            if degrees[0] > degrees[1]:
869                raise ValueError("degrees should be in (min,max) format. Got (max,min).")
870
871        return method(self, *args, **kwargs)
872
873    return new_method
874
875
876def check_random_select_subpolicy_op(method):
877    """Wrapper method to check the parameters of RandomSelectSubpolicyOp."""
878
879    @wraps(method)
880    def new_method(self, *args, **kwargs):
881        [policy], _ = parse_user_args(method, *args, **kwargs)
882        type_check(policy, (list,), "policy")
883        if not policy:
884            raise ValueError("policy can not be empty.")
885        for sub_ind, sub in enumerate(policy):
886            type_check(sub, (list,), "policy[{0}]".format([sub_ind]))
887            if not sub:
888                raise ValueError("policy[{0}] can not be empty.".format(sub_ind))
889            for op_ind, tp in enumerate(sub):
890                check_2tuple(tp, "policy[{0}][{1}]".format(sub_ind, op_ind))
891                check_c_tensor_op(tp[0], "op of (op, prob) in policy[{0}][{1}]".format(sub_ind, op_ind))
892                check_value(tp[1], (0, 1), "prob of (op, prob) policy[{0}][{1}]".format(sub_ind, op_ind))
893
894        return method(self, *args, **kwargs)
895
896    return new_method
897
898
899def check_soft_dvpp_decode_random_crop_resize_jpeg(method):
900    """Wrapper method to check the parameters of SoftDvppDecodeRandomCropResizeJpeg."""
901
902    @wraps(method)
903    def new_method(self, *args, **kwargs):
904        [size, scale, ratio, max_attempts], _ = parse_user_args(method, *args, **kwargs)
905        check_size_scale_ration_max_attempts_paras(size, scale, ratio, max_attempts)
906
907        return method(self, *args, **kwargs)
908
909    return new_method
910
911
912def check_random_solarize(method):
913    """Wrapper method to check the parameters of RandomSolarizeOp."""
914
915    @wraps(method)
916    def new_method(self, *args, **kwargs):
917        [threshold], _ = parse_user_args(method, *args, **kwargs)
918
919        type_check(threshold, (tuple,), "threshold")
920        type_check_list(threshold, (int,), "threshold")
921        if len(threshold) != 2:
922            raise ValueError("threshold must be a sequence of two numbers.")
923        for element in threshold:
924            check_value(element, (0, UINT8_MAX))
925        if threshold[1] < threshold[0]:
926            raise ValueError("threshold must be in min max format numbers.")
927
928        return method(self, *args, **kwargs)
929
930    return new_method
931
932
933def check_gaussian_blur(method):
934    """Wrapper method to check the parameters of GaussianBlur."""
935
936    @wraps(method)
937    def new_method(self, *args, **kwargs):
938        [kernel_size, sigma], _ = parse_user_args(method, *args, **kwargs)
939
940        type_check(kernel_size, (int, list, tuple), "kernel_size")
941        if isinstance(kernel_size, int):
942            check_value(kernel_size, (1, FLOAT_MAX_INTEGER), "kernel_size")
943            check_odd(kernel_size, "kernel_size")
944        elif isinstance(kernel_size, (list, tuple)) and len(kernel_size) == 2:
945            for index, value in enumerate(kernel_size):
946                type_check(value, (int,), "kernel_size[{}]".format(index))
947                check_value(value, (1, FLOAT_MAX_INTEGER), "kernel_size")
948                check_odd(value, "kernel_size[{}]".format(index))
949        else:
950            raise TypeError(
951                "Kernel size should be a single integer or a list/tuple (kernel_width, kernel_height) of length 2.")
952
953        if sigma is not None:
954            type_check(sigma, (numbers.Number, list, tuple), "sigma")
955            if isinstance(sigma, numbers.Number):
956                check_value(sigma, (0, FLOAT_MAX_INTEGER), "sigma")
957            elif isinstance(sigma, (list, tuple)) and len(sigma) == 2:
958                for index, value in enumerate(sigma):
959                    type_check(value, (numbers.Number,), "size[{}]".format(index))
960                    check_value(value, (0, FLOAT_MAX_INTEGER), "sigma")
961            else:
962                raise TypeError("Sigma should be a single number or a list/tuple of length 2 for width and height.")
963
964        return method(self, *args, **kwargs)
965
966    return new_method
967
968
969def check_convert_color(method):
970    """Wrapper method to check the parameters of convertcolor."""
971
972    @wraps(method)
973    def new_method(self, *args, **kwargs):
974        [convert_mode], _ = parse_user_args(method, *args, **kwargs)
975        if convert_mode is not None:
976            type_check(convert_mode, (ConvertMode,), "convert_mode")
977        return method(self, *args, **kwargs)
978
979    return new_method
980