• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2023 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16Generate pyboost function from pyboost_op.yaml
17"""
18
19import os
20import re
21import pathlib
22from dataclasses import dataclass
23import pyboost_utils
24from pyboost_utils import get_convert_type_str, get_input_dtype, get_return_type, tuple_input_to_cpp_type, \
25    number_input_to_cpp_type, get_const_number_convert, get_tuple_input_convert, get_pyboost_name, is_cube, \
26    AclnnUtils, get_disable_flag, is_optional_param, get_value_convert_type_str, is_pyboost_enable
27import template
28from template import CppTemplate
29from op_proto import OpProto
30from gen_utils import check_change_and_replace_file, py_licence_str, write_file
31
32
33@dataclass
34class FuncHeaderData:
35    work_path: str
36    op_header_template_path: list
37    code_generate_path: list
38    op_name_str: str
39    operator_name: str
40    call_args_with_type: list
41    cpp_func_return: str
42
43
44def generate_pyboost_base_op_header_code(work_path, op_name_str, operator_name, call_args_with_type, cpp_func_return):
45    """ generate_pyboost_base_op_header_code """
46    pyboost_op_header_str = template.PYBOOST_BASE_OP_DEFINE_TEMPLATE.replace(op_name=op_name_str,
47                                                                             op_name_upper=op_name_str.upper(),
48                                                                             call_args=call_args_with_type,
49                                                                             return_type=cpp_func_return)
50    op_header_dir_path = os.path.join(work_path, "mindspore/ccsrc/kernel/pyboost/auto_generate/")
51    pathlib.Path(op_header_dir_path).mkdir(parents=True, exist_ok=True)
52    tmp_op_file_path = os.path.join(op_header_dir_path, "tmp_" + operator_name + ".h")
53    dst_op_file_path = os.path.join(op_header_dir_path, operator_name + ".h")
54    write_file(tmp_op_file_path, pyboost_op_header_str)
55    check_change_and_replace_file(dst_op_file_path, tmp_op_file_path)
56
57
58def generate_pyboost_op_header_code(header_data: FuncHeaderData):
59    """ generate_pyboost_op_header_code """
60
61    for tpl_path, gen_path in zip(header_data.op_header_template_path, header_data.code_generate_path):
62        pyboost_op_str = tpl_path.replace(op_name=header_data.op_name_str,
63                                          op_name_upper=header_data.op_name_str.upper(),
64                                          operator_name=header_data.operator_name,
65                                          call_args_with_type=header_data.call_args_with_type,
66                                          return_type=header_data.cpp_func_return)
67        op_header_dir_path = os.path.join(header_data.work_path, gen_path)
68        pathlib.Path(op_header_dir_path).mkdir(parents=True, exist_ok=True)
69        tmp_op_file_path = os.path.join(op_header_dir_path, "tmp_" + header_data.operator_name + ".h")
70        dst_op_file_path = os.path.join(op_header_dir_path, header_data.operator_name + ".h")
71        write_file(tmp_op_file_path, pyboost_op_str)
72        check_change_and_replace_file(dst_op_file_path, tmp_op_file_path)
73
74
75class TemplatePaths:
76    """
77    template paths for code auto generation
78    """
79
80    def __init__(self, op_header_template_path, op_call_template_path, op_source_template_path, op_custom_template_path,
81                 op_view_template_path, code_generate_path):
82        self.op_header_template_path = op_header_template_path
83        self.op_call_template_path = op_call_template_path
84        self.op_source_template_path = op_source_template_path
85        self.op_custom_template_path = op_custom_template_path
86        self.op_view_template_path = op_view_template_path
87        self.code_generate_path = code_generate_path
88
89
90def generate_malloc_input(need_malloc_tensors):
91    """
92    generate malloc inputs
93    :param need_malloc_tensors:
94    :return:
95    """
96    malloc_inputs = ''
97    args_list = ''
98    for item in need_malloc_tensors:
99        args_list += f'{item}, '
100    args_list = args_list[:-2]
101    if args_list:
102        malloc_inputs += f'PyBoostUtils::MallocOpInputs(device_context, {args_list});\n'
103    return malloc_inputs
104
105
106def generate_get_inputs_kernel_tensors(call_args):
107    """
108    generate get inputs kernel tensors
109    :param call_args:
110    :return:
111    """
112    inputs_kernel_tensors = ''
113    args_list = ''
114    for item in call_args:
115        args_list += f'{item}, '
116    args_list = args_list[:-2]
117    if args_list:
118        inputs_kernel_tensors += f'const auto &input_address_info = PyBoostUtils::GetAddressInfo(' \
119                                 f'device_context, op->stream_id(), op->input_abs(), {args_list});\n'
120    return inputs_kernel_tensors
121
122
123def generate_create_input_address(need_malloc_tensors):
124    """create input address"""
125    create_input_address = ''
126    args_list = ''
127    for item in need_malloc_tensors:
128        args_list += f'{item}, '
129    args_list = args_list[:-2]
130    if args_list:
131        create_input_address = f'PyBoostUtils::PrepareOpInputs(device_context_, op->stream_id(), {args_list});\n'
132    return create_input_address
133
134
135def generate_tensor_cpu_cast_input_code(call_args_with_tensor, call_tensors):
136    """ generate_tensor_cpu_cast_input_code """
137    cast_input = ""
138    real_call_args_tensor = call_args_with_tensor.copy()
139    for i, tensor in enumerate(call_args_with_tensor):
140        is_tuple_tensor = real_call_args_tensor[i].endswith("_vector")
141        is_tensor = real_call_args_tensor[i] in call_tensors
142        if is_tensor:
143            cast_input += f'const auto &real_{tensor} = PyBoostUtils::CastTensor({tensor}, ' \
144                          f'select_kernel.input_type()[{i}].dtype, "CPU");\n'
145            real_call_args_tensor[i] = "real_" + real_call_args_tensor[i]
146        if is_tuple_tensor:
147            cast_input += f'const auto &real_{tensor} = PyBoostUtils::CastTensor({tensor}, ' \
148                          f'select_kernel.input_type()[{i}].dtype, "CPU");\n'
149            real_call_args_tensor[i] = "PyBoostUtils::ConvertTensorVectorToTuple(real_" + real_call_args_tensor[i] + ")"
150    if cast_input != "":
151        cast_input = "auto &select_kernel = kernel_attr_pair.second;\n" + cast_input
152    return cast_input, real_call_args_tensor
153
154
155def generate_pyboost_op_source_code(work_path, op_proto, template_paths, converter):
156    """ generate_pyboost_op_source_code """
157    # PyBoost source generate
158    operator_name = converter.functional_name
159    call_args_tensor = []
160    for type, arg_name in zip(converter.call_args_types, converter.call_args):
161        if type in ("BaseTensorPtr", "std::optional<BaseTensorPtr>"):
162            call_args_tensor.append(arg_name)
163
164    for call_tpl, src_tpl, view_tpl, cus_tpl, gen_path in zip(template_paths.op_call_template_path,
165                                                              template_paths.op_source_template_path,
166                                                              template_paths.op_view_template_path,
167                                                              template_paths.op_custom_template_path,
168                                                              template_paths.code_generate_path):
169        is_ascend = 'ascend' in gen_path
170        is_cpu = 'cpu' in gen_path
171        is_gpu = 'gpu' in gen_path
172        malloc_inputs = generate_malloc_input(converter.need_malloc_tensors)
173        create_input_address = generate_create_input_address(converter.need_malloc_tensors)
174        get_inputs_kernel_tensors = generate_get_inputs_kernel_tensors(converter.call_args_with_tensor)
175
176        # call_impl
177        call_impl = ''
178        customize_include = ''
179        op_name_str = op_proto.class_name
180        cube_math_type = ''
181        get_cube_math_type = ''
182        real_output = ', ' + converter.op_outputs
183        proto_operator_name = op_proto.operator_name
184        register_custom_kernel = ''
185        if is_ascend and op_proto.ascend != 'default':
186            call_impl = cus_tpl.replace(call_args=converter.call_args,
187                                        return_values=converter.call_func_outputs,
188                                        customize_func=op_proto.ascend + "Customize",
189                                        )
190            customize_include = "#include \"plugin/device/ascend/kernel/pyboost/customize/{}.h\"".format(
191                operator_name.lower())
192        elif is_cpu and op_proto.cpu != 'default':
193            call_impl = cus_tpl.replace(call_args=converter.call_args,
194                                        return_values=converter.call_func_outputs,
195                                        customize_func=op_proto.cpu + "Customize",
196                                        )
197            customize_include = "#include \"plugin/device/cpu/kernel/pyboost/customize/{}.h\"".format(
198                operator_name.lower())
199            register_custom_kernel = "MS_REG_PYBOOST_CPU_CUSTOM_KERNEL({});".format(op_name_str)
200        elif is_gpu and op_proto.gpu != 'default':
201            call_impl = cus_tpl.replace(call_args=converter.call_args,
202                                        return_values=converter.call_func_outputs,
203                                        customize_func=op_proto.gpu + "Customize",
204                                        )
205            customize_include = "#include \"plugin/device/gpu/kernel/pyboost/customize/{}.h\"".format(
206                operator_name.lower())
207            register_custom_kernel = "MS_REG_PYBOOST_GPU_CUSTOM_KERNEL({});".format(op_name_str)
208        elif op_proto.is_view:
209            set_output_abs = "SetOutputAbstract();"
210            if converter.call_func_outputs == "outputs_":
211                set_output_abs = "SetOutputTupleAbstract();"
212            call_impl = view_tpl.replace(op_name=op_proto.class_name,
213                                         call_args=converter.call_args,
214                                         call_tensors=call_args_tensor,
215                                         return_values=converter.call_func_outputs,
216                                         input=converter.call_args[0],
217                                         set_output_abs=set_output_abs)
218            customize_include = "#include \"mindspore/core/ops/view/{}_strides_calc.h\"".format(proto_operator_name)
219        else:
220            cast_input_code, real_call_args_tensor = generate_tensor_cpu_cast_input_code(
221                converter.call_args_with_tensor, call_args_tensor)
222            if is_ascend and is_cube(op_proto.class_name):
223                get_cube_math_type = f'// cubeMathType: 0 - KEEP_DTYPE, 1 - ALLOW_FP32_DOWN_PRECISION\n'
224                get_cube_math_type += "auto cube_math_type = GetCubeMathType();"
225                cube_math_type = ', cube_math_type'
226            aclnn_name = AclnnUtils.get_aclnn_interface(op_name_str)
227            if converter.inplace_process != '':
228                real_output = ''
229            customize_include = '#include "ops/auto_generate/gen_ops_primitive.h"'
230
231            call_impl = call_tpl.replace(aclnn_name=aclnn_name,
232                                         call_args=converter.call_args,
233                                         call_tensors=call_args_tensor,
234                                         value_tuple_convert=converter.value_tuple_convert,
235                                         const_number_convert=converter.const_number_convert,
236                                         create_input_address=create_input_address,
237                                         tensor_list_convert=converter.tensor_list_convert,
238                                         call_args_with_tensor=converter.call_args_with_tensor,
239                                         malloc_inputs=malloc_inputs,
240                                         get_inputs_kernel_tensors=get_inputs_kernel_tensors,
241                                         get_cube_math_type=get_cube_math_type,
242                                         cube_math_type=cube_math_type,
243                                         real_call_args=converter.call_args_after_convert,
244                                         return_values=converter.call_func_outputs,
245                                         outputs=real_output,
246                                         inplace_process=converter.inplace_process,
247                                         cast_input_code=cast_input_code,
248                                         real_call_args_tensor=real_call_args_tensor,
249                                         class_name=op_proto.class_name,
250                                         op_name_str=op_name_str)
251
252        pyboost_op_source_str = src_tpl.replace(op_name=op_name_str,
253                                                operator_name=operator_name,
254                                                call_args_with_type=converter.call_args_with_types,
255                                                return_type=converter.cpp_func_return,
256                                                customize_include=customize_include,
257                                                call_impl=call_impl,
258                                                register_custom_kernel=register_custom_kernel)
259        op_header_dir_path = os.path.join(work_path, gen_path)
260        tmp_op_source_file_path = os.path.join(op_header_dir_path, "tmp_" + operator_name.lower() + ".cc")
261        dst_op_source_file_path = os.path.join(op_header_dir_path, operator_name.lower() + ".cc")
262        write_file(tmp_op_source_file_path, pyboost_op_source_str)
263        check_change_and_replace_file(dst_op_source_file_path, tmp_op_source_file_path)
264
265
266def generate_pyboost_op_register_source_code(work_path, all_ops, all_operator_names):
267    """ generate_pyboost_op_register_source_code """
268    include_str = ''
269    factory_str = ''
270    for op_name in all_ops:
271        factory_str += "template class OpFactory<{0}>;\n".format(op_name)
272    for operator_name in all_operator_names:
273        include_str += "#include \"kernel/pyboost/auto_generate/{0}.h\"\n".format(operator_name)
274    op_register_file_str = template.PYBOOST_OP_REGISTER_TEMPLATE.replace(op_includes=include_str,
275                                                                         op_factory_templates=factory_str)
276    op_register_dir_path = os.path.join(work_path, "mindspore/ccsrc/kernel/pyboost/auto_generate/")
277    pathlib.Path(op_register_dir_path).mkdir(parents=True, exist_ok=True)
278    tmp_op_register_file_path = os.path.join(op_register_dir_path, "tmp_" + "op_register.cc")
279    dst_op_register_file_path = os.path.join(op_register_dir_path, "op_register.cc")
280    write_file(tmp_op_register_file_path, op_register_file_str)
281    check_change_and_replace_file(dst_op_register_file_path, tmp_op_register_file_path)
282
283
284def generate_pyboost_op_return_code(op_proto):
285    """ generate_pyboost_op_return_code """
286    returns_type = []
287    for return_obj in op_proto.returns:
288        returns_type.append(get_return_type(return_obj.arg_dtype))
289    if len(returns_type) == 1:
290        cpp_func_return = returns_type[0]
291    elif not returns_type:
292        raise Exception("No return")
293    else:
294        cpp_func_return = "std::tuple("
295        cpp_func_return += ','.join(s for s in returns_type)
296        cpp_func_return += ")"
297    return returns_type, cpp_func_return
298
299
300def generate_pyboost_op_func_return_type(op_proto):
301    """ generate_pyboost_op_func_return_type """
302    returns_type = []
303    type_convert_to_base = {
304        'std::vector<tensor::TensorPtr>': 'std::vector<tensor::BaseTensorPtr>',
305        'tensor::TensorPtr': 'tensor::BaseTensorPtr'
306    }
307    for return_obj in op_proto.returns:
308        temp_return = get_return_type(return_obj.arg_dtype)
309        if temp_return in type_convert_to_base:
310            returns_type.append(type_convert_to_base[temp_return])
311        else:
312            raise Exception("Not return found")
313    if len(returns_type) == 1:
314        cpp_func_return = returns_type[0]
315    elif len(returns_type) > 1:
316        cpp_func_return = "std::tuple<"
317        cpp_func_return += ','.join(s for s in returns_type)
318        cpp_func_return += ">"
319    else:
320        raise Exception("Not return found")
321    return cpp_func_return
322
323
324def generate_pyboost_outputs(op_proto):
325    """ generate_pyboost_outputs """
326    op_outputs = ''
327    call_outputs = ''
328    returns_type = []
329    for return_obj in op_proto.returns:
330        returns_type.append(get_return_type(return_obj.arg_dtype))
331
332    if len(returns_type) == 1:
333        if returns_type[0] == 'tensor::TensorPtr':
334            op_outputs = 'outputs[0]'
335            call_outputs = 'outputs_[0]'
336        elif returns_type[0] == "std::vector<tensor::TensorPtr>":
337            op_outputs = 'outputs'
338            call_outputs = 'outputs_'
339        else:
340            raise Exception("Not support return type {}".format(returns_type[0]))
341    elif len(returns_type) > 1:
342        outputs_str = ''
343        for i in range(len(returns_type)):
344            outputs_str += 'outputs[{}],'.format(i)
345        op_outputs = outputs_str[:-1]
346
347        outputs_str = ''
348        for i in range(len(returns_type)):
349            outputs_str += 'outputs_[{}],'.format(i)
350        outputs_str = outputs_str[:-1]
351        call_outputs = "std::make_tuple(" + outputs_str + ")"
352
353    return op_outputs, call_outputs
354
355
356def generate_ops_header_files(work_path, yaml_data):
357    """
358    :param work_path:
359    :param yaml_data:
360    :return: void
361    """
362    extern_str = ''
363    extern_template = CppTemplate("MS_EXPORT extern OpDef g${op_name};\n")
364    for operator_name, operator_data in yaml_data.items():
365        op_proto = OpProto.load_from_yaml(operator_name, operator_data)
366        extern_str += extern_template.replace(op_name=op_proto.class_name)
367    ops_header_file = template.GEN_OPS_DEF_HEADER_TEMPLATE.replace(extern_variable=extern_str)
368    dir_path = os.path.join(work_path, "mindspore/core/ops/auto_generate")
369    pathlib.Path(dir_path).mkdir(parents=True, exist_ok=True)
370    dst_file_path = os.path.join(dir_path, "gen_ops_def.h")
371    tmp_file_path = os.path.join(dir_path, "tmp_gen_ops_def.h")
372    write_file(tmp_file_path, ops_header_file)
373    check_change_and_replace_file(dst_file_path, tmp_file_path)
374
375
376def generate_parser_func(op_proto: OpProto) -> str:
377    """
378    Generate parser func
379    :param op_proto:
380    :return: str
381    """
382    convert_template = CppTemplate("auto $arg_name = converter.${convert_func}(args, $arg_index);\n")
383    parser_func_str = ''
384    for index, arg in enumerate(op_proto.op_args):
385        is_optional = is_optional_param(arg)
386        if arg.is_type_id:
387            arg.arg_dtype = 'type'
388        convert_type_str = get_convert_type_str(arg.arg_dtype, is_optional)
389        parser_func_str += convert_template.replace(arg_name=arg.arg_name, convert_func=convert_type_str,
390                                                    arg_index=pyboost_utils.get_index(index))
391    return parser_func_str
392
393
394def get_convert_tensor_template():
395    """
396    Get convert tensor template
397    """
398    convert_to_tensor_template = CppTemplate(
399        'auto ${output} = PyNativeAlgo::Common::ConvertStubNodeToTensor(${input}, ${need_contiguous}, '\
400        'op_run_info->requires_grad);\n')
401    convert_to_tensor_list_template = CppTemplate(
402        'auto ${output} = PyNativeAlgo::Common::ConvertStubNodeToValueTuple(${input}, ${need_contiguous}, '\
403        'op_run_info->requires_grad);\n')
404    return convert_to_tensor_template, convert_to_tensor_list_template
405
406
407def generate_pyboost_functions(work_path, yaml_data):
408    """
409    Generate pyboost functions file from yaml.
410    """
411    pyboost_func_str = ''
412    pyboost_func_pybind_def = ''
413    pyboost_func_include_headers_str = ''
414    pyboost_func_include_header_template = CppTemplate("#include \"kernel/pyboost/auto_generate/${operator_name}.h\"\n")
415    for operator_name, operator_data in yaml_data.items():
416        op_proto = OpProto.load_from_yaml(operator_name, operator_data)
417        if not op_proto.is_dispatch:
418            continue
419        op_def_name_str = f"g{op_proto.class_name}"
420        operator_name = op_proto.operator_name
421        op_name_str = op_proto.class_name
422        op_args_str = [op_arg.arg_name for op_arg in op_proto.op_args]
423        parser_body_str = generate_parser_func(op_proto)
424        convert_to_tensor_template, convert_to_tensor_list_template = get_convert_tensor_template()
425
426        grad_args_str = []
427        call_args_str = []
428        cast_args_str = []
429        convert_stub_str = ''
430        optional_to_value_str = ''
431        need_contiguous = 'true'
432        value_str = '_value'
433        if op_proto.is_view:
434            # view/aclnn op no need to contiguous tensor.
435            need_contiguous = 'false'
436        for op_arg in op_proto.op_args:
437            cast_str = 'cast_'
438            convert_optional_to_value_template = CppTemplate(
439                "auto ${output} = PyNativeAlgo::PyBoost::OptionalToValue(${input});\n")
440            if pyboost_utils.is_tensor(op_arg):
441                if is_optional_param(op_arg):
442                    convert_stub_output_name = op_arg.arg_name + '_optional'
443                    convert_stub_str += convert_to_tensor_template.replace(output=convert_stub_output_name,
444                                                                           input=op_arg.arg_name,
445                                                                           need_contiguous=need_contiguous)
446                    cast_output = cast_str + convert_stub_output_name
447
448                    convert_optional_to_value_name = op_arg.arg_name + value_str
449                    optional_to_value_str += \
450                        convert_optional_to_value_template.replace(input=cast_output,
451                                                                   output=convert_optional_to_value_name)
452                    call_arg = convert_stub_output_name
453                    grad_arg = convert_optional_to_value_name
454                    cast_arg = cast_output
455                else:
456                    convert_stub_output_name = op_arg.arg_name + "_tensor"
457                    convert_stub_str += convert_to_tensor_template.replace(input=op_arg.arg_name,
458                                                                           output=convert_stub_output_name,
459                                                                           need_contiguous=need_contiguous)
460                    call_arg = convert_stub_output_name
461                    grad_arg = cast_str + convert_stub_output_name
462                    cast_arg = grad_arg
463            elif pyboost_utils.is_tensor_list(op_arg):
464                if is_optional_param(op_arg):
465                    # to adapt the cases that TensorList is optional.
466                    convert_stub_output_name = op_arg.arg_name + '_optional'
467                    convert_stub_str += convert_to_tensor_list_template.replace(output=convert_stub_output_name,
468                                                                                input=op_arg.arg_name,
469                                                                                need_contiguous=need_contiguous)
470                    cast_output = cast_str + convert_stub_output_name
471
472                    convert_optional_to_value_name = op_arg.arg_name + value_str
473                    optional_to_value_str += \
474                        convert_optional_to_value_template.replace(input=cast_output,
475                                                                   output=convert_optional_to_value_name)
476                    call_arg = convert_stub_output_name
477                    grad_arg = convert_optional_to_value_name
478                    cast_arg = cast_output
479                else:
480                    convert_stub_output_name = op_arg.arg_name + "_tensor_list"
481                    convert_stub_str += convert_to_tensor_list_template.replace(input=op_arg.arg_name,
482                                                                                output=convert_stub_output_name,
483                                                                                need_contiguous=need_contiguous)
484                    call_arg = convert_stub_output_name
485                    grad_arg = cast_str + convert_stub_output_name
486                    cast_arg = grad_arg
487            else:
488                call_arg = op_arg.arg_name
489                grad_arg = cast_str + op_arg.arg_name
490                cast_arg = grad_arg
491                if is_optional_param(op_arg):
492                    convert_optional_to_value_name = op_arg.arg_name + value_str
493                    optional_to_value_str += \
494                        convert_optional_to_value_template.replace(input=call_arg,
495                                                                   output=convert_optional_to_value_name)
496                    grad_arg = convert_optional_to_value_name
497            grad_args_str.append(grad_arg)
498            call_args_str.append(call_arg)
499            cast_args_str.append(cast_arg)
500        type_num, same_type = gen_signature_same_type_table(op_proto.indexes, operator_data)
501        pyboost_func_str += template.PYBOOST_FUNCTION_TEMPLATE.replace(func_name=op_proto.pyboost_function_name,
502                                                                       op_def_name=op_def_name_str, same_type=same_type,
503                                                                       type_num=type_num, parser_body=parser_body_str,
504                                                                       op_name=op_name_str,
505                                                                       convert_stub=convert_stub_str,
506                                                                       optional_to_value=optional_to_value_str,
507                                                                       call_args=call_args_str, grad_args=grad_args_str,
508                                                                       cast_args=cast_args_str, op_args=op_args_str,
509                                                                       class_name=op_proto.class_name)
510        pyboost_func_str = pyboost_func_str + template.NEW_LINE + template.NEW_LINE
511        pyboost_func_pybind_def += template.REGISTER_DEFINE_TEMPLATE.replace(
512            pyboost_op_name=get_pyboost_name(op_proto.operator_name),
513            pyboost_cfunc_name=op_proto.pyboost_function_name, class_name=op_proto.class_name)
514        pyboost_func_include_headers_str += pyboost_func_include_header_template.replace(operator_name=operator_name)
515    register_func_str = template.REGISTER_TEMPLATE.replace(register_func=pyboost_func_pybind_def)
516    pyboost_func_file = template.PYBOOST_HEADER_TEMPLATE.replace(include_op_header=pyboost_func_include_headers_str,
517                                                                 function_body=pyboost_func_str,
518                                                                 register_function_body=register_func_str)
519    dir_path = os.path.join(work_path, "mindspore/ccsrc/pipeline/pynative/op_function/auto_generate")
520    pathlib.Path(dir_path).mkdir(parents=True, exist_ok=True)
521    tmp_file_path = os.path.join(dir_path, "tmp_pyboost_functions.cc")
522    dst_file_path = os.path.join(dir_path, "pyboost_functions.cc")
523    write_file(tmp_file_path, pyboost_func_file)
524    check_change_and_replace_file(dst_file_path, tmp_file_path)
525
526
527def convert_value_type(op_proto: OpProto) -> str:
528    """
529    Generate parser func
530    :param op_proto:
531    :return: str
532    """
533    convert_template = CppTemplate(
534        "auto $arg_name = ValueConverter::${convert_func}(op_runner_info->inputs, $arg_index);\n")
535    parser_func_str = ''
536    for index, arg in enumerate(op_proto.op_args):
537        is_optional = is_optional_param(arg)
538        convert_type_str = get_value_convert_type_str(arg.arg_dtype, is_optional)
539        parser_func_str += convert_template.replace(arg_name=arg.arg_name, convert_func=convert_type_str,
540                                                    arg_index=pyboost_utils.get_index(index))
541    return parser_func_str
542
543
544def contiguous_tensor_value(op_proto: OpProto) -> str:
545    """
546    Generate parser func
547    :param op_proto:
548    :return: str
549    """
550    # Do nothing in view op
551    if op_proto.is_view:
552        return ''
553    contiguous_template = CppTemplate(
554        "$arg_name = ValueConverter::ContiguousTensorValue(op_runner_info, $arg_name);\n")
555    contiguous_func_str = ''
556    need_contiguous_dtype = {'tensor', 'tuple[tensor]'}
557    for arg in op_proto.op_args:
558        if arg.arg_dtype not in need_contiguous_dtype:
559            continue
560        contiguous_func_str += contiguous_template.replace(arg_name=arg.arg_name)
561    return contiguous_func_str
562
563
564def generate_pyboost_grad_functions(work_path, yaml_data):
565    """
566    Generate pyboostgrad  functions file from yaml.
567    """
568    pyboost_func_str = ''
569    pyboost_func_reg_def = ''
570    pyboost_func_include_headers_str = ''
571    pyboost_func_include_header_template = CppTemplate("#include \"kernel/pyboost/auto_generate/${operator_name}.h\"\n")
572    for operator_name, operator_data in yaml_data.items():
573        if not is_pyboost_enable(operator_data):
574            continue
575        op_proto = OpProto.load_from_yaml(operator_name, operator_data)
576        if not op_proto.is_dispatch:
577            continue
578        operator_name = op_proto.operator_name
579        op_name_str = op_proto.class_name
580        op_args_str = [op_arg.arg_name for op_arg in op_proto.op_args]
581        convert_value_type_str = convert_value_type(op_proto)
582        convert_value_type_str += contiguous_tensor_value(op_proto)
583
584        call_args_str = []
585        for op_arg in op_proto.op_args:
586            call_arg = op_arg.arg_name
587            call_args_str.append(call_arg)
588        pyboost_func_str += template.PYBOOST_GRAD_FUNCTION_TEMPLATE.replace(func_name=op_proto.pyboost_function_name,
589                                                                            op_name=op_name_str,
590                                                                            op_args=op_args_str,
591                                                                            convert_body=convert_value_type_str,
592                                                                            call_args=call_args_str)
593        pyboost_func_str = pyboost_func_str + template.NEW_LINE
594        pyboost_func_reg_def += template.REGISTER_PYBOOST_GRAD_DEFINE_TEMPLATE.replace(
595            pyboost_op_name=op_proto.class_name,
596            pyboost_cfunc_name=op_proto.pyboost_function_name)
597        pyboost_func_include_headers_str += pyboost_func_include_header_template.replace(operator_name=operator_name)
598
599    register_func_str = template.REGISTER_PYBOOST_GRAD_TEMPLATE.replace(register_func=pyboost_func_reg_def)
600    pyboost_func_file = \
601        template.PYBOOST_GRAD_HEADER_TEMPLATE.replace(include_op_header=pyboost_func_include_headers_str,
602                                                      function_body=pyboost_func_str,
603                                                      register_function_body=register_func_str)
604    dir_path = os.path.join(work_path, "mindspore/ccsrc/runtime/pynative/op_function/auto_generate")
605    pathlib.Path(dir_path).mkdir(parents=True, exist_ok=True)
606    tmp_file_path = os.path.join(dir_path, "tmp_pyboost_grad_functions.cc")
607    dst_file_path = os.path.join(dir_path, "pyboost_grad_functions.cc")
608    write_file(tmp_file_path, pyboost_func_file)
609    check_change_and_replace_file(dst_file_path, tmp_file_path)
610
611
612def generate_inplace_process_cpp_code(op_proto):
613    """ generate_ref_process_cpp_code """
614    inplace_process = f'// RefOps update output by input tensor\n'
615    has_ref = False
616    for index, return_obj in enumerate(op_proto.returns):
617        if return_obj.inplace != '':
618            inplace_process += f'outputs_[{index}]->set_device_address(' \
619                               f'{return_obj.inplace}_tensor->device_address()); '
620            has_ref = True
621            break
622    if has_ref:
623        return inplace_process
624    return ''
625
626
627def get_auto_generate_template():
628    """
629    get template collections
630    :return: TemplatePaths
631    """
632    op_header_template_path = [template.PYBOOST_ASCEND_OP_HEADER_TEMPLATE, template.PYBOOST_GPU_OP_HEADER_TEMPLATE,
633                               template.PYBOOST_CPU_OP_HEADER_TEMPLATE]
634    op_call_template_path = [template.PYBOOST_ASCEND_CALL_TEMPLATE, template.PYBOOST_GPU_CALL_TEMPLATE,
635                             template.PYBOOST_CPU_CALL_TEMPLATE]
636    op_source_template_path = [template.PYBOOST_ASCEND_OP_SOURCE_TEMPLATE, template.PYBOOST_GPU_OP_SOURCE_TEMPLATE,
637                               template.PYBOOST_CPU_OP_SOURCE_TEMPLATE]
638    op_custom_template_path = [template.PYBOOST_ASCEND_CUSTOMIZE_CALL_TEMPLATE,
639                               template.PYBOOST_GPU_CUSTOMIZE_CALL_TEMPLATE,
640                               template.PYBOOST_CPU_CUSTOMIZE_CALL_TEMPLATE]
641    op_view_template_path = [template.PYBOOST_ASCEND_VIEW_CALL_TEMPLATE, template.PYBOOST_GPU_VIEW_CALL_TEMPLATE,
642                             template.PYBOOST_CPU_VIEW_CALL_TEMPLATE]
643    code_generate_path = ["mindspore/ccsrc/plugin/device/ascend/kernel/pyboost/auto_generate/",
644                          "mindspore/ccsrc/plugin/device/gpu/kernel/pyboost/auto_generate/",
645                          "mindspore/ccsrc/plugin/device/cpu/kernel/pyboost/auto_generate/"]
646    return TemplatePaths(op_header_template_path, op_call_template_path, op_source_template_path,
647                         op_custom_template_path,
648                         op_view_template_path, code_generate_path)
649
650
651class OpTemplateConverter:
652    """
653    template converter
654    """
655
656    def __init__(self, op_proto):
657        self.op_proto = op_proto
658        self.op_name = op_proto.class_name
659        self.functional_name = op_proto.operator_name
660        self.call_args = self.parse_original_call_args(op_proto.op_args)
661        self.call_args_types = self.parse_call_args_types(op_proto.op_args)
662        self.call_args_with_types = self.parse_call_args_with_types(self.call_args, self.call_args_types)
663        self.need_malloc_tensors, self.tensor_list_convert, self.call_args_with_tensor = \
664            self.parse_need_malloc_tensors(op_proto.op_args, self.call_args)
665        self.call_args_after_convert, self.value_tuple_convert, self.const_number_convert = \
666            self.op_args_converter(op_proto.op_args, self.call_args)
667        self.cpp_func_return = generate_pyboost_op_func_return_type(op_proto)
668        self.op_outputs, self.call_func_outputs = generate_pyboost_outputs(op_proto)
669        self.inplace_process = generate_inplace_process_cpp_code(op_proto)
670
671    @staticmethod
672    def parse_call_args_types(op_args):
673        """
674        :param op_args:
675        :return: call_args_types
676        """
677        call_args_types = []
678        for op_arg in op_args:
679            is_optional = is_optional_param(op_arg)
680            call_args_types.append(get_input_dtype(op_arg.arg_dtype, is_optional))
681        return call_args_types
682
683    @staticmethod
684    def parse_call_args_with_types(call_args, call_args_types):
685        """
686        :param call_args:
687        :param call_args_types:
688        :return: call_args_with_types
689        """
690        call_args_with_types = []
691        for type_name, arg_name in zip(call_args_types, call_args):
692            call_args_with_types.append("const " + type_name + " &" + arg_name)
693        return call_args_with_types
694
695
696    @staticmethod
697    def parse_need_malloc_tensors(op_args, call_args):
698        """
699        :param op_args:
700        :param call_args:
701        :return: need_malloc_tensors
702        """
703        need_malloc_tensors = []
704        tensor_list_convert = []
705        call_args_with_tensor = []
706        for op_arg, call_arg in zip(op_args, call_args):
707            if pyboost_utils.is_tensor(op_arg):
708                call_arg = op_arg.arg_name + "_tensor"
709                need_malloc_tensors.append(call_arg)
710                call_args_with_tensor.append(call_arg)
711            elif tuple_input_to_cpp_type(op_arg.arg_dtype) and pyboost_utils.is_tensor_list(op_arg):
712                need_malloc_tensors.append(call_arg + "_vector")
713                tensor_list_convert.append(get_tuple_input_convert(call_arg, op_arg.arg_dtype))
714                call_args_with_tensor.append(call_arg + "_vector")
715            else:
716                call_args_with_tensor.append(call_arg)
717        return need_malloc_tensors, tensor_list_convert, call_args_with_tensor
718
719
720    @staticmethod
721    def parse_original_call_args(op_args):
722        """
723        :param op_args:
724        :return: call_args
725        """
726        call_args = []
727        for op_arg in op_args:
728            if pyboost_utils.is_tensor(op_arg):
729                call_arg = op_arg.arg_name + "_tensor"
730            elif pyboost_utils.is_tensor_list(op_arg):
731                call_arg = op_arg.arg_name + "_tensor_list"
732            else:
733                call_arg = op_arg.arg_name
734            call_args.append(call_arg)
735        return call_args
736
737    @staticmethod
738    def op_args_converter(op_args, call_args):
739        """Convert ValutePtr to cpp data type"""
740        call_args_after_convert = []
741        value_tuple_convert = []
742        const_number_convert = []
743        for op_arg, call_arg in zip(op_args, call_args):
744            if number_input_to_cpp_type(op_arg.arg_dtype):
745                call_args_after_convert.append(call_arg + "_imm")
746                const_number_convert.append(get_const_number_convert(call_arg, op_arg))
747            elif tuple_input_to_cpp_type(op_arg.arg_dtype):
748                call_args_after_convert.append(call_arg + "_vector")
749                value_tuple_convert.append(get_tuple_input_convert(call_arg, op_arg.arg_dtype))
750            else:
751                call_args_after_convert.append(call_arg)
752        if const_number_convert:
753            const_number_convert.insert(0, '// Convert ValuePtr to c++ scalar\n')
754        if value_tuple_convert:
755            value_tuple_convert.insert(0, '// ValueTuple to std::vector\n')
756        return call_args_after_convert, value_tuple_convert, const_number_convert
757
758
759def delete_residual_files(work_path, all_operator_name, code_generate_path_list):
760    """
761    Delete residual files.
762    """
763    code_generate_path_list.append("mindspore/ccsrc/kernel/pyboost/auto_generate/")
764    for code_generate_path in code_generate_path_list:
765        all_files_name = []
766        code_generate_path = os.path.join(work_path, code_generate_path)
767        if os.path.exists(code_generate_path):
768            all_files_name = os.listdir(code_generate_path)
769        all_registered_op = set(item.split(".")[0] for item in all_files_name)
770        need_clean_op = all_registered_op - set(all_operator_name)
771        for file in all_files_name:
772            if file == "op_register.cc":
773                continue
774            for clean_name in need_clean_op:
775                judge_file = file.split(".")[0]
776                if judge_file == clean_name:
777                    file_path = os.path.join(code_generate_path, file)
778                    if os.path.exists(file_path):
779                        os.remove(file_path)
780
781
782def generate_pyboost_op_cpp_code(work_path, yaml_data):
783    """
784    Generate pyboost op cpp code from yaml.
785    """
786
787    all_op_names = []
788    all_functional_names = []
789    all_operator_name = []
790    for operator_name, operator_data in yaml_data.items():
791        op_proto = OpProto.load_from_yaml(operator_name, operator_data)
792        if not op_proto.is_dispatch:
793            continue
794        template_paths = get_auto_generate_template()
795        converter = OpTemplateConverter(op_proto)
796        functional_name = converter.functional_name
797
798        op_name_str = converter.op_name
799
800        all_op_names.append(op_name_str)
801        all_operator_name.append(operator_name)
802        all_functional_names.append(functional_name)
803
804        call_args_with_types = converter.call_args_with_types
805        cpp_func_return = converter.cpp_func_return
806
807        generate_pyboost_base_op_header_code(work_path, op_name_str, functional_name, call_args_with_types,
808                                             cpp_func_return)
809        header_data = FuncHeaderData(work_path, template_paths.op_header_template_path,
810                                     template_paths.code_generate_path, op_name_str,
811                                     functional_name, call_args_with_types, cpp_func_return)
812        generate_pyboost_op_header_code(header_data)
813        generate_pyboost_op_source_code(work_path, op_proto, template_paths, converter)
814    delete_residual_files(work_path, all_operator_name, template_paths.code_generate_path)
815    generate_pyboost_op_register_source_code(work_path, all_op_names, all_functional_names)
816
817
818def gen_pyboost_inner_prim(work_path, op_yaml_data):
819    """
820    gen pyboost inner prim
821    :param work_path:
822    :param op_yaml_data:
823    :return:
824    """
825    gen_py = ''
826    gen_header = py_licence_str + template.IMPORT_PYBOOST_PRIM_HEADER
827    for operator_name, operator_data in op_yaml_data.items():
828        op_proto = OpProto.load_from_yaml(operator_name, operator_data)
829        if not op_proto.is_pyboost:
830            continue
831        if not op_proto.prim_init:
832            continue
833        gen_header += template.PYBOOST_PY_FUNC_IMPORT_HEADEAR.replace(class_name=op_proto.class_name)
834        args = operator_data.get('args')
835        input_args = []
836        processed_args = []
837        process_func = ''
838        for arg_name, arg_info in args.items():
839            arg_handler = arg_info.get('arg_handler')
840            processed_arg = arg_name
841            if arg_handler is not None and arg_handler != 'dtype_to_type_id':
842                process_func += f"""converted_{arg_name} = {arg_handler}({arg_name})\n"""
843                processed_arg = 'converted_' + arg_name
844            input_args.append(arg_name)
845            processed_args.append(processed_arg)
846        gen_py += template.PYTHON_PRIM_TEMPLATE.replace(class_name=op_proto.class_name, input_args=input_args,
847                                                        process_func=process_func, func_impl_name=operator_name,
848                                                        processed_args=processed_args)
849    dir_path = os.path.join(work_path, "mindspore/python/mindspore/ops/auto_generate")
850    pathlib.Path(dir_path).mkdir(parents=True, exist_ok=True)
851    dst_file_path = os.path.join(dir_path, "pyboost_inner_prim.py")
852    tmp_file_path = os.path.join(dir_path, "tmp_pyboost_inner_prim.py")
853    write_file(tmp_file_path, gen_header + gen_py)
854    check_change_and_replace_file(dst_file_path, tmp_file_path)
855
856
857def process_args(args):
858    """
859    process args
860    :return: func args, input_args
861    """
862    func_args = []
863    input_args = []
864    for arg_name, arg_info in args.items():
865        init_value = arg_info.get('init')
866        arg_handler = arg_info.get('arg_handler')
867        input_arg = arg_name
868        if arg_handler is not None and arg_handler != 'dtype_to_type_id':
869            input_arg = 'converted_' + arg_name
870        if init_value is None:
871            default_key = 'default'
872            default_value = arg_info.get(default_key)
873            default_value = '=' + str(default_value) if default_key in arg_info else ''
874            func_args.append(arg_name + default_value)
875            input_args.append(input_arg)
876        else:
877            if init_value == 'NO_VALUE':
878                func_args.append(f"""{arg_name}""")
879            else:
880                func_args.append(f"""{arg_name}={init_value}""")
881    return func_args, input_args
882
883
884def gen_pyboost_py_func(work_path, op_yaml_data, doc_data):
885    """ gen_pyboost_py_func """
886    gen_py = ''
887    op_desc_dict = {}
888
889    py_header = py_licence_str + template.IMPORT_PYBOOST_FUNC_HEADER
890    for operator_name, operator_desc in doc_data.items():
891        desc = operator_desc.get("description")
892        op_desc_dict[operator_name] = desc
893    for operator_name, operator_data in op_yaml_data.items():
894        op_proto = OpProto.load_from_yaml(operator_name, operator_data)
895        if not op_proto.is_pyboost:
896            continue
897        func_def = operator_data.get('function')
898        func_name = operator_name
899        if func_def is not None:
900            func_disable = get_disable_flag(func_def)
901            if func_disable:
902                continue
903            item = func_def.get("name")
904            if item is not None:
905                func_name = item
906        if func_name.endswith("_ext"):
907            func_name = func_name[:-4]
908        else:
909            continue
910        func_impl_name = func_name
911        if func_name.endswith("_"):
912            func_impl_name = func_name[:-1]
913        description = op_desc_dict.get(operator_name)
914        args = operator_data.get('args')
915        func_args, input_args = process_args(args)
916        gen_py += template.PYBOOST_PY_FUNC_TEMPLATE.replace(func_name=func_name, description=description,
917                                                            func_args=func_args,
918                                                            func_impl_name=func_impl_name,
919                                                            input_args=input_args)
920    dir_path = os.path.join(work_path, "mindspore/python/mindspore/ops/auto_generate")
921    pathlib.Path(dir_path).mkdir(parents=True, exist_ok=True)
922    dst_file_path = os.path.join(dir_path, "gen_extend_func.py")
923    tmp_file_path = os.path.join(dir_path, "tmp_gen_extend_func.py")
924    write_file(tmp_file_path, py_header + gen_py)
925    check_change_and_replace_file(dst_file_path, tmp_file_path)
926
927
928def gen_signature_same_type_table(args_map, operator_data):
929    """
930    gen signature same type table
931    :param operator_name:
932    :param operator_data:
933    :return:
934    """
935    args_signature = operator_data.get('args_signature')
936    signature_table = ''
937    type_num = 0
938    if args_signature is not None:
939        dtype_group = args_signature.get('dtype_group')
940        if dtype_group is not None:
941            match = re.findall(r'\((.*?)\)', dtype_group)
942            for item in match:
943                name_args = item.replace(' ', '').split(",")
944                signature_table += '{'
945                for arg in name_args:
946                    arg_index = args_map[arg]
947                    signature_table += f"""{arg_index}, """
948                signature_table = signature_table[:-2]
949                signature_table += '}, '
950                type_num += 1
951            signature_table = signature_table[:-2]
952    return type_num, signature_table
953
954
955def gen_pyboost_code(work_path, ops_yaml_data, doc_yaml_data):
956    """ gen_pyboost_code """
957    # generate pyboost inner prim
958    gen_pyboost_inner_prim(work_path, ops_yaml_data)
959    # generate pyboost py func
960    gen_pyboost_py_func(work_path, ops_yaml_data, doc_yaml_data)
961    # generate ops header file
962    generate_ops_header_files(work_path, ops_yaml_data)
963    # generate pyboost functions
964    generate_pyboost_functions(work_path, ops_yaml_data)
965    # generate pyboost grad functions
966    generate_pyboost_grad_functions(work_path, ops_yaml_data)
967    # generate pyboost backend cpp code
968    generate_pyboost_op_cpp_code(work_path, ops_yaml_data)
969