1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""tbe common""" 16import os 17 18 19class TBEException(Exception): 20 """tbe exception class""" 21 22 def __init__(self, err_msg): 23 super().__init__(self) 24 self.__error_msg = err_msg 25 26 def __str__(self): 27 return self.__error_msg 28 29 30def get_built_in_impl_path(): 31 """get built-in tbe implement path""" 32 tbe_impl_path = os.environ.get("TBE_IMPL_PATH") 33 if tbe_impl_path is None: 34 default_install_path = '/usr/local/HiAI/runtime/ops/op_impl/built-in/ai_core/tbe/' 35 backup_install_path = '/usr/local/Ascend/opp/op_impl/built-in/ai_core/tbe/' 36 if os.path.exists(default_install_path): 37 tbe_impl_path = default_install_path 38 elif os.path.exists(backup_install_path): 39 tbe_impl_path = backup_install_path 40 if not tbe_impl_path: 41 raise ValueError("Can not find the env TBE_IMPL_PATH") 42 return tbe_impl_path 43 44 45def _check_arg_info(item): 46 """ 47 Check parameter Validity. 48 49 Args: 50 item (dict): A dict, to be checked. 51 52 Raises: 53 Exception: If specific keyword is not found. 54 """ 55 if 'shape' not in item: 56 raise ValueError("Json string Errors, key:shape not found.") 57 if 'ori_shape' not in item: 58 raise ValueError("Json string Errors, key:ori_shape not found.") 59 if 'format' not in item or not item['format']: 60 raise ValueError("Json string Errors, key:format not found.") 61 if 'ori_format' not in item or not item['ori_format']: 62 raise ValueError("Json string Errors, key:ori_format not found.") 63 if 'dtype' not in item or not item['dtype']: 64 raise ValueError("Json string Errors, key:dtype not found.") 65 if 'param_type' not in item or not item['param_type']: 66 raise ValueError("Json string Errors, key:param_type not found.") 67 68 69def get_input_output(io_info, args): 70 """ 71 Parse args. 72 73 Args: 74 io_info (dict): input or output info dict. 75 args (list): the arguments list. 76 77 Raises: 78 Exception: If specific keyword is not found. 79 """ 80 for item in io_info: 81 arg = [] 82 for info in item: 83 if 'valid' not in info: 84 raise ValueError("Json string Errors, key:valid not found.") 85 if info['valid']: 86 _check_arg_info(info) 87 del info['valid'] 88 del info['name'] 89 if len(item) > 1: 90 arg.append(info) 91 else: 92 if info['param_type'] == 'dynamic': 93 arg.append(info) 94 else: 95 args.append(info) 96 else: 97 if len(item) > 1: 98 arg.append(None) 99 else: 100 args.append(None) 101 if arg: 102 args.append(tuple(arg)) 103 104 105def get_attr(attr_info, args): 106 """ 107 Parse args. 108 109 Args: 110 attr_info (dict): input or output info dict. 111 args (list): the arguments list. 112 113 Raises: 114 Exception: If specific keyword is not found. 115 """ 116 for item in attr_info: 117 if item["valid"]: 118 if 'value' not in item: 119 raise ValueError("Json string Errors, attr key:value not found.") 120 if item["name"] != "isRef": 121 args.append(item['value']) 122 123 124def get_args(op_info, arg_type): 125 """ 126 Parse args. 127 128 Args: 129 op_info (dict): Op info dict. 130 arg_type (str): arg, to be parsed. 131 132 Raises: 133 Exception: If specific keyword is not found. 134 """ 135 if arg_type not in op_info: 136 raise ValueError("Json string Errors, key:{} not found.".format(arg_type)) 137 args = [] 138 if not op_info[arg_type]: 139 return args 140 141 arg_info = op_info[arg_type] 142 if arg_type in ['inputs', 'outputs']: 143 get_input_output(arg_info, args) 144 elif arg_type == 'attrs': 145 get_attr(arg_info, args) 146 147 return args 148 149 150def check_kernel_info(kernel_info): 151 if 'op_info' not in kernel_info or not kernel_info['op_info']: 152 raise ValueError("Json string Errors, key:op_info not found.") 153 if 'name' not in kernel_info['op_info'] or not kernel_info['op_info']['name']: 154 raise ValueError("Json string Errors, key:name not found.") 155 if 'kernel_name' not in kernel_info['op_info'] or not kernel_info['op_info']['kernel_name']: 156 raise ValueError("Json string Errors, key:kernel_name not found.") 157