1# Copyright 2020-2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""Check parameters.""" 16from __future__ import absolute_import 17 18import re 19import inspect 20import math 21from types import FunctionType, MethodType 22from functools import reduce, wraps 23from itertools import repeat 24from collections.abc import Iterable 25import numpy as np 26 27from mindspore import context 28from mindspore import log as logger 29from mindspore.common import dtype as mstype 30from mindspore._c_expression import Tensor as Tensor_ 31 32 33EQ = 1 # == 34NE = 2 # != 35LT = 3 # < 36LE = 4 # <= 37GT = 5 # > 38GE = 6 # >= 39# scalar range check 40INC_NEITHER = 7 # (), include neither 41INC_LEFT = 8 # [), include left 42INC_RIGHT = 9 # (], include right 43INC_BOTH = 10 # [], include both 44# collection in, not in 45IN = 11 46NOT_IN = 12 47 48 49def _check_binary_rel(val1, val2, rel): 50 """check binary relation""" 51 if rel == EQ: 52 return val1 == val2 53 if rel == NE: 54 return val1 != val2 55 if rel == LT: 56 return val1 < val2 57 if rel == LE: 58 return val1 <= val2 59 if rel == GT: 60 return val1 > val2 61 if rel == GE: 62 return val1 >= val2 63 if rel == IN: 64 return val1 in val2 65 if rel == NOT_IN: 66 return val1 not in val2 67 68 return False 69 70 71def _check_inc_rel(val, lower, upper, rel): 72 """check include relation""" 73 if rel == INC_NEITHER: 74 return not (val <= lower or val >= upper) 75 if rel == INC_LEFT: 76 return not (val < lower or val >= upper) 77 if rel == INC_RIGHT: 78 return not (val <= lower or val > upper) 79 if rel == INC_BOTH: 80 return not (val < lower or val > upper) 81 82 return False 83 84 85def _format_str_one_value(value, rel): 86 """format string""" 87 if rel == EQ: 88 return f"= {value}" 89 if rel == NE: 90 return f"!= {value}" 91 if rel == LT: 92 return f"< {value}" 93 if rel == LE: 94 return f"<= {value}" 95 if rel == GT: 96 return f"> {value}" 97 if rel == GE: 98 return f">= {value}" 99 if rel == IN: 100 return f"in {value}" 101 if rel == NOT_IN: 102 return f"not in {value}" 103 104 return "" 105 106 107def _format_str_two_value(val1, val2, rel): 108 """format string""" 109 if rel == INC_NEITHER: 110 return f"({val1}, {val2})" 111 if rel == INC_LEFT: 112 return f"[{val1}, {val2})" 113 if rel == INC_RIGHT: 114 return f"({val1}, {val2}]" 115 if rel == INC_BOTH: 116 return f"[{val1}, {val2}]" 117 118 return "" 119 120 121def _check_3d_int_or_tuple(arg_name, arg_value, prim_name, allow_five=False, ret_five=False, 122 greater_zero=True, third_one=False, three_input=False): 123 """ 124 Checks whether an argument is a positive int or tuple with 3 or 5(when allow_five is True) positive int elements. 125 """ 126 127 def _raise_message(third_one_flag=False, three_input_flag=False): 128 if third_one_flag: 129 raise ValueError(f"For '{prim_name}', the depth of parameter '{arg_name}' must be 1, " \ 130 f"but got {ret_value[-3]}.") 131 if three_input_flag: 132 raise ValueError(f"For '{prim_name}', the parameter '{arg_name}' must be an positive integer " \ 133 f"or a tuple of three positive integer, but got {arg_value}.") 134 raise ValueError(f"For '{prim_name}', the parameter '{arg_name}' must be an positive integer or " \ 135 f"a tuple of three {'or five ' if allow_five else ''}positive integer, but got {arg_value}") 136 137 def _get_return_value(): 138 def _check(): 139 if not isinstance(arg_value, int): 140 if len(arg_value) == 5: 141 if not allow_five: 142 _raise_message() 143 elif not len(arg_value) == 3: 144 _raise_message() 145 146 _check() 147 if isinstance(arg_value, int): 148 ret = (1, 1, arg_value, arg_value, arg_value) if ret_five else (arg_value, arg_value, arg_value) 149 elif len(arg_value) == 3: 150 ret = (1, 1, arg_value[0], arg_value[1], arg_value[2]) if ret_five else arg_value 151 else: # case: len(arg_value) == 5 152 ret = arg_value if ret_five else (arg_value[2], arg_value[3], arg_value[4]) 153 154 return ret 155 156 def _check_value(ret_value): 157 for item in ret_value: 158 if isinstance(item, int) and not isinstance(item, bool): 159 if greater_zero and item > 0: 160 continue 161 if not greater_zero and item >= 0: 162 continue 163 _raise_message() 164 165 def _check_third_one(ret_value): 166 if third_one: 167 if ret_value[-3] != 1: 168 _raise_message(third_one_flag=third_one) 169 170 check_value_type(arg_name, arg_value, (int, tuple), prim_name) 171 if three_input and isinstance(arg_value, tuple): 172 if len(arg_value) != 3: 173 _raise_message(three_input_flag=three_input) 174 ret_value = _get_return_value() 175 _check_value(ret_value) 176 _check_third_one(ret_value) 177 178 return tuple(ret_value) 179 180 181def _check_dup(axes): 182 for item in axes: 183 count = 0 184 for item2 in axes: 185 if item == item2: 186 count += 1 187 188 if count > 1: 189 raise ValueError(f"The element of parameter 'axis' can not be duplicate, but got {axes}.") 190 191 192def _check_number(arg_value, value, rel, arg_type=int, arg_name=None, prim_name=None): 193 """ 194 Check argument integer. 195 196 Usage: 197 - arg_value = _check_number(arg_value, 2, GT, int, "value", None) 198 """ 199 prim_name = f"For \'{prim_name}\', the " if prim_name else 'The ' 200 arg_name = f"\'{arg_name}\'" if arg_name else 'input value' 201 202 def _check_param(): 203 prim_info = f'{prim_name}' + f'{arg_name}' 204 if isinstance(arg_value, arg_type): 205 if math.isinf(arg_value) or math.isnan(arg_value) or np.isinf(arg_value) or np.isnan(arg_value): 206 raise ValueError(f"{prim_info} must be a legal value, but got '{arg_value}'.") 207 else: 208 raise TypeError(f"{prim_info} must be {arg_type.__name__}, but got '{type(arg_value).__name__}'") 209 210 type_mismatch = not isinstance(arg_value, arg_type) or isinstance(arg_value, bool) 211 rel_ret = _check_binary_rel(arg_value, value, rel) 212 if type_mismatch or not rel_ret: 213 rel_str = _format_str_one_value(value, rel) 214 msg = f"{prim_info} must be {arg_type.__name__} and must {rel_str}, " \ 215 f"but got '{arg_value}' with type '{type(arg_value).__name__}'." 216 if type_mismatch: 217 raise TypeError(msg) 218 raise ValueError(msg) 219 220 _check_param() 221 return arg_value 222 223 224def check_is_number(arg_value, arg_type, arg_name=None, prim_name=None): 225 """ 226 Checks input value is float type or not. 227 228 Usage: 229 - number = check_is_number(number, int) 230 - number = check_is_number(number, int, "bias") 231 - number = check_is_number(number, int, "bias", "bias_class") 232 """ 233 prim_name = f"For \'{prim_name}\', the" if prim_name else 'The' 234 arg_name = f"\'{arg_name}\'" if arg_name else 'input value' 235 236 def _check_param(): 237 if isinstance(arg_value, arg_type) and not isinstance(arg_value, bool): 238 if math.isinf(arg_value) or math.isnan(arg_value) or np.isinf(arg_value) or np.isnan(arg_value): 239 raise ValueError(f"{prim_name} {arg_name} must be a legal float, but got '{arg_value}'.") 240 else: 241 raise TypeError(f"{prim_name} type of {arg_name} must be '{arg_type.__name__}', " \ 242 f"but got '{type(arg_value).__name__}'.") 243 _check_param() 244 return arg_value 245 246 247def check_number_range(arg_value, lower_limit, upper_limit, rel, value_type, arg_name=None, prim_name=None): 248 """ 249 Method for checking whether an int value is in some range. 250 251 Usage: 252 - number = check_number_range(number, 0.0, 1.0, INC_NEITHER, "number", float) # number in [0.0, 1.0] 253 - number = check_number_range(number, 0, 1, INC_NEITHER, "number", int) # number in [0, 1] 254 """ 255 prim_name = f"For \'{prim_name}\', the" if prim_name else 'The' 256 arg_name = f"\'{arg_name}\'" if arg_name else 'input value' 257 258 def _check_param(): 259 type_mismatch = not isinstance(arg_value, (np.ndarray, np.generic, value_type)) or isinstance(arg_value, bool) 260 if type_mismatch: 261 raise TypeError(f"{prim_name} {arg_name} must be '{value_type.__name__}', " \ 262 f"but got '{type(arg_value).__name__}'.") 263 264 if not _check_inc_rel(arg_value, lower_limit, upper_limit, rel): 265 rel_str = _format_str_two_value(lower_limit, upper_limit, rel) 266 raise ValueError(f"{prim_name} {arg_name} must be in range of {rel_str}, " \ 267 f"but got {arg_value} with type '{type(arg_value).__name__}'.") 268 _check_param() 269 return arg_value 270 271 272def check(arg_name, arg_value, value_name, value, rel=EQ, prim_name=None, excp_cls=ValueError): 273 """ 274 Method for judging relation between two int values or list/tuple made up of ints. 275 This method is not suitable for judging relation between floats, since it does not consider float error. 276 """ 277 def _check(): 278 if not _check_binary_rel(arg_value, value, rel): 279 rel_str = _format_str_one_value(f'{value_name}: {value}', rel) 280 msg_prefix = f'For \'{prim_name}\', the' if prim_name else "The" 281 msg_subject = f"{msg_prefix} \'{arg_name}\'" if " " not in arg_name else f"{msg_prefix} {arg_name}" 282 raise excp_cls(f'{msg_subject} should be {rel_str}, but got {arg_value}.') 283 284 _check() 285 return arg_value 286 287 288def check_int(arg_value, value, rel, arg_name=None, prim_name=None): 289 """ 290 Checks input integer value `arg_value` compare to `value`. 291 292 Usage: 293 - number = check_int(number, 0, GE, "number", None) # number >= 0 294 """ 295 return _check_number(arg_value, value, rel, int, arg_name, prim_name) 296 297 298def check_is_int(arg_value, arg_name=None, prim_name=None): 299 """ 300 Checks input value is float type or not. 301 302 Usage: 303 - number = check_is_int(number, int) 304 - number = check_is_int(number, int, "bias") 305 - number = check_is_int(number, int, "bias", "bias_class") 306 """ 307 return check_is_number(arg_value, int, arg_name, prim_name) 308 309 310def check_equal_int(arg_value, value, arg_name=None, prim_name=None): 311 """ 312 Checks input integer value `arg_value` compare to `value`. 313 314 Usage: 315 - number = check_equal_int(number, 0, "number", None) # number == 0 316 """ 317 return _check_number(arg_value, value, EQ, int, arg_name, prim_name) 318 319 320def check_positive_int(arg_value, arg_name=None, prim_name=None): 321 """ 322 Check argument is positive integer, which mean arg_value > 0. 323 324 Usage: 325 - number = check_positive_int(number) 326 - number = check_positive_int(number, "bias") 327 """ 328 return _check_number(arg_value, 0, GT, int, arg_name, prim_name) 329 330 331def check_positive_int_sequence(sequence, arg_name=None, prim_name=None): 332 """ 333 Check argument is positive int sequence, which mean all element > 0 in sequence. 334 335 Usage: 336 - sequence = check_positive_int_sequence(sequence) 337 - sequence = check_positive_int_sequence(sequence, "dims") 338 """ 339 for idx in range(len(sequence)): 340 element = sequence[idx] 341 arg_idx = f"{arg_name if arg_name else 'arg_name'}[{idx}]" 342 _check_number(element, 0, GT, int, arg_idx, prim_name) 343 return sequence 344 345 346def check_negative_int(arg_value, arg_name=None, prim_name=None): 347 """ 348 Check argument is negative integer, which mean arg_value < 0. 349 350 Usage: 351 - number = check_negative_int(number) 352 - number = check_negative_int(number, "bias") 353 """ 354 return _check_number(arg_value, 0, LT, int, arg_name, prim_name) 355 356 357def check_non_positive_int(arg_value, arg_name=None, prim_name=None): 358 """ 359 Check argument is non-negative integer, which mean arg_value <= 0. 360 361 Usage: 362 - number = check_non_positive_int(number) 363 - number = check_non_positive_int(number, "bias") 364 """ 365 return _check_number(arg_value, 0, LE, int, arg_name, prim_name) 366 367 368def check_non_negative_int(arg_value, arg_name=None, prim_name=None): 369 """ 370 Check argument is non-negative integer, which mean arg_value >= 0. 371 372 Usage: 373 - number = check_non_negative_int(number) 374 - number = check_non_negative_int(number, "bias") 375 """ 376 return _check_number(arg_value, 0, GE, int, arg_name, prim_name) 377 378 379def check_non_negative_int_sequence(sequence, arg_name=None, prim_name=None): 380 """ 381 Check argument is positive sequence, which mean all element >= 0 in sequence. 382 383 Usage: 384 - sequence = check_non_negative_int_sequence(sequence) 385 - sequence = check_non_negative_int_sequence(sequence, "dims") 386 """ 387 for idx in range(len(sequence)): 388 element = sequence[idx] 389 arg_idx = f"{arg_name if arg_name else 'arg_name'}[{idx}]" 390 _check_number(element, 0, GE, int, arg_idx, prim_name) 391 return sequence 392 393 394def check_float(arg_value, value, rel, arg_name=None, prim_name=None): 395 """ 396 Checks input float value `arg_value` compare to `value`. 397 398 Usage: 399 - number = check_float(number, 0.0, GE, "number", None) # number >= 0 400 """ 401 return _check_number(arg_value, value, rel, float, arg_name, prim_name) 402 403 404def check_is_float(arg_value, arg_name=None, prim_name=None): 405 """ 406 Checks input value is float type or not. 407 408 Usage: 409 - number = check_is_float(number) 410 - number = check_is_float(number, "bias") 411 - number = check_is_float(number, "bias", "bias_class") 412 """ 413 return check_is_number(arg_value, float, arg_name, prim_name) 414 415 416def check_positive_float(arg_value, arg_name=None, prim_name=None): 417 """ 418 Check argument is positive float, which mean arg_value > 0. 419 420 Usage: 421 - number = check_positive_float(number) 422 - number = check_positive_float(number, "bias") 423 - number = check_positive_float(number, "bias", "bias_class") 424 """ 425 return _check_number(arg_value, 0, GT, float, arg_name, prim_name) 426 427 428def check_positive_float_sequence(sequence, arg_name=None, prim_name=None): 429 """ 430 Check argument is positive sequence, which mean all element > 0 in sequence. 431 432 Usage: 433 - sequence = check_positive_float_sequence(sequence) 434 - sequence = check_positive_float_sequence(sequence, "dims") 435 """ 436 for idx in range(len(sequence)): 437 element = sequence[idx] 438 arg_idx = f"{arg_name if arg_name else 'arg_name'}[{idx}]" 439 _check_number(element, 0, GT, float, arg_idx, prim_name) 440 return sequence 441 442 443def check_negative_float(arg_value, arg_name=None, prim_name=None): 444 """ 445 Check argument is negative float, which mean arg_value < 0. 446 447 Usage: 448 - number = check_negative_float(number) 449 - number = check_negative_float(number, "bias") 450 """ 451 return _check_number(arg_value, 0, LT, float, arg_name, prim_name) 452 453 454def check_non_positive_float(arg_value, arg_name=None, prim_name=None): 455 """ 456 Check argument is non-negative float, which mean arg_value <= 0. 457 458 Usage: 459 - number = check_non_positive_float(number) 460 - number = check_non_positive_float(number, "bias") 461 """ 462 return _check_number(arg_value, 0, LE, float, arg_name, prim_name) 463 464 465def check_non_negative_float(arg_value, arg_name=None, prim_name=None): 466 """ 467 Check argument is non-negative float, which mean arg_value >= 0. 468 469 Usage: 470 - number = check_non_negative_float(number) 471 - number = check_non_negative_float(number, "bias") 472 """ 473 return _check_number(arg_value, 0, GE, float, arg_name, prim_name) 474 475 476def check_number(arg_name, arg_value, value, rel, prim_name): 477 """Number value judgment.""" 478 def _check(): 479 if not _check_binary_rel(arg_value, value, rel): 480 rel_str = _format_str_one_value(value, rel) 481 raise ValueError(f'For \'{prim_name}\', the argument \'{arg_name}\' ' \ 482 f'must {rel_str}, but got {arg_value}.') 483 _check() 484 return arg_value 485 486 487def check_isinstance(arg_name, arg_value, classes): 488 """Check arg isinstance of classes""" 489 def _check(): 490 if not isinstance(arg_value, classes): 491 raise ValueError(f'The parameter \'{arg_name}\' must be isinstance of {classes}, but got {arg_value}.') 492 _check() 493 return arg_value 494 495 496def check_bool(arg_value, arg_name=None, prim_name=None): 497 """ 498 Check argument is instance of bool. 499 500 Usage: 501 - has_bias = check_bool(has_bias) 502 - has_bias = check_bool(has_bias, "has_bias") 503 """ 504 prim_name = f"For '{prim_name}', the" if prim_name else 'The' 505 arg_name = f"'{arg_name}'" if arg_name else 'input value' 506 507 def _check(): 508 if not isinstance(arg_value, bool): 509 raise TypeError(f"{prim_name} {arg_name} must be a bool, but got {type(arg_value).__name__}.") 510 _check() 511 return arg_value 512 513 514def check_int_range(arg_value, lower_limit, upper_limit, rel, arg_name=None, prim_name=None): 515 """ 516 Method for checking whether input value is in int range. 517 518 Usage: 519 - number = check_int_range(number, 0, 1, INC_NEITHER) # number in [0, 1] 520 - number = check_int_range(number, 0, 1, INC_NEITHER, "number") # number in [0, 1] 521 """ 522 return check_number_range(arg_value, lower_limit, upper_limit, rel, int, arg_name, prim_name) 523 524 525def check_float_range(arg_value, lower_limit, upper_limit, rel, arg_name=None, prim_name=None): 526 """ 527 Method for checking whether input value is in float range. 528 529 Usage: 530 - number = check_float_range(number, 0.0, 1.0, INC_NEITHER) # number in [0.0, 1.0] 531 - number = check_float_range(number, 0.0, 1.0, INC_NEITHER, "number") # number in [0.0, 1.0] 532 """ 533 return check_number_range(arg_value, lower_limit, upper_limit, rel, float, arg_name, prim_name) 534 535 536def check_string(arg_value, valid_values, arg_name=None, prim_name=None): 537 """ 538 Check whether string is in some value list. 539 540 Usage: 541 - method = check_string(method, ["string1", "string2", "string3"], "method") 542 """ 543 arg_name = arg_name if arg_name else "parameter" 544 msg_prefix = f'For \'{prim_name}\', the' if prim_name else "The" 545 546 def _check(): 547 if not (isinstance(arg_value, str) and arg_value in valid_values): 548 raise ValueError(f"{msg_prefix} '{arg_name}' must be str and must be in '{valid_values}'," \ 549 f" but got '{arg_value}'.") 550 _check() 551 return arg_value 552 553 554def check_str_by_regular(target, reg=None, flag=re.ASCII, prim_name=None): 555 if reg is None: 556 # Named string regular expression 557 reg = r"^\w+[0-9a-zA-Z\_\.]*$" 558 if re.match(reg, target, flag) is None: 559 prim_name = f"For '{prim_name}', the" if prim_name else "The" 560 raise ValueError(f"{prim_name} '{target}' is illegal, it must be match regular'{reg}' by flags'{flag}.'") 561 return True 562 563 564# pylint: disable=missing-docstring 565def check_str_and_none_by_regular(target, reg=None, flag=re.ASCII, prim_name=None): 566 if reg is None: 567 # Named string regular expression 568 reg = r"^\w*[0-9a-zA-Z\_\.\-]*$" 569 if re.match(reg, target, flag) is None: 570 prim_name = f"For '{prim_name}', the" if prim_name else "The" 571 raise ValueError(f"{prim_name} '{target}' is illegal, it must be match regular'{reg}' by flags'{flag}.'") 572 return True 573 574 575def check_file_name_by_regular(target, reg=None, prim_name=None): 576 """Check whether file name is legitimate.""" 577 if not isinstance(target, str): 578 prim_name = f"For '{prim_name}', the" if prim_name else "The" 579 raise TypeError(f"{prim_name} '{target}' must be string, but got {type(target)}.") 580 if target.endswith("\\") or target.endswith("/"): 581 prim_name = f"For '{prim_name}', the" if prim_name else "The" 582 raise ValueError(f"{prim_name} '{target}' cannot be a directory path.") 583 if reg is None: 584 reg = r"^[0-9a-zA-Z@\_\-\.\:\/\\]+$" 585 if re.match(reg, target) is None: 586 prim_name = f"For '{prim_name}', the" if prim_name else "The" 587 raise ValueError(f"{prim_name} '{target}' is illegal, it must be match regular '{reg}'.") 588 589 return True 590 591 592def check_pad_value_by_mode(pad_mode, padding, prim_name): 593 """Validates value of padding according to pad_mode""" 594 if pad_mode != 'pad' and padding != 0: 595 raise ValueError(f"For '{prim_name}', padding must be zero when pad_mode is '{pad_mode}'," \ 596 f" but got {padding}.") 597 return padding 598 599 600def check_subclass(arg_name, type_, template_types, prim_name, addition_error_info=None): 601 """Checks whether some type is subclass of another type""" 602 if not isinstance(template_types, Iterable): 603 template_types = (template_types,) 604 hit = False 605 for template_type in template_types: 606 if isinstance(template_type, mstype.Type): 607 if mstype._issubclass_(type_, template_type): # pylint: disable=W0212 608 hit = True 609 break 610 elif type_ is template_type: 611 hit = True 612 break 613 if not hit: 614 if addition_error_info is None: 615 addition_error_info = '' 616 else: 617 addition_error_info = ' ' + addition_error_info 618 type_str = (f"type '{type(type_).__name__}'" if isinstance(type_, (tuple, list)) else str(type_)) 619 raise TypeError(f"For '{prim_name}', the element of '{arg_name}'" \ 620 f" must be {'one of ' if len(template_types) > 1 else ''}" \ 621 f"{', '.join((str(x) for x in template_types))}, but got {type_str}" \ 622 f"{addition_error_info}.The supported data types depend on the hardware that" \ 623 f" executes the operator, for more details, please refer to the MindSpore official " \ 624 f"website to get more information about the data type.") 625 626 627def check_valid_input(arg_name, arg_value, prim_name): 628 """Checks valid value.""" 629 def _check(): 630 if arg_value is None: 631 raise ValueError(f"For \'{prim_name}\', the argument '{arg_name}'" \ 632 f"can not be None, but got {arg_value}.") 633 _check() 634 return arg_value 635 636 637def check_types_same_and_valid(args, valid_values, prim_name): 638 """Checks whether the types of inputs are the same and valid.""" 639 640 def _check_type_valid(arg): 641 arg_key, arg_val = arg 642 elem_type = arg_val 643 check_subclass(arg_key, elem_type, valid_values, prim_name) 644 return (arg_key, elem_type) 645 646 def _check_types_same(arg1, arg2): 647 arg1_name, arg1_type = arg1 648 arg2_name, arg2_type = arg2 649 if arg1_type != arg2_type: 650 raise TypeError(f"For '{prim_name}', the type of '{arg2_name}' should be same as '{arg1_name}'," \ 651 f" but got '{arg1_name}' with type {arg1_type}" \ 652 f" and '{arg2_name}' with type {arg2_type}.") 653 return arg1 654 655 elem_types = map(_check_type_valid, args.items()) 656 reduce(_check_types_same, elem_types) 657 658 659def check_tensors_dtypes_same_and_valid(args, valid_dtypes, prim_name): 660 """Checks whether the element types of input tensors are the same and valid.""" 661 valid_dtypes = valid_dtypes if isinstance(valid_dtypes, Iterable) else [valid_dtypes] 662 tensor_types = [mstype.TensorType(t) for t in valid_dtypes] 663 check_types_same_and_valid(args, tensor_types, prim_name) 664 665 666def check_tensor_dtype_valid(arg_name, arg_type, valid_dtypes, prim_name): 667 """Checks whether the element types of input tensors are valid.""" 668 valid_dtypes = valid_dtypes if isinstance(valid_dtypes, Iterable) else [valid_dtypes] 669 tensor_types = [mstype.TensorType(t) for t in valid_dtypes] 670 check_subclass(arg_name, arg_type, tensor_types, prim_name) 671 672 673def check_scalar_or_tensor_types_same(args, valid_values, prim_name, allow_mix=False): 674 """ 675 Checks whether the types of inputs are the same. If the input args are tensors, checks their element types. 676 If `allow_mix` is True, Tensor(float32) and float32 are type compatible, otherwise an exception will be raised. 677 """ 678 679 def _check_argument_type(arg): 680 arg_key, arg_val = arg 681 if isinstance(arg_val, type(mstype.tensor_type)): 682 arg_val = arg_val.element_type() 683 if arg_val not in valid_values: 684 raise TypeError(f'For \'{prim_name}\', the type of \'{arg_key}\' must be in {valid_values},' \ 685 f' but got {arg_val}.') 686 return arg 687 688 def _check_types_same(arg1, arg2): 689 arg1_name, arg1_type = arg1 690 arg2_name, arg2_type = arg2 691 except_flag = False 692 if isinstance(arg1_type, type(mstype.tensor_type)) and isinstance(arg2_type, type(mstype.tensor_type)): 693 arg1_type = arg1_type.element_type() 694 arg2_type = arg2_type.element_type() 695 elif not (isinstance(arg1_type, type(mstype.tensor_type)) or isinstance(arg2_type, type(mstype.tensor_type))): 696 pass 697 elif allow_mix: 698 arg1_type = arg1_type.element_type() if isinstance(arg1_type, type(mstype.tensor_type)) else arg1_type 699 arg2_type = arg2_type.element_type() if isinstance(arg2_type, type(mstype.tensor_type)) else arg2_type 700 else: 701 except_flag = True 702 703 if except_flag or arg1_type != arg2_type: 704 raise TypeError(f"For '{prim_name}', the type of '{arg2_name}' must be same as '{arg1_name}'," \ 705 f" but got '{arg1_name}' with type {arg1_type}" \ 706 f" and '{arg2_name}' with type {arg2_type}.") 707 return arg1 708 709 args_map = map(_check_argument_type, args.items()) 710 reduce(_check_types_same, args_map) 711 712 713def check_value_type(arg_name, arg_value, valid_types, prim_name=None): 714 """Checks whether a value is instance of some types.""" 715 valid_types = valid_types if isinstance(valid_types, Iterable) else (valid_types,) 716 717 def raise_error_msg(cond, arg_value): 718 """func for raising error message when check failed""" 719 if not cond: 720 return 721 type_names = [t.__name__ if hasattr(t, '__name__') else str(t) for t in valid_types] 722 num_types = len(valid_types) 723 msg_prefix = f"For '{prim_name}', the" if prim_name else "The" 724 raise TypeError(f'{msg_prefix} type of \'{arg_name}\' should be {"one of " if num_types > 1 else ""}' \ 725 f'\'{type_names if num_types > 1 else type_names[0]}\', ' \ 726 f'but got type \'{type(arg_value).__name__}\'.') 727 728 # Notice: bool is subclass of int, so `check_value_type('x', True, [int])` will check fail, and 729 # `check_value_type('x', True, [bool, int])` will check pass 730 cond = isinstance(arg_value, bool) and bool not in tuple(valid_types) 731 raise_error_msg(cond, arg_value) 732 if isinstance(arg_value, float) and float not in tuple(valid_types): 733 arg_value = round(arg_value, 6) 734 cond = not isinstance(arg_value, tuple(valid_types)) 735 raise_error_msg(cond, arg_value) 736 return arg_value 737 738 739def check_type_name(arg_name, arg_type, valid_types, prim_name): 740 """Checks whether a type in some specified types""" 741 valid_types = valid_types if isinstance(valid_types, Iterable) else (valid_types,) 742 743 def raise_error_msg(cond, arg_type): 744 """func for raising error message when check failed""" 745 if not cond: 746 return 747 type_names = [t.__name__ if hasattr(t, '__name__') else t for t in valid_types] 748 num_types = len(valid_types) 749 msg_prefix = f"For '{prim_name}', the" if prim_name else "The" 750 raise TypeError(f"{msg_prefix} '{arg_name}' should be {'one of ' if num_types > 1 else ''}" \ 751 f"{type_names if num_types > 1 else type_names[0]}, " \ 752 f"but got '{arg_type.__name__ if hasattr(arg_type, '__name__') else repr(arg_type)}'.") 753 754 if isinstance(arg_type, type(mstype.tensor_type)): 755 arg_type = arg_type.element_type() 756 cond = arg_type not in valid_types 757 raise_error_msg(cond, arg_type) 758 return arg_type 759 760 761def check_reduce_shape(ori_shape, shape, axis, prim_name, arg_name1, arg_name2): 762 """Checks whether shape is ori_shape reduced on axis""" 763 axis_origin = axis 764 axis = axis if isinstance(axis, Iterable) else (axis,) 765 exp_shape = [ori_shape[i] for i in range(len(ori_shape)) if i not in axis] 766 if list(shape) != exp_shape: 767 raise ValueError(f"For '{prim_name}', " \ 768 f"the shape of parameter '{arg_name1}' reduce on 'axis': {axis_origin} must " \ 769 f"be equal to the shape of '{arg_name2}': {shape}, but got {ori_shape}.") 770 771 772def check_astype_dtype(dtype): 773 """Check whether dtype is a valid input, and convert to mstype""" 774 all_types = mstype.__dtype__ + ["int", "float", "bool"] 775 if isinstance(dtype, str): 776 if dtype.lower() not in all_types: 777 raise TypeError(f"For Tensor.astype, the input type must be one of {all_types}, but got '{dtype}'.") 778 dtype = mstype.pytype_to_dtype(np.dtype(dtype.lower())) 779 elif isinstance(dtype, type): 780 dtype = mstype.pytype_to_dtype(dtype) 781 elif not dtype in mstype.number_type + (mstype.bool_,): 782 raise TypeError(f"For Tensor.astype, the input type must be one of {mstype.number_type + (mstype.bool_,)}," \ 783 f" but got '{dtype}'.") 784 return dtype 785 786 787def check_transpose_axis(axes, ndim): 788 """Check the axis argument for tensor.transpose""" 789 def _check_dim(): 790 # if multiple arguments provided, it must be `ndim` number of ints 791 if len(axes) != ndim: 792 raise ValueError(f"For Tensor.transpose, the number of axes must be equal to the dimension of Tensor, " \ 793 f"but got {len(axes)} in the number of axes.") 794 795 if not axes or (len(axes) == 1 and axes[0] is None): 796 return tuple(range(ndim-1, -1, -1)) 797 798 if len(axes) == 1: 799 perm = axes[0] 800 # if only one argument provided, it must be tuple or list 801 if isinstance(perm, list): 802 perm = tuple(perm) 803 elif isinstance(perm, int): 804 perm = (perm,) 805 _check_dim() 806 else: 807 if not isinstance(perm, tuple): 808 raise TypeError(f"For Tensor.transpose, the parameter 'axes' must be a tuple/list, " \ 809 f"or series of integer, but got {type(axes[0])}") 810 return perm 811 812 _check_dim() 813 return axes 814 815 816def check_reshape_shp(shp): 817 """Check the shape argument for tensor.reshape""" 818 819 if len(shp) == 1: 820 new_shape = shp[0] 821 # if only one argument provided, it must be int, tuple or list 822 if isinstance(new_shape, int): 823 return shp 824 if isinstance(new_shape, list): 825 new_shape = tuple(new_shape) 826 else: 827 if not isinstance(new_shape, tuple): 828 raise TypeError( 829 f"For Tensor.reshape, the parameter 'shape' must be an integer, or tuple/list, " \ 830 f"or series of integer, but got {type(shp[0])}") 831 return new_shape 832 833 return shp 834 835 836def check_flatten_order(order): 837 """Check flatten function input order""" 838 if not isinstance(order, str): 839 raise TypeError(f"For Tensor.flatten, the parameter 'order' must be a string, but got {type(order)}") 840 if order not in ('C', 'F'): 841 raise ValueError(f"For Tensor.flatten, the parameter 'order' must be 'C' or 'F', but got '{order}'") 842 843 844def check_swapaxes_axis(axes, ndim): 845 """Check all the axes argument for ops.swapaxes""" 846 if isinstance(axes, int): 847 return check_axis_in_range(axes, ndim) 848 if isinstance(axes, (tuple, list)): 849 for axis in axes: 850 if not isinstance(axis, int): 851 raise TypeError(f"For ops.swapaxes, the axis argument must be integer, but got {type(axis)}.") 852 check_axis_in_range(axis, ndim) 853 tmp = () 854 for x in axes: 855 tmp = tmp + ((x + ndim) % ndim,) 856 return tmp 857 raise TypeError(f"For ops.swapaxes, the argument 'axes' must be integer, list or tuple for check, " \ 858 f"but got {type(axes)}.") 859 860 861def prepare_shape_for_squeeze(shape, axes): 862 """ 863 Creates the squeezed new shape based on the tensor and given axes. 864 865 Args: 866 shape (tuple): the shape of the tensor 867 axes Union[int, tuple(int), list(int)]: the axes with dimensions need to 868 be squeezed. 869 870 Returns: 871 new_shape(tuple): the shape with dimensions squeezed. 872 """ 873 new_shape = () 874 ndim = len(shape) 875 876 def _check(axes, ndim): 877 if axes >= ndim or axes < -ndim: 878 raise ValueError(f"For Tensor.squeeze, the 'axis' must be in the range of [-{ndim}, {ndim}), " \ 879 f"but got {axes}.") 880 881 def _check_for(axes, ndim): 882 for axis in axes: 883 _check(axis, ndim) 884 885 if isinstance(axes, int): 886 _check(axes, ndim) 887 axes = (axes,) 888 elif isinstance(axes, (list, tuple)): 889 _check_for(axes, ndim) 890 new_axes = () 891 for item in axes: 892 if item not in new_axes: 893 new_axes += (item,) 894 axes = new_axes 895 else: 896 raise TypeError(f"For Tensor.squeeze, the parameter 'axes' must be one of [int, tuple, list], " \ 897 f"but got {type(axes)}") 898 899 def _check_axis(s, idx, axes, ndim): 900 # if an axis is selected with shape entry greater than one, an error is raised. 901 if s != 1 and ((idx in axes) or (idx - ndim in axes)): 902 raise ValueError(f"For Tensor.squeeze, the shape of parameter 'axis' {axes} must be 1, but got {s}.") 903 904 for idx in range(ndim): 905 s = shape[idx] 906 _check_axis(s, idx, axes, ndim) 907 if s != 1 or (idx not in axes) and (idx - ndim not in axes): 908 new_shape = new_shape + (s,) 909 910 return new_shape 911 912 913def check_axis_in_range(axis, ndim): 914 """Checks axes are with the bounds of ndim""" 915 def _check(): 916 if not isinstance(axis, int): 917 raise TypeError(f'The axes must be integers, but got {type(axis)}') 918 919 if axis >= ndim or axis < -ndim: 920 raise ValueError(f"The 'axis' must be in the range of [-{ndim}, {ndim}), but got {axis}.") 921 922 _check() 923 return (axis + ndim) % ndim 924 925 926def check_axis_valid(axes, ndim): 927 """ 928 Checks axes are valid given ndim, and returns axes that can be passed 929 to the built-in operator (non-negative, int or tuple) 930 """ 931 def _check_range(axes): 932 for axis in axes: 933 check_axis_in_range(axis, ndim) 934 935 if axes is None: 936 axes = tuple(range(ndim)) 937 return axes 938 if isinstance(axes, (tuple, list)): 939 _check_range(axes) 940 tmp = () 941 for x in axes: 942 tmp = tmp + ((x + ndim) % ndim,) 943 _check_dup(tmp) 944 return tmp 945 check_axis_in_range(axes, ndim) 946 return (axes % ndim,) 947 948 949def max_(*args): 950 """Return the maximum value of the input parameter.""" 951 return max(*args) 952 953 954def min_(*args): 955 """Return the minimum value of the input parameter.""" 956 return min(*args) 957 958 959def is_stub_tensor(tensor): 960 return hasattr(tensor, "stub") 961 962 963def expanded_shape(ndim, axis_size, axis): 964 """ 965 Returns a shape with size = 1 for all dimensions 966 except at axis. 967 """ 968 return tuple(axis_size if i == axis else 1 for i in range(ndim)) 969 970 971def tuple_slice(tup, start, end): 972 """get sliced tuple from start and end.""" 973 return tup[start:end] 974 975 976def infer_out_shape(*shapes): 977 """ 978 Returns shape of output after broadcasting. Raises ValueError if shapes cannot be broadcast. 979 """ 980 def _check(items, max_size, shapes): 981 for item in items: 982 if item not in (1, max_size): 983 raise ValueError(f'For Tensor, the dimension on each axis must be 1 or the max value on the axis' \ 984 f'to support broadcasting, but got shapes {shapes,}') 985 shape_out = () 986 max_len = max([len(it) for it in shapes]) 987 for i in range(max_len): 988 items = [it[i-(max_len-len(it))] if i - (max_len - len(it)) 989 >= 0 else 1 for it in shapes] 990 max_size = 0 if 0 in items else max(items) 991 _check(items, max_size, shapes) 992 shape_out = shape_out + (max_size,) 993 return shape_out 994 995 996def check_axis_type(axis, type_int=True, type_tuple=True, type_list=True): 997 """Check axis argument type.""" 998 if type_int and isinstance(axis, int): 999 return True 1000 if (type_tuple and isinstance(axis, tuple)) or (type_list and isinstance(axis, list)): 1001 for ax in axis: 1002 if not isinstance(ax, int): 1003 raise TypeError(f"For Tensor.ptp, each axis must be integer, but got {type(ax)} in {axis}.") 1004 return True 1005 1006 type_str = "" 1007 if type_int: 1008 type_str += "int, " 1009 if type_tuple: 1010 type_str += "tuple, " 1011 if type_list: 1012 type_str += "list, " 1013 raise TypeError(f"For Tensor.ptp, the axis should be {type_str}, but got {type(axis)}.") 1014 1015 1016def check_and_canonicalize_axes(axes, ndim): 1017 """Check whether the types and values of input axes are valid.""" 1018 def _check(axes, ax, ndim): 1019 if not isinstance(ax, int): 1020 raise TypeError(f"Each axis should be integer, but got {type(ax)} in {axes}.") 1021 if ax >= ndim or ax < -ndim: 1022 raise ValueError(f"The 'axis' must be in the range of [-{ndim}, {ndim}), but got {ax}.") 1023 1024 axes = axes if isinstance(axes, tuple) else (axes,) 1025 new_axes = () 1026 for ax in axes: 1027 _check(axes, ax, ndim) 1028 ax = ax if ax >= 0 else ax + ndim 1029 new_axes += (ax,) 1030 _check_dup(new_axes) 1031 return new_axes 1032 1033 1034def check_type_support(dtype, device, supported_dtypes): 1035 """Checks whether the data type is supported.""" 1036 return dtype in supported_dtypes or not context.get_context('device_target') == device 1037 1038 1039def check_sparse_tensor_input(indices, values, shape): 1040 """Common input check for SparseTensors.""" 1041 if not isinstance(indices, Tensor_) and not is_stub_tensor(indices): 1042 raise TypeError(f"For SparseTensors, 'indices' must be Tensor, but got {type(indices)}.") 1043 if not isinstance(values, Tensor_) and not is_stub_tensor(values): 1044 raise TypeError(f"For SparseTensors, 'values' must be Tensor, but got {type(values)}.") 1045 if not isinstance(shape, tuple): 1046 raise TypeError(f"For SparseTensors, 'shape' must be tuple, but got {type(shape)}.") 1047 1048 1049def check_csr_tensor_input(indptr, indices, values, shape): 1050 """Checks inputs type for CSRTensor.""" 1051 if not isinstance(indptr, Tensor_) and not is_stub_tensor(indptr): 1052 raise TypeError(f"For CSRTensor, 'indptr' must be Tensor, but got {type(indptr)}.") 1053 check_sparse_tensor_input(indices, values, shape) 1054 1055 1056def check_csr_tensor_shape(indptr_shp, indices_shp, values_shp, csr_shp): 1057 """Checks input tensors' shapes for CSRTensor.""" 1058 # Support empty sparse tensor 1059 if (indptr_shp == (0,)) and (indices_shp == (0,)) and (values_shp == (0,)): 1060 return 1061 shape_size = 1 1062 val_shp_size = 1 1063 for item in csr_shp: 1064 if item <= 0: 1065 raise ValueError(f"For CSRTensor, the element of shape must be positive, but got {item}") 1066 if not isinstance(item, int): 1067 raise TypeError(f"For CSRTensor, the element type of shape must be int, but got {type(item)}") 1068 shape_size *= item 1069 for item in values_shp: 1070 if item <= 0: 1071 raise ValueError(f"The element of shape must be positive, but got {item}") 1072 val_shp_size *= item 1073 if shape_size < val_shp_size: 1074 raise ValueError(f"Shape total size: {shape_size} is too small to hold {val_shp_size} non-zero values.") 1075 if len(indices_shp) != 1: 1076 raise ValueError(f"For CSRTensor, indices must be a 1-dimensional tensor, " \ 1077 f"but got a {len(indices_shp)} dimension tensor.") 1078 if len(indptr_shp) != 1: 1079 raise ValueError(f"For CSRTensor, indptr must be a 1-dimensional tensor, " \ 1080 f"but got a {len(indptr_shp)} dimension tensor.") 1081 if csr_shp[0] + 1 != indptr_shp[0]: 1082 raise ValueError(f"For CSRTensor, indptr must have length (1 + shape[0]), " \ 1083 f"but got: {indptr_shp[0]}") 1084 if indices_shp[0] != values_shp[0]: 1085 err_msg1 = "For CSRTensor, indices and values must equal in their shape, " 1086 err_msg2 = f"but got indices shape: {indices_shp[0]}, values shape: {values_shp[0]}." 1087 raise ValueError(err_msg1 + err_msg2) 1088 if len(values_shp) + 1 != len(csr_shp): 1089 raise ValueError(f"Values' dimension should equal to CSRTensor's dimension - 1, but got" \ 1090 f"Values' dimension: {len(values_shp)} , CSRTensor's dimension: " \ 1091 f"{len(csr_shp)}") 1092 if values_shp[1:] != csr_shp[2:]: 1093 raise ValueError(f"CSRTensor's shape[2: ] must be equal to value's shape[1: ]," \ 1094 f"but CSRTensor's shape[2: ] got: {csr_shp[2: ]} and value's shape[1: ]" \ 1095 f"got: {values_shp[1: ]}") 1096 1097 1098def check_csr_tensor_dtype(indptr_dtype, indices_dtype): 1099 """Checks input tensors' data types for CSRTensor.""" 1100 if indptr_dtype not in (mstype.int16, mstype.int32, mstype.int64): 1101 raise TypeError(f"For CSRTensor, indptr must have int16 or int32 or int64 data type, " \ 1102 f"but got {indptr_dtype}.") 1103 if indices_dtype not in (mstype.int16, mstype.int32, mstype.int64): 1104 raise TypeError(f"For CSRTensor, indices must have int16 or int32 or int64 data type, " \ 1105 f"but got {indices_dtype}.") 1106 1107 1108def check_coo_tensor_input(indices, values, shape): 1109 """Checks inputs type for COOTensor.""" 1110 check_sparse_tensor_input(indices, values, shape) 1111 1112 1113def check_coo_tensor_shape(indices_shp, values_shp, coo_shp): 1114 """Checks input tensors' shapes for COOTensor.""" 1115 if len(coo_shp) != 2: 1116 raise ValueError(f"For COOTensor, the length of 'shape' must be 2, but got {coo_shp}.") 1117 if (indices_shp == (0,)) and (values_shp == (0,)): 1118 return 1119 shp_mul = 1 1120 for sh in coo_shp: 1121 if sh <= 0: 1122 raise ValueError(f"For COOTensor, the element of 'shape' must be positive, but got {sh} in {coo_shp}.") 1123 if not isinstance(sh, int): 1124 raise TypeError(f"For COOTensor, the element type of 'shape' must be int, but got {type(sh)}") 1125 shp_mul *= sh 1126 if shp_mul < values_shp[0]: 1127 raise ValueError(f"For COOTensor, shape is too small: ({shp_mul}) to hold all values({values_shp[0]}).") 1128 if len(indices_shp) != 2: 1129 raise ValueError(f"For COOTensor, 'indices' must be a 2-dimensional tensor, but got a {len(indices_shp)}" \ 1130 f"-dimensional tensor.") 1131 if len(values_shp) != 1: 1132 raise ValueError(f"For COOTensor, 'values' must be a 1-dimensional tensor, but got a {len(values_shp)}" \ 1133 f"-dimensional tensor.") 1134 if indices_shp[0] != values_shp[0]: 1135 raise ValueError(f"For COOTensor, 'indices.shape[0]' must be euqal to 'values.shape[0]', but got " \ 1136 f"'indices.shape[0]' = {indices_shp[0]} and 'values.shape[0]' = {values_shp[0]}.") 1137 if indices_shp[1] != 2: 1138 raise ValueError(f"For COOTensor, 'indices.shape[1]' must be 2, but got {indices_shp[1]}.") 1139 1140 1141def check_coo_tensor_dtype(indices_dtype): 1142 """Checks input tensors' data types for COOTensor.""" 1143 if indices_dtype not in (mstype.int16, mstype.int32, mstype.int64): 1144 raise TypeError(f"For COOTensor, the type of 'indices' must be one of [int16, int32, int64], but got " \ 1145 f"{indices_dtype}.") 1146 1147 1148def check_element_type_of_iterable(arg_name, arg_value, valid_types, prim_name=None): 1149 """Check type of the element of a iterabel object, except dict.""" 1150 check_value_type(arg_name, arg_value, [list, tuple], prim_name) 1151 type_names = [t.__name__ if hasattr(t, '__name__') else str(t) for t in valid_types] 1152 num_types = len(valid_types) 1153 msg_prefix = f"For '{prim_name}', the" if prim_name else "The" 1154 for element in arg_value: 1155 if not isinstance(element, tuple(valid_types)): 1156 raise TypeError(f"{msg_prefix} type of '{arg_name}' should be {'one of ' if num_types > 1 else ''}" \ 1157 f"{type_names if num_types > 1 else type_names[0]}, " \ 1158 f"but got '{element}' with type '{type(element).__name__}'.") 1159 1160 1161def check_element_type_of_dict(arg_name, arg_value, key_types, value_types, prim_name=None): 1162 """Check the type of key and value of a dict.""" 1163 check_value_type(arg_name, arg_value, [dict], prim_name) 1164 msg_prefix = f"For '{prim_name}', the" if prim_name else "The" 1165 type_names = [t.__name__ if hasattr(t, '__name__') else str(t) for t in key_types] 1166 num_types = len(key_types) 1167 for element in arg_value.keys(): 1168 if not isinstance(element, tuple(key_types)): 1169 raise TypeError(f"{msg_prefix} type of '{arg_name}' should be {'one of ' if num_types > 1 else ''}" \ 1170 f"{type_names if num_types > 1 else type_names[0]}, " \ 1171 f"but got '{element}' with type '{type(element).__name__}'.") 1172 1173 type_names = [t.__name__ if hasattr(t, '__name__') else str(t) for t in value_types] 1174 num_types = len(value_types) 1175 for element in arg_value.values(): 1176 if not isinstance(element, tuple(value_types)): 1177 raise TypeError(f"{msg_prefix} type of '{arg_name}' should be {'one of ' if num_types > 1 else ''}" \ 1178 f"{type_names if num_types > 1 else type_names[0]}, " \ 1179 f"but got '{element}' with type '{type(element).__name__}'.") 1180 1181 1182def check_size_and_element_type_of_tuple(arg_name, arg_value, expect_size, expect_element_type, prim_name=None): 1183 """Check the size and element type of a tuple.""" 1184 check_value_type(arg_name, arg_value, [tuple], prim_name) 1185 check_equal_int(len(arg_value), expect_size, arg_name + ' size', prim_name) 1186 check_element_type_of_iterable('arg_name', arg_value, [expect_element_type], prim_name) 1187 1188 1189def _check_symbol(dyn_input, net_input, index, symbolic_shape_data): 1190 """Check symbolic shape values.""" 1191 actual_shape = net_input.shape 1192 for i, sym in enumerate(dyn_input.symbolic_shape): 1193 # the Symbol is converted to dict 1194 if not isinstance(sym, dict): 1195 continue 1196 # the value of symbols with same "id" should be equal. 1197 if "id" in sym: 1198 sym_id = sym["id"] 1199 k_idval = "unique_id_value_map" 1200 if k_idval not in symbolic_shape_data: 1201 symbolic_shape_data[k_idval] = {} 1202 unique_id_value = symbolic_shape_data[k_idval] 1203 if sym_id not in unique_id_value: 1204 unique_id_value[sym_id] = actual_shape[i] 1205 elif unique_id_value[sym_id] != actual_shape[i]: 1206 raise ValueError( 1207 f"The {i + 1}th shape value of {index + 1}th actual input args is a unique symbol, all values must " 1208 f"be the same. The previous value is {unique_id_value[sym_id]}, but the current value is " 1209 f"{actual_shape[i]}. Actual shape: {actual_shape}, axis: {i}.") 1210 # check the value in range [min, max]. 1211 if "min" in sym and actual_shape[i] < sym["min"]: 1212 raise ValueError( 1213 f"The {i + 1}th shape value of {index + 1}th actual input args must be greater than or equal to the " 1214 f"'min' value '{sym['min']}' of `Symbol`, but got '{actual_shape[i]}'. Actual shape: {actual_shape}, " 1215 f"axis: {i}.") 1216 if "max" in sym and actual_shape[i] > sym["max"]: 1217 raise ValueError( 1218 f"The {i + 1}th shape value of {index + 1}th actual input args must be less than or equal to the " 1219 f"'max' value '{sym['max']}' of `Symbol`, but got '{actual_shape[i]}'. Actual shape: {actual_shape}, " 1220 f"axis: {i}.") 1221 # check the shape item that satisfies the "divisor * N + remainder, N >= 1". 1222 d = sym.get("divisor", 1) 1223 r = sym.get("remainder", 0) 1224 if actual_shape[i] < d or actual_shape[i] % d != r: 1225 raise ValueError( 1226 f"The {i + 1}th shape value of {index + 1}th actual input args must be match the 'divisor'(d) and " 1227 f"'remainder'(r) of `Symbol`. The value should be 'd * N + r' for 'N > 0', got d={d} and r={r}, but " 1228 f"actual shape value is '{actual_shape[i]}'. Actual shape: {actual_shape}, axis: {i}") 1229 1230 1231def check_symbolic_shape(dynamic_inputs, actual_inputs): 1232 """Check the symboic shape""" 1233 symbolic_shape_data = {} 1234 1235 def run_check(dyn_inputs, net_inputs): 1236 """the real checking function""" 1237 for index, (dyn_input, net_input) in enumerate(zip(dyn_inputs, net_inputs)): 1238 if isinstance(dyn_input, (tuple, list)): 1239 run_check(dyn_input, net_input) 1240 elif hasattr(dyn_input, "symbolic_shape"): 1241 _check_symbol(dyn_input, net_input, index, symbolic_shape_data) 1242 1243 run_check(dynamic_inputs, actual_inputs) 1244 1245 1246def check_input_format(input_param): 1247 """Judge input format.""" 1248 if input_param == "NCHW": 1249 return input_param 1250 raise ValueError(f"The data format must be NCHW, but got {input_param}.") 1251 1252 1253def _expand_tuple(n_dimensions): 1254 """To expand an int number to tuple.""" 1255 1256 def convert(m): 1257 if not isinstance(m, tuple): 1258 if isinstance(m, int) and not isinstance(m, bool): 1259 return tuple(repeat(m, n_dimensions)) 1260 raise TypeError(f"When expanding an int number to tuple, input type must be integer or tuple[int], " \ 1261 f"but got {type(m)}") 1262 1263 if not len(m) is n_dimensions: 1264 raise TypeError(f"When expanding an int number to tuple, input tuple dimension must be {n_dimensions}, " \ 1265 f"but got {m}") 1266 1267 for i in m: 1268 if not isinstance(i, int) or isinstance(i, bool): 1269 raise TypeError(f"When expanding an int number to tuple, " \ 1270 f"the type of element in input tuple must be an integer, but got {type(i)}.") 1271 return m 1272 1273 return convert 1274 1275 1276def _check_data_type_valid(data, valid_type): 1277 """Check data type valid.""" 1278 if valid_type is None: 1279 return data is None 1280 if isinstance(data, valid_type): 1281 if hasattr(data, 'size') and data.size == 0: 1282 msg = "The input data can not be empty." 1283 logger.critical(msg) 1284 raise ValueError(msg) 1285 return True 1286 return False 1287 1288 1289def check_input_data(*data, data_class): 1290 """Input data check.""" 1291 for item in data: 1292 if isinstance(item, (list, tuple)): 1293 for v in item: 1294 check_input_data(v, data_class=data_class) 1295 elif isinstance(item, dict): 1296 for v in item.values(): 1297 check_input_data(v, data_class=data_class) 1298 else: 1299 if isinstance(data_class, (tuple, list)): 1300 ret = True in tuple(_check_data_type_valid(item, data_type) for data_type in data_class) 1301 else: 1302 ret = _check_data_type_valid(item, data_class) 1303 if not ret: 1304 data_class_str = tuple(i.__name__ if hasattr(i, '__name__') else i for i in data_class) if isinstance( 1305 data_class, (tuple, list)) else (data_class if data_class is None else data_class.__name__) 1306 raise TypeError(f'The types of input data must be in the Union({data_class_str}, ' \ 1307 f'tuple[{data_class_str}], list[{data_class_str}], dict[{data_class_str}]), ' \ 1308 f'but got type {item if item is None else type(item).__name__}.') 1309 1310 1311def check_input_dataset(*dataset, dataset_type): 1312 """Input dataset check.""" 1313 if not dataset: 1314 return False 1315 for item in dataset: 1316 if not isinstance(item, dataset_type): 1317 return False 1318 return True 1319 1320 1321def check_output_data(data): 1322 """Output data check.""" 1323 if data is None: 1324 raise RuntimeError('The output data can not be None, please check your net or input data.') 1325 1326 1327once = _expand_tuple(1) 1328twice = _expand_tuple(2) 1329triple = _expand_tuple(3) 1330 1331 1332def args_type_check(*type_args, **type_kwargs): 1333 """Check whether input data type is correct.""" 1334 1335 def type_check(func): 1336 sig = inspect.signature(func) 1337 bound_types = sig.bind_partial(*type_args, **type_kwargs).arguments 1338 1339 @wraps(func) 1340 def wrapper(*args, **kwargs): 1341 nonlocal bound_types 1342 bound_values = sig.bind(*args, **kwargs) 1343 argument_dict = bound_values.arguments 1344 if "kwargs" in bound_types: 1345 bound_types = bound_types["kwargs"] 1346 if "kwargs" in argument_dict: 1347 argument_dict = argument_dict["kwargs"] 1348 for name, value in argument_dict.items(): 1349 if name in bound_types: 1350 if value is not None and not isinstance(value, bound_types[name]): 1351 raise TypeError(f"The parameter '{name}' must be {bound_types[name]}, but got {type(value)}") 1352 return func(*args, **kwargs) 1353 1354 return wrapper 1355 1356 return type_check 1357 1358 1359def check_hook_fn(hook_type, hook_fn): 1360 """Check hook fn""" 1361 if context.get_context("mode") != context.PYNATIVE_MODE: 1362 logger.warning(f"'{hook_type}' function is only supported in pynative mode, you can use " 1363 f"context.set_context to set pynative mode.") 1364 return False 1365 1366 if not isinstance(hook_fn, (FunctionType, MethodType)): 1367 raise TypeError(f"When using 'hook_type(hook_fn)', the type of 'hook_fn' must be python " 1368 f"function, but got {type(hook_fn)}.") 1369 1370 if hook_fn.__code__.co_name == "staging_specialize": 1371 raise TypeError(f"Decorating hook function {hook_fn.__name__} with '@jit' is not supported.") 1372 1373 if hook_type == "register_hook" and hook_fn.__code__.co_argcount != 1: 1374 raise TypeError(f"Tensor hook function {hook_fn.__name__} arg num is not equal to 1.") 1375 1376 return True 1377 1378_set_record = {} 1379