1# coding=utf-8 2# 3# Copyright (c) 2025 Huawei Device Co., Ltd. 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15 16from abc import ABC, abstractmethod 17from collections.abc import Iterable 18from typing import TYPE_CHECKING, ClassVar 19 20if TYPE_CHECKING: 21 from taihe.driver.contexts import CompilerInstance 22 23 24class Backend(ABC): 25 @abstractmethod 26 def __init__(self, instance: "CompilerInstance"): 27 """Initialize the backend.""" 28 29 def post_process(self): 30 """Post-processes the IR just after parsing. 31 32 Language backend may transform the IR in-place in this stage. 33 """ 34 return 35 36 def validate(self): 37 """Validate the IR after the post-process stage. 38 39 Language backend MUST NOT transform the IR in this stage. 40 """ 41 return 42 43 @abstractmethod 44 def generate(self): 45 """Generate the output files. 46 47 Language backend MUST NOT transform the IR or report error in this stage: 48 - The transformation should be completed in the `post_process` stage. 49 - The error reporting should be completed in the `validate` stage. 50 """ 51 52 53class BackendConfig(ABC): 54 NAME: ClassVar[str] 55 "The name of the backend." 56 57 DEPS: ClassVar[list[str]] = [] 58 "List of backends that the current backend depends on." 59 60 @abstractmethod 61 def __init__(self): 62 ... 63 64 @abstractmethod 65 def construct(self, instance: "CompilerInstance") -> Backend: 66 ... 67 68 69BackendConfigT = type[BackendConfig] 70 71 72class BackendRegistry: 73 def __init__(self): 74 self._factories: dict[str, BackendConfigT] = {} 75 76 def register(self, factory: BackendConfigT): 77 name = factory.NAME 78 if (setted := self._factories.setdefault(name, factory)) is not factory: 79 raise ValueError( 80 f"backend {name!r} cannot be registered as {factory.__name__} " 81 f"because it is already registered as {setted.__name__}" 82 ) 83 84 def get_backend_names(self) -> list[str]: 85 return list(self._factories.keys()) 86 87 def clear(self): 88 self._factories.clear() 89 90 def collect_required_backends(self, names: Iterable[str]) -> list[BackendConfigT]: 91 result: list[BackendConfigT] = [] 92 visited: set[str] = set() 93 94 def add(name: str): 95 if name in visited: 96 return False 97 factory = self._factories.get(name) 98 if not factory: 99 raise KeyError(f"unknown backend {name!r}") 100 101 visited.add(name) 102 for dep in factory.DEPS: 103 add(dep) 104 result.append(factory) 105 return True 106 107 for name in names: 108 add(name) 109 110 return result 111 112 def register_all(self): 113 from taihe.codegen.abi import ( 114 AbiHeaderBackendConfig, 115 AbiSourcesBackendConfig, 116 CAuthorBackendConfig, 117 ) 118 from taihe.codegen.ani import AniBridgeBackendConfig 119 from taihe.codegen.cpp import ( 120 CppAuthorBackendConfig, 121 CppCommonHeadersBackendConfig, 122 CppUserHeadersBackendConfig, 123 ) 124 from taihe.semantics import PrettyPrintBackendConfig 125 126 backends = [ 127 # abi 128 AbiHeaderBackendConfig, 129 AbiSourcesBackendConfig, 130 CAuthorBackendConfig, 131 # cpp 132 CppCommonHeadersBackendConfig, 133 CppAuthorBackendConfig, 134 CppUserHeadersBackendConfig, 135 # ani 136 AniBridgeBackendConfig, 137 # pretty print 138 PrettyPrintBackendConfig, 139 ] 140 141 for b in backends: 142 self.register(b)