• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2023 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Template."""
16import re
17import os
18
19
20class CppTemplate:
21    """
22    template for generate c++ code
23    """
24    regular_str = r"(^[^\n\S]*)?\$([^\d\W]\w*|\{,?[^\d\W]\w*\,?})"
25    regular_match = re.compile(regular_str, re.MULTILINE)
26
27    def __init__(self, code_pattern):
28        self.code_pattern = code_pattern
29
30    @staticmethod
31    def load_from_file(file_path):
32        with open(file_path, "r") as f:
33            return CppTemplate(f.read())
34
35    def replace(self, **kwargs):
36        """
37        replace param.
38        :param kwargs:
39        :return:
40        """
41
42        def find(key: str):
43            if key in kwargs:
44                return kwargs[key]
45            raise TypeError(f"{key} should be in kwargs!")
46
47        def add_indent(indent, var):
48            return "".join([indent + line + "\n" for data in var for line in str(data).splitlines()]).rstrip()
49
50        def extract_variable(key):
51            start = ""
52            end = ""
53            if key[0] == "{":
54                key = key[1:-1]
55                if key[0] == ",":
56                    start = ","
57                    key = key[1:]
58                if key[-1] == ",":
59                    end = ", "
60                    key = key[:-1]
61            return find(key), start, end
62
63        def match_rule(match):
64            indent = match.group(1)
65            key = match.group(2)
66            var, start, end = extract_variable(key)
67            if indent is not None:
68                if not isinstance(var, list):
69                    return add_indent(indent, [var])
70                return add_indent(indent, var)
71            if isinstance(var, list):
72                code = ", ".join(str(x) for x in var)
73                if not var:
74                    return code
75                return start + code + end
76            return str(var)
77
78        return self.regular_match.sub(match_rule, self.code_pattern)
79
80
81NEW_LINE = "\n"
82WORK_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../../")
83
84PYTHON_PRIM_TEMPLATE = CppTemplate("""
85
86class _Pyboost${class_name}Prim(${class_name}Prim_):
87    def __call__(self, ${input_args}):
88        ${process_func}
89        return _convert_stub(super().__call__(${input_args}))
90
91
92${func_impl_name}_impl = _Pyboost${class_name}Prim()
93""")
94
95IMPORT_PYBOOST_PRIM_HEADER = f"""
96from mindspore.common._stub_tensor import _convert_stub
97from mindspore.ops.auto_generate.gen_arg_handler import *
98"""
99
100IMPORT_PYBOOST_FUNC_HEADER = f"""
101from mindspore.common import dtype as mstype
102from mindspore.ops.auto_generate.pyboost_inner_prim import *
103
104"""
105
106REGISTER_DEFINE_TEMPLATE = CppTemplate(
107    """
108    (void)py::class_<${class_name}PrimAdapter, PrimitiveFunctionAdapter, std::shared_ptr<${class_name}PrimAdapter>>(
109      *m, "${class_name}Prim_")
110      .def(py::init<>())
111      .def("__call__", &${class_name}PrimAdapter::Call, "Call ${class_name} op.");
112    m->def(\"${pyboost_op_name}\", &mindspore::pynative::${pyboost_cfunc_name}, \"Encrypt the data.\");""")
113REGISTER_TEMPLATE = CppTemplate("void RegisterPyBoostFunction(py::module *m) {${register_func}\n}")
114
115REGISTER_PYBOOST_GRAD_DEFINE_TEMPLATE = CppTemplate(
116    "MS_REG_PYBOOST_GRAD_OP(${pyboost_op_name}, mindspore::runtime::${pyboost_cfunc_name});\n")
117REGISTER_PYBOOST_GRAD_TEMPLATE = CppTemplate("${register_func}")
118
119PYBOOST_FUNCTION_TEMPLATE = CppTemplate.load_from_file(
120    os.path.join(WORK_PATH, './mindspore/ccsrc/pipeline/pynative/op_function/template/pyboost_function.tpl'))
121
122PYBOOST_HEADER_TEMPLATE = CppTemplate.load_from_file(
123    os.path.join(WORK_PATH, './mindspore/ccsrc/pipeline/pynative/op_function/template/pyboost_function_header.tpl'))
124
125PYBOOST_GRAD_FUNCTION_TEMPLATE = CppTemplate.load_from_file(
126    os.path.join(WORK_PATH, './mindspore/ccsrc/runtime/pynative/op_function/template/pyboost_grad_function.tpl'))
127
128PYBOOST_GRAD_HEADER_TEMPLATE = CppTemplate.load_from_file(
129    os.path.join(WORK_PATH,
130                 './mindspore/ccsrc/runtime/pynative/op_function/template/pyboost_grad_function_header.tpl'))
131
132GEN_OPS_DEF_HEADER_TEMPLATE = CppTemplate.load_from_file(
133    os.path.join(WORK_PATH, './mindspore/python/mindspore/ops_generate/gen_ops_def_header.tpl'))
134
135PYBOOST_BASE_OP_DEFINE_TEMPLATE = CppTemplate.load_from_file(
136    os.path.join(WORK_PATH, './mindspore/ccsrc/kernel/pyboost/template/pyboost_op_header.tpl'))
137
138PYBOOST_OP_REGISTER_TEMPLATE = CppTemplate.load_from_file(
139    os.path.join(WORK_PATH, './mindspore/ccsrc/kernel/pyboost/template/pyboost_op_register.tpl'))
140
141# Ascend op generate
142PYBOOST_ASCEND_OP_HEADER_TEMPLATE = CppTemplate.load_from_file(
143    os.path.join(WORK_PATH,
144                 './mindspore/ccsrc/plugin/device/ascend/kernel/pyboost/template/pyboost_aclnn_header_template.tpl'))
145
146PYBOOST_ASCEND_OP_SOURCE_TEMPLATE = CppTemplate.load_from_file(
147    os.path.join(WORK_PATH,
148                 './mindspore/ccsrc/plugin/device/ascend/kernel/pyboost/template/pyboost_aclnn_source_template.tpl'))
149
150PYBOOST_ASCEND_CALL_TEMPLATE = CppTemplate.load_from_file(
151    os.path.join(WORK_PATH,
152                 './mindspore/ccsrc/plugin/device/ascend/kernel/pyboost/template/pyboost_ascend_call_template.tpl'))
153
154PYBOOST_ASCEND_VIEW_CALL_TEMPLATE = CppTemplate.load_from_file(
155    os.path.join(WORK_PATH,
156                 './mindspore/ccsrc/kernel/pyboost/template/'
157                 'pyboost_view_template.tpl'))
158
159PYBOOST_ASCEND_CUSTOMIZE_CALL_TEMPLATE = CppTemplate.load_from_file(
160    os.path.join(WORK_PATH,
161                 './mindspore/ccsrc/plugin/device/ascend/kernel/pyboost/template'
162                 '/pyboost_ascend_customize_call_template.tpl'))
163
164# GPU op generate
165PYBOOST_GPU_OP_HEADER_TEMPLATE = CppTemplate.load_from_file(
166    os.path.join(WORK_PATH,
167                 './mindspore/ccsrc/plugin/device/gpu/kernel/pyboost/template/pyboost_gpu_header_template.tpl'))
168
169PYBOOST_GPU_OP_SOURCE_TEMPLATE = CppTemplate.load_from_file(
170    os.path.join(WORK_PATH,
171                 './mindspore/ccsrc/plugin/device/gpu/kernel/pyboost/template/pyboost_gpu_source_template.tpl'))
172
173PYBOOST_GPU_CALL_TEMPLATE = CppTemplate.load_from_file(
174    os.path.join(WORK_PATH,
175                 './mindspore/ccsrc/plugin/device/gpu/kernel/pyboost/template/pyboost_gpu_call_template.tpl'))
176
177PYBOOST_GPU_VIEW_CALL_TEMPLATE = CppTemplate.load_from_file(
178    os.path.join(WORK_PATH,
179                 './mindspore/ccsrc/kernel/pyboost/template/pyboost_view_template.tpl'))
180
181PYBOOST_GPU_CUSTOMIZE_CALL_TEMPLATE = CppTemplate.load_from_file(
182    os.path.join(WORK_PATH,
183                 './mindspore/ccsrc/plugin/device/gpu/kernel/pyboost/template'
184                 '/pyboost_gpu_customize_call_template.tpl'))
185
186# CPU op generate
187PYBOOST_CPU_OP_HEADER_TEMPLATE = CppTemplate.load_from_file(
188    os.path.join(WORK_PATH,
189                 './mindspore/ccsrc/plugin/device/cpu/kernel/pyboost/template/pyboost_cpu_header_template.tpl'))
190
191PYBOOST_CPU_OP_SOURCE_TEMPLATE = CppTemplate.load_from_file(
192    os.path.join(WORK_PATH,
193                 './mindspore/ccsrc/plugin/device/cpu/kernel/pyboost/template/pyboost_cpu_source_template.tpl'))
194
195PYBOOST_CPU_CALL_TEMPLATE = CppTemplate.load_from_file(
196    os.path.join(WORK_PATH,
197                 './mindspore/ccsrc/plugin/device/cpu/kernel/pyboost/template/pyboost_cpu_call_template.tpl'))
198
199PYBOOST_CPU_VIEW_CALL_TEMPLATE = CppTemplate.load_from_file(
200    os.path.join(WORK_PATH,
201                 './mindspore/ccsrc/kernel/pyboost/template/pyboost_view_template.tpl'))
202
203PYBOOST_CPU_CUSTOMIZE_CALL_TEMPLATE = CppTemplate.load_from_file(
204    os.path.join(WORK_PATH,
205                 './mindspore/ccsrc/plugin/device/cpu/kernel/pyboost/template'
206                 '/pyboost_cpu_customize_call_template.tpl'))
207
208PYBOOST_PY_FUNC_IMPORT_HEADEAR = CppTemplate(
209    """from mindspore._c_expression import ${class_name}Prim_\n"""
210)
211
212PYBOOST_PY_FUNC_TEMPLATE = CppTemplate("""
213def ${func_name}(${func_args}):
214    r\"\"\"
215    ${description}
216    \"\"\"
217    return ${func_impl_name}_impl(${input_args})\n\n""")
218
219OP_PROTO_TEMPLATE = CppTemplate("""
220${class_name}FuncImpl g${class_name}FuncImpl;
221OpDef g${class_name} = {
222  /*.name_=*/"${class_name}",
223  /*.args_=*/ {
224    ${input_args}
225  },
226  /* .returns_ = */ {
227    ${return_args}
228  },
229  /*.signatures_ =*/ {
230    ${signatures}
231  },
232  /*.indexes_ =*/ {
233    ${indexes}
234  },
235  /*.func_impl_=*/g${class_name}FuncImpl,
236  /*.enable_dispatch_ =*/${enable_dispatch},
237  /*.is_view_ =*/${is_view},
238};
239""")
240