• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2023 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16Generate operator definition from ops.yaml
17"""
18import os
19import re
20import shutil
21import pathlib
22import logging
23import gen_utils
24from gen_utils import (py_licence_str, cc_license_str, check_change_and_replace_file, merge_files,
25                       merge_files_append, safe_load_yaml, convert_dtype_str, write_file)
26from pyboost_utils import get_pyboost_name, is_pyboost_enable, AclnnUtils, get_dtypes
27import template
28from template import CppTemplate
29from gen_pyboost_func import gen_pyboost_code
30from gen_aclnn_implement import gen_aclnn_kernel
31
32
33def _get_op_name(yaml_key, yaml_value):
34    """
35    Get op name for python class Primitive or c++ OpDef name.
36    """
37    # If has class item, use the specified item.
38    class_def = yaml_value.get("class")
39    if class_def is not None:
40        class_name_specify = class_def.get("name")
41        if class_name_specify is not None:
42            return class_name_specify
43    # Else use the default rule generate class name.
44    op_name = yaml_key
45    class_name_normal = ''.join(word.capitalize() for word in op_name.split('_'))
46    return class_name_normal
47
48
49def _get_op_func_name(yaml_key, yaml_value):
50    func_def = yaml_value.get('function')
51    func_name = yaml_key
52
53    if func_def is not None:
54        item = func_def.get("name")
55        if item is not None:
56            func_name = item
57    return func_name
58
59
60def _auto_generate_class_disabled(yaml_value):
61    """Check whether class can be auto generated."""
62    if 'class' not in yaml_value.keys():
63        return False
64    class_def = yaml_value.get("class")
65    if 'disable' not in class_def.keys():
66        return False
67    disable_item = class_def.get("disable")
68    if disable_item is True:
69        return True
70    if disable_item is False:
71        return False
72    raise TypeError(f"The disable label for class should be True or False, but get {disable_item}.")
73
74
75def _auto_generate_func_disabled(yaml_value):
76    """Check whether function can be auto generated."""
77    if 'function' not in yaml_value.keys():
78        return False
79    func_def = yaml_value.get('function')
80    if 'disable' not in func_def.keys():
81        return False
82    disable_item = func_def.get("disable")
83    if disable_item is True:
84        return True
85    if disable_item is False:
86        return False
87    raise TypeError(f"The disable label for function should be True or False, but get {disable_item}.")
88
89
90def signature_get_rw_label(arg_name, write_list, read_list, ref_list):
91    """
92    Generate signature rw code
93    """
94    for rw_arg_name in write_list:
95        if rw_arg_name == arg_name:
96            return ', sig.sig_rw.RW_WRITE'
97    for read_arg_name in read_list:
98        if read_arg_name == arg_name:
99            return ', sig.sig_rw.RW_READ'
100    for ref_arg_name in ref_list:
101        if ref_arg_name == arg_name:
102            return ', sig.sig_rw.RW_REF'
103    return ''
104
105
106def signature_get_rw_label_cc(rw_op_name, write_list, read_list, ref_list):
107    """
108    Generate cc signature rw code
109    """
110    rw_label = 'kRWDefault'
111    for op in write_list:
112        if op == rw_op_name:
113            rw_label = 'kRWWrite'
114    for op in read_list:
115        if op == rw_op_name:
116            rw_label = 'kRWRead'
117    for op in ref_list:
118        if op == rw_op_name:
119            rw_label = 'kRWRef'
120    return 'SignatureEnumRW::' + rw_label
121
122
123def signature_get_enum_dtype_cc(index):
124    """
125    Generate cc enum dtype code
126    """
127    enum_type = 'SignatureEnumDType::'
128    type_map = {0: 'kDType',
129                1: 'kDType1',
130                2: 'kDType2',
131                3: 'kDType3',
132                4: 'kDType4',
133                5: 'kDType5',
134                6: 'kDType6',
135                7: 'kDType7',
136                8: 'kDType8',
137                9: 'kDType9'}
138    if index in type_map:
139        return enum_type + type_map[index]
140    return enum_type + 'kDTypeEmptyDefaultValue'
141
142
143def signature_get_dtype_label(index):
144    """
145    Generate signature dtype code
146    """
147    dtype_index = ''
148    if index > 0:
149        dtype_index = f"""{index}"""
150    return f"""dtype=sig.sig_dtype.T{dtype_index}"""
151
152
153def get_same_dtype_groups(args_signature, args_name):
154    """
155    Get same dtype groups
156    """
157    same_dtype_groups = {}
158    dtype_conut = 0
159    if args_signature is None:
160        return same_dtype_groups, dtype_conut
161
162    dtype_group = args_signature.get('dtype_group')
163    if dtype_group is not None:
164        args_list = []
165        match = re.findall(r'\((.*?)\)', dtype_group)
166        for item in match:
167            args_list.append(item.replace(' ', '').split(","))
168        for arg_name in args_name:
169            if arg_name in same_dtype_groups:
170                continue
171            is_match = False
172            for group in args_list:
173                if arg_name in group:
174                    is_match = True
175                    for item in group:
176                        same_dtype_groups[item] = dtype_conut
177                    break
178            if not is_match:
179                same_dtype_groups[arg_name] = dtype_conut
180            dtype_conut = dtype_conut + 1
181    return same_dtype_groups, dtype_conut
182
183
184def generate_py_op_signature(op_name, args_signature, args_name, args_default):
185    """
186    Generate __mindspore_signature__
187    """
188    def _check_signature_arg_valid(op_name, sig_arg_names, args_names):
189        for sig_arg_name in sig_arg_names:
190            if sig_arg_name not in args_names:
191                raise ValueError(f"Op {op_name} has no input arg named '{sig_arg_name}'!")
192
193    if args_signature is None and not args_default:
194        return ''
195
196    signature_code = f"""    __mindspore_signature__ = """
197
198    # Init rw.
199    write_list = []
200    read_list = []
201    ref_list = []
202    if args_signature is not None:
203        rw_write = args_signature.get('rw_write')
204        rw_read = args_signature.get('rw_read')
205        rw_ref = args_signature.get('rw_ref')
206        if rw_write is not None:
207            write_list = rw_write.replace(' ', '').split(",")
208            _check_signature_arg_valid(op_name, write_list, args_name)
209        if rw_read is not None:
210            read_list = rw_read.replace(' ', '').split(",")
211            _check_signature_arg_valid(op_name, read_list, args_name)
212        if rw_ref is not None:
213            ref_list = rw_ref.replace(' ', '').split(",")
214            _check_signature_arg_valid(op_name, ref_list, args_name)
215    # Init dtype group.
216    same_dtype_groups, dtype_conut = get_same_dtype_groups(args_signature, args_name)
217    _check_signature_arg_valid(op_name, list(same_dtype_groups.keys()), args_name)
218    # Only one dtype_group is set.
219    if dtype_conut == 1 and not any([write_list, read_list, ref_list, args_default]):
220        signature_code += '('
221        for _ in range(len(args_name) - 1):
222            signature_code += 'sig.sig_dtype.T, '
223        signature_code += 'sig.sig_dtype.T)\n\n'
224        return signature_code
225
226    # Set sig.make_sig.
227    signature_code += f""" (\n"""
228    for arg_name in args_name:
229        signature_code += f"""        sig.make_sig('{arg_name}'"""
230        signature_code += signature_get_rw_label(arg_name, write_list, read_list, ref_list)
231        if arg_name in same_dtype_groups:
232            signature_code += f""", """ + signature_get_dtype_label(same_dtype_groups[arg_name])
233        if arg_name in args_default:
234            signature_code += f""", default=""" + str(args_default[arg_name])
235        signature_code += f"""),\n"""
236    signature_code += f"""    )\n\n"""
237    return signature_code
238
239
240def generate_cc_op_signature(args_signature, args_name):
241    """
242    generate signatures on in cc file
243    :param args_signature:
244    :param args_name:
245    :return:
246    """
247    if args_signature is None:
248        return ''
249    signature_code = ''
250    # Init rw.
251    write_list = []
252    read_list = []
253    ref_list = []
254    if args_signature is not None:
255        rw_write = args_signature.get('rw_write')
256        rw_read = args_signature.get('rw_read')
257        rw_ref = args_signature.get('rw_ref')
258        if rw_write is not None:
259            write_list = rw_write.replace(' ', '').split(",")
260        if rw_read is not None:
261            read_list = rw_read.replace(' ', '').split(",")
262        if rw_ref is not None:
263            ref_list = rw_ref.replace(' ', '').split(",")
264    # Init dtype group.
265    same_dtype_groups, _ = get_same_dtype_groups(args_signature, args_name)
266    for arg_name in args_name:
267        enum_rw = signature_get_rw_label_cc(arg_name, write_list, read_list, ref_list)
268        enum_dtype = signature_get_enum_dtype_cc(same_dtype_groups.get(arg_name))
269        signature = f"""Signature("{arg_name}", {enum_rw}, \
270         SignatureEnumKind::kKindPositionalKeyword, nullptr, {enum_dtype}),\n """
271        signature_code += signature
272    return signature_code
273
274
275def generate_py_op_deprecated(deprecated):
276    """
277    Generate @deprecated
278    """
279    if deprecated is None:
280        return ''
281    version = deprecated.get("version")
282    if version is None:
283        raise ValueError("The version of deprecated can't be None.")
284    substitute = deprecated.get("substitute")
285    if substitute is None:
286        raise ValueError("The substitute of deprecated can't be None.")
287    use_substitute = deprecated.get("use_substitute")
288    if use_substitute is None:
289        raise ValueError("The use_substitute of deprecated can't be None.")
290    if use_substitute is not True and use_substitute is not False:
291        raise ValueError(f"The use_substitute must be True or False, but got {use_substitute}")
292
293    deprecated = f"""    @deprecated("{version}", "{substitute}", {use_substitute})\n"""
294    return deprecated
295
296
297def _normalize_func_description_fromat(description):
298    """
299    Process description.
300    """
301    if not description:
302        return description
303    lines = description.split("\n")
304    if len(lines) == 1:
305        return description
306    # Add line indentation to other lines after the first line
307    for i in range(1, len(lines)):
308        indent = "    " if lines[i] else ""
309        lines[i] = indent + lines[i]
310    # Remove trailing blank lines
311    lines = lines if lines[-1] != "" else lines[:-1]
312    description = "\n".join(lines)
313    return description
314
315
316def _get_op_description(operator_name, doc_str):
317    """
318    Generate ops api description.
319    """
320    if doc_str is None:
321        print(f"Description is None, op_name: {operator_name}")
322        return ""
323    description = doc_str.get(operator_name)
324    if description is None:
325        print(f"Description is None, op_name: {operator_name}")
326        return ""
327    description = description.get("description")
328    if description is None:
329        print(f"Description is None, op_name: {operator_name}")
330        return ""
331    return _normalize_func_description_fromat(description)
332
333
334def generate_py_op_func(yaml_data, doc_data):
335    """
336    Generate operator python function api.
337    """
338    gen_py = ''
339
340    for operator_name, operator_data in yaml_data.items():
341        if _auto_generate_func_disabled(operator_data):
342            continue
343        func_name = _get_op_func_name(operator_name, operator_data)
344        args = operator_data.get('args')
345        class_name = _get_op_name(operator_name, operator_data)
346        func_args = []
347        prim_init_args = []
348        prim_call_args = []
349        for arg_name, arg_info in args.items():
350            is_prim_init = arg_info.get('prim_init')
351            has_default = 'default' in arg_info.keys()
352
353            # step1: Process function args.
354            if not has_default:
355                func_args.append(f"""{arg_name}""")
356            else:
357                default_value = arg_info.get('default')
358                func_args.append(f"""{arg_name}={default_value}""")
359
360            # step2: Process primitive object init args.
361            if is_prim_init:
362                prim_init_args.append(arg_name)
363
364            # step3: Process primitive object call args.
365            else:
366                prim_call_args.append(arg_name)
367        description = _get_op_description(operator_name, doc_data)
368        function_code = f"""\n
369def {func_name}({', '.join(arg for arg in func_args)}):
370    r\"\"\"
371    {description}
372    \"\"\"
373    {operator_name}_op = _get_cache_prim({class_name})({', '.join(arg_name for arg_name in prim_init_args)})
374    return {operator_name}_op({', '.join(arg_name for arg_name in prim_call_args)})\n"""
375
376        if not prim_init_args:
377            if _auto_generate_class_disabled(operator_data):
378                gen_py += f"""\n{operator_name}_op={class_name}()"""
379            function_code = f"""\n
380def {func_name}({', '.join(arg for arg in func_args)}):
381    r\"\"\"
382    {description}
383    \"\"\"
384    return {operator_name}_op({', '.join(arg_name for arg_name in prim_call_args)})\n"""
385        else:
386            dis = operator_data.get("dispatch")
387            if dis is not None:
388                enable_pyboost = dis.get("enable")
389                if enable_pyboost:
390                    function_code = f"""\n
391def {func_name}({', '.join(arg for arg in func_args)}):
392    r\"\"\"
393    {description}
394    \"\"\"
395    return {operator_name}_impl({', '.join(arg_name for arg_name, _ in args.items())})\n"""
396        gen_py += function_code
397
398    return gen_py
399
400
401def get_dtype(arg_info):
402    dtype = arg_info.get('dtype')
403    # Currently, TypeId is represented by int
404    if dtype == 'TypeId':
405        dtype = 'int'
406    return dtype
407
408
409def process_args(class_name, args):
410    """
411    Process arg for yaml, get arg_name, init value, type cast, arg_handler, etc.
412    """
413    inputs_name = []
414    args_name = []
415    args_assign = []
416    inputs_default = {}
417    init_args_with_default = []
418    args_handlers = {}
419    for arg_name, arg_info in args.items():
420        dtype = get_dtype(arg_info)
421        default_value = arg_info.get('default')
422        has_default = 'default' in arg_info.keys()
423        is_prim_init = arg_info.get('prim_init')
424        arg_handler = arg_info.get('arg_handler')
425
426        # step1: get args infos:
427        if is_prim_init:
428            # step1.1: get args name:
429            args_name.append(arg_name)
430            # step1.2: get args assign with default value:
431            if has_default:
432                init_args_with_default.append(f"""{arg_name}={default_value}""")
433            else:
434                init_args_with_default.append(f"""{arg_name}""")
435
436            # step1.3: get args set prim arg expression:
437            assign_str = gen_utils.get_assign_str_by_type_it(class_name, arg_info, arg_name, dtype)
438            if arg_handler:
439                assign_str = f"""        self._set_prim_arg_with_handler("{arg_name}", {assign_str}, {arg_handler})"""
440            else:
441                assign_str = f"""        self._set_prim_arg("{arg_name}", {assign_str})"""
442            args_assign.append(assign_str)
443        # step2: get inputs infos:
444        else:
445            # step2.1: get inputs name:
446            inputs_name.append(arg_name)
447
448            # step2.2: get default value of inputs:
449            if has_default:
450                inputs_default[arg_name] = default_value
451
452            # step2.3: get args_handler functions for inputs
453            if arg_handler:
454                args_handlers[arg_name] = arg_handler
455
456    return inputs_name, inputs_default, args_name, args_assign, init_args_with_default, args_handlers
457
458
459def generate_pyboost_import_header(yaml_data):
460    """
461    Generate python primitive
462    """
463    pyboost_import_header = ''
464    import_pyboost = CppTemplate("from mindspore._c_expression import $var\n")
465    for operator_name, operator_data in yaml_data.items():
466        is_pyboost = is_pyboost_enable(operator_data)
467        if is_pyboost:
468            header = import_pyboost.replace(var=get_pyboost_name(operator_name))
469            pyboost_import_header += header
470    return pyboost_import_header
471
472
473def _generate_class_description(class_name, func_name, input_args, init_args, func_disabled, doc_str):
474    """Generate description for every primitive definition."""
475    if func_disabled:
476        # if function disabled, function name is equal to operator_name
477        description = _get_op_description(func_name, doc_str)
478        description = f"""    r\"\"\"
479    {description}
480    \"\"\"
481"""
482        return description
483
484    # If function is an released API, refer to the function doc.
485    description_str = f"""    r\"\"\"
486    .. code-block::
487
488        prim = ops.{class_name}({', '.join(init_args)})
489        out = prim({', '.join(input_args)})
490
491    is equivalent to
492
493    .. code-block::
494
495        ops.{func_name}({", ".join(input_args + init_args)})
496
497    Refer to :func:`mindspore.ops.{func_name}` for more details.
498    \"\"\"
499"""
500    return description_str
501
502
503def get_init_code(init_code, operator_data):
504    """
505    Generate init code for primitive
506    """
507    labels = operator_data.get('labels')
508    if labels is not None:
509        if init_code != "":
510            init_code += "\n"
511        init_code += \
512            '\n'.join([f"""        self.add_prim_attr("{key}", {value})""" for key, value in labels.items()])
513    if init_code == "":
514        init_code = f"""        pass"""
515    return init_code
516
517
518def generate_py_primitive(yaml_data, doc_str):
519    """
520    Generate python primitive
521    """
522
523    def _generate_arg_handler(class_name, arg, arg_handler, is_optional):
524        """Generate arg_handler"""
525        arg_handler_call = f"""{arg_handler}('{class_name}', '{arg}', {arg})"""
526        if is_optional:
527            arg_handler_call = f"""{arg} if {arg} is None else {arg_handler_call}"""
528        return arg_handler_call
529
530    gen_py = ''
531    for operator_name, operator_data in yaml_data.items():
532        if _auto_generate_class_disabled(operator_data):
533            continue
534        class_name = _get_op_name(operator_name, operator_data)
535        func_name = _get_op_func_name(operator_name, operator_data)
536        pyboost_func_name = get_pyboost_name(operator_name)
537        args = operator_data.get('args')
538        inputs_args, inputs_default, init_args, args_assign, init_args_with_default, args_handlers = \
539            process_args(class_name, args)
540        init_code = '\n'.join(args_assign)
541        signature_code = generate_py_op_signature(class_name, operator_data.get('args_signature'), inputs_args,
542                                                  inputs_default)
543        deprecated_code = generate_py_op_deprecated(operator_data.get('deprecated'))
544        init_code = get_init_code(init_code, operator_data)
545        primitive_code = f"""\n
546class {class_name}(Primitive):\n"""
547        func_disabled = _auto_generate_func_disabled(operator_data)
548        primitive_code += _generate_class_description(class_name, func_name, inputs_args, init_args, func_disabled,
549                                                      doc_str)
550        if signature_code != "":
551            primitive_code += signature_code
552        if deprecated_code != "":
553            primitive_code += deprecated_code
554        primitive_code += f"""    @prim_arg_register
555    def __init__(self"""
556        if init_args_with_default:
557            primitive_code += ", " + f"""{', '.join(init_args_with_default) if init_args_with_default else ''}"""
558        call_args = []
559        for name in inputs_args:
560            call_args.append(f"""{name}={inputs_default[name]}""" if name in inputs_default else name)
561        primitive_code += f"""):
562{init_code}
563
564    def __call__(self, {', '.join(call_args)}):"""
565        is_pyboost = is_pyboost_enable(operator_data)
566        if is_pyboost:
567            primitive_code += f"""
568          return _convert_stub({pyboost_func_name}(self, ["""
569        else:
570            primitive_code += f"""
571          return super().__call__("""
572        if inputs_args:
573            args_with_handler = []
574            for arg in inputs_args:
575                if arg in args_handlers:
576                    is_optional = inputs_default.get(arg) == "None"
577                    args_with_handler.append(_generate_arg_handler(class_name, arg, args_handlers[arg], is_optional))
578                else:
579                    args_with_handler.append(arg)
580            primitive_code += ', '.join(args_with_handler)
581
582        if init_args:
583            primitive_code += ', '
584            primitive_code += ', '.join([f'self.{arg}' for arg in init_args])
585        if is_pyboost:
586            primitive_code += """]))"""
587        else:
588            primitive_code += """)
589"""
590
591        gen_py += primitive_code
592        if not init_args:
593            prim_op_object = f"""\n
594{operator_name}_op={class_name}()
595"""
596            gen_py += prim_op_object
597    return gen_py
598
599
600def generate_op_name_opdef(yaml_data):
601    """
602    Generate op name
603    """
604    op_name_head = f"""
605#ifndef MINDSPORE_CORE_OP_NAME_H_
606#define MINDSPORE_CORE_OP_NAME_H_
607
608namespace mindspore::ops {{
609"""
610
611    op_name_end = f"""}}  // namespace mindspore::ops
612
613#endif  // MINDSPORE_CORE_OP_NAME_H_
614"""
615
616    op_name_gen = ''
617    op_name_gen += op_name_head
618    for operator_name, operator_data in yaml_data.items():
619        k_name_op = _get_op_name(operator_name, operator_data)
620        op_name_gen += f"""constexpr auto kName{k_name_op} = "{k_name_op}";
621"""
622
623    op_name_gen += op_name_end
624    return op_name_gen
625
626
627def generate_op_prim_opdef(yaml_data):
628    """
629    Generate primitive c++ definition
630    """
631    ops_prim_head = f"""
632#ifndef MINDSPORE_CORE_OPS_GEN_OPS_PRIMITIVE_H_
633#define MINDSPORE_CORE_OPS_GEN_OPS_PRIMITIVE_H_
634
635#include <memory>
636#include "ir/anf.h"
637#include "ir/primitive.h"
638#include "ops/auto_generate/gen_ops_name.h"
639#include "mindapi/base/macros.h"
640
641namespace mindspore::prim {{
642"""
643
644    ops_prim_end = f"""}}  // namespace mindspore::prim
645#endif  // MINDSPORE_CORE_OPS_GEN_OPS_PRIMITIVE_H_
646"""
647
648    ops_prim_gen = ''
649    ops_prim_gen += ops_prim_head
650    for operator_name, operator_data in yaml_data.items():
651        k_name_op = _get_op_name(operator_name, operator_data)
652        ops_prim_gen += f"""GVAR_DEF(PrimitivePtr, kPrim{k_name_op}, std::make_shared<Primitive>(ops::kName{k_name_op}))
653"""
654    ops_prim_gen += ops_prim_end
655    return ops_prim_gen
656
657
658def generate_lite_ops(yaml_data):
659    """
660    Generate BaseOperator parameter set and get func
661    """
662    lite_ops_h_head = f"""
663#ifndef MINDSPORE_CORE_OPS_GEN_LITE_OPS_H_
664#define MINDSPORE_CORE_OPS_GEN_LITE_OPS_H_
665
666#include <vector>
667#include "ops/base_operator.h"
668#include "ops/auto_generate/gen_ops_name.h"
669
670namespace mindspore::ops {{
671"""
672
673    lite_ops_h_end = f"""}}  // namespace mindspore::ops
674#endif  // MINDSPORE_CORE_OPS_GEN_LITE_OPS_H_
675"""
676
677    lite_ops_cc_head = """
678#include "ops/auto_generate/gen_lite_ops.h"
679#include "mindapi/src/helper.h"
680#include "ops/primitive_c.h"
681#include "ops/base_operator.h"
682#include "abstract/abstract_value.h"
683
684namespace mindspore::ops {
685"""
686
687    lite_ops_cc_end = f"""}}  // namespace mindspore::ops
688    """
689
690    lite_ops_h_gen = ''
691    lite_ops_cc_gen = ''
692
693    lite_ops_h_gen += lite_ops_h_head
694    lite_ops_cc_gen += lite_ops_cc_head
695    for operator_name, operator_data in yaml_data.items():
696        op_name = _get_op_name(operator_name, operator_data)
697        lite_ops_h_gen += f"""class MIND_API {op_name} : public BaseOperator {{
698 public:
699  MIND_API_BASE_MEMBER({op_name});
700  {op_name}() : BaseOperator(kName{op_name}) {{}}\n"""
701        args = operator_data.get('args')
702        for _, (arg_name, arg_info) in enumerate(args.items()):
703            is_prim_init = arg_info.get('prim_init')
704            if not is_prim_init:
705                continue
706
707            dtype = get_dtype(arg_info)
708            if dtype == "str":
709                dtype = "std::string"
710            if dtype in ("tuple[str]", "list[str]"):
711                dtype = "std::vector<std::string>"
712            if dtype in ("tuple[int]", "list[int]"):
713                dtype = "std::vector<int64_t>"
714            if dtype in ("tuple[float]", "list[float]"):
715                dtype = "std::vector<float>"
716            if dtype in ("tuple[bool]", "list[bool]"):
717                dtype = "std::vector<bool>"
718            if dtype == "int":
719                dtype = "int64_t"
720            lite_ops_h_gen += f"""  void set_{arg_name}(const {dtype} &{arg_name});\n"""
721            lite_ops_h_gen += f"""  {dtype} get_{arg_name}() const;\n"""
722
723            lite_ops_cc_gen += f"""void {op_name}::set_{arg_name}(const {dtype} &{arg_name}) \
724            {{ (void)this->AddAttr("{arg_name}", api::MakeValue({arg_name})); }}\n\n"""
725            lite_ops_cc_gen += f"""{dtype} {op_name}::get_{arg_name}() const \
726            {{ return GetValue<{dtype}>(GetAttr("{arg_name}")); }}\n\n"""
727
728            op_name = _get_op_name(operator_name, operator_data)
729        lite_ops_cc_gen += f"""REGISTER_PRIMITIVE_C(kName{op_name}, {op_name});\n"""
730        lite_ops_cc_gen += f"""MIND_API_OPERATOR_IMPL({op_name}, BaseOperator);\n\n"""
731        lite_ops_h_gen += f"""}};\n\n"""
732    lite_ops_h_gen += lite_ops_h_end
733    lite_ops_cc_gen += lite_ops_cc_end
734    return lite_ops_h_gen, lite_ops_cc_gen
735
736
737def generate_cc_opdef(yaml_data):
738    """
739    Generate c++ OpDef
740    """
741    gen_cc_code = f"""\n
742namespace mindspore::ops {{"""
743    gen_opdef_map = f"""
744std::unordered_map<std::string, OpDefPtr> gOpDefTable = {{"""
745    gen_include = f"""\n
746#include \"ops/auto_generate/gen_ops_def.h\""""
747    gen_include += f"""
748#include \"mindspore/core/ir/signature.h\""""
749
750    for operator_name, operator_data in yaml_data.items():
751        args = operator_data.get('args')
752        class_name = _get_op_name(operator_name, operator_data)
753        inputs_args, _, _, _, _, _ = process_args(class_name, args)
754        signature_code = generate_cc_op_signature(operator_data.get('args_signature'), inputs_args)
755        args = operator_data.get('args')
756        returns = operator_data.get('returns')
757        dispatch = operator_data.get("dispatch")
758        # dispatch not defined in yaml or dispatch.enable==False
759        if not dispatch or not dispatch.get("enable"):
760            dispatch = "false"
761        else:
762            dispatch = "true"
763        enable_dispatch_str = f"""{dispatch}"""
764
765        is_view = operator_data.get('view')
766        if is_view:
767            is_view = "true"
768        else:
769            is_view = "false"
770        is_view_str = f"""{is_view}"""
771
772
773        gen_include += f"""\n#include "ops/ops_func_impl/{operator_name}.h\""""
774        cc_index_str = ''
775        gen_opdef_map += f"""\n  {{"{class_name}", &g{class_name}}},"""
776        input_args_str = ''
777        args_dict = {}
778        for i, (arg_name, arg_info) in enumerate(args.items()):
779            args_dict[arg_name] = i
780            cc_index_str += f"""{{"{arg_name}", {i}}},\n"""
781            dtype = get_dtype(arg_info)
782            cc_dtype_str = convert_dtype_str(dtype)
783
784            is_prim_init = 1 if arg_info.get('prim_init') else 0
785            arg_handler = arg_info.get('arg_handler')
786            arg_handler_str = "" if arg_handler is None else arg_handler
787
788            type_cast = arg_info.get('type_cast')
789            type_cast_str = "" if type_cast is None else \
790                ', '.join('DT_' + type.replace('[', '_').replace(']', '').upper() for type in
791                          (ct.strip() for ct in type_cast.split(",")))
792
793            # default: None is regarded as a optional argument.
794            is_optional_str = "false"
795            if 'default' in arg_info.keys() and arg_info.get('default') == "None":
796                is_optional_str = "true"
797
798            input_args_str += f"""\n    {{/*.arg_name_=*/"{arg_name}", /*.arg_dtype_=*/{cc_dtype_str}, """ + \
799                              f"""/*.as_init_arg_=*/{is_prim_init}, /*.arg_handler_=*/"{arg_handler_str}", """ + \
800                              f"""/*.cast_dtype_ =*/{{{type_cast_str}}}, /*.is_optional_=*/{is_optional_str}}},"""
801
802        # Process outputs.
803        return_args_str = ''
804        for return_name, return_info in returns.items():
805            return_dtype = return_info.get('dtype')
806            ref_name = return_info.get('inplace')
807            ref_index_str = -1 if ref_name is None else args_dict.get(ref_name)
808            cc_return_type_str = 'DT_' + return_dtype.replace('[', '_').replace(']', '').upper()
809            return_args_str += f"""{{/*.arg_name_=*/"{return_name}", /*.arg_dtype_=*/{cc_return_type_str},
810            /*.inplace_input_index_=*/{ref_index_str}}},\n"""
811
812        op_def_cc = template.OP_PROTO_TEMPLATE.replace(class_name=class_name, input_args=input_args_str,
813                                                       return_args=return_args_str, signatures=signature_code,
814                                                       indexes=cc_index_str, enable_dispatch=enable_dispatch_str,
815                                                       is_view=is_view_str)
816        gen_cc_code += op_def_cc
817    gen_opdef_map += f"""\n}};"""
818    gen_cc_code += gen_opdef_map
819
820    cc_opdef_end = f"""\n}}  // namespace mindspore::ops\n"""
821    return gen_include + gen_cc_code + cc_opdef_end
822
823
824ops_py_prim_header = f"""
825\"\"\"Operators definition generated by gen_ops.py, includes primitive classes.\"\"\"
826
827from mindspore.ops.primitive import Primitive, prim_arg_register
828from mindspore.ops import signature as sig
829from mindspore.common import dtype as mstype
830from mindspore.common._decorator import deprecated
831from mindspore.ops._primitive_cache import _get_cache_prim
832from mindspore.ops.auto_generate.gen_arg_dtype_cast import type_it
833from mindspore.ops.auto_generate.gen_arg_handler import *
834from mindspore._c_expression import OpDtype
835from mindspore.common._stub_tensor import _convert_stub
836"""
837
838
839ops_py_def_header = f"""
840\"\"\"Operators definition generated by gen_ops.py, includes functions.\"\"\"
841
842from .gen_ops_prim import *
843from .pyboost_inner_prim import *
844from mindspore.ops.operations.manually_defined.ops_def import *
845from mindspore.ops._primitive_cache import _get_cache_prim
846"""
847
848
849def generate_ops_prim_file(work_path, yaml_str, doc_str, file_pre):
850    py_path = os.path.join(work_path, f'mindspore/python/mindspore/ops/auto_generate/{file_pre}_ops_prim.py')
851    tmp_py_path = os.path.join(work_path, f'mindspore/python/mindspore/ops/auto_generate/tmp_{file_pre}_ops_prim.py')
852    pyboost_import_header = generate_pyboost_import_header(yaml_str)
853    py_prim = generate_py_primitive(yaml_str, doc_str)
854    write_file(tmp_py_path, py_licence_str + ops_py_prim_header + pyboost_import_header + py_prim)
855    check_change_and_replace_file(py_path, tmp_py_path)
856
857
858def generate_ops_def_file(work_path, yaml_str, doc_str, file_pre):
859    py_path = os.path.join(work_path, f'mindspore/python/mindspore/ops/auto_generate/{file_pre}_ops_def.py')
860    tmp_py_path = os.path.join(work_path, f'mindspore/python/mindspore/ops/auto_generate/tmp_{file_pre}_ops_def.py')
861    py_func = generate_py_op_func(yaml_str, doc_str)
862    write_file(tmp_py_path, py_licence_str + ops_py_def_header + py_func)
863    check_change_and_replace_file(py_path, tmp_py_path)
864
865
866def generate_ops_py_files(work_path, yaml_str, doc_str, file_pre):
867    """
868    Generate ops python file from yaml.
869    """
870    generate_ops_prim_file(work_path, yaml_str, doc_str, file_pre)
871    generate_ops_def_file(work_path, yaml_str, doc_str, file_pre)
872
873
874def generate_ops_cc_files(work_path, yaml_str):
875    """
876    Generate ops c++ file from yaml.
877    """
878    # ops_def
879    op_cc_path = os.path.join(work_path, 'mindspore/core/ops/auto_generate/gen_ops_def.cc')
880    tmp_op_cc_path = os.path.join(work_path, 'mindspore/core/ops/auto_generate/tmp_gen_ops_def.cc')
881    cc_def_code = generate_cc_opdef(yaml_str)
882    write_file(tmp_op_cc_path, cc_license_str + cc_def_code)
883    check_change_and_replace_file(op_cc_path, tmp_op_cc_path)
884
885    # ops_primitive
886    op_prim_path = os.path.join(work_path, 'mindspore/core/ops/auto_generate/gen_ops_primitive.h')
887    tmp_op_prim_path = os.path.join(work_path, 'mindspore/core/ops/auto_generate/tmp_gen_ops_primitive.h')
888    op_prim_code = generate_op_prim_opdef(yaml_str)
889    write_file(tmp_op_prim_path, cc_license_str + op_prim_code)
890    check_change_and_replace_file(op_prim_path, tmp_op_prim_path)
891
892    # lite_h_ops
893    lite_ops_h_path = os.path.join(work_path, 'mindspore/core/ops/auto_generate/gen_lite_ops.h')
894    tmp_lite_ops_h_path = os.path.join(work_path, 'mindspore/core/ops/auto_generate/tmp_gen_lite_ops.h')
895    lite_ops_h_code, lite_ops_cc_code = generate_lite_ops(yaml_str)
896    write_file(tmp_lite_ops_h_path, cc_license_str + lite_ops_h_code)
897    check_change_and_replace_file(lite_ops_h_path, tmp_lite_ops_h_path)
898
899    # lite_cc_ops
900    lite_ops_cc_path = os.path.join(work_path, 'mindspore/core/ops/auto_generate/gen_lite_ops.cc')
901    tmp_lite_ops_cc_path = os.path.join(work_path, 'mindspore/core/ops/auto_generate/tmp_gen_lite_ops.cc')
902    write_file(tmp_lite_ops_cc_path, cc_license_str + lite_ops_cc_code)
903    check_change_and_replace_file(lite_ops_cc_path, tmp_lite_ops_cc_path)
904
905    # ops_names
906    op_name_path = os.path.join(work_path, 'mindspore/core/ops/auto_generate/gen_ops_name.h')
907    tmp_op_name_path = os.path.join(work_path, 'mindspore/core/ops/auto_generate/tmp_gen_ops_name.h')
908    op_name_code = generate_op_name_opdef(yaml_str)
909    write_file(tmp_op_name_path, cc_license_str + op_name_code)
910    check_change_and_replace_file(op_name_path, tmp_op_name_path)
911
912
913def generate_op_labels(yaml_data):
914    """
915    Generate python labels
916    """
917    gen_label_py = f"""op_labels = {{"""
918    for operator_name, operator_data in yaml_data.items():
919        labels = operator_data.get('labels')
920        if labels is not None:
921            class_name = _get_op_name(operator_name, operator_data)
922            gen_label_py += f"""
923    "{class_name}": {{"""
924            gen_label_py += f""", """.join([f""""{key}": {value}""" for key, value in labels.items()])
925            gen_label_py += f"""}},"""
926    gen_label_py += f"""
927}}"""
928    return gen_label_py
929
930
931def generate_op_arg_default_value(yaml_data):
932    """
933    Generate python default value.
934    """
935    default_py_header = f"""\"\"\"Operator labels and args default value.\"\"\"
936from mindspore.common import dtype as mstype\n\n"""
937
938    gen_default_py = default_py_header + f"""op_args_default_value = {{"""
939    for operator_name, operator_data in yaml_data.items():
940        arg_default_dict = {}
941        args = operator_data.get('args')
942        for arg_name, arg_info in args.items():
943            arg_default = arg_info.get('default')
944            if arg_default is not None:
945                arg_default_dict[arg_name] = arg_default
946        if arg_default_dict:
947            class_name = _get_op_name(operator_name, operator_data)
948            gen_default_py += f"""
949    "{class_name}": {{"""
950            gen_default_py += f""", """.join([f""""{key}": {value}""" for key, value in arg_default_dict.items()])
951            gen_default_py += f"""}},"""
952    gen_default_py += f"""
953}}"""
954    return gen_default_py
955
956
957def generate_create_instance_helper_file(work_path, yaml_str):
958    """
959    Generate C++ helper file from yaml.
960    """
961    dst_dir = os.path.join(work_path, 'mindspore/python/mindspore/ops/auto_generate')
962    op_py_path = os.path.join(dst_dir, 'cpp_create_prim_instance_helper.py')
963    tmp_op_py_path = os.path.join(dst_dir, 'tmp_cpp_create_prim_instance_helper.py')
964    py_labels = generate_op_labels(yaml_str)
965    py_arg_default = generate_op_arg_default_value(yaml_str)
966    write_file(tmp_op_py_path, py_licence_str + "\n" + py_arg_default + "\n\n" + py_labels + "\n")
967    check_change_and_replace_file(op_py_path, tmp_op_py_path)
968
969
970def generate_aclnn_reg_code(yaml_data):
971    """generate aclnn register code"""
972    current_path = os.path.dirname(os.path.abspath(__file__))
973    work_path = os.path.join(current_path, '../../../../')
974    ops_yaml_path = os.path.join(work_path, 'mindspore/python/mindspore/ops_generate/ops.yaml')
975    yaml_str = gen_utils.safe_load_yaml(ops_yaml_path)
976
977    reg_code = f"""
978#include "plugin/device/ascend/kernel/opapi/aclnn_kernel_mod.h"
979
980namespace mindspore {{
981namespace kernel {{
982"""
983    for operator_name, operator_data in yaml_data.items():
984        dispatch = operator_data.get("dispatch")
985        if not dispatch or not dispatch.get("enable"):
986            continue
987        Ascend = dispatch.get("Ascend")
988        if Ascend is not None:  # KernelMod is provided by yaml, don't auto generate it.
989            continue
990        _, _, none_tensor_exist = get_dtypes(operator_data)
991        if none_tensor_exist:
992            gen_aclnn_kernel(operator_name, yaml_str, auto=True)
993            continue
994        class_name = ''.join(word.capitalize() for word in operator_name.split('_'))
995        op_class = operator_data.get("class")
996        if op_class and op_class.get("name") is not None:
997            class_name = op_class.get("name")
998        inputs_outputs_num = len(operator_data.get("args")) + len(operator_data.get("returns"))
999        aclnn_name = AclnnUtils.get_aclnn_interface(class_name)
1000        reg_code += f"""
1001MS_ACLNN_COMMON_KERNEL_FACTORY_REG({class_name}, {aclnn_name}, {inputs_outputs_num});"""
1002    reg_code += f"""
1003}}  // namespace kernel
1004}}  // namespace mindspore
1005"""
1006    return reg_code
1007
1008
1009def generate_aclnn_reg_file(work_path, yaml_str):
1010    """
1011    Generate nnacl kernelmod register
1012    """
1013    tmp_register_file = work_path + 'mindspore/ccsrc/plugin/device/ascend/kernel/opapi/tmp_aclnn_kernel_register.cc'
1014    register_file = work_path + 'mindspore/ccsrc/plugin/device/ascend/kernel/opapi/aclnn_kernel_register_auto.cc'
1015    reg_code = generate_aclnn_reg_code(yaml_str)
1016    write_file(tmp_register_file, cc_license_str + reg_code)
1017    check_change_and_replace_file(register_file, tmp_register_file)
1018
1019
1020def generate_arg_handler_files(work_path):
1021    """
1022    Generate arg handler files.
1023    """
1024    dst_dir = os.path.join(work_path, 'mindspore/python/mindspore/ops/auto_generate')
1025    src_arg_handler_path = os.path.join(work_path, 'mindspore/python/mindspore/ops_generate/arg_handler.py')
1026    dst_arg_handler_path = os.path.join(dst_dir, 'gen_arg_handler.py')
1027    tmp_dst_arg_handler_path = os.path.join(dst_dir, 'tmp_gen_arg_handler.py')
1028    if not os.path.exists(dst_dir):
1029        os.makedirs(dst_dir)
1030    shutil.copy(src_arg_handler_path, tmp_dst_arg_handler_path)
1031    check_change_and_replace_file(dst_arg_handler_path, tmp_dst_arg_handler_path)
1032
1033    src_arg_dtype_cast_path = os.path.join(work_path, 'mindspore/python/mindspore/ops_generate/arg_dtype_cast.py')
1034    dst_arg_dtype_cast_path = os.path.join(dst_dir, 'gen_arg_dtype_cast.py')
1035    tmp_arg_dtype_cast_path = os.path.join(dst_dir, 'tmp_arg_dtype_cast.py')
1036    shutil.copy(src_arg_dtype_cast_path, tmp_arg_dtype_cast_path)
1037    check_change_and_replace_file(dst_arg_dtype_cast_path, tmp_arg_dtype_cast_path)
1038
1039
1040def main():
1041    current_path = os.path.dirname(os.path.abspath(__file__))
1042    work_path = os.path.join(current_path, '../../../../')
1043
1044    # merge ops yaml
1045    ops_yaml_path = os.path.join(work_path, 'mindspore/python/mindspore/ops_generate/ops.yaml')
1046    doc_yaml_path = os.path.join(work_path, 'mindspore/python/mindspore/ops_generate/ops_doc.yaml')
1047
1048    ops_yaml_dir_path = os.path.join(work_path, 'mindspore/core/ops/ops_def/')
1049    infer_ops_yaml_dir_path = os.path.join(work_path, 'mindspore/core/ops/ops_def/infer/')
1050    doc_yaml_dir_path = os.path.join(work_path, 'mindspore/core/ops/ops_def/doc/')
1051    merge_files(ops_yaml_dir_path, ops_yaml_path, '*op.yaml')
1052    merge_files_append(infer_ops_yaml_dir_path, ops_yaml_path, '*op.yaml')
1053    merge_files(doc_yaml_dir_path, doc_yaml_path, '*doc.yaml')
1054
1055    # make auto_generate dir
1056    cc_path = os.path.join(work_path, 'mindspore/core/ops/auto_generate/')
1057    pathlib.Path(cc_path).mkdir(parents=True, exist_ok=True)
1058
1059    # generate arg_handler files
1060    generate_arg_handler_files(work_path)
1061
1062    # read ops definition str and doc str
1063    ops_yaml_str = safe_load_yaml(ops_yaml_path)
1064    doc_yaml_str = safe_load_yaml(doc_yaml_path)
1065
1066    # generate ops python files
1067    generate_ops_py_files(work_path, ops_yaml_str, doc_yaml_str, "gen")
1068
1069    # generate ops c++ files
1070    generate_ops_cc_files(work_path, ops_yaml_str)
1071    # generate create prim instance helper file
1072    generate_create_instance_helper_file(work_path, ops_yaml_str)
1073    # generate pyboost code
1074    gen_pyboost_code(work_path, ops_yaml_str, doc_yaml_str)
1075    # generate aclnn kernelmod register
1076    generate_aclnn_reg_file(work_path, ops_yaml_str)
1077
1078
1079if __name__ == "__main__":
1080    try:
1081        main()
1082    # pylint: disable=broad-except
1083    except Exception as e:
1084        logging.critical("Auto generate failed, err info: %s", e)
1085