1# Copyright 2019-2022 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Validators for image processing operations. 16""" 17import numbers 18from functools import wraps 19import numpy as np 20 21from mindspore._c_dataengine import TensorOp, TensorOperation 22from mindspore._c_expression import typing 23from mindspore.dataset.core.validator_helpers import check_value, check_uint8, FLOAT_MIN_INTEGER, FLOAT_MAX_INTEGER, \ 24 check_pos_float32, check_float32, check_2tuple, check_range, check_positive, INT32_MAX, INT32_MIN, \ 25 parse_user_args, type_check, type_check_list, check_c_tensor_op, UINT8_MAX, UINT8_MIN, check_value_normalize_std, \ 26 check_value_cutoff, check_value_ratio, check_odd, check_non_negative_float32, check_non_negative_int32, \ 27 check_pos_int32, check_int32, check_tensor_op, deprecator_factory, check_valid_str 28from mindspore.dataset.transforms.validators import check_transform_op_type 29from .utils import Inter, Border, ImageBatchFormat, ConvertMode, SliceMode, AutoAugmentPolicy 30 31 32def check_affine(method): 33 """Wrapper method to check the parameters of Affine.""" 34 @wraps(method) 35 def new_method(self, *args, **kwargs): 36 [degrees, translate, scale, shear, resample, fill_value], _ = parse_user_args(method, *args, **kwargs) 37 38 type_check(degrees, (int, float), "degrees") 39 check_value(degrees, [-180, 180], "degrees") 40 41 type_check(translate, (list, tuple), "translate") 42 if len(translate) != 2: 43 raise TypeError("The length of translate should be 2.") 44 for i, t in enumerate(translate): 45 type_check(t, (int, float), "translate[{}]".format(i)) 46 check_value(t, [-1.0, 1.0], "translate[{}]".format(i)) 47 48 type_check(scale, (int, float), "scale") 49 check_positive(scale, "scale") 50 51 type_check(shear, (numbers.Number, tuple, list), "shear") 52 if isinstance(shear, (list, tuple)): 53 if len(shear) != 2: 54 raise TypeError("The length of shear should be 2.") 55 for i, _ in enumerate(shear): 56 type_check(shear[i], (int, float), "shear[{}]".format(i)) 57 check_value(shear[i], [-180, 180], "shear[{}]".format(i)) 58 else: 59 check_value(shear, [-180, 180], "shear") 60 61 type_check(resample, (Inter,), "resample") 62 63 check_fill_value(fill_value) 64 65 return method(self, *args, **kwargs) 66 67 return new_method 68 69 70def check_crop_size(size): 71 """Wrapper method to check the parameters of crop size.""" 72 type_check(size, (int, list, tuple), "size") 73 if isinstance(size, int): 74 check_value(size, (1, FLOAT_MAX_INTEGER)) 75 elif isinstance(size, (tuple, list)) and len(size) == 2: 76 for index, value in enumerate(size): 77 type_check(value, (int,), "size[{}]".format(index)) 78 check_value(value, (1, FLOAT_MAX_INTEGER)) 79 else: 80 raise TypeError("Size should be a single integer or a list/tuple (h, w) of length 2.") 81 82 83def check_crop_coordinates(coordinates): 84 """Wrapper method to check the parameters of crop size.""" 85 type_check(coordinates, (list, tuple), "coordinates") 86 if isinstance(coordinates, (tuple, list)) and len(coordinates) == 2: 87 for index, value in enumerate(coordinates): 88 type_check(value, (int,), "coordinates[{}]".format(index)) 89 check_value(value, (0, INT32_MAX), "coordinates[{}]".format(index)) 90 else: 91 raise TypeError("Coordinates should be a list/tuple (y, x) of length 2.") 92 93 94def check_cut_mix_batch_c(method): 95 """Wrapper method to check the parameters of CutMixBatch.""" 96 97 @wraps(method) 98 def new_method(self, *args, **kwargs): 99 [image_batch_format, alpha, prob], _ = parse_user_args(method, *args, **kwargs) 100 type_check(image_batch_format, (ImageBatchFormat,), "image_batch_format") 101 type_check(alpha, (int, float), "alpha") 102 type_check(prob, (int, float), "prob") 103 check_pos_float32(alpha) 104 check_positive(alpha, "alpha") 105 check_value(prob, [0, 1], "prob") 106 return method(self, *args, **kwargs) 107 108 return new_method 109 110 111def check_resize_size(size): 112 """Wrapper method to check the parameters of resize.""" 113 if isinstance(size, int): 114 check_value(size, (1, FLOAT_MAX_INTEGER)) 115 elif isinstance(size, (tuple, list)) and len(size) == 2: 116 for i, value in enumerate(size): 117 type_check(value, (int,), "size at dim {0}".format(i)) 118 check_value(value, (1, INT32_MAX), "size at dim {0}".format(i)) 119 else: 120 raise TypeError("Size should be a single integer or a list/tuple (h, w) of length 2.") 121 122 123def check_mix_up_batch_c(method): 124 """Wrapper method to check the parameters of MixUpBatch.""" 125 126 @wraps(method) 127 def new_method(self, *args, **kwargs): 128 [alpha], _ = parse_user_args(method, *args, **kwargs) 129 type_check(alpha, (int, float), "alpha") 130 check_positive(alpha, "alpha") 131 check_pos_float32(alpha) 132 133 return method(self, *args, **kwargs) 134 135 return new_method 136 137 138def check_normalize_param(mean, std): 139 """Check the parameters of Normalize and NormalizePad operations.""" 140 type_check(mean, (list, tuple), "mean") 141 type_check(std, (list, tuple), "std") 142 if len(mean) != len(std): 143 raise ValueError("Length of mean and std must be equal.") 144 for i, mean_value in enumerate(mean): 145 type_check(mean_value, (int, float), "mean[{}]".format(i)) 146 check_value(mean_value, [0, 255], "mean[{}]".format(i)) 147 for j, std_value in enumerate(std): 148 type_check(std_value, (int, float), "std[{}]".format(j)) 149 check_value_normalize_std(std_value, [0, 255], "std[{}]".format(j)) 150 151 152def check_normalize_c_param(mean, std): 153 type_check(mean, (list, tuple), "mean") 154 type_check(std, (list, tuple), "std") 155 if len(mean) != len(std): 156 raise ValueError("Length of mean and std must be equal.") 157 for mean_value in mean: 158 check_value(mean_value, [0, 255], "mean_value") 159 for std_value in std: 160 check_value_normalize_std(std_value, [0, 255], "std_value") 161 162 163def check_normalize_py_param(mean, std): 164 type_check(mean, (list, tuple), "mean") 165 type_check(std, (list, tuple), "std") 166 if len(mean) != len(std): 167 raise ValueError("Length of mean and std must be equal.") 168 for mean_value in mean: 169 check_value(mean_value, [0., 1.], "mean_value") 170 for std_value in std: 171 check_value_normalize_std(std_value, [0., 1.], "std_value") 172 173 174def check_fill_value(fill_value): 175 if isinstance(fill_value, int): 176 check_uint8(fill_value, "fill_value") 177 elif isinstance(fill_value, tuple) and len(fill_value) == 3: 178 for i, value in enumerate(fill_value): 179 check_uint8(value, "fill_value[{0}]".format(i)) 180 else: 181 raise TypeError("fill_value should be a single integer or a 3-tuple.") 182 183 184def check_padding(padding): 185 """Parsing the padding arguments and check if it is legal.""" 186 type_check(padding, (tuple, list, numbers.Number), "padding") 187 if isinstance(padding, numbers.Number): 188 type_check(padding, (int,), "padding") 189 check_value(padding, (0, INT32_MAX), "padding") 190 if isinstance(padding, (tuple, list)): 191 if len(padding) not in (2, 4): 192 raise ValueError("The size of the padding list or tuple should be 2 or 4.") 193 for i, pad_value in enumerate(padding): 194 type_check(pad_value, (int,), "padding[{}]".format(i)) 195 check_value(pad_value, (0, INT32_MAX), "pad_value") 196 197 198def check_degrees(degrees): 199 """Check if the `degrees` is legal.""" 200 type_check(degrees, (int, float, list, tuple), "degrees") 201 if isinstance(degrees, (int, float)): 202 check_non_negative_float32(degrees, "degrees") 203 elif isinstance(degrees, (list, tuple)): 204 if len(degrees) == 2: 205 type_check_list(degrees, (int, float), "degrees") 206 for value in degrees: 207 check_float32(value, "degrees") 208 if degrees[0] > degrees[1]: 209 raise ValueError("degrees should be in (min,max) format. Got (max,min).") 210 else: 211 raise TypeError("If degrees is a sequence, the length must be 2.") 212 213 214def check_random_color_adjust_param(value, input_name, center=1, bound=(0, FLOAT_MAX_INTEGER), non_negative=True): 215 """Check the parameters in random color adjust operation.""" 216 type_check(value, (numbers.Number, list, tuple), input_name) 217 if isinstance(value, numbers.Number): 218 if value < 0: 219 raise ValueError("The input value of {} cannot be negative.".format(input_name)) 220 elif isinstance(value, (list, tuple)): 221 if len(value) != 2: 222 raise TypeError("If {0} is a sequence, the length must be 2.".format(input_name)) 223 if value[0] > value[1]: 224 raise ValueError("{0} value should be in (min,max) format. Got ({1}, {2}).".format(input_name, 225 value[0], value[1])) 226 check_range(value, bound) 227 228 229def check_erasing_value(value): 230 if not (isinstance(value, (numbers.Number,)) or 231 (isinstance(value, (str,)) and value == 'random') or 232 (isinstance(value, (tuple, list)) and len(value) == 3)): 233 raise ValueError("The value for erasing should be either a single value, " 234 "or a string 'random', or a sequence of 3 elements for RGB respectively.") 235 236 237def check_crop(method): 238 """A wrapper that wraps a parameter checker around the original function(crop operation).""" 239 240 @wraps(method) 241 def new_method(self, *args, **kwargs): 242 [coordinates, size], _ = parse_user_args(method, *args, **kwargs) 243 check_crop_coordinates(coordinates) 244 check_crop_size(size) 245 246 return method(self, *args, **kwargs) 247 248 return new_method 249 250 251def check_center_crop(method): 252 """A wrapper that wraps a parameter checker around the original function(center crop operation).""" 253 254 @wraps(method) 255 def new_method(self, *args, **kwargs): 256 [size], _ = parse_user_args(method, *args, **kwargs) 257 check_crop_size(size) 258 259 return method(self, *args, **kwargs) 260 261 return new_method 262 263 264def check_five_crop(method): 265 """A wrapper that wraps a parameter checker around the original function(five crop operation).""" 266 267 @wraps(method) 268 def new_method(self, *args, **kwargs): 269 [size], _ = parse_user_args(method, *args, **kwargs) 270 check_crop_size(size) 271 272 return method(self, *args, **kwargs) 273 274 return new_method 275 276 277def check_erase(method): 278 """Wrapper method to check the parameters of erase operation.""" 279 280 @wraps(method) 281 def new_method(self, *args, **kwargs): 282 [top, left, height, width, value, inplace], _ = parse_user_args( 283 method, *args, **kwargs) 284 check_non_negative_int32(top, "top") 285 check_non_negative_int32(left, "left") 286 check_pos_int32(height, "height") 287 check_pos_int32(width, "width") 288 type_check(inplace, (bool,), "inplace") 289 type_check(value, (float, int, tuple), "value") 290 if isinstance(value, (float, int)): 291 value = tuple([value] * 3) 292 type_check_list(value, (float, int), "value") 293 if isinstance(value, tuple) and len(value) == 3: 294 for i, val in enumerate(value): 295 check_value(val, (UINT8_MIN, UINT8_MAX), "value[{}]".format(i)) 296 else: 297 raise TypeError("value should be a single integer/float or a 3-tuple.") 298 299 return method(self, *args, **kwargs) 300 301 return new_method 302 303 304def check_random_posterize(method): 305 """A wrapper that wraps a parameter checker around the original function(posterize operation).""" 306 307 @wraps(method) 308 def new_method(self, *args, **kwargs): 309 [bits], _ = parse_user_args(method, *args, **kwargs) 310 if bits is not None: 311 type_check(bits, (list, tuple, int), "bits") 312 if isinstance(bits, int): 313 check_value(bits, [1, 8]) 314 if isinstance(bits, (list, tuple)): 315 if len(bits) != 2: 316 raise TypeError("Size of bits should be a single integer or a list/tuple (min, max) of length 2.") 317 for item in bits: 318 check_uint8(item, "bits") 319 # also checks if min <= max 320 check_range(bits, [1, 8]) 321 return method(self, *args, **kwargs) 322 323 return new_method 324 325 326def check_posterize(method): 327 """A wrapper that wraps a parameter checker around the original function(posterize operation).""" 328 329 @wraps(method) 330 def new_method(self, *args, **kwargs): 331 [bits], _ = parse_user_args(method, *args, **kwargs) 332 type_check(bits, (int,), "bits") 333 check_value(bits, [0, 8], "bits") 334 return method(self, *args, **kwargs) 335 336 return new_method 337 338 339def check_resize_interpolation(method): 340 """A wrapper that wraps a parameter checker around the original function(resize interpolation operation).""" 341 342 @wraps(method) 343 def new_method(self, *args, **kwargs): 344 [size, interpolation], _ = parse_user_args(method, *args, **kwargs) 345 if interpolation is None: 346 raise KeyError("Interpolation should not be None") 347 check_resize_size(size) 348 type_check(interpolation, (Inter,), "interpolation") 349 350 return method(self, *args, **kwargs) 351 352 return new_method 353 354def check_device_target(method): 355 """A wrapper that wraps a parameter checker""" 356 357 @wraps(method) 358 def new_method(self, *args, **kwargs): 359 [device_target], _ = parse_user_args(method, *args, **kwargs) 360 check_valid_str(device_target, ["CPU", "Ascend"], "device_target") 361 return method(self, *args, **kwargs) 362 return new_method 363 364 365def check_resized_crop(method): 366 """A wrapper that wraps a parameter checker around the original function(ResizedCrop operation).""" 367 368 @wraps(method) 369 def new_method(self, *args, **kwargs): 370 [top, left, height, width, size, interpolation], _ = parse_user_args(method, *args, **kwargs) 371 check_non_negative_int32(top, "top") 372 check_non_negative_int32(left, "left") 373 check_pos_int32(height, "height") 374 check_pos_int32(width, "width") 375 type_check(interpolation, (Inter,), "interpolation") 376 check_crop_size(size) 377 378 return method(self, *args, **kwargs) 379 return new_method 380 381 382def check_resize(method): 383 """A wrapper that wraps a parameter checker around the original function(resize operation).""" 384 385 @wraps(method) 386 def new_method(self, *args, **kwargs): 387 [size], _ = parse_user_args(method, *args, **kwargs) 388 check_resize_size(size) 389 390 return method(self, *args, **kwargs) 391 392 return new_method 393 394 395def check_size_scale_ration_max_attempts_paras(size, scale, ratio, max_attempts): 396 """Wrapper method to check the parameters of RandomCropDecodeResize.""" 397 398 check_crop_size(size) 399 if scale is not None: 400 type_check(scale, (tuple, list), "scale") 401 if len(scale) != 2: 402 raise TypeError("scale should be a list/tuple of length 2.") 403 type_check_list(scale, (float, int), "scale") 404 if scale[0] > scale[1]: 405 raise ValueError("scale should be in (min,max) format. Got (max,min).") 406 check_range(scale, [0, FLOAT_MAX_INTEGER]) 407 check_positive(scale[1], "scale[1]") 408 if ratio is not None: 409 type_check(ratio, (tuple, list), "ratio") 410 if len(ratio) != 2: 411 raise TypeError("ratio should be a list/tuple of length 2.") 412 check_pos_float32(ratio[0], "ratio[0]") 413 check_pos_float32(ratio[1], "ratio[1]") 414 if ratio[0] > ratio[1]: 415 raise ValueError("ratio should be in (min,max) format. Got (max,min).") 416 if max_attempts is not None: 417 check_pos_int32(max_attempts, "max_attempts") 418 419 420def check_random_adjust_sharpness(method): 421 """Wrapper method to check the parameters of RandomAdjustSharpness.""" 422 423 @wraps(method) 424 def new_method(self, *args, **kwargs): 425 [degree, prob], _ = parse_user_args(method, *args, **kwargs) 426 type_check(degree, (float, int), "degree") 427 check_non_negative_float32(degree, "degree") 428 type_check(prob, (float, int), "prob") 429 check_value(prob, [0., 1.], "prob") 430 431 return method(self, *args, **kwargs) 432 433 return new_method 434 435 436def check_random_resize_crop(method): 437 """A wrapper that wraps a parameter checker around the original function(random resize crop operation).""" 438 439 @wraps(method) 440 def new_method(self, *args, **kwargs): 441 [size, scale, ratio, interpolation, max_attempts], _ = parse_user_args(method, *args, **kwargs) 442 if interpolation is not None: 443 type_check(interpolation, (Inter,), "interpolation") 444 check_size_scale_ration_max_attempts_paras(size, scale, ratio, max_attempts) 445 446 return method(self, *args, **kwargs) 447 448 return new_method 449 450 451def check_random_auto_contrast(method): 452 """Wrapper method to check the parameters of Python RandomAutoContrast op.""" 453 454 @wraps(method) 455 def new_method(self, *args, **kwargs): 456 [cutoff, ignore, prob], _ = parse_user_args(method, *args, **kwargs) 457 type_check(cutoff, (int, float), "cutoff") 458 check_value_cutoff(cutoff, [0, 50], "cutoff") 459 if ignore is not None: 460 type_check(ignore, (list, tuple, int), "ignore") 461 if isinstance(ignore, int): 462 check_value(ignore, [0, 255], "ignore") 463 if isinstance(ignore, (list, tuple)): 464 for item in ignore: 465 type_check(item, (int,), "item") 466 check_value(item, [0, 255], "ignore") 467 type_check(prob, (float, int,), "prob") 468 check_value(prob, [0., 1.], "prob") 469 470 return method(self, *args, **kwargs) 471 472 return new_method 473 474 475def check_prob(method): 476 """A wrapper that wraps a parameter checker (to confirm probability) around the original function.""" 477 478 @wraps(method) 479 def new_method(self, *args, **kwargs): 480 [prob], _ = parse_user_args(method, *args, **kwargs) 481 type_check(prob, (float, int,), "prob") 482 check_value(prob, [0., 1.], "prob") 483 484 return method(self, *args, **kwargs) 485 486 return new_method 487 488 489def check_alpha(method): 490 """A wrapper method to check alpha parameter in RandomLighting.""" 491 492 @wraps(method) 493 def new_method(self, *args, **kwargs): 494 [alpha], _ = parse_user_args(method, *args, **kwargs) 495 type_check(alpha, (float, int,), "alpha") 496 check_non_negative_float32(alpha, "alpha") 497 498 return method(self, *args, **kwargs) 499 500 return new_method 501 502 503def check_normalize(method): 504 """A wrapper that wraps a parameter checker around the original function.""" 505 506 @wraps(method) 507 def new_method(self, *args, **kwargs): 508 [mean, std, is_hwc], _ = parse_user_args(method, *args, **kwargs) 509 check_normalize_param(mean, std) 510 type_check(is_hwc, (bool,), "is_hwc") 511 return method(self, *args, **kwargs) 512 513 return new_method 514 515 516def check_normalize_py(method): 517 """A wrapper that wraps a parameter checker around the original function(normalize operation written in Python).""" 518 519 @wraps(method) 520 def new_method(self, *args, **kwargs): 521 [mean, std], _ = parse_user_args(method, *args, **kwargs) 522 check_normalize_py_param(mean, std) 523 524 return method(self, *args, **kwargs) 525 526 return new_method 527 528 529def check_normalize_c(method): 530 """A wrapper that wraps a parameter checker around the original function(normalize operation written in C++).""" 531 532 @wraps(method) 533 def new_method(self, *args, **kwargs): 534 [mean, std], _ = parse_user_args(method, *args, **kwargs) 535 check_normalize_c_param(mean, std) 536 537 return method(self, *args, **kwargs) 538 539 return new_method 540 541 542def check_normalizepad(method): 543 """A wrapper that wraps a parameter checker around the original function.""" 544 545 @wraps(method) 546 def new_method(self, *args, **kwargs): 547 [mean, std, dtype, is_hwc], _ = parse_user_args(method, *args, **kwargs) 548 check_normalize_param(mean, std) 549 type_check(is_hwc, (bool,), "is_hwc") 550 if not isinstance(dtype, str): 551 raise TypeError("dtype should be string.") 552 if dtype not in ["float32", "float16"]: 553 raise ValueError("dtype only supports float32 or float16.") 554 555 return method(self, *args, **kwargs) 556 557 return new_method 558 559 560def check_normalizepad_c(method): 561 """A wrapper that wraps a parameter checker around the original function(normalizepad written in C++).""" 562 563 @wraps(method) 564 def new_method(self, *args, **kwargs): 565 [mean, std, dtype], _ = parse_user_args(method, *args, **kwargs) 566 check_normalize_c_param(mean, std) 567 if not isinstance(dtype, str): 568 raise TypeError("dtype should be string.") 569 if dtype not in ["float32", "float16"]: 570 raise ValueError("dtype only support float32 or float16.") 571 572 return method(self, *args, **kwargs) 573 574 return new_method 575 576 577def check_normalizepad_py(method): 578 """A wrapper that wraps a parameter checker around the original function(normalizepad written in Python).""" 579 580 @wraps(method) 581 def new_method(self, *args, **kwargs): 582 [mean, std, dtype], _ = parse_user_args(method, *args, **kwargs) 583 check_normalize_py_param(mean, std) 584 if not isinstance(dtype, str): 585 raise TypeError("dtype should be string.") 586 if dtype not in ["float32", "float16"]: 587 raise ValueError("dtype only support float32 or float16.") 588 589 return method(self, *args, **kwargs) 590 591 return new_method 592 593 594def check_random_crop(method): 595 """Wrapper method to check the parameters of random crop.""" 596 597 @wraps(method) 598 def new_method(self, *args, **kwargs): 599 [size, padding, pad_if_needed, fill_value, padding_mode], _ = parse_user_args(method, *args, **kwargs) 600 check_crop_size(size) 601 type_check(pad_if_needed, (bool,), "pad_if_needed") 602 if padding is not None: 603 check_padding(padding) 604 if fill_value is not None: 605 check_fill_value(fill_value) 606 if padding_mode is not None: 607 type_check(padding_mode, (Border,), "padding_mode") 608 609 return method(self, *args, **kwargs) 610 611 return new_method 612 613 614def check_random_color_adjust(method): 615 """Wrapper method to check the parameters of random color adjust.""" 616 617 @wraps(method) 618 def new_method(self, *args, **kwargs): 619 [brightness, contrast, saturation, hue], _ = parse_user_args(method, *args, **kwargs) 620 check_random_color_adjust_param(brightness, "brightness") 621 check_random_color_adjust_param(contrast, "contrast") 622 check_random_color_adjust_param(saturation, "saturation") 623 check_random_color_adjust_param(hue, 'hue', center=0, bound=(-0.5, 0.5), non_negative=False) 624 625 return method(self, *args, **kwargs) 626 627 return new_method 628 629 630def check_resample_expand_center_fill_value_params(resample, expand, center, fill_value): 631 type_check(resample, (Inter,), "resample") 632 type_check(expand, (bool,), "expand") 633 if center is not None: 634 check_2tuple(center, "center") 635 for value in center: 636 type_check(value, (int, float), "center") 637 check_value(value, [INT32_MIN, INT32_MAX], "center") 638 check_fill_value(fill_value) 639 640 641def check_random_rotation(method): 642 """Wrapper method to check the parameters of random rotation.""" 643 644 @wraps(method) 645 def new_method(self, *args, **kwargs): 646 [degrees, resample, expand, center, fill_value], _ = parse_user_args(method, *args, **kwargs) 647 check_degrees(degrees) 648 check_resample_expand_center_fill_value_params(resample, expand, center, fill_value) 649 650 return method(self, *args, **kwargs) 651 652 return new_method 653 654 655def check_rotate(method): 656 """Wrapper method to check the parameters of rotate.""" 657 658 @wraps(method) 659 def new_method(self, *args, **kwargs): 660 [degrees, resample, expand, center, fill_value], _ = parse_user_args(method, *args, **kwargs) 661 type_check(degrees, (float, int), "degrees") 662 check_float32(degrees, "degrees") 663 check_resample_expand_center_fill_value_params(resample, expand, center, fill_value) 664 665 return method(self, *args, **kwargs) 666 667 return new_method 668 669 670def check_ten_crop(method): 671 """Wrapper method to check the parameters of crop.""" 672 673 @wraps(method) 674 def new_method(self, *args, **kwargs): 675 [size, use_vertical_flip], _ = parse_user_args(method, *args, **kwargs) 676 check_crop_size(size) 677 678 if use_vertical_flip is not None: 679 type_check(use_vertical_flip, (bool,), "use_vertical_flip") 680 681 return method(self, *args, **kwargs) 682 683 return new_method 684 685 686def check_num_channels(method): 687 """Wrapper method to check the parameters of number of channels.""" 688 689 @wraps(method) 690 def new_method(self, *args, **kwargs): 691 [num_output_channels], _ = parse_user_args(method, *args, **kwargs) 692 type_check(num_output_channels, (int,), "num_output_channels") 693 if num_output_channels not in (1, 3): 694 raise ValueError("Number of channels of the output grayscale image" 695 "should be either 1 or 3. Got {0}.".format(num_output_channels)) 696 697 return method(self, *args, **kwargs) 698 699 return new_method 700 701 702def check_pad(method): 703 """Wrapper method to check the parameters of random pad.""" 704 705 @wraps(method) 706 def new_method(self, *args, **kwargs): 707 [padding, fill_value, padding_mode], _ = parse_user_args(method, *args, **kwargs) 708 check_padding(padding) 709 check_fill_value(fill_value) 710 type_check(padding_mode, (Border,), "padding_mode") 711 712 return method(self, *args, **kwargs) 713 714 return new_method 715 716 717def check_pad_to_size(method): 718 """Wrapper method to check the parameters of PadToSize.""" 719 720 @wraps(method) 721 def new_method(self, *args, **kwargs): 722 [size, offset, fill_value, padding_mode], _ = parse_user_args(method, *args, **kwargs) 723 724 type_check(size, (int, list, tuple), "size") 725 if isinstance(size, int): 726 check_pos_int32(size, "size") 727 else: 728 if len(size) != 2: 729 raise ValueError("The size must be a sequence of length 2.") 730 for i, value in enumerate(size): 731 check_pos_int32(value, "size{0}".format(i)) 732 733 if offset is not None: 734 type_check(offset, (int, list, tuple), "offset") 735 if isinstance(offset, int): 736 check_non_negative_int32(offset, "offset") 737 else: 738 if len(offset) not in [0, 2]: 739 raise ValueError("The offset must be empty or a sequence of length 2.") 740 for i, _ in enumerate(offset): 741 check_non_negative_int32(offset[i], "offset{0}".format(i)) 742 743 check_fill_value(fill_value) 744 type_check(padding_mode, (Border,), "padding_mode") 745 746 return method(self, *args, **kwargs) 747 748 return new_method 749 750 751def check_perspective(method): 752 """Wrapper method to check the parameters of Perspective.""" 753 754 @wraps(method) 755 def new_method(self, *args, **kwargs): 756 [start_points, end_points, interpolation], _ = parse_user_args(method, *args, **kwargs) 757 758 type_check_list(start_points, (list, tuple), "start_points") 759 type_check_list(end_points, (list, tuple), "end_points") 760 761 if len(start_points) != 4: 762 raise TypeError("start_points should be a list or tuple of length 4.") 763 for i, element in enumerate(start_points): 764 type_check(element, (list, tuple), "start_points[{}]".format(i)) 765 if len(start_points[i]) != 2: 766 raise TypeError("start_points[{}] should be a list or tuple of length 2.".format(i)) 767 check_int32(element[0], "start_points[{}][0]".format(i)) 768 check_int32(element[1], "start_points[{}][1]".format(i)) 769 if len(end_points) != 4: 770 raise TypeError("end_points should be a list or tuple of length 4.") 771 for i, element in enumerate(end_points): 772 type_check(element, (list, tuple), "end_points[{}]".format(i)) 773 if len(end_points[i]) != 2: 774 raise TypeError("end_points[{}] should be a list or tuple of length 2.".format(i)) 775 check_int32(element[0], "end_points[{}][0]".format(i)) 776 check_int32(element[1], "end_points[{}][1]".format(i)) 777 778 type_check(interpolation, (Inter,), "interpolation") 779 780 return method(self, *args, **kwargs) 781 782 return new_method 783 784 785def check_slice_patches(method): 786 """Wrapper method to check the parameters of slice patches.""" 787 788 @wraps(method) 789 def new_method(self, *args, **kwargs): 790 [num_height, num_width, slice_mode, fill_value], _ = parse_user_args(method, *args, **kwargs) 791 if num_height is not None: 792 type_check(num_height, (int,), "num_height") 793 check_value(num_height, (1, INT32_MAX), "num_height") 794 if num_width is not None: 795 type_check(num_width, (int,), "num_width") 796 check_value(num_width, (1, INT32_MAX), "num_width") 797 if slice_mode is not None: 798 type_check(slice_mode, (SliceMode,), "slice_mode") 799 if fill_value is not None: 800 type_check(fill_value, (int,), "fill_value") 801 check_value(fill_value, [0, 255], "fill_value") 802 return method(self, *args, **kwargs) 803 804 return new_method 805 806 807def check_random_perspective(method): 808 """Wrapper method to check the parameters of random perspective.""" 809 810 @wraps(method) 811 def new_method(self, *args, **kwargs): 812 [distortion_scale, prob, interpolation], _ = parse_user_args(method, *args, **kwargs) 813 814 type_check(distortion_scale, (float,), "distortion_scale") 815 type_check(prob, (float,), "prob") 816 check_value(distortion_scale, [0., 1.], "distortion_scale") 817 check_value(prob, [0., 1.], "prob") 818 type_check(interpolation, (Inter,), "interpolation") 819 820 return method(self, *args, **kwargs) 821 822 return new_method 823 824 825def check_mix_up(method): 826 """Wrapper method to check the parameters of mix up.""" 827 828 @wraps(method) 829 def new_method(self, *args, **kwargs): 830 [batch_size, alpha, is_single], _ = parse_user_args(method, *args, **kwargs) 831 type_check(is_single, (bool,), "is_single") 832 type_check(batch_size, (int,), "batch_size") 833 type_check(alpha, (int, float), "alpha") 834 check_value(batch_size, (1, FLOAT_MAX_INTEGER)) 835 check_positive(alpha, "alpha") 836 return method(self, *args, **kwargs) 837 838 return new_method 839 840 841def check_rgb_to_bgr(method): 842 """Wrapper method to check the parameters of rgb_to_bgr.""" 843 844 @wraps(method) 845 def new_method(self, *args, **kwargs): 846 [is_hwc], _ = parse_user_args(method, *args, **kwargs) 847 type_check(is_hwc, (bool,), "is_hwc") 848 return method(self, *args, **kwargs) 849 850 return new_method 851 852 853def check_rgb_to_hsv(method): 854 """Wrapper method to check the parameters of rgb_to_hsv.""" 855 856 @wraps(method) 857 def new_method(self, *args, **kwargs): 858 [is_hwc], _ = parse_user_args(method, *args, **kwargs) 859 type_check(is_hwc, (bool,), "is_hwc") 860 return method(self, *args, **kwargs) 861 862 return new_method 863 864 865def check_hsv_to_rgb(method): 866 """Wrapper method to check the parameters of hsv_to_rgb.""" 867 868 @wraps(method) 869 def new_method(self, *args, **kwargs): 870 [is_hwc], _ = parse_user_args(method, *args, **kwargs) 871 type_check(is_hwc, (bool,), "is_hwc") 872 return method(self, *args, **kwargs) 873 874 return new_method 875 876 877def check_random_erasing(method): 878 """Wrapper method to check the parameters of random erasing.""" 879 880 @wraps(method) 881 def new_method(self, *args, **kwargs): 882 [prob, scale, ratio, value, inplace, max_attempts], _ = parse_user_args(method, *args, **kwargs) 883 884 type_check(prob, (float, int,), "prob") 885 type_check_list(scale, (float, int,), "scale") 886 if len(scale) != 2: 887 raise TypeError("scale should be a list or tuple of length 2.") 888 type_check_list(ratio, (float, int,), "ratio") 889 if len(ratio) != 2: 890 raise TypeError("ratio should be a list or tuple of length 2.") 891 type_check(value, (int, list, tuple, str), "value") 892 type_check(inplace, (bool,), "inplace") 893 type_check(max_attempts, (int,), "max_attempts") 894 check_erasing_value(value) 895 896 check_value(prob, [0., 1.], "prob") 897 if scale[0] > scale[1]: 898 raise ValueError("scale should be in (min,max) format. Got (max,min).") 899 check_range(scale, [0, FLOAT_MAX_INTEGER]) 900 check_positive(scale[1], "scale[1]") 901 if ratio[0] > ratio[1]: 902 raise ValueError("ratio should be in (min,max) format. Got (max,min).") 903 check_value_ratio(ratio[0], [0, FLOAT_MAX_INTEGER]) 904 check_value_ratio(ratio[1], [0, FLOAT_MAX_INTEGER]) 905 if isinstance(value, int): 906 check_value(value, (0, 255)) 907 if isinstance(value, (list, tuple)): 908 for item in value: 909 type_check(item, (int,), "value") 910 check_value(item, [0, 255], "value") 911 check_value(max_attempts, (1, FLOAT_MAX_INTEGER)) 912 913 return method(self, *args, **kwargs) 914 915 return new_method 916 917 918def check_cutout_new(method): 919 """Wrapper method to check the parameters of cutout operation.""" 920 921 @wraps(method) 922 def new_method(self, *args, **kwargs): 923 [length, num_patches, is_hwc], _ = parse_user_args(method, *args, **kwargs) 924 type_check(length, (int,), "length") 925 type_check(num_patches, (int,), "num_patches") 926 type_check(is_hwc, (bool,), "is_hwc") 927 check_value(length, (1, FLOAT_MAX_INTEGER)) 928 check_value(num_patches, (1, FLOAT_MAX_INTEGER)) 929 930 return method(self, *args, **kwargs) 931 932 return new_method 933 934 935def check_cutout(method): 936 """Wrapper method to check the parameters of cutout operation.""" 937 938 @wraps(method) 939 def new_method(self, *args, **kwargs): 940 [length, num_patches], _ = parse_user_args(method, *args, **kwargs) 941 type_check(length, (int,), "length") 942 type_check(num_patches, (int,), "num_patches") 943 check_value(length, (1, FLOAT_MAX_INTEGER)) 944 check_value(num_patches, (1, FLOAT_MAX_INTEGER)) 945 946 return method(self, *args, **kwargs) 947 948 return new_method 949 950 951def check_decode(method): 952 """Wrapper method to check the parameters of decode operation.""" 953 954 @wraps(method) 955 def new_method(self, *args, **kwargs): 956 [to_pil], _ = parse_user_args(method, *args, **kwargs) 957 type_check(to_pil, (bool,), "to_pil") 958 959 return method(self, *args, **kwargs) 960 961 return new_method 962 963 964def check_linear_transform(method): 965 """Wrapper method to check the parameters of linear transform.""" 966 967 @wraps(method) 968 def new_method(self, *args, **kwargs): 969 [transformation_matrix, mean_vector], _ = parse_user_args(method, *args, **kwargs) 970 type_check(transformation_matrix, (np.ndarray,), "transformation_matrix") 971 type_check(mean_vector, (np.ndarray,), "mean_vector") 972 973 if transformation_matrix.shape[0] != transformation_matrix.shape[1]: 974 raise ValueError("transformation_matrix should be a square matrix. " 975 "Got shape {} instead.".format(transformation_matrix.shape)) 976 if mean_vector.shape[0] != transformation_matrix.shape[0]: 977 raise ValueError("mean_vector length {0} should match either one dimension of the square" 978 "transformation_matrix {1}.".format(mean_vector.shape[0], transformation_matrix.shape)) 979 980 return method(self, *args, **kwargs) 981 982 return new_method 983 984 985def check_random_affine(method): 986 """Wrapper method to check the parameters of random affine.""" 987 988 @wraps(method) 989 def new_method(self, *args, **kwargs): 990 [degrees, translate, scale, shear, resample, fill_value], _ = parse_user_args(method, *args, **kwargs) 991 check_degrees(degrees) 992 993 if translate is not None: 994 type_check(translate, (list, tuple), "translate") 995 type_check_list(translate, (int, float), "translate") 996 if len(translate) != 2 and len(translate) != 4: 997 raise TypeError("translate should be a list or tuple of length 2 or 4.") 998 for i, t in enumerate(translate): 999 check_value(t, [-1.0, 1.0], "translate at {0}".format(i)) 1000 1001 if scale is not None: 1002 type_check(scale, (tuple, list), "scale") 1003 type_check_list(scale, (int, float), "scale") 1004 if len(scale) == 2: 1005 if scale[0] > scale[1]: 1006 raise ValueError("Input scale[1] must be equal to or greater than scale[0].") 1007 check_range(scale, [0, FLOAT_MAX_INTEGER]) 1008 check_positive(scale[1], "scale[1]") 1009 else: 1010 raise TypeError("scale should be a list or tuple of length 2.") 1011 1012 if shear is not None: 1013 type_check(shear, (numbers.Number, tuple, list), "shear") 1014 if isinstance(shear, numbers.Number): 1015 check_positive(shear, "shear") 1016 else: 1017 type_check_list(shear, (int, float), "shear") 1018 if len(shear) not in (2, 4): 1019 raise TypeError("shear must be of length 2 or 4.") 1020 if len(shear) == 2 and shear[0] > shear[1]: 1021 raise ValueError("Input shear[1] must be equal to or greater than shear[0]") 1022 if len(shear) == 4 and (shear[0] > shear[1] or shear[2] > shear[3]): 1023 raise ValueError("Input shear[1] must be equal to or greater than shear[0] and " 1024 "shear[3] must be equal to or greater than shear[2].") 1025 1026 type_check(resample, (Inter,), "resample") 1027 1028 if fill_value is not None: 1029 check_fill_value(fill_value) 1030 1031 return method(self, *args, **kwargs) 1032 1033 return new_method 1034 1035 1036def check_rescale(method): 1037 """Wrapper method to check the parameters of rescale.""" 1038 1039 @wraps(method) 1040 def new_method(self, *args, **kwargs): 1041 [rescale, shift], _ = parse_user_args(method, *args, **kwargs) 1042 type_check(rescale, (numbers.Number,), "rescale") 1043 type_check(shift, (numbers.Number,), "shift") 1044 check_float32(rescale, "rescale") 1045 check_float32(shift, "shift") 1046 1047 return method(self, *args, **kwargs) 1048 1049 return new_method 1050 1051 1052def check_uniform_augment_cpp(method): 1053 """Wrapper method to check the parameters of UniformAugment C++ op.""" 1054 1055 @wraps(method) 1056 def new_method(self, *args, **kwargs): 1057 [transforms, num_ops], _ = parse_user_args(method, *args, **kwargs) 1058 type_check(num_ops, (int,), "num_ops") 1059 check_positive(num_ops, "num_ops") 1060 1061 if num_ops > len(transforms): 1062 raise ValueError("num_ops is greater than transforms list size.") 1063 parsed_transforms = [] 1064 for op in transforms: 1065 if op and getattr(op, 'parse', None): 1066 parsed_transforms.append(op.parse()) 1067 else: 1068 parsed_transforms.append(op) 1069 type_check(parsed_transforms, (list, tuple,), "transforms") 1070 for index, arg in enumerate(parsed_transforms): 1071 if not isinstance(arg, (TensorOp, TensorOperation)): 1072 raise TypeError("Type of Transforms[{0}] must be c_transform, but got {1}".format(index, type(arg))) 1073 1074 return method(self, *args, **kwargs) 1075 1076 return new_method 1077 1078 1079def check_uniform_augment(method): 1080 """Wrapper method to check the parameters of UniformAugment Unified op.""" 1081 1082 @wraps(method) 1083 def new_method(self, *args, **kwargs): 1084 [transforms, num_ops], _ = parse_user_args(method, *args, **kwargs) 1085 type_check(num_ops, (int,), "num_ops") 1086 check_positive(num_ops, "num_ops") 1087 1088 if num_ops > len(transforms): 1089 raise ValueError("num_ops is greater than transforms list size.") 1090 1091 type_check(transforms, (list, tuple,), "transforms list") 1092 if not transforms: 1093 raise ValueError("transforms list can not be empty.") 1094 for ind, op in enumerate(transforms): 1095 check_tensor_op(op, "transforms[{0}]".format(ind)) 1096 check_transform_op_type(ind, op) 1097 1098 return method(self, *args, **kwargs) 1099 1100 return new_method 1101 1102 1103def check_bounding_box_augment_cpp(method): 1104 """Wrapper method to check the parameters of BoundingBoxAugment C++ op.""" 1105 1106 @wraps(method) 1107 def new_method(self, *args, **kwargs): 1108 [transform, ratio], _ = parse_user_args(method, *args, **kwargs) 1109 type_check(ratio, (float, int), "ratio") 1110 check_value(ratio, [0., 1.], "ratio") 1111 if transform and getattr(transform, 'parse', None): 1112 transform = transform.parse() 1113 type_check(transform, (TensorOp, TensorOperation), "transform") 1114 return method(self, *args, **kwargs) 1115 1116 return new_method 1117 1118 1119def check_adjust_brightness(method): 1120 """Wrapper method to check the parameters of AdjustBrightness ops (Python and C++).""" 1121 1122 @wraps(method) 1123 def new_method(self, *args, **kwargs): 1124 [brightness_factor], _ = parse_user_args(method, *args, **kwargs) 1125 type_check(brightness_factor, (float, int), "brightness_factor") 1126 check_value(brightness_factor, (0, FLOAT_MAX_INTEGER), "brightness_factor") 1127 return method(self, *args, **kwargs) 1128 1129 return new_method 1130 1131 1132def check_adjust_contrast(method): 1133 """Wrapper method to check the parameters of AdjustContrast ops (Python and C++).""" 1134 1135 @wraps(method) 1136 def new_method(self, *args, **kwargs): 1137 [contrast_factor], _ = parse_user_args(method, *args, **kwargs) 1138 type_check(contrast_factor, (float, int), "contrast_factor") 1139 check_value(contrast_factor, (0, FLOAT_MAX_INTEGER), "contrast_factor") 1140 return method(self, *args, **kwargs) 1141 1142 return new_method 1143 1144 1145def check_adjust_gamma(method): 1146 """Wrapper method to check the parameters of AdjustGamma ops (Python and C++).""" 1147 1148 @wraps(method) 1149 def new_method(self, *args, **kwargs): 1150 [gamma, gain], _ = parse_user_args(method, *args, **kwargs) 1151 type_check(gamma, (float, int), "gamma") 1152 check_value(gamma, (0, FLOAT_MAX_INTEGER)) 1153 if gain is not None: 1154 type_check(gain, (float, int), "gain") 1155 check_value(gain, (FLOAT_MIN_INTEGER, FLOAT_MAX_INTEGER)) 1156 return method(self, *args, **kwargs) 1157 1158 return new_method 1159 1160 1161def check_adjust_hue(method): 1162 """Wrapper method to check the parameters of AdjustHue ops (Python and C++).""" 1163 1164 @wraps(method) 1165 def new_method(self, *args, **kwargs): 1166 [hue_factor], _ = parse_user_args(method, *args, **kwargs) 1167 type_check(hue_factor, (float, int), "hue_factor") 1168 check_value(hue_factor, (-0.5, 0.5), "hue_factor") 1169 return method(self, *args, **kwargs) 1170 1171 return new_method 1172 1173 1174def check_adjust_saturation(method): 1175 """Wrapper method to check the parameters of AdjustSaturation ops (Python and C++).""" 1176 1177 @wraps(method) 1178 def new_method(self, *args, **kwargs): 1179 [saturation_factor], _ = parse_user_args(method, *args, **kwargs) 1180 type_check(saturation_factor, (float, int), "saturation_factor") 1181 check_value(saturation_factor, (0, FLOAT_MAX_INTEGER)) 1182 return method(self, *args, **kwargs) 1183 1184 return new_method 1185 1186 1187def check_adjust_sharpness(method): 1188 """Wrapper method to check the parameters of AdjustSharpness ops (Python and C++).""" 1189 1190 @wraps(method) 1191 def new_method(self, *args, **kwargs): 1192 [sharpness_factor], _ = parse_user_args(method, *args, **kwargs) 1193 type_check(sharpness_factor, (float, int), "sharpness_factor") 1194 check_value(sharpness_factor, (0, FLOAT_MAX_INTEGER)) 1195 return method(self, *args, **kwargs) 1196 1197 return new_method 1198 1199 1200def check_auto_contrast(method): 1201 """Wrapper method to check the parameters of AutoContrast ops (Python and C++).""" 1202 1203 @wraps(method) 1204 def new_method(self, *args, **kwargs): 1205 [cutoff, ignore], _ = parse_user_args(method, *args, **kwargs) 1206 type_check(cutoff, (int, float), "cutoff") 1207 check_value_cutoff(cutoff, [0, 50], "cutoff") 1208 if ignore is not None: 1209 type_check(ignore, (list, tuple, int), "ignore") 1210 if isinstance(ignore, int): 1211 check_value(ignore, [0, 255], "ignore") 1212 if isinstance(ignore, (list, tuple)): 1213 for item in ignore: 1214 type_check(item, (int,), "item") 1215 check_value(item, [0, 255], "ignore") 1216 return method(self, *args, **kwargs) 1217 1218 return new_method 1219 1220 1221def check_uniform_augment_py(method): 1222 """Wrapper method to check the parameters of Python UniformAugment op.""" 1223 1224 @wraps(method) 1225 def new_method(self, *args, **kwargs): 1226 [transforms, num_ops], _ = parse_user_args(method, *args, **kwargs) 1227 type_check(transforms, (list,), "transforms") 1228 1229 if not transforms: 1230 raise ValueError("transforms list is empty.") 1231 1232 for transform in transforms: 1233 if isinstance(transform, TensorOp): 1234 raise ValueError("transform list only accepts Python operations.") 1235 1236 type_check(num_ops, (int,), "num_ops") 1237 check_positive(num_ops, "num_ops") 1238 if num_ops > len(transforms): 1239 raise ValueError("num_ops cannot be greater than the length of transforms list.") 1240 1241 return method(self, *args, **kwargs) 1242 1243 return new_method 1244 1245 1246def check_positive_degrees(method): 1247 """A wrapper method to check degrees parameter in RandomSharpness and RandomColor ops (Python and C++)""" 1248 1249 @wraps(method) 1250 def new_method(self, *args, **kwargs): 1251 [degrees], _ = parse_user_args(method, *args, **kwargs) 1252 1253 if degrees is not None: 1254 if not isinstance(degrees, (list, tuple)): 1255 raise TypeError("degrees must be either a tuple or a list.") 1256 type_check_list(degrees, (int, float), "degrees") 1257 if len(degrees) != 2: 1258 raise ValueError("degrees must be a sequence with length 2.") 1259 for degree in degrees: 1260 check_value(degree, (0, FLOAT_MAX_INTEGER)) 1261 if degrees[0] > degrees[1]: 1262 raise ValueError("degrees should be in (min,max) format. Got (max,min).") 1263 1264 return method(self, *args, **kwargs) 1265 1266 return new_method 1267 1268 1269def check_random_select_subpolicy_op(method): 1270 """Wrapper method to check the parameters of RandomSelectSubpolicyOp.""" 1271 1272 @wraps(method) 1273 def new_method(self, *args, **kwargs): 1274 [policy], _ = parse_user_args(method, *args, **kwargs) 1275 type_check(policy, (list,), "policy") 1276 if not policy: 1277 raise ValueError("policy can not be empty.") 1278 for sub_ind, sub in enumerate(policy): 1279 type_check(sub, (list,), "policy[{0}]".format([sub_ind])) 1280 if not sub: 1281 raise ValueError("policy[{0}] can not be empty.".format(sub_ind)) 1282 for op_ind, tp in enumerate(sub): 1283 check_2tuple(tp, "policy[{0}][{1}]".format(sub_ind, op_ind)) 1284 check_c_tensor_op(tp[0], "op of (op, prob) in policy[{0}][{1}]".format(sub_ind, op_ind)) 1285 check_value(tp[1], (0, 1), "prob of (op, prob) policy[{0}][{1}]".format(sub_ind, op_ind)) 1286 1287 return method(self, *args, **kwargs) 1288 1289 return new_method 1290 1291 1292def check_random_solarize(method): 1293 """Wrapper method to check the parameters of RandomSolarizeOp.""" 1294 1295 @wraps(method) 1296 def new_method(self, *args, **kwargs): 1297 [threshold], _ = parse_user_args(method, *args, **kwargs) 1298 1299 type_check(threshold, (tuple,), "threshold") 1300 type_check_list(threshold, (int,), "threshold") 1301 if len(threshold) != 2: 1302 raise ValueError("threshold must be a sequence of two numbers.") 1303 for element in threshold: 1304 check_value(element, (0, UINT8_MAX)) 1305 if threshold[1] < threshold[0]: 1306 raise ValueError("threshold must be in min max format numbers.") 1307 1308 return method(self, *args, **kwargs) 1309 1310 return new_method 1311 1312 1313def check_gaussian_blur(method): 1314 """Wrapper method to check the parameters of GaussianBlur.""" 1315 1316 @wraps(method) 1317 def new_method(self, *args, **kwargs): 1318 [kernel_size, sigma], _ = parse_user_args(method, *args, **kwargs) 1319 1320 type_check(kernel_size, (int, list, tuple), "kernel_size") 1321 if isinstance(kernel_size, int): 1322 check_value(kernel_size, (1, FLOAT_MAX_INTEGER), "kernel_size") 1323 check_odd(kernel_size, "kernel_size") 1324 elif isinstance(kernel_size, (list, tuple)) and len(kernel_size) == 2: 1325 for index, value in enumerate(kernel_size): 1326 type_check(value, (int,), "kernel_size[{}]".format(index)) 1327 check_value(value, (1, FLOAT_MAX_INTEGER), "kernel_size") 1328 check_odd(value, "kernel_size[{}]".format(index)) 1329 else: 1330 raise TypeError( 1331 "Kernel size should be a single integer or a list/tuple (kernel_width, kernel_height) of length 2.") 1332 1333 if sigma is not None: 1334 type_check(sigma, (numbers.Number, list, tuple), "sigma") 1335 if isinstance(sigma, numbers.Number): 1336 check_value(sigma, (0, FLOAT_MAX_INTEGER), "sigma") 1337 elif isinstance(sigma, (list, tuple)) and len(sigma) == 2: 1338 for index, value in enumerate(sigma): 1339 type_check(value, (numbers.Number,), "size[{}]".format(index)) 1340 check_value(value, (0, FLOAT_MAX_INTEGER), "sigma") 1341 else: 1342 raise TypeError("Sigma should be a single number or a list/tuple of length 2 for width and height.") 1343 1344 return method(self, *args, **kwargs) 1345 1346 return new_method 1347 1348 1349def check_convert_color(method): 1350 """Wrapper method to check the parameters of convertcolor.""" 1351 1352 @wraps(method) 1353 def new_method(self, *args, **kwargs): 1354 [convert_mode], _ = parse_user_args(method, *args, **kwargs) 1355 type_check(convert_mode, (ConvertMode,), "convert_mode") 1356 return method(self, *args, **kwargs) 1357 1358 return new_method 1359 1360 1361def check_auto_augment(method): 1362 """Wrapper method to check the parameters of AutoAugment.""" 1363 1364 @wraps(method) 1365 def new_method(self, *args, **kwargs): 1366 [policy, interpolation, fill_value], _ = parse_user_args(method, *args, **kwargs) 1367 1368 type_check(policy, (AutoAugmentPolicy,), "policy") 1369 type_check(interpolation, (Inter,), "interpolation") 1370 check_fill_value(fill_value) 1371 return method(self, *args, **kwargs) 1372 1373 return new_method 1374 1375 1376def check_to_tensor(method): 1377 """Wrapper method to check the parameters of ToTensor.""" 1378 1379 @wraps(method) 1380 def new_method(self, *args, **kwargs): 1381 [output_type], _ = parse_user_args(method, *args, **kwargs) 1382 1383 # Check if output_type is mindspore.dtype 1384 if isinstance(output_type, (typing.Type,)): 1385 return method(self, *args, **kwargs) 1386 1387 # Special case: Check if output_type is None (which is invalid) 1388 if output_type is None: 1389 # Use type_check to raise error with descriptive error message 1390 type_check(output_type, (typing.Type, np.dtype,), "output_type") 1391 1392 try: 1393 # Check if output_type can be converted to numpy type 1394 _ = np.dtype(output_type) 1395 except (TypeError, ValueError): 1396 # Use type_check to raise error with descriptive error message 1397 type_check(output_type, (typing.Type, np.dtype,), "output_type") 1398 1399 return method(self, *args, **kwargs) 1400 1401 return new_method 1402 1403 1404def deprecated_c_vision(substitute_name=None, substitute_module=None): 1405 """Decorator for version 1.8 deprecation warning for legacy mindspore.dataset.vision.c_transforms operation. 1406 1407 Args: 1408 substitute_name (str, optional): The substitute name for deprecated operation. 1409 substitute_module (str, optional): The substitute module for deprecated operation. 1410 """ 1411 return deprecator_factory("1.8", "mindspore.dataset.vision.c_transforms", "mindspore.dataset.vision", 1412 substitute_name, substitute_module) 1413 1414 1415def deprecated_py_vision(substitute_name=None, substitute_module=None): 1416 """Decorator for version 1.8 deprecation warning for legacy mindspore.dataset.vision.py_transforms operation. 1417 1418 Args: 1419 substitute_name (str, optional): The substitute name for deprecated operation. 1420 substitute_module (str, optional): The substitute module for deprecated operation. 1421 """ 1422 return deprecator_factory("1.8", "mindspore.dataset.vision.py_transforms", "mindspore.dataset.vision", 1423 substitute_name, substitute_module) 1424 1425 1426def check_solarize(method): 1427 """Wrapper method to check the parameters of SolarizeOp.""" 1428 1429 @wraps(method) 1430 def new_method(self, *args, **kwargs): 1431 1432 [threshold], _ = parse_user_args(method, *args, **kwargs) 1433 type_check(threshold, (float, int, list, tuple), "threshold") 1434 if isinstance(threshold, (float, int)): 1435 threshold = (threshold, threshold) 1436 type_check_list(threshold, (float, int), "threshold") 1437 if len(threshold) != 2: 1438 raise TypeError("threshold must be a single number or sequence of two numbers.") 1439 for i, value in enumerate(threshold): 1440 check_value(value, (UINT8_MIN, UINT8_MAX), "threshold[{}]".format(i)) 1441 if threshold[1] < threshold[0]: 1442 raise ValueError("threshold must be in order of (min, max).") 1443 1444 return method(self, *args, **kwargs) 1445 1446 return new_method 1447 1448 1449def check_trivial_augment_wide(method): 1450 """Wrapper method to check the parameters of TrivialAugmentWide.""" 1451 1452 @wraps(method) 1453 def new_method(self, *args, **kwargs): 1454 [num_magnitude_bins, interpolation, fill_value], _ = parse_user_args(method, *args, **kwargs) 1455 type_check(num_magnitude_bins, (int,), "num_magnitude_bins") 1456 check_value(num_magnitude_bins, (2, FLOAT_MAX_INTEGER), "num_magnitude_bins") 1457 type_check(interpolation, (Inter,), "interpolation") 1458 check_fill_value(fill_value) 1459 return method(self, *args, **kwargs) 1460 1461 return new_method 1462 1463 1464def check_rand_augment(method): 1465 """Wrapper method to check the parameters of RandAugment.""" 1466 1467 @wraps(method) 1468 def new_method(self, *args, **kwargs): 1469 [num_ops, magnitude, num_magnitude_bins, interpolation, fill_value], _ = parse_user_args(method, *args, 1470 **kwargs) 1471 1472 type_check(num_ops, (int,), "num_ops") 1473 check_value(num_ops, (0, FLOAT_MAX_INTEGER), "num_ops") 1474 type_check(num_magnitude_bins, (int,), "num_magnitude_bins") 1475 check_value(num_magnitude_bins, (2, FLOAT_MAX_INTEGER), "num_magnitude_bins") 1476 type_check(magnitude, (int,), "magnitude") 1477 check_value(magnitude, (0, num_magnitude_bins), "magnitude", right_open_interval=True) 1478 type_check(interpolation, (Inter,), "interpolation") 1479 check_fill_value(fill_value) 1480 return method(self, *args, **kwargs) 1481 1482 return new_method 1483