1# Copyright 2019 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Validators for TensorOps. 16""" 17import numbers 18from functools import wraps 19import numpy as np 20from mindspore._c_dataengine import TensorOp, TensorOperation 21 22from mindspore.dataset.core.validator_helpers import check_value, check_uint8, FLOAT_MIN_INTEGER, FLOAT_MAX_INTEGER, \ 23 check_pos_float32, check_float32, check_2tuple, check_range, check_positive, INT32_MAX, INT32_MIN, \ 24 parse_user_args, type_check, type_check_list, check_c_tensor_op, UINT8_MAX, check_value_normalize_std, \ 25 check_value_cutoff, check_value_ratio, check_odd, check_non_negative_float32 26from .utils import Inter, Border, ImageBatchFormat, ConvertMode, SliceMode 27 28 29def check_crop_size(size): 30 """Wrapper method to check the parameters of crop size.""" 31 type_check(size, (int, list, tuple), "size") 32 if isinstance(size, int): 33 check_value(size, (1, FLOAT_MAX_INTEGER)) 34 elif isinstance(size, (tuple, list)) and len(size) == 2: 35 for index, value in enumerate(size): 36 type_check(value, (int,), "size[{}]".format(index)) 37 check_value(value, (1, FLOAT_MAX_INTEGER)) 38 else: 39 raise TypeError("Size should be a single integer or a list/tuple (h, w) of length 2.") 40 41 42def check_crop_coordinates(coordinates): 43 """Wrapper method to check the parameters of crop size.""" 44 type_check(coordinates, (list, tuple), "coordinates") 45 if isinstance(coordinates, (tuple, list)) and len(coordinates) == 2: 46 for index, value in enumerate(coordinates): 47 type_check(value, (int,), "coordinates[{}]".format(index)) 48 check_value(value, (0, INT32_MAX), "coordinates[{}]".format(index)) 49 else: 50 raise TypeError("Coordinates should be a list/tuple (y, x) of length 2.") 51 52 53def check_cut_mix_batch_c(method): 54 """Wrapper method to check the parameters of CutMixBatch.""" 55 56 @wraps(method) 57 def new_method(self, *args, **kwargs): 58 [image_batch_format, alpha, prob], _ = parse_user_args(method, *args, **kwargs) 59 type_check(image_batch_format, (ImageBatchFormat,), "image_batch_format") 60 type_check(alpha, (int, float), "alpha") 61 type_check(prob, (int, float), "prob") 62 check_pos_float32(alpha) 63 check_positive(alpha, "alpha") 64 check_value(prob, [0, 1], "prob") 65 return method(self, *args, **kwargs) 66 67 return new_method 68 69 70def check_resize_size(size): 71 """Wrapper method to check the parameters of resize.""" 72 if isinstance(size, int): 73 check_value(size, (1, FLOAT_MAX_INTEGER)) 74 elif isinstance(size, (tuple, list)) and len(size) == 2: 75 for i, value in enumerate(size): 76 type_check(value, (int,), "size at dim {0}".format(i)) 77 check_value(value, (1, INT32_MAX), "size at dim {0}".format(i)) 78 else: 79 raise TypeError("Size should be a single integer or a list/tuple (h, w) of length 2.") 80 81 82def check_mix_up_batch_c(method): 83 """Wrapper method to check the parameters of MixUpBatch.""" 84 85 @wraps(method) 86 def new_method(self, *args, **kwargs): 87 [alpha], _ = parse_user_args(method, *args, **kwargs) 88 type_check(alpha, (int, float), "alpha") 89 check_positive(alpha, "alpha") 90 check_pos_float32(alpha) 91 92 return method(self, *args, **kwargs) 93 94 return new_method 95 96 97def check_normalize_c_param(mean, std): 98 type_check(mean, (list, tuple), "mean") 99 type_check(std, (list, tuple), "std") 100 if len(mean) != len(std): 101 raise ValueError("Length of mean and std must be equal.") 102 for mean_value in mean: 103 check_value(mean_value, [0, 255], "mean_value") 104 for std_value in std: 105 check_value_normalize_std(std_value, [0, 255], "std_value") 106 107 108def check_normalize_py_param(mean, std): 109 type_check(mean, (list, tuple), "mean") 110 type_check(std, (list, tuple), "std") 111 if len(mean) != len(std): 112 raise ValueError("Length of mean and std must be equal.") 113 for mean_value in mean: 114 check_value(mean_value, [0., 1.], "mean_value") 115 for std_value in std: 116 check_value_normalize_std(std_value, [0., 1.], "std_value") 117 118 119def check_fill_value(fill_value): 120 if isinstance(fill_value, int): 121 check_uint8(fill_value) 122 elif isinstance(fill_value, tuple) and len(fill_value) == 3: 123 for value in fill_value: 124 check_uint8(value) 125 else: 126 raise TypeError("fill_value should be a single integer or a 3-tuple.") 127 128 129def check_padding(padding): 130 """Parsing the padding arguments and check if it is legal.""" 131 type_check(padding, (tuple, list, numbers.Number), "padding") 132 if isinstance(padding, numbers.Number): 133 check_value(padding, (0, INT32_MAX), "padding") 134 if isinstance(padding, (tuple, list)): 135 if len(padding) not in (2, 4): 136 raise ValueError("The size of the padding list or tuple should be 2 or 4.") 137 for i, pad_value in enumerate(padding): 138 type_check(pad_value, (int,), "padding[{}]".format(i)) 139 check_value(pad_value, (0, INT32_MAX), "pad_value") 140 141 142def check_degrees(degrees): 143 """Check if the `degrees` is legal.""" 144 type_check(degrees, (int, float, list, tuple), "degrees") 145 if isinstance(degrees, (int, float)): 146 check_non_negative_float32(degrees, "degrees") 147 elif isinstance(degrees, (list, tuple)): 148 if len(degrees) == 2: 149 type_check_list(degrees, (int, float), "degrees") 150 for value in degrees: 151 check_float32(value, "degrees") 152 if degrees[0] > degrees[1]: 153 raise ValueError("degrees should be in (min,max) format. Got (max,min).") 154 else: 155 raise TypeError("If degrees is a sequence, the length must be 2.") 156 157 158def check_random_color_adjust_param(value, input_name, center=1, bound=(0, FLOAT_MAX_INTEGER), non_negative=True): 159 """Check the parameters in random color adjust operation.""" 160 type_check(value, (numbers.Number, list, tuple), input_name) 161 if isinstance(value, numbers.Number): 162 if value < 0: 163 raise ValueError("The input value of {} cannot be negative.".format(input_name)) 164 elif isinstance(value, (list, tuple)): 165 if len(value) != 2: 166 raise TypeError("If {0} is a sequence, the length must be 2.".format(input_name)) 167 if value[0] > value[1]: 168 raise ValueError("{0} value should be in (min,max) format. Got ({1}, {2}).".format(input_name, 169 value[0], value[1])) 170 check_range(value, bound) 171 172 173def check_erasing_value(value): 174 if not (isinstance(value, (numbers.Number,)) or 175 (isinstance(value, (str,)) and value == 'random') or 176 (isinstance(value, (tuple, list)) and len(value) == 3)): 177 raise ValueError("The value for erasing should be either a single value, " 178 "or a string 'random', or a sequence of 3 elements for RGB respectively.") 179 180 181def check_crop(method): 182 """A wrapper that wraps a parameter checker around the original function(crop operation).""" 183 184 @wraps(method) 185 def new_method(self, *args, **kwargs): 186 [coordinates, size], _ = parse_user_args(method, *args, **kwargs) 187 check_crop_coordinates(coordinates) 188 check_crop_size(size) 189 190 return method(self, *args, **kwargs) 191 192 return new_method 193 194 195def check_center_crop(method): 196 """A wrapper that wraps a parameter checker around the original function(center crop operation).""" 197 198 @wraps(method) 199 def new_method(self, *args, **kwargs): 200 [size], _ = parse_user_args(method, *args, **kwargs) 201 check_crop_size(size) 202 203 return method(self, *args, **kwargs) 204 205 return new_method 206 207 208def check_five_crop(method): 209 """A wrapper that wraps a parameter checker around the original function(five crop operation).""" 210 211 @wraps(method) 212 def new_method(self, *args, **kwargs): 213 [size], _ = parse_user_args(method, *args, **kwargs) 214 check_crop_size(size) 215 216 return method(self, *args, **kwargs) 217 218 return new_method 219 220 221def check_posterize(method): 222 """A wrapper that wraps a parameter checker around the original function(posterize operation).""" 223 224 @wraps(method) 225 def new_method(self, *args, **kwargs): 226 [bits], _ = parse_user_args(method, *args, **kwargs) 227 if bits is not None: 228 type_check(bits, (list, tuple, int), "bits") 229 if isinstance(bits, int): 230 check_value(bits, [1, 8]) 231 if isinstance(bits, (list, tuple)): 232 if len(bits) != 2: 233 raise TypeError("Size of bits should be a single integer or a list/tuple (min, max) of length 2.") 234 for item in bits: 235 check_uint8(item, "bits") 236 # also checks if min <= max 237 check_range(bits, [1, 8]) 238 return method(self, *args, **kwargs) 239 240 return new_method 241 242 243def check_resize_interpolation(method): 244 """A wrapper that wraps a parameter checker around the original function(resize interpolation operation).""" 245 246 @wraps(method) 247 def new_method(self, *args, **kwargs): 248 [size, interpolation], _ = parse_user_args(method, *args, **kwargs) 249 if interpolation is None: 250 raise KeyError("Interpolation should not be None") 251 check_resize_size(size) 252 type_check(interpolation, (Inter,), "interpolation") 253 254 return method(self, *args, **kwargs) 255 256 return new_method 257 258 259def check_resize(method): 260 """A wrapper that wraps a parameter checker around the original function(resize operation).""" 261 262 @wraps(method) 263 def new_method(self, *args, **kwargs): 264 [size], _ = parse_user_args(method, *args, **kwargs) 265 check_resize_size(size) 266 267 return method(self, *args, **kwargs) 268 269 return new_method 270 271 272def check_size_scale_ration_max_attempts_paras(size, scale, ratio, max_attempts): 273 """Wrapper method to check the parameters of RandomCropDecodeResize and SoftDvppDecodeRandomCropResizeJpeg.""" 274 275 check_crop_size(size) 276 if scale is not None: 277 type_check(scale, (tuple, list), "scale") 278 if len(scale) != 2: 279 raise TypeError("scale should be a list/tuple of length 2.") 280 type_check_list(scale, (float, int), "scale") 281 if scale[0] > scale[1]: 282 raise ValueError("scale should be in (min,max) format. Got (max,min).") 283 check_range(scale, [0, FLOAT_MAX_INTEGER]) 284 check_positive(scale[1], "scale[1]") 285 if ratio is not None: 286 type_check(ratio, (tuple, list), "ratio") 287 if len(ratio) != 2: 288 raise TypeError("ratio should be a list/tuple of length 2.") 289 type_check_list(ratio, (float, int), "ratio") 290 if ratio[0] > ratio[1]: 291 raise ValueError("ratio should be in (min,max) format. Got (max,min).") 292 check_range(ratio, [0, FLOAT_MAX_INTEGER]) 293 check_positive(ratio[0], "ratio[0]") 294 check_positive(ratio[1], "ratio[1]") 295 if max_attempts is not None: 296 check_value(max_attempts, (1, FLOAT_MAX_INTEGER)) 297 298 299def check_random_resize_crop(method): 300 """A wrapper that wraps a parameter checker around the original function(random resize crop operation).""" 301 302 @wraps(method) 303 def new_method(self, *args, **kwargs): 304 [size, scale, ratio, interpolation, max_attempts], _ = parse_user_args(method, *args, **kwargs) 305 if interpolation is not None: 306 type_check(interpolation, (Inter,), "interpolation") 307 check_size_scale_ration_max_attempts_paras(size, scale, ratio, max_attempts) 308 309 return method(self, *args, **kwargs) 310 311 return new_method 312 313 314def check_prob(method): 315 """A wrapper that wraps a parameter checker (to confirm probability) around the original function.""" 316 317 @wraps(method) 318 def new_method(self, *args, **kwargs): 319 [prob], _ = parse_user_args(method, *args, **kwargs) 320 type_check(prob, (float, int,), "prob") 321 check_value(prob, [0., 1.], "prob") 322 323 return method(self, *args, **kwargs) 324 325 return new_method 326 327 328def check_normalize_c(method): 329 """A wrapper that wraps a parameter checker around the original function(normalize operation written in C++).""" 330 331 @wraps(method) 332 def new_method(self, *args, **kwargs): 333 [mean, std], _ = parse_user_args(method, *args, **kwargs) 334 check_normalize_c_param(mean, std) 335 336 return method(self, *args, **kwargs) 337 338 return new_method 339 340 341def check_normalize_py(method): 342 """A wrapper that wraps a parameter checker around the original function(normalize operation written in Python).""" 343 344 @wraps(method) 345 def new_method(self, *args, **kwargs): 346 [mean, std], _ = parse_user_args(method, *args, **kwargs) 347 check_normalize_py_param(mean, std) 348 349 return method(self, *args, **kwargs) 350 351 return new_method 352 353 354def check_normalizepad_c(method): 355 """A wrapper that wraps a parameter checker around the original function(normalizepad written in C++).""" 356 357 @wraps(method) 358 def new_method(self, *args, **kwargs): 359 [mean, std, dtype], _ = parse_user_args(method, *args, **kwargs) 360 check_normalize_c_param(mean, std) 361 if not isinstance(dtype, str): 362 raise TypeError("dtype should be string.") 363 if dtype not in ["float32", "float16"]: 364 raise ValueError("dtype only support float32 or float16.") 365 366 return method(self, *args, **kwargs) 367 368 return new_method 369 370 371def check_normalizepad_py(method): 372 """A wrapper that wraps a parameter checker around the original function(normalizepad written in Python).""" 373 374 @wraps(method) 375 def new_method(self, *args, **kwargs): 376 [mean, std, dtype], _ = parse_user_args(method, *args, **kwargs) 377 check_normalize_py_param(mean, std) 378 if not isinstance(dtype, str): 379 raise TypeError("dtype should be string.") 380 if dtype not in ["float32", "float16"]: 381 raise ValueError("dtype only support float32 or float16.") 382 383 return method(self, *args, **kwargs) 384 385 return new_method 386 387 388def check_random_crop(method): 389 """Wrapper method to check the parameters of random crop.""" 390 391 @wraps(method) 392 def new_method(self, *args, **kwargs): 393 [size, padding, pad_if_needed, fill_value, padding_mode], _ = parse_user_args(method, *args, **kwargs) 394 check_crop_size(size) 395 type_check(pad_if_needed, (bool,), "pad_if_needed") 396 if padding is not None: 397 check_padding(padding) 398 if fill_value is not None: 399 check_fill_value(fill_value) 400 if padding_mode is not None: 401 type_check(padding_mode, (Border,), "padding_mode") 402 403 return method(self, *args, **kwargs) 404 405 return new_method 406 407 408def check_random_color_adjust(method): 409 """Wrapper method to check the parameters of random color adjust.""" 410 411 @wraps(method) 412 def new_method(self, *args, **kwargs): 413 [brightness, contrast, saturation, hue], _ = parse_user_args(method, *args, **kwargs) 414 check_random_color_adjust_param(brightness, "brightness") 415 check_random_color_adjust_param(contrast, "contrast") 416 check_random_color_adjust_param(saturation, "saturation") 417 check_random_color_adjust_param(hue, 'hue', center=0, bound=(-0.5, 0.5), non_negative=False) 418 419 return method(self, *args, **kwargs) 420 421 return new_method 422 423 424def check_resample_expand_center_fill_value_params(resample, expand, center, fill_value): 425 type_check(resample, (Inter,), "resample") 426 type_check(expand, (bool,), "expand") 427 if center is not None: 428 check_2tuple(center, "center") 429 for value in center: 430 type_check(value, (int, float), "center") 431 check_value(value, [INT32_MIN, INT32_MAX], "center") 432 check_fill_value(fill_value) 433 434 435def check_random_rotation(method): 436 """Wrapper method to check the parameters of random rotation.""" 437 438 @wraps(method) 439 def new_method(self, *args, **kwargs): 440 [degrees, resample, expand, center, fill_value], _ = parse_user_args(method, *args, **kwargs) 441 check_degrees(degrees) 442 check_resample_expand_center_fill_value_params(resample, expand, center, fill_value) 443 444 return method(self, *args, **kwargs) 445 446 return new_method 447 448 449def check_rotate(method): 450 """Wrapper method to check the parameters of rotate.""" 451 452 @wraps(method) 453 def new_method(self, *args, **kwargs): 454 [degrees, resample, expand, center, fill_value], _ = parse_user_args(method, *args, **kwargs) 455 type_check(degrees, (float, int), "degrees") 456 check_float32(degrees, "degrees") 457 check_resample_expand_center_fill_value_params(resample, expand, center, fill_value) 458 459 return method(self, *args, **kwargs) 460 461 return new_method 462 463 464def check_ten_crop(method): 465 """Wrapper method to check the parameters of crop.""" 466 467 @wraps(method) 468 def new_method(self, *args, **kwargs): 469 [size, use_vertical_flip], _ = parse_user_args(method, *args, **kwargs) 470 check_crop_size(size) 471 472 if use_vertical_flip is not None: 473 type_check(use_vertical_flip, (bool,), "use_vertical_flip") 474 475 return method(self, *args, **kwargs) 476 477 return new_method 478 479 480def check_num_channels(method): 481 """Wrapper method to check the parameters of number of channels.""" 482 483 @wraps(method) 484 def new_method(self, *args, **kwargs): 485 [num_output_channels], _ = parse_user_args(method, *args, **kwargs) 486 if num_output_channels is not None: 487 if num_output_channels not in (1, 3): 488 raise ValueError("Number of channels of the output grayscale image" 489 "should be either 1 or 3. Got {0}.".format(num_output_channels)) 490 491 return method(self, *args, **kwargs) 492 493 return new_method 494 495 496def check_pad(method): 497 """Wrapper method to check the parameters of random pad.""" 498 499 @wraps(method) 500 def new_method(self, *args, **kwargs): 501 [padding, fill_value, padding_mode], _ = parse_user_args(method, *args, **kwargs) 502 check_padding(padding) 503 check_fill_value(fill_value) 504 type_check(padding_mode, (Border,), "padding_mode") 505 506 return method(self, *args, **kwargs) 507 508 return new_method 509 510 511def check_slice_patches(method): 512 """Wrapper method to check the parameters of slice patches.""" 513 514 @wraps(method) 515 def new_method(self, *args, **kwargs): 516 [num_height, num_width, slice_mode, fill_value], _ = parse_user_args(method, *args, **kwargs) 517 if num_height is not None: 518 type_check(num_height, (int,), "num_height") 519 check_value(num_height, (1, INT32_MAX), "num_height") 520 if num_width is not None: 521 type_check(num_width, (int,), "num_width") 522 check_value(num_width, (1, INT32_MAX), "num_width") 523 if slice_mode is not None: 524 type_check(slice_mode, (SliceMode,), "slice_mode") 525 if fill_value is not None: 526 type_check(fill_value, (int,), "fill_value") 527 check_value(fill_value, [0, 255], "fill_value") 528 return method(self, *args, **kwargs) 529 530 return new_method 531 532 533def check_random_perspective(method): 534 """Wrapper method to check the parameters of random perspective.""" 535 536 @wraps(method) 537 def new_method(self, *args, **kwargs): 538 [distortion_scale, prob, interpolation], _ = parse_user_args(method, *args, **kwargs) 539 540 type_check(distortion_scale, (float,), "distortion_scale") 541 type_check(prob, (float,), "prob") 542 check_value(distortion_scale, [0., 1.], "distortion_scale") 543 check_value(prob, [0., 1.], "prob") 544 type_check(interpolation, (Inter,), "interpolation") 545 546 return method(self, *args, **kwargs) 547 548 return new_method 549 550 551def check_mix_up(method): 552 """Wrapper method to check the parameters of mix up.""" 553 554 @wraps(method) 555 def new_method(self, *args, **kwargs): 556 [batch_size, alpha, is_single], _ = parse_user_args(method, *args, **kwargs) 557 type_check(is_single, (bool,), "is_single") 558 type_check(batch_size, (int,), "batch_size") 559 type_check(alpha, (int, float), "alpha") 560 check_value(batch_size, (1, FLOAT_MAX_INTEGER)) 561 check_positive(alpha, "alpha") 562 return method(self, *args, **kwargs) 563 564 return new_method 565 566 567def check_rgb_to_bgr(method): 568 """Wrapper method to check the parameters of rgb_to_bgr.""" 569 570 @wraps(method) 571 def new_method(self, *args, **kwargs): 572 [is_hwc], _ = parse_user_args(method, *args, **kwargs) 573 type_check(is_hwc, (bool,), "is_hwc") 574 return method(self, *args, **kwargs) 575 576 return new_method 577 578 579def check_rgb_to_hsv(method): 580 """Wrapper method to check the parameters of rgb_to_hsv.""" 581 582 @wraps(method) 583 def new_method(self, *args, **kwargs): 584 [is_hwc], _ = parse_user_args(method, *args, **kwargs) 585 type_check(is_hwc, (bool,), "is_hwc") 586 return method(self, *args, **kwargs) 587 588 return new_method 589 590 591def check_hsv_to_rgb(method): 592 """Wrapper method to check the parameters of hsv_to_rgb.""" 593 594 @wraps(method) 595 def new_method(self, *args, **kwargs): 596 [is_hwc], _ = parse_user_args(method, *args, **kwargs) 597 type_check(is_hwc, (bool,), "is_hwc") 598 return method(self, *args, **kwargs) 599 600 return new_method 601 602 603def check_random_erasing(method): 604 """Wrapper method to check the parameters of random erasing.""" 605 606 @wraps(method) 607 def new_method(self, *args, **kwargs): 608 [prob, scale, ratio, value, inplace, max_attempts], _ = parse_user_args(method, *args, **kwargs) 609 610 type_check(prob, (float, int,), "prob") 611 type_check_list(scale, (float, int,), "scale") 612 if len(scale) != 2: 613 raise TypeError("scale should be a list or tuple of length 2.") 614 type_check_list(ratio, (float, int,), "ratio") 615 if len(ratio) != 2: 616 raise TypeError("ratio should be a list or tuple of length 2.") 617 type_check(value, (int, list, tuple, str), "value") 618 type_check(inplace, (bool,), "inplace") 619 type_check(max_attempts, (int,), "max_attempts") 620 check_erasing_value(value) 621 622 check_value(prob, [0., 1.], "prob") 623 if scale[0] > scale[1]: 624 raise ValueError("scale should be in (min,max) format. Got (max,min).") 625 check_range(scale, [0, FLOAT_MAX_INTEGER]) 626 check_positive(scale[1], "scale[1]") 627 if ratio[0] > ratio[1]: 628 raise ValueError("ratio should be in (min,max) format. Got (max,min).") 629 check_value_ratio(ratio[0], [0, FLOAT_MAX_INTEGER]) 630 check_value_ratio(ratio[1], [0, FLOAT_MAX_INTEGER]) 631 if isinstance(value, int): 632 check_value(value, (0, 255)) 633 if isinstance(value, (list, tuple)): 634 for item in value: 635 type_check(item, (int,), "value") 636 check_value(item, [0, 255], "value") 637 check_value(max_attempts, (1, FLOAT_MAX_INTEGER)) 638 639 return method(self, *args, **kwargs) 640 641 return new_method 642 643 644def check_cutout(method): 645 """Wrapper method to check the parameters of cutout operation.""" 646 647 @wraps(method) 648 def new_method(self, *args, **kwargs): 649 [length, num_patches], _ = parse_user_args(method, *args, **kwargs) 650 type_check(length, (int,), "length") 651 type_check(num_patches, (int,), "num_patches") 652 check_value(length, (1, FLOAT_MAX_INTEGER)) 653 check_value(num_patches, (1, FLOAT_MAX_INTEGER)) 654 655 return method(self, *args, **kwargs) 656 657 return new_method 658 659 660def check_linear_transform(method): 661 """Wrapper method to check the parameters of linear transform.""" 662 663 @wraps(method) 664 def new_method(self, *args, **kwargs): 665 [transformation_matrix, mean_vector], _ = parse_user_args(method, *args, **kwargs) 666 type_check(transformation_matrix, (np.ndarray,), "transformation_matrix") 667 type_check(mean_vector, (np.ndarray,), "mean_vector") 668 669 if transformation_matrix.shape[0] != transformation_matrix.shape[1]: 670 raise ValueError("transformation_matrix should be a square matrix. " 671 "Got shape {} instead.".format(transformation_matrix.shape)) 672 if mean_vector.shape[0] != transformation_matrix.shape[0]: 673 raise ValueError("mean_vector length {0} should match either one dimension of the square" 674 "transformation_matrix {1}.".format(mean_vector.shape[0], transformation_matrix.shape)) 675 676 return method(self, *args, **kwargs) 677 678 return new_method 679 680 681def check_random_affine(method): 682 """Wrapper method to check the parameters of random affine.""" 683 684 @wraps(method) 685 def new_method(self, *args, **kwargs): 686 [degrees, translate, scale, shear, resample, fill_value], _ = parse_user_args(method, *args, **kwargs) 687 check_degrees(degrees) 688 689 if translate is not None: 690 type_check(translate, (list, tuple), "translate") 691 type_check_list(translate, (int, float), "translate") 692 if len(translate) != 2 and len(translate) != 4: 693 raise TypeError("translate should be a list or tuple of length 2 or 4.") 694 for i, t in enumerate(translate): 695 check_value(t, [-1.0, 1.0], "translate at {0}".format(i)) 696 697 if scale is not None: 698 type_check(scale, (tuple, list), "scale") 699 type_check_list(scale, (int, float), "scale") 700 if len(scale) == 2: 701 if scale[0] > scale[1]: 702 raise ValueError("Input scale[1] must be equal to or greater than scale[0].") 703 check_range(scale, [0, FLOAT_MAX_INTEGER]) 704 check_positive(scale[1], "scale[1]") 705 else: 706 raise TypeError("scale should be a list or tuple of length 2.") 707 708 if shear is not None: 709 type_check(shear, (numbers.Number, tuple, list), "shear") 710 if isinstance(shear, numbers.Number): 711 check_positive(shear, "shear") 712 else: 713 type_check_list(shear, (int, float), "shear") 714 if len(shear) not in (2, 4): 715 raise TypeError("shear must be of length 2 or 4.") 716 if len(shear) == 2 and shear[0] > shear[1]: 717 raise ValueError("Input shear[1] must be equal to or greater than shear[0]") 718 if len(shear) == 4 and (shear[0] > shear[1] or shear[2] > shear[3]): 719 raise ValueError("Input shear[1] must be equal to or greater than shear[0] and " 720 "shear[3] must be equal to or greater than shear[2].") 721 722 type_check(resample, (Inter,), "resample") 723 724 if fill_value is not None: 725 check_fill_value(fill_value) 726 727 return method(self, *args, **kwargs) 728 729 return new_method 730 731 732def check_rescale(method): 733 """Wrapper method to check the parameters of rescale.""" 734 735 @wraps(method) 736 def new_method(self, *args, **kwargs): 737 [rescale, shift], _ = parse_user_args(method, *args, **kwargs) 738 type_check(rescale, (numbers.Number,), "rescale") 739 type_check(shift, (numbers.Number,), "shift") 740 check_float32(rescale) 741 check_float32(shift) 742 743 return method(self, *args, **kwargs) 744 745 return new_method 746 747 748def check_uniform_augment_cpp(method): 749 """Wrapper method to check the parameters of UniformAugment C++ op.""" 750 751 @wraps(method) 752 def new_method(self, *args, **kwargs): 753 [transforms, num_ops], _ = parse_user_args(method, *args, **kwargs) 754 type_check(num_ops, (int,), "num_ops") 755 check_positive(num_ops, "num_ops") 756 757 if num_ops > len(transforms): 758 raise ValueError("num_ops is greater than transforms list size.") 759 parsed_transforms = [] 760 for op in transforms: 761 if op and getattr(op, 'parse', None): 762 parsed_transforms.append(op.parse()) 763 else: 764 parsed_transforms.append(op) 765 type_check(parsed_transforms, (list, tuple,), "transforms") 766 for index, arg in enumerate(parsed_transforms): 767 if not isinstance(arg, (TensorOp, TensorOperation)): 768 raise TypeError("Type of Transforms[{0}] must be c_transform, but got {1}".format(index, type(arg))) 769 770 return method(self, *args, **kwargs) 771 772 return new_method 773 774 775def check_bounding_box_augment_cpp(method): 776 """Wrapper method to check the parameters of BoundingBoxAugment C++ op.""" 777 778 @wraps(method) 779 def new_method(self, *args, **kwargs): 780 [transform, ratio], _ = parse_user_args(method, *args, **kwargs) 781 type_check(ratio, (float, int), "ratio") 782 check_value(ratio, [0., 1.], "ratio") 783 if transform and getattr(transform, 'parse', None): 784 transform = transform.parse() 785 type_check(transform, (TensorOp, TensorOperation), "transform") 786 return method(self, *args, **kwargs) 787 788 return new_method 789 790 791def check_adjust_gamma(method): 792 """Wrapper method to check the parameters of AdjustGamma ops (Python and C++).""" 793 794 @wraps(method) 795 def new_method(self, *args, **kwargs): 796 [gamma, gain], _ = parse_user_args(method, *args, **kwargs) 797 type_check(gamma, (float, int), "gamma") 798 check_value(gamma, (0, FLOAT_MAX_INTEGER)) 799 if gain is not None: 800 type_check(gain, (float, int), "gain") 801 check_value(gain, (FLOAT_MIN_INTEGER, FLOAT_MAX_INTEGER)) 802 return method(self, *args, **kwargs) 803 804 return new_method 805 806 807def check_auto_contrast(method): 808 """Wrapper method to check the parameters of AutoContrast ops (Python and C++).""" 809 810 @wraps(method) 811 def new_method(self, *args, **kwargs): 812 [cutoff, ignore], _ = parse_user_args(method, *args, **kwargs) 813 type_check(cutoff, (int, float), "cutoff") 814 check_value_cutoff(cutoff, [0, 50], "cutoff") 815 if ignore is not None: 816 type_check(ignore, (list, tuple, int), "ignore") 817 if isinstance(ignore, int): 818 check_value(ignore, [0, 255], "ignore") 819 if isinstance(ignore, (list, tuple)): 820 for item in ignore: 821 type_check(item, (int,), "item") 822 check_value(item, [0, 255], "ignore") 823 return method(self, *args, **kwargs) 824 825 return new_method 826 827 828def check_uniform_augment_py(method): 829 """Wrapper method to check the parameters of Python UniformAugment op.""" 830 831 @wraps(method) 832 def new_method(self, *args, **kwargs): 833 [transforms, num_ops], _ = parse_user_args(method, *args, **kwargs) 834 type_check(transforms, (list,), "transforms") 835 836 if not transforms: 837 raise ValueError("transforms list is empty.") 838 839 for transform in transforms: 840 if isinstance(transform, TensorOp): 841 raise ValueError("transform list only accepts Python operations.") 842 843 type_check(num_ops, (int,), "num_ops") 844 check_positive(num_ops, "num_ops") 845 if num_ops > len(transforms): 846 raise ValueError("num_ops cannot be greater than the length of transforms list.") 847 848 return method(self, *args, **kwargs) 849 850 return new_method 851 852 853def check_positive_degrees(method): 854 """A wrapper method to check degrees parameter in RandomSharpness and RandomColor ops (Python and C++)""" 855 856 @wraps(method) 857 def new_method(self, *args, **kwargs): 858 [degrees], _ = parse_user_args(method, *args, **kwargs) 859 860 if degrees is not None: 861 if not isinstance(degrees, (list, tuple)): 862 raise TypeError("degrees must be either a tuple or a list.") 863 type_check_list(degrees, (int, float), "degrees") 864 if len(degrees) != 2: 865 raise ValueError("degrees must be a sequence with length 2.") 866 for degree in degrees: 867 check_value(degree, (0, FLOAT_MAX_INTEGER)) 868 if degrees[0] > degrees[1]: 869 raise ValueError("degrees should be in (min,max) format. Got (max,min).") 870 871 return method(self, *args, **kwargs) 872 873 return new_method 874 875 876def check_random_select_subpolicy_op(method): 877 """Wrapper method to check the parameters of RandomSelectSubpolicyOp.""" 878 879 @wraps(method) 880 def new_method(self, *args, **kwargs): 881 [policy], _ = parse_user_args(method, *args, **kwargs) 882 type_check(policy, (list,), "policy") 883 if not policy: 884 raise ValueError("policy can not be empty.") 885 for sub_ind, sub in enumerate(policy): 886 type_check(sub, (list,), "policy[{0}]".format([sub_ind])) 887 if not sub: 888 raise ValueError("policy[{0}] can not be empty.".format(sub_ind)) 889 for op_ind, tp in enumerate(sub): 890 check_2tuple(tp, "policy[{0}][{1}]".format(sub_ind, op_ind)) 891 check_c_tensor_op(tp[0], "op of (op, prob) in policy[{0}][{1}]".format(sub_ind, op_ind)) 892 check_value(tp[1], (0, 1), "prob of (op, prob) policy[{0}][{1}]".format(sub_ind, op_ind)) 893 894 return method(self, *args, **kwargs) 895 896 return new_method 897 898 899def check_soft_dvpp_decode_random_crop_resize_jpeg(method): 900 """Wrapper method to check the parameters of SoftDvppDecodeRandomCropResizeJpeg.""" 901 902 @wraps(method) 903 def new_method(self, *args, **kwargs): 904 [size, scale, ratio, max_attempts], _ = parse_user_args(method, *args, **kwargs) 905 check_size_scale_ration_max_attempts_paras(size, scale, ratio, max_attempts) 906 907 return method(self, *args, **kwargs) 908 909 return new_method 910 911 912def check_random_solarize(method): 913 """Wrapper method to check the parameters of RandomSolarizeOp.""" 914 915 @wraps(method) 916 def new_method(self, *args, **kwargs): 917 [threshold], _ = parse_user_args(method, *args, **kwargs) 918 919 type_check(threshold, (tuple,), "threshold") 920 type_check_list(threshold, (int,), "threshold") 921 if len(threshold) != 2: 922 raise ValueError("threshold must be a sequence of two numbers.") 923 for element in threshold: 924 check_value(element, (0, UINT8_MAX)) 925 if threshold[1] < threshold[0]: 926 raise ValueError("threshold must be in min max format numbers.") 927 928 return method(self, *args, **kwargs) 929 930 return new_method 931 932 933def check_gaussian_blur(method): 934 """Wrapper method to check the parameters of GaussianBlur.""" 935 936 @wraps(method) 937 def new_method(self, *args, **kwargs): 938 [kernel_size, sigma], _ = parse_user_args(method, *args, **kwargs) 939 940 type_check(kernel_size, (int, list, tuple), "kernel_size") 941 if isinstance(kernel_size, int): 942 check_value(kernel_size, (1, FLOAT_MAX_INTEGER), "kernel_size") 943 check_odd(kernel_size, "kernel_size") 944 elif isinstance(kernel_size, (list, tuple)) and len(kernel_size) == 2: 945 for index, value in enumerate(kernel_size): 946 type_check(value, (int,), "kernel_size[{}]".format(index)) 947 check_value(value, (1, FLOAT_MAX_INTEGER), "kernel_size") 948 check_odd(value, "kernel_size[{}]".format(index)) 949 else: 950 raise TypeError( 951 "Kernel size should be a single integer or a list/tuple (kernel_width, kernel_height) of length 2.") 952 953 if sigma is not None: 954 type_check(sigma, (numbers.Number, list, tuple), "sigma") 955 if isinstance(sigma, numbers.Number): 956 check_value(sigma, (0, FLOAT_MAX_INTEGER), "sigma") 957 elif isinstance(sigma, (list, tuple)) and len(sigma) == 2: 958 for index, value in enumerate(sigma): 959 type_check(value, (numbers.Number,), "size[{}]".format(index)) 960 check_value(value, (0, FLOAT_MAX_INTEGER), "sigma") 961 else: 962 raise TypeError("Sigma should be a single number or a list/tuple of length 2 for width and height.") 963 964 return method(self, *args, **kwargs) 965 966 return new_method 967 968 969def check_convert_color(method): 970 """Wrapper method to check the parameters of convertcolor.""" 971 972 @wraps(method) 973 def new_method(self, *args, **kwargs): 974 [convert_mode], _ = parse_user_args(method, *args, **kwargs) 975 if convert_mode is not None: 976 type_check(convert_mode, (ConvertMode,), "convert_mode") 977 return method(self, *args, **kwargs) 978 979 return new_method 980