1# Copyright 2023 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""pyboost utils.""" 16 17import os 18import logging 19from gen_utils import safe_load_yaml 20 21 22def is_optional_param(op_arg): 23 if op_arg.as_init_arg and str(op_arg.default) == 'None': 24 return True 25 return False 26 27 28def is_tensor(op_arg): 29 if op_arg.arg_dtype == 'tensor': 30 return True 31 return False 32 33 34def is_tensor_list(op_arg): 35 if op_arg.arg_dtype in ['list[tensor]', 'tuple[tensor]']: 36 return True 37 return False 38 39 40def is_list(op_arg): 41 if op_arg.arg_dtype in ['tuple[int]', 'tuple[float]', 'tuple[bool]', 42 'tuple[tensor]', 'list[int]', 'list[bool]', 'list[tensor]']: 43 return True 44 return False 45 46 47def get_index(index: int): 48 """ 49 get index 50 :param index: 51 :return: str 52 """ 53 return "kIndex" + str(index) 54 55 56def get_convert_type_str(dtype: str, optional): 57 """ 58 Convert type 59 """ 60 # add more type here 61 native_type_convert = { 62 'int': 'ToInt', 63 'float': 'ToFloat', 64 'bool': 'ToBool', 65 'number': 'ToScalar', 66 'tuple[int]': 'ToIntList<py::tuple>', 67 'tuple[float]': 'ToFloatList<py::tuple>', 68 'tuple[bool]': 'ToBoolList<py::tuple>', 69 'tuple[tensor]': 'ToTensorList<py::tuple>', 70 'list[int]': 'ToIntList<py::list>', 71 'list[float]': 'ToFloatList<py::list>', 72 'list[bool]': 'ToBoolList<py::list>', 73 'list[tensor]': 'ToTensorList<py::list>', 74 'tensor': 'ToTensor', 75 'str': 'ToString', 76 'type': 'ToDtype', 77 } 78 optional_type_convert = { 79 'int': 'ToIntOptional', 80 'float': 'ToFloatOptional', 81 'number': 'ToScalarOptional', 82 'tensor': 'ToTensorOptional', 83 'type': 'ToDtypeOptional', 84 'str': 'ToStringOptional', 85 'tuple[int]': 'ToIntListOptional<py::tuple>', 86 'tuple[float]': 'ToFloatListOptional<py::tuple>', 87 'tuple[bool]': 'ToBoolListOptional<py::tuple>', 88 'tuple[tensor]': 'ToTensorListOptional<py::tuple>', 89 'list[int]': 'ToIntListOptional<py::list>', 90 'list[float]': 'ToFloatListOptional<py::list>', 91 'list[bool]': 'ToBoolListOptional<py::list>', 92 'list[tensor]': 'ToTensorListOptional<py::list>', 93 } 94 if optional: 95 if dtype in optional_type_convert: 96 return optional_type_convert[dtype] 97 raise TypeError(f"""Unsupported convert optional type {dtype} for args.""") 98 if dtype in native_type_convert: 99 return native_type_convert[dtype] 100 raise TypeError(f"""Unsupported convert type {dtype} for args.""") 101 102 103def get_value_convert_type_str(dtype: str, optional): 104 """ 105 Convert type 106 """ 107 # add more type here 108 native_type_convert = { 109 'int': 'ToInt', 110 'float': 'ToFloat', 111 'bool': 'ToBool', 112 'number': 'ToScalar', 113 'tensor': 'ToTensor', 114 'str': 'ToString', 115 'type': 'ToDtype', 116 'tuple[int]': 'ToValueTuple', 117 'tuple[float]': 'ToValueTuple', 118 'tuple[bool]': 'ToValueTuple', 119 'tuple[tensor]': 'ToValueTuple', 120 } 121 optional_type_convert = { 122 'int': 'ToIntOptional', 123 'float': 'ToFloatOptional', 124 'number': 'ToScalarOptional', 125 'tensor': 'ToTensorOptional', 126 'type': 'ToDtypeOptional', 127 'str': 'ToStringOptional', 128 'tuple[int]': 'ToValueTupleOptional', 129 'tuple[float]': 'ToValueTupleOptional', 130 'tuple[bool]': 'ToValueTupleOptional', 131 'tuple[tensor]': 'ToValueTupleOptional', 132 } 133 if optional: 134 if dtype in optional_type_convert: 135 return optional_type_convert[dtype] 136 raise TypeError(f"""Unsupported convert optional type {dtype} for args.""") 137 if dtype in native_type_convert: 138 return native_type_convert[dtype] 139 raise TypeError(f"""Unsupported convert type {dtype} for args.""") 140 141 142def tuple_input_to_cpp_type(dtype: str): 143 """ 144 dtype convert 145 :param dtype: 146 :return: 147 """ 148 types_map = { 149 'tuple[int]': 'int64_t', 150 'tuple[float]': 'float', 151 'tuple[bool]': 'bool', 152 'tuple[str]': 'string', 153 'tuple[tensor]': 'TensorPtr', 154 'list[int]': 'int64_t', 155 'list[float]': 'float', 156 'list[bool]': 'bool', 157 'list[tensor]': 'TensorPtr', 158 } 159 return types_map.get(dtype) 160 161 162def number_input_to_cpp_type(dtype: str): 163 types_map = { 164 'int': 'int64_t', 165 'float': 'float', 166 'bool': 'bool', 167 'str': 'string' 168 } 169 return types_map.get(dtype) 170 171 172def get_input_dtype(dtype: str, optional): 173 """ 174 Convert type 175 """ 176 # add more type here 177 value_tuple = 'ValueTuplePtr' 178 type_convert = { 179 'int': 'Int64ImmPtr', 180 'float': 'FP32ImmPtr', 181 'bool': 'BoolImmPtr', 182 'number': 'ScalarPtr', 183 'str': 'StringImmPtr', 184 'tensor': 'BaseTensorPtr', 185 'tuple[int]': value_tuple, 186 'tuple[float]': value_tuple, 187 'tuple[bool]': value_tuple, 188 'tuple[tensor]': value_tuple, 189 'list[int]': value_tuple, 190 'list[float]': value_tuple, 191 'list[bool]': value_tuple, 192 'list[tensor]': value_tuple, 193 } 194 value_tuple_optional = 'std::optional<ValueTuplePtr>' 195 optional_type_convert = { 196 'int': 'std::optional<Int64ImmPtr>', 197 'float': 'std::optional<FP32ImmPtr>', 198 'bool': 'std::optional<BoolImmPtr>', 199 'number': 'std::optional<ScalarPtr>', 200 'str': 'std::optional<StringImmPtr>', 201 'tensor': 'std::optional<BaseTensorPtr>', 202 'tuple[int]': value_tuple_optional, 203 'tuple[float]': value_tuple_optional, 204 'tuple[bool]': value_tuple_optional, 205 'tuple[tensor]': value_tuple_optional, 206 } 207 if optional: 208 if dtype in optional_type_convert: 209 return optional_type_convert[dtype] 210 raise TypeError(f"""Unsupported convert optional type {dtype} for args.""") 211 if dtype in type_convert: 212 return type_convert[dtype] 213 raise TypeError(f"""Unsupported convert type {dtype} for args.""") 214 215 216def is_cube(class_name): 217 cube_set = {'Bmm', 'Baddbmm', 'MatMulExt', 'Mv'} 218 if class_name in cube_set: 219 return True 220 return False 221 222 223def get_return_type(dtype: str): 224 """ 225 Convert type 226 """ 227 # add more type here 228 type_convert = { 229 'tuple[tensor]': 'std::vector<tensor::TensorPtr>', 230 'list[tensor]': 'std::vector<tensor::TensorPtr>', 231 'tensor': 'tensor::TensorPtr', 232 } 233 if dtype in type_convert: 234 return type_convert[dtype] 235 raise TypeError(f"""Unsupported convert type {dtype} for args.""") 236 237 238def get_disable_flag(yaml_def): 239 """ 240 Get class or functional api disable generate flag. 241 """ 242 disable_flag = False 243 if yaml_def is not None: 244 item = yaml_def.get("disable") 245 if item is not None: 246 if item is not True and item is not False: 247 raise TypeError(f"The disable label for function should be True or False, but get {item}.") 248 disable_flag = item 249 return disable_flag 250 251 252def get_op_name(operator_name, class_def): 253 """ 254 Get op name for python class Primitive or c++ OpDef name. 255 """ 256 class_name = ''.join(word.capitalize() for word in operator_name.split('_')) 257 if class_def is not None: 258 item = class_def.get("name") 259 if item is not None: 260 class_name = item 261 return class_name 262 263 264def get_pyboost_name(operator_name): 265 return 'pyboost_' + operator_name 266 267 268def convert_python_func_name_to_c(func_name: str) -> str: 269 return ''.join(word.capitalize() for word in func_name.split('_')) 270 271 272def get_const_number_convert(arg_name, op_arg): 273 cpp_type = number_input_to_cpp_type(op_arg.arg_dtype) 274 if op_arg.is_type_id: 275 return f"TypeId {arg_name}_imm = static_cast<TypeId>(GetValue<{cpp_type}>({arg_name}));\n" 276 return f"auto {arg_name}_imm = GetValue<{cpp_type}>({arg_name});\n" 277 278 279def get_tuple_input_convert(arg_name, arg_type): 280 """ 281 convert tuple input. 282 :param arg_name: 283 :param arg_type: 284 :return: 285 """ 286 cpp_type = tuple_input_to_cpp_type(arg_type) 287 if cpp_type == "TensorPtr": 288 cpp_type = "BaseTensorPtr" 289 return f"std::vector<{cpp_type}> {arg_name}_vector = ConvertValueTupleToVector<{cpp_type}>({arg_name});\n" 290 291 292def is_pyboost_enable(operator_data): 293 dispatch_key = 'dispatch' 294 if dispatch_key in operator_data.keys(): 295 enable = operator_data[dispatch_key].get('enable') 296 if enable: 297 return True 298 return False 299 300 301def convert_types(inputs): 302 '''convert type to acl type''' 303 inputs_dtypes = {} 304 flag = False 305 for i in inputs: 306 inputs_dtypes[i] = inputs.get(i).get('dtype') 307 if inputs_dtypes[i] != 'tensor': 308 flag = True 309 if 'tuple' in inputs_dtypes[i]: 310 data_type = inputs_dtypes[i].split('[')[1].strip(']') 311 if data_type == 'tensor': 312 logging.info("Not support tuple[tensor] input.") 313 elif data_type == 'int': 314 inputs_dtypes[i] = 'std::vector<int64_t>' 315 elif data_type == 'float': 316 inputs_dtypes[i] = 'std::vector<float>' 317 elif data_type == 'bool': 318 inputs_dtypes[i] = 'std::vector<uint8_t>' 319 else: 320 logging.warning("Not support tuple[%s]] input.", data_type) 321 if inputs_dtypes[i] == 'number': 322 inputs_dtypes[i] = 'ScalarPtr' 323 if inputs_dtypes[i] == 'int': 324 inputs_dtypes[i] = 'int64_t' 325 return inputs_dtypes, flag 326 327 328def get_dtypes(op_yaml): 329 """get op inputs and outputs dtypes""" 330 inputs = op_yaml.get('args') 331 outputs = op_yaml.get('returns') 332 inputs_dtypes, flag_in = convert_types(inputs) 333 outputs_dtypes, flag_out = convert_types(outputs) 334 none_tensor_exist = (flag_in or flag_out) 335 return inputs_dtypes, outputs_dtypes, none_tensor_exist 336 337 338class AclnnUtils: 339 """ 340 aclnn utils 341 """ 342 work_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../../") 343 aclnn_map = safe_load_yaml(os.path.join(work_path, "./mindspore/python/mindspore/ops_generate/aclnn_config.yaml")) 344 345 @staticmethod 346 def get_aclnn_interface(class_name): 347 """ 348 get aclnn interface name. 349 :param class_name: 350 :return: 351 """ 352 if class_name in AclnnUtils.aclnn_map.keys(): 353 return AclnnUtils.aclnn_map[class_name] 354 return "aclnn" + class_name 355