1# Copyright 2020-2022 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15""" 16General Validators. 17""" 18from __future__ import absolute_import 19 20import inspect 21from multiprocessing import cpu_count 22import os 23 24import mindspore._c_dataengine as cde 25from mindspore import log as logger 26 27# POS_INT_MIN is used to limit values from starting from 0 28POS_INT_MIN = 1 29UINT8_MAX = 255 30UINT8_MIN = 0 31UINT32_MAX = 4294967295 32UINT32_MIN = 0 33UINT64_MAX = 18446744073709551615 34UINT64_MIN = 0 35INT32_MAX = 2147483647 36INT32_MIN = -2147483648 37INT64_MAX = 9223372036854775807 38INT64_MIN = -9223372036854775808 39FLOAT_MAX_INTEGER = 16777216 40FLOAT_MIN_INTEGER = -16777216 41DOUBLE_MAX_INTEGER = 9007199254740992 42DOUBLE_MIN_INTEGER = -9007199254740992 43 44valid_detype = [ 45 "bool", "int8", "int16", "int32", "int64", "uint8", "uint16", 46 "uint32", "uint64", "float16", "float32", "float64", "string" 47] 48 49 50def is_iterable(obj): 51 """ 52 Helper function to check if object is iterable. 53 54 Args: 55 obj (any): object to check if iterable 56 57 Returns: 58 bool, true if object iterable 59 """ 60 try: 61 iter(obj) 62 except TypeError: 63 return False 64 return True 65 66 67def pad_arg_name(arg_name): 68 """ 69 Appends a space to the arg_name (if not empty) 70 71 :param arg_name: the input string 72 :return: the padded string 73 """ 74 if arg_name != "": 75 arg_name = arg_name + " " 76 return arg_name 77 78 79def check_value(value, valid_range, arg_name="", left_open_interval=False, right_open_interval=False): 80 """ 81 Validates a value is within a desired range with left and right interval open or close. 82 83 :param value: the value to be validated. 84 :param valid_range: the desired range. 85 :param arg_name: name of the variable to be validated. 86 :param left_open_interval: True for left interval open and False for close. 87 :param right_open_interval: True for right interval open and False for close. 88 :return: Exception: when the validation fails, nothing otherwise. 89 """ 90 arg_name = pad_arg_name(arg_name) 91 if not left_open_interval and not right_open_interval: 92 if value < valid_range[0] or value > valid_range[1]: 93 raise ValueError( 94 "Input {0}is not within the required interval of [{1}, {2}].".format(arg_name, valid_range[0], 95 valid_range[1])) 96 elif left_open_interval and not right_open_interval: 97 if value <= valid_range[0] or value > valid_range[1]: 98 raise ValueError( 99 "Input {0}is not within the required interval of ({1}, {2}].".format(arg_name, valid_range[0], 100 valid_range[1])) 101 elif not left_open_interval and right_open_interval: 102 if value < valid_range[0] or value >= valid_range[1]: 103 raise ValueError( 104 "Input {0}is not within the required interval of [{1}, {2}).".format(arg_name, valid_range[0], 105 valid_range[1])) 106 else: 107 if value <= valid_range[0] or value >= valid_range[1]: 108 raise ValueError( 109 "Input {0}is not within the required interval of ({1}, {2}).".format(arg_name, valid_range[0], 110 valid_range[1])) 111 112 113def check_value_cutoff(value, valid_range, arg_name=""): 114 """ 115 Validates a value is within a desired range [inclusive, exclusive). 116 117 :param value: the value to be validated 118 :param valid_range: the desired range 119 :param arg_name: arg_name: arg_name: name of the variable to be validated 120 :return: Exception: when the validation fails, nothing otherwise. 121 """ 122 check_value(value, valid_range, arg_name, False, True) 123 124 125def check_value_ratio(value, valid_range, arg_name=""): 126 """ 127 Validates a value is within a desired range (exclusive, inclusive]. 128 129 :param value: the value to be validated 130 :param valid_range: the desired range 131 :param arg_name: arg_name: name of the variable to be validated 132 :return: Exception: when the validation fails, nothing otherwise. 133 """ 134 check_value(value, valid_range, arg_name, True, False) 135 136 137def check_value_normalize_std(value, valid_range, arg_name=""): 138 """ 139 Validates a value is within a desired range (exclusive, inclusive]. 140 141 :param value: the value to be validated 142 :param valid_range: the desired range 143 :param arg_name: arg_name: name of the variable to be validated 144 :return: Exception: when the validation fails, nothing otherwise. 145 """ 146 check_value(value, valid_range, arg_name, True, False) 147 148 149def check_range(values, valid_range, arg_name=""): 150 """ 151 Validates the boundaries a range are within a desired range [inclusive, inclusive]. 152 153 :param values: the two values to be validated 154 :param valid_range: the desired range 155 :param arg_name: arg_name: name of the variable to be validated 156 :return: Exception: when the validation fails, nothing otherwise. 157 """ 158 arg_name = pad_arg_name(arg_name) 159 if not valid_range[0] <= values[0] <= values[1] <= valid_range[1]: 160 raise ValueError( 161 "Input {0}is not within the required interval of [{1}, {2}].".format(arg_name, valid_range[0], 162 valid_range[1])) 163 164 165def check_positive(value, arg_name=""): 166 """ 167 Validates the value of a variable is positive. 168 169 :param value: the value of the variable 170 :param arg_name: name of the variable to be validated 171 :return: Exception: when the validation fails, nothing otherwise. 172 """ 173 arg_name = pad_arg_name(arg_name) 174 if value <= 0: 175 raise ValueError("Input {0}must be greater than 0.".format(arg_name)) 176 177 178def check_int32_not_zero(value, arg_name=""): 179 arg_name = pad_arg_name(arg_name) 180 type_check(value, (int,), arg_name) 181 if value < INT32_MIN or value > INT32_MAX or value == 0: 182 raise ValueError( 183 "Input {0}is not within the required interval of [-2147483648, 0) and (0, 2147483647].".format(arg_name)) 184 185 186def check_odd(value, arg_name=""): 187 arg_name = pad_arg_name(arg_name) 188 if value % 2 != 1: 189 raise ValueError( 190 "Input {0}is not an odd value.".format(arg_name)) 191 192 193def check_2tuple(value, arg_name=""): 194 """ 195 Validates a variable is a tuple with two entries. 196 197 :param value: the value of the variable 198 :param arg_name: name of the variable to be validated 199 :return: Exception: when the validation fails, nothing otherwise. 200 """ 201 if not (isinstance(value, tuple) and len(value) == 2): 202 raise ValueError("Value {0} needs to be a 2-tuple.".format(arg_name)) 203 204 205def check_int32(value, arg_name=""): 206 """ 207 Validates the value of a variable is within the range of int32. 208 209 :param value: the value of the variable 210 :param arg_name: name of the variable to be validated 211 :return: Exception: when the validation fails, nothing otherwise. 212 """ 213 type_check(value, (int,), arg_name) 214 check_value(value, [INT32_MIN, INT32_MAX], arg_name) 215 216 217def check_uint8(value, arg_name=""): 218 """ 219 Validates the value of a variable is within the range of uint8. 220 221 :param value: the value of the variable 222 :param arg_name: name of the variable to be validated 223 :return: Exception: when the validation fails, nothing otherwise. 224 """ 225 type_check(value, (int,), arg_name) 226 check_value(value, [UINT8_MIN, UINT8_MAX], arg_name) 227 228 229def check_uint32(value, arg_name=""): 230 """ 231 Validates the value of a variable is within the range of uint32. 232 233 :param value: the value of the variable 234 :param arg_name: name of the variable to be validated 235 :return: Exception: when the validation fails, nothing otherwise. 236 """ 237 type_check(value, (int,), arg_name) 238 check_value(value, [UINT32_MIN, UINT32_MAX], arg_name) 239 240 241def check_pos_uint32(value, arg_name=""): 242 """ 243 Validates the value of a variable is within the range of positive uint32. 244 245 :param value: the value of the variable 246 :param arg_name: name of the variable to be validated 247 :return: Exception: when the validation fails, nothing otherwise. 248 """ 249 type_check(value, (int,), arg_name) 250 check_value(value, [POS_INT_MIN, UINT32_MAX], arg_name) 251 252 253def check_pos_int32(value, arg_name=""): 254 """ 255 Validates the value of a variable is within the range of int32. 256 257 :param value: the value of the variable 258 :param arg_name: name of the variable to be validated 259 :return: Exception: when the validation fails, nothing otherwise. 260 """ 261 type_check(value, (int,), arg_name) 262 check_value(value, [POS_INT_MIN, INT32_MAX], arg_name) 263 264 265def check_uint64(value, arg_name=""): 266 """ 267 Validates the value of a variable is within the range of uint64. 268 269 :param value: the value of the variable 270 :param arg_name: name of the variable to be validated 271 :return: Exception: when the validation fails, nothing otherwise. 272 """ 273 type_check(value, (int,), arg_name) 274 check_value(value, [UINT64_MIN, UINT64_MAX], arg_name) 275 276 277def check_pos_int64(value, arg_name=""): 278 """ 279 Validates the value of a variable is within the range of int64. 280 281 :param value: the value of the variable 282 :param arg_name: name of the variable to be validated 283 :return: Exception: when the validation fails, nothing otherwise. 284 """ 285 type_check(value, (int,), arg_name) 286 check_value(value, [POS_INT_MIN, INT64_MAX], arg_name) 287 288 289def check_non_negative_int32(value, arg_name=""): 290 """ 291 Validates the value of a variable is within the range of non negative int32. 292 293 :param value: the value of the variable. 294 :param arg_name: name of the variable to be validated. 295 :return: Exception: when the validation fails, nothing otherwise. 296 """ 297 type_check(value, (int,), arg_name) 298 check_value(value, [0, INT32_MAX], arg_name) 299 300 301def check_float32(value, arg_name=""): 302 """ 303 Validates the value of a variable is within the range of float32. 304 305 :param value: the value of the variable 306 :param arg_name: name of the variable to be validated 307 :return: Exception: when the validation fails, nothing otherwise. 308 """ 309 type_check(value, (float, int), arg_name) 310 check_value(value, [FLOAT_MIN_INTEGER, FLOAT_MAX_INTEGER], arg_name) 311 312 313def check_float64(value, arg_name=""): 314 """ 315 Validates the value of a variable is within the range of float64. 316 317 :param value: the value of the variable 318 :param arg_name: name of the variable to be validated 319 :return: Exception: when the validation fails, nothing otherwise. 320 """ 321 type_check(value, (float, int), arg_name) 322 check_value(value, [DOUBLE_MIN_INTEGER, DOUBLE_MAX_INTEGER], arg_name) 323 324 325def check_pos_float32(value, arg_name=""): 326 """ 327 Validates the value of a variable is within the range of positive float32. 328 329 :param value: the value of the variable 330 :param arg_name: name of the variable to be validated 331 :return: Exception: when the validation fails, nothing otherwise. 332 """ 333 type_check(value, (float, int), arg_name) 334 check_value(value, [UINT32_MIN, FLOAT_MAX_INTEGER], arg_name, True) 335 336 337def check_pos_float64(value, arg_name=""): 338 """ 339 Validates the value of a variable is within the range of positive float64. 340 341 :param value: the value of the variable 342 :param arg_name: name of the variable to be validated 343 :return: Exception: when the validation fails, nothing otherwise. 344 """ 345 type_check(value, (float, int), arg_name) 346 check_value(value, [UINT64_MIN, DOUBLE_MAX_INTEGER], arg_name, True) 347 348 349def check_non_negative_float32(value, arg_name=""): 350 """ 351 Validates the value of a variable is within the range of non negative float32. 352 353 :param value: the value of the variable 354 :param arg_name: name of the variable to be validated 355 :return: Exception: when the validation fails, nothing otherwise. 356 """ 357 type_check(value, (float, int), arg_name) 358 check_value(value, [UINT32_MIN, FLOAT_MAX_INTEGER], arg_name) 359 360 361def check_non_negative_float64(value, arg_name=""): 362 """ 363 Validates the value of a variable is within the range of non negative float64. 364 365 :param value: the value of the variable 366 :param arg_name: name of the variable to be validated 367 :return: Exception: when the validation fails, nothing otherwise. 368 """ 369 type_check(value, (float, int), arg_name) 370 check_value(value, [UINT32_MIN, DOUBLE_MAX_INTEGER], arg_name) 371 372 373def check_float32_not_zero(value, arg_name=""): 374 arg_name = pad_arg_name(arg_name) 375 type_check(value, (float, int), arg_name) 376 if value < FLOAT_MIN_INTEGER or value > FLOAT_MAX_INTEGER or value == 0: 377 raise ValueError( 378 "Input {0}is not within the required interval of [-16777216, 0) and (0, 16777216].".format(arg_name)) 379 380 381def check_valid_detype(type_): 382 """ 383 Validates if a type is a DE Type. 384 385 :param type_: the type_ to be validated 386 :return: Exception: when the type is not a DE type, True otherwise. 387 """ 388 if type_ not in valid_detype: 389 raise TypeError("Unknown column type.") 390 return True 391 392 393def check_valid_str(value, valid_strings, arg_name=""): 394 """ 395 Validates the content stored in a string. 396 397 :param value: the value to be validated 398 :param valid_strings: a list/set of valid strings 399 :param arg_name: name of the variable to be validated 400 :return: Exception: when the type is not a DE type, nothing otherwise. 401 """ 402 type_check(value, (str,), arg_name) 403 if value not in valid_strings: 404 raise ValueError("Input {0} is not within the valid set of {1}.".format(arg_name, str(valid_strings))) 405 406 407def check_valid_list_tuple(value, valid_list_tuple, data_type, arg_name=""): 408 """ 409 Validate value in valid_list_tuple. 410 411 Args: 412 value (Union[list, tuple]): the value to be validated. 413 valid_strings (Union[list, tuple]): name of columns. 414 type (tuple): tuple of all valid types for value. 415 arg_name (str): the names of value. 416 Returns: 417 Exception: when the value is not correct, otherwise nothing. 418 """ 419 valid_length = len(valid_list_tuple[0]) 420 type_check(value, (list, tuple), arg_name) 421 type_check_list(value, data_type, arg_name) 422 if len(value) != valid_length: 423 raise ValueError("Input {0} is a list or tuple of length {1}.".format(arg_name, valid_length)) 424 if value not in valid_list_tuple: 425 raise ValueError( 426 "Input {0}{1} is not within the valid set of {2}.".format(arg_name, value, valid_list_tuple)) 427 428 429def check_columns(columns, name): 430 """ 431 Validate strings in column_names. 432 433 Args: 434 columns (list): list of column_names. 435 name (str): name of columns. 436 437 Returns: 438 Exception: when the value is not correct, otherwise nothing. 439 """ 440 type_check(columns, (list, str), name) 441 if isinstance(columns, str): 442 if not columns: 443 raise ValueError("{0} should not be an empty str.".format(name)) 444 elif isinstance(columns, list): 445 if not columns: 446 raise ValueError("{0} should not be empty.".format(name)) 447 for i, column_name in enumerate(columns): 448 if not column_name: 449 raise ValueError("{0}[{1}] should not be empty.".format(name, i)) 450 451 col_names = ["{0}[{1}]".format(name, i) for i in range(len(columns))] 452 type_check_list(columns, (str,), col_names) 453 if len(set(columns)) != len(columns): 454 raise ValueError("Every column name should not be same with others in column_names.") 455 456 457def parse_user_args(method, *args, **kwargs): 458 """ 459 Parse user arguments in a function. 460 461 Args: 462 method (method): a callable function. 463 args: user passed args. 464 kwargs: user passed kwargs. 465 466 Returns: 467 user_filled_args (list): values of what the user passed in for the arguments. 468 ba.arguments (Ordered Dict): ordered dict of parameter and argument for what the user has passed. 469 """ 470 sig = inspect.signature(method) 471 if 'self' in sig.parameters or 'cls' in sig.parameters: 472 ba = sig.bind(method, *args, **kwargs) 473 ba.apply_defaults() 474 params = list(sig.parameters.keys())[1:] 475 else: 476 ba = sig.bind(*args, **kwargs) 477 ba.apply_defaults() 478 params = list(sig.parameters.keys()) 479 480 user_filled_args = [ba.arguments.get(arg_value) for arg_value in params] 481 return user_filled_args, ba.arguments 482 483 484def type_check_list(args, types, arg_names): 485 """ 486 Check the type of each parameter in the list. 487 488 Args: 489 args (Union[list, tuple]): a list or tuple of any variable. 490 types (tuple): tuple of all valid types for arg. 491 arg_names (Union[list, tuple of str]): the names of args. 492 493 Returns: 494 Exception: when the type is not correct, otherwise nothing. 495 """ 496 type_check(args, (list, tuple,), arg_names) 497 if len(args) != len(arg_names) and not isinstance(arg_names, str): 498 raise ValueError("List of arguments is not the same length as argument_names.") 499 if isinstance(arg_names, str): 500 arg_names = ["{0}[{1}]".format(arg_names, i) for i in range(len(args))] 501 for arg, arg_name in zip(args, arg_names): 502 type_check(arg, types, arg_name) 503 504 505def type_check(arg, types, arg_name): 506 """ 507 Check the type of the parameter. 508 509 Args: 510 arg (Any) : any variable. 511 types (tuple): tuple of all valid types for arg. 512 arg_name (str): the name of arg. 513 514 Returns: 515 Exception: when the validation fails, otherwise nothing. 516 """ 517 518 if int in types and bool not in types: 519 if isinstance(arg, bool): 520 # handle special case of booleans being a subclass of ints 521 print_value = '\"\"' if repr(arg) == repr('') else arg 522 raise TypeError("Argument {0} with value {1} is not of type {2}, but got {3}.".format(arg_name, print_value, 523 types, type(arg))) 524 if not isinstance(arg, types): 525 print_value = '\"\"' if repr(arg) == repr('') else arg 526 raise TypeError("Argument {0} with value {1} is not of type {2}, but got {3}.".format(arg_name, print_value, 527 list(types), type(arg))) 528 529 530def check_filename(path): 531 """ 532 check the filename in the path. 533 534 Args: 535 path (str): the path. 536 537 Returns: 538 Exception: when error. 539 """ 540 if not isinstance(path, str): 541 raise TypeError("path: {} is not string".format(path)) 542 filename = os.path.basename(os.path.realpath(path)) 543 forbidden_symbols = set(r'\/:*?"<>|`&\';') 544 545 if set(filename) & forbidden_symbols: 546 raise ValueError(r"filename should not contain \/:*?\"<>|`&;\'") 547 548 if filename.startswith(' ') or filename.endswith(' '): 549 raise ValueError("filename should not start/end with space.") 550 551 552def check_dir(dataset_dir): 553 """ 554 Validates if the argument is a directory. 555 556 :param dataset_dir: string containing directory path 557 :return: Exception: when the validation fails, nothing otherwise. 558 """ 559 type_check(dataset_dir, (str,), "dataset_dir") 560 if not os.path.isdir(dataset_dir) or not os.access(dataset_dir, os.R_OK): 561 raise ValueError("The folder {} does not exist or is not a directory or permission denied!".format(dataset_dir)) 562 563 564def check_list_same_size(list1, list2, list1_name="", list2_name=""): 565 """ 566 Validates the two lists as the same size. 567 568 :param list1: the first list to be validated 569 :param list2: the secend list to be validated 570 :param list1_name: name of the list1 571 :param list2_name: name of the list2 572 :return: Exception: when the two list no same size, nothing otherwise. 573 """ 574 if len(list1) != len(list2): 575 raise ValueError("The size of {0} should be the same as that of {1}.".format(list1_name, list2_name)) 576 577 578def check_file(dataset_file): 579 """ 580 Validates if the argument is a valid file name. 581 582 :param dataset_file: string containing file path 583 :return: Exception: when the validation fails, nothing otherwise. 584 """ 585 check_filename(dataset_file) 586 dataset_file = os.path.realpath(dataset_file) 587 if not os.path.isfile(dataset_file) or not os.access(dataset_file, os.R_OK): 588 raise ValueError("The file {} does not exist or permission denied!".format(dataset_file)) 589 590 591def check_sampler_shuffle_shard_options(param_dict): 592 """ 593 Check for valid shuffle, sampler, num_shards, and shard_id inputs. 594 Args: 595 param_dict (dict): param_dict. 596 597 Returns: 598 Exception: ValueError or RuntimeError if error. 599 """ 600 shuffle, sampler = param_dict.get('shuffle'), param_dict.get('sampler') 601 num_shards, shard_id = param_dict.get('num_shards'), param_dict.get('shard_id') 602 num_samples = param_dict.get('num_samples') 603 604 if sampler is not None: 605 if shuffle is not None: 606 raise RuntimeError("sampler and shuffle cannot be specified at the same time.") 607 if num_shards is not None or shard_id is not None: 608 raise RuntimeError("sampler and sharding cannot be specified at the same time.") 609 if num_samples is not None: 610 raise RuntimeError("sampler and num_samples cannot be specified at the same time.") 611 612 if num_shards is not None: 613 check_pos_int32(num_shards, "num_shards") 614 if shard_id is None: 615 raise RuntimeError("num_shards is specified and currently requires shard_id as well.") 616 check_value(shard_id, [0, num_shards - 1], "shard_id") 617 618 if num_shards is None and shard_id is not None: 619 raise RuntimeError("shard_id is specified but num_shards is not.") 620 621 622def check_padding_options(param_dict): 623 """ 624 Check for valid padded_sample and num_padded of padded samples. 625 626 Args: 627 param_dict (dict): param_dict. 628 629 Returns: 630 Exception: ValueError or RuntimeError if error. 631 """ 632 633 columns_list = param_dict.get('columns_list') 634 padded_sample, num_padded = param_dict.get('padded_sample'), param_dict.get('num_padded') 635 if padded_sample is not None: 636 if num_padded is None: 637 raise RuntimeError("padded_sample is specified and requires num_padded as well.") 638 if num_padded < 0: 639 raise ValueError("num_padded is invalid, num_padded={}.".format(num_padded)) 640 if columns_list is None: 641 raise RuntimeError("padded_sample is specified and requires columns_list as well.") 642 for column in columns_list: 643 if column not in padded_sample: 644 raise ValueError("padded_sample cannot match columns_list.") 645 if padded_sample is None and num_padded is not None: 646 raise RuntimeError("num_padded is specified but padded_sample is not.") 647 648 649def check_num_parallel_workers(value): 650 """ 651 Validates the value for num_parallel_workers. 652. 653 :param value: an integer corresponding to the number of parallel workers 654 :return: Exception: when the validation fails, nothing otherwise. 655 """ 656 type_check(value, (int,), "num_parallel_workers") 657 if value < 1 or value > cpu_count(): 658 raise ValueError("num_parallel_workers exceeds the boundary between 1 and {}!".format(cpu_count())) 659 660 661def check_num_samples(value): 662 """ 663 Validates number of samples are valid. 664. 665 :param value: an integer corresponding to the number of samples. 666 :return: Exception: when the validation fails, nothing otherwise. 667 """ 668 type_check(value, (int,), "num_samples") 669 if value < 0 or value > INT64_MAX: 670 raise ValueError( 671 "num_samples exceeds the boundary between {} and {}(INT64_MAX)!".format(0, INT64_MAX)) 672 673 674def validate_dataset_param_value(param_list, param_dict, param_type): 675 """ 676 677 :param param_list: a list of parameter names. 678 :param param_dict: a dcitionary containing parameter names and their values. 679 :param param_type: a tuple containing type of parameters. 680 :return: Exception: when the validation fails, nothing otherwise. 681 """ 682 for param_name in param_list: 683 if param_dict.get(param_name) is not None: 684 if param_name == 'num_parallel_workers': 685 check_num_parallel_workers(param_dict.get(param_name)) 686 if param_name == 'num_samples': 687 check_num_samples(param_dict.get(param_name)) 688 else: 689 type_check(param_dict.get(param_name), (param_type,), param_name) 690 691 692def check_tensor_op(param, param_name): 693 """check whether param is a tensor op or a callable Python function""" 694 if not isinstance(param, cde.TensorOp) and not callable(param) and not getattr(param, 'parse', None): 695 raise TypeError("{0} is neither a transforms op (TensorOperation) nor a callable pyfunc.".format(param_name)) 696 697 698def check_c_tensor_op(param, param_name): 699 """check whether param is a tensor op or a callable Python function but not a py_transform""" 700 if callable(param) and str(param).find("py_transform") >= 0: 701 raise TypeError("{0} is a py_transform op which is not allowed to use.".format(param_name)) 702 if not isinstance(param, cde.TensorOp) and not callable(param) and not getattr(param, 'parse', None): 703 raise TypeError("{0} is neither a c_transform op (TensorOperation) nor a callable pyfunc.".format(param_name)) 704 705 706def replace_none(value, default): 707 """ replaces None with a default value.""" 708 return value if value is not None else default 709 710 711def check_dataset_num_shards_shard_id(num_shards, shard_id): 712 if (num_shards is None) != (shard_id is None): 713 # These two parameters appear together. 714 raise ValueError("num_shards and shard_id need to be passed in together.") 715 if num_shards is not None: 716 check_pos_int32(num_shards, "num_shards") 717 if shard_id >= num_shards: 718 raise ValueError("shard_id should be less than num_shards.") 719 720 721def deprecator_factory(version, old_module, new_module, substitute_name=None, substitute_module=None): 722 """Decorator factory function for deprecated operation to log deprecation warning message. 723 724 Args: 725 version (str): Version that the operation is deprecated. 726 old_module (str): Old module for deprecated operation. 727 new_module (str): New module for deprecated operation. 728 substitute_name (str, optional): The substitute name for deprecated operation. 729 substitute_module (str, optional): The substitute module for deprecated operation. 730 """ 731 732 def decorator(op): 733 def wrapper(*args, **kwargs): 734 # Get operation class name for operation class which applies decorator to __init__() 735 name = str(op).split()[1].split(".")[0] 736 # Build message 737 message = f"'{name}' from " + f"{old_module}" + f" is deprecated from version " f"{version}" + \ 738 " and will be removed in a future version." 739 message += f" Use '{substitute_name}'" if substitute_name else f" Use '{name}'" 740 message += f" from {substitute_module} instead." if substitute_module \ 741 else f" from " f"{new_module}" + " instead." 742 743 # Log warning message 744 logger.warning(message) 745 746 ret = op(*args, **kwargs) 747 return ret 748 749 return wrapper 750 751 return decorator 752 753 754def check_dict(data, key_type, value_type, param_name): 755 """ check key and value type in dict.""" 756 if data is not None: 757 if not isinstance(data, dict): 758 raise TypeError("{0} should be dict type, but got: {1}".format(param_name, type(data))) 759 760 for key, value in data.items(): 761 if not isinstance(key, key_type): 762 raise TypeError("key '{0}' in parameter {1} should be {2} type, but got: {3}" 763 .format(key, param_name, key_type, type(key))) 764 if not isinstance(value, value_type): 765 raise TypeError("value of '{0}' in parameter {1} should be {2} type, but got: {3}" 766 .format(key, param_name, value_type, type(value))) 767 768 769def check_feature_shape(data, shape, param_name): 770 if isinstance(data, dict): 771 for key, value in data.items(): 772 if len(value.shape) != 2 or value.shape[0] != shape: 773 raise ValueError("Shape of '{0}' in '{1}' should be of 2 dimension, and shape of first dimension " 774 "should be: {2}, but got: {3}.".format(key, param_name, shape, value.shape)) 775