1# coding=utf-8 2# 3# Copyright (c) 2025 Huawei Device Co., Ltd. 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15 16import re 17 18from taihe.codegen.abi.analyses import ( 19 GlobFuncABIInfo, 20 IfaceABIInfo, 21 PackageABIInfo, 22 TypeABIInfo, 23) 24from taihe.codegen.abi.writer import CHeaderWriter, CSourceWriter 25from taihe.codegen.cpp.analyses import ( 26 IfaceMethodCppInfo, 27 PackageCppInfo, 28 TypeCppInfo, 29) 30from taihe.semantics.declarations import ( 31 GlobFuncDecl, 32 IfaceDecl, 33 IfaceMethodDecl, 34 PackageDecl, 35 PackageGroup, 36) 37from taihe.semantics.types import ( 38 IfaceType, 39) 40from taihe.utils.analyses import AbstractAnalysis, AnalysisManager 41from taihe.utils.outputs import FileKind, OutputManager 42 43 44class PackageCppImplInfo(AbstractAnalysis[PackageDecl]): 45 def __init__(self, am: AnalysisManager, p: PackageDecl) -> None: 46 super().__init__(am, p) 47 self.header = f"{p.name}.impl.hpp" 48 self.source = f"{p.name}.impl.cpp" 49 50 51class GlobFuncCppImplInfo(AbstractAnalysis[GlobFuncDecl]): 52 def __init__(self, am: AnalysisManager, f: GlobFuncDecl) -> None: 53 super().__init__(am, f) 54 self.macro = f"TH_EXPORT_CPP_API_{f.name}" 55 56 57class CppImplHeadersGenerator: 58 def __init__(self, om: OutputManager, am: AnalysisManager): 59 self.om = om 60 self.am = am 61 62 def generate(self, pg: PackageGroup): 63 for pkg in pg.packages: 64 self.gen_package_file(pkg) 65 66 def gen_package_file(self, pkg: PackageDecl): 67 pkg_abi_info = PackageABIInfo.get(self.am, pkg) 68 pkg_cpp_impl_info = PackageCppImplInfo.get(self.am, pkg) 69 with CHeaderWriter( 70 self.om, 71 f"include/{pkg_cpp_impl_info.header}", 72 FileKind.CPP_HEADER, 73 ) as pkg_cpp_impl_target: 74 pkg_cpp_impl_target.add_include("taihe/common.hpp") 75 pkg_cpp_impl_target.add_include(pkg_abi_info.header) 76 for func in pkg.functions: 77 for param in func.params: 78 type_cpp_info = TypeCppInfo.get(self.am, param.ty_ref.resolved_ty) 79 pkg_cpp_impl_target.add_include(*type_cpp_info.impl_headers) 80 if return_ty_ref := func.return_ty_ref: 81 type_cpp_info = TypeCppInfo.get(self.am, return_ty_ref.resolved_ty) 82 pkg_cpp_impl_target.add_include(*type_cpp_info.impl_headers) 83 self.gen_func(func, pkg_cpp_impl_target) 84 85 def gen_func( 86 self, 87 func: GlobFuncDecl, 88 pkg_cpp_impl_target: CHeaderWriter, 89 ): 90 func_abi_info = GlobFuncABIInfo.get(self.am, func) 91 func_cpp_impl_info = GlobFuncCppImplInfo.get(self.am, func) 92 func_impl = "CPP_FUNC_IMPL" 93 args_from_abi = [] 94 abi_params = [] 95 for param in func.params: 96 type_cpp_info = TypeCppInfo.get(self.am, param.ty_ref.resolved_ty) 97 type_abi_info = TypeABIInfo.get(self.am, param.ty_ref.resolved_ty) 98 args_from_abi.append(type_cpp_info.pass_from_abi(param.name)) 99 abi_params.append(f"{type_abi_info.as_param} {param.name}") 100 args_from_abi_str = ", ".join(args_from_abi) 101 abi_params_str = ", ".join(abi_params) 102 cpp_result = f"{func_impl}({args_from_abi_str})" 103 if return_ty_ref := func.return_ty_ref: 104 type_cpp_info = TypeCppInfo.get(self.am, return_ty_ref.resolved_ty) 105 type_abi_info = TypeABIInfo.get(self.am, return_ty_ref.resolved_ty) 106 abi_return_ty_name = type_abi_info.as_owner 107 abi_result = type_cpp_info.return_into_abi(cpp_result) 108 else: 109 abi_return_ty_name = "void" 110 abi_result = cpp_result 111 pkg_cpp_impl_target.writelns( 112 f"#define {func_cpp_impl_info.macro}({func_impl}) \\", 113 f" {abi_return_ty_name} {func_abi_info.mangled_name}({abi_params_str}) {{ \\", 114 f" return {abi_result}; \\", 115 f" }}", 116 ) 117 118 119class CppImplSourcesGenerator: 120 def __init__(self, om: OutputManager, am: AnalysisManager): 121 self.om = om 122 self.am = am 123 self.using_namespaces: list[str] = [] 124 125 @property 126 def make_holder(self): 127 return self.mask("taihe::make_holder") 128 129 @property 130 def runtime_error(self): 131 return self.mask("std::runtime_error") 132 133 def mask(self, cpp_type: str): 134 pattern = r"(::)?([A-Za-z_][A-Za-z_0-9]*::)*[A-Za-z_][A-Za-z_0-9]*" 135 136 def replace_ns(match): 137 matched = match.group(0) 138 for ns in self.using_namespaces: 139 ns = ns + "::" 140 if matched.startswith(ns): 141 return matched[len(ns) :] 142 ns = "::" + ns 143 if matched.startswith(ns): 144 return matched[len(ns) :] 145 return matched 146 147 return re.sub(pattern, replace_ns, cpp_type) 148 149 def generate(self, pg: PackageGroup): 150 for pkg in pg.packages: 151 self.gen_package_file(pkg) 152 153 def gen_package_file(self, pkg: PackageDecl): 154 pkg_cpp_info = PackageCppInfo.get(self.am, pkg) 155 pkg_cpp_impl_info = PackageCppImplInfo.get(self.am, pkg) 156 with CSourceWriter( 157 self.om, 158 f"temp/{pkg_cpp_impl_info.source}", 159 FileKind.TEMPLATE, 160 ) as pkg_cpp_impl_target: 161 pkg_cpp_impl_target.add_include(pkg_cpp_info.header) 162 pkg_cpp_impl_target.add_include(pkg_cpp_impl_info.header) 163 pkg_cpp_impl_target.add_include("taihe/runtime.hpp") 164 pkg_cpp_impl_target.add_include("stdexcept") 165 pkg_cpp_impl_target.newline() 166 self.using_namespaces = [] 167 pkg_cpp_impl_target.newline() 168 self.gen_anonymous_namespace_block(pkg, pkg_cpp_impl_target) 169 pkg_cpp_impl_target.newline() 170 pkg_cpp_impl_target.writelns( 171 "// Since these macros are auto-generate, lint will cause false positive.", 172 "// NOLINTBEGIN", 173 ) 174 for func in pkg.functions: 175 self.gen_func_macro(func, pkg_cpp_impl_target) 176 pkg_cpp_impl_target.writelns( 177 "// NOLINTEND", 178 ) 179 self.using_namespaces = [] 180 181 def gen_using_namespace( 182 self, 183 pkg_cpp_impl_target: CSourceWriter, 184 namespace: str, 185 ): 186 pkg_cpp_impl_target.writelns( 187 f"using namespace {namespace};", 188 ) 189 self.using_namespaces.append(namespace) 190 191 def gen_anonymous_namespace_block( 192 self, 193 pkg: PackageDecl, 194 pkg_cpp_impl_target: CSourceWriter, 195 ): 196 with pkg_cpp_impl_target.indented( 197 f"namespace {{", 198 f"}} // namespace", 199 indent="", 200 ): 201 pkg_cpp_impl_target.writelns( 202 f"// To be implemented.", 203 ) 204 for iface in pkg.interfaces: 205 pkg_cpp_impl_target.newline() 206 self.gen_iface(iface, pkg_cpp_impl_target) 207 for func in pkg.functions: 208 pkg_cpp_impl_target.newline() 209 self.gen_func_impl(func, pkg_cpp_impl_target) 210 211 def gen_iface( 212 self, 213 iface: IfaceDecl, 214 pkg_cpp_impl_target: CSourceWriter, 215 ): 216 iface_abi_info = IfaceABIInfo.get(self.am, iface) 217 impl_name = f"{iface.name}Impl" 218 with pkg_cpp_impl_target.indented( 219 f"class {impl_name} {{", 220 f"}};", 221 ): 222 pkg_cpp_impl_target.writelns( 223 f"public:", 224 ) 225 with pkg_cpp_impl_target.indented( 226 f"{impl_name}() {{", 227 f"}}", 228 ): 229 pkg_cpp_impl_target.writelns( 230 f"// Don't forget to implement the constructor.", 231 ) 232 for ancestor in iface_abi_info.ancestor_dict: 233 for func in ancestor.methods: 234 pkg_cpp_impl_target.newline() 235 self.gen_method_impl(func, pkg_cpp_impl_target) 236 237 def gen_method_impl( 238 self, 239 func: IfaceMethodDecl, 240 pkg_cpp_impl_target: CSourceWriter, 241 ): 242 method_cpp_info = IfaceMethodCppInfo.get(self.am, func) 243 func_cpp_impl_name = method_cpp_info.impl_name 244 cpp_params = [] 245 for param in func.params: 246 type_cpp_info = TypeCppInfo.get(self.am, param.ty_ref.resolved_ty) 247 cpp_params.append(f"{self.mask(type_cpp_info.as_param)} {param.name}") 248 cpp_params_str = ", ".join(cpp_params) 249 if return_ty_ref := func.return_ty_ref: 250 type_cpp_info = TypeCppInfo.get(self.am, return_ty_ref.resolved_ty) 251 cpp_return_ty_name = self.mask(type_cpp_info.as_owner) 252 else: 253 cpp_return_ty_name = "void" 254 with pkg_cpp_impl_target.indented( 255 f"{cpp_return_ty_name} {func_cpp_impl_name}({cpp_params_str}) {{", 256 f"}}", 257 ): 258 if return_ty_ref and isinstance(return_ty_ref.resolved_ty, IfaceType): 259 impl_name = f"{return_ty_ref.resolved_ty.ty_decl.name}Impl" 260 pkg_cpp_impl_target.writelns( 261 f"// The parameters in the make_holder function should be of the same type", 262 f"// as the parameters in the constructor of the actual implementation class.", 263 f"return {self.make_holder}<{impl_name}, {cpp_return_ty_name}>();", 264 ) 265 else: 266 pkg_cpp_impl_target.writelns( 267 f'TH_THROW({self.runtime_error}, "{func_cpp_impl_name} not implemented");', 268 ) 269 270 def gen_func_impl( 271 self, 272 func: GlobFuncDecl, 273 pkg_cpp_impl_target: CSourceWriter, 274 ): 275 func_cpp_impl_name = func.name 276 cpp_params = [] 277 for param in func.params: 278 type_cpp_info = TypeCppInfo.get(self.am, param.ty_ref.resolved_ty) 279 cpp_params.append(f"{self.mask(type_cpp_info.as_param)} {param.name}") 280 cpp_params_str = ", ".join(cpp_params) 281 if return_ty_ref := func.return_ty_ref: 282 type_cpp_info = TypeCppInfo.get(self.am, return_ty_ref.resolved_ty) 283 cpp_return_ty_name = self.mask(type_cpp_info.as_owner) 284 else: 285 cpp_return_ty_name = "void" 286 with pkg_cpp_impl_target.indented( 287 f"{cpp_return_ty_name} {func_cpp_impl_name}({cpp_params_str}) {{", 288 f"}}", 289 ): 290 if return_ty_ref and isinstance(return_ty_ref.resolved_ty, IfaceType): 291 impl_name = f"{return_ty_ref.resolved_ty.ty_decl.name}Impl" 292 pkg_cpp_impl_target.writelns( 293 f"// The parameters in the make_holder function should be of the same type", 294 f"// as the parameters in the constructor of the actual implementation class.", 295 f"return {self.make_holder}<{impl_name}, {cpp_return_ty_name}>();", 296 ) 297 else: 298 pkg_cpp_impl_target.writelns( 299 f'TH_THROW({self.runtime_error}, "{func_cpp_impl_name} not implemented");', 300 ) 301 302 def gen_func_macro( 303 self, 304 func: GlobFuncDecl, 305 pkg_cpp_impl_target: CSourceWriter, 306 ): 307 func_cpp_impl_info = GlobFuncCppImplInfo.get(self.am, func) 308 func_cpp_impl_name = f"{func.name}" 309 pkg_cpp_impl_target.writelns( 310 f"{func_cpp_impl_info.macro}({func_cpp_impl_name});", 311 )