• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# coding=utf-8
2#
3# Copyright (c) 2025 Huawei Device Co., Ltd.
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import re
17
18from taihe.codegen.abi.analyses import (
19    GlobFuncABIInfo,
20    IfaceABIInfo,
21    PackageABIInfo,
22    TypeABIInfo,
23)
24from taihe.codegen.abi.writer import CHeaderWriter, CSourceWriter
25from taihe.codegen.cpp.analyses import (
26    IfaceMethodCppInfo,
27    PackageCppInfo,
28    TypeCppInfo,
29)
30from taihe.semantics.declarations import (
31    GlobFuncDecl,
32    IfaceDecl,
33    IfaceMethodDecl,
34    PackageDecl,
35    PackageGroup,
36)
37from taihe.semantics.types import (
38    IfaceType,
39)
40from taihe.utils.analyses import AbstractAnalysis, AnalysisManager
41from taihe.utils.outputs import FileKind, OutputManager
42
43
44class PackageCppImplInfo(AbstractAnalysis[PackageDecl]):
45    def __init__(self, am: AnalysisManager, p: PackageDecl) -> None:
46        super().__init__(am, p)
47        self.header = f"{p.name}.impl.hpp"
48        self.source = f"{p.name}.impl.cpp"
49
50
51class GlobFuncCppImplInfo(AbstractAnalysis[GlobFuncDecl]):
52    def __init__(self, am: AnalysisManager, f: GlobFuncDecl) -> None:
53        super().__init__(am, f)
54        self.macro = f"TH_EXPORT_CPP_API_{f.name}"
55
56
57class CppImplHeadersGenerator:
58    def __init__(self, om: OutputManager, am: AnalysisManager):
59        self.om = om
60        self.am = am
61
62    def generate(self, pg: PackageGroup):
63        for pkg in pg.packages:
64            self.gen_package_file(pkg)
65
66    def gen_package_file(self, pkg: PackageDecl):
67        pkg_abi_info = PackageABIInfo.get(self.am, pkg)
68        pkg_cpp_impl_info = PackageCppImplInfo.get(self.am, pkg)
69        with CHeaderWriter(
70            self.om,
71            f"include/{pkg_cpp_impl_info.header}",
72            FileKind.CPP_HEADER,
73        ) as pkg_cpp_impl_target:
74            pkg_cpp_impl_target.add_include("taihe/common.hpp")
75            pkg_cpp_impl_target.add_include(pkg_abi_info.header)
76            for func in pkg.functions:
77                for param in func.params:
78                    type_cpp_info = TypeCppInfo.get(self.am, param.ty_ref.resolved_ty)
79                    pkg_cpp_impl_target.add_include(*type_cpp_info.impl_headers)
80                if return_ty_ref := func.return_ty_ref:
81                    type_cpp_info = TypeCppInfo.get(self.am, return_ty_ref.resolved_ty)
82                    pkg_cpp_impl_target.add_include(*type_cpp_info.impl_headers)
83                self.gen_func(func, pkg_cpp_impl_target)
84
85    def gen_func(
86        self,
87        func: GlobFuncDecl,
88        pkg_cpp_impl_target: CHeaderWriter,
89    ):
90        func_abi_info = GlobFuncABIInfo.get(self.am, func)
91        func_cpp_impl_info = GlobFuncCppImplInfo.get(self.am, func)
92        func_impl = "CPP_FUNC_IMPL"
93        args_from_abi = []
94        abi_params = []
95        for param in func.params:
96            type_cpp_info = TypeCppInfo.get(self.am, param.ty_ref.resolved_ty)
97            type_abi_info = TypeABIInfo.get(self.am, param.ty_ref.resolved_ty)
98            args_from_abi.append(type_cpp_info.pass_from_abi(param.name))
99            abi_params.append(f"{type_abi_info.as_param} {param.name}")
100        args_from_abi_str = ", ".join(args_from_abi)
101        abi_params_str = ", ".join(abi_params)
102        cpp_result = f"{func_impl}({args_from_abi_str})"
103        if return_ty_ref := func.return_ty_ref:
104            type_cpp_info = TypeCppInfo.get(self.am, return_ty_ref.resolved_ty)
105            type_abi_info = TypeABIInfo.get(self.am, return_ty_ref.resolved_ty)
106            abi_return_ty_name = type_abi_info.as_owner
107            abi_result = type_cpp_info.return_into_abi(cpp_result)
108        else:
109            abi_return_ty_name = "void"
110            abi_result = cpp_result
111        pkg_cpp_impl_target.writelns(
112            f"#define {func_cpp_impl_info.macro}({func_impl}) \\",
113            f"    {abi_return_ty_name} {func_abi_info.mangled_name}({abi_params_str}) {{ \\",
114            f"        return {abi_result}; \\",
115            f"    }}",
116        )
117
118
119class CppImplSourcesGenerator:
120    def __init__(self, om: OutputManager, am: AnalysisManager):
121        self.om = om
122        self.am = am
123        self.using_namespaces: list[str] = []
124
125    @property
126    def make_holder(self):
127        return self.mask("taihe::make_holder")
128
129    @property
130    def runtime_error(self):
131        return self.mask("std::runtime_error")
132
133    def mask(self, cpp_type: str):
134        pattern = r"(::)?([A-Za-z_][A-Za-z_0-9]*::)*[A-Za-z_][A-Za-z_0-9]*"
135
136        def replace_ns(match):
137            matched = match.group(0)
138            for ns in self.using_namespaces:
139                ns = ns + "::"
140                if matched.startswith(ns):
141                    return matched[len(ns) :]
142                ns = "::" + ns
143                if matched.startswith(ns):
144                    return matched[len(ns) :]
145            return matched
146
147        return re.sub(pattern, replace_ns, cpp_type)
148
149    def generate(self, pg: PackageGroup):
150        for pkg in pg.packages:
151            self.gen_package_file(pkg)
152
153    def gen_package_file(self, pkg: PackageDecl):
154        pkg_cpp_info = PackageCppInfo.get(self.am, pkg)
155        pkg_cpp_impl_info = PackageCppImplInfo.get(self.am, pkg)
156        with CSourceWriter(
157            self.om,
158            f"temp/{pkg_cpp_impl_info.source}",
159            FileKind.TEMPLATE,
160        ) as pkg_cpp_impl_target:
161            pkg_cpp_impl_target.add_include(pkg_cpp_info.header)
162            pkg_cpp_impl_target.add_include(pkg_cpp_impl_info.header)
163            pkg_cpp_impl_target.add_include("taihe/runtime.hpp")
164            pkg_cpp_impl_target.add_include("stdexcept")
165            pkg_cpp_impl_target.newline()
166            self.using_namespaces = []
167            pkg_cpp_impl_target.newline()
168            self.gen_anonymous_namespace_block(pkg, pkg_cpp_impl_target)
169            pkg_cpp_impl_target.newline()
170            pkg_cpp_impl_target.writelns(
171                "// Since these macros are auto-generate, lint will cause false positive.",
172                "// NOLINTBEGIN",
173            )
174            for func in pkg.functions:
175                self.gen_func_macro(func, pkg_cpp_impl_target)
176            pkg_cpp_impl_target.writelns(
177                "// NOLINTEND",
178            )
179            self.using_namespaces = []
180
181    def gen_using_namespace(
182        self,
183        pkg_cpp_impl_target: CSourceWriter,
184        namespace: str,
185    ):
186        pkg_cpp_impl_target.writelns(
187            f"using namespace {namespace};",
188        )
189        self.using_namespaces.append(namespace)
190
191    def gen_anonymous_namespace_block(
192        self,
193        pkg: PackageDecl,
194        pkg_cpp_impl_target: CSourceWriter,
195    ):
196        with pkg_cpp_impl_target.indented(
197            f"namespace {{",
198            f"}}  // namespace",
199            indent="",
200        ):
201            pkg_cpp_impl_target.writelns(
202                f"// To be implemented.",
203            )
204            for iface in pkg.interfaces:
205                pkg_cpp_impl_target.newline()
206                self.gen_iface(iface, pkg_cpp_impl_target)
207            for func in pkg.functions:
208                pkg_cpp_impl_target.newline()
209                self.gen_func_impl(func, pkg_cpp_impl_target)
210
211    def gen_iface(
212        self,
213        iface: IfaceDecl,
214        pkg_cpp_impl_target: CSourceWriter,
215    ):
216        iface_abi_info = IfaceABIInfo.get(self.am, iface)
217        impl_name = f"{iface.name}Impl"
218        with pkg_cpp_impl_target.indented(
219            f"class {impl_name} {{",
220            f"}};",
221        ):
222            pkg_cpp_impl_target.writelns(
223                f"public:",
224            )
225            with pkg_cpp_impl_target.indented(
226                f"{impl_name}() {{",
227                f"}}",
228            ):
229                pkg_cpp_impl_target.writelns(
230                    f"// Don't forget to implement the constructor.",
231                )
232            for ancestor in iface_abi_info.ancestor_dict:
233                for func in ancestor.methods:
234                    pkg_cpp_impl_target.newline()
235                    self.gen_method_impl(func, pkg_cpp_impl_target)
236
237    def gen_method_impl(
238        self,
239        func: IfaceMethodDecl,
240        pkg_cpp_impl_target: CSourceWriter,
241    ):
242        method_cpp_info = IfaceMethodCppInfo.get(self.am, func)
243        func_cpp_impl_name = method_cpp_info.impl_name
244        cpp_params = []
245        for param in func.params:
246            type_cpp_info = TypeCppInfo.get(self.am, param.ty_ref.resolved_ty)
247            cpp_params.append(f"{self.mask(type_cpp_info.as_param)} {param.name}")
248        cpp_params_str = ", ".join(cpp_params)
249        if return_ty_ref := func.return_ty_ref:
250            type_cpp_info = TypeCppInfo.get(self.am, return_ty_ref.resolved_ty)
251            cpp_return_ty_name = self.mask(type_cpp_info.as_owner)
252        else:
253            cpp_return_ty_name = "void"
254        with pkg_cpp_impl_target.indented(
255            f"{cpp_return_ty_name} {func_cpp_impl_name}({cpp_params_str}) {{",
256            f"}}",
257        ):
258            if return_ty_ref and isinstance(return_ty_ref.resolved_ty, IfaceType):
259                impl_name = f"{return_ty_ref.resolved_ty.ty_decl.name}Impl"
260                pkg_cpp_impl_target.writelns(
261                    f"// The parameters in the make_holder function should be of the same type",
262                    f"// as the parameters in the constructor of the actual implementation class.",
263                    f"return {self.make_holder}<{impl_name}, {cpp_return_ty_name}>();",
264                )
265            else:
266                pkg_cpp_impl_target.writelns(
267                    f'TH_THROW({self.runtime_error}, "{func_cpp_impl_name} not implemented");',
268                )
269
270    def gen_func_impl(
271        self,
272        func: GlobFuncDecl,
273        pkg_cpp_impl_target: CSourceWriter,
274    ):
275        func_cpp_impl_name = func.name
276        cpp_params = []
277        for param in func.params:
278            type_cpp_info = TypeCppInfo.get(self.am, param.ty_ref.resolved_ty)
279            cpp_params.append(f"{self.mask(type_cpp_info.as_param)} {param.name}")
280        cpp_params_str = ", ".join(cpp_params)
281        if return_ty_ref := func.return_ty_ref:
282            type_cpp_info = TypeCppInfo.get(self.am, return_ty_ref.resolved_ty)
283            cpp_return_ty_name = self.mask(type_cpp_info.as_owner)
284        else:
285            cpp_return_ty_name = "void"
286        with pkg_cpp_impl_target.indented(
287            f"{cpp_return_ty_name} {func_cpp_impl_name}({cpp_params_str}) {{",
288            f"}}",
289        ):
290            if return_ty_ref and isinstance(return_ty_ref.resolved_ty, IfaceType):
291                impl_name = f"{return_ty_ref.resolved_ty.ty_decl.name}Impl"
292                pkg_cpp_impl_target.writelns(
293                    f"// The parameters in the make_holder function should be of the same type",
294                    f"// as the parameters in the constructor of the actual implementation class.",
295                    f"return {self.make_holder}<{impl_name}, {cpp_return_ty_name}>();",
296                )
297            else:
298                pkg_cpp_impl_target.writelns(
299                    f'TH_THROW({self.runtime_error}, "{func_cpp_impl_name} not implemented");',
300                )
301
302    def gen_func_macro(
303        self,
304        func: GlobFuncDecl,
305        pkg_cpp_impl_target: CSourceWriter,
306    ):
307        func_cpp_impl_info = GlobFuncCppImplInfo.get(self.am, func)
308        func_cpp_impl_name = f"{func.name}"
309        pkg_cpp_impl_target.writelns(
310            f"{func_cpp_impl_info.macro}({func_cpp_impl_name});",
311        )