1# coding=utf-8 2# 3# Copyright (c) 2025 Huawei Device Co., Ltd. 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15 16from taihe.codegen.abi.analyses import ( 17 GlobFuncABIInfo, 18 PackageABIInfo, 19 TypeABIInfo, 20) 21from taihe.codegen.abi.writer import CHeaderWriter, CSourceWriter 22from taihe.semantics.declarations import ( 23 GlobFuncDecl, 24 PackageDecl, 25 PackageGroup, 26) 27from taihe.utils.analyses import AbstractAnalysis, AnalysisManager 28from taihe.utils.outputs import FileKind, OutputManager 29 30 31class PackageCImplInfo(AbstractAnalysis[PackageDecl]): 32 def __init__(self, am: AnalysisManager, p: PackageDecl) -> None: 33 super().__init__(am, p) 34 self.header = f"{p.name}.impl.h" 35 self.source = f"{p.name}.impl.c" 36 37 38class GlobFuncCImplInfo(AbstractAnalysis[GlobFuncDecl]): 39 def __init__(self, am: AnalysisManager, f: GlobFuncDecl) -> None: 40 super().__init__(am, f) 41 self.macro = f"TH_EXPORT_C_API_{f.name}" 42 43 44class CImplHeadersGenerator: 45 def __init__(self, om: OutputManager, am: AnalysisManager): 46 self.om = om 47 self.am = am 48 49 def generate(self, pg: PackageGroup): 50 for pkg in pg.packages: 51 self.gen_package_file(pkg) 52 53 def gen_package_file(self, pkg: PackageDecl): 54 pkg_c_impl_info = PackageCImplInfo.get(self.am, pkg) 55 pkg_abi_info = PackageABIInfo.get(self.am, pkg) 56 with CHeaderWriter( 57 self.om, 58 f"include/{pkg_c_impl_info.header}", 59 FileKind.C_HEADER, 60 ) as pkg_c_impl_target: 61 pkg_c_impl_target.add_include("taihe/common.h", pkg_abi_info.header) 62 for func in pkg.functions: 63 for param in func.params: 64 type_abi_info = TypeABIInfo.get(self.am, param.ty_ref.resolved_ty) 65 pkg_c_impl_target.add_include(*type_abi_info.impl_headers) 66 if return_ty_ref := func.return_ty_ref: 67 type_abi_info = TypeABIInfo.get(self.am, return_ty_ref.resolved_ty) 68 pkg_c_impl_target.add_include(*type_abi_info.impl_headers) 69 self.gen_func(func, pkg_c_impl_target) 70 71 def gen_func( 72 self, 73 func: GlobFuncDecl, 74 pkg_c_impl_target: CHeaderWriter, 75 ): 76 func_abi_info = GlobFuncABIInfo.get(self.am, func) 77 func_c_impl_info = GlobFuncCImplInfo.get(self.am, func) 78 func_impl = "C_FUNC_IMPL" 79 params = [] 80 args = [] 81 for param in func.params: 82 type_abi_info = TypeABIInfo.get(self.am, param.ty_ref.resolved_ty) 83 params.append(f"{type_abi_info.as_param} {param.name}") 84 args.append(param.name) 85 params_str = ", ".join(params) 86 args_str = ", ".join(args) 87 if return_ty_ref := func.return_ty_ref: 88 type_abi_info = TypeABIInfo.get(self.am, return_ty_ref.resolved_ty) 89 return_ty_name = type_abi_info.as_owner 90 else: 91 return_ty_name = "void" 92 pkg_c_impl_target.writelns( 93 f"#define {func_c_impl_info.macro}({func_impl}) \\", 94 f" {return_ty_name} {func_abi_info.mangled_name}({params_str}) {{ \\", 95 f" return {func_impl}({args_str}); \\", 96 f" }}", 97 ) 98 99 100class CImplSourcesGenerator: 101 def __init__(self, om: OutputManager, am: AnalysisManager): 102 self.om = om 103 self.am = am 104 105 def generate(self, pg: PackageGroup): 106 for pkg in pg.packages: 107 self.gen_package_file(pkg) 108 109 def gen_package_file(self, pkg: PackageDecl): 110 pkg_c_impl_info = PackageCImplInfo.get(self.am, pkg) 111 with CSourceWriter( 112 self.om, 113 f"temp/{pkg_c_impl_info.source}", 114 FileKind.TEMPLATE, 115 ) as pkg_c_impl_target: 116 pkg_c_impl_target.add_include(pkg_c_impl_info.header) 117 for func in pkg.functions: 118 self.gen_func(func, pkg_c_impl_target) 119 120 def gen_func( 121 self, 122 func: GlobFuncDecl, 123 pkg_c_impl_target: CSourceWriter, 124 ): 125 func_c_impl_info = GlobFuncCImplInfo.get(self.am, func) 126 func_c_impl_name = f"{func.name}_impl" 127 params = [] 128 for param in func.params: 129 type_abi_info = TypeABIInfo.get(self.am, param.ty_ref.resolved_ty) 130 params.append(f"{type_abi_info.as_param} {param.name}") 131 params_str = ", ".join(params) 132 if return_ty_ref := func.return_ty_ref: 133 type_abi_info = TypeABIInfo.get(self.am, return_ty_ref.resolved_ty) 134 return_ty_name = type_abi_info.as_owner 135 else: 136 return_ty_name = "void" 137 with pkg_c_impl_target.indented( 138 f"{return_ty_name} {func_c_impl_name}({params_str}) {{", 139 f"}}", 140 ): 141 pkg_c_impl_target.writelns( 142 f"// TODO", 143 ) 144 pkg_c_impl_target.writelns( 145 f"{func_c_impl_info.macro}({func_c_impl_name});", 146 )