• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020-2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Check parameters."""
16from __future__ import absolute_import
17
18import re
19import inspect
20import math
21from types import FunctionType, MethodType
22from functools import reduce, wraps
23from itertools import repeat
24from collections.abc import Iterable
25import numpy as np
26
27from mindspore import context
28from mindspore import log as logger
29from mindspore.common import dtype as mstype
30from mindspore._c_expression import Tensor as Tensor_
31
32
33EQ = 1  # ==
34NE = 2  # !=
35LT = 3  # <
36LE = 4  # <=
37GT = 5  # >
38GE = 6  # >=
39# scalar range check
40INC_NEITHER = 7  # (), include neither
41INC_LEFT = 8  # [), include left
42INC_RIGHT = 9  # (], include right
43INC_BOTH = 10  # [], include both
44# collection in, not in
45IN = 11
46NOT_IN = 12
47
48
49def _check_binary_rel(val1, val2, rel):
50    """check binary relation"""
51    if rel == EQ:
52        return val1 == val2
53    if rel == NE:
54        return val1 != val2
55    if rel == LT:
56        return val1 < val2
57    if rel == LE:
58        return val1 <= val2
59    if rel == GT:
60        return val1 > val2
61    if rel == GE:
62        return val1 >= val2
63    if rel == IN:
64        return val1 in val2
65    if rel == NOT_IN:
66        return val1 not in val2
67
68    return False
69
70
71def _check_inc_rel(val, lower, upper, rel):
72    """check include relation"""
73    if rel == INC_NEITHER:
74        return not (val <= lower or val >= upper)
75    if rel == INC_LEFT:
76        return not (val < lower or val >= upper)
77    if rel == INC_RIGHT:
78        return not (val <= lower or val > upper)
79    if rel == INC_BOTH:
80        return not (val < lower or val > upper)
81
82    return False
83
84
85def _format_str_one_value(value, rel):
86    """format string"""
87    if rel == EQ:
88        return f"= {value}"
89    if rel == NE:
90        return f"!= {value}"
91    if rel == LT:
92        return f"< {value}"
93    if rel == LE:
94        return f"<= {value}"
95    if rel == GT:
96        return f"> {value}"
97    if rel == GE:
98        return f">= {value}"
99    if rel == IN:
100        return f"in {value}"
101    if rel == NOT_IN:
102        return f"not in {value}"
103
104    return ""
105
106
107def _format_str_two_value(val1, val2, rel):
108    """format string"""
109    if rel == INC_NEITHER:
110        return f"({val1}, {val2})"
111    if rel == INC_LEFT:
112        return f"[{val1}, {val2})"
113    if rel == INC_RIGHT:
114        return f"({val1}, {val2}]"
115    if rel == INC_BOTH:
116        return f"[{val1}, {val2}]"
117
118    return ""
119
120
121def _check_3d_int_or_tuple(arg_name, arg_value, prim_name, allow_five=False, ret_five=False,
122                           greater_zero=True, third_one=False, three_input=False):
123    """
124    Checks whether an argument is a positive int or tuple with 3 or 5(when allow_five is True) positive int elements.
125    """
126
127    def _raise_message(third_one_flag=False, three_input_flag=False):
128        if third_one_flag:
129            raise ValueError(f"For '{prim_name}', the depth of parameter '{arg_name}' must be 1, " \
130                             f"but got {ret_value[-3]}.")
131        if three_input_flag:
132            raise ValueError(f"For '{prim_name}', the parameter '{arg_name}' must be an positive integer " \
133                             f"or a tuple of three positive integer, but got {arg_value}.")
134        raise ValueError(f"For '{prim_name}', the parameter '{arg_name}' must be an positive integer or " \
135                         f"a tuple of three {'or five ' if allow_five else ''}positive integer, but got {arg_value}")
136
137    def _get_return_value():
138        def _check():
139            if not isinstance(arg_value, int):
140                if len(arg_value) == 5:
141                    if not allow_five:
142                        _raise_message()
143                elif not len(arg_value) == 3:
144                    _raise_message()
145
146        _check()
147        if isinstance(arg_value, int):
148            ret = (1, 1, arg_value, arg_value, arg_value) if ret_five else (arg_value, arg_value, arg_value)
149        elif len(arg_value) == 3:
150            ret = (1, 1, arg_value[0], arg_value[1], arg_value[2]) if ret_five else arg_value
151        else: # case: len(arg_value) == 5
152            ret = arg_value if ret_five else (arg_value[2], arg_value[3], arg_value[4])
153
154        return ret
155
156    def _check_value(ret_value):
157        for item in ret_value:
158            if isinstance(item, int) and not isinstance(item, bool):
159                if greater_zero and item > 0:
160                    continue
161                if not greater_zero and item >= 0:
162                    continue
163            _raise_message()
164
165    def _check_third_one(ret_value):
166        if third_one:
167            if ret_value[-3] != 1:
168                _raise_message(third_one_flag=third_one)
169
170    check_value_type(arg_name, arg_value, (int, tuple), prim_name)
171    if three_input and isinstance(arg_value, tuple):
172        if len(arg_value) != 3:
173            _raise_message(three_input_flag=three_input)
174    ret_value = _get_return_value()
175    _check_value(ret_value)
176    _check_third_one(ret_value)
177
178    return tuple(ret_value)
179
180
181def _check_dup(axes):
182    for item in axes:
183        count = 0
184        for item2 in axes:
185            if item == item2:
186                count += 1
187
188        if count > 1:
189            raise ValueError(f"The element of parameter 'axis' can not be duplicate, but got {axes}.")
190
191
192def _check_number(arg_value, value, rel, arg_type=int, arg_name=None, prim_name=None):
193    """
194    Check argument integer.
195
196    Usage:
197    - arg_value = _check_number(arg_value, 2, GT, int, "value", None)
198    """
199    prim_name = f"For \'{prim_name}\', the " if prim_name else 'The '
200    arg_name = f"\'{arg_name}\'" if arg_name else 'input value'
201
202    def _check_param():
203        prim_info = f'{prim_name}' + f'{arg_name}'
204        if isinstance(arg_value, arg_type):
205            if math.isinf(arg_value) or math.isnan(arg_value) or np.isinf(arg_value) or np.isnan(arg_value):
206                raise ValueError(f"{prim_info} must be a legal value, but got '{arg_value}'.")
207        else:
208            raise TypeError(f"{prim_info} must be {arg_type.__name__}, but got '{type(arg_value).__name__}'")
209
210        type_mismatch = not isinstance(arg_value, arg_type) or isinstance(arg_value, bool)
211        rel_ret = _check_binary_rel(arg_value, value, rel)
212        if type_mismatch or not rel_ret:
213            rel_str = _format_str_one_value(value, rel)
214            msg = f"{prim_info} must be {arg_type.__name__} and must {rel_str}, " \
215                  f"but got '{arg_value}' with type '{type(arg_value).__name__}'."
216            if type_mismatch:
217                raise TypeError(msg)
218            raise ValueError(msg)
219
220    _check_param()
221    return arg_value
222
223
224def check_is_number(arg_value, arg_type, arg_name=None, prim_name=None):
225    """
226    Checks input value is float type or not.
227
228    Usage:
229    - number = check_is_number(number, int)
230    - number = check_is_number(number, int, "bias")
231    - number = check_is_number(number, int, "bias", "bias_class")
232    """
233    prim_name = f"For \'{prim_name}\', the" if prim_name else 'The'
234    arg_name = f"\'{arg_name}\'" if arg_name else 'input value'
235
236    def _check_param():
237        if isinstance(arg_value, arg_type) and not isinstance(arg_value, bool):
238            if math.isinf(arg_value) or math.isnan(arg_value) or np.isinf(arg_value) or np.isnan(arg_value):
239                raise ValueError(f"{prim_name} {arg_name} must be a legal float, but got '{arg_value}'.")
240        else:
241            raise TypeError(f"{prim_name} type of {arg_name} must be '{arg_type.__name__}', " \
242                            f"but got '{type(arg_value).__name__}'.")
243    _check_param()
244    return arg_value
245
246
247def check_number_range(arg_value, lower_limit, upper_limit, rel, value_type, arg_name=None, prim_name=None):
248    """
249    Method for checking whether an int value is in some range.
250
251    Usage:
252    - number = check_number_range(number, 0.0, 1.0, INC_NEITHER, "number", float) # number in [0.0, 1.0]
253    - number = check_number_range(number, 0, 1, INC_NEITHER, "number", int) # number in [0, 1]
254    """
255    prim_name = f"For \'{prim_name}\', the" if prim_name else 'The'
256    arg_name = f"\'{arg_name}\'" if arg_name else 'input value'
257
258    def _check_param():
259        type_mismatch = not isinstance(arg_value, (np.ndarray, np.generic, value_type)) or isinstance(arg_value, bool)
260        if type_mismatch:
261            raise TypeError(f"{prim_name} {arg_name} must be '{value_type.__name__}',  " \
262                            f"but got '{type(arg_value).__name__}'.")
263
264        if not _check_inc_rel(arg_value, lower_limit, upper_limit, rel):
265            rel_str = _format_str_two_value(lower_limit, upper_limit, rel)
266            raise ValueError(f"{prim_name} {arg_name} must be in range of {rel_str}, " \
267                             f"but got {arg_value} with type '{type(arg_value).__name__}'.")
268    _check_param()
269    return arg_value
270
271
272def check(arg_name, arg_value, value_name, value, rel=EQ, prim_name=None, excp_cls=ValueError):
273    """
274    Method for judging relation between two int values or list/tuple made up of ints.
275    This method is not suitable for judging relation between floats, since it does not consider float error.
276    """
277    def _check():
278        if not _check_binary_rel(arg_value, value, rel):
279            rel_str = _format_str_one_value(f'{value_name}: {value}', rel)
280            msg_prefix = f'For \'{prim_name}\', the' if prim_name else "The"
281            msg_subject = f"{msg_prefix} \'{arg_name}\'" if " " not in arg_name else f"{msg_prefix} {arg_name}"
282            raise excp_cls(f'{msg_subject} should be {rel_str}, but got {arg_value}.')
283
284    _check()
285    return arg_value
286
287
288def check_int(arg_value, value, rel, arg_name=None, prim_name=None):
289    """
290    Checks input integer value `arg_value` compare to `value`.
291
292    Usage:
293    - number = check_int(number, 0, GE, "number", None) # number >= 0
294    """
295    return _check_number(arg_value, value, rel, int, arg_name, prim_name)
296
297
298def check_is_int(arg_value, arg_name=None, prim_name=None):
299    """
300    Checks input value is float type or not.
301
302    Usage:
303    - number = check_is_int(number, int)
304    - number = check_is_int(number, int, "bias")
305    - number = check_is_int(number, int, "bias", "bias_class")
306    """
307    return check_is_number(arg_value, int, arg_name, prim_name)
308
309
310def check_equal_int(arg_value, value, arg_name=None, prim_name=None):
311    """
312    Checks input integer value `arg_value` compare to `value`.
313
314    Usage:
315    - number = check_equal_int(number, 0, "number", None) # number == 0
316    """
317    return _check_number(arg_value, value, EQ, int, arg_name, prim_name)
318
319
320def check_positive_int(arg_value, arg_name=None, prim_name=None):
321    """
322    Check argument is positive integer, which mean arg_value > 0.
323
324    Usage:
325    - number = check_positive_int(number)
326    - number = check_positive_int(number, "bias")
327    """
328    return _check_number(arg_value, 0, GT, int, arg_name, prim_name)
329
330
331def check_positive_int_sequence(sequence, arg_name=None, prim_name=None):
332    """
333    Check argument is positive int sequence, which mean all element > 0 in sequence.
334
335    Usage:
336    - sequence = check_positive_int_sequence(sequence)
337    - sequence = check_positive_int_sequence(sequence, "dims")
338    """
339    for idx in range(len(sequence)):
340        element = sequence[idx]
341        arg_idx = f"{arg_name if arg_name else 'arg_name'}[{idx}]"
342        _check_number(element, 0, GT, int, arg_idx, prim_name)
343    return sequence
344
345
346def check_negative_int(arg_value, arg_name=None, prim_name=None):
347    """
348    Check argument is negative integer, which mean arg_value < 0.
349
350    Usage:
351    - number = check_negative_int(number)
352    - number = check_negative_int(number, "bias")
353    """
354    return _check_number(arg_value, 0, LT, int, arg_name, prim_name)
355
356
357def check_non_positive_int(arg_value, arg_name=None, prim_name=None):
358    """
359    Check argument is non-negative integer, which mean arg_value <= 0.
360
361    Usage:
362    - number = check_non_positive_int(number)
363    - number = check_non_positive_int(number, "bias")
364    """
365    return _check_number(arg_value, 0, LE, int, arg_name, prim_name)
366
367
368def check_non_negative_int(arg_value, arg_name=None, prim_name=None):
369    """
370    Check argument is non-negative integer, which mean arg_value >= 0.
371
372    Usage:
373    - number = check_non_negative_int(number)
374    - number = check_non_negative_int(number, "bias")
375    """
376    return _check_number(arg_value, 0, GE, int, arg_name, prim_name)
377
378
379def check_non_negative_int_sequence(sequence, arg_name=None, prim_name=None):
380    """
381    Check argument is positive sequence, which mean all element >= 0 in sequence.
382
383    Usage:
384    - sequence = check_non_negative_int_sequence(sequence)
385    - sequence = check_non_negative_int_sequence(sequence, "dims")
386    """
387    for idx in range(len(sequence)):
388        element = sequence[idx]
389        arg_idx = f"{arg_name if arg_name else 'arg_name'}[{idx}]"
390        _check_number(element, 0, GE, int, arg_idx, prim_name)
391    return sequence
392
393
394def check_float(arg_value, value, rel, arg_name=None, prim_name=None):
395    """
396    Checks input float value `arg_value` compare to `value`.
397
398    Usage:
399    - number = check_float(number, 0.0, GE, "number", None) # number >= 0
400    """
401    return _check_number(arg_value, value, rel, float, arg_name, prim_name)
402
403
404def check_is_float(arg_value, arg_name=None, prim_name=None):
405    """
406    Checks input value is float type or not.
407
408    Usage:
409    - number = check_is_float(number)
410    - number = check_is_float(number, "bias")
411    - number = check_is_float(number, "bias", "bias_class")
412    """
413    return check_is_number(arg_value, float, arg_name, prim_name)
414
415
416def check_positive_float(arg_value, arg_name=None, prim_name=None):
417    """
418    Check argument is positive float, which mean arg_value > 0.
419
420    Usage:
421    - number = check_positive_float(number)
422    - number = check_positive_float(number, "bias")
423    - number = check_positive_float(number, "bias", "bias_class")
424    """
425    return _check_number(arg_value, 0, GT, float, arg_name, prim_name)
426
427
428def check_positive_float_sequence(sequence, arg_name=None, prim_name=None):
429    """
430    Check argument is positive sequence, which mean all element > 0 in sequence.
431
432    Usage:
433    - sequence = check_positive_float_sequence(sequence)
434    - sequence = check_positive_float_sequence(sequence, "dims")
435    """
436    for idx in range(len(sequence)):
437        element = sequence[idx]
438        arg_idx = f"{arg_name if arg_name else 'arg_name'}[{idx}]"
439        _check_number(element, 0, GT, float, arg_idx, prim_name)
440    return sequence
441
442
443def check_negative_float(arg_value, arg_name=None, prim_name=None):
444    """
445    Check argument is negative float, which mean arg_value < 0.
446
447    Usage:
448    - number = check_negative_float(number)
449    - number = check_negative_float(number, "bias")
450    """
451    return _check_number(arg_value, 0, LT, float, arg_name, prim_name)
452
453
454def check_non_positive_float(arg_value, arg_name=None, prim_name=None):
455    """
456    Check argument is non-negative float, which mean arg_value <= 0.
457
458    Usage:
459    - number = check_non_positive_float(number)
460    - number = check_non_positive_float(number, "bias")
461    """
462    return _check_number(arg_value, 0, LE, float, arg_name, prim_name)
463
464
465def check_non_negative_float(arg_value, arg_name=None, prim_name=None):
466    """
467    Check argument is non-negative float, which mean arg_value >= 0.
468
469    Usage:
470    - number = check_non_negative_float(number)
471    - number = check_non_negative_float(number, "bias")
472    """
473    return _check_number(arg_value, 0, GE, float, arg_name, prim_name)
474
475
476def check_number(arg_name, arg_value, value, rel, prim_name):
477    """Number value judgment."""
478    def _check():
479        if not _check_binary_rel(arg_value, value, rel):
480            rel_str = _format_str_one_value(value, rel)
481            raise ValueError(f'For \'{prim_name}\', the argument \'{arg_name}\' ' \
482                             f'must {rel_str}, but got {arg_value}.')
483    _check()
484    return arg_value
485
486
487def check_isinstance(arg_name, arg_value, classes):
488    """Check arg isinstance of classes"""
489    def _check():
490        if not isinstance(arg_value, classes):
491            raise ValueError(f'The parameter \'{arg_name}\' must be isinstance of {classes}, but got {arg_value}.')
492    _check()
493    return arg_value
494
495
496def check_bool(arg_value, arg_name=None, prim_name=None):
497    """
498    Check argument is instance of bool.
499
500    Usage:
501    - has_bias = check_bool(has_bias)
502    - has_bias = check_bool(has_bias, "has_bias")
503    """
504    prim_name = f"For '{prim_name}', the" if prim_name else 'The'
505    arg_name = f"'{arg_name}'" if arg_name else 'input value'
506
507    def _check():
508        if not isinstance(arg_value, bool):
509            raise TypeError(f"{prim_name} {arg_name} must be a bool, but got {type(arg_value).__name__}.")
510    _check()
511    return arg_value
512
513
514def check_int_range(arg_value, lower_limit, upper_limit, rel, arg_name=None, prim_name=None):
515    """
516    Method for checking whether input value is in int range.
517
518    Usage:
519    - number = check_int_range(number, 0, 1, INC_NEITHER) # number in [0, 1]
520    - number = check_int_range(number, 0, 1, INC_NEITHER, "number") # number in [0, 1]
521    """
522    return check_number_range(arg_value, lower_limit, upper_limit, rel, int, arg_name, prim_name)
523
524
525def check_float_range(arg_value, lower_limit, upper_limit, rel, arg_name=None, prim_name=None):
526    """
527    Method for checking whether input value is in float range.
528
529    Usage:
530    - number = check_float_range(number, 0.0, 1.0, INC_NEITHER) # number in [0.0, 1.0]
531    - number = check_float_range(number, 0.0, 1.0, INC_NEITHER, "number") # number in [0.0, 1.0]
532    """
533    return check_number_range(arg_value, lower_limit, upper_limit, rel, float, arg_name, prim_name)
534
535
536def check_string(arg_value, valid_values, arg_name=None, prim_name=None):
537    """
538    Check whether string is in some value list.
539
540    Usage:
541    - method = check_string(method, ["string1", "string2", "string3"], "method")
542    """
543    arg_name = arg_name if arg_name else "parameter"
544    msg_prefix = f'For \'{prim_name}\', the' if prim_name else "The"
545
546    def _check():
547        if not (isinstance(arg_value, str) and arg_value in valid_values):
548            raise ValueError(f"{msg_prefix} '{arg_name}' must be str and must be in '{valid_values}'," \
549                             f" but got '{arg_value}'.")
550    _check()
551    return arg_value
552
553
554def check_str_by_regular(target, reg=None, flag=re.ASCII, prim_name=None):
555    if reg is None:
556        # Named string regular expression
557        reg = r"^\w+[0-9a-zA-Z\_\.]*$"
558    if re.match(reg, target, flag) is None:
559        prim_name = f"For '{prim_name}', the" if prim_name else "The"
560        raise ValueError(f"{prim_name} '{target}' is illegal, it must be match regular'{reg}' by flags'{flag}.'")
561    return True
562
563
564# pylint: disable=missing-docstring
565def check_str_and_none_by_regular(target, reg=None, flag=re.ASCII, prim_name=None):
566    if reg is None:
567        # Named string regular expression
568        reg = r"^\w*[0-9a-zA-Z\_\.\-]*$"
569    if re.match(reg, target, flag) is None:
570        prim_name = f"For '{prim_name}', the" if prim_name else "The"
571        raise ValueError(f"{prim_name} '{target}' is illegal, it must be match regular'{reg}' by flags'{flag}.'")
572    return True
573
574
575def check_file_name_by_regular(target, reg=None, prim_name=None):
576    """Check whether file name is legitimate."""
577    if not isinstance(target, str):
578        prim_name = f"For '{prim_name}', the" if prim_name else "The"
579        raise TypeError(f"{prim_name} '{target}' must be string, but got {type(target)}.")
580    if target.endswith("\\") or target.endswith("/"):
581        prim_name = f"For '{prim_name}', the" if prim_name else "The"
582        raise ValueError(f"{prim_name} '{target}' cannot be a directory path.")
583    if reg is None:
584        reg = r"^[0-9a-zA-Z@\_\-\.\:\/\\]+$"
585    if re.match(reg, target) is None:
586        prim_name = f"For '{prim_name}', the" if prim_name else "The"
587        raise ValueError(f"{prim_name} '{target}' is illegal, it must be match regular '{reg}'.")
588
589    return True
590
591
592def check_pad_value_by_mode(pad_mode, padding, prim_name):
593    """Validates value of padding according to pad_mode"""
594    if pad_mode != 'pad' and padding != 0:
595        raise ValueError(f"For '{prim_name}', padding must be zero when pad_mode is '{pad_mode}'," \
596                         f" but got {padding}.")
597    return padding
598
599
600def check_subclass(arg_name, type_, template_types, prim_name, addition_error_info=None):
601    """Checks whether some type is subclass of another type"""
602    if not isinstance(template_types, Iterable):
603        template_types = (template_types,)
604    hit = False
605    for template_type in template_types:
606        if isinstance(template_type, mstype.Type):
607            if mstype._issubclass_(type_, template_type):  # pylint: disable=W0212
608                hit = True
609                break
610        elif type_ is template_type:
611            hit = True
612            break
613    if not hit:
614        if addition_error_info is None:
615            addition_error_info = ''
616        else:
617            addition_error_info = ' ' + addition_error_info
618        type_str = (f"type '{type(type_).__name__}'" if isinstance(type_, (tuple, list)) else str(type_))
619        raise TypeError(f"For '{prim_name}', the element of '{arg_name}'" \
620                        f" must be {'one of ' if len(template_types) > 1 else ''}" \
621                        f"{', '.join((str(x) for x in template_types))}, but got {type_str}" \
622                        f"{addition_error_info}.The supported data types depend on the hardware that" \
623                        f" executes the operator, for more details, please refer to the MindSpore official " \
624                        f"website to get more information about the data type.")
625
626
627def check_valid_input(arg_name, arg_value, prim_name):
628    """Checks valid value."""
629    def _check():
630        if arg_value is None:
631            raise ValueError(f"For \'{prim_name}\', the argument '{arg_name}'" \
632                             f"can not be None, but got {arg_value}.")
633    _check()
634    return arg_value
635
636
637def check_types_same_and_valid(args, valid_values, prim_name):
638    """Checks whether the types of inputs are the same and valid."""
639
640    def _check_type_valid(arg):
641        arg_key, arg_val = arg
642        elem_type = arg_val
643        check_subclass(arg_key, elem_type, valid_values, prim_name)
644        return (arg_key, elem_type)
645
646    def _check_types_same(arg1, arg2):
647        arg1_name, arg1_type = arg1
648        arg2_name, arg2_type = arg2
649        if arg1_type != arg2_type:
650            raise TypeError(f"For '{prim_name}', the type of '{arg2_name}' should be same as '{arg1_name}'," \
651                            f" but got '{arg1_name}' with type {arg1_type}" \
652                            f" and '{arg2_name}' with type {arg2_type}.")
653        return arg1
654
655    elem_types = map(_check_type_valid, args.items())
656    reduce(_check_types_same, elem_types)
657
658
659def check_tensors_dtypes_same_and_valid(args, valid_dtypes, prim_name):
660    """Checks whether the element types of input tensors are the same and valid."""
661    valid_dtypes = valid_dtypes if isinstance(valid_dtypes, Iterable) else [valid_dtypes]
662    tensor_types = [mstype.TensorType(t) for t in valid_dtypes]
663    check_types_same_and_valid(args, tensor_types, prim_name)
664
665
666def check_tensor_dtype_valid(arg_name, arg_type, valid_dtypes, prim_name):
667    """Checks whether the element types of input tensors are valid."""
668    valid_dtypes = valid_dtypes if isinstance(valid_dtypes, Iterable) else [valid_dtypes]
669    tensor_types = [mstype.TensorType(t) for t in valid_dtypes]
670    check_subclass(arg_name, arg_type, tensor_types, prim_name)
671
672
673def check_scalar_or_tensor_types_same(args, valid_values, prim_name, allow_mix=False):
674    """
675    Checks whether the types of inputs are the same. If the input args are tensors, checks their element types.
676    If `allow_mix` is True, Tensor(float32) and float32 are type compatible, otherwise an exception will be raised.
677    """
678
679    def _check_argument_type(arg):
680        arg_key, arg_val = arg
681        if isinstance(arg_val, type(mstype.tensor_type)):
682            arg_val = arg_val.element_type()
683        if arg_val not in valid_values:
684            raise TypeError(f'For \'{prim_name}\', the type of \'{arg_key}\' must be in {valid_values},' \
685                            f' but got {arg_val}.')
686        return arg
687
688    def _check_types_same(arg1, arg2):
689        arg1_name, arg1_type = arg1
690        arg2_name, arg2_type = arg2
691        except_flag = False
692        if isinstance(arg1_type, type(mstype.tensor_type)) and isinstance(arg2_type, type(mstype.tensor_type)):
693            arg1_type = arg1_type.element_type()
694            arg2_type = arg2_type.element_type()
695        elif not (isinstance(arg1_type, type(mstype.tensor_type)) or isinstance(arg2_type, type(mstype.tensor_type))):
696            pass
697        elif allow_mix:
698            arg1_type = arg1_type.element_type() if isinstance(arg1_type, type(mstype.tensor_type)) else arg1_type
699            arg2_type = arg2_type.element_type() if isinstance(arg2_type, type(mstype.tensor_type)) else arg2_type
700        else:
701            except_flag = True
702
703        if except_flag or arg1_type != arg2_type:
704            raise TypeError(f"For '{prim_name}', the type of '{arg2_name}' must be same as '{arg1_name}'," \
705                            f" but got '{arg1_name}' with type {arg1_type}" \
706                            f" and '{arg2_name}' with type {arg2_type}.")
707        return arg1
708
709    args_map = map(_check_argument_type, args.items())
710    reduce(_check_types_same, args_map)
711
712
713def check_value_type(arg_name, arg_value, valid_types, prim_name=None):
714    """Checks whether a value is instance of some types."""
715    valid_types = valid_types if isinstance(valid_types, Iterable) else (valid_types,)
716
717    def raise_error_msg(cond, arg_value):
718        """func for raising error message when check failed"""
719        if not cond:
720            return
721        type_names = [t.__name__ if hasattr(t, '__name__') else str(t) for t in valid_types]
722        num_types = len(valid_types)
723        msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
724        raise TypeError(f'{msg_prefix} type of \'{arg_name}\' should be {"one of " if num_types > 1 else ""}' \
725                        f'\'{type_names if num_types > 1 else type_names[0]}\', ' \
726                        f'but got type \'{type(arg_value).__name__}\'.')
727
728    # Notice: bool is subclass of int, so `check_value_type('x', True, [int])` will check fail, and
729    #         `check_value_type('x', True, [bool, int])` will check pass
730    cond = isinstance(arg_value, bool) and bool not in tuple(valid_types)
731    raise_error_msg(cond, arg_value)
732    if isinstance(arg_value, float) and float not in tuple(valid_types):
733        arg_value = round(arg_value, 6)
734    cond = not isinstance(arg_value, tuple(valid_types))
735    raise_error_msg(cond, arg_value)
736    return arg_value
737
738
739def check_type_name(arg_name, arg_type, valid_types, prim_name):
740    """Checks whether a type in some specified types"""
741    valid_types = valid_types if isinstance(valid_types, Iterable) else (valid_types,)
742
743    def raise_error_msg(cond, arg_type):
744        """func for raising error message when check failed"""
745        if not cond:
746            return
747        type_names = [t.__name__ if hasattr(t, '__name__') else t for t in valid_types]
748        num_types = len(valid_types)
749        msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
750        raise TypeError(f"{msg_prefix} '{arg_name}' should be {'one of ' if num_types > 1 else ''}" \
751                        f"{type_names if num_types > 1 else type_names[0]}, " \
752                        f"but got '{arg_type.__name__ if hasattr(arg_type, '__name__') else repr(arg_type)}'.")
753
754    if isinstance(arg_type, type(mstype.tensor_type)):
755        arg_type = arg_type.element_type()
756    cond = arg_type not in valid_types
757    raise_error_msg(cond, arg_type)
758    return arg_type
759
760
761def check_reduce_shape(ori_shape, shape, axis, prim_name, arg_name1, arg_name2):
762    """Checks whether shape is ori_shape reduced on axis"""
763    axis_origin = axis
764    axis = axis if isinstance(axis, Iterable) else (axis,)
765    exp_shape = [ori_shape[i] for i in range(len(ori_shape)) if i not in axis]
766    if list(shape) != exp_shape:
767        raise ValueError(f"For '{prim_name}', " \
768                         f"the shape of parameter '{arg_name1}' reduce on 'axis': {axis_origin} must " \
769                         f"be equal to the shape of '{arg_name2}': {shape}, but got {ori_shape}.")
770
771
772def check_astype_dtype(dtype):
773    """Check whether dtype is a valid input, and convert to mstype"""
774    all_types = mstype.__dtype__ + ["int", "float", "bool"]
775    if isinstance(dtype, str):
776        if dtype.lower() not in all_types:
777            raise TypeError(f"For Tensor.astype, the input type must be one of {all_types}, but got '{dtype}'.")
778        dtype = mstype.pytype_to_dtype(np.dtype(dtype.lower()))
779    elif isinstance(dtype, type):
780        dtype = mstype.pytype_to_dtype(dtype)
781    elif not dtype in mstype.number_type + (mstype.bool_,):
782        raise TypeError(f"For Tensor.astype, the input type must be one of {mstype.number_type + (mstype.bool_,)}," \
783                        f" but got '{dtype}'.")
784    return dtype
785
786
787def check_transpose_axis(axes, ndim):
788    """Check the axis argument for tensor.transpose"""
789    def _check_dim():
790        # if multiple arguments provided, it must be `ndim` number of ints
791        if len(axes) != ndim:
792            raise ValueError(f"For Tensor.transpose, the number of axes must be equal to the dimension of Tensor, " \
793                             f"but got {len(axes)} in the number of axes.")
794
795    if not axes or (len(axes) == 1 and axes[0] is None):
796        return tuple(range(ndim-1, -1, -1))
797
798    if len(axes) == 1:
799        perm = axes[0]
800        # if only one argument provided, it must be tuple or list
801        if isinstance(perm, list):
802            perm = tuple(perm)
803        elif isinstance(perm, int):
804            perm = (perm,)
805            _check_dim()
806        else:
807            if not isinstance(perm, tuple):
808                raise TypeError(f"For Tensor.transpose, the parameter 'axes' must be a tuple/list, " \
809                                f"or series of integer, but got {type(axes[0])}")
810        return perm
811
812    _check_dim()
813    return axes
814
815
816def check_reshape_shp(shp):
817    """Check the shape argument for tensor.reshape"""
818
819    if len(shp) == 1:
820        new_shape = shp[0]
821        # if only one argument provided, it must be int, tuple or list
822        if isinstance(new_shape, int):
823            return shp
824        if isinstance(new_shape, list):
825            new_shape = tuple(new_shape)
826        else:
827            if not isinstance(new_shape, tuple):
828                raise TypeError(
829                    f"For Tensor.reshape, the parameter 'shape' must be an integer, or tuple/list, " \
830                    f"or series of integer, but got {type(shp[0])}")
831        return new_shape
832
833    return shp
834
835
836def check_flatten_order(order):
837    """Check flatten function input order"""
838    if not isinstance(order, str):
839        raise TypeError(f"For Tensor.flatten, the parameter 'order' must be a string, but got {type(order)}")
840    if order not in ('C', 'F'):
841        raise ValueError(f"For Tensor.flatten, the parameter 'order' must be 'C' or 'F', but got '{order}'")
842
843
844def check_swapaxes_axis(axes, ndim):
845    """Check all the axes argument for ops.swapaxes"""
846    if isinstance(axes, int):
847        return check_axis_in_range(axes, ndim)
848    if isinstance(axes, (tuple, list)):
849        for axis in axes:
850            if not isinstance(axis, int):
851                raise TypeError(f"For ops.swapaxes, the axis argument must be integer, but got {type(axis)}.")
852            check_axis_in_range(axis, ndim)
853        tmp = ()
854        for x in axes:
855            tmp = tmp + ((x + ndim) % ndim,)
856        return tmp
857    raise TypeError(f"For ops.swapaxes, the argument 'axes' must be integer, list or tuple for check, " \
858                    f"but got {type(axes)}.")
859
860
861def prepare_shape_for_squeeze(shape, axes):
862    """
863    Creates the squeezed new shape based on the tensor and given axes.
864
865    Args:
866        shape (tuple): the shape of the tensor
867        axes Union[int, tuple(int), list(int)]: the axes with dimensions need to
868            be squeezed.
869
870    Returns:
871        new_shape(tuple): the shape with dimensions squeezed.
872    """
873    new_shape = ()
874    ndim = len(shape)
875
876    def _check(axes, ndim):
877        if axes >= ndim or axes < -ndim:
878            raise ValueError(f"For Tensor.squeeze, the 'axis' must be in the range of [-{ndim}, {ndim}), " \
879                             f"but got {axes}.")
880
881    def _check_for(axes, ndim):
882        for axis in axes:
883            _check(axis, ndim)
884
885    if isinstance(axes, int):
886        _check(axes, ndim)
887        axes = (axes,)
888    elif isinstance(axes, (list, tuple)):
889        _check_for(axes, ndim)
890        new_axes = ()
891        for item in axes:
892            if item not in new_axes:
893                new_axes += (item,)
894        axes = new_axes
895    else:
896        raise TypeError(f"For Tensor.squeeze, the parameter 'axes' must be one of [int, tuple, list], " \
897                        f"but got {type(axes)}")
898
899    def _check_axis(s, idx, axes, ndim):
900        # if an axis is selected with shape entry greater than one, an error is raised.
901        if s != 1 and ((idx in axes) or (idx - ndim in axes)):
902            raise ValueError(f"For Tensor.squeeze, the shape of parameter 'axis' {axes} must be 1, but got {s}.")
903
904    for idx in range(ndim):
905        s = shape[idx]
906        _check_axis(s, idx, axes, ndim)
907        if s != 1 or (idx not in axes) and (idx - ndim not in axes):
908            new_shape = new_shape + (s,)
909
910    return new_shape
911
912
913def check_axis_in_range(axis, ndim):
914    """Checks axes are with the bounds of ndim"""
915    def _check():
916        if not isinstance(axis, int):
917            raise TypeError(f'The axes must be integers, but got {type(axis)}')
918
919        if axis >= ndim or axis < -ndim:
920            raise ValueError(f"The 'axis' must be in the range of [-{ndim}, {ndim}), but got {axis}.")
921
922    _check()
923    return (axis + ndim) % ndim
924
925
926def check_axis_valid(axes, ndim):
927    """
928    Checks axes are valid given ndim, and returns axes that can be passed
929    to the built-in operator (non-negative, int or tuple)
930    """
931    def _check_range(axes):
932        for axis in axes:
933            check_axis_in_range(axis, ndim)
934
935    if axes is None:
936        axes = tuple(range(ndim))
937        return axes
938    if isinstance(axes, (tuple, list)):
939        _check_range(axes)
940        tmp = ()
941        for x in axes:
942            tmp = tmp + ((x + ndim) % ndim,)
943        _check_dup(tmp)
944        return tmp
945    check_axis_in_range(axes, ndim)
946    return (axes % ndim,)
947
948
949def max_(*args):
950    """Return the maximum value of the input parameter."""
951    return max(*args)
952
953
954def min_(*args):
955    """Return the minimum value of the input parameter."""
956    return min(*args)
957
958
959def is_stub_tensor(tensor):
960    return hasattr(tensor, "stub")
961
962
963def expanded_shape(ndim, axis_size, axis):
964    """
965    Returns a shape with size = 1 for all dimensions
966    except at axis.
967    """
968    return tuple(axis_size if i == axis else 1 for i in range(ndim))
969
970
971def tuple_slice(tup, start, end):
972    """get sliced tuple from start and end."""
973    return tup[start:end]
974
975
976def infer_out_shape(*shapes):
977    """
978    Returns shape of output after broadcasting. Raises ValueError if shapes cannot be broadcast.
979    """
980    def _check(items, max_size, shapes):
981        for item in items:
982            if item not in (1, max_size):
983                raise ValueError(f'For Tensor, the dimension on each axis must be 1 or the max value on the axis' \
984                                 f'to support broadcasting, but got shapes {shapes,}')
985    shape_out = ()
986    max_len = max([len(it) for it in shapes])
987    for i in range(max_len):
988        items = [it[i-(max_len-len(it))] if i - (max_len - len(it))
989                 >= 0 else 1 for it in shapes]
990        max_size = 0 if 0 in items else max(items)
991        _check(items, max_size, shapes)
992        shape_out = shape_out + (max_size,)
993    return shape_out
994
995
996def check_axis_type(axis, type_int=True, type_tuple=True, type_list=True):
997    """Check axis argument type."""
998    if type_int and isinstance(axis, int):
999        return True
1000    if (type_tuple and isinstance(axis, tuple)) or (type_list and isinstance(axis, list)):
1001        for ax in axis:
1002            if not isinstance(ax, int):
1003                raise TypeError(f"For Tensor.ptp, each axis must be integer, but got {type(ax)} in {axis}.")
1004        return True
1005
1006    type_str = ""
1007    if type_int:
1008        type_str += "int, "
1009    if type_tuple:
1010        type_str += "tuple, "
1011    if type_list:
1012        type_str += "list, "
1013    raise TypeError(f"For Tensor.ptp, the axis should be {type_str}, but got {type(axis)}.")
1014
1015
1016def check_and_canonicalize_axes(axes, ndim):
1017    """Check whether the types and values of input axes are valid."""
1018    def _check(axes, ax, ndim):
1019        if not isinstance(ax, int):
1020            raise TypeError(f"Each axis should be integer, but got {type(ax)} in {axes}.")
1021        if ax >= ndim or ax < -ndim:
1022            raise ValueError(f"The 'axis' must be in the range of [-{ndim}, {ndim}), but got {ax}.")
1023
1024    axes = axes if isinstance(axes, tuple) else (axes,)
1025    new_axes = ()
1026    for ax in axes:
1027        _check(axes, ax, ndim)
1028        ax = ax if ax >= 0 else ax + ndim
1029        new_axes += (ax,)
1030    _check_dup(new_axes)
1031    return new_axes
1032
1033
1034def check_type_support(dtype, device, supported_dtypes):
1035    """Checks whether the data type is supported."""
1036    return dtype in supported_dtypes or not context.get_context('device_target') == device
1037
1038
1039def check_sparse_tensor_input(indices, values, shape):
1040    """Common input check for SparseTensors."""
1041    if not isinstance(indices, Tensor_) and not is_stub_tensor(indices):
1042        raise TypeError(f"For SparseTensors, 'indices' must be Tensor, but got {type(indices)}.")
1043    if not isinstance(values, Tensor_) and not is_stub_tensor(values):
1044        raise TypeError(f"For SparseTensors, 'values' must be Tensor, but got {type(values)}.")
1045    if not isinstance(shape, tuple):
1046        raise TypeError(f"For SparseTensors, 'shape' must be tuple, but got {type(shape)}.")
1047
1048
1049def check_csr_tensor_input(indptr, indices, values, shape):
1050    """Checks inputs type for CSRTensor."""
1051    if not isinstance(indptr, Tensor_) and not is_stub_tensor(indptr):
1052        raise TypeError(f"For CSRTensor, 'indptr' must be Tensor, but got {type(indptr)}.")
1053    check_sparse_tensor_input(indices, values, shape)
1054
1055
1056def check_csr_tensor_shape(indptr_shp, indices_shp, values_shp, csr_shp):
1057    """Checks input tensors' shapes for CSRTensor."""
1058    # Support empty sparse tensor
1059    if (indptr_shp == (0,)) and (indices_shp == (0,)) and (values_shp == (0,)):
1060        return
1061    shape_size = 1
1062    val_shp_size = 1
1063    for item in csr_shp:
1064        if item <= 0:
1065            raise ValueError(f"For CSRTensor, the element of shape must be positive, but got {item}")
1066        if not isinstance(item, int):
1067            raise TypeError(f"For CSRTensor, the element type of shape must be int, but got {type(item)}")
1068        shape_size *= item
1069    for item in values_shp:
1070        if item <= 0:
1071            raise ValueError(f"The element of shape must be positive, but got {item}")
1072        val_shp_size *= item
1073    if shape_size < val_shp_size:
1074        raise ValueError(f"Shape total size: {shape_size} is too small to hold {val_shp_size} non-zero values.")
1075    if len(indices_shp) != 1:
1076        raise ValueError(f"For CSRTensor, indices must be a 1-dimensional tensor, " \
1077                         f"but got a {len(indices_shp)} dimension tensor.")
1078    if len(indptr_shp) != 1:
1079        raise ValueError(f"For CSRTensor, indptr must be a 1-dimensional tensor, " \
1080                         f"but got a {len(indptr_shp)} dimension tensor.")
1081    if csr_shp[0] + 1 != indptr_shp[0]:
1082        raise ValueError(f"For CSRTensor, indptr must have length (1 + shape[0]), " \
1083                         f"but got: {indptr_shp[0]}")
1084    if indices_shp[0] != values_shp[0]:
1085        err_msg1 = "For CSRTensor, indices and values must equal in their shape, "
1086        err_msg2 = f"but got indices shape: {indices_shp[0]}, values shape: {values_shp[0]}."
1087        raise ValueError(err_msg1 + err_msg2)
1088    if len(values_shp) + 1 != len(csr_shp):
1089        raise ValueError(f"Values' dimension should equal to CSRTensor's dimension - 1, but got" \
1090                         f"Values' dimension: {len(values_shp)} , CSRTensor's dimension: " \
1091                         f"{len(csr_shp)}")
1092    if values_shp[1:] != csr_shp[2:]:
1093        raise ValueError(f"CSRTensor's shape[2: ] must be equal to value's shape[1: ]," \
1094                         f"but CSRTensor's shape[2: ] got: {csr_shp[2: ]} and value's shape[1: ]" \
1095                         f"got: {values_shp[1: ]}")
1096
1097
1098def check_csr_tensor_dtype(indptr_dtype, indices_dtype):
1099    """Checks input tensors' data types for CSRTensor."""
1100    if indptr_dtype not in (mstype.int16, mstype.int32, mstype.int64):
1101        raise TypeError(f"For CSRTensor, indptr must have int16 or int32 or int64 data type, " \
1102                        f"but got {indptr_dtype}.")
1103    if indices_dtype not in (mstype.int16, mstype.int32, mstype.int64):
1104        raise TypeError(f"For CSRTensor, indices must have int16 or int32 or int64 data type, " \
1105                        f"but got {indices_dtype}.")
1106
1107
1108def check_coo_tensor_input(indices, values, shape):
1109    """Checks inputs type for COOTensor."""
1110    check_sparse_tensor_input(indices, values, shape)
1111
1112
1113def check_coo_tensor_shape(indices_shp, values_shp, coo_shp):
1114    """Checks input tensors' shapes for COOTensor."""
1115    if len(coo_shp) != 2:
1116        raise ValueError(f"For COOTensor, the length of 'shape' must be 2, but got {coo_shp}.")
1117    if (indices_shp == (0,)) and (values_shp == (0,)):
1118        return
1119    shp_mul = 1
1120    for sh in coo_shp:
1121        if sh <= 0:
1122            raise ValueError(f"For COOTensor, the element of 'shape' must be positive, but got {sh} in {coo_shp}.")
1123        if not isinstance(sh, int):
1124            raise TypeError(f"For COOTensor, the element type of 'shape' must be int, but got {type(sh)}")
1125        shp_mul *= sh
1126    if shp_mul < values_shp[0]:
1127        raise ValueError(f"For COOTensor, shape is too small: ({shp_mul}) to hold all values({values_shp[0]}).")
1128    if len(indices_shp) != 2:
1129        raise ValueError(f"For COOTensor, 'indices' must be a 2-dimensional tensor, but got a {len(indices_shp)}" \
1130                         f"-dimensional tensor.")
1131    if len(values_shp) != 1:
1132        raise ValueError(f"For COOTensor, 'values' must be a 1-dimensional tensor, but got a {len(values_shp)}" \
1133                         f"-dimensional tensor.")
1134    if indices_shp[0] != values_shp[0]:
1135        raise ValueError(f"For COOTensor, 'indices.shape[0]' must be euqal to 'values.shape[0]', but got " \
1136                         f"'indices.shape[0]' = {indices_shp[0]} and 'values.shape[0]' = {values_shp[0]}.")
1137    if indices_shp[1] != 2:
1138        raise ValueError(f"For COOTensor, 'indices.shape[1]' must be 2, but got {indices_shp[1]}.")
1139
1140
1141def check_coo_tensor_dtype(indices_dtype):
1142    """Checks input tensors' data types for COOTensor."""
1143    if indices_dtype not in (mstype.int16, mstype.int32, mstype.int64):
1144        raise TypeError(f"For COOTensor, the type of 'indices' must be one of [int16, int32, int64], but got " \
1145                        f"{indices_dtype}.")
1146
1147
1148def check_element_type_of_iterable(arg_name, arg_value, valid_types, prim_name=None):
1149    """Check type of the element of a iterabel object, except dict."""
1150    check_value_type(arg_name, arg_value, [list, tuple], prim_name)
1151    type_names = [t.__name__ if hasattr(t, '__name__') else str(t) for t in valid_types]
1152    num_types = len(valid_types)
1153    msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
1154    for element in arg_value:
1155        if not isinstance(element, tuple(valid_types)):
1156            raise TypeError(f"{msg_prefix} type of '{arg_name}' should be {'one of ' if num_types > 1 else ''}" \
1157                            f"{type_names if num_types > 1 else type_names[0]}, " \
1158                            f"but got '{element}' with type '{type(element).__name__}'.")
1159
1160
1161def check_element_type_of_dict(arg_name, arg_value, key_types, value_types, prim_name=None):
1162    """Check the type of key and value of a dict."""
1163    check_value_type(arg_name, arg_value, [dict], prim_name)
1164    msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
1165    type_names = [t.__name__ if hasattr(t, '__name__') else str(t) for t in key_types]
1166    num_types = len(key_types)
1167    for element in arg_value.keys():
1168        if not isinstance(element, tuple(key_types)):
1169            raise TypeError(f"{msg_prefix} type of '{arg_name}' should be {'one of ' if num_types > 1 else ''}" \
1170                            f"{type_names if num_types > 1 else type_names[0]}, " \
1171                            f"but got '{element}' with type '{type(element).__name__}'.")
1172
1173    type_names = [t.__name__ if hasattr(t, '__name__') else str(t) for t in value_types]
1174    num_types = len(value_types)
1175    for element in arg_value.values():
1176        if not isinstance(element, tuple(value_types)):
1177            raise TypeError(f"{msg_prefix} type of '{arg_name}' should be {'one of ' if num_types > 1 else ''}" \
1178                            f"{type_names if num_types > 1 else type_names[0]}, " \
1179                            f"but got '{element}' with type '{type(element).__name__}'.")
1180
1181
1182def check_size_and_element_type_of_tuple(arg_name, arg_value, expect_size, expect_element_type, prim_name=None):
1183    """Check the size and element type of a tuple."""
1184    check_value_type(arg_name, arg_value, [tuple], prim_name)
1185    check_equal_int(len(arg_value), expect_size, arg_name + ' size', prim_name)
1186    check_element_type_of_iterable('arg_name', arg_value, [expect_element_type], prim_name)
1187
1188
1189def _check_symbol(dyn_input, net_input, index, symbolic_shape_data):
1190    """Check symbolic shape values."""
1191    actual_shape = net_input.shape
1192    for i, sym in enumerate(dyn_input.symbolic_shape):
1193        # the Symbol is converted to dict
1194        if not isinstance(sym, dict):
1195            continue
1196        # the value of symbols with same "id" should be equal.
1197        if "id" in sym:
1198            sym_id = sym["id"]
1199            k_idval = "unique_id_value_map"
1200            if k_idval not in symbolic_shape_data:
1201                symbolic_shape_data[k_idval] = {}
1202            unique_id_value = symbolic_shape_data[k_idval]
1203            if sym_id not in unique_id_value:
1204                unique_id_value[sym_id] = actual_shape[i]
1205            elif unique_id_value[sym_id] != actual_shape[i]:
1206                raise ValueError(
1207                    f"The {i + 1}th shape value of {index + 1}th actual input args is a unique symbol, all values must "
1208                    f"be the same. The previous value is {unique_id_value[sym_id]}, but the current value is "
1209                    f"{actual_shape[i]}. Actual shape: {actual_shape}, axis: {i}.")
1210        # check the value in range [min, max].
1211        if "min" in sym and actual_shape[i] < sym["min"]:
1212            raise ValueError(
1213                f"The {i + 1}th shape value of {index + 1}th actual input args must be greater than or equal to the "
1214                f"'min' value '{sym['min']}' of `Symbol`, but got '{actual_shape[i]}'.  Actual shape: {actual_shape}, "
1215                f"axis: {i}.")
1216        if "max" in sym and actual_shape[i] > sym["max"]:
1217            raise ValueError(
1218                f"The {i + 1}th shape value of {index + 1}th actual input args must be less than or equal to the "
1219                f"'max' value '{sym['max']}' of `Symbol`, but got '{actual_shape[i]}'. Actual shape: {actual_shape}, "
1220                f"axis: {i}.")
1221        # check the shape item that satisfies the "divisor * N + remainder, N >= 1".
1222        d = sym.get("divisor", 1)
1223        r = sym.get("remainder", 0)
1224        if actual_shape[i] < d or actual_shape[i] % d != r:
1225            raise ValueError(
1226                f"The {i + 1}th shape value of {index + 1}th actual input args must be match the 'divisor'(d) and "
1227                f"'remainder'(r) of `Symbol`. The value should be 'd * N + r' for 'N > 0', got d={d} and r={r}, but "
1228                f"actual shape value is '{actual_shape[i]}'. Actual shape: {actual_shape}, axis: {i}")
1229
1230
1231def check_symbolic_shape(dynamic_inputs, actual_inputs):
1232    """Check the symboic shape"""
1233    symbolic_shape_data = {}
1234
1235    def run_check(dyn_inputs, net_inputs):
1236        """the real checking function"""
1237        for index, (dyn_input, net_input) in enumerate(zip(dyn_inputs, net_inputs)):
1238            if isinstance(dyn_input, (tuple, list)):
1239                run_check(dyn_input, net_input)
1240            elif hasattr(dyn_input, "symbolic_shape"):
1241                _check_symbol(dyn_input, net_input, index, symbolic_shape_data)
1242
1243    run_check(dynamic_inputs, actual_inputs)
1244
1245
1246def check_input_format(input_param):
1247    """Judge input format."""
1248    if input_param == "NCHW":
1249        return input_param
1250    raise ValueError(f"The data format must be NCHW, but got {input_param}.")
1251
1252
1253def _expand_tuple(n_dimensions):
1254    """To expand an int number to tuple."""
1255
1256    def convert(m):
1257        if not isinstance(m, tuple):
1258            if isinstance(m, int) and not isinstance(m, bool):
1259                return tuple(repeat(m, n_dimensions))
1260            raise TypeError(f"When expanding an int number to tuple, input type must be integer or tuple[int], " \
1261                            f"but got {type(m)}")
1262
1263        if not len(m) is n_dimensions:
1264            raise TypeError(f"When expanding an int number to tuple, input tuple dimension must be {n_dimensions}, " \
1265                            f"but got {m}")
1266
1267        for i in m:
1268            if not isinstance(i, int) or isinstance(i, bool):
1269                raise TypeError(f"When expanding an int number to tuple, " \
1270                                f"the type of element in input tuple must be an integer, but got {type(i)}.")
1271        return m
1272
1273    return convert
1274
1275
1276def _check_data_type_valid(data, valid_type):
1277    """Check data type valid."""
1278    if valid_type is None:
1279        return data is None
1280    if isinstance(data, valid_type):
1281        if hasattr(data, 'size') and data.size == 0:
1282            msg = "The input data can not be empty."
1283            logger.critical(msg)
1284            raise ValueError(msg)
1285        return True
1286    return False
1287
1288
1289def check_input_data(*data, data_class):
1290    """Input data check."""
1291    for item in data:
1292        if isinstance(item, (list, tuple)):
1293            for v in item:
1294                check_input_data(v, data_class=data_class)
1295        elif isinstance(item, dict):
1296            for v in item.values():
1297                check_input_data(v, data_class=data_class)
1298        else:
1299            if isinstance(data_class, (tuple, list)):
1300                ret = True in tuple(_check_data_type_valid(item, data_type) for data_type in data_class)
1301            else:
1302                ret = _check_data_type_valid(item, data_class)
1303            if not ret:
1304                data_class_str = tuple(i.__name__ if hasattr(i, '__name__') else i for i in data_class) if isinstance(
1305                    data_class, (tuple, list)) else (data_class if data_class is None else data_class.__name__)
1306                raise TypeError(f'The types of input data must be in the Union({data_class_str}, ' \
1307                                f'tuple[{data_class_str}], list[{data_class_str}], dict[{data_class_str}]), ' \
1308                                f'but got type {item if item is None else type(item).__name__}.')
1309
1310
1311def check_input_dataset(*dataset, dataset_type):
1312    """Input dataset check."""
1313    if not dataset:
1314        return False
1315    for item in dataset:
1316        if not isinstance(item, dataset_type):
1317            return False
1318    return True
1319
1320
1321def check_output_data(data):
1322    """Output data check."""
1323    if data is None:
1324        raise RuntimeError('The output data can not be None, please check your net or input data.')
1325
1326
1327once = _expand_tuple(1)
1328twice = _expand_tuple(2)
1329triple = _expand_tuple(3)
1330
1331
1332def args_type_check(*type_args, **type_kwargs):
1333    """Check whether input data type is correct."""
1334
1335    def type_check(func):
1336        sig = inspect.signature(func)
1337        bound_types = sig.bind_partial(*type_args, **type_kwargs).arguments
1338
1339        @wraps(func)
1340        def wrapper(*args, **kwargs):
1341            nonlocal bound_types
1342            bound_values = sig.bind(*args, **kwargs)
1343            argument_dict = bound_values.arguments
1344            if "kwargs" in bound_types:
1345                bound_types = bound_types["kwargs"]
1346            if "kwargs" in argument_dict:
1347                argument_dict = argument_dict["kwargs"]
1348            for name, value in argument_dict.items():
1349                if name in bound_types:
1350                    if value is not None and not isinstance(value, bound_types[name]):
1351                        raise TypeError(f"The parameter '{name}' must be {bound_types[name]}, but got {type(value)}")
1352            return func(*args, **kwargs)
1353
1354        return wrapper
1355
1356    return type_check
1357
1358
1359def check_hook_fn(hook_type, hook_fn):
1360    """Check hook fn"""
1361    if context.get_context("mode") != context.PYNATIVE_MODE:
1362        logger.warning(f"'{hook_type}' function is only supported in pynative mode, you can use "
1363                       f"context.set_context to set pynative mode.")
1364        return False
1365
1366    if not isinstance(hook_fn, (FunctionType, MethodType)):
1367        raise TypeError(f"When using 'hook_type(hook_fn)', the type of 'hook_fn' must be python "
1368                        f"function, but got {type(hook_fn)}.")
1369
1370    if hook_fn.__code__.co_name == "staging_specialize":
1371        raise TypeError(f"Decorating hook function {hook_fn.__name__} with '@jit' is not supported.")
1372
1373    if hook_type == "register_hook" and hook_fn.__code__.co_argcount != 1:
1374        raise TypeError(f"Tensor hook function {hook_fn.__name__} arg num is not equal to 1.")
1375
1376    return True
1377
1378_set_record = {}
1379