• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2023 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""pyboost utils."""
16
17import os
18import logging
19from gen_utils import safe_load_yaml
20
21
22def is_optional_param(op_arg):
23    if op_arg.as_init_arg and str(op_arg.default) == 'None':
24        return True
25    return False
26
27
28def is_tensor(op_arg):
29    if op_arg.arg_dtype == 'tensor':
30        return True
31    return False
32
33
34def is_tensor_list(op_arg):
35    if op_arg.arg_dtype in ['list[tensor]', 'tuple[tensor]']:
36        return True
37    return False
38
39
40def is_list(op_arg):
41    if op_arg.arg_dtype in ['tuple[int]', 'tuple[float]', 'tuple[bool]',
42                            'tuple[tensor]', 'list[int]', 'list[bool]', 'list[tensor]']:
43        return True
44    return False
45
46
47def get_index(index: int):
48    """
49    get index
50    :param index:
51    :return: str
52    """
53    return "kIndex" + str(index)
54
55
56def get_convert_type_str(dtype: str, optional):
57    """
58    Convert type
59    """
60    # add more type here
61    native_type_convert = {
62        'int': 'ToInt',
63        'float': 'ToFloat',
64        'bool': 'ToBool',
65        'number': 'ToScalar',
66        'tuple[int]': 'ToIntList<py::tuple>',
67        'tuple[float]': 'ToFloatList<py::tuple>',
68        'tuple[bool]': 'ToBoolList<py::tuple>',
69        'tuple[tensor]': 'ToTensorList<py::tuple>',
70        'list[int]': 'ToIntList<py::list>',
71        'list[float]': 'ToFloatList<py::list>',
72        'list[bool]': 'ToBoolList<py::list>',
73        'list[tensor]': 'ToTensorList<py::list>',
74        'tensor': 'ToTensor',
75        'str': 'ToString',
76        'type': 'ToDtype',
77    }
78    optional_type_convert = {
79        'int': 'ToIntOptional',
80        'float': 'ToFloatOptional',
81        'number': 'ToScalarOptional',
82        'tensor': 'ToTensorOptional',
83        'type': 'ToDtypeOptional',
84        'str': 'ToStringOptional',
85        'tuple[int]': 'ToIntListOptional<py::tuple>',
86        'tuple[float]': 'ToFloatListOptional<py::tuple>',
87        'tuple[bool]': 'ToBoolListOptional<py::tuple>',
88        'tuple[tensor]': 'ToTensorListOptional<py::tuple>',
89        'list[int]': 'ToIntListOptional<py::list>',
90        'list[float]': 'ToFloatListOptional<py::list>',
91        'list[bool]': 'ToBoolListOptional<py::list>',
92        'list[tensor]': 'ToTensorListOptional<py::list>',
93    }
94    if optional:
95        if dtype in optional_type_convert:
96            return optional_type_convert[dtype]
97        raise TypeError(f"""Unsupported convert optional type {dtype} for args.""")
98    if dtype in native_type_convert:
99        return native_type_convert[dtype]
100    raise TypeError(f"""Unsupported convert type {dtype} for args.""")
101
102
103def get_value_convert_type_str(dtype: str, optional):
104    """
105    Convert type
106    """
107    # add more type here
108    native_type_convert = {
109        'int': 'ToInt',
110        'float': 'ToFloat',
111        'bool': 'ToBool',
112        'number': 'ToScalar',
113        'tensor': 'ToTensor',
114        'str': 'ToString',
115        'type': 'ToDtype',
116        'tuple[int]': 'ToValueTuple',
117        'tuple[float]': 'ToValueTuple',
118        'tuple[bool]': 'ToValueTuple',
119        'tuple[tensor]': 'ToValueTuple',
120    }
121    optional_type_convert = {
122        'int': 'ToIntOptional',
123        'float': 'ToFloatOptional',
124        'number': 'ToScalarOptional',
125        'tensor': 'ToTensorOptional',
126        'type': 'ToDtypeOptional',
127        'str': 'ToStringOptional',
128        'tuple[int]': 'ToValueTupleOptional',
129        'tuple[float]': 'ToValueTupleOptional',
130        'tuple[bool]': 'ToValueTupleOptional',
131        'tuple[tensor]': 'ToValueTupleOptional',
132    }
133    if optional:
134        if dtype in optional_type_convert:
135            return optional_type_convert[dtype]
136        raise TypeError(f"""Unsupported convert optional type {dtype} for args.""")
137    if dtype in native_type_convert:
138        return native_type_convert[dtype]
139    raise TypeError(f"""Unsupported convert type {dtype} for args.""")
140
141
142def tuple_input_to_cpp_type(dtype: str):
143    """
144    dtype convert
145    :param dtype:
146    :return:
147    """
148    types_map = {
149        'tuple[int]': 'int64_t',
150        'tuple[float]': 'float',
151        'tuple[bool]': 'bool',
152        'tuple[str]': 'string',
153        'tuple[tensor]': 'TensorPtr',
154        'list[int]': 'int64_t',
155        'list[float]': 'float',
156        'list[bool]': 'bool',
157        'list[tensor]': 'TensorPtr',
158    }
159    return types_map.get(dtype)
160
161
162def number_input_to_cpp_type(dtype: str):
163    types_map = {
164        'int': 'int64_t',
165        'float': 'float',
166        'bool': 'bool',
167        'str': 'string'
168    }
169    return types_map.get(dtype)
170
171
172def get_input_dtype(dtype: str, optional):
173    """
174    Convert type
175    """
176    # add more type here
177    value_tuple = 'ValueTuplePtr'
178    type_convert = {
179        'int': 'Int64ImmPtr',
180        'float': 'FP32ImmPtr',
181        'bool': 'BoolImmPtr',
182        'number': 'ScalarPtr',
183        'str': 'StringImmPtr',
184        'tensor': 'BaseTensorPtr',
185        'tuple[int]': value_tuple,
186        'tuple[float]': value_tuple,
187        'tuple[bool]': value_tuple,
188        'tuple[tensor]': value_tuple,
189        'list[int]': value_tuple,
190        'list[float]': value_tuple,
191        'list[bool]': value_tuple,
192        'list[tensor]': value_tuple,
193    }
194    value_tuple_optional = 'std::optional<ValueTuplePtr>'
195    optional_type_convert = {
196        'int': 'std::optional<Int64ImmPtr>',
197        'float': 'std::optional<FP32ImmPtr>',
198        'bool': 'std::optional<BoolImmPtr>',
199        'number': 'std::optional<ScalarPtr>',
200        'str': 'std::optional<StringImmPtr>',
201        'tensor': 'std::optional<BaseTensorPtr>',
202        'tuple[int]': value_tuple_optional,
203        'tuple[float]': value_tuple_optional,
204        'tuple[bool]': value_tuple_optional,
205        'tuple[tensor]': value_tuple_optional,
206    }
207    if optional:
208        if dtype in optional_type_convert:
209            return optional_type_convert[dtype]
210        raise TypeError(f"""Unsupported convert optional type {dtype} for args.""")
211    if dtype in type_convert:
212        return type_convert[dtype]
213    raise TypeError(f"""Unsupported convert type {dtype} for args.""")
214
215
216def is_cube(class_name):
217    cube_set = {'Bmm', 'Baddbmm', 'MatMulExt', 'Mv'}
218    if class_name in cube_set:
219        return True
220    return False
221
222
223def get_return_type(dtype: str):
224    """
225    Convert type
226    """
227    # add more type here
228    type_convert = {
229        'tuple[tensor]': 'std::vector<tensor::TensorPtr>',
230        'list[tensor]': 'std::vector<tensor::TensorPtr>',
231        'tensor': 'tensor::TensorPtr',
232    }
233    if dtype in type_convert:
234        return type_convert[dtype]
235    raise TypeError(f"""Unsupported convert type {dtype} for args.""")
236
237
238def get_disable_flag(yaml_def):
239    """
240    Get class or functional api disable generate flag.
241    """
242    disable_flag = False
243    if yaml_def is not None:
244        item = yaml_def.get("disable")
245        if item is not None:
246            if item is not True and item is not False:
247                raise TypeError(f"The disable label for function should be True or False, but get {item}.")
248            disable_flag = item
249    return disable_flag
250
251
252def get_op_name(operator_name, class_def):
253    """
254    Get op name for python class Primitive or c++ OpDef name.
255    """
256    class_name = ''.join(word.capitalize() for word in operator_name.split('_'))
257    if class_def is not None:
258        item = class_def.get("name")
259        if item is not None:
260            class_name = item
261    return class_name
262
263
264def get_pyboost_name(operator_name):
265    return 'pyboost_' + operator_name
266
267
268def convert_python_func_name_to_c(func_name: str) -> str:
269    return ''.join(word.capitalize() for word in func_name.split('_'))
270
271
272def get_const_number_convert(arg_name, op_arg):
273    cpp_type = number_input_to_cpp_type(op_arg.arg_dtype)
274    if op_arg.is_type_id:
275        return f"TypeId {arg_name}_imm = static_cast<TypeId>(GetValue<{cpp_type}>({arg_name}));\n"
276    return f"auto {arg_name}_imm = GetValue<{cpp_type}>({arg_name});\n"
277
278
279def get_tuple_input_convert(arg_name, arg_type):
280    """
281    convert tuple input.
282    :param arg_name:
283    :param arg_type:
284    :return:
285    """
286    cpp_type = tuple_input_to_cpp_type(arg_type)
287    if cpp_type == "TensorPtr":
288        cpp_type = "BaseTensorPtr"
289    return f"std::vector<{cpp_type}> {arg_name}_vector = ConvertValueTupleToVector<{cpp_type}>({arg_name});\n"
290
291
292def is_pyboost_enable(operator_data):
293    dispatch_key = 'dispatch'
294    if dispatch_key in operator_data.keys():
295        enable = operator_data[dispatch_key].get('enable')
296        if enable:
297            return True
298    return False
299
300
301def convert_types(inputs):
302    '''convert type to acl type'''
303    inputs_dtypes = {}
304    flag = False
305    for i in inputs:
306        inputs_dtypes[i] = inputs.get(i).get('dtype')
307        if inputs_dtypes[i] != 'tensor':
308            flag = True
309        if 'tuple' in inputs_dtypes[i]:
310            data_type = inputs_dtypes[i].split('[')[1].strip(']')
311            if data_type == 'tensor':
312                logging.info("Not support tuple[tensor] input.")
313            elif data_type == 'int':
314                inputs_dtypes[i] = 'std::vector<int64_t>'
315            elif data_type == 'float':
316                inputs_dtypes[i] = 'std::vector<float>'
317            elif data_type == 'bool':
318                inputs_dtypes[i] = 'std::vector<uint8_t>'
319            else:
320                logging.warning("Not support tuple[%s]] input.", data_type)
321        if inputs_dtypes[i] == 'number':
322            inputs_dtypes[i] = 'ScalarPtr'
323        if inputs_dtypes[i] == 'int':
324            inputs_dtypes[i] = 'int64_t'
325    return inputs_dtypes, flag
326
327
328def get_dtypes(op_yaml):
329    """get op inputs and outputs dtypes"""
330    inputs = op_yaml.get('args')
331    outputs = op_yaml.get('returns')
332    inputs_dtypes, flag_in = convert_types(inputs)
333    outputs_dtypes, flag_out = convert_types(outputs)
334    none_tensor_exist = (flag_in or flag_out)
335    return inputs_dtypes, outputs_dtypes, none_tensor_exist
336
337
338class AclnnUtils:
339    """
340    aclnn utils
341    """
342    work_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../../")
343    aclnn_map = safe_load_yaml(os.path.join(work_path, "./mindspore/python/mindspore/ops_generate/aclnn_config.yaml"))
344
345    @staticmethod
346    def get_aclnn_interface(class_name):
347        """
348        get aclnn interface name.
349        :param class_name:
350        :return:
351        """
352        if class_name in AclnnUtils.aclnn_map.keys():
353            return AclnnUtils.aclnn_map[class_name]
354        return "aclnn" + class_name
355