• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""tbe common"""
16import os
17
18
19class TBEException(Exception):
20    """tbe exception class"""
21
22    def __init__(self, err_msg):
23        super().__init__(self)
24        self.__error_msg = err_msg
25
26    def __str__(self):
27        return self.__error_msg
28
29
30def get_built_in_impl_path():
31    """get built-in tbe implement path"""
32    tbe_impl_path = os.environ.get("TBE_IMPL_PATH")
33    if tbe_impl_path is None:
34        default_install_path = '/usr/local/HiAI/runtime/ops/op_impl/built-in/ai_core/tbe/'
35        backup_install_path = '/usr/local/Ascend/opp/op_impl/built-in/ai_core/tbe/'
36        if os.path.exists(default_install_path):
37            tbe_impl_path = default_install_path
38        elif os.path.exists(backup_install_path):
39            tbe_impl_path = backup_install_path
40    if not tbe_impl_path:
41        raise ValueError("Can not find the env TBE_IMPL_PATH")
42    return tbe_impl_path
43
44
45def _check_arg_info(item):
46    """
47    Check parameter Validity.
48
49    Args:
50        item (dict): A dict, to be checked.
51
52    Raises:
53        Exception: If specific keyword is not found.
54    """
55    if 'shape' not in item:
56        raise ValueError("Json string Errors, key:shape not found.")
57    if 'ori_shape' not in item:
58        raise ValueError("Json string Errors, key:ori_shape not found.")
59    if 'format' not in item or not item['format']:
60        raise ValueError("Json string Errors, key:format not found.")
61    if 'ori_format' not in item or not item['ori_format']:
62        raise ValueError("Json string Errors, key:ori_format not found.")
63    if 'dtype' not in item or not item['dtype']:
64        raise ValueError("Json string Errors, key:dtype not found.")
65    if 'param_type' not in item or not item['param_type']:
66        raise ValueError("Json string Errors, key:param_type not found.")
67
68
69def get_input_output(io_info, args):
70    """
71    Parse args.
72
73    Args:
74        io_info (dict): input or output info dict.
75        args (list): the arguments list.
76
77    Raises:
78        Exception: If specific keyword is not found.
79    """
80    for item in io_info:
81        arg = []
82        for info in item:
83            if 'valid' not in info:
84                raise ValueError("Json string Errors, key:valid not found.")
85            if info['valid']:
86                _check_arg_info(info)
87                del info['valid']
88                del info['name']
89                if len(item) > 1:
90                    arg.append(info)
91                else:
92                    if info['param_type'] == 'dynamic':
93                        arg.append(info)
94                    else:
95                        args.append(info)
96            else:
97                if len(item) > 1:
98                    arg.append(None)
99                else:
100                    args.append(None)
101        if arg:
102            args.append(tuple(arg))
103
104
105def get_attr(attr_info, args):
106    """
107    Parse args.
108
109    Args:
110        attr_info (dict): input or output info dict.
111        args (list): the arguments list.
112
113    Raises:
114        Exception: If specific keyword is not found.
115    """
116    for item in attr_info:
117        if item["valid"]:
118            if 'value' not in item:
119                raise ValueError("Json string Errors, attr key:value not found.")
120            if item["name"] != "isRef":
121                args.append(item['value'])
122
123
124def get_args(op_info, arg_type):
125    """
126    Parse args.
127
128    Args:
129        op_info (dict): Op info dict.
130        arg_type (str): arg, to be parsed.
131
132    Raises:
133        Exception: If specific keyword is not found.
134    """
135    if arg_type not in op_info:
136        raise ValueError("Json string Errors, key:{} not found.".format(arg_type))
137    args = []
138    if not op_info[arg_type]:
139        return args
140
141    arg_info = op_info[arg_type]
142    if arg_type in ['inputs', 'outputs']:
143        get_input_output(arg_info, args)
144    elif arg_type == 'attrs':
145        get_attr(arg_info, args)
146
147    return args
148
149
150def check_kernel_info(kernel_info):
151    if 'op_info' not in kernel_info or not kernel_info['op_info']:
152        raise ValueError("Json string Errors, key:op_info not found.")
153    if 'name' not in kernel_info['op_info'] or not kernel_info['op_info']['name']:
154        raise ValueError("Json string Errors, key:name not found.")
155    if 'kernel_name' not in kernel_info['op_info'] or not kernel_info['op_info']['kernel_name']:
156        raise ValueError("Json string Errors, key:kernel_name not found.")
157