• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020-2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Check parameters."""
16
17import re
18import inspect
19import math
20from enum import Enum
21from functools import reduce, wraps
22from itertools import repeat, zip_longest
23from collections import deque
24from collections.abc import Iterable
25import numpy as np
26from mindspore import context
27from mindspore import log as logger
28from mindspore.common import dtype as mstype
29from mindspore._c_expression import Tensor as Tensor_
30
31
32class Rel(Enum):
33
34    """Numerical relationship between variables, logical relationship enumeration definition of range."""
35    # scalar compare
36    EQ = 1  # ==
37    NE = 2  # !=
38    LT = 3  # <
39    LE = 4  # <=
40    GT = 5  # >
41    GE = 6  # >=
42    # scalar range check
43    INC_NEITHER = 7  # (), include neither
44    INC_LEFT = 8  # [), include left
45    INC_RIGHT = 9  # (], include right
46    INC_BOTH = 10  # [], include both
47    # collection in, not in
48    IN = 11
49    NOT_IN = 12
50
51    @staticmethod
52    def get_strs(rel):
53        """Get value from rel_strs."""
54        return rel_strs.get(rel, "")
55
56    @staticmethod
57    def get_fns(rel):
58        """Get value from rel_fns."""
59        return rel_fns.get(rel, lambda *args: False)
60
61
62rel_fns = {
63    # scalar compare
64    Rel.EQ: lambda x, y: x == y,
65    Rel.NE: lambda x, y: x != y,
66    Rel.LT: lambda x, y: x < y,
67    Rel.LE: lambda x, y: x <= y,
68    Rel.GT: lambda x, y: x > y,
69    Rel.GE: lambda x, y: x >= y,
70    # scalar range check
71    Rel.INC_NEITHER: lambda x, lower, upper: (lower < x < upper),
72    Rel.INC_LEFT: lambda x, lower, upper: (lower <= x < upper),
73    Rel.INC_RIGHT: lambda x, lower, upper: (lower < x <= upper),
74    Rel.INC_BOTH: lambda x, lower, upper: (lower <= x <= upper),
75    # collection in, not in
76    Rel.IN: lambda x, y: x in y,
77    Rel.NOT_IN: lambda x, y: x not in y,
78}
79
80rel_strs = {
81    # scalar compare
82    Rel.EQ: "= {}",
83    Rel.NE: "!= {}",
84    Rel.LT: "< {}",
85    Rel.LE: "<= {}",
86    Rel.GT: "> {}",
87    Rel.GE: ">= {}",
88    # scalar range check
89    Rel.INC_NEITHER: "({}, {})",
90    Rel.INC_LEFT: "[{}, {})",
91    Rel.INC_RIGHT: "({}, {}]",
92    Rel.INC_BOTH: "[{}, {}]",
93    # collection in, not in
94    Rel.IN: "in {}",
95    Rel.NOT_IN: "not in {}",
96}
97
98
99def _check_3d_int_or_tuple(arg_name, arg_value, prim_name, allow_five=False, ret_five=False,
100                           greater_zero=True, third_one=False, three_input=False):
101    """
102    Checks whether an argument is a positive int or tuple with 3 or 5(when allow_five is True) positive int elements.
103    """
104
105    def _raise_message(third_one_flag=False, three_input_flag=False):
106        if third_one_flag:
107            raise ValueError(f"For '{prim_name}' the depth of attr '{arg_name}' should be 1, but got {ret_value[-3]}")
108        if three_input_flag:
109            raise ValueError(f"For '{prim_name}' attr '{arg_name}' should be an positive int number or a tuple of "
110                             f"three positive int numbers, but got {arg_value}")
111        raise ValueError(f"For '{prim_name}' attr '{arg_name}' should be an positive int number or a tuple of three "
112                         f"{'or five ' if allow_five else ''}positive int numbers, but got {arg_value}")
113
114    def _get_return_value():
115        if isinstance(arg_value, int):
116            ret = (1, 1, arg_value, arg_value, arg_value) if ret_five else (arg_value, arg_value, arg_value)
117        elif len(arg_value) == 3:
118            ret = (1, 1, arg_value[0], arg_value[1], arg_value[2]) if ret_five else arg_value
119        elif len(arg_value) == 5:
120            if not allow_five:
121                _raise_message()
122            ret = arg_value if ret_five else (arg_value[1], arg_value[2], arg_value[3])
123        else:
124            _raise_message()
125        return ret
126
127    Validator.check_value_type(arg_name, arg_value, (int, tuple), prim_name)
128    if three_input and isinstance(arg_value, tuple):
129        if len(arg_value) != 3:
130            _raise_message(three_input_flag=three_input)
131    ret_value = _get_return_value()
132    for item in ret_value:
133        if isinstance(item, int) and not isinstance(item, bool):
134            if greater_zero and item > 0:
135                continue
136            if not greater_zero and item >= 0:
137                continue
138        _raise_message()
139
140    if third_one:
141        if ret_value[-3] != 1:
142            _raise_message(third_one_flag=third_one)
143
144    return tuple(ret_value)
145
146
147def check_number(arg_value, value, rel, arg_type=int, arg_name=None, prim_name=None):
148    """
149    Check argument integer.
150
151    Example:
152    - number = check_number(number, 0, Rel.GE, "number", None) # number >= 0
153    """
154    rel_fn = Rel.get_fns(rel)
155    prim_name = f'in `{prim_name}`' if prim_name else ''
156    arg_name = f'`{arg_name}`' if arg_name else ''
157
158    if isinstance(arg_value, arg_type):
159        if math.isinf(arg_value) or math.isnan(arg_value) or np.isinf(arg_value) or np.isnan(arg_value):
160            raise ValueError(f'{arg_name} {prim_name} must be legal value, but got `{arg_value}`.')
161    else:
162        raise TypeError(f'{arg_name} {prim_name} must be {arg_type.__name__}, but got `{type(arg_value).__name__}`')
163
164    type_mismatch = not isinstance(arg_value, arg_type) or isinstance(arg_value, bool)
165    type_except = TypeError if type_mismatch else ValueError
166    if type_mismatch or not rel_fn(arg_value, value):
167        rel_str = Rel.get_strs(rel).format(value)
168        raise type_except(f'{arg_name} {prim_name} should be an {arg_type.__name__} and must {rel_str}, '
169                          f'but got `{arg_value}` with type `{type(arg_value).__name__}`.')
170
171    return arg_value
172
173
174def check_is_number(arg_value, arg_type, arg_name=None, prim_name=None):
175    """
176    Checks input value is float type or not.
177
178    Usage:
179    - number = check_is_number(number, int)
180    - number = check_is_number(number, int, "bias")
181    - number = check_is_number(number, int, "bias", "bias_class")
182    """
183    prim_name = f"For \'{prim_name}\', the" if prim_name else 'The'
184    arg_name = f"\'{arg_name}\'" if arg_name else 'input value'
185    if isinstance(arg_value, arg_type) and not isinstance(arg_value, bool):
186        if math.isinf(arg_value) or math.isnan(arg_value) or np.isinf(arg_value) or np.isnan(arg_value):
187            raise ValueError(f'{prim_name} {arg_name} must be legal float, but got `{arg_value}`.')
188        return arg_value
189    raise TypeError(f'{prim_name} type of {arg_name} must be {arg_type.__name__}, but got `{type(arg_value).__name__}`')
190
191
192def check_number_range(arg_value, lower_limit, upper_limit, rel, value_type, arg_name=None, prim_name=None):
193    """
194    Method for checking whether an int value is in some range.
195
196    Usage:
197    - number = check_number_range(number, 0.0, 1.0, Rel.INC_NEITHER, "number", float) # number in [0.0, 1.0]
198    - number = check_number_range(number, 0, 1, Rel.INC_NEITHER, "number", int) # number in [0, 1]
199    """
200    rel_fn = Rel.get_fns(rel)
201    prim_name = f'in `{prim_name}`' if prim_name else ''
202    arg_name = f'`{arg_name}`' if arg_name else ''
203    type_mismatch = not isinstance(arg_value, (np.ndarray, np.generic, value_type)) or isinstance(arg_value, bool)
204    if type_mismatch:
205        raise TypeError("{} {} must be `{}`,  but got `{}`.".format(
206            arg_name, prim_name, value_type.__name__, type(arg_value).__name__))
207    if not rel_fn(arg_value, lower_limit, upper_limit):
208        rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit)
209        raise ValueError("{} {} should be in range of {}, but got {:.3e} with type `{}`.".format(
210            arg_name, prim_name, rel_str, arg_value, type(arg_value).__name__))
211    return arg_value
212
213
214class Validator:
215    """validator for checking input parameters"""
216
217    @staticmethod
218    def check(arg_name, arg_value, value_name, value, rel=Rel.EQ, prim_name=None, excp_cls=ValueError):
219        """
220        Method for judging relation between two int values or list/tuple made up of ints.
221        This method is not suitable for judging relation between floats, since it does not consider float error.
222        """
223        rel_fn = Rel.get_fns(rel)
224        if not rel_fn(arg_value, value):
225            rel_str = Rel.get_strs(rel).format(f'{value_name}: {value}')
226            msg_prefix = f'For \'{prim_name}\', the' if prim_name else "The"
227            raise excp_cls(f'{msg_prefix} \'{arg_name}\' should be {rel_str}, but got {arg_value}.')
228        return arg_value
229
230    @staticmethod
231    def check_int(arg_value, value, rel, arg_name=None, prim_name=None):
232        """
233        Checks input integer value `arg_value` compare to `value`.
234
235        Usage:
236        - number = check_int(number, 0, Rel.GE, "number", None) # number >= 0
237        """
238        return check_number(arg_value, value, rel, int, arg_name, prim_name)
239
240    @staticmethod
241    def check_is_int(arg_value, arg_name=None, prim_name=None):
242        """
243        Checks input value is float type or not.
244
245        Usage:
246        - number = check_is_int(number, int)
247        - number = check_is_int(number, int, "bias")
248        - number = check_is_int(number, int, "bias", "bias_class")
249        """
250        return check_is_number(arg_value, int, arg_name, prim_name)
251
252    @staticmethod
253    def check_equal_int(arg_value, value, arg_name=None, prim_name=None):
254        """
255        Checks input integer value `arg_value` compare to `value`.
256
257        Usage:
258        - number = check_int(number, 0, Rel.GE, "number", None) # number >= 0
259        """
260        return check_number(arg_value, value, Rel.EQ, int, arg_name, prim_name)
261
262    @staticmethod
263    def check_positive_int(arg_value, arg_name=None, prim_name=None):
264        """
265        Check argument is positive integer, which mean arg_value > 0.
266
267        Usage:
268        - number = check_positive_int(number)
269        - number = check_positive_int(number, "bias")
270        """
271        return check_number(arg_value, 0, Rel.GT, int, arg_name, prim_name)
272
273    @staticmethod
274    def check_negative_int(arg_value, arg_name=None, prim_name=None):
275        """
276        Check argument is negative integer, which mean arg_value < 0.
277
278        Usage:
279        - number = check_negative_int(number)
280        - number = check_negative_int(number, "bias")
281        """
282        return check_number(arg_value, 0, Rel.LT, int, arg_name, prim_name)
283
284    @staticmethod
285    def check_non_positive_int(arg_value, arg_name=None, prim_name=None):
286        """
287        Check argument is non-negative integer, which mean arg_value <= 0.
288
289        Usage:
290        - number = check_non_positive_int(number)
291        - number = check_non_positive_int(number, "bias")
292        """
293        return check_number(arg_value, 0, Rel.LE, int, arg_name, prim_name)
294
295    @staticmethod
296    def check_non_negative_int(arg_value, arg_name=None, prim_name=None):
297        """
298        Check argument is non-negative integer, which mean arg_value >= 0.
299
300        Usage:
301        - number = check_non_negative_int(number)
302        - number = check_non_negative_int(number, "bias")
303        """
304        return check_number(arg_value, 0, Rel.GE, int, arg_name, prim_name)
305
306    @staticmethod
307    def check_float(arg_value, value, rel, arg_name=None, prim_name=None):
308        """
309        Checks input float value `arg_value` compare to `value`.
310
311        Usage:
312        - number = check_float(number, 0.0, Rel.GE, "number", None) # number >= 0
313        """
314        return check_number(arg_value, value, rel, float, arg_name, prim_name)
315
316    @staticmethod
317    def check_is_float(arg_value, arg_name=None, prim_name=None):
318        """
319        Checks input value is float type or not.
320
321        Usage:
322        - number = check_is_float(number, int)
323        - number = check_is_float(number, int, "bias")
324        - number = check_is_float(number, int, "bias", "bias_class")
325        """
326        return check_is_number(arg_value, float, arg_name, prim_name)
327
328    @staticmethod
329    def check_positive_float(arg_value, arg_name=None, prim_name=None):
330        """
331        Check argument is positive float, which mean arg_value > 0.
332
333        Usage:
334        - number = check_positive_float(number)
335        - number = check_positive_float(number, "bias")
336        - number = check_positive_float(number, "bias", "bias_class")
337        """
338        return check_number(arg_value, 0, Rel.GT, float, arg_name, prim_name)
339
340    @staticmethod
341    def check_negative_float(arg_value, arg_name=None, prim_name=None):
342        """
343        Check argument is negative float, which mean arg_value < 0.
344
345        Usage:
346        - number = check_negative_float(number)
347        - number = check_negative_float(number, "bias")
348        """
349        return check_number(arg_value, 0, Rel.LT, float, arg_name, prim_name)
350
351    @staticmethod
352    def check_non_positive_float(arg_value, arg_name=None, prim_name=None):
353        """
354        Check argument is non-negative float, which mean arg_value <= 0.
355
356        Usage:
357        - number = check_non_positive_float(number)
358        - number = check_non_positive_float(number, "bias")
359        """
360        return check_number(arg_value, 0, Rel.LE, float, arg_name, prim_name)
361
362    @staticmethod
363    def check_non_negative_float(arg_value, arg_name=None, prim_name=None):
364        """
365        Check argument is non-negative float, which mean arg_value >= 0.
366
367        Usage:
368        - number = check_non_negative_float(number)
369        - number = check_non_negative_float(number, "bias")
370        """
371        return check_number(arg_value, 0, Rel.GE, float, arg_name, prim_name)
372
373    @staticmethod
374    def check_number(arg_name, arg_value, value, rel, prim_name):
375        """Number value judgment."""
376        rel_fn = Rel.get_fns(rel)
377        if not rel_fn(arg_value, value):
378            rel_str = Rel.get_strs(rel).format(value)
379            raise ValueError(f'For \'{prim_name}\' the `{arg_name}` must {rel_str}, but got {arg_value}.')
380        return arg_value
381
382    @staticmethod
383    def check_isinstance(arg_name, arg_value, classes):
384        """Check arg isinstance of classes"""
385        if not isinstance(arg_value, classes):
386            raise ValueError(f'The `{arg_name}` should be isinstance of {classes}, but got {arg_value}.')
387        return arg_value
388
389    @staticmethod
390    def check_bool(arg_value, arg_name=None, prim_name=None):
391        """
392        Check argument is instance of bool.
393
394        Usage:
395        - has_bias = check_bool(has_bias)
396        - has_bias = check_bool(has_bias, "has_bias")
397        """
398        if not isinstance(arg_value, bool):
399            if prim_name and arg_name:
400                msg_prefix = f"For '{prim_name}', the '{arg_name}'"
401            elif prim_name and arg_name is None:
402                msg_prefix = f"For '{prim_name}', Parameter"
403            else:
404                msg_prefix = "Parameter"
405            raise TypeError(f"{msg_prefix} should be a bool, but got {type(arg_value).__name__}.")
406        return arg_value
407
408    @staticmethod
409    def check_int_range(arg_value, lower_limit, upper_limit, rel, arg_name=None, prim_name=None):
410        """
411        Method for checking whether input value is in int range.
412
413        Usage:
414        - number = check_int_range(number, 0, 1, Rel.INC_NEITHER) # number in [0, 1]
415        - number = check_int_range(number, 0, 1, Rel.INC_NEITHER, "number") # number in [0, 1]
416        """
417        return check_number_range(arg_value, lower_limit, upper_limit, rel, int, arg_name, prim_name)
418
419    @staticmethod
420    def check_float_range(arg_value, lower_limit, upper_limit, rel, arg_name=None, prim_name=None):
421        """
422        Method for checking whether input value is in float range.
423
424        Usage:
425        - number = check_float_range(number, 0.0, 1.0, Rel.INC_NEITHER) # number in [0.0, 1.0]
426        - number = check_float_range(number, 0.0, 1.0, Rel.INC_NEITHER, "number") # number in [0.0, 1.0]
427        """
428        return check_number_range(arg_value, lower_limit, upper_limit, rel, float, arg_name, prim_name)
429
430    @staticmethod
431    def check_string(arg_value, valid_values, arg_name=None, prim_name=None):
432        """
433        Check whether string is in some value list.
434
435        Usage:
436        - method = check_string(method, ["string1", "string2", "string3"], "method")
437        """
438        if isinstance(arg_value, str) and arg_value in valid_values:
439            return arg_value
440        arg_name = arg_name if arg_name else "Parameter"
441        msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The"
442        raise ValueError(f"{msg_prefix} '{arg_name}' should be str and must be in '{valid_values}',"
443                         f" but got '{arg_value}'.")
444
445    @staticmethod
446    def check_str_by_regular(target, reg=None, flag=re.ASCII, prim_name=None):
447        if reg is None:
448            # Named string regular expression
449            reg = r"^\w+[0-9a-zA-Z\_\.]*$"
450        if re.match(reg, target, flag) is None:
451            prim_name = f'in `{prim_name}`' if prim_name else ""
452            raise ValueError("'{}' {} is illegal, it should be match regular'{}' by flags'{}.'".format(
453                target, prim_name, reg, flag))
454        return True
455
456    @staticmethod
457    def check_file_name_by_regular(target, reg=None, prim_name=None):
458        """Check whether file name is legitimate."""
459        if not isinstance(target, str):
460            raise ValueError("Args file_name {} must be string, please check it".format(target))
461        if target.endswith("\\") or target.endswith("/"):
462            raise ValueError("File name cannot be a directory path.")
463        if reg is None:
464            reg = r"^[0-9a-zA-Z\_\-\.\:\/\\]+$"
465        if re.match(reg, target) is None:
466            prim_name = f'in `{prim_name}`' if prim_name else ""
467            raise ValueError("'{}' {} is illegal, it should be match regular'{}'.".format(
468                target, prim_name, reg))
469
470        return True
471
472    @staticmethod
473    def check_pad_value_by_mode(pad_mode, padding, prim_name):
474        """Validates value of padding according to pad_mode"""
475        if pad_mode != 'pad' and padding != 0:
476            raise ValueError(f"For '{prim_name}', padding must be zero when pad_mode is '{pad_mode}'.")
477        return padding
478
479    @staticmethod
480    def check_subclass(arg_name, type_, template_types, prim_name, addition_error_info=None):
481        """Checks whether some type is subclass of another type"""
482        if not isinstance(template_types, Iterable):
483            template_types = (template_types,)
484        hit = False
485        for template_type in template_types:
486            if isinstance(template_type, mstype.Type):
487                if mstype.issubclass_(type_, template_type):
488                    hit = True
489                    break
490            elif type_ is template_type:
491                hit = True
492                break
493        if not hit:
494            if addition_error_info is None:
495                addition_error_info = ''
496            type_str = (type(type_).__name__ if isinstance(type_, (tuple, list)) else "") + str(type_)
497            raise TypeError(f"For '{prim_name}', the type of '{arg_name}'"
498                            f" should be {'one of ' if len(template_types) > 1 else ''}"
499                            f"{', '.join((str(x) for x in template_types))}, but got {type_str}"
500                            f" {addition_error_info}. The supported data types depend on the hardware that"
501                            f" executes the operator, please refer the official api document to get"
502                            f" more information about the data type.")
503
504    @staticmethod
505    def check_valid_input(arg_name, arg_value, prim_name):
506        """Checks valid value."""
507        if arg_value is None:
508            raise ValueError(f"For \'{prim_name}\', the '{arg_name}' can not be None, but got {arg_value}.")
509        return arg_value
510
511    @staticmethod
512    def check_types_same_and_valid(args, valid_values, prim_name):
513        """Checks whether the types of inputs are the same and valid."""
514
515        def _check_type_valid(arg):
516            arg_key, arg_val = arg
517            elem_type = arg_val
518            Validator.check_subclass(arg_key, elem_type, valid_values, prim_name)
519            return (arg_key, elem_type)
520
521        def _check_types_same(arg1, arg2):
522            arg1_name, arg1_type = arg1
523            arg2_name, arg2_type = arg2
524            if arg1_type != arg2_type:
525                raise TypeError(f"For '{prim_name}', type of '{arg2_name}' should be same as '{arg1_name}',"
526                                f" but got '{arg1_name}' with type {arg1_type}"
527                                f" and '{arg2_name}' with type {arg2_type}.")
528            return arg1
529
530        elem_types = map(_check_type_valid, args.items())
531        reduce(_check_types_same, elem_types)
532
533    @staticmethod
534    def check_tensors_dtypes_same_and_valid(args, valid_dtypes, prim_name):
535        """Checks whether the element types of input tensors are the same and valid."""
536        valid_dtypes = valid_dtypes if isinstance(valid_dtypes, Iterable) else [valid_dtypes]
537        tensor_types = [mstype.tensor_type(t) for t in valid_dtypes]
538        Validator.check_types_same_and_valid(args, tensor_types, prim_name)
539
540    @staticmethod
541    def check_tensor_dtype_valid(arg_name, arg_type, valid_dtypes, prim_name):
542        """Checks whether the element types of input tensors are valid."""
543        valid_dtypes = valid_dtypes if isinstance(valid_dtypes, Iterable) else [valid_dtypes]
544        tensor_types = [mstype.tensor_type(t) for t in valid_dtypes]
545        Validator.check_subclass(arg_name, arg_type, tensor_types, prim_name)
546
547    @staticmethod
548    def check_scalar_or_tensor_types_same(args, valid_values, prim_name, allow_mix=False):
549        """
550        Checks whether the types of inputs are the same. If the input args are tensors, checks their element types.
551        If `allow_mix` is True, Tensor(float32) and float32 are type compatible, otherwise an exception will be raised.
552        """
553
554        def _check_argument_type(arg):
555            arg_key, arg_val = arg
556            if isinstance(arg_val, type(mstype.tensor)):
557                arg_val = arg_val.element_type()
558            if not arg_val in valid_values:
559                raise TypeError(f'For \'{prim_name}\', the type of `{arg_key}` should be in {valid_values},'
560                                f' but got {arg_val}.')
561            return arg
562
563        def _check_types_same(arg1, arg2):
564            arg1_name, arg1_type = arg1
565            arg2_name, arg2_type = arg2
566            except_flag = False
567            if isinstance(arg1_type, type(mstype.tensor)) and isinstance(arg2_type, type(mstype.tensor)):
568                arg1_type = arg1_type.element_type()
569                arg2_type = arg2_type.element_type()
570            elif not (isinstance(arg1_type, type(mstype.tensor)) or isinstance(arg2_type, type(mstype.tensor))):
571                pass
572            elif allow_mix:
573                arg1_type = arg1_type.element_type() if isinstance(arg1_type, type(mstype.tensor)) else arg1_type
574                arg2_type = arg2_type.element_type() if isinstance(arg2_type, type(mstype.tensor)) else arg2_type
575            else:
576                except_flag = True
577
578            if except_flag or arg1_type != arg2_type:
579                raise TypeError(f'For \'{prim_name}\' type of `{arg2_name}` should be same as `{arg1_name}`,'
580                                f' but `{arg1_name}` is {arg1_type} and `{arg2_name}` is {arg2_type}.')
581            return arg1
582
583        reduce(_check_types_same, map(_check_argument_type, args.items()))
584
585    @staticmethod
586    def check_value_type(arg_name, arg_value, valid_types, prim_name=None):
587        """Checks whether a value is instance of some types."""
588        valid_types = valid_types if isinstance(valid_types, Iterable) else (valid_types,)
589
590        def raise_error_msg():
591            """func for raising error message when check failed"""
592            type_names = [t.__name__ if hasattr(t, '__name__') else str(t) for t in valid_types]
593            num_types = len(valid_types)
594            msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
595            raise TypeError(f'{msg_prefix} type of `{arg_name}` should be {"one of " if num_types > 1 else ""}'
596                            f'{type_names if num_types > 1 else type_names[0]}, '
597                            f'but got {arg_value} with type {type(arg_value).__name__}.')
598
599        # Notice: bool is subclass of int, so `check_value_type('x', True, [int])` will check fail, and
600        #         `check_value_type('x', True, [bool, int])` will check pass
601        if isinstance(arg_value, bool) and bool not in tuple(valid_types):
602            raise_error_msg()
603        if not isinstance(arg_value, tuple(valid_types)):
604            raise_error_msg()
605        return arg_value
606
607    @staticmethod
608    def check_type_name(arg_name, arg_type, valid_types, prim_name):
609        """Checks whether a type in some specified types"""
610        valid_types = valid_types if isinstance(valid_types, Iterable) else (valid_types,)
611
612        def raise_error_msg():
613            """func for raising error message when check failed"""
614            type_names = [t.__name__ if hasattr(t, '__name__') else t for t in valid_types]
615            num_types = len(valid_types)
616            msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
617            raise TypeError(f"{msg_prefix} '{arg_name}' should be {'one of ' if num_types > 1 else ''}"
618                            f"{type_names if num_types > 1 else type_names[0]}, "
619                            f"but got {arg_type.__name__ if hasattr(arg_type, '__name__') else repr(arg_type)}.")
620
621        if isinstance(arg_type, type(mstype.tensor)):
622            arg_type = arg_type.element_type()
623        if arg_type not in valid_types:
624            raise_error_msg()
625        return arg_type
626
627    @staticmethod
628    def check_reduce_shape(ori_shape, shape, axis, prim_name):
629        """Checks whether shape is ori_shape reduced on axis"""
630        axis = axis if isinstance(axis, Iterable) else (axis,)
631        exp_shape = [ori_shape[i] for i in range(len(ori_shape)) if i not in axis]
632        if list(shape) != exp_shape:
633            raise ValueError(f"For '{prim_name}', the origin shape {ori_shape} reduce on {axis} should be "
634                             f"{tuple(exp_shape)}, but got {shape}.")
635
636    @staticmethod
637    def check_astype_dtype(dtype):
638        """Check whether dtype is a valid input, and convert to mstype"""
639        all_types = mstype.__dtype__ + ["int", "float", "bool"]
640        if isinstance(dtype, str):
641            if dtype.lower() not in all_types:
642                raise TypeError(f"`{dtype}` not understood.")
643            dtype = mstype.pytype_to_dtype(np.dtype(dtype.lower()))
644        elif isinstance(dtype, type):
645            dtype = mstype.pytype_to_dtype(dtype)
646        elif not dtype in mstype.number_type + (mstype.bool_,):
647            raise TypeError(f"`{dtype}` not understood.")
648        return dtype
649
650    @staticmethod
651    def check_transpose_axis(axes, ndim):
652        """Check the axis argument for tensor.transpose"""
653        if not axes or (len(axes) == 1 and axes[0] is None):
654            return tuple(range(ndim-1, -1, -1))
655
656        if len(axes) == 1:
657            perm = axes[0]
658            # if only one argument provided, it must be tuple or list
659            if isinstance(perm, list):
660                perm = tuple(perm)
661            else:
662                if not isinstance(perm, tuple):
663                    raise TypeError(f"The `axes` should be a tuple/list, or series of int, but got {type(axes[0])}")
664            return perm
665
666        # if multiple arguments provided, it must be `ndim` number of ints
667        if len(axes) != ndim:
668            raise ValueError("The number of axes must equal to the dimension of tensor.")
669        return axes
670
671    @staticmethod
672    def check_reshape_shp(shp):
673        """Check the shape argument for tensor.reshape"""
674
675        if len(shp) == 1:
676            new_shape = shp[0]
677            # if only one argument provided, it must be int, tuple or list
678            if isinstance(new_shape, int):
679                return shp
680            if isinstance(new_shape, list):
681                new_shape = tuple(new_shape)
682            else:
683                if not isinstance(new_shape, tuple):
684                    raise TypeError(
685                        f"The `shape` should be an int, or tuple/list, or series of int, but got {type(shp[0])}")
686            return new_shape
687
688        return shp
689
690    @staticmethod
691    def check_flatten_order(order):
692        """Check flatten function input order"""
693        if not isinstance(order, str):
694            raise TypeError(f"The order variable should be a string, but got {type(order)}")
695        if order not in ('C', 'F'):
696            raise ValueError(f"only `C` and `F` are supported as order, but got {order}")
697        return order
698
699    @staticmethod
700    def check_swapaxes_axis(axes, ndim):
701        """Check all the axes argument for tensor.swapaxes"""
702        if isinstance(axes, int):
703            Validator.check_axis_in_range(axes, ndim)
704            return axes % ndim
705        if isinstance(axes, (tuple, list)):
706            for axis in axes:
707                if not isinstance(axis, int):
708                    raise TypeError(f"axis argument should be integer, but got {type(axis)}.")
709                Validator.check_axis_in_range(axis, ndim)
710            axes = tuple(map(lambda x: x % ndim, axes))
711            return axes
712        raise TypeError(f"axes should be integer, list or tuple for check, but got {type(axes)}.")
713
714    @staticmethod
715    def prepare_shape_for_squeeze(shape, axes):
716        """
717        Creates the squeezed new shape based on the tensor and given axes.
718
719        Args:
720            shape (tuple): the shape of the tensor
721            axes Union[int, tuple(int), list(int)]: the axes with dimensions need to
722                be squeezed.
723
724        Returns:
725            new_shape(tuple): the shape with dimensions squeezed.
726        """
727        new_shape = []
728        ndim = len(shape)
729
730        # Convert to set
731        if isinstance(axes, int):
732            if axes >= ndim or axes < -ndim:
733                raise ValueError(f"axis {axes} is out of bounds for tensor of dimension {ndim}")
734            axes = {axes}
735
736        elif isinstance(axes, (list, tuple)):
737            for axis in axes:
738                if axis >= ndim or axis < -ndim:
739                    raise ValueError(f"axis {axis} is out of bounds for tensor of dimension {ndim}")
740            axes = set(axes)
741
742        else:
743            raise TypeError(f"only int, tuple and list are allowed for axes, but got {type(axes)}")
744
745        for idx, s in enumerate(shape):
746            if s != 1 or (idx not in axes) and (idx - ndim not in axes):
747                new_shape.append(s)
748            # if an axis is selected with shape entry greater than one, an error is raised.
749            if s != 1 and ((idx in axes) or (idx - ndim in axes)):
750                raise ValueError(f"axis {axes} has shape entry {s} > 1, cannot be squeezed.")
751        return tuple(new_shape)
752
753    @staticmethod
754    def check_axis_in_range(axis, ndim):
755        """Checks axes are with the bounds of ndim"""
756        if not isinstance(axis, int):
757            raise TypeError(f'axes should be integers, not {type(axis)}')
758        if not -ndim <= axis < ndim:
759            raise ValueError(f'axis {axis} is out of bounds for array of dimension {ndim}')
760        return axis % ndim
761
762    @staticmethod
763    def check_axis_valid(axes, ndim):
764        """
765        Checks axes are valid given ndim, and returns axes that can be passed
766        to the built-in operator (non-negative, int or tuple)
767        """
768        if axes is None:
769            axes = tuple(range(ndim))
770            return axes
771        if isinstance(axes, (tuple, list)):
772            for axis in axes:
773                Validator.check_axis_in_range(axis, ndim)
774            axes = tuple(map(lambda x: x % ndim, axes))
775            if any(axes.count(el) > 1 for el in axes):
776                raise ValueError('duplicate value in "axis"')
777            return axes
778        Validator.check_axis_in_range(axes, ndim)
779        return (axes % ndim,)
780
781    @staticmethod
782    def max_(*args):
783        return max(*args)
784
785    @staticmethod
786    def min_(*args):
787        return min(*args)
788
789    @staticmethod
790    def expanded_shape(ndim, axis_size, axis):
791        """
792        Returns a shape with size = 1 for all dimensions
793        except at axis.
794        """
795        return tuple(axis_size if i == axis else 1 for i in range(ndim))
796
797    @staticmethod
798    def tuple_slice(tup, start, end):
799        """get sliced tuple from start and end."""
800        return tup[start:end]
801
802    @staticmethod
803    def infer_out_shape(*shapes):
804        """
805        Returns shape of output after broadcasting. Raises ValueError if shapes cannot be broadcast.
806        """
807        shape_out = deque()
808        reversed_shapes = map(reversed, shapes)
809        for items in zip_longest(*reversed_shapes, fillvalue=1):
810            max_size = 0 if 0 in items else max(items)
811            if any(item not in (1, max_size) for item in items):
812                raise ValueError(f'operands could not be broadcast together with shapes {*shapes,}')
813            shape_out.appendleft(max_size)
814        return tuple(shape_out)
815
816    @staticmethod
817    def get_log2_size(size):
818        return math.ceil(math.log2(size))
819
820    @staticmethod
821    def check_axis_type(axis, type_int=True, type_tuple=True, type_list=True):
822        """Check axis argument type."""
823        if type_int and isinstance(axis, int):
824            return True
825        if (type_tuple and isinstance(axis, tuple)) or (type_list and isinstance(axis, list)):
826            for ax in axis:
827                if not isinstance(ax, int):
828                    raise TypeError(f"Each axis should be integer, but got {type(ax)} in {axis}.")
829            return True
830
831        type_str = ""
832        if type_int:
833            type_str += "int, "
834        if type_tuple:
835            type_str += "tuple, "
836        if type_list:
837            type_str += "list, "
838        raise TypeError(f"Axis should be {type_str}but got {type(axis)}.")
839
840    @staticmethod
841    def check_and_canonicalize_axes(axes, ndim):
842        """Check whether the types and values of input axes are valid."""
843        axes = axes if isinstance(axes, tuple) else (axes,)
844        new_axes = ()
845        for ax in axes:
846            if not isinstance(ax, int):
847                raise TypeError((f"Each axis should be integer, but got {type(ax)} in {axes}."))
848            if not -ndim <= ax < ndim:
849                raise ValueError(f'axis {ax} is out of bounds for array of dimension {ndim}')
850            ax = ax if ax >= 0 else ax + ndim
851            new_axes += (ax,)
852        if any(new_axes.count(el) > 1 for el in new_axes):
853            raise ValueError('duplicate value in "axis"')
854        return new_axes
855
856    @staticmethod
857    def empty_compile(dtype, shape):
858        """Returns an empty Tensor."""
859        return Tensor_(dtype, shape)
860
861    @staticmethod
862    def check_type_support(dtype, device, supported_dtypes):
863        return dtype in supported_dtypes or not context.get_context('device_target') == device
864
865
866def check_input_format(input_param):
867    """Judge input format."""
868    if input_param == "NCHW":
869        return input_param
870    raise ValueError("The data format must be NCHW.")
871
872
873def _expand_tuple(n_dimensions):
874    """To expand a int number to tuple."""
875
876    def convert(m):
877        if not isinstance(m, tuple):
878            if isinstance(m, int) and not isinstance(m, bool):
879                return tuple(repeat(m, n_dimensions))
880            raise TypeError("Input type must be int or tuple[int].")
881
882        if not len(m) is n_dimensions:
883            raise TypeError("Input tuple dimension is incorrect.")
884
885        for i in m:
886            if not isinstance(i, int) or isinstance(i, bool):
887                raise TypeError("Incorrect type inside of a tuple, must be int!")
888        return m
889
890    return convert
891
892
893def _check_data_type_valid(data, valid_type):
894    """Check data type valid."""
895    if valid_type is None:
896        return data is None
897    if isinstance(data, valid_type):
898        if hasattr(data, 'size') and data.size == 0:
899            msg = "Please provide non-empty data."
900            logger.error(msg)
901            raise ValueError(msg)
902        return True
903    return False
904
905
906def check_input_data(*data, data_class):
907    """Input data check."""
908    for item in data:
909        if isinstance(item, (list, tuple)):
910            for v in item:
911                check_input_data(v, data_class=data_class)
912        elif isinstance(item, dict):
913            for v in item.values():
914                check_input_data(v, data_class=data_class)
915        else:
916            if isinstance(data_class, (tuple, list)):
917                ret = True in tuple(_check_data_type_valid(item, data_type) for data_type in data_class)
918            else:
919                ret = _check_data_type_valid(item, data_class)
920            if not ret:
921                data_class_str = tuple(i.__name__ if hasattr(i, '__name__') else i for i in data_class) \
922                                 if isinstance(data_class, (tuple, list)) else \
923                                 (data_class if data_class is None else data_class.__name__)
924                raise ValueError(f'Please provide as model inputs either a single or '
925                                 f'a tuple or a list or a dict of {data_class_str}, '
926                                 f'but got part data type is {item if item is None else type(item).__name__}.')
927
928
929def check_output_data(data):
930    """Output data check."""
931    if data is None:
932        raise RuntimeError('Executor return data ' + str(data) + ', please check your net or input data.')
933
934
935once = _expand_tuple(1)
936twice = _expand_tuple(2)
937triple = _expand_tuple(3)
938
939
940def args_type_check(*type_args, **type_kwargs):
941    """Check whether input data type is correct."""
942
943    def type_check(func):
944        sig = inspect.signature(func)
945        bound_types = sig.bind_partial(*type_args, **type_kwargs).arguments
946
947        @wraps(func)
948        def wrapper(*args, **kwargs):
949            nonlocal bound_types
950            bound_values = sig.bind(*args, **kwargs)
951            argument_dict = bound_values.arguments
952            if "kwargs" in bound_types:
953                bound_types = bound_types["kwargs"]
954            if "kwargs" in argument_dict:
955                argument_dict = argument_dict["kwargs"]
956            for name, value in argument_dict.items():
957                if name in bound_types:
958                    if value is not None and not isinstance(value, bound_types[name]):
959                        raise TypeError('Argument {} must be {}'.format(name, bound_types[name]))
960            return func(*args, **kwargs)
961
962        return wrapper
963
964    return type_check
965
966
967_set_record = {}
968
969
970def args_unreset_check(*unreset_args, **unreset_kwargs):
971    """Check the entered non repeatable setting properties."""
972
973    def unreset_check(func):
974        sig = inspect.signature(func)
975        bound_unreset = sig.bind_partial(*unreset_args, **unreset_kwargs).arguments
976
977        @wraps(func)
978        def wrapper(*args, **kwargs):
979            nonlocal bound_unreset
980            bound_values = sig.bind(*args, **kwargs)
981            argument_dict = bound_values.arguments
982            if "kwargs" in bound_unreset:
983                bound_unreset = bound_unreset["kwargs"]
984            if "kwargs" in argument_dict:
985                argument_dict = argument_dict["kwargs"]
986            for name, value in argument_dict.items():
987                if name in _set_record.keys():
988                    raise TypeError('Argument {} is non-renewable parameter {}.'.format(name, bound_unreset[name]))
989                if name in bound_unreset:
990                    _set_record[name] = value
991            return func(*args, **kwargs)
992
993        return wrapper
994
995    return unreset_check
996