• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# coding=utf-8
2#
3# Copyright (c) 2025 Huawei Device Co., Ltd.
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16from taihe.codegen.abi.analyses import (
17    GlobFuncABIInfo,
18    PackageABIInfo,
19    TypeABIInfo,
20)
21from taihe.codegen.abi.writer import CHeaderWriter, CSourceWriter
22from taihe.semantics.declarations import (
23    GlobFuncDecl,
24    PackageDecl,
25    PackageGroup,
26)
27from taihe.utils.analyses import AbstractAnalysis, AnalysisManager
28from taihe.utils.outputs import FileKind, OutputManager
29
30
31class PackageCImplInfo(AbstractAnalysis[PackageDecl]):
32    def __init__(self, am: AnalysisManager, p: PackageDecl) -> None:
33        super().__init__(am, p)
34        self.header = f"{p.name}.impl.h"
35        self.source = f"{p.name}.impl.c"
36
37
38class GlobFuncCImplInfo(AbstractAnalysis[GlobFuncDecl]):
39    def __init__(self, am: AnalysisManager, f: GlobFuncDecl) -> None:
40        super().__init__(am, f)
41        self.macro = f"TH_EXPORT_C_API_{f.name}"
42
43
44class CImplHeadersGenerator:
45    def __init__(self, om: OutputManager, am: AnalysisManager):
46        self.om = om
47        self.am = am
48
49    def generate(self, pg: PackageGroup):
50        for pkg in pg.packages:
51            self.gen_package_file(pkg)
52
53    def gen_package_file(self, pkg: PackageDecl):
54        pkg_c_impl_info = PackageCImplInfo.get(self.am, pkg)
55        pkg_abi_info = PackageABIInfo.get(self.am, pkg)
56        with CHeaderWriter(
57            self.om,
58            f"include/{pkg_c_impl_info.header}",
59            FileKind.C_HEADER,
60        ) as pkg_c_impl_target:
61            pkg_c_impl_target.add_include("taihe/common.h", pkg_abi_info.header)
62            for func in pkg.functions:
63                for param in func.params:
64                    type_abi_info = TypeABIInfo.get(self.am, param.ty_ref.resolved_ty)
65                    pkg_c_impl_target.add_include(*type_abi_info.impl_headers)
66                if return_ty_ref := func.return_ty_ref:
67                    type_abi_info = TypeABIInfo.get(self.am, return_ty_ref.resolved_ty)
68                    pkg_c_impl_target.add_include(*type_abi_info.impl_headers)
69                self.gen_func(func, pkg_c_impl_target)
70
71    def gen_func(
72        self,
73        func: GlobFuncDecl,
74        pkg_c_impl_target: CHeaderWriter,
75    ):
76        func_abi_info = GlobFuncABIInfo.get(self.am, func)
77        func_c_impl_info = GlobFuncCImplInfo.get(self.am, func)
78        func_impl = "C_FUNC_IMPL"
79        params = []
80        args = []
81        for param in func.params:
82            type_abi_info = TypeABIInfo.get(self.am, param.ty_ref.resolved_ty)
83            params.append(f"{type_abi_info.as_param} {param.name}")
84            args.append(param.name)
85        params_str = ", ".join(params)
86        args_str = ", ".join(args)
87        if return_ty_ref := func.return_ty_ref:
88            type_abi_info = TypeABIInfo.get(self.am, return_ty_ref.resolved_ty)
89            return_ty_name = type_abi_info.as_owner
90        else:
91            return_ty_name = "void"
92        pkg_c_impl_target.writelns(
93            f"#define {func_c_impl_info.macro}({func_impl}) \\",
94            f"    {return_ty_name} {func_abi_info.mangled_name}({params_str}) {{ \\",
95            f"        return {func_impl}({args_str}); \\",
96            f"    }}",
97        )
98
99
100class CImplSourcesGenerator:
101    def __init__(self, om: OutputManager, am: AnalysisManager):
102        self.om = om
103        self.am = am
104
105    def generate(self, pg: PackageGroup):
106        for pkg in pg.packages:
107            self.gen_package_file(pkg)
108
109    def gen_package_file(self, pkg: PackageDecl):
110        pkg_c_impl_info = PackageCImplInfo.get(self.am, pkg)
111        with CSourceWriter(
112            self.om,
113            f"temp/{pkg_c_impl_info.source}",
114            FileKind.TEMPLATE,
115        ) as pkg_c_impl_target:
116            pkg_c_impl_target.add_include(pkg_c_impl_info.header)
117            for func in pkg.functions:
118                self.gen_func(func, pkg_c_impl_target)
119
120    def gen_func(
121        self,
122        func: GlobFuncDecl,
123        pkg_c_impl_target: CSourceWriter,
124    ):
125        func_c_impl_info = GlobFuncCImplInfo.get(self.am, func)
126        func_c_impl_name = f"{func.name}_impl"
127        params = []
128        for param in func.params:
129            type_abi_info = TypeABIInfo.get(self.am, param.ty_ref.resolved_ty)
130            params.append(f"{type_abi_info.as_param} {param.name}")
131        params_str = ", ".join(params)
132        if return_ty_ref := func.return_ty_ref:
133            type_abi_info = TypeABIInfo.get(self.am, return_ty_ref.resolved_ty)
134            return_ty_name = type_abi_info.as_owner
135        else:
136            return_ty_name = "void"
137        with pkg_c_impl_target.indented(
138            f"{return_ty_name} {func_c_impl_name}({params_str}) {{",
139            f"}}",
140        ):
141            pkg_c_impl_target.writelns(
142                f"// TODO",
143            )
144        pkg_c_impl_target.writelns(
145            f"{func_c_impl_info.macro}({func_c_impl_name});",
146        )