1# coding=utf-8 2# 3# Copyright (c) 2025 Huawei Device Co., Ltd. 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15 16"""Convert AST to IR.""" 17 18from collections.abc import Iterable 19from typing import Any 20 21from typing_extensions import override 22 23from taihe.parse.antlr.TaiheAST import TaiheAST as ast 24from taihe.parse.antlr.TaiheVisitor import TaiheVisitor as Visitor 25from taihe.parse.ast_generation import generate_ast 26from taihe.semantics.declarations import ( 27 AttrItemDecl, 28 CallbackTypeRefDecl, 29 DeclarationImportDecl, 30 DeclarationRefDecl, 31 EnumDecl, 32 EnumItemDecl, 33 GenericTypeRefDecl, 34 GlobFuncDecl, 35 IfaceDecl, 36 IfaceMethodDecl, 37 IfaceParentDecl, 38 LongTypeRefDecl, 39 PackageDecl, 40 PackageImportDecl, 41 PackageRefDecl, 42 ParamDecl, 43 ShortTypeRefDecl, 44 StructDecl, 45 StructFieldDecl, 46 UnionDecl, 47 UnionFieldDecl, 48) 49from taihe.utils.diagnostics import DiagnosticsManager 50from taihe.utils.exceptions import InvalidPackageNameError 51from taihe.utils.sources import SourceBase, SourceLocation 52 53 54def pkg2str(pkg_name: ast.PkgName) -> str: 55 return ".".join(t.text for t in pkg_name.parts) 56 57 58def is_valid_pkg_name(name: str) -> bool: 59 """Checks if the package name is valid.""" 60 for part in name.split("."): 61 if not part: 62 return False 63 if not all(c.isalpha() or c == "_" for c in part[:1]): 64 return False 65 if not all(c.isalnum() or c == "_" for c in part[1:]): 66 return False 67 return True 68 69 70class ExprEvaluator(Visitor): 71 # Bool Expr 72 73 @override 74 def visit_literal_bool_expr(self, node: ast.LiteralBoolExpr) -> bool: 75 return { 76 "true": True, 77 "false": False, 78 }[node.val.text] 79 80 @override 81 def visit_int_comparison_bool_expr(self, node: ast.IntComparisonBoolExpr) -> bool: 82 return { 83 ">": int.__gt__, 84 "<": int.__lt__, 85 ">=": int.__ge__, 86 "<=": int.__le__, 87 "==": int.__eq__, 88 "!=": int.__ne__, 89 }[node.op.text]( 90 int(self.visit(node.left)), 91 int(self.visit(node.right)), 92 ) 93 94 @override 95 def visit_float_comparison_bool_expr( 96 self, node: ast.FloatComparisonBoolExpr 97 ) -> bool: 98 return { 99 ">": float.__gt__, 100 "<": float.__lt__, 101 ">=": float.__ge__, 102 "<=": float.__le__, 103 "==": float.__eq__, 104 "!=": float.__ne__, 105 }[node.op.text]( 106 float(self.visit(node.left)), 107 float(self.visit(node.right)), 108 ) 109 110 @override 111 def visit_unary_bool_expr(self, node: ast.UnaryBoolExpr) -> bool: 112 pass 113 return not self.visit(node.expr) 114 115 @override 116 def visit_binary_bool_expr(self, node: ast.BinaryBoolExpr) -> bool: 117 return { 118 "&&": bool.__and__, 119 "||": bool.__or__, 120 }[node.op.text]( 121 bool(self.visit(node.left)), 122 bool(self.visit(node.right)), 123 ) 124 125 @override 126 def visit_parenthesis_bool_expr(self, node: ast.ParenthesisBoolExpr) -> bool: 127 return self.visit(node.expr) 128 129 @override 130 def visit_conditional_bool_expr(self, node: ast.ConditionalBoolExpr) -> bool: 131 return ( 132 self.visit(node.then_expr) 133 if self.visit(node.cond) 134 else self.visit(node.else_expr) 135 ) 136 137 # Int Expr 138 139 @override 140 def visit_literal_int_expr(self, node: ast.LiteralIntExpr) -> int: 141 text = node.val.text 142 if text.startswith("0b"): 143 return int(text, 2) 144 if text.startswith("0o"): 145 return int(text, 8) 146 if text.startswith("0x"): 147 return int(text, 16) 148 return int(text) 149 150 @override 151 def visit_parenthesis_int_expr(self, node: ast.ParenthesisIntExpr) -> int: 152 return self.visit(node.expr) 153 154 @override 155 def visit_conditional_int_expr(self, node: ast.ConditionalIntExpr) -> int: 156 return ( 157 self.visit(node.then_expr) 158 if self.visit(node.cond) 159 else self.visit(node.else_expr) 160 ) 161 162 @override 163 def visit_unary_int_expr(self, node: ast.UnaryIntExpr) -> int: 164 return { 165 "-": int.__neg__, 166 "+": int.__pos__, 167 "~": int.__invert__, 168 }[node.op.text]( 169 int(self.visit(node.expr)), 170 ) 171 172 @override 173 def visit_binary_int_expr(self, node: ast.BinaryIntExpr) -> int: 174 return { 175 "+": int.__add__, 176 "-": int.__sub__, 177 "*": int.__mul__, 178 "/": int.__floordiv__, 179 "%": int.__mod__, 180 "<<": int.__lshift__, 181 ">>": int.__rshift__, 182 "&": int.__and__, 183 "|": int.__or__, 184 "^": int.__xor__, 185 }[node.op.text]( 186 int(self.visit(node.left)), 187 int(self.visit(node.right)), 188 ) 189 190 @override 191 def visit_binary_int_shift_expr(self, node: ast.BinaryIntShiftExpr) -> int: 192 return { 193 "<": int.__lshift__, 194 ">": int.__rshift__, 195 }[node.ch.text]( 196 int(self.visit(node.left)), 197 int(self.visit(node.right)), 198 ) 199 200 # Float Expr 201 202 @override 203 def visit_literal_float_expr(self, node: ast.LiteralFloatExpr) -> float: 204 return float(node.val.text) 205 206 @override 207 def visit_parenthesis_float_expr(self, node: ast.ParenthesisFloatExpr) -> float: 208 return self.visit(node.expr) 209 210 @override 211 def visit_conditional_float_expr(self, node: ast.ConditionalFloatExpr) -> Any: 212 return ( 213 self.visit(node.then_expr) 214 if self.visit(node.cond) 215 else self.visit(node.else_expr) 216 ) 217 218 @override 219 def visit_unary_float_expr(self, node: ast.UnaryFloatExpr) -> float: 220 return { 221 "-": float.__neg__, 222 "+": float.__pos__, 223 }[node.op.text]( 224 float(self.visit(node.expr)), 225 ) 226 227 @override 228 def visit_binary_float_expr(self, node: ast.BinaryFloatExpr) -> float: 229 return { 230 "+": float.__add__, 231 "-": float.__sub__, 232 "*": float.__mul__, 233 "/": float.__truediv__, 234 }[node.op.text]( 235 float(self.visit(node.left)), 236 float(self.visit(node.right)), 237 ) 238 239 # String Expr 240 241 @override 242 def visit_literal_string_expr(self, node: ast.LiteralStringExpr) -> str: 243 return node.val.text[1:-1].encode("utf-8").decode("unicode_escape") 244 245 @override 246 def visit_literal_doc_string_expr(self, node: ast.LiteralDocStringExpr) -> str: 247 return node.val.text[3:-3] 248 249 @override 250 def visit_binary_string_expr(self, node: ast.BinaryStringExpr) -> str: 251 return self.visit(node.left) + self.visit(node.right) 252 253 @override 254 def visit_any_expr(self, node: ast.AnyExpr) -> Any: 255 return self.visit(node.expr) 256 257 258class AstConverter(ExprEvaluator): 259 """Converts a node on AST to the intermetiade representation. 260 261 Note that declerations with errors are discarded. 262 """ 263 264 source: SourceBase 265 diag: DiagnosticsManager 266 267 def __init__(self, source: SourceBase, diag: DiagnosticsManager): 268 self.source = source 269 self.diag = diag 270 271 # Type References 272 273 @override 274 def visit_long_type(self, node: ast.LongType) -> LongTypeRefDecl: 275 d = LongTypeRefDecl(node.loc, pkg2str(node.pkg_name), str(node.decl_name)) 276 self.diag.for_each(node.forward_attrs, lambda a: d.add_attr(self.visit(a))) 277 return d 278 279 @override 280 def visit_short_type(self, node: ast.ShortType) -> ShortTypeRefDecl: 281 d = ShortTypeRefDecl(node.loc, str(node.decl_name)) 282 self.diag.for_each(node.forward_attrs, lambda a: d.add_attr(self.visit(a))) 283 return d 284 285 @override 286 def visit_generic_type(self, node: ast.GenericType) -> GenericTypeRefDecl: 287 d = GenericTypeRefDecl(node.loc, str(node.decl_name)) 288 self.diag.for_each(node.args, lambda a: d.add_arg_ty_ref(self.visit(a))) 289 self.diag.for_each(node.forward_attrs, lambda a: d.add_attr(self.visit(a))) 290 return d 291 292 @override 293 def visit_callback_type(self, node: ast.CallbackType) -> CallbackTypeRefDecl: 294 if ty := node.return_ty: 295 d = CallbackTypeRefDecl(node.loc, self.visit(ty)) 296 else: 297 d = CallbackTypeRefDecl(node.loc) 298 self.diag.for_each(node.parameters, lambda p: d.add_param(self.visit(p))) 299 self.diag.for_each(node.forward_attrs, lambda a: d.add_attr(self.visit(a))) 300 return d 301 302 # Uses 303 304 @override 305 def visit_use_package(self, node: ast.UsePackage) -> Iterable[PackageImportDecl]: 306 p_ref = PackageRefDecl(node.pkg_name.loc, pkg2str(node.pkg_name)) 307 if node.pkg_alias: 308 d = PackageImportDecl( 309 p_ref, 310 name=str(node.pkg_alias), 311 loc=node.pkg_alias.loc, 312 ) 313 else: 314 d = PackageImportDecl( 315 p_ref, 316 ) 317 yield d 318 319 @override 320 def visit_use_symbol(self, node: ast.UseSymbol) -> Iterable[DeclarationImportDecl]: 321 p_ref = PackageRefDecl(node.pkg_name.loc, pkg2str(node.pkg_name)) 322 for p in node.decl_alias_pairs: 323 d_ref = DeclarationRefDecl(p.decl_name.loc, str(p.decl_name), p_ref) 324 if p.decl_alias: 325 d = DeclarationImportDecl( 326 d_ref, 327 name=str(p.decl_alias), 328 loc=p.decl_alias.loc, 329 ) 330 else: 331 d = DeclarationImportDecl( 332 d_ref, 333 ) 334 yield d 335 336 # Declarations 337 338 @override 339 def visit_struct_property(self, node: ast.StructProperty) -> StructFieldDecl: 340 d = StructFieldDecl(node.name.loc, str(node.name), self.visit(node.ty)) 341 self.diag.for_each(node.forward_attrs, lambda a: d.add_attr(self.visit(a))) 342 return d 343 344 @override 345 def visit_struct(self, node: ast.Struct) -> StructDecl: 346 d = StructDecl(node.name.loc, str(node.name)) 347 self.diag.for_each(node.fields, lambda f: d.add_field(self.visit(f))) 348 self.diag.for_each(node.forward_attrs, lambda a: d.add_attr(self.visit(a))) 349 self.diag.for_each(node.inner_attrs, lambda a: d.add_attr(self.visit(a))) 350 return d 351 352 @override 353 def visit_enum_property(self, node: ast.EnumProperty) -> EnumItemDecl: 354 if node.val: 355 d = EnumItemDecl(node.name.loc, str(node.name), self.visit(node.val)) 356 else: 357 d = EnumItemDecl(node.name.loc, str(node.name)) 358 self.diag.for_each(node.forward_attrs, lambda a: d.add_attr(self.visit(a))) 359 return d 360 361 @override 362 def visit_enum(self, node: ast.Enum) -> EnumDecl: 363 d = EnumDecl(node.name.loc, str(node.name), self.visit(node.enum_ty)) 364 self.diag.for_each(node.fields, lambda a: d.add_item(self.visit(a))) 365 self.diag.for_each(node.forward_attrs, lambda a: d.add_attr(self.visit(a))) 366 return d 367 368 @override 369 def visit_union_property(self, node: ast.UnionProperty) -> UnionFieldDecl: 370 if ty := node.ty: 371 d = UnionFieldDecl(node.name.loc, str(node.name), self.visit(ty)) 372 else: 373 d = UnionFieldDecl(node.name.loc, str(node.name)) 374 self.diag.for_each(node.forward_attrs, lambda a: d.add_attr(self.visit(a))) 375 return d 376 377 @override 378 def visit_union(self, node: ast.Union) -> UnionDecl: 379 d = UnionDecl(node.name.loc, str(node.name)) 380 self.diag.for_each(node.fields, lambda f: d.add_field(self.visit(f))) 381 self.diag.for_each(node.forward_attrs, lambda a: d.add_attr(self.visit(a))) 382 self.diag.for_each(node.inner_attrs, lambda a: d.add_attr(self.visit(a))) 383 return d 384 385 @override 386 def visit_parameter(self, node: ast.Parameter) -> ParamDecl: 387 d = ParamDecl(node.name.loc, str(node.name), self.visit(node.ty)) 388 self.diag.for_each(node.forward_attrs, lambda a: d.add_attr(self.visit(a))) 389 return d 390 391 @override 392 def visit_interface_function(self, node: ast.InterfaceFunction) -> IfaceMethodDecl: 393 if ty := node.return_ty: 394 d = IfaceMethodDecl(node.name.loc, str(node.name), self.visit(ty)) 395 else: 396 d = IfaceMethodDecl(node.name.loc, str(node.name)) 397 self.diag.for_each(node.parameters, lambda p: d.add_param(self.visit(p))) 398 self.diag.for_each(node.forward_attrs, lambda a: d.add_attr(self.visit(a))) 399 return d 400 401 @override 402 def visit_interface_parent(self, node: ast.InterfaceParent) -> IfaceParentDecl: 403 p = IfaceParentDecl(node.ty.loc, self.visit(node.ty)) 404 return p 405 406 @override 407 def visit_interface(self, node: ast.Interface) -> IfaceDecl: 408 d = IfaceDecl(node.name.loc, str(node.name)) 409 self.diag.for_each(node.fields, lambda f: d.add_method(self.visit(f))) 410 self.diag.for_each(node.extends, lambda i: d.add_parent(self.visit(i))) 411 self.diag.for_each(node.forward_attrs, lambda a: d.add_attr(self.visit(a))) 412 self.diag.for_each(node.inner_attrs, lambda a: d.add_attr(self.visit(a))) 413 return d 414 415 @override 416 def visit_global_function(self, node: ast.GlobalFunction) -> GlobFuncDecl: 417 if ty := node.return_ty: 418 d = GlobFuncDecl(node.name.loc, str(node.name), self.visit(ty)) 419 else: 420 d = GlobFuncDecl(node.name.loc, str(node.name)) 421 self.diag.for_each(node.parameters, lambda p: d.add_param(self.visit(p))) 422 self.diag.for_each(node.forward_attrs, lambda a: d.add_attr(self.visit(a))) 423 return d 424 425 # Attributes 426 427 def visit_attr(self, node: ast.DeclAttr | ast.ScopeAttr) -> AttrItemDecl: 428 args: list[Any] = [] 429 kwargs: dict[str, Any] = {} 430 for arg in node.args: 431 if isinstance(arg, ast.NamedAttrArg): 432 kwargs[str(arg.name)] = self.visit(arg.val) 433 else: 434 args.append(self.visit(arg.val)) 435 d = AttrItemDecl(node.name.loc, str(node.name), args, kwargs) 436 return d 437 438 @override 439 def visit_decl_attr(self, node: ast.DeclAttr) -> AttrItemDecl: 440 return self.visit_attr(node) 441 442 @override 443 def visit_scope_attr(self, node: ast.ScopeAttr) -> AttrItemDecl: 444 return self.visit_attr(node) 445 446 # Package 447 448 @override 449 def visit_spec(self, node: ast.Spec) -> PackageDecl: 450 if not is_valid_pkg_name(self.source.pkg_name): 451 raise InvalidPackageNameError( 452 self.source.pkg_name, 453 loc=SourceLocation(self.source), 454 ) 455 pkg = PackageDecl(self.source.pkg_name, SourceLocation(self.source)) 456 for u in node.uses: 457 self.diag.for_each(self.visit(u), pkg.add_import) 458 self.diag.for_each(node.fields, lambda n: pkg.add_declaration(self.visit(n))) 459 self.diag.for_each(node.inner_attrs, lambda a: pkg.add_attr(self.visit(a))) 460 return pkg 461 462 def convert(self) -> PackageDecl: 463 """Converts the whole source code buffer to a package. 464 465 Returns: 466 PackageDecl: The package declaration containing all declarations 467 and imports from the source code. 468 469 Raises: 470 InvalidPackageNameError: If the package name is invalid. 471 """ 472 ast_nodes = generate_ast(self.source, self.diag) 473 return self.visit_spec(ast_nodes)