• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2022 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16Check parameters.
17"""
18
19
20def check_isinstance(arg_name, arg_value, classes, enable_none=False):
21    """Check arg isinstance of classes"""
22    if enable_none:
23        if arg_value is None:
24            return arg_value
25    if not isinstance(arg_value, classes):
26        raise TypeError(f"{arg_name} must be {classes.__name__}, but got {format(type(arg_value))}.")
27    return arg_value
28
29
30def check_list_of_element(arg_name, arg_value, ele_classes, enable_none=False):
31    """Check arg isinstance of classes"""
32    if enable_none:
33        if arg_value is None:
34            return arg_value
35    if not isinstance(arg_value, list):
36        raise TypeError(f"{arg_name} must be list, but got {format(type(arg_value))}.")
37    for i, element in enumerate(arg_value):
38        if not isinstance(element, ele_classes):
39            raise TypeError(f"{arg_name} element must be {ele_classes.__name__}, but got "
40                            f"{type(element)} at index {i}.")
41    return arg_value
42
43
44def check_uint32_number_range(arg_name, arg_value):
45    """Check arg uint32 number range"""
46    check_isinstance(arg_name, arg_value, int)
47    if arg_value < 0 or arg_value > pow(2, 32) - 1:
48        raise ValueError(f"{arg_name} value should be in range [0, UINT32_MAX], but got {arg_value}")
49
50
51def check_uint64_number_range(arg_name, arg_value):
52    """Check arg uint64 number range"""
53    check_isinstance(arg_name, arg_value, int)
54    if arg_value < 0 or arg_value > pow(2, 64) - 1:
55        raise ValueError(f"{arg_name} value should be in range [0, UINT64_MAX], but got {arg_value}")
56
57
58def check_input_shape(input_shape_name, input_shape, enable_none=False):
59    """Check input_shape's type is dict{str: list[int]}"""
60    if enable_none:
61        if input_shape is None:
62            return input_shape
63    if not isinstance(input_shape, dict):
64        raise TypeError(f"{input_shape_name} must be dict, but got {format(type(input_shape))}.")
65    for key in input_shape:
66        if not isinstance(key, str):
67            raise TypeError(f"{input_shape_name} key must be str, but got {format(type(input_shape))}.")
68        if not isinstance(input_shape[key], list):
69            raise TypeError(f"{input_shape_name} value must be list, but got "
70                            f"{type(input_shape[key])} at key {key}.")
71        for j, element in enumerate(input_shape[key]):
72            if not isinstance(element, int):
73                raise TypeError(f"{input_shape_name} value's element must be int, but got "
74                                f"{type(element)} at index {j}.")
75    return input_shape
76
77
78def check_config_info(config_info_name, config_info, enable_none=False):
79    """Check config_info's type is dict{str: str}"""
80    if enable_none:
81        if config_info is None:
82            return config_info
83    if not isinstance(config_info, dict):
84        raise TypeError(f"{config_info_name} must be dict, but got {format(type(config_info))}.")
85    for key in config_info:
86        if not isinstance(key, str):
87            raise TypeError(f"{config_info_name} key must be str, but got {type(key)} at key {key}.")
88        if not isinstance(config_info[key], str):
89            raise TypeError(f"{config_info_name} val must be str, but got "
90                            f"{type(config_info[key])} at key {key}.")
91    return config_info
92
93
94def check_tensor_input_param(shape=None, device=None):
95    """Check tensor input param"""
96    if shape is not None:
97        if not isinstance(shape, (list, tuple)):
98            raise TypeError(f"shape must be list or tuple, but got {type(shape)}.")
99        for i, element in enumerate(shape):
100            if not isinstance(element, int):
101                raise TypeError(f"shape element must be int, but got {type(element)} at index {i}.")
102    if device is None:
103        return
104    if device is not None and not isinstance(device, str):
105        raise TypeError(f"device must be str, but got {type(device)}.")
106    split_device = device.split(":")
107    if len(split_device) > 2:
108        raise TypeError(f"device must be 'ascend:index', eg: 'ascend:0'")
109    if len(split_device) > 0 and split_device[0] != "ascend":
110        raise TypeError(f"now only support ascend device.")
111    if len(split_device) == 2 and not split_device[1].isdigit():
112        raise TypeError(f"device id should >= 0.")
113