• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020-2023 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16"""Operators info register."""
17from __future__ import absolute_import
18from __future__ import division
19
20import inspect
21import json
22import os
23import functools
24import platform
25import hashlib
26import shutil
27
28from mindspore._c_expression import Oplib
29from mindspore import _checkparam as validator
30from mindspore import log as logger
31
32if platform.system() == "Linux":
33    import fcntl
34
35# path of built-in op info register.
36BUILT_IN_OPS_REGISTER_PATH = "mindspore/ops/_op_impl"
37BUILT_IN_CUSTOM_OPS_REGISTER_PATH = "mindspore/ops/_op_impl/_custom_op"
38
39KEY_NAME = "name"
40ASCEND_CUSTOM_OPP_PATH = "ASCEND_CUSTOM_OPP_PATH"
41
42
43def _get_reg_info_attr(op_info, attr_name, default_value=None):
44    """get attr value"""
45    for _, item in enumerate(op_info.get("attr", [])):
46        if item.get(KEY_NAME) == attr_name:
47            return item.get("defaultValue")
48    return default_value
49
50
51class _CustomInstaller:
52    """save custom op registration information to a json file which will be used by GE"""
53    reg_info_hash = []  # used to avoid writing the same reg info to file multiple times
54    copied_paths = []  # used to avoid copying the same file multiple times
55
56    def __init__(self, op_info, func=None):
57        self.op_info = op_info
58        self.func = func
59        self.op_type = op_info.get("op_name") if not func else func.__name__
60        vendor_name = "ms"
61        custom_dir = os.path.join(os.path.realpath("./"), "vendors", vendor_name)
62        self._set_env(custom_dir)
63        op_impl_dir = os.path.join(custom_dir, "op_impl")
64        self.ai_core_config_dir = os.path.join(op_impl_dir, "ai_core", "tbe", "config")
65        self.ai_core_impl_dir = os.path.join(op_impl_dir, "ai_core", "tbe", vendor_name + "_impl")
66        self.ai_cpu_config_dir = os.path.join(op_impl_dir, "cpu", "config")
67        self.ai_cpu_impl_dir = os.path.join(op_impl_dir, "cpu", "aicpu_kernel", "impl")
68
69    @staticmethod
70    def _set_env(custom_opp_path):
71        """set custom file path to env"""
72        if not os.environ.get(ASCEND_CUSTOM_OPP_PATH):
73            os.environ[ASCEND_CUSTOM_OPP_PATH] = custom_opp_path
74        else:
75            paths = os.environ[ASCEND_CUSTOM_OPP_PATH].split(':')
76            if custom_opp_path not in paths:
77                os.environ[ASCEND_CUSTOM_OPP_PATH] = custom_opp_path + ':' + os.environ[ASCEND_CUSTOM_OPP_PATH]
78
79    @staticmethod
80    def _create_dir(*dir_names):
81        """create directory"""
82        for dir_name in dir_names:
83            if not os.path.isdir(dir_name):
84                try:
85                    os.makedirs(dir_name, exist_ok=True)
86                except OSError as err:
87                    if err.errno == 17:  # File exists
88                        pass
89                    else:
90                        raise err
91
92    @staticmethod
93    def _copy_file(src_path, dst_dir):
94        """copy file"""
95        if not os.path.exists(src_path) or src_path in _CustomInstaller.copied_paths:
96            return
97        _CustomInstaller.copied_paths.append(src_path)
98        if os.path.isfile(src_path):
99            lock_file = os.path.join(dst_dir, "file.lock")
100            with os.fdopen(os.open(lock_file, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600), 'w') as f:
101                fcntl.flock(f.fileno(), fcntl.LOCK_EX)
102                shutil.copy(src_path, dst_dir)
103
104    def check(self):
105        """check if the reg info need written"""
106        if platform.system() != "Linux":
107            return False
108        if not os.environ.get("MS_DEV_CUSTOM_OPP_PATH"):
109            # only process the first time import the mindspore module
110            return False
111        if self.op_info.get("target") in ["GPU", "CPU"]:
112            return False
113        sha256 = hashlib.sha256()
114        value = json.dumps(self.op_info, sort_keys=True).encode()
115        sha256.update(value)
116        hash_value = sha256.hexdigest()
117        if hash_value in _CustomInstaller.reg_info_hash:
118            return False
119        _CustomInstaller.reg_info_hash.append(hash_value)
120        return True
121
122    def _find_ai_cpu_so_path(self, so_file):
123        """find the absolute path of so"""
124        current_path = os.path.dirname(os.path.abspath(__file__))
125        search_paths = [current_path + "/../lib", current_path + "/../lib/plugin/ascend"]
126        for path in search_paths:
127            so_path = os.path.join(path, so_file)
128            if os.path.exists(so_path):
129                return so_path
130        logger.warning("For Custom op '{}', can not find the aicpu so file '{}' in the following directories:\n{}"
131                       .format(self.op_type, so_file, "\n".join(search_paths)))
132        return ""
133
134    def _gen_ai_core_reg_info(self, imply_path, func_name):
135        """generate reg info"""
136
137        def _get_dtype_format(idx):
138            data_type = []
139            data_format = []
140            for _, dtype_format in enumerate(self.op_info.get("dtype_format", [])):
141                if not dtype_format[idx][0]:
142                    data_type = None
143                else:
144                    data_type.append(dtype_format[idx][0])
145                if not dtype_format[idx][1]:
146                    data_format = None
147                else:
148                    if dtype_format[idx][1] == "DefaultFormat":
149                        data_format.append("ND")
150                    else:
151                        data_format.append(dtype_format[idx][1])
152            return data_type, data_format
153
154        op_info = {"opFile": {"value": os.path.splitext(os.path.basename(imply_path))[0]},
155                   "opInterface": {"value": func_name}}
156        # attr
157        attrs_name = []
158        for _, item in enumerate(self.op_info.get("attr", [])):
159            attr_name = item.get(KEY_NAME)
160            attrs_name.append(attr_name)
161            key = "attr_" + attr_name
162            op_info[key] = {}
163            for k, v in item.items():
164                if k != KEY_NAME:
165                    op_info[key][k] = v
166        if attrs_name:
167            op_info["attr"] = {"list": ",".join(attrs_name)}
168        # input and output
169        inputs = self.op_info.get("inputs", [])
170        outputs = self.op_info.get("outputs", [])
171        input_num = len(inputs)
172        output_num = len(outputs)
173        for i in range(input_num + output_num):
174            item = inputs[i] if i < input_num else outputs[i - input_num]
175            key = "input" if i < input_num else "output"
176            key += str(item.get("index"))
177            op_info[key] = {KEY_NAME: item.get(KEY_NAME),
178                            "paramType": item.get("paramType", "required"),
179                            "shape": item.get("shape", "all")}
180            dtype, formats = _get_dtype_format(i)
181            if dtype:
182                op_info[key]["dtype"] = ",".join(dtype)
183            if formats:
184                op_info[key]["format"] = ",".join(formats)
185        return op_info
186
187    @staticmethod
188    def _gen_ai_cpu_reg_info(so_file):
189        """generate reg info"""
190        op_info = {"opInfo": {"computeCost": "100",
191                              "engine": "DNN_VM_AICPU",
192                              "flagAsync": "False",
193                              "flagPartial": "False",
194                              "functionName": "RunCpuKernel",
195                              "kernelSo": so_file,
196                              "opKernelLib": "CUSTAICPUKernel",
197                              "userDefined": "True"}}
198        return op_info
199
200    def _save_op_info(self, dst_dir, file_name, op_info):
201        """save op info file"""
202        repo = {}
203        save_path = os.path.join(dst_dir, file_name)
204        lock_file = os.path.join(dst_dir, "file.lock")
205        with os.fdopen(os.open(lock_file, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600), 'w') as f:
206            fcntl.flock(f.fileno(), fcntl.LOCK_EX)
207            if os.path.isfile(save_path):
208                with open(save_path, 'r') as fr:
209                    json_str = fr.read()
210                json_str = "{}" if json_str == "" else json_str
211                repo = json.loads(json_str)
212            repo.update({self.op_type: op_info})
213            with os.fdopen(os.open(save_path, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600), 'w') as fw:
214                json.dump(repo, fw, sort_keys=True, indent=4, separators=(',', ':'))
215
216    def run(self):
217        """save reg info to file"""
218        if not self.check():
219            return
220        so_name = _get_reg_info_attr(self.op_info, "cust_aicpu")
221        if so_name:
222            _CustomInstaller._create_dir(self.ai_cpu_config_dir, self.ai_cpu_impl_dir)
223            # copy so file
224            so_file = "lib" + so_name + ".so"
225            imply_path = self._find_ai_cpu_so_path(so_file)
226            self._copy_file(imply_path, self.ai_cpu_impl_dir)
227            # generate and copy reg info file
228            op_info = self._gen_ai_cpu_reg_info(so_file)
229            self._save_op_info(self.ai_cpu_config_dir, "cust_aicpu_kernel.json", op_info)
230        else:
231            _CustomInstaller._create_dir(self.ai_core_config_dir, self.ai_core_impl_dir)
232            # copy dsl file
233            imply_path = os.path.realpath(inspect.getfile(self.func))
234            self._copy_file(imply_path, self.ai_core_impl_dir)
235            # generate and copy reg info file
236            op_info = self._gen_ai_core_reg_info(imply_path, self.func.__name__)
237            self._copy_file(imply_path, self.ai_core_impl_dir)
238            for arc_name in ["ascend910", "ascend910b", "ascend910c", "ascend310p"]:
239                arc_dir = os.path.join(self.ai_core_config_dir, arc_name)
240                _CustomInstaller._create_dir(arc_dir)
241                self._save_op_info(arc_dir, "aic-{}-ops-info.json".format(arc_name), op_info)
242
243
244def op_info_register(op_info):
245    r"""
246    A decorator which is used to register an operator.
247
248    Note:
249        'op_info' should represent the operator information by string with json format.
250        The 'op_info' will be added into oplib.
251
252    Args:
253        op_info (Union[str, dict]): operator information in json format.
254
255    Examples:
256        >>> from mindspore.ops import op_info_register, TBERegOp, DataType
257        >>> abs_op_info = TBERegOp("Abs") \
258        ...    .fusion_type("ELEMWISE") \
259        ...    .async_flag(False) \
260        ...    .binfile_name("abs.so") \
261        ...    .compute_cost(10) \
262        ...    .kernel_name("abs") \
263        ...    .partial_flag(True) \
264        ...    .op_pattern("formatAgnostic") \
265        ...    .input(0, "x", None, "required", None) \
266        ...    .output(0, "y", True, "required", "all") \
267        ...    .dtype_format(DataType.F16_None, DataType.F16_None) \
268        ...    .dtype_format(DataType.F32_None, DataType.F32_None) \
269        ...    .dtype_format(DataType.I32_None, DataType.I32_None) \
270        ...    .get_op_info()
271        >>>
272        >>> @op_info_register(abs_op_info)
273        ... def _abs_tbe():
274        ...    return
275        ...
276
277    Returns:
278        Function, returns a decorator for op info register.
279    """
280
281    def register_decorator(func):
282        if isinstance(op_info, dict):
283            op_info_real = json.dumps(op_info)
284        else:
285            op_info_real = op_info
286        validator.check_value_type("op_info", op_info_real, [str])
287        op_lib = Oplib()
288        file_path = os.path.realpath(inspect.getfile(func))
289        # keep the path custom ops implementation.
290        if BUILT_IN_CUSTOM_OPS_REGISTER_PATH in file_path:
291            imply_path = file_path
292        else:
293            imply_path = "" if BUILT_IN_OPS_REGISTER_PATH in file_path else file_path
294        if not op_lib.reg_op(op_info_real, imply_path):
295            raise ValueError('Invalid op info {}:\n{}\n'.format(file_path, op_info_real))
296
297        def wrapped_function(*args, **kwargs):
298            return func(*args, **kwargs)
299
300        return wrapped_function
301
302    return register_decorator
303
304
305def custom_info_register(*reg_info):
306    r"""
307    A decorator which is used to bind the registration information to the `func` parameter of
308    :class:`mindspore.ops.Custom`.
309
310    Note:
311        The 'reg_info' will be added into oplib.
312
313    Args:
314        reg_info (tuple[str, dict]): Each item represents registration information in json format.
315
316    Returns:
317        Function, returns a decorator for op info register.
318
319    Raises:
320        TypeError: If `reg_info` is not a tuple.
321
322    Examples:
323        >>> from mindspore.ops import custom_info_register, CustomRegOp, DataType
324        >>> custom_func_ascend_info = CustomRegOp() \
325        ...     .input(0, "x", "dynamic") \
326        ...     .output(0, "y") \
327        ...     .dtype_format(DataType.F16_Default, DataType.F16_Default) \
328        ...     .dtype_format(DataType.F32_Default, DataType.F32_Default) \
329        ...     .target("Ascend") \
330        ...     .get_op_info()
331        >>>
332        >>> @custom_info_register(custom_func_ascend_info)
333        ... def custom_func(x):
334        ...     pass
335    """
336
337    def decorator(func):
338        setattr(func, "reg_info", reg_info)
339        if reg_info:
340            used_reg_info = reg_info[0]
341            if isinstance(used_reg_info, dict):
342                # ai_cpu should be parsed inside CustomRegOp, skip it here
343                if not _get_reg_info_attr(used_reg_info, "cust_aicpu"):
344                    _CustomInstaller(used_reg_info, func).run()
345
346        @functools.wraps(func)
347        def wrapper(*args, **kwargs):
348            return func(*args, **kwargs)
349
350        return wrapper
351
352    return decorator
353
354
355class RegOp:
356    """
357    Base class for op info register.
358
359    Args:
360        op_name (str): Name of operator.
361    """
362
363    def __init__(self, op_name=""):
364        if not isinstance(op_name, str):
365            raise ValueError("op name value must be string")
366        if not op_name.strip():
367            raise ValueError("op name is empty")
368        self.op_name = op_name
369        self.inputs = []
370        self.outputs = []
371        self.attr_ = []
372        self.fusion_type_ = ''
373        self.dtype_format_ = []
374
375    def _is_string(self, value):
376        """
377        Check if the value is a str type.
378
379        Args:
380            value: Parameter to be checked.
381
382        Raises:
383            TypeError: If the type of value is not a str.
384        """
385        if not isinstance(value, str):
386            raise TypeError("%s value must be str" % str(value))
387        return True
388
389    def _is_int(self, value):
390        """
391        Check if the value is an int.
392
393        Args:
394            value: Parameter to be checked.
395
396        Raises:
397            TypeError: If the type of value is not an int.
398        """
399        if not isinstance(value, int):
400            raise TypeError("%s value must be int" % str(value))
401        return True
402
403    def _is_bool(self, value):
404        """
405        Check if the value is a bool.
406
407        Args:
408            value: Parameter to be checked.
409
410        Raises:
411            TypeError: If the type of value is not a bool.
412        """
413        if not isinstance(value, bool):
414            raise TypeError("%s value must be bool" % str(value))
415        return True
416
417    @staticmethod
418    def _is_list(value):
419        """
420        Check if the value is a list.
421
422        Args:
423            value: Parameter to be checked.
424
425        Raises:
426            TypeError: If the type of value is not a list.
427        """
428        if not isinstance(value, list):
429            raise TypeError("%s value must be list" % str(value))
430        return True
431
432    def _check_param(self, param_list, key_list, fn_list, kwargs):
433        """
434        Check if the parameter type is correct.
435
436        Args:
437            param_list (list): Parameter list to be checked.
438            key_list (list): The keys of output dict.
439            fn_list (list): Function used for parameter checking. If the function list has only one element,
440                            all parameters will use the same function.
441            kwargs (dict): Other parameter information.
442
443        Raises:
444            TypeError: If the type of value is not list.
445            ValueError: If the size of param list is not equal to the size of key list, or
446                        the size of param list is not equal to the size of function list.
447        """
448        for i in [param_list, key_list, fn_list]:
449            if not isinstance(i, list):
450                raise TypeError("%s value must be list type" % str(i))
451        if len(param_list) != len(key_list) or (len(fn_list) != 1 and len(param_list) != len(fn_list)):
452            raise ValueError("param_list size {}, key_list size {}, must be equal.And fn_list size {}.".
453                             format(len(param_list), len(key_list), len(fn_list)))
454        out_dict = {}
455        for idx, element in enumerate(param_list):
456            if element is not None:
457                if len(fn_list) == 1:
458                    fn_list[0](element)
459                else:
460                    fn_list[idx](element)
461                out_dict[key_list[idx]] = element
462        if kwargs:
463            out_dict = dict(out_dict, **kwargs)
464        return out_dict
465
466    def fusion_type(self, fusion_type):
467        """
468        Fusion type of the operator.
469
470        Args:
471            fusion_type (str): Value of fusion type.
472        """
473        self._is_string(fusion_type)
474        self.fusion_type_ = fusion_type
475        return self
476
477    def dtype_format(self, *args):
478        """
479        A dtype and format supported by the operator.
480
481        Args:
482            args (tuple): Value of dtype and format.
483
484        Raises:
485            ValueError: If the size of args not equal to input size add output size.
486            TypeError: If the type of args is not tuple.
487        """
488        if len(self.inputs) + len(self.outputs) != len(args):
489            raise ValueError("input size add output size must be equal to dtype format size")
490        dtype_format = []
491        for arg in args:
492            if not isinstance(arg, tuple) or (len(arg) != 2 and len(arg) != 3):
493                raise ValueError("dtype and format value must be tuple of two or three elements")
494            self._is_string(arg[0])
495            self._is_string(arg[1])
496            if len(arg) == 3:
497                if self._is_string(arg[2]):
498                    dtype_format.append(arg)
499            else:
500                dtype_format.append(arg)
501        self.dtype_format_.append(tuple(dtype_format))
502        return self
503
504    def get_op_info(self):
505        """
506        Return all registration information for this instance.
507
508        The '_' character ending the key is removed here for compatibility with previous version.
509
510        Key will be unified into an underlined form later.
511        """
512        op_info = {}
513        for key, value in self.__dict__.items():
514            if isinstance(key, str) and key.endswith('_'):
515                key = key.rstrip('_')
516                key_dic = {"dynamic_shape_support": "dynamicShapeSupport",
517                           "dynamic_rank_support": "dynamicRankSupport",
518                           "dynamic_compile_static": "dynamicCompileStatic",
519                           "need_check_support": "needCheckSupport",
520                           "dynamic_format": "dynamicFormat"
521                           }
522                key = key_dic.get(key, key)
523            op_info[key] = value
524        return op_info
525
526
527class CpuRegOp(RegOp):
528    """Class for Cpu op info register"""
529
530    def __init__(self, op_name):
531        super(CpuRegOp, self).__init__(op_name)
532        self.imply_type = "CPU"
533
534    def input(self, index=None, name=None, param_type=None, **kwargs):
535        """
536        Register Cpu op input information.
537
538        Args:
539            index (int): Order of the input. Default: ``None`` .
540            name (str): Name of the input. Default: ``None`` .
541            param_type (str): Param type of the input. Default: ``None`` .
542            kwargs (dict): Other information of the input.
543        """
544        param_list = [index, name, param_type]
545        key_list = ["index", "name", "paramType"]
546        fn_list = [self._is_int, self._is_string, self._is_string]
547        input_dict = self._check_param(param_list, key_list, fn_list, kwargs)
548        self.inputs.append(input_dict)
549        return self
550
551    def output(self, index=None, name=None, param_type=None, **kwargs):
552        """
553        Register AiCPU op output information.
554
555        Args:
556            index (int): Order of the output. Default: ``None`` .
557            name (str): Name of the output. Default: ``None`` .
558            param_type (str): Param type of the output. Default: ``None`` .
559            kwargs (dict): Other information of the output.
560        """
561        param_list = [index, name, param_type]
562        key_list = ["index", "name", "paramType"]
563        fn_list = [self._is_int, self._is_string, self._is_string]
564        output_dict = self._check_param(param_list, key_list, fn_list, kwargs)
565        self.outputs.append(output_dict)
566        return self
567
568    def attr(self, name=None, value_type=None, value=None, **kwargs):
569        """
570        Register AiCPU op attribute information.
571
572        Args:
573            name (str): Name of the attribute. Default: ``None`` .
574            value_type (str): Value type of the attribute. Default: ``None`` .
575            value (str): Value of the attribute. Default: ``None`` .
576            kwargs (dict): Other information of the attribute.
577        """
578        param_list = [name, value_type, value]
579        key_list = ["name", "type", "value"]
580        fn_list = [self._is_string]
581        attr_dict = self._check_param(param_list, key_list, fn_list, kwargs)
582        self.attr_.append(attr_dict)
583        return self
584
585
586class AkgRegOp(RegOp):
587    """Class for Akg op info register."""
588
589    def __init__(self, op_name, processor):
590        super(AkgRegOp, self).__init__(op_name)
591        self.imply_type = "AKG"
592        self.processor = processor
593
594    def input(self, index=None, name=None, param_type=None, **kwargs):
595        """
596        Register Akg op input information.
597
598        Args:
599            index (int): Order of the input. Default: ``None`` .
600            name (str): Name of the input. Default: ``None`` .
601            param_type (str): Param type of the input. Default: ``None`` .
602            kwargs (dict): Other information of the input.
603        """
604        param_list = [index, name, param_type]
605        key_list = ["index", "name", "paramType"]
606        fn_list = [self._is_int, self._is_string, self._is_string]
607        input_dict = self._check_param(param_list, key_list, fn_list, kwargs)
608        self.inputs.append(input_dict)
609        return self
610
611    def output(self, index=None, name=None, **kwargs):
612        """
613        Register Akg op output information.
614
615        Args:
616            index (int): Order of the output. Default: ``None`` .
617            name (str): Name of the output. Default: ``None`` .
618            kwargs (dict): Other information of the output.
619        """
620        param_list = [index, name]
621        key_list = ["index", "name"]
622        fn_list = [self._is_int, self._is_string]
623        output_dict = self._check_param(param_list, key_list, fn_list, kwargs)
624        self.outputs.append(output_dict)
625        return self
626
627    def attr(self, name=None, param_type=None, value_type=None, **kwargs):
628        """
629        Register Akg op attribute information.
630
631        Args:
632            name (str): Name of the attribute. Default: ``None`` .
633            param_type (str): Param type of the attribute. Default: ``None`` .
634            value_type (str): Value type of the attribute. Default: ``None`` .
635            kwargs (dict): Other information of the attribute.
636        """
637        param_list = [name, param_type, value_type]
638        key_list = ["name", "paramType", "type"]
639        fn_list = [self._is_string]
640        attr_dict = self._check_param(param_list, key_list, fn_list, kwargs)
641        self.attr_.append(attr_dict)
642        return self
643
644
645class AkgGpuRegOp(AkgRegOp):
646    """Class for AkgGpu op info register"""
647
648    def __init__(self, op_name):
649        super(AkgGpuRegOp, self).__init__(op_name, "CUDA")
650
651
652class AkgAscendRegOp(AkgRegOp):
653    """Class for AkgAscend op info register"""
654
655    def __init__(self, op_name):
656        super(AkgAscendRegOp, self).__init__(op_name, "AiCore")
657
658
659class AkgCpuRegOp(AkgRegOp):
660    """Class for AkgCpu op info register"""
661
662    def __init__(self, op_name):
663        super(AkgCpuRegOp, self).__init__(op_name, "CPU")
664
665
666class AiCPURegOp(CpuRegOp):
667    r"""
668    Class for AiCPU operator information registration.
669
670    Args:
671        op_name (str): Name of operator.
672
673    Examples:
674        >>> from mindspore.ops import AiCPURegOp, DataType
675        >>> stack_op_info = AiCPURegOp("Stack") \
676        ...    .fusion_type("OPAQUE") \
677        ...    .attr("axis", "int") \
678        ...    .input(0, "x", "dynamic") \
679        ...    .output(0, "y", "required") \
680        ...    .dtype_format(DataType.I8_Default, DataType.I8_Default) \
681        ...    .dtype_format(DataType.I16_Default, DataType.I16_Default) \
682        ...    .dtype_format(DataType.I32_Default, DataType.I32_Default) \
683        ...    .dtype_format(DataType.I64_Default, DataType.I64_Default) \
684        ...    .dtype_format(DataType.U8_Default, DataType.U8_Default) \
685        ...    .dtype_format(DataType.U16_Default, DataType.U16_Default) \
686        ...    .dtype_format(DataType.U32_Default, DataType.U32_Default) \
687        ...    .dtype_format(DataType.U64_Default, DataType.U64_Default) \
688        ...    .dtype_format(DataType.F16_Default, DataType.F16_Default) \
689        ...    .dtype_format(DataType.F32_Default, DataType.F32_Default) \
690        ...    .dtype_format(DataType.F64_Default, DataType.F64_Default) \
691        ...    .dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
692        ...    .get_op_info()
693        >>>
694    """
695
696    def __init__(self, op_name):
697        super(AiCPURegOp, self).__init__(op_name)
698        self.imply_type = "AiCPU"
699
700
701class TBERegOp(RegOp):
702    r"""
703    Class for TBE operator information registration. TBE (Tensor Boost Engine) is the Ascend operator development
704    tool, which is extended on the basis of the TVM framework to develop custom operators.
705
706    Args:
707        op_name (str): Name of operator.
708
709    Examples:
710        >>> from mindspore.ops import TBERegOp, DataType
711        >>> op_name_op_info = TBERegOp("OpName") \
712        ...    .fusion_type("ELEMWISE") \
713        ...    .async_flag(False) \
714        ...    .binfile_name("op_name.so") \
715        ...    .compute_cost(10) \
716        ...    .kernel_name("op_name") \
717        ...    .partial_flag(True) \
718        ...    .op_pattern("formatAgnostic") \
719        ...    .need_check_supported(True) \
720        ...    .dynamic_shape(True) \
721        ...    .dynamic_rank_support(True) \
722        ...    .dynamic_compile_static(True) \
723        ...    .attr("format", "required", "str", "all") \
724        ...    .input(0, "x1", None, "required", None) \
725        ...    .input(0, "x2", None, "required", None) \
726        ...    .input(1, "axis", None, "required", None) \
727        ...    .output(0, "y", True, "required", "all") \
728        ...    .real_input_index([1, 0]) \
729        ...    .input_to_attr_index([2]) \
730        ...    .unknown_shape_formats(["ND", "ND", "ND", "ND"]) \
731        ...    .reshape_type("NC") \
732        ...    .is_dynamic_format(True) \
733        ...    .dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None, DataType.F16_None) \
734        ...    .dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None, DataType.F32_None) \
735        ...    .dtype_format(DataType.I32_None, DataType.I32_None, DataType.I32_None, DataType.I32_None) \
736        ...    .get_op_info()
737        >>>
738    """
739
740    def __init__(self, op_name):
741        super(TBERegOp, self).__init__(op_name)
742        self.imply_type = "TBE"
743        self.async_flag_ = False
744        self.binfile_ = ''
745        self.compute_cost_ = 10
746        self.kernel_ = ''
747        self.partial_flag_ = False
748        self.reshape_type_ = ''
749        self.dynamic_rank_support_ = False
750        self.dynamic_shape_support_ = False
751        self.dynamic_compile_static_ = False
752        self.need_check_support_ = False
753        self.dynamic_format_ = False
754        self.op_pattern_ = ""
755        self.real_input_index_ = []
756        self.input_to_attr_index_ = []
757        self.unknown_shape_formats_ = []
758
759    def unknown_shape_formats(self, unknown_shape_formats):
760        """
761        Description data arrangement of operator input / output tensor in dynamic shape scene.
762
763        Args:
764            unknown_shape_formats (list): Description data arrangement of operator input / output tensor in dynamic
765                                          shape scene.
766        """
767        RegOp._is_list(unknown_shape_formats)
768        self.unknown_shape_formats_.append(unknown_shape_formats)
769        return self
770
771    def dynamic_rank_support(self, dynamic_rank_support):
772        """
773        Description whether the operator supports dynamic rank (dynamic dimension).
774
775        Args:
776            dynamic_rank_support (bool): Description whether the operator supports dynamic rank (dynamic dimension).
777                                         True: indicates that dynamic rank is supported, and the operator supports
778                                         shape (- 2), which is used to determine whether dynamic is performed.
779                                         False: indicates that the operator does not support dynamic rank.
780                                         Default: ``False`` .
781        """
782        if self._is_bool(dynamic_rank_support):
783            self.dynamic_rank_support_ = dynamic_rank_support
784        return self
785
786    def real_input_index(self, real_input_index):
787        """
788        Description operator front end and tbe operator input mapping.
789
790        Args:
791            real_input_index (list): Value of real_input_index. Default: ``()`` .
792        """
793        RegOp._is_list(real_input_index)
794        self.real_input_index_ = real_input_index
795        return self
796
797    def input_to_attr_index(self, input_to_attr_index):
798        """
799        Description the index of input need to cast to attr.
800
801        Args:
802            input_to_attr_index (list): Value of input_to_attr_index. Default: ``()`` .
803        """
804        RegOp._is_list(input_to_attr_index)
805        self.input_to_attr_index_ = input_to_attr_index
806        return self
807
808    def async_flag(self, async_flag=False):
809        """
810        Define the calculation efficiency of the operator, whether the asynchronous calculation is supported.
811
812        Args:
813            async_flag (bool): Value of async flag. Default: ``False`` .
814        """
815        self._is_bool(async_flag)
816        self.async_flag_ = async_flag
817        return self
818
819    def binfile_name(self, binfile_name):
820        """
821        Set the binary file name of the operator, it is optional.
822
823        Args:
824            binfile_name (str): The binary file name of the operator.
825        """
826        self._is_string(binfile_name)
827        self.binfile_ = binfile_name
828        return self
829
830    def compute_cost(self, compute_cost=10):
831        """
832        Define the calculation efficiency of operator, which refers to the value of the cost model
833        in the tiling module.
834
835        Args:
836            compute_cost (int): Value of compute cost. Default: ``10`` .
837        """
838        self._is_int(compute_cost)
839        self.compute_cost_ = compute_cost
840        return self
841
842    def kernel_name(self, kernel_name):
843        """
844        The name of operator kernel.
845
846        Args:
847            kernel_name (str): Name of operator kernel.
848        """
849        self._is_string(kernel_name)
850        self.kernel_ = kernel_name
851        return self
852
853    def partial_flag(self, partial_flag=True):
854        """
855        Define the calculation efficiency of operator, whether the partial calculation is supported.
856
857        Args:
858            partial_flag (bool): Value of partial flag. Default: ``True`` .
859        """
860        self._is_bool(partial_flag)
861        self.partial_flag_ = partial_flag
862        return self
863
864    def reshape_type(self, reshape_type):
865        """
866        Reshape type of operator.
867
868        Args:
869            reshape_type (str): Value of reshape type. For example, if the input shape is :math:`(2, 3)`
870                and `reshape_type` is set to "CH", then the new shape is :math:`(1, 2, 3, 1)`.
871                "CH" means the C and H dimensions are kept and
872                new dimensions are added for N and W dimension.
873        """
874        self._is_string(reshape_type)
875        self.reshape_type_ = reshape_type
876        return self
877
878    def dynamic_shape(self, dynamic_shape=False):
879        """
880        Whether the operator supports dynamic shape.
881
882        Args:
883            dynamic_shape (bool): Value of dynamic shape. Default: ``False`` .
884        """
885        self._is_bool(dynamic_shape)
886        self.dynamic_shape_support_ = dynamic_shape
887        return self
888
889    def dynamic_compile_static(self, dynamic_compile_static=False):
890        """
891        Whether the operator supports dynamic compile static.
892
893        Args:
894            dynamic_compile_static (bool): Value of dynamic compile static. Default: ``False`` .
895        """
896        if self._is_bool(dynamic_compile_static):
897            self.dynamic_compile_static_ = dynamic_compile_static
898        return self
899
900    def need_check_supported(self, need_check_supported=False):
901        """
902        Whether the operator needs check supports.
903
904        Args:
905            need_check_supported (bool): Value of need_check_supported. Default: ``False`` .
906        """
907        if self._is_bool(need_check_supported):
908            self.need_check_support_ = need_check_supported
909        return self
910
911    def is_dynamic_format(self, is_dynamic_format=False):
912        """
913        Whether the operator needs calop_select_format api.
914
915        Args:
916            is_dynamic_format (bool): Value of is_dynamic_format. Default: ``False`` .
917        """
918        if self._is_bool(is_dynamic_format):
919            self.dynamic_format_ = is_dynamic_format
920        return self
921
922    def op_pattern(self, pattern=None):
923        """
924        The behavior type of operator, such as broadcast, reduce and so on.
925
926        Args:
927            pattern (str): Value of op pattern, e.g. "broadcast", "reduce". Default: ``None`` .
928        """
929        if pattern is not None and self._is_string(pattern):
930            self.op_pattern_ = pattern
931        return self
932
933    def attr(self, name=None, param_type=None, value_type=None, value=None, default_value=None, **kwargs):
934        """
935        Register TBE op attribute information.
936
937        Args:
938            name (str): Name of the attribute. Default: ``None`` .
939            param_type (str): Param type of the attribute. Default: ``None`` .
940            value_type (str): Type of the attribute. Default: ``None`` .
941            value (str): Value of the attribute. Default: ``None`` .
942            default_value (str): Default value of attribute. Default: ``None`` .
943            kwargs (dict): Other information of the attribute.
944        """
945        param_list = [name, param_type, value_type, value, default_value]
946        key_list = ["name", "paramType", "type", "value", "defaultValue"]
947        fn_list = [self._is_string]
948        attr_dict = self._check_param(param_list, key_list, fn_list, kwargs)
949        self.attr_.append(attr_dict)
950        return self
951
952    def input(self, index=None, name=None, need_compile=None, param_type=None, shape=None, value_depend=None, **kwargs):
953        """
954        Register TBE op input information.
955
956        Args:
957            index (int): Order of the input. Default: ``None`` .
958            name (str): Name of the input. Default: ``None`` .
959            need_compile (bool): Whether the input needs to be compiled or not. Default: ``None`` .
960            param_type (str): Type of the input. Default: ``None`` .
961            shape (str): Shape of the input. Default: ``None`` .
962            value_depend (str): Whether the input is constant value depend. Default: ``None`` .
963            kwargs (dict): Other information of the input.
964        """
965        param_list = [index, name, need_compile, param_type, shape, value_depend]
966        key_list = ["index", "name", "needCompile", "paramType", "shape", "valueDepend"]
967        fn_list = [self._is_int, self._is_string, self._is_bool, self._is_string, self._is_string, self._is_string]
968        input_dict = self._check_param(param_list, key_list, fn_list, kwargs)
969        value_depend_values = ("ignored", "optional", "required")
970        if value_depend and value_depend.lower() not in value_depend_values:
971            raise ValueError("Operator {} input{}'s value_depend's value ({}) is not in {}.".
972                             format(self.op_name, index, value_depend, value_depend_values))
973        self.inputs.append(input_dict)
974        return self
975
976    def output(self, index=None, name=None, need_compile=None, param_type=None, shape=None, **kwargs):
977        """
978        Register TBE op output information.
979
980        Args:
981            index (int): Order of the output. Default: ``None`` .
982            name (str): Name of the output. Default: ``None`` .
983            need_compile (bool): Whether the output needs to be compiled or not. Default: ``None`` .
984            param_type (str): Type of the output. Default: ``None`` .
985            shape (str): Shape of the output. Default: ``None`` .
986            kwargs (dict): Other information of the output.
987        """
988        param_list = [index, name, need_compile, param_type, shape]
989        key_list = ["index", "name", "need_compile", "paramType", "shape"]
990        fn_list = [self._is_int, self._is_string, self._is_bool, self._is_string, self._is_string]
991        output_dict = self._check_param(param_list, key_list, fn_list, kwargs)
992        self.outputs.append(output_dict)
993        return self
994
995
996class CustomRegOp(RegOp):
997    r"""
998    Class used for generating the registration information for the `func` parameter of :class:`mindspore.ops.Custom`.
999    The registration information mainly specifies the supported data types and formats of input and output tensors,
1000    attributes and target of `func`.
1001
1002    Args:
1003        op_name (str): kernel name. The name will be record in the reg_op_name attr of the kernel node.
1004            Besides, the operator will generate a unique name automatically to identify the reg info.
1005            Default: ``"Custom"`` .
1006
1007    Examples:
1008        >>> from mindspore.ops import CustomRegOp, DataType
1009        >>> custom_op_ascend_info = CustomRegOp() \
1010        ...     .input(0, "x", "dynamic") \
1011        ...     .output(0, "y") \
1012        ...     .dtype_format(DataType.F16_Default, DataType.F16_Default) \
1013        ...     .dtype_format(DataType.F32_Default, DataType.F32_Default) \
1014        ...     .target("Ascend") \
1015        ...     .get_op_info()
1016    """
1017
1018    def __init__(self, op_name="Custom"):
1019        super(CustomRegOp, self).__init__(op_name)
1020        self.target_ = "UnKnown"
1021
1022    def input(self, index=None, name=None, param_type="required", **kwargs):
1023        """
1024        Specifies the input tensor information for the `func` parameter of :class:`mindspore.ops.Custom`. Each
1025        invocation of this function will generate one input tensor information, that means, if `func` has two input
1026        tensors, then this function should be invoked two times continuously. The input tensor information will be
1027        generated as a dict: {"index": `index`, "name": `name`, "param_type": `param_type`}.
1028
1029        Args:
1030            index (int): Index of the input, starts from 0. 0 means the first input tensor, 1 means the second input
1031                tensor and so on. If ``None`` , key "index" will not appear in the input tensor information dict.
1032                Default: ``None`` .
1033            name (str): Name of the `index` 'th input. If ``None`` , key "name" will not appear in the input tensor
1034                information dict. Default: ``None`` .
1035            param_type (str): Parameter type of the `index` 'th input, can be one of
1036                [``"required"`` , ``"dynamic"`` , ``"optional"`` ]. If ``None`` , key "param_type" will not appear in
1037                the input tensor information dict. Default: ``"required"`` .
1038
1039                - ``"required"``: means the `index` 'th input exist and can only be a single tensor.
1040                - ``"dynamic":`` means the `index` 'th input exist and may be multiple tensors, such as the input of
1041                  AddN.
1042                - ``"optional"``: means the `index` 'th input may exist and be a single tensor or may not exist.
1043
1044            kwargs (dict): Other information of the input, used for extension.
1045
1046        Raises:
1047            TypeError: If `index` is neither int nor None.
1048            TypeError: If `name` is neither str nor None.
1049            TypeError: If `param_type` is neither str nor None.
1050        """
1051        param_list = [index, name, param_type]
1052        key_list = ["index", "name", "paramType"]
1053        fn_list = [self._is_int, self._is_string, self._is_string]
1054        input_dict = self._check_param(param_list, key_list, fn_list, kwargs)
1055        self.inputs.append(input_dict)
1056        return self
1057
1058    def output(self, index=None, name=None, param_type="required", **kwargs):
1059        """
1060        Specifies the output tensor information for the `func` parameter of :class:`mindspore.ops.Custom`. Each
1061        invocation of this function will generate one output tensor information, which means, if `func` has two output
1062        tensors, then this function should be invoked two times continuously. The output tensor information will be
1063        generated as a dict: {"index": `index`, "name": `name`, "param_type": `param_type`}.
1064
1065        Args:
1066            index (int): Index of the output, starts from 0. 0 means the first output tensor, 1 means the second output
1067                tensor and so on. If ``None`` , key "index" will not appear in the output tensor information dict.
1068                Default: ``None`` .
1069            name (str): Name of the `index` 'th output. If ``None`` , key "name" will not appear in the output tensor
1070                information dict. Default: ``None`` .
1071            param_type (str): Parameter type of the `index` 'th output, can be one of
1072                [ ``"required"`` , ``"dynamic"`` , ``"optional"`` ]. If ``None`` , key "param_type" will not appear in
1073                the output tensor information dict. Default: ``"required"`` .
1074
1075                - ``"required"``: means the `index` 'th output exist and can only be a single tensor.
1076                - ``"dynamic"``: means the `index` 'th output exist and may be multiple tensors.
1077                - ``"optional"``: means the `index` 'th output may exist and be a single tensor or may not exist.
1078
1079            kwargs (dict): Other information of the output, used for extension.
1080
1081        Raises:
1082            TypeError: If `index` is neither int nor None.
1083            TypeError: If `name` is neither str nor None.
1084            TypeError: If `param_type` is neither str nor None.
1085        """
1086        param_list = [index, name, param_type]
1087        key_list = ["index", "name", "paramType"]
1088        fn_list = [self._is_int, self._is_string, self._is_string]
1089        output_dict = self._check_param(param_list, key_list, fn_list, kwargs)
1090        self.outputs.append(output_dict)
1091        return self
1092
1093    def dtype_format(self, *args):
1094        """
1095        Specifies the supported data type and format of each input tensor and output tensor for the `func` parameter
1096        of :class:`mindspore.ops.Custom`. This function should be invoked after `input` and `output` function as shown
1097        in the above example.
1098
1099        Args:
1100            args (tuple): A tuple of (data type, format) pair, the length of `args` should be equal to the sum of input
1101                tensors and output tensors. Each item in `args` is also a tuple, tuple[0] and tuple[1] are both str
1102                type, which specifies the data type and format of a tensor respectively. :class:`mindspore.ops.DataType`
1103                provides many predefined (data type, format) combinations, for example, `DataType.F16_Default` means the
1104                data type is float16 and the format is default format.
1105
1106        Raises:
1107            ValueError: If the size of `args` not equal to the sum of input tensors and output tensors.
1108        """
1109        io_nums = len(self.inputs) + len(self.outputs)
1110        if len(args) != io_nums:
1111            raise ValueError("The size of 'args' must be equal to the sum of input tensors and output tensors, but got "
1112                             "{} vs {}".format(len(args), io_nums))
1113        return super(CustomRegOp, self).dtype_format(*args)
1114
1115    def attr(self, name=None, param_type=None, value_type=None, default_value=None, **kwargs):
1116        """
1117        Specifies the attributes information for the `func` parameter of :class:`mindspore.ops.Custom`. Each
1118        invocation of this function will generate one attribute information, that means, if `func` has two attributes,
1119        then this function should be invoked two times continuously. The attributes information will be
1120        generated as a dict: {"name": `name`, "param_type": `param_type`, "value_type": `value_type`, "default_value":
1121        `default_value`}.
1122
1123        Args:
1124            name (str): Name of the attribute. If ``None`` , key "name" will not appear in the attributes tensor
1125                information dict. Default: ``None`` .
1126            param_type (str): Parameter type of the attribute, can be one of ["required", "optional"]. If ``None`` ,
1127                key "param_type" will not appear in the attributes tensor information dict. Default: ``None`` .
1128
1129                - "required": means must provide a value for this attribute either by setting a default value in the
1130                  registration information or providing an input value when calling the Custom operator.
1131                - "optional": means does not have to provide a value for this attribute.
1132
1133            value_type (str): Value type of the attribute, can be one of ["int", "str", "bool", "float", "listInt",
1134                "listStr", "listBool", "listFloat"]. If ``None`` , key "value_type" will not appear in the attributes
1135                tensor information dict. Default: ``None`` .
1136
1137                - "int": string representation of Python type int.
1138                - "str": string representation of Python type str.
1139                - "bool": string representation of Python type bool.
1140                - "float": string representation of Python type float.
1141                - "listInt": string representation of Python type list of int.
1142                - "listStr": string representation of Python type list of str.
1143                - "listBool": string representation of Python type list of bool.
1144                - "listFloat": string representation of Python type list of float.
1145
1146            default_value (str): Default value of the attribute. `default_value` and `value_type` are used together.
1147                If the real default value of the attribute is float type with value 1.0, then the `value_type` should be
1148                "float" and `default_value` should be "1.0". If the real default value of the attribute is a list of int
1149                with value [1, 2, 3], then the `value_type` should be "listInt" and `default_value` should be "1,2,3",
1150                each item should split by ','. If ``None`` , means the attribute has no default value and key
1151                "default_value" will not appear in the attributes tensor information dict. It is used for "akg",
1152                "aicpu" and "tbe" Custom operators currently. Default: ``None`` .
1153            kwargs (dict): Other information of the attribute, used for extension.
1154
1155        Raises:
1156            TypeError: If `name` is neither str nor None.
1157            TypeError: If `param_type` is neither str nor None.
1158            TypeError: If `value_type` is neither str nor None.
1159            TypeError: If `default_value` is neither str nor None.
1160        """
1161        param_list = [name, param_type, value_type, default_value]
1162        key_list = ["name", "paramType", "type", "defaultValue"]
1163        fn_list = [self._is_string]
1164        attr_dict = self._check_param(param_list, key_list, fn_list, kwargs)
1165        self.attr_.append(attr_dict)
1166        return self
1167
1168    def target(self, target=None):
1169        """
1170        Specifies the target that this registration information is used for.
1171
1172        Args:
1173            target (str): Device target for current operator information, should be one of ["Ascend", "GPU", "CPU"].
1174                For the same `func` of :class:`mindspore.ops.Custom`, it may support different data types and formats
1175                on different targets, use `target` to specify which target that this registration information is used
1176                for. If ``None`` , it will be inferred automatically inside :class:`mindspore.ops.Custom`.
1177                Default: ``None`` .
1178
1179        Raises:
1180            TypeError: If `target` is neither str nor None.
1181        """
1182        if target is not None:
1183            self._is_string(target)
1184        self.target_ = target
1185        return self
1186
1187    def get_op_info(self):
1188        """
1189        Return the generated registration information as a dict. This function should be invoked at last on the
1190        `CustomRegOp` instance as shown in the above example.
1191        """
1192        op_info = {}
1193        for k, v in self.__dict__.items():
1194            if isinstance(k, str) and k.endswith('_'):
1195                k = k.rstrip('_')
1196            op_info[k] = v
1197        if _get_reg_info_attr(op_info, "cust_aicpu"):
1198            _CustomInstaller(op_info).run()
1199        return op_info
1200
1201
1202class DataType:
1203    r"""
1204    Various combinations of dtype and format of Ascend ops.
1205
1206    current support:
1207
1208    .. code-block::
1209
1210        None_None = ("", "")
1211        None_Default = ("", "DefaultFormat")
1212        BOOL_None = ("bool", "")
1213        BOOL_Default = ("bool", "DefaultFormat")
1214        BOOL_5HD = ("bool", "NC1HWC0")
1215        BOOL_FracZ = ("bool", "FRACTAL_Z")
1216        BOOL_FracNZ = ("bool", "FRACTAL_NZ")
1217        BOOL_C1HWNCoC0 = ("bool", "C1HWNCoC0")
1218        BOOL_NCHW = ("bool", "NCHW")
1219        BOOL_NHWC = ("bool", "NHWC")
1220        BOOL_HWCN = ("bool", "HWCN")
1221        BOOL_NDHWC = ("bool", "NDHWC")
1222        BOOL_ChannelLast = ("bool", "ChannelLast")
1223
1224        I8_None = ("int8", "")
1225        I8_Default = ("int8", "DefaultFormat")
1226        I8_5HD = ("int8", "NC1HWC0")
1227        I8_FracZ = ("int8", "FRACTAL_Z")
1228        I8_FracNZ = ("int8", "FRACTAL_NZ")
1229        I8_C1HWNCoC0 = ("int8", "C1HWNCoC0")
1230        I8_NCHW = ("int8", "NCHW")
1231        I8_NHWC = ("int8", "NHWC")
1232        I8_HWCN = ("int8", "HWCN")
1233        I8_NDHWC = ("int8", "NDHWC")
1234        I8_ChannelLast = ("int8", "ChannelLast")
1235        I8_NDC1HWC0 = ("int8", "NDC1HWC0")
1236
1237        U8_None = ("uint8", "")
1238        U8_Default = ("uint8", "DefaultFormat")
1239        U8_5HD = ("uint8", "NC1HWC0")
1240        U8_FracZ = ("uint8", "FRACTAL_Z")
1241        U8_FracNZ = ("uint8", "FRACTAL_NZ")
1242        U8_C1HWNCoC0 = ("uint8", "C1HWNCoC0")
1243        U8_NCHW = ("uint8", "NCHW")
1244        U8_NHWC = ("uint8", "NHWC")
1245        U8_HWCN = ("uint8", "HWCN")
1246        U8_NDHWC = ("uint8", "NDHWC")
1247        U8_ChannelLast = ("uint8", "ChannelLast")
1248        U8_NDC1HWC0 = ("uint8", "NDC1HWC0")
1249
1250        I16_None = ("int16", "")
1251        I16_Default = ("int16", "DefaultFormat")
1252        I16_5HD = ("int16", "NC1HWC0")
1253        I16_FracZ = ("int16", "FRACTAL_Z")
1254        I16_FracNZ = ("int16", "FRACTAL_NZ")
1255        I16_C1HWNCoC0 = ("int16", "C1HWNCoC0")
1256        I16_NCHW = ("int16", "NCHW")
1257        I16_NHWC = ("int16", "NHWC")
1258        I16_HWCN = ("int16", "HWCN")
1259        I16_NDHWC = ("int16", "NDHWC")
1260        I16_ChannelLast = ("int16", "ChannelLast")
1261
1262        U16_None = ("uint16", "")
1263        U16_Default = ("uint16", "DefaultFormat")
1264        U16_5HD = ("uint16", "NC1HWC0")
1265        U16_FracZ = ("uint16", "FRACTAL_Z")
1266        U16_FracNZ = ("uint16", "FRACTAL_NZ")
1267        U16_C1HWNCoC0 = ("uint16", "C1HWNCoC0")
1268        U16_NCHW = ("uint16", "NCHW")
1269        U16_NHWC = ("uint16", "NHWC")
1270        U16_HWCN = ("uint16", "HWCN")
1271        U16_NDHWC = ("uint16", "NDHWC")
1272        U16_ChannelLast = ("uint16", "ChannelLast")
1273
1274        I32_None = ("int32", "")
1275        I32_Default = ("int32", "DefaultFormat")
1276        I32_5HD = ("int32", "NC1HWC0")
1277        I32_FracZ = ("int32", "FRACTAL_Z")
1278        I32_FracNZ = ("int32", "FRACTAL_NZ")
1279        I32_C1HWNCoC0 = ("int32", "C1HWNCoC0")
1280        I32_NCHW = ("int32", "NCHW")
1281        I32_NHWC = ("int32", "NHWC")
1282        I32_HWCN = ("int32", "HWCN")
1283        I32_NDHWC = ("int32", "NDHWC")
1284        I32_ChannelLast = ("int32", "ChannelLast")
1285
1286        U32_None = ("uint32", "")
1287        U32_Default = ("uint32", "DefaultFormat")
1288        U32_5HD = ("uint32", "NC1HWC0")
1289        U32_FracZ = ("uint32", "FRACTAL_Z")
1290        U32_FracNZ = ("uint32", "FRACTAL_NZ")
1291        U32_C1HWNCoC0 = ("uint32", "C1HWNCoC0")
1292        U32_NCHW = ("uint32", "NCHW")
1293        U32_NHWC = ("uint32", "NHWC")
1294        U32_HWCN = ("uint32", "HWCN")
1295        U32_NDHWC = ("uint32", "NDHWC")
1296        U32_ChannelLast = ("uint32", "ChannelLast")
1297
1298        I64_None = ("int64", "")
1299        I64_Default = ("int64", "DefaultFormat")
1300        I64_5HD = ("int64", "NC1HWC0")
1301        I64_FracZ = ("int64", "FRACTAL_Z")
1302        I64_FracNZ = ("int64", "FRACTAL_NZ")
1303        I64_C1HWNCoC0 = ("int64", "C1HWNCoC0")
1304        I64_NCHW = ("int64", "NCHW")
1305        I64_NHWC = ("int64", "NHWC")
1306        I64_HWCN = ("int64", "HWCN")
1307        I64_NDHWC = ("int64", "NDHWC")
1308        I64_ChannelLast = ("int64", "ChannelLast")
1309
1310        U64_None = ("uint64", "")
1311        U64_Default = ("uint64", "DefaultFormat")
1312        U64_5HD = ("uint64", "NC1HWC0")
1313        U64_FracZ = ("uint64", "FRACTAL_Z")
1314        U64_FracNZ = ("uint64", "FRACTAL_NZ")
1315        U64_C1HWNCoC0 = ("uint64", "C1HWNCoC0")
1316        U64_NCHW = ("uint64", "NCHW")
1317        U64_NHWC = ("uint64", "NHWC")
1318        U64_HWCN = ("uint64", "HWCN")
1319        U64_NDHWC = ("uint64", "NDHWC")
1320        U64_ChannelLast = ("uint64", "ChannelLast")
1321
1322        F16_None = ("float16", "")
1323        F16_Default = ("float16", "DefaultFormat")
1324        F16_5HD = ("float16", "NC1HWC0")
1325        F16_FracZ = ("float16", "FRACTAL_Z")
1326        F16_FracNZ = ("float16", "FRACTAL_NZ")
1327        F16_C1HWNCoC0 = ("float16", "C1HWNCoC0")
1328        F16_NCHW = ("float16", "NCHW")
1329        F16_NHWC = ("float16", "NHWC")
1330        F16_HWCN = ("float16", "HWCN")
1331        F16_NDHWC = ("float16", "NDHWC")
1332        F16_NCDHW = ("float16", "NCDHW")
1333        F16_DHWCN = ("float16", "DHWCN")
1334        F16_NDC1HWC0 = ("float16", "NDC1HWC0")
1335        F16_FRACTAL_Z_3D = ("float16", "FRACTAL_Z_3D")
1336        F16_FracZNLSTM = ("float16", "FRACTAL_ZN_LSTM")
1337        F16_FracZNRNN = ("float16", "FRACTAL_ZN_RNN")
1338        F16_ND_RNNBIAS = ("float16", "ND_RNN_BIAS")
1339        F16_ChannelLast = ("float16", "ChannelLast")
1340
1341        F32_None = ("float32", "")
1342        F32_Default = ("float32", "DefaultFormat")
1343        F32_5HD = ("float32", "NC1HWC0")
1344        F32_FracZ = ("float32", "FRACTAL_Z")
1345        F32_FracNZ = ("float32", "FRACTAL_NZ")
1346        F32_C1HWNCoC0 = ("float32", "C1HWNCoC0")
1347        F32_NCHW = ("float32", "NCHW")
1348        F32_NHWC = ("float32", "NHWC")
1349        F32_HWCN = ("float32", "HWCN")
1350        F32_NDHWC = ("float32", "NDHWC")
1351        F32_NCDHW = ("float32", "NCDHW")
1352        F32_DHWCN = ("float32", "DHWCN")
1353        F32_NDC1HWC0 = ("float32", "NDC1HWC0")
1354        F32_FRACTAL_Z_3D = ("float32", "FRACTAL_Z_3D")
1355        F32_FracZNLSTM = ("float32", "FRACTAL_ZN_LSTM")
1356        F32_FracZNRNN = ("float32", "FRACTAL_ZN_RNN")
1357        F32_ND_RNNBIAS = ("float32", "ND_RNN_BIAS")
1358        F32_ChannelLast = ("float32", "ChannelLast")
1359
1360        F64_None = ("float64", "")
1361        F64_Default = ("float64", "DefaultFormat")
1362        F64_5HD = ("float64", "NC1HWC0")
1363        F64_FracZ = ("float64", "FRACTAL_Z")
1364        F64_FracNZ = ("float64", "FRACTAL_NZ")
1365        F64_C1HWNCoC0 = ("float64", "C1HWNCoC0")
1366        F64_NCHW = ("float64", "NCHW")
1367        F64_NHWC = ("float64", "NHWC")
1368        F64_HWCN = ("float64", "HWCN")
1369        F64_NDHWC = ("float64", "NDHWC")
1370        F64_ChannelLast = ("float64", "ChannelLast")
1371
1372        C64_Default = ("complex64", "DefaultFormat")
1373        C128_Default = ("complex128", "DefaultFormat")
1374    """
1375
1376    None_None = ("", "")
1377    None_Default = ("", "DefaultFormat")
1378
1379    BOOL_None = ("bool", "")
1380    BOOL_Default = ("bool", "DefaultFormat")
1381    BOOL_5HD = ("bool", "NC1HWC0")
1382    BOOL_FracZ = ("bool", "FRACTAL_Z")
1383    BOOL_FracNZ = ("bool", "FRACTAL_NZ")
1384    BOOL_C1HWNCoC0 = ("bool", "C1HWNCoC0")
1385    BOOL_NCHW = ("bool", "NCHW")
1386    BOOL_NHWC = ("bool", "NHWC")
1387    BOOL_HWCN = ("bool", "HWCN")
1388    BOOL_NDHWC = ("bool", "NDHWC")
1389    BOOL_ChannelLast = ("bool", "ChannelLast")
1390    BOOL_Default_Tuple = ("bool", "DefaultFormat", "tuple")
1391    BOOL_Default_List = ("bool", "DefaultFormat", "list")
1392
1393    I8_None = ("int8", "")
1394    I8_Default = ("int8", "DefaultFormat")
1395    I8_5HD = ("int8", "NC1HWC0")
1396    I8_FracZ = ("int8", "FRACTAL_Z")
1397    I8_FracNZ = ("int8", "FRACTAL_NZ")
1398    I8_C1HWNCoC0 = ("int8", "C1HWNCoC0")
1399    I8_NCHW = ("int8", "NCHW")
1400    I8_NHWC = ("int8", "NHWC")
1401    I8_HWCN = ("int8", "HWCN")
1402    I8_NDHWC = ("int8", "NDHWC")
1403    I8_NCDHW = ("int8", "NCDHW")
1404    I8_ChannelLast = ("int8", "ChannelLast")
1405    I8_NDC1HWC0 = ("int8", "NDC1HWC0")
1406    I8_NC1HWC0 = ("int8", "NC1HWC0")
1407    I8_Default_Tuple = ("int8", "DefaultFormat", "tuple")
1408    I8_Default_List = ("int8", "DefaultFormat", "list")
1409
1410    U8_None = ("uint8", "")
1411    U8_Default = ("uint8", "DefaultFormat")
1412    U8_5HD = ("uint8", "NC1HWC0")
1413    U8_FracZ = ("uint8", "FRACTAL_Z")
1414    U8_FracNZ = ("uint8", "FRACTAL_NZ")
1415    U8_C1HWNCoC0 = ("uint8", "C1HWNCoC0")
1416    U8_NCHW = ("uint8", "NCHW")
1417    U8_NHWC = ("uint8", "NHWC")
1418    U8_HWCN = ("uint8", "HWCN")
1419    U8_NDHWC = ("uint8", "NDHWC")
1420    U8_NCDHW = ("uint8", "NCDHW")
1421    U8_ChannelLast = ("uint8", "ChannelLast")
1422    U8_NDC1HWC0 = ("uint8", "NDC1HWC0")
1423    U8_NC1HWC0 = ("uint8", "NC1HWC0")
1424    U8_Default_Tuple = ("uint8", "DefaultFormat", "tuple")
1425    U8_Default_List = ("uint8", "DefaultFormat", "list")
1426
1427    I16_None = ("int16", "")
1428    I16_Default = ("int16", "DefaultFormat")
1429    I16_5HD = ("int16", "NC1HWC0")
1430    I16_FracZ = ("int16", "FRACTAL_Z")
1431    I16_FracNZ = ("int16", "FRACTAL_NZ")
1432    I16_C1HWNCoC0 = ("int16", "C1HWNCoC0")
1433    I16_NCHW = ("int16", "NCHW")
1434    I16_NHWC = ("int16", "NHWC")
1435    I16_HWCN = ("int16", "HWCN")
1436    I16_NDHWC = ("int16", "NDHWC")
1437    I16_ChannelLast = ("int16", "ChannelLast")
1438    I16_Default_Tuple = ("int16", "DefaultFormat", "tuple")
1439    I16_Default_List = ("int16", "DefaultFormat", "list")
1440
1441    U16_None = ("uint16", "")
1442    U16_Default = ("uint16", "DefaultFormat")
1443    U16_5HD = ("uint16", "NC1HWC0")
1444    U16_FracZ = ("uint16", "FRACTAL_Z")
1445    U16_FracNZ = ("uint16", "FRACTAL_NZ")
1446    U16_C1HWNCoC0 = ("uint16", "C1HWNCoC0")
1447    U16_NCHW = ("uint16", "NCHW")
1448    U16_NHWC = ("uint16", "NHWC")
1449    U16_HWCN = ("uint16", "HWCN")
1450    U16_NDHWC = ("uint16", "NDHWC")
1451    U16_ChannelLast = ("uint16", "ChannelLast")
1452    U16_Default_Tuple = ("uint16", "DefaultFormat", "tuple")
1453    U16_Default_List = ("uint16", "DefaultFormat", "list")
1454
1455    I32_None = ("int32", "")
1456    I32_Default = ("int32", "DefaultFormat")
1457    I32_5HD = ("int32", "NC1HWC0")
1458    I32_FracZ = ("int32", "FRACTAL_Z")
1459    I32_FracNZ = ("int32", "FRACTAL_NZ")
1460    I32_C1HWNCoC0 = ("int32", "C1HWNCoC0")
1461    I32_NCHW = ("int32", "NCHW")
1462    I32_NHWC = ("int32", "NHWC")
1463    I32_HWCN = ("int32", "HWCN")
1464    I32_NDHWC = ("int32", "NDHWC")
1465    I32_NDC1HWC0 = ("int32", "NDC1HWC0")
1466    I32_NCDHW = ("int32", "NCDHW")
1467    I32_ChannelLast = ("int32", "ChannelLast")
1468    I32_Default_Tuple = ("int32", "DefaultFormat", "tuple")
1469    I32_Default_List = ("int32", "DefaultFormat", "list")
1470
1471    U32_None = ("uint32", "")
1472    U32_Default = ("uint32", "DefaultFormat")
1473    U32_5HD = ("uint32", "NC1HWC0")
1474    U32_FracZ = ("uint32", "FRACTAL_Z")
1475    U32_FracNZ = ("uint32", "FRACTAL_NZ")
1476    U32_C1HWNCoC0 = ("uint32", "C1HWNCoC0")
1477    U32_NCHW = ("uint32", "NCHW")
1478    U32_NHWC = ("uint32", "NHWC")
1479    U32_HWCN = ("uint32", "HWCN")
1480    U32_NDHWC = ("uint32", "NDHWC")
1481    U32_ChannelLast = ("uint32", "ChannelLast")
1482    U32_Default_Tuple = ("uint32", "DefaultFormat", "tuple")
1483    U32_Default_List = ("uint32", "DefaultFormat", "list")
1484
1485    I64_None = ("int64", "")
1486    I64_Default = ("int64", "DefaultFormat")
1487    I64_5HD = ("int64", "NC1HWC0")
1488    I64_FracZ = ("int64", "FRACTAL_Z")
1489    I64_FracNZ = ("int64", "FRACTAL_NZ")
1490    I64_C1HWNCoC0 = ("int64", "C1HWNCoC0")
1491    I64_NCHW = ("int64", "NCHW")
1492    I64_NHWC = ("int64", "NHWC")
1493    I64_HWCN = ("int64", "HWCN")
1494    I64_NDHWC = ("int64", "NDHWC")
1495    I64_ChannelLast = ("int64", "ChannelLast")
1496    I64_Default_Tuple = ("int64", "DefaultFormat", "tuple")
1497    I64_Default_List = ("int64", "DefaultFormat", "list")
1498
1499    U64_None = ("uint64", "")
1500    U64_Default = ("uint64", "DefaultFormat")
1501    U64_5HD = ("uint64", "NC1HWC0")
1502    U64_FracZ = ("uint64", "FRACTAL_Z")
1503    U64_FracNZ = ("uint64", "FRACTAL_NZ")
1504    U64_C1HWNCoC0 = ("uint64", "C1HWNCoC0")
1505    U64_NCHW = ("uint64", "NCHW")
1506    U64_NHWC = ("uint64", "NHWC")
1507    U64_HWCN = ("uint64", "HWCN")
1508    U64_NDHWC = ("uint64", "NDHWC")
1509    U64_ChannelLast = ("uint64", "ChannelLast")
1510    U64_Default_Tuple = ("uint64", "DefaultFormat", "tuple")
1511    U64_Default_List = ("uint64", "DefaultFormat", "list")
1512
1513    F16_None = ("float16", "")
1514    F16_Default = ("float16", "DefaultFormat")
1515    F16_5HD = ("float16", "NC1HWC0")
1516    F16_FracZ = ("float16", "FRACTAL_Z")
1517    F16_FracNZ = ("float16", "FRACTAL_NZ")
1518    F16_C1HWNCoC0 = ("float16", "C1HWNCoC0")
1519    F16_NCHW = ("float16", "NCHW")
1520    F16_NHWC = ("float16", "NHWC")
1521    F16_HWCN = ("float16", "HWCN")
1522    F16_NDHWC = ("float16", "NDHWC")
1523    F16_NCDHW = ("float16", "NCDHW")
1524    F16_DHWCN = ("float16", "DHWCN")
1525    F16_NDC1HWC0 = ("float16", "NDC1HWC0")
1526    F16_FRACTAL_Z_3D = ("float16", "FRACTAL_Z_3D")
1527    F16_FracZNLSTM = ("float16", "FRACTAL_ZN_LSTM")
1528    F16_FracZNRNN = ("float16", "FRACTAL_ZN_RNN")
1529    F16_ND_RNNBIAS = ("float16", "ND_RNN_BIAS")
1530    F16_ChannelLast = ("float16", "ChannelLast")
1531    F16_Default_Tuple = ("float16", "DefaultFormat", "tuple")
1532    F16_Default_List = ("float16", "DefaultFormat", "list")
1533
1534    F32_None = ("float32", "")
1535    F32_Default = ("float32", "DefaultFormat")
1536    F32_5HD = ("float32", "NC1HWC0")
1537    F32_FracZ = ("float32", "FRACTAL_Z")
1538    F32_FracNZ = ("float32", "FRACTAL_NZ")
1539    F32_C1HWNCoC0 = ("float32", "C1HWNCoC0")
1540    F32_NCHW = ("float32", "NCHW")
1541    F32_NHWC = ("float32", "NHWC")
1542    F32_HWCN = ("float32", "HWCN")
1543    F32_NDHWC = ("float32", "NDHWC")
1544    F32_NCDHW = ("float32", "NCDHW")
1545    F32_DHWCN = ("float32", "DHWCN")
1546    F32_NDC1HWC0 = ("float32", "NDC1HWC0")
1547    F32_FRACTAL_Z_3D = ("float32", "FRACTAL_Z_3D")
1548    F32_FracZNLSTM = ("float32", "FRACTAL_ZN_LSTM")
1549    F32_FracZNRNN = ("float32", "FRACTAL_ZN_RNN")
1550    F32_ND_RNNBIAS = ("float32", "ND_RNN_BIAS")
1551    F32_ChannelLast = ("float32", "ChannelLast")
1552    F32_Default_Tuple = ("float32", "DefaultFormat", "tuple")
1553    F32_Default_List = ("float32", "DefaultFormat", "list")
1554
1555    F64_None = ("float64", "")
1556    F64_Default = ("float64", "DefaultFormat")
1557    F64_5HD = ("float64", "NC1HWC0")
1558    F64_FracZ = ("float64", "FRACTAL_Z")
1559    F64_FracNZ = ("float64", "FRACTAL_NZ")
1560    F64_C1HWNCoC0 = ("float64", "C1HWNCoC0")
1561    F64_NCHW = ("float64", "NCHW")
1562    F64_NHWC = ("float64", "NHWC")
1563    F64_HWCN = ("float64", "HWCN")
1564    F64_NDHWC = ("float64", "NDHWC")
1565    F64_ChannelLast = ("float64", "ChannelLast")
1566    F64_Default_Tuple = ("float64", "DefaultFormat", "tuple")
1567    F64_Default_List = ("float64", "DefaultFormat", "list")
1568
1569    C64_Default = ("complex64", "DefaultFormat")
1570    C128_Default = ("complex128", "DefaultFormat")
1571    C64_Default_Tuple = ("complex64", "DefaultFormat", "tuple")
1572    C128_Default_Tuple = ("complex128", "DefaultFormat", "tuple")
1573