1# coding=utf-8 2# 3# Copyright (c) 2025 Huawei Device Co., Ltd. 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15 16from collections.abc import Callable, Iterable 17from typing import Any, TypeGuard, TypeVar 18 19from typing_extensions import override 20 21from taihe.semantics.declarations import ( 22 CallbackTypeRefDecl, 23 DeclarationRefDecl, 24 EnumDecl, 25 EnumItemDecl, 26 GenericTypeRefDecl, 27 GlobFuncDecl, 28 IfaceDecl, 29 IfaceMethodDecl, 30 IfaceParentDecl, 31 LongTypeRefDecl, 32 NamedDecl, 33 PackageDecl, 34 PackageGroup, 35 PackageRefDecl, 36 ShortTypeRefDecl, 37 StructDecl, 38 TypeDecl, 39 TypeRefDecl, 40 UnionDecl, 41) 42from taihe.semantics.types import ( 43 BUILTIN_GENERICS, 44 BUILTIN_TYPES, 45 CallbackType, 46 ScalarKind, 47 ScalarType, 48 StringType, 49 Type, 50 UserType, 51) 52from taihe.semantics.visitor import RecursiveDeclVisitor 53from taihe.utils.diagnostics import DiagnosticsManager 54from taihe.utils.exceptions import ( 55 DeclarationNotInScopeError, 56 DeclNotExistError, 57 DeclRedefError, 58 DuplicateExtendsWarn, 59 EnumValueError, 60 GenericArgumentsError, 61 NotATypeError, 62 PackageNotExistError, 63 PackageNotInScopeError, 64 RecursiveReferenceError, 65 SymbolConflictWithNamespaceError, 66 TypeUsageError, 67) 68 69 70def analyze_semantics(pg: PackageGroup, diag: DiagnosticsManager): 71 """Runs semantic analysis passes on the given package group.""" 72 _check_decl_confilct_with_namespace(pg, diag) 73 _ResolveImportsPass(diag).handle_decl(pg) 74 _CheckFieldNameCollisionErrorPass(diag).handle_decl(pg) 75 _CheckEnumTypePass(diag).handle_decl(pg) 76 _CheckRecursiveInclusionPass(diag).handle_decl(pg) 77 78 79def _check_decl_confilct_with_namespace( 80 pg: PackageGroup, 81 diag: DiagnosticsManager, 82): 83 """Checks for declarations conflicts with namespaces.""" 84 namespaces: dict[str, list[PackageDecl]] = {} 85 for pkg in pg.packages: 86 pkg_name = pkg.name 87 # package "a.b.c" -> namespaces ["a.b.c", "a.b", "a"] 88 while True: 89 namespaces.setdefault(pkg_name, []).append(pkg) 90 splited = pkg_name.rsplit(".", maxsplit=1) 91 if len(splited) == 2: 92 pkg_name = splited[0] 93 else: 94 break 95 96 for p in pg.packages: 97 for d in p.declarations: 98 name = p.name + "." + d.name 99 if packages := namespaces.get(name, []): 100 diag.emit(SymbolConflictWithNamespaceError(d, name, packages)) 101 102 103class _ResolveImportsPass(RecursiveDeclVisitor): 104 """Resolves imports and type references within a package group.""" 105 106 diag: DiagnosticsManager 107 108 def __init__(self, diag: DiagnosticsManager): 109 self._current_pkg_group = None 110 self._current_pkg = None # Always points to the current package. 111 self.diag = diag 112 113 @property 114 def pkg(self) -> PackageDecl: 115 pass 116 return self._current_pkg 117 118 @property 119 def pkg_group(self) -> PackageGroup: 120 pass 121 return self._current_pkg_group 122 123 @override 124 def visit_package_decl(self, p: PackageDecl) -> None: 125 self._current_pkg = p 126 super().visit_package_decl(p) 127 self._current_pkg = None 128 129 @override 130 def visit_package_group(self, g: PackageGroup) -> None: 131 self._current_pkg_group = g 132 super().visit_package_group(g) 133 self._current_pkg_group = None 134 135 @override 136 def visit_package_ref_decl(self, d: PackageRefDecl) -> None: 137 if d.is_resolved: 138 return 139 d.is_resolved = True 140 141 pkg = self.pkg_group.lookup(d.symbol) 142 143 if pkg is None: 144 self.diag.emit(PackageNotExistError(d.symbol, loc=d.loc)) 145 return 146 147 d.maybe_resolved_pkg = pkg 148 149 @override 150 def visit_declaration_ref_decl(self, d: DeclarationRefDecl) -> None: 151 if d.is_resolved: 152 return 153 d.is_resolved = True 154 155 self.handle_decl(d.pkg_ref) 156 157 pkg = d.pkg_ref.maybe_resolved_pkg 158 159 if pkg is None: 160 # No need to repeatedly throw exceptions 161 return 162 163 decl = pkg.lookup(d.symbol) 164 165 if decl is None: 166 self.diag.emit(DeclNotExistError(d.symbol, loc=d.loc)) 167 return 168 169 d.maybe_resolved_decl = decl 170 171 @override 172 def visit_long_type_ref_decl(self, d: LongTypeRefDecl) -> None: 173 if d.is_resolved: 174 return 175 d.is_resolved = True 176 177 # Find the corresponding imported package according to the package name 178 pkg_import = self.pkg.lookup_pkg_import(d.pkname) 179 180 if pkg_import is None: 181 self.diag.emit(PackageNotInScopeError(d.pkname, loc=d.loc)) 182 return 183 184 # Then find the corresponding type declaration from the package 185 pkg = pkg_import.pkg_ref.maybe_resolved_pkg 186 187 if pkg is None: 188 # No need to repeatedly throw exceptions 189 return 190 191 decl = pkg.lookup(d.symbol) 192 193 if decl is None: 194 self.diag.emit(DeclNotExistError(d.symbol, loc=d.loc)) 195 return 196 197 if not isinstance(decl, TypeDecl): 198 self.diag.emit(NotATypeError(d.symbol, loc=d.loc)) 199 return 200 201 d.maybe_resolved_ty = decl.as_type(d) 202 203 @override 204 def visit_short_type_ref_decl(self, d: ShortTypeRefDecl) -> None: 205 if d.is_resolved: 206 return 207 d.is_resolved = True 208 209 # Find Builtin Types 210 builder = BUILTIN_TYPES.get(d.symbol) 211 212 if builder: 213 d.maybe_resolved_ty = builder(d) 214 return 215 216 # Find types declared in the current package 217 decl = self.pkg.lookup(d.symbol) 218 219 if decl: 220 if not isinstance(decl, TypeDecl): 221 self.diag.emit(NotATypeError(d.symbol, loc=d.loc)) 222 return 223 224 d.maybe_resolved_ty = decl.as_type(d) 225 return 226 227 # Look for imported type declarations 228 decl_import = self.pkg.lookup_decl_import(d.symbol) 229 230 if decl_import is None: 231 self.diag.emit(DeclarationNotInScopeError(d.symbol, loc=d.loc)) 232 return 233 234 decl = decl_import.decl_ref.maybe_resolved_decl 235 236 if decl is None: 237 # No need to repeatedly throw exceptions 238 return 239 240 if not isinstance(decl, TypeDecl): 241 self.diag.emit(NotATypeError(d.symbol, loc=d.loc)) 242 return 243 244 d.maybe_resolved_ty = decl.as_type(d) 245 246 @override 247 def visit_generic_type_ref_decl(self, d: GenericTypeRefDecl) -> None: 248 if d.is_resolved: 249 return 250 d.is_resolved = True 251 252 super().visit_generic_type_ref_decl(d) 253 254 args_ty: list[Type] = [] 255 for arg_ty_ref in d.args_ty_ref: 256 arg_ty = arg_ty_ref.maybe_resolved_ty 257 if arg_ty is None: 258 # No need to repeatedly throw exceptions 259 return 260 args_ty.append(arg_ty) 261 262 decl_name = d.symbol 263 264 builder = BUILTIN_GENERICS.get(decl_name) 265 266 if builder is None: 267 self.diag.emit(DeclarationNotInScopeError(decl_name, loc=d.loc)) 268 return 269 270 try: 271 d.maybe_resolved_ty = builder(d, *args_ty) 272 except TypeError: 273 self.diag.emit(GenericArgumentsError(d.text, loc=d.loc)) 274 275 @override 276 def visit_callback_type_ref_decl(self, d: CallbackTypeRefDecl) -> None: 277 if d.is_resolved: 278 return 279 d.is_resolved = True 280 281 super().visit_callback_type_ref_decl(d) 282 283 if d.return_ty_ref: 284 return_ty = d.return_ty_ref.maybe_resolved_ty 285 if return_ty is None: 286 # No need to repeatedly throw exceptions 287 return 288 else: 289 return_ty = None 290 291 params_ty: list[Type] = [] 292 for param in d.params: 293 arg_ty = param.ty_ref.maybe_resolved_ty 294 if arg_ty is None: 295 # No need to repeatedly throw exceptions 296 return 297 params_ty.append(arg_ty) 298 299 d.maybe_resolved_ty = CallbackType(d, return_ty, tuple(params_ty)) 300 301 302class _CheckFieldNameCollisionErrorPass(RecursiveDeclVisitor): 303 """Check for duplicate field names in declarations and name anonymous declarations.""" 304 305 diag: DiagnosticsManager 306 307 def __init__(self, diag: DiagnosticsManager): 308 self.diag = diag 309 310 @override 311 def visit_glob_func_decl(self, d: GlobFuncDecl) -> None: 312 self.check_collision_helper(d.params) 313 return super().visit_glob_func_decl(d) 314 315 @override 316 def visit_iface_func_decl(self, d: IfaceMethodDecl) -> None: 317 self.check_collision_helper(d.params) 318 return super().visit_iface_func_decl(d) 319 320 @override 321 def visit_enum_decl(self, d: EnumDecl) -> None: 322 self.check_collision_helper(d.items) 323 return super().visit_enum_decl(d) 324 325 @override 326 def visit_struct_decl(self, d: StructDecl) -> None: 327 self.check_collision_helper(d.fields) 328 return super().visit_struct_decl(d) 329 330 @override 331 def visit_union_decl(self, d: UnionDecl) -> None: 332 self.check_collision_helper(d.fields) 333 return super().visit_union_decl(d) 334 335 @override 336 def visit_iface_decl(self, d: IfaceDecl) -> None: 337 self.check_collision_helper(d.methods) 338 return super().visit_iface_decl(d) 339 340 @override 341 def visit_package_decl(self, p: PackageDecl) -> None: 342 self.check_collision_helper(p.declarations) 343 return super().visit_package_decl(p) 344 345 def check_collision_helper(self, children: Iterable[NamedDecl]): 346 names: dict[str, NamedDecl] = {} 347 for f in children: 348 if (prev := names.setdefault(f.name, f)) != f: 349 self.diag.emit(DeclRedefError(prev, f)) 350 351 352class _CheckEnumTypePass(RecursiveDeclVisitor): 353 """Validated enum item types.""" 354 355 diag: DiagnosticsManager 356 357 def __init__(self, diag: DiagnosticsManager): 358 self.diag = diag 359 360 def visit_enum_decl(self, d: EnumDecl) -> None: 361 def is_int(val: Any) -> TypeGuard[int]: 362 return not isinstance(val, bool) and isinstance(val, int) 363 364 valid: Callable[[Any], bool] 365 increment: Callable[[Any, EnumItemDecl], Any] 366 default: Callable[[EnumItemDecl], Any] 367 368 match d.ty_ref.maybe_resolved_ty: 369 case ScalarType(_, ScalarKind.I8): 370 valid = lambda val: is_int(val) and -(2**7) <= val < 2**7 371 increment = lambda prev, item: prev + 1 372 default = lambda item: 0 373 case ScalarType(_, ScalarKind.I16): 374 valid = lambda val: is_int(val) and -(2**15) <= val < 2**15 375 increment = lambda prev, item: prev + 1 376 default = lambda item: 0 377 case ScalarType(_, ScalarKind.I32): 378 valid = lambda val: is_int(val) and -(2**31) <= val < 2**31 379 increment = lambda prev, item: prev + 1 380 default = lambda item: 0 381 case ScalarType(_, ScalarKind.I64): 382 valid = lambda val: is_int(val) and -(2**63) <= val < 2**63 383 increment = lambda prev, item: prev + 1 384 default = lambda item: 0 385 case ScalarType(_, ScalarKind.U8): 386 valid = lambda val: is_int(val) and 0 <= val < 2**8 387 increment = lambda prev, item: prev + 1 388 default = lambda item: 0 389 case ScalarType(_, ScalarKind.U16): 390 valid = lambda val: is_int(val) and 0 <= val < 2**16 391 increment = lambda prev, item: prev + 1 392 default = lambda item: 0 393 case ScalarType(_, ScalarKind.U32): 394 valid = lambda val: is_int(val) and 0 <= val < 2**32 395 increment = lambda prev, item: prev + 1 396 default = lambda item: 0 397 case ScalarType(_, ScalarKind.U64): 398 valid = lambda val: is_int(val) and 0 <= val < 2**64 399 increment = lambda prev, item: prev + 1 400 default = lambda item: 0 401 case ScalarType(_, ScalarKind.BOOL): 402 valid = lambda val: isinstance(val, bool) 403 increment = lambda prev, item: False 404 default = lambda item: False 405 case ScalarType(_, ScalarKind.F32): 406 valid = lambda val: isinstance(val, float) 407 increment = lambda prev, item: 0.0 408 default = lambda item: 0.0 409 case ScalarType(_, ScalarKind.F64): 410 valid = lambda val: isinstance(val, float) 411 increment = lambda prev, item: 0.0 412 default = lambda item: 0.0 413 case StringType(): 414 valid = lambda val: isinstance(val, str) 415 increment = lambda prev, item: item.name 416 default = lambda item: item.name 417 case None: 418 return 419 case _: 420 self.diag.emit(TypeUsageError(d.ty_ref)) 421 return 422 423 prev = None 424 for item in d.items: 425 if item.value is None: 426 item.value = default(item) if prev is None else increment(prev, item) 427 if not valid(item.value): 428 self.diag.emit(EnumValueError(item, d)) 429 prev = None 430 else: 431 prev = item.value 432 433 434class _CheckRecursiveInclusionPass(RecursiveDeclVisitor): 435 """Validates struct fields for type correctness and cycles.""" 436 437 diag: DiagnosticsManager 438 439 def __init__(self, diag: DiagnosticsManager): 440 self.diag = diag 441 self.type_table: dict[ 442 TypeDecl, 443 list[tuple[tuple[TypeDecl, TypeRefDecl], TypeDecl]], 444 ] = {} 445 446 def visit_package_group(self, g: PackageGroup) -> None: 447 self.type_table = {} 448 super().visit_package_group(g) 449 cycles = detect_cycles(self.type_table) 450 for cycle in cycles: 451 last, *other = cycle[::-1] 452 self.diag.emit(RecursiveReferenceError(last, other)) 453 454 def visit_enum_decl(self, d: EnumDecl) -> None: 455 self.type_table[d] = [] 456 457 def visit_iface_decl(self, d: IfaceDecl) -> None: 458 parent_iface_list = self.type_table.setdefault(d, []) 459 parent_iface_dict: dict[IfaceDecl, IfaceParentDecl] = {} 460 for parent in d.parents: 461 if (parent_ty := parent.ty_ref.maybe_resolved_ty) is None: 462 continue 463 if not isinstance(parent_ty, UserType): 464 self.diag.emit(TypeUsageError(parent.ty_ref)) 465 continue 466 if not isinstance(parent_iface := parent_ty.ty_decl, IfaceDecl): 467 self.diag.emit(TypeUsageError(parent.ty_ref)) 468 continue 469 parent_iface_list.append(((d, parent.ty_ref), parent_iface)) 470 prev = parent_iface_dict.setdefault(parent_iface, parent) 471 if prev != parent: 472 self.diag.emit( 473 DuplicateExtendsWarn( 474 d, 475 parent_iface, 476 loc=parent.ty_ref.loc, 477 prev_loc=prev.ty_ref.loc, 478 ) 479 ) 480 481 def visit_struct_decl(self, d: StructDecl) -> None: 482 type_list = self.type_table.setdefault(d, []) 483 for f in d.fields: 484 if isinstance(ty := f.ty_ref.maybe_resolved_ty, UserType): 485 type_list.append(((d, f.ty_ref), ty.ty_decl)) 486 487 def visit_union_decl(self, d: UnionDecl) -> None: 488 type_list = self.type_table.setdefault(d, []) 489 for i in d.fields: 490 if i.ty_ref is None: 491 continue 492 if isinstance(ty := i.ty_ref.maybe_resolved_ty, UserType): 493 type_list.append(((d, i.ty_ref), ty.ty_decl)) 494 495 496V = TypeVar("V") 497E = TypeVar("E") 498 499 500def detect_cycles(graph: dict[V, list[tuple[E, V]]]) -> list[list[E]]: 501 """Detects and returns all cycles in a directed graph. 502 503 Example: 504 ------- 505 >>> graph = { 506 "A": [("A.b_0", "B")], 507 "B": [("B.c_0", "C")], 508 "C": [("C.a_0", "A"), ("C.a_1", "A")], 509 } 510 >>> detect_cycles(graph) 511 [["A.b_0", "B.c_0", "C.a_0"], ["A.b_0", "B.c_0", "C.a_1"]] 512 """ 513 cycles: list[list[E]] = [] 514 515 order = {point: i for i, point in enumerate(graph)} 516 glist = [ 517 [(edge, order[child]) for edge, child in children] 518 for children in graph.values() 519 ] 520 visited = [False for _ in glist] 521 edges: list[E] = [] 522 523 def visit(i: int): 524 if i < k: 525 return 526 if visited[i]: 527 if i == k: 528 cycles.append(edges.copy()) 529 return 530 visited[i] = True 531 for edge, j in glist[i]: 532 edges.append(edge) 533 visit(j) 534 edges.pop() 535 visited[i] = False 536 537 for k in range(len(glist)): 538 visit(k) 539 540 return cycles