1# Copyright 2020-2023 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16"""Operators info register.""" 17from __future__ import absolute_import 18from __future__ import division 19 20import inspect 21import json 22import os 23import functools 24import platform 25import hashlib 26import shutil 27 28from mindspore._c_expression import Oplib 29from mindspore import _checkparam as validator 30from mindspore import log as logger 31 32if platform.system() == "Linux": 33 import fcntl 34 35# path of built-in op info register. 36BUILT_IN_OPS_REGISTER_PATH = "mindspore/ops/_op_impl" 37BUILT_IN_CUSTOM_OPS_REGISTER_PATH = "mindspore/ops/_op_impl/_custom_op" 38 39KEY_NAME = "name" 40ASCEND_CUSTOM_OPP_PATH = "ASCEND_CUSTOM_OPP_PATH" 41 42 43def _get_reg_info_attr(op_info, attr_name, default_value=None): 44 """get attr value""" 45 for _, item in enumerate(op_info.get("attr", [])): 46 if item.get(KEY_NAME) == attr_name: 47 return item.get("defaultValue") 48 return default_value 49 50 51class _CustomInstaller: 52 """save custom op registration information to a json file which will be used by GE""" 53 reg_info_hash = [] # used to avoid writing the same reg info to file multiple times 54 copied_paths = [] # used to avoid copying the same file multiple times 55 56 def __init__(self, op_info, func=None): 57 self.op_info = op_info 58 self.func = func 59 self.op_type = op_info.get("op_name") if not func else func.__name__ 60 vendor_name = "ms" 61 custom_dir = os.path.join(os.path.realpath("./"), "vendors", vendor_name) 62 self._set_env(custom_dir) 63 op_impl_dir = os.path.join(custom_dir, "op_impl") 64 self.ai_core_config_dir = os.path.join(op_impl_dir, "ai_core", "tbe", "config") 65 self.ai_core_impl_dir = os.path.join(op_impl_dir, "ai_core", "tbe", vendor_name + "_impl") 66 self.ai_cpu_config_dir = os.path.join(op_impl_dir, "cpu", "config") 67 self.ai_cpu_impl_dir = os.path.join(op_impl_dir, "cpu", "aicpu_kernel", "impl") 68 69 @staticmethod 70 def _set_env(custom_opp_path): 71 """set custom file path to env""" 72 if not os.environ.get(ASCEND_CUSTOM_OPP_PATH): 73 os.environ[ASCEND_CUSTOM_OPP_PATH] = custom_opp_path 74 else: 75 paths = os.environ[ASCEND_CUSTOM_OPP_PATH].split(':') 76 if custom_opp_path not in paths: 77 os.environ[ASCEND_CUSTOM_OPP_PATH] = custom_opp_path + ':' + os.environ[ASCEND_CUSTOM_OPP_PATH] 78 79 @staticmethod 80 def _create_dir(*dir_names): 81 """create directory""" 82 for dir_name in dir_names: 83 if not os.path.isdir(dir_name): 84 try: 85 os.makedirs(dir_name, exist_ok=True) 86 except OSError as err: 87 if err.errno == 17: # File exists 88 pass 89 else: 90 raise err 91 92 @staticmethod 93 def _copy_file(src_path, dst_dir): 94 """copy file""" 95 if not os.path.exists(src_path) or src_path in _CustomInstaller.copied_paths: 96 return 97 _CustomInstaller.copied_paths.append(src_path) 98 if os.path.isfile(src_path): 99 lock_file = os.path.join(dst_dir, "file.lock") 100 with os.fdopen(os.open(lock_file, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600), 'w') as f: 101 fcntl.flock(f.fileno(), fcntl.LOCK_EX) 102 shutil.copy(src_path, dst_dir) 103 104 def check(self): 105 """check if the reg info need written""" 106 if platform.system() != "Linux": 107 return False 108 if not os.environ.get("MS_DEV_CUSTOM_OPP_PATH"): 109 # only process the first time import the mindspore module 110 return False 111 if self.op_info.get("target") in ["GPU", "CPU"]: 112 return False 113 sha256 = hashlib.sha256() 114 value = json.dumps(self.op_info, sort_keys=True).encode() 115 sha256.update(value) 116 hash_value = sha256.hexdigest() 117 if hash_value in _CustomInstaller.reg_info_hash: 118 return False 119 _CustomInstaller.reg_info_hash.append(hash_value) 120 return True 121 122 def _find_ai_cpu_so_path(self, so_file): 123 """find the absolute path of so""" 124 current_path = os.path.dirname(os.path.abspath(__file__)) 125 search_paths = [current_path + "/../lib", current_path + "/../lib/plugin/ascend"] 126 for path in search_paths: 127 so_path = os.path.join(path, so_file) 128 if os.path.exists(so_path): 129 return so_path 130 logger.warning("For Custom op '{}', can not find the aicpu so file '{}' in the following directories:\n{}" 131 .format(self.op_type, so_file, "\n".join(search_paths))) 132 return "" 133 134 def _gen_ai_core_reg_info(self, imply_path, func_name): 135 """generate reg info""" 136 137 def _get_dtype_format(idx): 138 data_type = [] 139 data_format = [] 140 for _, dtype_format in enumerate(self.op_info.get("dtype_format", [])): 141 if not dtype_format[idx][0]: 142 data_type = None 143 else: 144 data_type.append(dtype_format[idx][0]) 145 if not dtype_format[idx][1]: 146 data_format = None 147 else: 148 if dtype_format[idx][1] == "DefaultFormat": 149 data_format.append("ND") 150 else: 151 data_format.append(dtype_format[idx][1]) 152 return data_type, data_format 153 154 op_info = {"opFile": {"value": os.path.splitext(os.path.basename(imply_path))[0]}, 155 "opInterface": {"value": func_name}} 156 # attr 157 attrs_name = [] 158 for _, item in enumerate(self.op_info.get("attr", [])): 159 attr_name = item.get(KEY_NAME) 160 attrs_name.append(attr_name) 161 key = "attr_" + attr_name 162 op_info[key] = {} 163 for k, v in item.items(): 164 if k != KEY_NAME: 165 op_info[key][k] = v 166 if attrs_name: 167 op_info["attr"] = {"list": ",".join(attrs_name)} 168 # input and output 169 inputs = self.op_info.get("inputs", []) 170 outputs = self.op_info.get("outputs", []) 171 input_num = len(inputs) 172 output_num = len(outputs) 173 for i in range(input_num + output_num): 174 item = inputs[i] if i < input_num else outputs[i - input_num] 175 key = "input" if i < input_num else "output" 176 key += str(item.get("index")) 177 op_info[key] = {KEY_NAME: item.get(KEY_NAME), 178 "paramType": item.get("paramType", "required"), 179 "shape": item.get("shape", "all")} 180 dtype, formats = _get_dtype_format(i) 181 if dtype: 182 op_info[key]["dtype"] = ",".join(dtype) 183 if formats: 184 op_info[key]["format"] = ",".join(formats) 185 return op_info 186 187 @staticmethod 188 def _gen_ai_cpu_reg_info(so_file): 189 """generate reg info""" 190 op_info = {"opInfo": {"computeCost": "100", 191 "engine": "DNN_VM_AICPU", 192 "flagAsync": "False", 193 "flagPartial": "False", 194 "functionName": "RunCpuKernel", 195 "kernelSo": so_file, 196 "opKernelLib": "CUSTAICPUKernel", 197 "userDefined": "True"}} 198 return op_info 199 200 def _save_op_info(self, dst_dir, file_name, op_info): 201 """save op info file""" 202 repo = {} 203 save_path = os.path.join(dst_dir, file_name) 204 lock_file = os.path.join(dst_dir, "file.lock") 205 with os.fdopen(os.open(lock_file, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600), 'w') as f: 206 fcntl.flock(f.fileno(), fcntl.LOCK_EX) 207 if os.path.isfile(save_path): 208 with open(save_path, 'r') as fr: 209 json_str = fr.read() 210 json_str = "{}" if json_str == "" else json_str 211 repo = json.loads(json_str) 212 repo.update({self.op_type: op_info}) 213 with os.fdopen(os.open(save_path, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600), 'w') as fw: 214 json.dump(repo, fw, sort_keys=True, indent=4, separators=(',', ':')) 215 216 def run(self): 217 """save reg info to file""" 218 if not self.check(): 219 return 220 so_name = _get_reg_info_attr(self.op_info, "cust_aicpu") 221 if so_name: 222 _CustomInstaller._create_dir(self.ai_cpu_config_dir, self.ai_cpu_impl_dir) 223 # copy so file 224 so_file = "lib" + so_name + ".so" 225 imply_path = self._find_ai_cpu_so_path(so_file) 226 self._copy_file(imply_path, self.ai_cpu_impl_dir) 227 # generate and copy reg info file 228 op_info = self._gen_ai_cpu_reg_info(so_file) 229 self._save_op_info(self.ai_cpu_config_dir, "cust_aicpu_kernel.json", op_info) 230 else: 231 _CustomInstaller._create_dir(self.ai_core_config_dir, self.ai_core_impl_dir) 232 # copy dsl file 233 imply_path = os.path.realpath(inspect.getfile(self.func)) 234 self._copy_file(imply_path, self.ai_core_impl_dir) 235 # generate and copy reg info file 236 op_info = self._gen_ai_core_reg_info(imply_path, self.func.__name__) 237 self._copy_file(imply_path, self.ai_core_impl_dir) 238 for arc_name in ["ascend910", "ascend910b", "ascend910c", "ascend310p"]: 239 arc_dir = os.path.join(self.ai_core_config_dir, arc_name) 240 _CustomInstaller._create_dir(arc_dir) 241 self._save_op_info(arc_dir, "aic-{}-ops-info.json".format(arc_name), op_info) 242 243 244def op_info_register(op_info): 245 r""" 246 A decorator which is used to register an operator. 247 248 Note: 249 'op_info' should represent the operator information by string with json format. 250 The 'op_info' will be added into oplib. 251 252 Args: 253 op_info (Union[str, dict]): operator information in json format. 254 255 Examples: 256 >>> from mindspore.ops import op_info_register, TBERegOp, DataType 257 >>> abs_op_info = TBERegOp("Abs") \ 258 ... .fusion_type("ELEMWISE") \ 259 ... .async_flag(False) \ 260 ... .binfile_name("abs.so") \ 261 ... .compute_cost(10) \ 262 ... .kernel_name("abs") \ 263 ... .partial_flag(True) \ 264 ... .op_pattern("formatAgnostic") \ 265 ... .input(0, "x", None, "required", None) \ 266 ... .output(0, "y", True, "required", "all") \ 267 ... .dtype_format(DataType.F16_None, DataType.F16_None) \ 268 ... .dtype_format(DataType.F32_None, DataType.F32_None) \ 269 ... .dtype_format(DataType.I32_None, DataType.I32_None) \ 270 ... .get_op_info() 271 >>> 272 >>> @op_info_register(abs_op_info) 273 ... def _abs_tbe(): 274 ... return 275 ... 276 277 Returns: 278 Function, returns a decorator for op info register. 279 """ 280 281 def register_decorator(func): 282 if isinstance(op_info, dict): 283 op_info_real = json.dumps(op_info) 284 else: 285 op_info_real = op_info 286 validator.check_value_type("op_info", op_info_real, [str]) 287 op_lib = Oplib() 288 file_path = os.path.realpath(inspect.getfile(func)) 289 # keep the path custom ops implementation. 290 if BUILT_IN_CUSTOM_OPS_REGISTER_PATH in file_path: 291 imply_path = file_path 292 else: 293 imply_path = "" if BUILT_IN_OPS_REGISTER_PATH in file_path else file_path 294 if not op_lib.reg_op(op_info_real, imply_path): 295 raise ValueError('Invalid op info {}:\n{}\n'.format(file_path, op_info_real)) 296 297 def wrapped_function(*args, **kwargs): 298 return func(*args, **kwargs) 299 300 return wrapped_function 301 302 return register_decorator 303 304 305def custom_info_register(*reg_info): 306 r""" 307 A decorator which is used to bind the registration information to the `func` parameter of 308 :class:`mindspore.ops.Custom`. 309 310 Note: 311 The 'reg_info' will be added into oplib. 312 313 Args: 314 reg_info (tuple[str, dict]): Each item represents registration information in json format. 315 316 Returns: 317 Function, returns a decorator for op info register. 318 319 Raises: 320 TypeError: If `reg_info` is not a tuple. 321 322 Examples: 323 >>> from mindspore.ops import custom_info_register, CustomRegOp, DataType 324 >>> custom_func_ascend_info = CustomRegOp() \ 325 ... .input(0, "x", "dynamic") \ 326 ... .output(0, "y") \ 327 ... .dtype_format(DataType.F16_Default, DataType.F16_Default) \ 328 ... .dtype_format(DataType.F32_Default, DataType.F32_Default) \ 329 ... .target("Ascend") \ 330 ... .get_op_info() 331 >>> 332 >>> @custom_info_register(custom_func_ascend_info) 333 ... def custom_func(x): 334 ... pass 335 """ 336 337 def decorator(func): 338 setattr(func, "reg_info", reg_info) 339 if reg_info: 340 used_reg_info = reg_info[0] 341 if isinstance(used_reg_info, dict): 342 # ai_cpu should be parsed inside CustomRegOp, skip it here 343 if not _get_reg_info_attr(used_reg_info, "cust_aicpu"): 344 _CustomInstaller(used_reg_info, func).run() 345 346 @functools.wraps(func) 347 def wrapper(*args, **kwargs): 348 return func(*args, **kwargs) 349 350 return wrapper 351 352 return decorator 353 354 355class RegOp: 356 """ 357 Base class for op info register. 358 359 Args: 360 op_name (str): Name of operator. 361 """ 362 363 def __init__(self, op_name=""): 364 if not isinstance(op_name, str): 365 raise ValueError("op name value must be string") 366 if not op_name.strip(): 367 raise ValueError("op name is empty") 368 self.op_name = op_name 369 self.inputs = [] 370 self.outputs = [] 371 self.attr_ = [] 372 self.fusion_type_ = '' 373 self.dtype_format_ = [] 374 375 def _is_string(self, value): 376 """ 377 Check if the value is a str type. 378 379 Args: 380 value: Parameter to be checked. 381 382 Raises: 383 TypeError: If the type of value is not a str. 384 """ 385 if not isinstance(value, str): 386 raise TypeError("%s value must be str" % str(value)) 387 return True 388 389 def _is_int(self, value): 390 """ 391 Check if the value is an int. 392 393 Args: 394 value: Parameter to be checked. 395 396 Raises: 397 TypeError: If the type of value is not an int. 398 """ 399 if not isinstance(value, int): 400 raise TypeError("%s value must be int" % str(value)) 401 return True 402 403 def _is_bool(self, value): 404 """ 405 Check if the value is a bool. 406 407 Args: 408 value: Parameter to be checked. 409 410 Raises: 411 TypeError: If the type of value is not a bool. 412 """ 413 if not isinstance(value, bool): 414 raise TypeError("%s value must be bool" % str(value)) 415 return True 416 417 @staticmethod 418 def _is_list(value): 419 """ 420 Check if the value is a list. 421 422 Args: 423 value: Parameter to be checked. 424 425 Raises: 426 TypeError: If the type of value is not a list. 427 """ 428 if not isinstance(value, list): 429 raise TypeError("%s value must be list" % str(value)) 430 return True 431 432 def _check_param(self, param_list, key_list, fn_list, kwargs): 433 """ 434 Check if the parameter type is correct. 435 436 Args: 437 param_list (list): Parameter list to be checked. 438 key_list (list): The keys of output dict. 439 fn_list (list): Function used for parameter checking. If the function list has only one element, 440 all parameters will use the same function. 441 kwargs (dict): Other parameter information. 442 443 Raises: 444 TypeError: If the type of value is not list. 445 ValueError: If the size of param list is not equal to the size of key list, or 446 the size of param list is not equal to the size of function list. 447 """ 448 for i in [param_list, key_list, fn_list]: 449 if not isinstance(i, list): 450 raise TypeError("%s value must be list type" % str(i)) 451 if len(param_list) != len(key_list) or (len(fn_list) != 1 and len(param_list) != len(fn_list)): 452 raise ValueError("param_list size {}, key_list size {}, must be equal.And fn_list size {}.". 453 format(len(param_list), len(key_list), len(fn_list))) 454 out_dict = {} 455 for idx, element in enumerate(param_list): 456 if element is not None: 457 if len(fn_list) == 1: 458 fn_list[0](element) 459 else: 460 fn_list[idx](element) 461 out_dict[key_list[idx]] = element 462 if kwargs: 463 out_dict = dict(out_dict, **kwargs) 464 return out_dict 465 466 def fusion_type(self, fusion_type): 467 """ 468 Fusion type of the operator. 469 470 Args: 471 fusion_type (str): Value of fusion type. 472 """ 473 self._is_string(fusion_type) 474 self.fusion_type_ = fusion_type 475 return self 476 477 def dtype_format(self, *args): 478 """ 479 A dtype and format supported by the operator. 480 481 Args: 482 args (tuple): Value of dtype and format. 483 484 Raises: 485 ValueError: If the size of args not equal to input size add output size. 486 TypeError: If the type of args is not tuple. 487 """ 488 if len(self.inputs) + len(self.outputs) != len(args): 489 raise ValueError("input size add output size must be equal to dtype format size") 490 dtype_format = [] 491 for arg in args: 492 if not isinstance(arg, tuple) or (len(arg) != 2 and len(arg) != 3): 493 raise ValueError("dtype and format value must be tuple of two or three elements") 494 self._is_string(arg[0]) 495 self._is_string(arg[1]) 496 if len(arg) == 3: 497 if self._is_string(arg[2]): 498 dtype_format.append(arg) 499 else: 500 dtype_format.append(arg) 501 self.dtype_format_.append(tuple(dtype_format)) 502 return self 503 504 def get_op_info(self): 505 """ 506 Return all registration information for this instance. 507 508 The '_' character ending the key is removed here for compatibility with previous version. 509 510 Key will be unified into an underlined form later. 511 """ 512 op_info = {} 513 for key, value in self.__dict__.items(): 514 if isinstance(key, str) and key.endswith('_'): 515 key = key.rstrip('_') 516 key_dic = {"dynamic_shape_support": "dynamicShapeSupport", 517 "dynamic_rank_support": "dynamicRankSupport", 518 "dynamic_compile_static": "dynamicCompileStatic", 519 "need_check_support": "needCheckSupport", 520 "dynamic_format": "dynamicFormat" 521 } 522 key = key_dic.get(key, key) 523 op_info[key] = value 524 return op_info 525 526 527class CpuRegOp(RegOp): 528 """Class for Cpu op info register""" 529 530 def __init__(self, op_name): 531 super(CpuRegOp, self).__init__(op_name) 532 self.imply_type = "CPU" 533 534 def input(self, index=None, name=None, param_type=None, **kwargs): 535 """ 536 Register Cpu op input information. 537 538 Args: 539 index (int): Order of the input. Default: ``None`` . 540 name (str): Name of the input. Default: ``None`` . 541 param_type (str): Param type of the input. Default: ``None`` . 542 kwargs (dict): Other information of the input. 543 """ 544 param_list = [index, name, param_type] 545 key_list = ["index", "name", "paramType"] 546 fn_list = [self._is_int, self._is_string, self._is_string] 547 input_dict = self._check_param(param_list, key_list, fn_list, kwargs) 548 self.inputs.append(input_dict) 549 return self 550 551 def output(self, index=None, name=None, param_type=None, **kwargs): 552 """ 553 Register AiCPU op output information. 554 555 Args: 556 index (int): Order of the output. Default: ``None`` . 557 name (str): Name of the output. Default: ``None`` . 558 param_type (str): Param type of the output. Default: ``None`` . 559 kwargs (dict): Other information of the output. 560 """ 561 param_list = [index, name, param_type] 562 key_list = ["index", "name", "paramType"] 563 fn_list = [self._is_int, self._is_string, self._is_string] 564 output_dict = self._check_param(param_list, key_list, fn_list, kwargs) 565 self.outputs.append(output_dict) 566 return self 567 568 def attr(self, name=None, value_type=None, value=None, **kwargs): 569 """ 570 Register AiCPU op attribute information. 571 572 Args: 573 name (str): Name of the attribute. Default: ``None`` . 574 value_type (str): Value type of the attribute. Default: ``None`` . 575 value (str): Value of the attribute. Default: ``None`` . 576 kwargs (dict): Other information of the attribute. 577 """ 578 param_list = [name, value_type, value] 579 key_list = ["name", "type", "value"] 580 fn_list = [self._is_string] 581 attr_dict = self._check_param(param_list, key_list, fn_list, kwargs) 582 self.attr_.append(attr_dict) 583 return self 584 585 586class AkgRegOp(RegOp): 587 """Class for Akg op info register.""" 588 589 def __init__(self, op_name, processor): 590 super(AkgRegOp, self).__init__(op_name) 591 self.imply_type = "AKG" 592 self.processor = processor 593 594 def input(self, index=None, name=None, param_type=None, **kwargs): 595 """ 596 Register Akg op input information. 597 598 Args: 599 index (int): Order of the input. Default: ``None`` . 600 name (str): Name of the input. Default: ``None`` . 601 param_type (str): Param type of the input. Default: ``None`` . 602 kwargs (dict): Other information of the input. 603 """ 604 param_list = [index, name, param_type] 605 key_list = ["index", "name", "paramType"] 606 fn_list = [self._is_int, self._is_string, self._is_string] 607 input_dict = self._check_param(param_list, key_list, fn_list, kwargs) 608 self.inputs.append(input_dict) 609 return self 610 611 def output(self, index=None, name=None, **kwargs): 612 """ 613 Register Akg op output information. 614 615 Args: 616 index (int): Order of the output. Default: ``None`` . 617 name (str): Name of the output. Default: ``None`` . 618 kwargs (dict): Other information of the output. 619 """ 620 param_list = [index, name] 621 key_list = ["index", "name"] 622 fn_list = [self._is_int, self._is_string] 623 output_dict = self._check_param(param_list, key_list, fn_list, kwargs) 624 self.outputs.append(output_dict) 625 return self 626 627 def attr(self, name=None, param_type=None, value_type=None, **kwargs): 628 """ 629 Register Akg op attribute information. 630 631 Args: 632 name (str): Name of the attribute. Default: ``None`` . 633 param_type (str): Param type of the attribute. Default: ``None`` . 634 value_type (str): Value type of the attribute. Default: ``None`` . 635 kwargs (dict): Other information of the attribute. 636 """ 637 param_list = [name, param_type, value_type] 638 key_list = ["name", "paramType", "type"] 639 fn_list = [self._is_string] 640 attr_dict = self._check_param(param_list, key_list, fn_list, kwargs) 641 self.attr_.append(attr_dict) 642 return self 643 644 645class AkgGpuRegOp(AkgRegOp): 646 """Class for AkgGpu op info register""" 647 648 def __init__(self, op_name): 649 super(AkgGpuRegOp, self).__init__(op_name, "CUDA") 650 651 652class AkgAscendRegOp(AkgRegOp): 653 """Class for AkgAscend op info register""" 654 655 def __init__(self, op_name): 656 super(AkgAscendRegOp, self).__init__(op_name, "AiCore") 657 658 659class AkgCpuRegOp(AkgRegOp): 660 """Class for AkgCpu op info register""" 661 662 def __init__(self, op_name): 663 super(AkgCpuRegOp, self).__init__(op_name, "CPU") 664 665 666class AiCPURegOp(CpuRegOp): 667 r""" 668 Class for AiCPU operator information registration. 669 670 Args: 671 op_name (str): Name of operator. 672 673 Examples: 674 >>> from mindspore.ops import AiCPURegOp, DataType 675 >>> stack_op_info = AiCPURegOp("Stack") \ 676 ... .fusion_type("OPAQUE") \ 677 ... .attr("axis", "int") \ 678 ... .input(0, "x", "dynamic") \ 679 ... .output(0, "y", "required") \ 680 ... .dtype_format(DataType.I8_Default, DataType.I8_Default) \ 681 ... .dtype_format(DataType.I16_Default, DataType.I16_Default) \ 682 ... .dtype_format(DataType.I32_Default, DataType.I32_Default) \ 683 ... .dtype_format(DataType.I64_Default, DataType.I64_Default) \ 684 ... .dtype_format(DataType.U8_Default, DataType.U8_Default) \ 685 ... .dtype_format(DataType.U16_Default, DataType.U16_Default) \ 686 ... .dtype_format(DataType.U32_Default, DataType.U32_Default) \ 687 ... .dtype_format(DataType.U64_Default, DataType.U64_Default) \ 688 ... .dtype_format(DataType.F16_Default, DataType.F16_Default) \ 689 ... .dtype_format(DataType.F32_Default, DataType.F32_Default) \ 690 ... .dtype_format(DataType.F64_Default, DataType.F64_Default) \ 691 ... .dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \ 692 ... .get_op_info() 693 >>> 694 """ 695 696 def __init__(self, op_name): 697 super(AiCPURegOp, self).__init__(op_name) 698 self.imply_type = "AiCPU" 699 700 701class TBERegOp(RegOp): 702 r""" 703 Class for TBE operator information registration. TBE (Tensor Boost Engine) is the Ascend operator development 704 tool, which is extended on the basis of the TVM framework to develop custom operators. 705 706 Args: 707 op_name (str): Name of operator. 708 709 Examples: 710 >>> from mindspore.ops import TBERegOp, DataType 711 >>> op_name_op_info = TBERegOp("OpName") \ 712 ... .fusion_type("ELEMWISE") \ 713 ... .async_flag(False) \ 714 ... .binfile_name("op_name.so") \ 715 ... .compute_cost(10) \ 716 ... .kernel_name("op_name") \ 717 ... .partial_flag(True) \ 718 ... .op_pattern("formatAgnostic") \ 719 ... .need_check_supported(True) \ 720 ... .dynamic_shape(True) \ 721 ... .dynamic_rank_support(True) \ 722 ... .dynamic_compile_static(True) \ 723 ... .attr("format", "required", "str", "all") \ 724 ... .input(0, "x1", None, "required", None) \ 725 ... .input(0, "x2", None, "required", None) \ 726 ... .input(1, "axis", None, "required", None) \ 727 ... .output(0, "y", True, "required", "all") \ 728 ... .real_input_index([1, 0]) \ 729 ... .input_to_attr_index([2]) \ 730 ... .unknown_shape_formats(["ND", "ND", "ND", "ND"]) \ 731 ... .reshape_type("NC") \ 732 ... .is_dynamic_format(True) \ 733 ... .dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None, DataType.F16_None) \ 734 ... .dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None, DataType.F32_None) \ 735 ... .dtype_format(DataType.I32_None, DataType.I32_None, DataType.I32_None, DataType.I32_None) \ 736 ... .get_op_info() 737 >>> 738 """ 739 740 def __init__(self, op_name): 741 super(TBERegOp, self).__init__(op_name) 742 self.imply_type = "TBE" 743 self.async_flag_ = False 744 self.binfile_ = '' 745 self.compute_cost_ = 10 746 self.kernel_ = '' 747 self.partial_flag_ = False 748 self.reshape_type_ = '' 749 self.dynamic_rank_support_ = False 750 self.dynamic_shape_support_ = False 751 self.dynamic_compile_static_ = False 752 self.need_check_support_ = False 753 self.dynamic_format_ = False 754 self.op_pattern_ = "" 755 self.real_input_index_ = [] 756 self.input_to_attr_index_ = [] 757 self.unknown_shape_formats_ = [] 758 759 def unknown_shape_formats(self, unknown_shape_formats): 760 """ 761 Description data arrangement of operator input / output tensor in dynamic shape scene. 762 763 Args: 764 unknown_shape_formats (list): Description data arrangement of operator input / output tensor in dynamic 765 shape scene. 766 """ 767 RegOp._is_list(unknown_shape_formats) 768 self.unknown_shape_formats_.append(unknown_shape_formats) 769 return self 770 771 def dynamic_rank_support(self, dynamic_rank_support): 772 """ 773 Description whether the operator supports dynamic rank (dynamic dimension). 774 775 Args: 776 dynamic_rank_support (bool): Description whether the operator supports dynamic rank (dynamic dimension). 777 True: indicates that dynamic rank is supported, and the operator supports 778 shape (- 2), which is used to determine whether dynamic is performed. 779 False: indicates that the operator does not support dynamic rank. 780 Default: ``False`` . 781 """ 782 if self._is_bool(dynamic_rank_support): 783 self.dynamic_rank_support_ = dynamic_rank_support 784 return self 785 786 def real_input_index(self, real_input_index): 787 """ 788 Description operator front end and tbe operator input mapping. 789 790 Args: 791 real_input_index (list): Value of real_input_index. Default: ``()`` . 792 """ 793 RegOp._is_list(real_input_index) 794 self.real_input_index_ = real_input_index 795 return self 796 797 def input_to_attr_index(self, input_to_attr_index): 798 """ 799 Description the index of input need to cast to attr. 800 801 Args: 802 input_to_attr_index (list): Value of input_to_attr_index. Default: ``()`` . 803 """ 804 RegOp._is_list(input_to_attr_index) 805 self.input_to_attr_index_ = input_to_attr_index 806 return self 807 808 def async_flag(self, async_flag=False): 809 """ 810 Define the calculation efficiency of the operator, whether the asynchronous calculation is supported. 811 812 Args: 813 async_flag (bool): Value of async flag. Default: ``False`` . 814 """ 815 self._is_bool(async_flag) 816 self.async_flag_ = async_flag 817 return self 818 819 def binfile_name(self, binfile_name): 820 """ 821 Set the binary file name of the operator, it is optional. 822 823 Args: 824 binfile_name (str): The binary file name of the operator. 825 """ 826 self._is_string(binfile_name) 827 self.binfile_ = binfile_name 828 return self 829 830 def compute_cost(self, compute_cost=10): 831 """ 832 Define the calculation efficiency of operator, which refers to the value of the cost model 833 in the tiling module. 834 835 Args: 836 compute_cost (int): Value of compute cost. Default: ``10`` . 837 """ 838 self._is_int(compute_cost) 839 self.compute_cost_ = compute_cost 840 return self 841 842 def kernel_name(self, kernel_name): 843 """ 844 The name of operator kernel. 845 846 Args: 847 kernel_name (str): Name of operator kernel. 848 """ 849 self._is_string(kernel_name) 850 self.kernel_ = kernel_name 851 return self 852 853 def partial_flag(self, partial_flag=True): 854 """ 855 Define the calculation efficiency of operator, whether the partial calculation is supported. 856 857 Args: 858 partial_flag (bool): Value of partial flag. Default: ``True`` . 859 """ 860 self._is_bool(partial_flag) 861 self.partial_flag_ = partial_flag 862 return self 863 864 def reshape_type(self, reshape_type): 865 """ 866 Reshape type of operator. 867 868 Args: 869 reshape_type (str): Value of reshape type. For example, if the input shape is :math:`(2, 3)` 870 and `reshape_type` is set to "CH", then the new shape is :math:`(1, 2, 3, 1)`. 871 "CH" means the C and H dimensions are kept and 872 new dimensions are added for N and W dimension. 873 """ 874 self._is_string(reshape_type) 875 self.reshape_type_ = reshape_type 876 return self 877 878 def dynamic_shape(self, dynamic_shape=False): 879 """ 880 Whether the operator supports dynamic shape. 881 882 Args: 883 dynamic_shape (bool): Value of dynamic shape. Default: ``False`` . 884 """ 885 self._is_bool(dynamic_shape) 886 self.dynamic_shape_support_ = dynamic_shape 887 return self 888 889 def dynamic_compile_static(self, dynamic_compile_static=False): 890 """ 891 Whether the operator supports dynamic compile static. 892 893 Args: 894 dynamic_compile_static (bool): Value of dynamic compile static. Default: ``False`` . 895 """ 896 if self._is_bool(dynamic_compile_static): 897 self.dynamic_compile_static_ = dynamic_compile_static 898 return self 899 900 def need_check_supported(self, need_check_supported=False): 901 """ 902 Whether the operator needs check supports. 903 904 Args: 905 need_check_supported (bool): Value of need_check_supported. Default: ``False`` . 906 """ 907 if self._is_bool(need_check_supported): 908 self.need_check_support_ = need_check_supported 909 return self 910 911 def is_dynamic_format(self, is_dynamic_format=False): 912 """ 913 Whether the operator needs calop_select_format api. 914 915 Args: 916 is_dynamic_format (bool): Value of is_dynamic_format. Default: ``False`` . 917 """ 918 if self._is_bool(is_dynamic_format): 919 self.dynamic_format_ = is_dynamic_format 920 return self 921 922 def op_pattern(self, pattern=None): 923 """ 924 The behavior type of operator, such as broadcast, reduce and so on. 925 926 Args: 927 pattern (str): Value of op pattern, e.g. "broadcast", "reduce". Default: ``None`` . 928 """ 929 if pattern is not None and self._is_string(pattern): 930 self.op_pattern_ = pattern 931 return self 932 933 def attr(self, name=None, param_type=None, value_type=None, value=None, default_value=None, **kwargs): 934 """ 935 Register TBE op attribute information. 936 937 Args: 938 name (str): Name of the attribute. Default: ``None`` . 939 param_type (str): Param type of the attribute. Default: ``None`` . 940 value_type (str): Type of the attribute. Default: ``None`` . 941 value (str): Value of the attribute. Default: ``None`` . 942 default_value (str): Default value of attribute. Default: ``None`` . 943 kwargs (dict): Other information of the attribute. 944 """ 945 param_list = [name, param_type, value_type, value, default_value] 946 key_list = ["name", "paramType", "type", "value", "defaultValue"] 947 fn_list = [self._is_string] 948 attr_dict = self._check_param(param_list, key_list, fn_list, kwargs) 949 self.attr_.append(attr_dict) 950 return self 951 952 def input(self, index=None, name=None, need_compile=None, param_type=None, shape=None, value_depend=None, **kwargs): 953 """ 954 Register TBE op input information. 955 956 Args: 957 index (int): Order of the input. Default: ``None`` . 958 name (str): Name of the input. Default: ``None`` . 959 need_compile (bool): Whether the input needs to be compiled or not. Default: ``None`` . 960 param_type (str): Type of the input. Default: ``None`` . 961 shape (str): Shape of the input. Default: ``None`` . 962 value_depend (str): Whether the input is constant value depend. Default: ``None`` . 963 kwargs (dict): Other information of the input. 964 """ 965 param_list = [index, name, need_compile, param_type, shape, value_depend] 966 key_list = ["index", "name", "needCompile", "paramType", "shape", "valueDepend"] 967 fn_list = [self._is_int, self._is_string, self._is_bool, self._is_string, self._is_string, self._is_string] 968 input_dict = self._check_param(param_list, key_list, fn_list, kwargs) 969 value_depend_values = ("ignored", "optional", "required") 970 if value_depend and value_depend.lower() not in value_depend_values: 971 raise ValueError("Operator {} input{}'s value_depend's value ({}) is not in {}.". 972 format(self.op_name, index, value_depend, value_depend_values)) 973 self.inputs.append(input_dict) 974 return self 975 976 def output(self, index=None, name=None, need_compile=None, param_type=None, shape=None, **kwargs): 977 """ 978 Register TBE op output information. 979 980 Args: 981 index (int): Order of the output. Default: ``None`` . 982 name (str): Name of the output. Default: ``None`` . 983 need_compile (bool): Whether the output needs to be compiled or not. Default: ``None`` . 984 param_type (str): Type of the output. Default: ``None`` . 985 shape (str): Shape of the output. Default: ``None`` . 986 kwargs (dict): Other information of the output. 987 """ 988 param_list = [index, name, need_compile, param_type, shape] 989 key_list = ["index", "name", "need_compile", "paramType", "shape"] 990 fn_list = [self._is_int, self._is_string, self._is_bool, self._is_string, self._is_string] 991 output_dict = self._check_param(param_list, key_list, fn_list, kwargs) 992 self.outputs.append(output_dict) 993 return self 994 995 996class CustomRegOp(RegOp): 997 r""" 998 Class used for generating the registration information for the `func` parameter of :class:`mindspore.ops.Custom`. 999 The registration information mainly specifies the supported data types and formats of input and output tensors, 1000 attributes and target of `func`. 1001 1002 Args: 1003 op_name (str): kernel name. The name will be record in the reg_op_name attr of the kernel node. 1004 Besides, the operator will generate a unique name automatically to identify the reg info. 1005 Default: ``"Custom"`` . 1006 1007 Examples: 1008 >>> from mindspore.ops import CustomRegOp, DataType 1009 >>> custom_op_ascend_info = CustomRegOp() \ 1010 ... .input(0, "x", "dynamic") \ 1011 ... .output(0, "y") \ 1012 ... .dtype_format(DataType.F16_Default, DataType.F16_Default) \ 1013 ... .dtype_format(DataType.F32_Default, DataType.F32_Default) \ 1014 ... .target("Ascend") \ 1015 ... .get_op_info() 1016 """ 1017 1018 def __init__(self, op_name="Custom"): 1019 super(CustomRegOp, self).__init__(op_name) 1020 self.target_ = "UnKnown" 1021 1022 def input(self, index=None, name=None, param_type="required", **kwargs): 1023 """ 1024 Specifies the input tensor information for the `func` parameter of :class:`mindspore.ops.Custom`. Each 1025 invocation of this function will generate one input tensor information, that means, if `func` has two input 1026 tensors, then this function should be invoked two times continuously. The input tensor information will be 1027 generated as a dict: {"index": `index`, "name": `name`, "param_type": `param_type`}. 1028 1029 Args: 1030 index (int): Index of the input, starts from 0. 0 means the first input tensor, 1 means the second input 1031 tensor and so on. If ``None`` , key "index" will not appear in the input tensor information dict. 1032 Default: ``None`` . 1033 name (str): Name of the `index` 'th input. If ``None`` , key "name" will not appear in the input tensor 1034 information dict. Default: ``None`` . 1035 param_type (str): Parameter type of the `index` 'th input, can be one of 1036 [``"required"`` , ``"dynamic"`` , ``"optional"`` ]. If ``None`` , key "param_type" will not appear in 1037 the input tensor information dict. Default: ``"required"`` . 1038 1039 - ``"required"``: means the `index` 'th input exist and can only be a single tensor. 1040 - ``"dynamic":`` means the `index` 'th input exist and may be multiple tensors, such as the input of 1041 AddN. 1042 - ``"optional"``: means the `index` 'th input may exist and be a single tensor or may not exist. 1043 1044 kwargs (dict): Other information of the input, used for extension. 1045 1046 Raises: 1047 TypeError: If `index` is neither int nor None. 1048 TypeError: If `name` is neither str nor None. 1049 TypeError: If `param_type` is neither str nor None. 1050 """ 1051 param_list = [index, name, param_type] 1052 key_list = ["index", "name", "paramType"] 1053 fn_list = [self._is_int, self._is_string, self._is_string] 1054 input_dict = self._check_param(param_list, key_list, fn_list, kwargs) 1055 self.inputs.append(input_dict) 1056 return self 1057 1058 def output(self, index=None, name=None, param_type="required", **kwargs): 1059 """ 1060 Specifies the output tensor information for the `func` parameter of :class:`mindspore.ops.Custom`. Each 1061 invocation of this function will generate one output tensor information, which means, if `func` has two output 1062 tensors, then this function should be invoked two times continuously. The output tensor information will be 1063 generated as a dict: {"index": `index`, "name": `name`, "param_type": `param_type`}. 1064 1065 Args: 1066 index (int): Index of the output, starts from 0. 0 means the first output tensor, 1 means the second output 1067 tensor and so on. If ``None`` , key "index" will not appear in the output tensor information dict. 1068 Default: ``None`` . 1069 name (str): Name of the `index` 'th output. If ``None`` , key "name" will not appear in the output tensor 1070 information dict. Default: ``None`` . 1071 param_type (str): Parameter type of the `index` 'th output, can be one of 1072 [ ``"required"`` , ``"dynamic"`` , ``"optional"`` ]. If ``None`` , key "param_type" will not appear in 1073 the output tensor information dict. Default: ``"required"`` . 1074 1075 - ``"required"``: means the `index` 'th output exist and can only be a single tensor. 1076 - ``"dynamic"``: means the `index` 'th output exist and may be multiple tensors. 1077 - ``"optional"``: means the `index` 'th output may exist and be a single tensor or may not exist. 1078 1079 kwargs (dict): Other information of the output, used for extension. 1080 1081 Raises: 1082 TypeError: If `index` is neither int nor None. 1083 TypeError: If `name` is neither str nor None. 1084 TypeError: If `param_type` is neither str nor None. 1085 """ 1086 param_list = [index, name, param_type] 1087 key_list = ["index", "name", "paramType"] 1088 fn_list = [self._is_int, self._is_string, self._is_string] 1089 output_dict = self._check_param(param_list, key_list, fn_list, kwargs) 1090 self.outputs.append(output_dict) 1091 return self 1092 1093 def dtype_format(self, *args): 1094 """ 1095 Specifies the supported data type and format of each input tensor and output tensor for the `func` parameter 1096 of :class:`mindspore.ops.Custom`. This function should be invoked after `input` and `output` function as shown 1097 in the above example. 1098 1099 Args: 1100 args (tuple): A tuple of (data type, format) pair, the length of `args` should be equal to the sum of input 1101 tensors and output tensors. Each item in `args` is also a tuple, tuple[0] and tuple[1] are both str 1102 type, which specifies the data type and format of a tensor respectively. :class:`mindspore.ops.DataType` 1103 provides many predefined (data type, format) combinations, for example, `DataType.F16_Default` means the 1104 data type is float16 and the format is default format. 1105 1106 Raises: 1107 ValueError: If the size of `args` not equal to the sum of input tensors and output tensors. 1108 """ 1109 io_nums = len(self.inputs) + len(self.outputs) 1110 if len(args) != io_nums: 1111 raise ValueError("The size of 'args' must be equal to the sum of input tensors and output tensors, but got " 1112 "{} vs {}".format(len(args), io_nums)) 1113 return super(CustomRegOp, self).dtype_format(*args) 1114 1115 def attr(self, name=None, param_type=None, value_type=None, default_value=None, **kwargs): 1116 """ 1117 Specifies the attributes information for the `func` parameter of :class:`mindspore.ops.Custom`. Each 1118 invocation of this function will generate one attribute information, that means, if `func` has two attributes, 1119 then this function should be invoked two times continuously. The attributes information will be 1120 generated as a dict: {"name": `name`, "param_type": `param_type`, "value_type": `value_type`, "default_value": 1121 `default_value`}. 1122 1123 Args: 1124 name (str): Name of the attribute. If ``None`` , key "name" will not appear in the attributes tensor 1125 information dict. Default: ``None`` . 1126 param_type (str): Parameter type of the attribute, can be one of ["required", "optional"]. If ``None`` , 1127 key "param_type" will not appear in the attributes tensor information dict. Default: ``None`` . 1128 1129 - "required": means must provide a value for this attribute either by setting a default value in the 1130 registration information or providing an input value when calling the Custom operator. 1131 - "optional": means does not have to provide a value for this attribute. 1132 1133 value_type (str): Value type of the attribute, can be one of ["int", "str", "bool", "float", "listInt", 1134 "listStr", "listBool", "listFloat"]. If ``None`` , key "value_type" will not appear in the attributes 1135 tensor information dict. Default: ``None`` . 1136 1137 - "int": string representation of Python type int. 1138 - "str": string representation of Python type str. 1139 - "bool": string representation of Python type bool. 1140 - "float": string representation of Python type float. 1141 - "listInt": string representation of Python type list of int. 1142 - "listStr": string representation of Python type list of str. 1143 - "listBool": string representation of Python type list of bool. 1144 - "listFloat": string representation of Python type list of float. 1145 1146 default_value (str): Default value of the attribute. `default_value` and `value_type` are used together. 1147 If the real default value of the attribute is float type with value 1.0, then the `value_type` should be 1148 "float" and `default_value` should be "1.0". If the real default value of the attribute is a list of int 1149 with value [1, 2, 3], then the `value_type` should be "listInt" and `default_value` should be "1,2,3", 1150 each item should split by ','. If ``None`` , means the attribute has no default value and key 1151 "default_value" will not appear in the attributes tensor information dict. It is used for "akg", 1152 "aicpu" and "tbe" Custom operators currently. Default: ``None`` . 1153 kwargs (dict): Other information of the attribute, used for extension. 1154 1155 Raises: 1156 TypeError: If `name` is neither str nor None. 1157 TypeError: If `param_type` is neither str nor None. 1158 TypeError: If `value_type` is neither str nor None. 1159 TypeError: If `default_value` is neither str nor None. 1160 """ 1161 param_list = [name, param_type, value_type, default_value] 1162 key_list = ["name", "paramType", "type", "defaultValue"] 1163 fn_list = [self._is_string] 1164 attr_dict = self._check_param(param_list, key_list, fn_list, kwargs) 1165 self.attr_.append(attr_dict) 1166 return self 1167 1168 def target(self, target=None): 1169 """ 1170 Specifies the target that this registration information is used for. 1171 1172 Args: 1173 target (str): Device target for current operator information, should be one of ["Ascend", "GPU", "CPU"]. 1174 For the same `func` of :class:`mindspore.ops.Custom`, it may support different data types and formats 1175 on different targets, use `target` to specify which target that this registration information is used 1176 for. If ``None`` , it will be inferred automatically inside :class:`mindspore.ops.Custom`. 1177 Default: ``None`` . 1178 1179 Raises: 1180 TypeError: If `target` is neither str nor None. 1181 """ 1182 if target is not None: 1183 self._is_string(target) 1184 self.target_ = target 1185 return self 1186 1187 def get_op_info(self): 1188 """ 1189 Return the generated registration information as a dict. This function should be invoked at last on the 1190 `CustomRegOp` instance as shown in the above example. 1191 """ 1192 op_info = {} 1193 for k, v in self.__dict__.items(): 1194 if isinstance(k, str) and k.endswith('_'): 1195 k = k.rstrip('_') 1196 op_info[k] = v 1197 if _get_reg_info_attr(op_info, "cust_aicpu"): 1198 _CustomInstaller(op_info).run() 1199 return op_info 1200 1201 1202class DataType: 1203 r""" 1204 Various combinations of dtype and format of Ascend ops. 1205 1206 current support: 1207 1208 .. code-block:: 1209 1210 None_None = ("", "") 1211 None_Default = ("", "DefaultFormat") 1212 BOOL_None = ("bool", "") 1213 BOOL_Default = ("bool", "DefaultFormat") 1214 BOOL_5HD = ("bool", "NC1HWC0") 1215 BOOL_FracZ = ("bool", "FRACTAL_Z") 1216 BOOL_FracNZ = ("bool", "FRACTAL_NZ") 1217 BOOL_C1HWNCoC0 = ("bool", "C1HWNCoC0") 1218 BOOL_NCHW = ("bool", "NCHW") 1219 BOOL_NHWC = ("bool", "NHWC") 1220 BOOL_HWCN = ("bool", "HWCN") 1221 BOOL_NDHWC = ("bool", "NDHWC") 1222 BOOL_ChannelLast = ("bool", "ChannelLast") 1223 1224 I8_None = ("int8", "") 1225 I8_Default = ("int8", "DefaultFormat") 1226 I8_5HD = ("int8", "NC1HWC0") 1227 I8_FracZ = ("int8", "FRACTAL_Z") 1228 I8_FracNZ = ("int8", "FRACTAL_NZ") 1229 I8_C1HWNCoC0 = ("int8", "C1HWNCoC0") 1230 I8_NCHW = ("int8", "NCHW") 1231 I8_NHWC = ("int8", "NHWC") 1232 I8_HWCN = ("int8", "HWCN") 1233 I8_NDHWC = ("int8", "NDHWC") 1234 I8_ChannelLast = ("int8", "ChannelLast") 1235 I8_NDC1HWC0 = ("int8", "NDC1HWC0") 1236 1237 U8_None = ("uint8", "") 1238 U8_Default = ("uint8", "DefaultFormat") 1239 U8_5HD = ("uint8", "NC1HWC0") 1240 U8_FracZ = ("uint8", "FRACTAL_Z") 1241 U8_FracNZ = ("uint8", "FRACTAL_NZ") 1242 U8_C1HWNCoC0 = ("uint8", "C1HWNCoC0") 1243 U8_NCHW = ("uint8", "NCHW") 1244 U8_NHWC = ("uint8", "NHWC") 1245 U8_HWCN = ("uint8", "HWCN") 1246 U8_NDHWC = ("uint8", "NDHWC") 1247 U8_ChannelLast = ("uint8", "ChannelLast") 1248 U8_NDC1HWC0 = ("uint8", "NDC1HWC0") 1249 1250 I16_None = ("int16", "") 1251 I16_Default = ("int16", "DefaultFormat") 1252 I16_5HD = ("int16", "NC1HWC0") 1253 I16_FracZ = ("int16", "FRACTAL_Z") 1254 I16_FracNZ = ("int16", "FRACTAL_NZ") 1255 I16_C1HWNCoC0 = ("int16", "C1HWNCoC0") 1256 I16_NCHW = ("int16", "NCHW") 1257 I16_NHWC = ("int16", "NHWC") 1258 I16_HWCN = ("int16", "HWCN") 1259 I16_NDHWC = ("int16", "NDHWC") 1260 I16_ChannelLast = ("int16", "ChannelLast") 1261 1262 U16_None = ("uint16", "") 1263 U16_Default = ("uint16", "DefaultFormat") 1264 U16_5HD = ("uint16", "NC1HWC0") 1265 U16_FracZ = ("uint16", "FRACTAL_Z") 1266 U16_FracNZ = ("uint16", "FRACTAL_NZ") 1267 U16_C1HWNCoC0 = ("uint16", "C1HWNCoC0") 1268 U16_NCHW = ("uint16", "NCHW") 1269 U16_NHWC = ("uint16", "NHWC") 1270 U16_HWCN = ("uint16", "HWCN") 1271 U16_NDHWC = ("uint16", "NDHWC") 1272 U16_ChannelLast = ("uint16", "ChannelLast") 1273 1274 I32_None = ("int32", "") 1275 I32_Default = ("int32", "DefaultFormat") 1276 I32_5HD = ("int32", "NC1HWC0") 1277 I32_FracZ = ("int32", "FRACTAL_Z") 1278 I32_FracNZ = ("int32", "FRACTAL_NZ") 1279 I32_C1HWNCoC0 = ("int32", "C1HWNCoC0") 1280 I32_NCHW = ("int32", "NCHW") 1281 I32_NHWC = ("int32", "NHWC") 1282 I32_HWCN = ("int32", "HWCN") 1283 I32_NDHWC = ("int32", "NDHWC") 1284 I32_ChannelLast = ("int32", "ChannelLast") 1285 1286 U32_None = ("uint32", "") 1287 U32_Default = ("uint32", "DefaultFormat") 1288 U32_5HD = ("uint32", "NC1HWC0") 1289 U32_FracZ = ("uint32", "FRACTAL_Z") 1290 U32_FracNZ = ("uint32", "FRACTAL_NZ") 1291 U32_C1HWNCoC0 = ("uint32", "C1HWNCoC0") 1292 U32_NCHW = ("uint32", "NCHW") 1293 U32_NHWC = ("uint32", "NHWC") 1294 U32_HWCN = ("uint32", "HWCN") 1295 U32_NDHWC = ("uint32", "NDHWC") 1296 U32_ChannelLast = ("uint32", "ChannelLast") 1297 1298 I64_None = ("int64", "") 1299 I64_Default = ("int64", "DefaultFormat") 1300 I64_5HD = ("int64", "NC1HWC0") 1301 I64_FracZ = ("int64", "FRACTAL_Z") 1302 I64_FracNZ = ("int64", "FRACTAL_NZ") 1303 I64_C1HWNCoC0 = ("int64", "C1HWNCoC0") 1304 I64_NCHW = ("int64", "NCHW") 1305 I64_NHWC = ("int64", "NHWC") 1306 I64_HWCN = ("int64", "HWCN") 1307 I64_NDHWC = ("int64", "NDHWC") 1308 I64_ChannelLast = ("int64", "ChannelLast") 1309 1310 U64_None = ("uint64", "") 1311 U64_Default = ("uint64", "DefaultFormat") 1312 U64_5HD = ("uint64", "NC1HWC0") 1313 U64_FracZ = ("uint64", "FRACTAL_Z") 1314 U64_FracNZ = ("uint64", "FRACTAL_NZ") 1315 U64_C1HWNCoC0 = ("uint64", "C1HWNCoC0") 1316 U64_NCHW = ("uint64", "NCHW") 1317 U64_NHWC = ("uint64", "NHWC") 1318 U64_HWCN = ("uint64", "HWCN") 1319 U64_NDHWC = ("uint64", "NDHWC") 1320 U64_ChannelLast = ("uint64", "ChannelLast") 1321 1322 F16_None = ("float16", "") 1323 F16_Default = ("float16", "DefaultFormat") 1324 F16_5HD = ("float16", "NC1HWC0") 1325 F16_FracZ = ("float16", "FRACTAL_Z") 1326 F16_FracNZ = ("float16", "FRACTAL_NZ") 1327 F16_C1HWNCoC0 = ("float16", "C1HWNCoC0") 1328 F16_NCHW = ("float16", "NCHW") 1329 F16_NHWC = ("float16", "NHWC") 1330 F16_HWCN = ("float16", "HWCN") 1331 F16_NDHWC = ("float16", "NDHWC") 1332 F16_NCDHW = ("float16", "NCDHW") 1333 F16_DHWCN = ("float16", "DHWCN") 1334 F16_NDC1HWC0 = ("float16", "NDC1HWC0") 1335 F16_FRACTAL_Z_3D = ("float16", "FRACTAL_Z_3D") 1336 F16_FracZNLSTM = ("float16", "FRACTAL_ZN_LSTM") 1337 F16_FracZNRNN = ("float16", "FRACTAL_ZN_RNN") 1338 F16_ND_RNNBIAS = ("float16", "ND_RNN_BIAS") 1339 F16_ChannelLast = ("float16", "ChannelLast") 1340 1341 F32_None = ("float32", "") 1342 F32_Default = ("float32", "DefaultFormat") 1343 F32_5HD = ("float32", "NC1HWC0") 1344 F32_FracZ = ("float32", "FRACTAL_Z") 1345 F32_FracNZ = ("float32", "FRACTAL_NZ") 1346 F32_C1HWNCoC0 = ("float32", "C1HWNCoC0") 1347 F32_NCHW = ("float32", "NCHW") 1348 F32_NHWC = ("float32", "NHWC") 1349 F32_HWCN = ("float32", "HWCN") 1350 F32_NDHWC = ("float32", "NDHWC") 1351 F32_NCDHW = ("float32", "NCDHW") 1352 F32_DHWCN = ("float32", "DHWCN") 1353 F32_NDC1HWC0 = ("float32", "NDC1HWC0") 1354 F32_FRACTAL_Z_3D = ("float32", "FRACTAL_Z_3D") 1355 F32_FracZNLSTM = ("float32", "FRACTAL_ZN_LSTM") 1356 F32_FracZNRNN = ("float32", "FRACTAL_ZN_RNN") 1357 F32_ND_RNNBIAS = ("float32", "ND_RNN_BIAS") 1358 F32_ChannelLast = ("float32", "ChannelLast") 1359 1360 F64_None = ("float64", "") 1361 F64_Default = ("float64", "DefaultFormat") 1362 F64_5HD = ("float64", "NC1HWC0") 1363 F64_FracZ = ("float64", "FRACTAL_Z") 1364 F64_FracNZ = ("float64", "FRACTAL_NZ") 1365 F64_C1HWNCoC0 = ("float64", "C1HWNCoC0") 1366 F64_NCHW = ("float64", "NCHW") 1367 F64_NHWC = ("float64", "NHWC") 1368 F64_HWCN = ("float64", "HWCN") 1369 F64_NDHWC = ("float64", "NDHWC") 1370 F64_ChannelLast = ("float64", "ChannelLast") 1371 1372 C64_Default = ("complex64", "DefaultFormat") 1373 C128_Default = ("complex128", "DefaultFormat") 1374 """ 1375 1376 None_None = ("", "") 1377 None_Default = ("", "DefaultFormat") 1378 1379 BOOL_None = ("bool", "") 1380 BOOL_Default = ("bool", "DefaultFormat") 1381 BOOL_5HD = ("bool", "NC1HWC0") 1382 BOOL_FracZ = ("bool", "FRACTAL_Z") 1383 BOOL_FracNZ = ("bool", "FRACTAL_NZ") 1384 BOOL_C1HWNCoC0 = ("bool", "C1HWNCoC0") 1385 BOOL_NCHW = ("bool", "NCHW") 1386 BOOL_NHWC = ("bool", "NHWC") 1387 BOOL_HWCN = ("bool", "HWCN") 1388 BOOL_NDHWC = ("bool", "NDHWC") 1389 BOOL_ChannelLast = ("bool", "ChannelLast") 1390 BOOL_Default_Tuple = ("bool", "DefaultFormat", "tuple") 1391 BOOL_Default_List = ("bool", "DefaultFormat", "list") 1392 1393 I8_None = ("int8", "") 1394 I8_Default = ("int8", "DefaultFormat") 1395 I8_5HD = ("int8", "NC1HWC0") 1396 I8_FracZ = ("int8", "FRACTAL_Z") 1397 I8_FracNZ = ("int8", "FRACTAL_NZ") 1398 I8_C1HWNCoC0 = ("int8", "C1HWNCoC0") 1399 I8_NCHW = ("int8", "NCHW") 1400 I8_NHWC = ("int8", "NHWC") 1401 I8_HWCN = ("int8", "HWCN") 1402 I8_NDHWC = ("int8", "NDHWC") 1403 I8_NCDHW = ("int8", "NCDHW") 1404 I8_ChannelLast = ("int8", "ChannelLast") 1405 I8_NDC1HWC0 = ("int8", "NDC1HWC0") 1406 I8_NC1HWC0 = ("int8", "NC1HWC0") 1407 I8_Default_Tuple = ("int8", "DefaultFormat", "tuple") 1408 I8_Default_List = ("int8", "DefaultFormat", "list") 1409 1410 U8_None = ("uint8", "") 1411 U8_Default = ("uint8", "DefaultFormat") 1412 U8_5HD = ("uint8", "NC1HWC0") 1413 U8_FracZ = ("uint8", "FRACTAL_Z") 1414 U8_FracNZ = ("uint8", "FRACTAL_NZ") 1415 U8_C1HWNCoC0 = ("uint8", "C1HWNCoC0") 1416 U8_NCHW = ("uint8", "NCHW") 1417 U8_NHWC = ("uint8", "NHWC") 1418 U8_HWCN = ("uint8", "HWCN") 1419 U8_NDHWC = ("uint8", "NDHWC") 1420 U8_NCDHW = ("uint8", "NCDHW") 1421 U8_ChannelLast = ("uint8", "ChannelLast") 1422 U8_NDC1HWC0 = ("uint8", "NDC1HWC0") 1423 U8_NC1HWC0 = ("uint8", "NC1HWC0") 1424 U8_Default_Tuple = ("uint8", "DefaultFormat", "tuple") 1425 U8_Default_List = ("uint8", "DefaultFormat", "list") 1426 1427 I16_None = ("int16", "") 1428 I16_Default = ("int16", "DefaultFormat") 1429 I16_5HD = ("int16", "NC1HWC0") 1430 I16_FracZ = ("int16", "FRACTAL_Z") 1431 I16_FracNZ = ("int16", "FRACTAL_NZ") 1432 I16_C1HWNCoC0 = ("int16", "C1HWNCoC0") 1433 I16_NCHW = ("int16", "NCHW") 1434 I16_NHWC = ("int16", "NHWC") 1435 I16_HWCN = ("int16", "HWCN") 1436 I16_NDHWC = ("int16", "NDHWC") 1437 I16_ChannelLast = ("int16", "ChannelLast") 1438 I16_Default_Tuple = ("int16", "DefaultFormat", "tuple") 1439 I16_Default_List = ("int16", "DefaultFormat", "list") 1440 1441 U16_None = ("uint16", "") 1442 U16_Default = ("uint16", "DefaultFormat") 1443 U16_5HD = ("uint16", "NC1HWC0") 1444 U16_FracZ = ("uint16", "FRACTAL_Z") 1445 U16_FracNZ = ("uint16", "FRACTAL_NZ") 1446 U16_C1HWNCoC0 = ("uint16", "C1HWNCoC0") 1447 U16_NCHW = ("uint16", "NCHW") 1448 U16_NHWC = ("uint16", "NHWC") 1449 U16_HWCN = ("uint16", "HWCN") 1450 U16_NDHWC = ("uint16", "NDHWC") 1451 U16_ChannelLast = ("uint16", "ChannelLast") 1452 U16_Default_Tuple = ("uint16", "DefaultFormat", "tuple") 1453 U16_Default_List = ("uint16", "DefaultFormat", "list") 1454 1455 I32_None = ("int32", "") 1456 I32_Default = ("int32", "DefaultFormat") 1457 I32_5HD = ("int32", "NC1HWC0") 1458 I32_FracZ = ("int32", "FRACTAL_Z") 1459 I32_FracNZ = ("int32", "FRACTAL_NZ") 1460 I32_C1HWNCoC0 = ("int32", "C1HWNCoC0") 1461 I32_NCHW = ("int32", "NCHW") 1462 I32_NHWC = ("int32", "NHWC") 1463 I32_HWCN = ("int32", "HWCN") 1464 I32_NDHWC = ("int32", "NDHWC") 1465 I32_NDC1HWC0 = ("int32", "NDC1HWC0") 1466 I32_NCDHW = ("int32", "NCDHW") 1467 I32_ChannelLast = ("int32", "ChannelLast") 1468 I32_Default_Tuple = ("int32", "DefaultFormat", "tuple") 1469 I32_Default_List = ("int32", "DefaultFormat", "list") 1470 1471 U32_None = ("uint32", "") 1472 U32_Default = ("uint32", "DefaultFormat") 1473 U32_5HD = ("uint32", "NC1HWC0") 1474 U32_FracZ = ("uint32", "FRACTAL_Z") 1475 U32_FracNZ = ("uint32", "FRACTAL_NZ") 1476 U32_C1HWNCoC0 = ("uint32", "C1HWNCoC0") 1477 U32_NCHW = ("uint32", "NCHW") 1478 U32_NHWC = ("uint32", "NHWC") 1479 U32_HWCN = ("uint32", "HWCN") 1480 U32_NDHWC = ("uint32", "NDHWC") 1481 U32_ChannelLast = ("uint32", "ChannelLast") 1482 U32_Default_Tuple = ("uint32", "DefaultFormat", "tuple") 1483 U32_Default_List = ("uint32", "DefaultFormat", "list") 1484 1485 I64_None = ("int64", "") 1486 I64_Default = ("int64", "DefaultFormat") 1487 I64_5HD = ("int64", "NC1HWC0") 1488 I64_FracZ = ("int64", "FRACTAL_Z") 1489 I64_FracNZ = ("int64", "FRACTAL_NZ") 1490 I64_C1HWNCoC0 = ("int64", "C1HWNCoC0") 1491 I64_NCHW = ("int64", "NCHW") 1492 I64_NHWC = ("int64", "NHWC") 1493 I64_HWCN = ("int64", "HWCN") 1494 I64_NDHWC = ("int64", "NDHWC") 1495 I64_ChannelLast = ("int64", "ChannelLast") 1496 I64_Default_Tuple = ("int64", "DefaultFormat", "tuple") 1497 I64_Default_List = ("int64", "DefaultFormat", "list") 1498 1499 U64_None = ("uint64", "") 1500 U64_Default = ("uint64", "DefaultFormat") 1501 U64_5HD = ("uint64", "NC1HWC0") 1502 U64_FracZ = ("uint64", "FRACTAL_Z") 1503 U64_FracNZ = ("uint64", "FRACTAL_NZ") 1504 U64_C1HWNCoC0 = ("uint64", "C1HWNCoC0") 1505 U64_NCHW = ("uint64", "NCHW") 1506 U64_NHWC = ("uint64", "NHWC") 1507 U64_HWCN = ("uint64", "HWCN") 1508 U64_NDHWC = ("uint64", "NDHWC") 1509 U64_ChannelLast = ("uint64", "ChannelLast") 1510 U64_Default_Tuple = ("uint64", "DefaultFormat", "tuple") 1511 U64_Default_List = ("uint64", "DefaultFormat", "list") 1512 1513 F16_None = ("float16", "") 1514 F16_Default = ("float16", "DefaultFormat") 1515 F16_5HD = ("float16", "NC1HWC0") 1516 F16_FracZ = ("float16", "FRACTAL_Z") 1517 F16_FracNZ = ("float16", "FRACTAL_NZ") 1518 F16_C1HWNCoC0 = ("float16", "C1HWNCoC0") 1519 F16_NCHW = ("float16", "NCHW") 1520 F16_NHWC = ("float16", "NHWC") 1521 F16_HWCN = ("float16", "HWCN") 1522 F16_NDHWC = ("float16", "NDHWC") 1523 F16_NCDHW = ("float16", "NCDHW") 1524 F16_DHWCN = ("float16", "DHWCN") 1525 F16_NDC1HWC0 = ("float16", "NDC1HWC0") 1526 F16_FRACTAL_Z_3D = ("float16", "FRACTAL_Z_3D") 1527 F16_FracZNLSTM = ("float16", "FRACTAL_ZN_LSTM") 1528 F16_FracZNRNN = ("float16", "FRACTAL_ZN_RNN") 1529 F16_ND_RNNBIAS = ("float16", "ND_RNN_BIAS") 1530 F16_ChannelLast = ("float16", "ChannelLast") 1531 F16_Default_Tuple = ("float16", "DefaultFormat", "tuple") 1532 F16_Default_List = ("float16", "DefaultFormat", "list") 1533 1534 F32_None = ("float32", "") 1535 F32_Default = ("float32", "DefaultFormat") 1536 F32_5HD = ("float32", "NC1HWC0") 1537 F32_FracZ = ("float32", "FRACTAL_Z") 1538 F32_FracNZ = ("float32", "FRACTAL_NZ") 1539 F32_C1HWNCoC0 = ("float32", "C1HWNCoC0") 1540 F32_NCHW = ("float32", "NCHW") 1541 F32_NHWC = ("float32", "NHWC") 1542 F32_HWCN = ("float32", "HWCN") 1543 F32_NDHWC = ("float32", "NDHWC") 1544 F32_NCDHW = ("float32", "NCDHW") 1545 F32_DHWCN = ("float32", "DHWCN") 1546 F32_NDC1HWC0 = ("float32", "NDC1HWC0") 1547 F32_FRACTAL_Z_3D = ("float32", "FRACTAL_Z_3D") 1548 F32_FracZNLSTM = ("float32", "FRACTAL_ZN_LSTM") 1549 F32_FracZNRNN = ("float32", "FRACTAL_ZN_RNN") 1550 F32_ND_RNNBIAS = ("float32", "ND_RNN_BIAS") 1551 F32_ChannelLast = ("float32", "ChannelLast") 1552 F32_Default_Tuple = ("float32", "DefaultFormat", "tuple") 1553 F32_Default_List = ("float32", "DefaultFormat", "list") 1554 1555 F64_None = ("float64", "") 1556 F64_Default = ("float64", "DefaultFormat") 1557 F64_5HD = ("float64", "NC1HWC0") 1558 F64_FracZ = ("float64", "FRACTAL_Z") 1559 F64_FracNZ = ("float64", "FRACTAL_NZ") 1560 F64_C1HWNCoC0 = ("float64", "C1HWNCoC0") 1561 F64_NCHW = ("float64", "NCHW") 1562 F64_NHWC = ("float64", "NHWC") 1563 F64_HWCN = ("float64", "HWCN") 1564 F64_NDHWC = ("float64", "NDHWC") 1565 F64_ChannelLast = ("float64", "ChannelLast") 1566 F64_Default_Tuple = ("float64", "DefaultFormat", "tuple") 1567 F64_Default_List = ("float64", "DefaultFormat", "list") 1568 1569 C64_Default = ("complex64", "DefaultFormat") 1570 C128_Default = ("complex128", "DefaultFormat") 1571 C64_Default_Tuple = ("complex64", "DefaultFormat", "tuple") 1572 C128_Default_Tuple = ("complex128", "DefaultFormat", "tuple") 1573