1# coding=utf-8 2# 3# Copyright (c) 2025 Huawei Device Co., Ltd. 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15 16from taihe.codegen.abi.analyses import ( 17 GlobFuncABIInfo, 18 PackageABIInfo, 19) 20from taihe.codegen.abi.writer import CHeaderWriter 21from taihe.codegen.cpp.analyses import ( 22 GlobFuncCppUserInfo, 23 PackageCppInfo, 24 PackageCppUserInfo, 25 TypeCppInfo, 26) 27from taihe.semantics.declarations import ( 28 GlobFuncDecl, 29 PackageDecl, 30 PackageGroup, 31) 32from taihe.utils.analyses import AnalysisManager 33from taihe.utils.outputs import FileKind, OutputManager 34 35 36class CppUserHeadersGenerator: 37 def __init__(self, om: OutputManager, am: AnalysisManager): 38 self.om = om 39 self.am = am 40 41 def generate(self, pg: PackageGroup): 42 for pkg in pg.packages: 43 self.gen_package_file(pkg) 44 45 def gen_package_file(self, pkg: PackageDecl): 46 pkg_abi_info = PackageABIInfo.get(self.am, pkg) 47 pkg_cpp_info = PackageCppInfo.get(self.am, pkg) 48 pkg_cpp_user_info = PackageCppUserInfo.get(self.am, pkg) 49 with CHeaderWriter( 50 self.om, 51 f"include/{pkg_cpp_user_info.header}", 52 FileKind.CPP_HEADER, 53 ) as pkg_cpp_target: 54 # types 55 pkg_cpp_target.add_include(pkg_cpp_info.header) 56 # functions 57 pkg_cpp_target.add_include("taihe/common.hpp") 58 pkg_cpp_target.add_include(pkg_abi_info.header) 59 for func in pkg.functions: 60 for param in func.params: 61 type_cpp_info = TypeCppInfo.get(self.am, param.ty_ref.resolved_ty) 62 pkg_cpp_target.add_include(*type_cpp_info.impl_headers) 63 if return_ty_ref := func.return_ty_ref: 64 type_cpp_info = TypeCppInfo.get(self.am, return_ty_ref.resolved_ty) 65 pkg_cpp_target.add_include(*type_cpp_info.impl_headers) 66 self.gen_func(func, pkg_cpp_target) 67 68 def gen_func( 69 self, 70 func: GlobFuncDecl, 71 pkg_cpp_target: CHeaderWriter, 72 ): 73 func_abi_info = GlobFuncABIInfo.get(self.am, func) 74 func_cpp_user_info = GlobFuncCppUserInfo.get(self.am, func) 75 params_cpp = [] 76 args_into_abi = [] 77 for param in func.params: 78 type_cpp_info = TypeCppInfo.get(self.am, param.ty_ref.resolved_ty) 79 params_cpp.append(f"{type_cpp_info.as_param} {param.name}") 80 args_into_abi.append(type_cpp_info.pass_into_abi(param.name)) 81 params_cpp_str = ", ".join(params_cpp) 82 args_into_abi_str = ", ".join(args_into_abi) 83 abi_result = f"{func_abi_info.mangled_name}({args_into_abi_str})" 84 if return_ty_ref := func.return_ty_ref: 85 type_cpp_info = TypeCppInfo.get(self.am, return_ty_ref.resolved_ty) 86 cpp_return_ty_name = type_cpp_info.as_owner 87 cpp_result = type_cpp_info.return_from_abi(abi_result) 88 else: 89 cpp_return_ty_name = "void" 90 cpp_result = abi_result 91 with pkg_cpp_target.indented( 92 f"namespace {func_cpp_user_info.namespace} {{", 93 f"}}", 94 indent="", 95 ): 96 with pkg_cpp_target.indented( 97 f"inline {cpp_return_ty_name} {func_cpp_user_info.call_name}({params_cpp_str}) {{", 98 f"}}", 99 ): 100 pkg_cpp_target.writelns( 101 f"return {cpp_result};", 102 )