1# Copyright 2022 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" 16Check parameters. 17""" 18 19 20def check_isinstance(arg_name, arg_value, classes, enable_none=False): 21 """Check arg isinstance of classes""" 22 if enable_none: 23 if arg_value is None: 24 return arg_value 25 if not isinstance(arg_value, classes): 26 raise TypeError(f"{arg_name} must be {classes.__name__}, but got {format(type(arg_value))}.") 27 return arg_value 28 29 30def check_list_of_element(arg_name, arg_value, ele_classes, enable_none=False): 31 """Check arg isinstance of classes""" 32 if enable_none: 33 if arg_value is None: 34 return arg_value 35 if not isinstance(arg_value, list): 36 raise TypeError(f"{arg_name} must be list, but got {format(type(arg_value))}.") 37 for i, element in enumerate(arg_value): 38 if not isinstance(element, ele_classes): 39 raise TypeError(f"{arg_name} element must be {ele_classes.__name__}, but got " 40 f"{type(element)} at index {i}.") 41 return arg_value 42 43 44def check_uint32_number_range(arg_name, arg_value): 45 """Check arg uint32 number range""" 46 check_isinstance(arg_name, arg_value, int) 47 if arg_value < 0 or arg_value > pow(2, 32) - 1: 48 raise ValueError(f"{arg_name} value should be in range [0, UINT32_MAX], but got {arg_value}") 49 50 51def check_uint64_number_range(arg_name, arg_value): 52 """Check arg uint64 number range""" 53 check_isinstance(arg_name, arg_value, int) 54 if arg_value < 0 or arg_value > pow(2, 64) - 1: 55 raise ValueError(f"{arg_name} value should be in range [0, UINT64_MAX], but got {arg_value}") 56 57 58def check_input_shape(input_shape_name, input_shape, enable_none=False): 59 """Check input_shape's type is dict{str: list[int]}""" 60 if enable_none: 61 if input_shape is None: 62 return input_shape 63 if not isinstance(input_shape, dict): 64 raise TypeError(f"{input_shape_name} must be dict, but got {format(type(input_shape))}.") 65 for key in input_shape: 66 if not isinstance(key, str): 67 raise TypeError(f"{input_shape_name} key must be str, but got {format(type(input_shape))}.") 68 if not isinstance(input_shape[key], list): 69 raise TypeError(f"{input_shape_name} value must be list, but got " 70 f"{type(input_shape[key])} at key {key}.") 71 for j, element in enumerate(input_shape[key]): 72 if not isinstance(element, int): 73 raise TypeError(f"{input_shape_name} value's element must be int, but got " 74 f"{type(element)} at index {j}.") 75 return input_shape 76 77 78def check_config_info(config_info_name, config_info, enable_none=False): 79 """Check config_info's type is dict{str: str}""" 80 if enable_none: 81 if config_info is None: 82 return config_info 83 if not isinstance(config_info, dict): 84 raise TypeError(f"{config_info_name} must be dict, but got {format(type(config_info))}.") 85 for key in config_info: 86 if not isinstance(key, str): 87 raise TypeError(f"{config_info_name} key must be str, but got {type(key)} at key {key}.") 88 if not isinstance(config_info[key], str): 89 raise TypeError(f"{config_info_name} val must be str, but got " 90 f"{type(config_info[key])} at key {key}.") 91 return config_info 92 93 94def check_tensor_input_param(shape=None, device=None): 95 """Check tensor input param""" 96 if shape is not None: 97 if not isinstance(shape, (list, tuple)): 98 raise TypeError(f"shape must be list or tuple, but got {type(shape)}.") 99 for i, element in enumerate(shape): 100 if not isinstance(element, int): 101 raise TypeError(f"shape element must be int, but got {type(element)} at index {i}.") 102 if device is None: 103 return 104 if device is not None and not isinstance(device, str): 105 raise TypeError(f"device must be str, but got {type(device)}.") 106 split_device = device.split(":") 107 if len(split_device) > 2: 108 raise TypeError(f"device must be 'ascend:index', eg: 'ascend:0'") 109 if len(split_device) > 0 and split_device[0] != "ascend": 110 raise TypeError(f"now only support ascend device.") 111 if len(split_device) == 2 and not split_device[1].isdigit(): 112 raise TypeError(f"device id should >= 0.") 113