• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# coding=utf-8
2#
3# Copyright (c) 2025 Huawei Device Co., Ltd.
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16from taihe.codegen.abi.analyses import (
17    GlobFuncABIInfo,
18    PackageABIInfo,
19)
20from taihe.codegen.abi.writer import CHeaderWriter
21from taihe.codegen.cpp.analyses import (
22    GlobFuncCppUserInfo,
23    PackageCppInfo,
24    PackageCppUserInfo,
25    TypeCppInfo,
26)
27from taihe.semantics.declarations import (
28    GlobFuncDecl,
29    PackageDecl,
30    PackageGroup,
31)
32from taihe.utils.analyses import AnalysisManager
33from taihe.utils.outputs import FileKind, OutputManager
34
35
36class CppUserHeadersGenerator:
37    def __init__(self, om: OutputManager, am: AnalysisManager):
38        self.om = om
39        self.am = am
40
41    def generate(self, pg: PackageGroup):
42        for pkg in pg.packages:
43            self.gen_package_file(pkg)
44
45    def gen_package_file(self, pkg: PackageDecl):
46        pkg_abi_info = PackageABIInfo.get(self.am, pkg)
47        pkg_cpp_info = PackageCppInfo.get(self.am, pkg)
48        pkg_cpp_user_info = PackageCppUserInfo.get(self.am, pkg)
49        with CHeaderWriter(
50            self.om,
51            f"include/{pkg_cpp_user_info.header}",
52            FileKind.CPP_HEADER,
53        ) as pkg_cpp_target:
54            # types
55            pkg_cpp_target.add_include(pkg_cpp_info.header)
56            # functions
57            pkg_cpp_target.add_include("taihe/common.hpp")
58            pkg_cpp_target.add_include(pkg_abi_info.header)
59            for func in pkg.functions:
60                for param in func.params:
61                    type_cpp_info = TypeCppInfo.get(self.am, param.ty_ref.resolved_ty)
62                    pkg_cpp_target.add_include(*type_cpp_info.impl_headers)
63                if return_ty_ref := func.return_ty_ref:
64                    type_cpp_info = TypeCppInfo.get(self.am, return_ty_ref.resolved_ty)
65                    pkg_cpp_target.add_include(*type_cpp_info.impl_headers)
66                self.gen_func(func, pkg_cpp_target)
67
68    def gen_func(
69        self,
70        func: GlobFuncDecl,
71        pkg_cpp_target: CHeaderWriter,
72    ):
73        func_abi_info = GlobFuncABIInfo.get(self.am, func)
74        func_cpp_user_info = GlobFuncCppUserInfo.get(self.am, func)
75        params_cpp = []
76        args_into_abi = []
77        for param in func.params:
78            type_cpp_info = TypeCppInfo.get(self.am, param.ty_ref.resolved_ty)
79            params_cpp.append(f"{type_cpp_info.as_param} {param.name}")
80            args_into_abi.append(type_cpp_info.pass_into_abi(param.name))
81        params_cpp_str = ", ".join(params_cpp)
82        args_into_abi_str = ", ".join(args_into_abi)
83        abi_result = f"{func_abi_info.mangled_name}({args_into_abi_str})"
84        if return_ty_ref := func.return_ty_ref:
85            type_cpp_info = TypeCppInfo.get(self.am, return_ty_ref.resolved_ty)
86            cpp_return_ty_name = type_cpp_info.as_owner
87            cpp_result = type_cpp_info.return_from_abi(abi_result)
88        else:
89            cpp_return_ty_name = "void"
90            cpp_result = abi_result
91        with pkg_cpp_target.indented(
92            f"namespace {func_cpp_user_info.namespace} {{",
93            f"}}",
94            indent="",
95        ):
96            with pkg_cpp_target.indented(
97                f"inline {cpp_return_ty_name} {func_cpp_user_info.call_name}({params_cpp_str}) {{",
98                f"}}",
99            ):
100                pkg_cpp_target.writelns(
101                    f"return {cpp_result};",
102                )