1# Copyright 2020-2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""Check parameters.""" 16 17import re 18import inspect 19import math 20from enum import Enum 21from functools import reduce, wraps 22from itertools import repeat, zip_longest 23from collections import deque 24from collections.abc import Iterable 25import numpy as np 26from mindspore import context 27from mindspore import log as logger 28from mindspore.common import dtype as mstype 29from mindspore._c_expression import Tensor as Tensor_ 30 31 32class Rel(Enum): 33 34 """Numerical relationship between variables, logical relationship enumeration definition of range.""" 35 # scalar compare 36 EQ = 1 # == 37 NE = 2 # != 38 LT = 3 # < 39 LE = 4 # <= 40 GT = 5 # > 41 GE = 6 # >= 42 # scalar range check 43 INC_NEITHER = 7 # (), include neither 44 INC_LEFT = 8 # [), include left 45 INC_RIGHT = 9 # (], include right 46 INC_BOTH = 10 # [], include both 47 # collection in, not in 48 IN = 11 49 NOT_IN = 12 50 51 @staticmethod 52 def get_strs(rel): 53 """Get value from rel_strs.""" 54 return rel_strs.get(rel, "") 55 56 @staticmethod 57 def get_fns(rel): 58 """Get value from rel_fns.""" 59 return rel_fns.get(rel, lambda *args: False) 60 61 62rel_fns = { 63 # scalar compare 64 Rel.EQ: lambda x, y: x == y, 65 Rel.NE: lambda x, y: x != y, 66 Rel.LT: lambda x, y: x < y, 67 Rel.LE: lambda x, y: x <= y, 68 Rel.GT: lambda x, y: x > y, 69 Rel.GE: lambda x, y: x >= y, 70 # scalar range check 71 Rel.INC_NEITHER: lambda x, lower, upper: (lower < x < upper), 72 Rel.INC_LEFT: lambda x, lower, upper: (lower <= x < upper), 73 Rel.INC_RIGHT: lambda x, lower, upper: (lower < x <= upper), 74 Rel.INC_BOTH: lambda x, lower, upper: (lower <= x <= upper), 75 # collection in, not in 76 Rel.IN: lambda x, y: x in y, 77 Rel.NOT_IN: lambda x, y: x not in y, 78} 79 80rel_strs = { 81 # scalar compare 82 Rel.EQ: "= {}", 83 Rel.NE: "!= {}", 84 Rel.LT: "< {}", 85 Rel.LE: "<= {}", 86 Rel.GT: "> {}", 87 Rel.GE: ">= {}", 88 # scalar range check 89 Rel.INC_NEITHER: "({}, {})", 90 Rel.INC_LEFT: "[{}, {})", 91 Rel.INC_RIGHT: "({}, {}]", 92 Rel.INC_BOTH: "[{}, {}]", 93 # collection in, not in 94 Rel.IN: "in {}", 95 Rel.NOT_IN: "not in {}", 96} 97 98 99def _check_3d_int_or_tuple(arg_name, arg_value, prim_name, allow_five=False, ret_five=False, 100 greater_zero=True, third_one=False, three_input=False): 101 """ 102 Checks whether an argument is a positive int or tuple with 3 or 5(when allow_five is True) positive int elements. 103 """ 104 105 def _raise_message(third_one_flag=False, three_input_flag=False): 106 if third_one_flag: 107 raise ValueError(f"For '{prim_name}' the depth of attr '{arg_name}' should be 1, but got {ret_value[-3]}") 108 if three_input_flag: 109 raise ValueError(f"For '{prim_name}' attr '{arg_name}' should be an positive int number or a tuple of " 110 f"three positive int numbers, but got {arg_value}") 111 raise ValueError(f"For '{prim_name}' attr '{arg_name}' should be an positive int number or a tuple of three " 112 f"{'or five ' if allow_five else ''}positive int numbers, but got {arg_value}") 113 114 def _get_return_value(): 115 if isinstance(arg_value, int): 116 ret = (1, 1, arg_value, arg_value, arg_value) if ret_five else (arg_value, arg_value, arg_value) 117 elif len(arg_value) == 3: 118 ret = (1, 1, arg_value[0], arg_value[1], arg_value[2]) if ret_five else arg_value 119 elif len(arg_value) == 5: 120 if not allow_five: 121 _raise_message() 122 ret = arg_value if ret_five else (arg_value[1], arg_value[2], arg_value[3]) 123 else: 124 _raise_message() 125 return ret 126 127 Validator.check_value_type(arg_name, arg_value, (int, tuple), prim_name) 128 if three_input and isinstance(arg_value, tuple): 129 if len(arg_value) != 3: 130 _raise_message(three_input_flag=three_input) 131 ret_value = _get_return_value() 132 for item in ret_value: 133 if isinstance(item, int) and not isinstance(item, bool): 134 if greater_zero and item > 0: 135 continue 136 if not greater_zero and item >= 0: 137 continue 138 _raise_message() 139 140 if third_one: 141 if ret_value[-3] != 1: 142 _raise_message(third_one_flag=third_one) 143 144 return tuple(ret_value) 145 146 147def check_number(arg_value, value, rel, arg_type=int, arg_name=None, prim_name=None): 148 """ 149 Check argument integer. 150 151 Example: 152 - number = check_number(number, 0, Rel.GE, "number", None) # number >= 0 153 """ 154 rel_fn = Rel.get_fns(rel) 155 prim_name = f'in `{prim_name}`' if prim_name else '' 156 arg_name = f'`{arg_name}`' if arg_name else '' 157 158 if isinstance(arg_value, arg_type): 159 if math.isinf(arg_value) or math.isnan(arg_value) or np.isinf(arg_value) or np.isnan(arg_value): 160 raise ValueError(f'{arg_name} {prim_name} must be legal value, but got `{arg_value}`.') 161 else: 162 raise TypeError(f'{arg_name} {prim_name} must be {arg_type.__name__}, but got `{type(arg_value).__name__}`') 163 164 type_mismatch = not isinstance(arg_value, arg_type) or isinstance(arg_value, bool) 165 type_except = TypeError if type_mismatch else ValueError 166 if type_mismatch or not rel_fn(arg_value, value): 167 rel_str = Rel.get_strs(rel).format(value) 168 raise type_except(f'{arg_name} {prim_name} should be an {arg_type.__name__} and must {rel_str}, ' 169 f'but got `{arg_value}` with type `{type(arg_value).__name__}`.') 170 171 return arg_value 172 173 174def check_is_number(arg_value, arg_type, arg_name=None, prim_name=None): 175 """ 176 Checks input value is float type or not. 177 178 Usage: 179 - number = check_is_number(number, int) 180 - number = check_is_number(number, int, "bias") 181 - number = check_is_number(number, int, "bias", "bias_class") 182 """ 183 prim_name = f"For \'{prim_name}\', the" if prim_name else 'The' 184 arg_name = f"\'{arg_name}\'" if arg_name else 'input value' 185 if isinstance(arg_value, arg_type) and not isinstance(arg_value, bool): 186 if math.isinf(arg_value) or math.isnan(arg_value) or np.isinf(arg_value) or np.isnan(arg_value): 187 raise ValueError(f'{prim_name} {arg_name} must be legal float, but got `{arg_value}`.') 188 return arg_value 189 raise TypeError(f'{prim_name} type of {arg_name} must be {arg_type.__name__}, but got `{type(arg_value).__name__}`') 190 191 192def check_number_range(arg_value, lower_limit, upper_limit, rel, value_type, arg_name=None, prim_name=None): 193 """ 194 Method for checking whether an int value is in some range. 195 196 Usage: 197 - number = check_number_range(number, 0.0, 1.0, Rel.INC_NEITHER, "number", float) # number in [0.0, 1.0] 198 - number = check_number_range(number, 0, 1, Rel.INC_NEITHER, "number", int) # number in [0, 1] 199 """ 200 rel_fn = Rel.get_fns(rel) 201 prim_name = f'in `{prim_name}`' if prim_name else '' 202 arg_name = f'`{arg_name}`' if arg_name else '' 203 type_mismatch = not isinstance(arg_value, (np.ndarray, np.generic, value_type)) or isinstance(arg_value, bool) 204 if type_mismatch: 205 raise TypeError("{} {} must be `{}`, but got `{}`.".format( 206 arg_name, prim_name, value_type.__name__, type(arg_value).__name__)) 207 if not rel_fn(arg_value, lower_limit, upper_limit): 208 rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit) 209 raise ValueError("{} {} should be in range of {}, but got {:.3e} with type `{}`.".format( 210 arg_name, prim_name, rel_str, arg_value, type(arg_value).__name__)) 211 return arg_value 212 213 214class Validator: 215 """validator for checking input parameters""" 216 217 @staticmethod 218 def check(arg_name, arg_value, value_name, value, rel=Rel.EQ, prim_name=None, excp_cls=ValueError): 219 """ 220 Method for judging relation between two int values or list/tuple made up of ints. 221 This method is not suitable for judging relation between floats, since it does not consider float error. 222 """ 223 rel_fn = Rel.get_fns(rel) 224 if not rel_fn(arg_value, value): 225 rel_str = Rel.get_strs(rel).format(f'{value_name}: {value}') 226 msg_prefix = f'For \'{prim_name}\', the' if prim_name else "The" 227 raise excp_cls(f'{msg_prefix} \'{arg_name}\' should be {rel_str}, but got {arg_value}.') 228 return arg_value 229 230 @staticmethod 231 def check_int(arg_value, value, rel, arg_name=None, prim_name=None): 232 """ 233 Checks input integer value `arg_value` compare to `value`. 234 235 Usage: 236 - number = check_int(number, 0, Rel.GE, "number", None) # number >= 0 237 """ 238 return check_number(arg_value, value, rel, int, arg_name, prim_name) 239 240 @staticmethod 241 def check_is_int(arg_value, arg_name=None, prim_name=None): 242 """ 243 Checks input value is float type or not. 244 245 Usage: 246 - number = check_is_int(number, int) 247 - number = check_is_int(number, int, "bias") 248 - number = check_is_int(number, int, "bias", "bias_class") 249 """ 250 return check_is_number(arg_value, int, arg_name, prim_name) 251 252 @staticmethod 253 def check_equal_int(arg_value, value, arg_name=None, prim_name=None): 254 """ 255 Checks input integer value `arg_value` compare to `value`. 256 257 Usage: 258 - number = check_int(number, 0, Rel.GE, "number", None) # number >= 0 259 """ 260 return check_number(arg_value, value, Rel.EQ, int, arg_name, prim_name) 261 262 @staticmethod 263 def check_positive_int(arg_value, arg_name=None, prim_name=None): 264 """ 265 Check argument is positive integer, which mean arg_value > 0. 266 267 Usage: 268 - number = check_positive_int(number) 269 - number = check_positive_int(number, "bias") 270 """ 271 return check_number(arg_value, 0, Rel.GT, int, arg_name, prim_name) 272 273 @staticmethod 274 def check_negative_int(arg_value, arg_name=None, prim_name=None): 275 """ 276 Check argument is negative integer, which mean arg_value < 0. 277 278 Usage: 279 - number = check_negative_int(number) 280 - number = check_negative_int(number, "bias") 281 """ 282 return check_number(arg_value, 0, Rel.LT, int, arg_name, prim_name) 283 284 @staticmethod 285 def check_non_positive_int(arg_value, arg_name=None, prim_name=None): 286 """ 287 Check argument is non-negative integer, which mean arg_value <= 0. 288 289 Usage: 290 - number = check_non_positive_int(number) 291 - number = check_non_positive_int(number, "bias") 292 """ 293 return check_number(arg_value, 0, Rel.LE, int, arg_name, prim_name) 294 295 @staticmethod 296 def check_non_negative_int(arg_value, arg_name=None, prim_name=None): 297 """ 298 Check argument is non-negative integer, which mean arg_value >= 0. 299 300 Usage: 301 - number = check_non_negative_int(number) 302 - number = check_non_negative_int(number, "bias") 303 """ 304 return check_number(arg_value, 0, Rel.GE, int, arg_name, prim_name) 305 306 @staticmethod 307 def check_float(arg_value, value, rel, arg_name=None, prim_name=None): 308 """ 309 Checks input float value `arg_value` compare to `value`. 310 311 Usage: 312 - number = check_float(number, 0.0, Rel.GE, "number", None) # number >= 0 313 """ 314 return check_number(arg_value, value, rel, float, arg_name, prim_name) 315 316 @staticmethod 317 def check_is_float(arg_value, arg_name=None, prim_name=None): 318 """ 319 Checks input value is float type or not. 320 321 Usage: 322 - number = check_is_float(number, int) 323 - number = check_is_float(number, int, "bias") 324 - number = check_is_float(number, int, "bias", "bias_class") 325 """ 326 return check_is_number(arg_value, float, arg_name, prim_name) 327 328 @staticmethod 329 def check_positive_float(arg_value, arg_name=None, prim_name=None): 330 """ 331 Check argument is positive float, which mean arg_value > 0. 332 333 Usage: 334 - number = check_positive_float(number) 335 - number = check_positive_float(number, "bias") 336 - number = check_positive_float(number, "bias", "bias_class") 337 """ 338 return check_number(arg_value, 0, Rel.GT, float, arg_name, prim_name) 339 340 @staticmethod 341 def check_negative_float(arg_value, arg_name=None, prim_name=None): 342 """ 343 Check argument is negative float, which mean arg_value < 0. 344 345 Usage: 346 - number = check_negative_float(number) 347 - number = check_negative_float(number, "bias") 348 """ 349 return check_number(arg_value, 0, Rel.LT, float, arg_name, prim_name) 350 351 @staticmethod 352 def check_non_positive_float(arg_value, arg_name=None, prim_name=None): 353 """ 354 Check argument is non-negative float, which mean arg_value <= 0. 355 356 Usage: 357 - number = check_non_positive_float(number) 358 - number = check_non_positive_float(number, "bias") 359 """ 360 return check_number(arg_value, 0, Rel.LE, float, arg_name, prim_name) 361 362 @staticmethod 363 def check_non_negative_float(arg_value, arg_name=None, prim_name=None): 364 """ 365 Check argument is non-negative float, which mean arg_value >= 0. 366 367 Usage: 368 - number = check_non_negative_float(number) 369 - number = check_non_negative_float(number, "bias") 370 """ 371 return check_number(arg_value, 0, Rel.GE, float, arg_name, prim_name) 372 373 @staticmethod 374 def check_number(arg_name, arg_value, value, rel, prim_name): 375 """Number value judgment.""" 376 rel_fn = Rel.get_fns(rel) 377 if not rel_fn(arg_value, value): 378 rel_str = Rel.get_strs(rel).format(value) 379 raise ValueError(f'For \'{prim_name}\' the `{arg_name}` must {rel_str}, but got {arg_value}.') 380 return arg_value 381 382 @staticmethod 383 def check_isinstance(arg_name, arg_value, classes): 384 """Check arg isinstance of classes""" 385 if not isinstance(arg_value, classes): 386 raise ValueError(f'The `{arg_name}` should be isinstance of {classes}, but got {arg_value}.') 387 return arg_value 388 389 @staticmethod 390 def check_bool(arg_value, arg_name=None, prim_name=None): 391 """ 392 Check argument is instance of bool. 393 394 Usage: 395 - has_bias = check_bool(has_bias) 396 - has_bias = check_bool(has_bias, "has_bias") 397 """ 398 if not isinstance(arg_value, bool): 399 if prim_name and arg_name: 400 msg_prefix = f"For '{prim_name}', the '{arg_name}'" 401 elif prim_name and arg_name is None: 402 msg_prefix = f"For '{prim_name}', Parameter" 403 else: 404 msg_prefix = "Parameter" 405 raise TypeError(f"{msg_prefix} should be a bool, but got {type(arg_value).__name__}.") 406 return arg_value 407 408 @staticmethod 409 def check_int_range(arg_value, lower_limit, upper_limit, rel, arg_name=None, prim_name=None): 410 """ 411 Method for checking whether input value is in int range. 412 413 Usage: 414 - number = check_int_range(number, 0, 1, Rel.INC_NEITHER) # number in [0, 1] 415 - number = check_int_range(number, 0, 1, Rel.INC_NEITHER, "number") # number in [0, 1] 416 """ 417 return check_number_range(arg_value, lower_limit, upper_limit, rel, int, arg_name, prim_name) 418 419 @staticmethod 420 def check_float_range(arg_value, lower_limit, upper_limit, rel, arg_name=None, prim_name=None): 421 """ 422 Method for checking whether input value is in float range. 423 424 Usage: 425 - number = check_float_range(number, 0.0, 1.0, Rel.INC_NEITHER) # number in [0.0, 1.0] 426 - number = check_float_range(number, 0.0, 1.0, Rel.INC_NEITHER, "number") # number in [0.0, 1.0] 427 """ 428 return check_number_range(arg_value, lower_limit, upper_limit, rel, float, arg_name, prim_name) 429 430 @staticmethod 431 def check_string(arg_value, valid_values, arg_name=None, prim_name=None): 432 """ 433 Check whether string is in some value list. 434 435 Usage: 436 - method = check_string(method, ["string1", "string2", "string3"], "method") 437 """ 438 if isinstance(arg_value, str) and arg_value in valid_values: 439 return arg_value 440 arg_name = arg_name if arg_name else "Parameter" 441 msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The" 442 raise ValueError(f"{msg_prefix} '{arg_name}' should be str and must be in '{valid_values}'," 443 f" but got '{arg_value}'.") 444 445 @staticmethod 446 def check_str_by_regular(target, reg=None, flag=re.ASCII, prim_name=None): 447 if reg is None: 448 # Named string regular expression 449 reg = r"^\w+[0-9a-zA-Z\_\.]*$" 450 if re.match(reg, target, flag) is None: 451 prim_name = f'in `{prim_name}`' if prim_name else "" 452 raise ValueError("'{}' {} is illegal, it should be match regular'{}' by flags'{}.'".format( 453 target, prim_name, reg, flag)) 454 return True 455 456 @staticmethod 457 def check_file_name_by_regular(target, reg=None, prim_name=None): 458 """Check whether file name is legitimate.""" 459 if not isinstance(target, str): 460 raise ValueError("Args file_name {} must be string, please check it".format(target)) 461 if target.endswith("\\") or target.endswith("/"): 462 raise ValueError("File name cannot be a directory path.") 463 if reg is None: 464 reg = r"^[0-9a-zA-Z\_\-\.\:\/\\]+$" 465 if re.match(reg, target) is None: 466 prim_name = f'in `{prim_name}`' if prim_name else "" 467 raise ValueError("'{}' {} is illegal, it should be match regular'{}'.".format( 468 target, prim_name, reg)) 469 470 return True 471 472 @staticmethod 473 def check_pad_value_by_mode(pad_mode, padding, prim_name): 474 """Validates value of padding according to pad_mode""" 475 if pad_mode != 'pad' and padding != 0: 476 raise ValueError(f"For '{prim_name}', padding must be zero when pad_mode is '{pad_mode}'.") 477 return padding 478 479 @staticmethod 480 def check_subclass(arg_name, type_, template_types, prim_name, addition_error_info=None): 481 """Checks whether some type is subclass of another type""" 482 if not isinstance(template_types, Iterable): 483 template_types = (template_types,) 484 hit = False 485 for template_type in template_types: 486 if isinstance(template_type, mstype.Type): 487 if mstype.issubclass_(type_, template_type): 488 hit = True 489 break 490 elif type_ is template_type: 491 hit = True 492 break 493 if not hit: 494 if addition_error_info is None: 495 addition_error_info = '' 496 type_str = (type(type_).__name__ if isinstance(type_, (tuple, list)) else "") + str(type_) 497 raise TypeError(f"For '{prim_name}', the type of '{arg_name}'" 498 f" should be {'one of ' if len(template_types) > 1 else ''}" 499 f"{', '.join((str(x) for x in template_types))}, but got {type_str}" 500 f" {addition_error_info}. The supported data types depend on the hardware that" 501 f" executes the operator, please refer the official api document to get" 502 f" more information about the data type.") 503 504 @staticmethod 505 def check_valid_input(arg_name, arg_value, prim_name): 506 """Checks valid value.""" 507 if arg_value is None: 508 raise ValueError(f"For \'{prim_name}\', the '{arg_name}' can not be None, but got {arg_value}.") 509 return arg_value 510 511 @staticmethod 512 def check_types_same_and_valid(args, valid_values, prim_name): 513 """Checks whether the types of inputs are the same and valid.""" 514 515 def _check_type_valid(arg): 516 arg_key, arg_val = arg 517 elem_type = arg_val 518 Validator.check_subclass(arg_key, elem_type, valid_values, prim_name) 519 return (arg_key, elem_type) 520 521 def _check_types_same(arg1, arg2): 522 arg1_name, arg1_type = arg1 523 arg2_name, arg2_type = arg2 524 if arg1_type != arg2_type: 525 raise TypeError(f"For '{prim_name}', type of '{arg2_name}' should be same as '{arg1_name}'," 526 f" but got '{arg1_name}' with type {arg1_type}" 527 f" and '{arg2_name}' with type {arg2_type}.") 528 return arg1 529 530 elem_types = map(_check_type_valid, args.items()) 531 reduce(_check_types_same, elem_types) 532 533 @staticmethod 534 def check_tensors_dtypes_same_and_valid(args, valid_dtypes, prim_name): 535 """Checks whether the element types of input tensors are the same and valid.""" 536 valid_dtypes = valid_dtypes if isinstance(valid_dtypes, Iterable) else [valid_dtypes] 537 tensor_types = [mstype.tensor_type(t) for t in valid_dtypes] 538 Validator.check_types_same_and_valid(args, tensor_types, prim_name) 539 540 @staticmethod 541 def check_tensor_dtype_valid(arg_name, arg_type, valid_dtypes, prim_name): 542 """Checks whether the element types of input tensors are valid.""" 543 valid_dtypes = valid_dtypes if isinstance(valid_dtypes, Iterable) else [valid_dtypes] 544 tensor_types = [mstype.tensor_type(t) for t in valid_dtypes] 545 Validator.check_subclass(arg_name, arg_type, tensor_types, prim_name) 546 547 @staticmethod 548 def check_scalar_or_tensor_types_same(args, valid_values, prim_name, allow_mix=False): 549 """ 550 Checks whether the types of inputs are the same. If the input args are tensors, checks their element types. 551 If `allow_mix` is True, Tensor(float32) and float32 are type compatible, otherwise an exception will be raised. 552 """ 553 554 def _check_argument_type(arg): 555 arg_key, arg_val = arg 556 if isinstance(arg_val, type(mstype.tensor)): 557 arg_val = arg_val.element_type() 558 if not arg_val in valid_values: 559 raise TypeError(f'For \'{prim_name}\', the type of `{arg_key}` should be in {valid_values},' 560 f' but got {arg_val}.') 561 return arg 562 563 def _check_types_same(arg1, arg2): 564 arg1_name, arg1_type = arg1 565 arg2_name, arg2_type = arg2 566 except_flag = False 567 if isinstance(arg1_type, type(mstype.tensor)) and isinstance(arg2_type, type(mstype.tensor)): 568 arg1_type = arg1_type.element_type() 569 arg2_type = arg2_type.element_type() 570 elif not (isinstance(arg1_type, type(mstype.tensor)) or isinstance(arg2_type, type(mstype.tensor))): 571 pass 572 elif allow_mix: 573 arg1_type = arg1_type.element_type() if isinstance(arg1_type, type(mstype.tensor)) else arg1_type 574 arg2_type = arg2_type.element_type() if isinstance(arg2_type, type(mstype.tensor)) else arg2_type 575 else: 576 except_flag = True 577 578 if except_flag or arg1_type != arg2_type: 579 raise TypeError(f'For \'{prim_name}\' type of `{arg2_name}` should be same as `{arg1_name}`,' 580 f' but `{arg1_name}` is {arg1_type} and `{arg2_name}` is {arg2_type}.') 581 return arg1 582 583 reduce(_check_types_same, map(_check_argument_type, args.items())) 584 585 @staticmethod 586 def check_value_type(arg_name, arg_value, valid_types, prim_name=None): 587 """Checks whether a value is instance of some types.""" 588 valid_types = valid_types if isinstance(valid_types, Iterable) else (valid_types,) 589 590 def raise_error_msg(): 591 """func for raising error message when check failed""" 592 type_names = [t.__name__ if hasattr(t, '__name__') else str(t) for t in valid_types] 593 num_types = len(valid_types) 594 msg_prefix = f"For '{prim_name}', the" if prim_name else "The" 595 raise TypeError(f'{msg_prefix} type of `{arg_name}` should be {"one of " if num_types > 1 else ""}' 596 f'{type_names if num_types > 1 else type_names[0]}, ' 597 f'but got {arg_value} with type {type(arg_value).__name__}.') 598 599 # Notice: bool is subclass of int, so `check_value_type('x', True, [int])` will check fail, and 600 # `check_value_type('x', True, [bool, int])` will check pass 601 if isinstance(arg_value, bool) and bool not in tuple(valid_types): 602 raise_error_msg() 603 if not isinstance(arg_value, tuple(valid_types)): 604 raise_error_msg() 605 return arg_value 606 607 @staticmethod 608 def check_type_name(arg_name, arg_type, valid_types, prim_name): 609 """Checks whether a type in some specified types""" 610 valid_types = valid_types if isinstance(valid_types, Iterable) else (valid_types,) 611 612 def raise_error_msg(): 613 """func for raising error message when check failed""" 614 type_names = [t.__name__ if hasattr(t, '__name__') else t for t in valid_types] 615 num_types = len(valid_types) 616 msg_prefix = f"For '{prim_name}', the" if prim_name else "The" 617 raise TypeError(f"{msg_prefix} '{arg_name}' should be {'one of ' if num_types > 1 else ''}" 618 f"{type_names if num_types > 1 else type_names[0]}, " 619 f"but got {arg_type.__name__ if hasattr(arg_type, '__name__') else repr(arg_type)}.") 620 621 if isinstance(arg_type, type(mstype.tensor)): 622 arg_type = arg_type.element_type() 623 if arg_type not in valid_types: 624 raise_error_msg() 625 return arg_type 626 627 @staticmethod 628 def check_reduce_shape(ori_shape, shape, axis, prim_name): 629 """Checks whether shape is ori_shape reduced on axis""" 630 axis = axis if isinstance(axis, Iterable) else (axis,) 631 exp_shape = [ori_shape[i] for i in range(len(ori_shape)) if i not in axis] 632 if list(shape) != exp_shape: 633 raise ValueError(f"For '{prim_name}', the origin shape {ori_shape} reduce on {axis} should be " 634 f"{tuple(exp_shape)}, but got {shape}.") 635 636 @staticmethod 637 def check_astype_dtype(dtype): 638 """Check whether dtype is a valid input, and convert to mstype""" 639 all_types = mstype.__dtype__ + ["int", "float", "bool"] 640 if isinstance(dtype, str): 641 if dtype.lower() not in all_types: 642 raise TypeError(f"`{dtype}` not understood.") 643 dtype = mstype.pytype_to_dtype(np.dtype(dtype.lower())) 644 elif isinstance(dtype, type): 645 dtype = mstype.pytype_to_dtype(dtype) 646 elif not dtype in mstype.number_type + (mstype.bool_,): 647 raise TypeError(f"`{dtype}` not understood.") 648 return dtype 649 650 @staticmethod 651 def check_transpose_axis(axes, ndim): 652 """Check the axis argument for tensor.transpose""" 653 if not axes or (len(axes) == 1 and axes[0] is None): 654 return tuple(range(ndim-1, -1, -1)) 655 656 if len(axes) == 1: 657 perm = axes[0] 658 # if only one argument provided, it must be tuple or list 659 if isinstance(perm, list): 660 perm = tuple(perm) 661 else: 662 if not isinstance(perm, tuple): 663 raise TypeError(f"The `axes` should be a tuple/list, or series of int, but got {type(axes[0])}") 664 return perm 665 666 # if multiple arguments provided, it must be `ndim` number of ints 667 if len(axes) != ndim: 668 raise ValueError("The number of axes must equal to the dimension of tensor.") 669 return axes 670 671 @staticmethod 672 def check_reshape_shp(shp): 673 """Check the shape argument for tensor.reshape""" 674 675 if len(shp) == 1: 676 new_shape = shp[0] 677 # if only one argument provided, it must be int, tuple or list 678 if isinstance(new_shape, int): 679 return shp 680 if isinstance(new_shape, list): 681 new_shape = tuple(new_shape) 682 else: 683 if not isinstance(new_shape, tuple): 684 raise TypeError( 685 f"The `shape` should be an int, or tuple/list, or series of int, but got {type(shp[0])}") 686 return new_shape 687 688 return shp 689 690 @staticmethod 691 def check_flatten_order(order): 692 """Check flatten function input order""" 693 if not isinstance(order, str): 694 raise TypeError(f"The order variable should be a string, but got {type(order)}") 695 if order not in ('C', 'F'): 696 raise ValueError(f"only `C` and `F` are supported as order, but got {order}") 697 return order 698 699 @staticmethod 700 def check_swapaxes_axis(axes, ndim): 701 """Check all the axes argument for tensor.swapaxes""" 702 if isinstance(axes, int): 703 Validator.check_axis_in_range(axes, ndim) 704 return axes % ndim 705 if isinstance(axes, (tuple, list)): 706 for axis in axes: 707 if not isinstance(axis, int): 708 raise TypeError(f"axis argument should be integer, but got {type(axis)}.") 709 Validator.check_axis_in_range(axis, ndim) 710 axes = tuple(map(lambda x: x % ndim, axes)) 711 return axes 712 raise TypeError(f"axes should be integer, list or tuple for check, but got {type(axes)}.") 713 714 @staticmethod 715 def prepare_shape_for_squeeze(shape, axes): 716 """ 717 Creates the squeezed new shape based on the tensor and given axes. 718 719 Args: 720 shape (tuple): the shape of the tensor 721 axes Union[int, tuple(int), list(int)]: the axes with dimensions need to 722 be squeezed. 723 724 Returns: 725 new_shape(tuple): the shape with dimensions squeezed. 726 """ 727 new_shape = [] 728 ndim = len(shape) 729 730 # Convert to set 731 if isinstance(axes, int): 732 if axes >= ndim or axes < -ndim: 733 raise ValueError(f"axis {axes} is out of bounds for tensor of dimension {ndim}") 734 axes = {axes} 735 736 elif isinstance(axes, (list, tuple)): 737 for axis in axes: 738 if axis >= ndim or axis < -ndim: 739 raise ValueError(f"axis {axis} is out of bounds for tensor of dimension {ndim}") 740 axes = set(axes) 741 742 else: 743 raise TypeError(f"only int, tuple and list are allowed for axes, but got {type(axes)}") 744 745 for idx, s in enumerate(shape): 746 if s != 1 or (idx not in axes) and (idx - ndim not in axes): 747 new_shape.append(s) 748 # if an axis is selected with shape entry greater than one, an error is raised. 749 if s != 1 and ((idx in axes) or (idx - ndim in axes)): 750 raise ValueError(f"axis {axes} has shape entry {s} > 1, cannot be squeezed.") 751 return tuple(new_shape) 752 753 @staticmethod 754 def check_axis_in_range(axis, ndim): 755 """Checks axes are with the bounds of ndim""" 756 if not isinstance(axis, int): 757 raise TypeError(f'axes should be integers, not {type(axis)}') 758 if not -ndim <= axis < ndim: 759 raise ValueError(f'axis {axis} is out of bounds for array of dimension {ndim}') 760 return axis % ndim 761 762 @staticmethod 763 def check_axis_valid(axes, ndim): 764 """ 765 Checks axes are valid given ndim, and returns axes that can be passed 766 to the built-in operator (non-negative, int or tuple) 767 """ 768 if axes is None: 769 axes = tuple(range(ndim)) 770 return axes 771 if isinstance(axes, (tuple, list)): 772 for axis in axes: 773 Validator.check_axis_in_range(axis, ndim) 774 axes = tuple(map(lambda x: x % ndim, axes)) 775 if any(axes.count(el) > 1 for el in axes): 776 raise ValueError('duplicate value in "axis"') 777 return axes 778 Validator.check_axis_in_range(axes, ndim) 779 return (axes % ndim,) 780 781 @staticmethod 782 def max_(*args): 783 return max(*args) 784 785 @staticmethod 786 def min_(*args): 787 return min(*args) 788 789 @staticmethod 790 def expanded_shape(ndim, axis_size, axis): 791 """ 792 Returns a shape with size = 1 for all dimensions 793 except at axis. 794 """ 795 return tuple(axis_size if i == axis else 1 for i in range(ndim)) 796 797 @staticmethod 798 def tuple_slice(tup, start, end): 799 """get sliced tuple from start and end.""" 800 return tup[start:end] 801 802 @staticmethod 803 def infer_out_shape(*shapes): 804 """ 805 Returns shape of output after broadcasting. Raises ValueError if shapes cannot be broadcast. 806 """ 807 shape_out = deque() 808 reversed_shapes = map(reversed, shapes) 809 for items in zip_longest(*reversed_shapes, fillvalue=1): 810 max_size = 0 if 0 in items else max(items) 811 if any(item not in (1, max_size) for item in items): 812 raise ValueError(f'operands could not be broadcast together with shapes {*shapes,}') 813 shape_out.appendleft(max_size) 814 return tuple(shape_out) 815 816 @staticmethod 817 def get_log2_size(size): 818 return math.ceil(math.log2(size)) 819 820 @staticmethod 821 def check_axis_type(axis, type_int=True, type_tuple=True, type_list=True): 822 """Check axis argument type.""" 823 if type_int and isinstance(axis, int): 824 return True 825 if (type_tuple and isinstance(axis, tuple)) or (type_list and isinstance(axis, list)): 826 for ax in axis: 827 if not isinstance(ax, int): 828 raise TypeError(f"Each axis should be integer, but got {type(ax)} in {axis}.") 829 return True 830 831 type_str = "" 832 if type_int: 833 type_str += "int, " 834 if type_tuple: 835 type_str += "tuple, " 836 if type_list: 837 type_str += "list, " 838 raise TypeError(f"Axis should be {type_str}but got {type(axis)}.") 839 840 @staticmethod 841 def check_and_canonicalize_axes(axes, ndim): 842 """Check whether the types and values of input axes are valid.""" 843 axes = axes if isinstance(axes, tuple) else (axes,) 844 new_axes = () 845 for ax in axes: 846 if not isinstance(ax, int): 847 raise TypeError((f"Each axis should be integer, but got {type(ax)} in {axes}.")) 848 if not -ndim <= ax < ndim: 849 raise ValueError(f'axis {ax} is out of bounds for array of dimension {ndim}') 850 ax = ax if ax >= 0 else ax + ndim 851 new_axes += (ax,) 852 if any(new_axes.count(el) > 1 for el in new_axes): 853 raise ValueError('duplicate value in "axis"') 854 return new_axes 855 856 @staticmethod 857 def empty_compile(dtype, shape): 858 """Returns an empty Tensor.""" 859 return Tensor_(dtype, shape) 860 861 @staticmethod 862 def check_type_support(dtype, device, supported_dtypes): 863 return dtype in supported_dtypes or not context.get_context('device_target') == device 864 865 866def check_input_format(input_param): 867 """Judge input format.""" 868 if input_param == "NCHW": 869 return input_param 870 raise ValueError("The data format must be NCHW.") 871 872 873def _expand_tuple(n_dimensions): 874 """To expand a int number to tuple.""" 875 876 def convert(m): 877 if not isinstance(m, tuple): 878 if isinstance(m, int) and not isinstance(m, bool): 879 return tuple(repeat(m, n_dimensions)) 880 raise TypeError("Input type must be int or tuple[int].") 881 882 if not len(m) is n_dimensions: 883 raise TypeError("Input tuple dimension is incorrect.") 884 885 for i in m: 886 if not isinstance(i, int) or isinstance(i, bool): 887 raise TypeError("Incorrect type inside of a tuple, must be int!") 888 return m 889 890 return convert 891 892 893def _check_data_type_valid(data, valid_type): 894 """Check data type valid.""" 895 if valid_type is None: 896 return data is None 897 if isinstance(data, valid_type): 898 if hasattr(data, 'size') and data.size == 0: 899 msg = "Please provide non-empty data." 900 logger.error(msg) 901 raise ValueError(msg) 902 return True 903 return False 904 905 906def check_input_data(*data, data_class): 907 """Input data check.""" 908 for item in data: 909 if isinstance(item, (list, tuple)): 910 for v in item: 911 check_input_data(v, data_class=data_class) 912 elif isinstance(item, dict): 913 for v in item.values(): 914 check_input_data(v, data_class=data_class) 915 else: 916 if isinstance(data_class, (tuple, list)): 917 ret = True in tuple(_check_data_type_valid(item, data_type) for data_type in data_class) 918 else: 919 ret = _check_data_type_valid(item, data_class) 920 if not ret: 921 data_class_str = tuple(i.__name__ if hasattr(i, '__name__') else i for i in data_class) \ 922 if isinstance(data_class, (tuple, list)) else \ 923 (data_class if data_class is None else data_class.__name__) 924 raise ValueError(f'Please provide as model inputs either a single or ' 925 f'a tuple or a list or a dict of {data_class_str}, ' 926 f'but got part data type is {item if item is None else type(item).__name__}.') 927 928 929def check_output_data(data): 930 """Output data check.""" 931 if data is None: 932 raise RuntimeError('Executor return data ' + str(data) + ', please check your net or input data.') 933 934 935once = _expand_tuple(1) 936twice = _expand_tuple(2) 937triple = _expand_tuple(3) 938 939 940def args_type_check(*type_args, **type_kwargs): 941 """Check whether input data type is correct.""" 942 943 def type_check(func): 944 sig = inspect.signature(func) 945 bound_types = sig.bind_partial(*type_args, **type_kwargs).arguments 946 947 @wraps(func) 948 def wrapper(*args, **kwargs): 949 nonlocal bound_types 950 bound_values = sig.bind(*args, **kwargs) 951 argument_dict = bound_values.arguments 952 if "kwargs" in bound_types: 953 bound_types = bound_types["kwargs"] 954 if "kwargs" in argument_dict: 955 argument_dict = argument_dict["kwargs"] 956 for name, value in argument_dict.items(): 957 if name in bound_types: 958 if value is not None and not isinstance(value, bound_types[name]): 959 raise TypeError('Argument {} must be {}'.format(name, bound_types[name])) 960 return func(*args, **kwargs) 961 962 return wrapper 963 964 return type_check 965 966 967_set_record = {} 968 969 970def args_unreset_check(*unreset_args, **unreset_kwargs): 971 """Check the entered non repeatable setting properties.""" 972 973 def unreset_check(func): 974 sig = inspect.signature(func) 975 bound_unreset = sig.bind_partial(*unreset_args, **unreset_kwargs).arguments 976 977 @wraps(func) 978 def wrapper(*args, **kwargs): 979 nonlocal bound_unreset 980 bound_values = sig.bind(*args, **kwargs) 981 argument_dict = bound_values.arguments 982 if "kwargs" in bound_unreset: 983 bound_unreset = bound_unreset["kwargs"] 984 if "kwargs" in argument_dict: 985 argument_dict = argument_dict["kwargs"] 986 for name, value in argument_dict.items(): 987 if name in _set_record.keys(): 988 raise TypeError('Argument {} is non-renewable parameter {}.'.format(name, bound_unreset[name])) 989 if name in bound_unreset: 990 _set_record[name] = value 991 return func(*args, **kwargs) 992 993 return wrapper 994 995 return unreset_check 996