1# Copyright 2020-2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15""" 16General Validators. 17""" 18import inspect 19from multiprocessing import cpu_count 20import os 21import numpy as np 22 23import mindspore._c_dataengine as cde 24 25# POS_INT_MIN is used to limit values from starting from 0 26POS_INT_MIN = 1 27UINT8_MAX = 255 28UINT8_MIN = 0 29UINT32_MAX = 4294967295 30UINT32_MIN = 0 31UINT64_MAX = 18446744073709551615 32UINT64_MIN = 0 33INT32_MAX = 2147483647 34INT32_MIN = -2147483648 35INT64_MAX = 9223372036854775807 36INT64_MIN = -9223372036854775808 37FLOAT_MAX_INTEGER = 16777216 38FLOAT_MIN_INTEGER = -16777216 39DOUBLE_MAX_INTEGER = 9007199254740992 40DOUBLE_MIN_INTEGER = -9007199254740992 41 42valid_detype = [ 43 "bool", "int8", "int16", "int32", "int64", "uint8", "uint16", 44 "uint32", "uint64", "float16", "float32", "float64", "string" 45] 46 47 48def is_iterable(obj): 49 """ 50 Helper function to check if object is iterable. 51 52 Args: 53 obj (any): object to check if iterable 54 55 Returns: 56 bool, true if object iteratable 57 """ 58 try: 59 iter(obj) 60 except TypeError: 61 return False 62 return True 63 64 65def pad_arg_name(arg_name): 66 """ 67 Appends a space to the arg_name (if not empty) 68 69 :param arg_name: the input string 70 :return: the padded string 71 """ 72 if arg_name != "": 73 arg_name = arg_name + " " 74 return arg_name 75 76 77def check_value(value, valid_range, arg_name="", left_open_interval=False, right_open_interval=False): 78 """ 79 Validates a value is within a desired range with left and right interval open or close. 80 81 :param value: the value to be validated. 82 :param valid_range: the desired range. 83 :param arg_name: name of the variable to be validated. 84 :param left_open_interval: True for left interval open and False for close. 85 :param right_open_interval: True for right interval open and False for close. 86 :return: Exception: when the validation fails, nothing otherwise. 87 """ 88 arg_name = pad_arg_name(arg_name) 89 if not left_open_interval and not right_open_interval: 90 if value < valid_range[0] or value > valid_range[1]: 91 raise ValueError( 92 "Input {0}is not within the required interval of [{1}, {2}].".format(arg_name, valid_range[0], 93 valid_range[1])) 94 elif left_open_interval and not right_open_interval: 95 if value <= valid_range[0] or value > valid_range[1]: 96 raise ValueError( 97 "Input {0}is not within the required interval of ({1}, {2}].".format(arg_name, valid_range[0], 98 valid_range[1])) 99 elif not left_open_interval and right_open_interval: 100 if value < valid_range[0] or value >= valid_range[1]: 101 raise ValueError( 102 "Input {0}is not within the required interval of [{1}, {2}).".format(arg_name, valid_range[0], 103 valid_range[1])) 104 else: 105 if value <= valid_range[0] or value >= valid_range[1]: 106 raise ValueError( 107 "Input {0}is not within the required interval of ({1}, {2}).".format(arg_name, valid_range[0], 108 valid_range[1])) 109 110 111def check_value_cutoff(value, valid_range, arg_name=""): 112 """ 113 Validates a value is within a desired range [inclusive, exclusive). 114 115 :param value: the value to be validated 116 :param valid_range: the desired range 117 :param arg_name: arg_name: arg_name: name of the variable to be validated 118 :return: Exception: when the validation fails, nothing otherwise. 119 """ 120 check_value(value, valid_range, arg_name, False, True) 121 122 123def check_value_ratio(value, valid_range, arg_name=""): 124 """ 125 Validates a value is within a desired range (exclusive, inclusive]. 126 127 :param value: the value to be validated 128 :param valid_range: the desired range 129 :param arg_name: arg_name: name of the variable to be validated 130 :return: Exception: when the validation fails, nothing otherwise. 131 """ 132 check_value(value, valid_range, arg_name, True, False) 133 134 135def check_value_normalize_std(value, valid_range, arg_name=""): 136 """ 137 Validates a value is within a desired range (exclusive, inclusive]. 138 139 :param value: the value to be validated 140 :param valid_range: the desired range 141 :param arg_name: arg_name: name of the variable to be validated 142 :return: Exception: when the validation fails, nothing otherwise. 143 """ 144 check_value(value, valid_range, arg_name, True, False) 145 146 147def check_range(values, valid_range, arg_name=""): 148 """ 149 Validates the boundaries a range are within a desired range [inclusive, inclusive]. 150 151 :param values: the two values to be validated 152 :param valid_range: the desired range 153 :param arg_name: arg_name: name of the variable to be validated 154 :return: Exception: when the validation fails, nothing otherwise. 155 """ 156 arg_name = pad_arg_name(arg_name) 157 if not valid_range[0] <= values[0] <= values[1] <= valid_range[1]: 158 raise ValueError( 159 "Input {0}is not within the required interval of [{1}, {2}].".format(arg_name, valid_range[0], 160 valid_range[1])) 161 162 163def check_positive(value, arg_name=""): 164 """ 165 Validates the value of a variable is positive. 166 167 :param value: the value of the variable 168 :param arg_name: name of the variable to be validated 169 :return: Exception: when the validation fails, nothing otherwise. 170 """ 171 arg_name = pad_arg_name(arg_name) 172 if value <= 0: 173 raise ValueError("Input {0}must be greater than 0.".format(arg_name)) 174 175 176def check_int32_not_zero(value, arg_name=""): 177 arg_name = pad_arg_name(arg_name) 178 type_check(value, (int,), arg_name) 179 if value < INT32_MIN or value > INT32_MAX or value == 0: 180 raise ValueError( 181 "Input {0}is not within the required interval of [-2147483648, 0) and (0, 2147483647].".format(arg_name)) 182 183 184def check_odd(value, arg_name=""): 185 arg_name = pad_arg_name(arg_name) 186 if value % 2 != 1: 187 raise ValueError( 188 "Input {0}is not an odd value.".format(arg_name)) 189 190 191def check_2tuple(value, arg_name=""): 192 """ 193 Validates a variable is a tuple with two entries. 194 195 :param value: the value of the variable 196 :param arg_name: name of the variable to be validated 197 :return: Exception: when the validation fails, nothing otherwise. 198 """ 199 if not (isinstance(value, tuple) and len(value) == 2): 200 raise ValueError("Value {0} needs to be a 2-tuple.".format(arg_name)) 201 202 203def check_int32(value, arg_name=""): 204 """ 205 Validates the value of a variable is within the range of int32. 206 207 :param value: the value of the variable 208 :param arg_name: name of the variable to be validated 209 :return: Exception: when the validation fails, nothing otherwise. 210 """ 211 type_check(value, (int,), arg_name) 212 check_value(value, [INT32_MIN, INT32_MAX], arg_name) 213 214 215def check_uint8(value, arg_name=""): 216 """ 217 Validates the value of a variable is within the range of uint8. 218 219 :param value: the value of the variable 220 :param arg_name: name of the variable to be validated 221 :return: Exception: when the validation fails, nothing otherwise. 222 """ 223 type_check(value, (int,), arg_name) 224 check_value(value, [UINT8_MIN, UINT8_MAX]) 225 226 227def check_uint32(value, arg_name=""): 228 """ 229 Validates the value of a variable is within the range of uint32. 230 231 :param value: the value of the variable 232 :param arg_name: name of the variable to be validated 233 :return: Exception: when the validation fails, nothing otherwise. 234 """ 235 type_check(value, (int,), arg_name) 236 check_value(value, [UINT32_MIN, UINT32_MAX]) 237 238 239def check_pos_uint32(value, arg_name=""): 240 """ 241 Validates the value of a variable is within the range of positive uint32. 242 243 :param value: the value of the variable 244 :param arg_name: name of the variable to be validated 245 :return: Exception: when the validation fails, nothing otherwise. 246 """ 247 type_check(value, (int,), arg_name) 248 check_value(value, [POS_INT_MIN, UINT32_MAX]) 249 250 251def check_pos_int32(value, arg_name=""): 252 """ 253 Validates the value of a variable is within the range of int32. 254 255 :param value: the value of the variable 256 :param arg_name: name of the variable to be validated 257 :return: Exception: when the validation fails, nothing otherwise. 258 """ 259 type_check(value, (int,), arg_name) 260 check_value(value, [POS_INT_MIN, INT32_MAX], arg_name) 261 262 263def check_uint64(value, arg_name=""): 264 """ 265 Validates the value of a variable is within the range of uint64. 266 267 :param value: the value of the variable 268 :param arg_name: name of the variable to be validated 269 :return: Exception: when the validation fails, nothing otherwise. 270 """ 271 type_check(value, (int,), arg_name) 272 check_value(value, [UINT64_MIN, UINT64_MAX]) 273 274 275def check_pos_int64(value, arg_name=""): 276 """ 277 Validates the value of a variable is within the range of int64. 278 279 :param value: the value of the variable 280 :param arg_name: name of the variable to be validated 281 :return: Exception: when the validation fails, nothing otherwise. 282 """ 283 type_check(value, (int,), arg_name) 284 check_value(value, [POS_INT_MIN, INT64_MAX]) 285 286 287def check_non_negative_int32(value, arg_name=""): 288 """ 289 Validates the value of a variable is within the range of non negative int32. 290 291 :param value: the value of the variable. 292 :param arg_name: name of the variable to be validated. 293 :return: Exception: when the validation fails, nothing otherwise. 294 """ 295 check_value(value, [UINT32_MIN, INT32_MAX], arg_name) 296 297 298def check_float32(value, arg_name=""): 299 """ 300 Validates the value of a variable is within the range of float32. 301 302 :param value: the value of the variable 303 :param arg_name: name of the variable to be validated 304 :return: Exception: when the validation fails, nothing otherwise. 305 """ 306 check_value(value, [FLOAT_MIN_INTEGER, FLOAT_MAX_INTEGER], arg_name) 307 308 309def check_float64(value, arg_name=""): 310 """ 311 Validates the value of a variable is within the range of float64. 312 313 :param value: the value of the variable 314 :param arg_name: name of the variable to be validated 315 :return: Exception: when the validation fails, nothing otherwise. 316 """ 317 check_value(value, [DOUBLE_MIN_INTEGER, DOUBLE_MAX_INTEGER], arg_name) 318 319 320def check_pos_float32(value, arg_name=""): 321 """ 322 Validates the value of a variable is within the range of positive float32. 323 324 :param value: the value of the variable 325 :param arg_name: name of the variable to be validated 326 :return: Exception: when the validation fails, nothing otherwise. 327 """ 328 check_value(value, [UINT32_MIN, FLOAT_MAX_INTEGER], arg_name, True) 329 330 331def check_pos_float64(value, arg_name=""): 332 """ 333 Validates the value of a variable is within the range of positive float64. 334 335 :param value: the value of the variable 336 :param arg_name: name of the variable to be validated 337 :return: Exception: when the validation fails, nothing otherwise. 338 """ 339 check_value(value, [UINT64_MIN, DOUBLE_MAX_INTEGER], arg_name, True) 340 341 342def check_non_negative_float32(value, arg_name=""): 343 """ 344 Validates the value of a variable is within the range of non negative float32. 345 346 :param value: the value of the variable 347 :param arg_name: name of the variable to be validated 348 :return: Exception: when the validation fails, nothing otherwise. 349 """ 350 check_value(value, [UINT32_MIN, FLOAT_MAX_INTEGER], arg_name) 351 352 353def check_non_negative_float64(value, arg_name=""): 354 """ 355 Validates the value of a variable is within the range of non negative float64. 356 357 :param value: the value of the variable 358 :param arg_name: name of the variable to be validated 359 :return: Exception: when the validation fails, nothing otherwise. 360 """ 361 check_value(value, [UINT32_MIN, DOUBLE_MAX_INTEGER], arg_name) 362 363 364def check_float32_not_zero(value, arg_name=""): 365 arg_name = pad_arg_name(arg_name) 366 type_check(value, (int,), arg_name) 367 if value < FLOAT_MIN_INTEGER or value > FLOAT_MAX_INTEGER or value == 0: 368 raise ValueError( 369 "Input {0}is not within the required interval of [-16777216, 0) and (0, 16777216].".format(arg_name)) 370 371 372def check_valid_detype(type_): 373 """ 374 Validates if a type is a DE Type. 375 376 :param type_: the type_ to be validated 377 :return: Exception: when the type is not a DE type, True otherwise. 378 """ 379 if type_ not in valid_detype: 380 raise TypeError("Unknown column type.") 381 return True 382 383 384def check_valid_str(value, valid_strings, arg_name=""): 385 """ 386 Validates the content stored in a string. 387 388 :param value: the value to be validated 389 :param valid_strings: a list/set of valid strings 390 :param arg_name: name of the variable to be validated 391 :return: Exception: when the type is not a DE type, nothing otherwise. 392 """ 393 type_check(value, (str,), arg_name) 394 if value not in valid_strings: 395 raise ValueError("Input {0} is not within the valid set of {1}.".format(arg_name, str(valid_strings))) 396 397 398def check_columns(columns, name): 399 """ 400 Validate strings in column_names. 401 402 Args: 403 columns (list): list of column_names. 404 name (str): name of columns. 405 406 Returns: 407 Exception: when the value is not correct, otherwise nothing. 408 """ 409 type_check(columns, (list, str), name) 410 if isinstance(columns, str): 411 if not columns: 412 raise ValueError("{0} should not be an empty str.".format(name)) 413 elif isinstance(columns, list): 414 if not columns: 415 raise ValueError("{0} should not be empty.".format(name)) 416 for i, column_name in enumerate(columns): 417 if not column_name: 418 raise ValueError("{0}[{1}] should not be empty.".format(name, i)) 419 420 col_names = ["{0}[{1}]".format(name, i) for i in range(len(columns))] 421 type_check_list(columns, (str,), col_names) 422 if len(set(columns)) != len(columns): 423 raise ValueError("Every column name should not be same with others in column_names.") 424 425 426def parse_user_args(method, *args, **kwargs): 427 """ 428 Parse user arguments in a function. 429 430 Args: 431 method (method): a callable function. 432 args: user passed args. 433 kwargs: user passed kwargs. 434 435 Returns: 436 user_filled_args (list): values of what the user passed in for the arguments. 437 ba.arguments (Ordered Dict): ordered dict of parameter and argument for what the user has passed. 438 """ 439 sig = inspect.signature(method) 440 if 'self' in sig.parameters or 'cls' in sig.parameters: 441 ba = sig.bind(method, *args, **kwargs) 442 ba.apply_defaults() 443 params = list(sig.parameters.keys())[1:] 444 else: 445 ba = sig.bind(*args, **kwargs) 446 ba.apply_defaults() 447 params = list(sig.parameters.keys()) 448 449 user_filled_args = [ba.arguments.get(arg_value) for arg_value in params] 450 return user_filled_args, ba.arguments 451 452 453def type_check_list(args, types, arg_names): 454 """ 455 Check the type of each parameter in the list. 456 457 Args: 458 args (Union[list, tuple]): a list or tuple of any variable. 459 types (tuple): tuple of all valid types for arg. 460 arg_names (Union[list, tuple of str]): the names of args. 461 462 Returns: 463 Exception: when the type is not correct, otherwise nothing. 464 """ 465 type_check(args, (list, tuple,), arg_names) 466 if len(args) != len(arg_names) and not isinstance(arg_names, str): 467 raise ValueError("List of arguments is not the same length as argument_names.") 468 if isinstance(arg_names, str): 469 arg_names = ["{0}[{1}]".format(arg_names, i) for i in range(len(args))] 470 for arg, arg_name in zip(args, arg_names): 471 type_check(arg, types, arg_name) 472 473 474def type_check(arg, types, arg_name): 475 """ 476 Check the type of the parameter. 477 478 Args: 479 arg (Any) : any variable. 480 types (tuple): tuple of all valid types for arg. 481 arg_name (str): the name of arg. 482 483 Returns: 484 Exception: when the validation fails, otherwise nothing. 485 """ 486 # handle special case of booleans being a subclass of ints 487 print_value = '\"\"' if repr(arg) == repr('') else arg 488 489 if int in types and bool not in types: 490 if isinstance(arg, bool): 491 raise TypeError("Argument {0} with value {1} is not of type {2}, but got {3}.".format(arg_name, print_value, 492 types, type(arg))) 493 if not isinstance(arg, types): 494 raise TypeError("Argument {0} with value {1} is not of type {2}, but got {3}.".format(arg_name, print_value, 495 list(types), type(arg))) 496 497 498def check_filename(path): 499 """ 500 check the filename in the path. 501 502 Args: 503 path (str): the path. 504 505 Returns: 506 Exception: when error. 507 """ 508 if not isinstance(path, str): 509 raise TypeError("path: {} is not string".format(path)) 510 filename = os.path.basename(os.path.realpath(path)) 511 forbidden_symbols = set(r'\/:*?"<>|`&\';') 512 513 if set(filename) & forbidden_symbols: 514 raise ValueError(r"filename should not contain \/:*?\"<>|`&;\'") 515 516 if filename.startswith(' ') or filename.endswith(' '): 517 raise ValueError("filename should not start/end with space.") 518 519 520def check_dir(dataset_dir): 521 """ 522 Validates if the argument is a directory. 523 524 :param dataset_dir: string containing directory path 525 :return: Exception: when the validation fails, nothing otherwise. 526 """ 527 type_check(dataset_dir, (str,), "dataset_dir") 528 if not os.path.isdir(dataset_dir) or not os.access(dataset_dir, os.R_OK): 529 raise ValueError("The folder {} does not exist or is not a directory or permission denied!".format(dataset_dir)) 530 531 532def check_list_same_size(list1, list2, list1_name="", list2_name=""): 533 """ 534 Validates the two lists as the same size. 535 536 :param list1: the first list to be validated 537 :param list2: the secend list to be validated 538 :param list1_name: name of the list1 539 :param list2_name: name of the list2 540 :return: Exception: when the two list no same size, nothing otherwise. 541 """ 542 if len(list1) != len(list2): 543 raise ValueError("The size of {0} should be the same as that of {1}.".format(list1_name, list2_name)) 544 545 546def check_file(dataset_file): 547 """ 548 Validates if the argument is a valid file name. 549 550 :param dataset_file: string containing file path 551 :return: Exception: when the validation fails, nothing otherwise. 552 """ 553 check_filename(dataset_file) 554 dataset_file = os.path.realpath(dataset_file) 555 if not os.path.isfile(dataset_file) or not os.access(dataset_file, os.R_OK): 556 raise ValueError("The file {} does not exist or permission denied!".format(dataset_file)) 557 558 559def check_sampler_shuffle_shard_options(param_dict): 560 """ 561 Check for valid shuffle, sampler, num_shards, and shard_id inputs. 562 Args: 563 param_dict (dict): param_dict. 564 565 Returns: 566 Exception: ValueError or RuntimeError if error. 567 """ 568 shuffle, sampler = param_dict.get('shuffle'), param_dict.get('sampler') 569 num_shards, shard_id = param_dict.get('num_shards'), param_dict.get('shard_id') 570 num_samples = param_dict.get('num_samples') 571 572 if sampler is not None: 573 if shuffle is not None: 574 raise RuntimeError("sampler and shuffle cannot be specified at the same time.") 575 if num_shards is not None or shard_id is not None: 576 raise RuntimeError("sampler and sharding cannot be specified at the same time.") 577 if num_samples is not None: 578 raise RuntimeError("sampler and num_samples cannot be specified at the same time.") 579 580 if num_shards is not None: 581 check_pos_int32(num_shards, "num_shards") 582 if shard_id is None: 583 raise RuntimeError("num_shards is specified and currently requires shard_id as well.") 584 check_value(shard_id, [0, num_shards - 1], "shard_id") 585 586 if num_shards is None and shard_id is not None: 587 raise RuntimeError("shard_id is specified but num_shards is not.") 588 589 590def check_padding_options(param_dict): 591 """ 592 Check for valid padded_sample and num_padded of padded samples. 593 594 Args: 595 param_dict (dict): param_dict. 596 597 Returns: 598 Exception: ValueError or RuntimeError if error. 599 """ 600 601 columns_list = param_dict.get('columns_list') 602 padded_sample, num_padded = param_dict.get('padded_sample'), param_dict.get('num_padded') 603 if padded_sample is not None: 604 if num_padded is None: 605 raise RuntimeError("padded_sample is specified and requires num_padded as well.") 606 if num_padded < 0: 607 raise ValueError("num_padded is invalid, num_padded={}.".format(num_padded)) 608 if columns_list is None: 609 raise RuntimeError("padded_sample is specified and requires columns_list as well.") 610 for column in columns_list: 611 if column not in padded_sample: 612 raise ValueError("padded_sample cannot match columns_list.") 613 if padded_sample is None and num_padded is not None: 614 raise RuntimeError("num_padded is specified but padded_sample is not.") 615 616 617def check_num_parallel_workers(value): 618 """ 619 Validates the value for num_parallel_workers. 620. 621 :param value: an integer corresponding to the number of parallel workers 622 :return: Exception: when the validation fails, nothing otherwise. 623 """ 624 type_check(value, (int,), "num_parallel_workers") 625 if value < 1 or value > cpu_count(): 626 raise ValueError("num_parallel_workers exceeds the boundary between 1 and {}!".format(cpu_count())) 627 628 629def check_num_samples(value): 630 """ 631 Validates number of samples are valid. 632. 633 :param value: an integer corresponding to the number of samples. 634 :return: Exception: when the validation fails, nothing otherwise. 635 """ 636 type_check(value, (int,), "num_samples") 637 if value < 0 or value > INT64_MAX: 638 raise ValueError( 639 "num_samples exceeds the boundary between {} and {}(INT64_MAX)!".format(0, INT64_MAX)) 640 641 642def validate_dataset_param_value(param_list, param_dict, param_type): 643 """ 644 645 :param param_list: a list of parameter names. 646 :param param_dict: a dcitionary containing parameter names and their values. 647 :param param_type: a tuple containing type of parameters. 648 :return: Exception: when the validation fails, nothing otherwise. 649 """ 650 for param_name in param_list: 651 if param_dict.get(param_name) is not None: 652 if param_name == 'num_parallel_workers': 653 check_num_parallel_workers(param_dict.get(param_name)) 654 if param_name == 'num_samples': 655 check_num_samples(param_dict.get(param_name)) 656 else: 657 type_check(param_dict.get(param_name), (param_type,), param_name) 658 659 660def check_gnn_list_of_pair_or_ndarray(param, param_name): 661 """ 662 Check if the input parameter is a list of tuple or numpy.ndarray. 663 664 Args: 665 param (Union[list[tuple], nd.ndarray]): param. 666 param_name (str): param_name. 667 668 Returns: 669 Exception: TypeError if error. 670 """ 671 type_check(param, (list, np.ndarray), param_name) 672 if isinstance(param, list): 673 param_names = ["node_list[{0}]".format(i) for i in range(len(param))] 674 type_check_list(param, (tuple,), param_names) 675 for idx, pair in enumerate(param): 676 if not len(pair) == 2: 677 raise ValueError("Each member in {0} must be a pair which means length == 2. Got length {1}".format( 678 param_names[idx], len(pair))) 679 column_names = ["node_list[{0}], number #{1} element".format(idx, i+1) for i in range(len(pair))] 680 type_check_list(pair, (int,), column_names) 681 elif isinstance(param, np.ndarray): 682 if param.ndim != 2: 683 raise ValueError("Input ndarray must be in dimension 2. Got {0}".format(param.ndim)) 684 if param.shape[1] != 2: 685 raise ValueError("Each member in {0} must be a pair which means length == 2. Got length {1}".format( 686 param_name, param.shape[1])) 687 if not param.dtype == np.int32: 688 raise TypeError("Each member in {0} should be of type int32. Got {1}.".format( 689 param_name, param.dtype)) 690 691 692def check_gnn_list_or_ndarray(param, param_name): 693 """ 694 Check if the input parameter is list or numpy.ndarray. 695 696 Args: 697 param (Union[list, nd.ndarray]): param. 698 param_name (str): param_name. 699 700 Returns: 701 Exception: TypeError if error. 702 """ 703 704 type_check(param, (list, np.ndarray), param_name) 705 if isinstance(param, list): 706 param_names = ["param_{0}".format(i) for i in range(len(param))] 707 type_check_list(param, (int,), param_names) 708 709 elif isinstance(param, np.ndarray): 710 if not param.dtype == np.int32: 711 raise TypeError("Each member in {0} should be of type int32. Got {1}.".format( 712 param_name, param.dtype)) 713 714 715def check_tensor_op(param, param_name): 716 """check whether param is a tensor op or a callable Python function""" 717 if not isinstance(param, cde.TensorOp) and not callable(param) and not getattr(param, 'parse', None): 718 raise TypeError("{0} is neither a c_transform op (TensorOperation) nor a callable pyfunc.".format(param_name)) 719 720 721def check_c_tensor_op(param, param_name): 722 """check whether param is a tensor op or a callable Python function but not a py_transform""" 723 if callable(param) and str(param).find("py_transform") >= 0: 724 raise TypeError("{0} is a py_transform op which is not allow to use.".format(param_name)) 725 if not isinstance(param, cde.TensorOp) and not callable(param) and not getattr(param, 'parse', None): 726 raise TypeError("{0} is neither a c_transform op (TensorOperation) nor a callable pyfunc.".format(param_name)) 727 728 729def replace_none(value, default): 730 """ replaces None with a default value.""" 731 return value if value is not None else default 732