1# coding=utf-8 2# 3# Copyright (c) 2025 Huawei Device Co., Ltd. 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15 16"""Implements the classic visitor pattern for core types. 17 18In most cases, you need to call `Visitor.handle_{type,decl}`. 19 20Design: 21- Each visitable type implements the "_accept" method, which delegates to the 22 corresponding `VisitorBase.visit_xxx` method. 23- `VisitorBase.visit_xxx` implements the default logic for each type. 24 1. Calls `node._traverse` to visit the "children" nodes for declarations. 25 2. Calls `self.visit_super_type` to bubble up towards the base type inside 26 the type hierarchy. 27- The `VisitorBase.visit_{type,decl}` is the "root" of the type hierarchy. 28""" 29 30from typing import TYPE_CHECKING, Generic, TypeVar 31 32from typing_extensions import override 33 34if TYPE_CHECKING: 35 from taihe.semantics.declarations import ( 36 CallbackTypeRefDecl, 37 Decl, 38 DeclarationImportDecl, 39 DeclarationRefDecl, 40 DeclProtocol, 41 EnumDecl, 42 EnumItemDecl, 43 GenericTypeRefDecl, 44 GlobFuncDecl, 45 IfaceDecl, 46 IfaceMethodDecl, 47 IfaceParentDecl, 48 ImportDecl, 49 LongTypeRefDecl, 50 PackageDecl, 51 PackageGroup, 52 PackageImportDecl, 53 PackageRefDecl, 54 ParamDecl, 55 ShortTypeRefDecl, 56 StructDecl, 57 StructFieldDecl, 58 TypeDecl, 59 TypeRefDecl, 60 UnionDecl, 61 UnionFieldDecl, 62 ) 63 from taihe.semantics.types import ( 64 ArrayType, 65 BuiltinType, 66 CallbackType, 67 EnumType, 68 GenericType, 69 IfaceType, 70 MapType, 71 OpaqueType, 72 OptionalType, 73 ScalarType, 74 SetType, 75 StringType, 76 StructType, 77 Type, 78 TypeProtocol, 79 UnionType, 80 UserType, 81 VectorType, 82 ) 83 84T = TypeVar("T") 85 86 87class TypeVisitor(Generic[T]): 88 """A base visitor for traversing types in the type hierarchy. 89 90 This visitor allows handling different types by defining specific visit methods. 91 Unlike `DeclVisitor`, it does NOT recursively visit internal declarations. 92 93 Usage: 94 - Override specific `visit_*` methods for custom behavior. 95 - Call `handle_type(t)` to start visiting a type. 96 """ 97 98 visiting: "TypeProtocol | None" = None 99 """The current node being visited. Only for debug use.""" 100 101 def handle_type(self, t: "TypeProtocol") -> T: 102 """The entrance for visiting.""" 103 r = self.visiting 104 self.visiting = t 105 try: 106 return t._accept(self) # type: ignore 107 except: 108 print( 109 f"Internal error from {self.__class__.__name__} while handling {self.visiting}" 110 ) 111 raise 112 finally: 113 self.visiting = r 114 115 def visit_type(self, t: "Type") -> T: 116 """The fallback method which handles the most general type. 117 118 Note that `TypeRef` is NOT a `Type`. 119 """ 120 raise NotImplementedError 121 122 ### Built-in types ### 123 124 def visit_builtin_type(self, t: "BuiltinType") -> T: 125 return self.visit_type(t) 126 127 def visit_scalar_type(self, t: "ScalarType") -> T: 128 return self.visit_builtin_type(t) 129 130 def visit_string_type(self, t: "StringType") -> T: 131 return self.visit_builtin_type(t) 132 133 def visit_opaque_type(self, t: "OpaqueType") -> T: 134 return self.visit_builtin_type(t) 135 136 ### UserTypes ### 137 138 def visit_user_type(self, t: "UserType") -> T: 139 return self.visit_type(t) 140 141 def visit_enum_type(self, t: "EnumType") -> T: 142 return self.visit_user_type(t) 143 144 def visit_struct_type(self, t: "StructType") -> T: 145 return self.visit_user_type(t) 146 147 def visit_union_type(self, t: "UnionType") -> T: 148 return self.visit_user_type(t) 149 150 def visit_iface_type(self, t: "IfaceType") -> T: 151 return self.visit_user_type(t) 152 153 ### Generic Types ### 154 155 def visit_callback_type(self, t: "CallbackType") -> T: 156 return self.visit_type(t) 157 158 def visit_generic_type(self, t: "GenericType") -> T: 159 return self.visit_type(t) 160 161 def visit_array_type(self, t: "ArrayType") -> T: 162 return self.visit_generic_type(t) 163 164 def visit_optional_type(self, t: "OptionalType") -> T: 165 return self.visit_generic_type(t) 166 167 def visit_vector_type(self, t: "VectorType") -> T: 168 return self.visit_generic_type(t) 169 170 def visit_map_type(self, t: "MapType") -> T: 171 return self.visit_generic_type(t) 172 173 def visit_set_type(self, t: "SetType") -> T: 174 return self.visit_generic_type(t) 175 176 177class DeclVisitor(Generic[T]): 178 """A base visitor for traversing declarations, including nested structures. 179 180 This visitor walks through the declaration hierarchy, visiting each declaration 181 and its associated types where applicable. 182 183 Usage: 184 - Override `visit_*` methods for specific behavior. 185 - Call `handle_decl(d)` to start visiting a declaration. 186 """ 187 188 visiting: "DeclProtocol | None" = None 189 """The current node being visited. Only for debug use.""" 190 191 def handle_decl(self, d: "DeclProtocol") -> T: 192 """The entrance for visiting anything "acceptable".""" 193 r = self.visiting 194 self.visiting = d 195 try: 196 return d._accept(self) # type: ignore 197 except: 198 print( 199 f"Internal error from {self.__class__.__name__} while handling {self.visiting}" 200 ) 201 raise 202 finally: 203 self.visiting = r 204 205 def visit_decl(self, d: "Decl") -> T: 206 """The fallback method which handles the most general cases.""" 207 raise NotImplementedError 208 209 def visit_param_decl(self, d: "ParamDecl") -> T: 210 return self.visit_decl(d) 211 212 ### Type References ### 213 214 def visit_type_ref_decl(self, d: "TypeRefDecl") -> T: 215 return self.visit_decl(d) 216 217 def visit_short_type_ref_decl(self, d: "ShortTypeRefDecl") -> T: 218 return self.visit_type_ref_decl(d) 219 220 def visit_long_type_ref_decl(self, d: "LongTypeRefDecl") -> T: 221 return self.visit_type_ref_decl(d) 222 223 def visit_generic_type_ref_decl(self, d: "GenericTypeRefDecl") -> T: 224 return self.visit_type_ref_decl(d) 225 226 def visit_callback_type_ref_decl(self, d: "CallbackTypeRefDecl") -> T: 227 return self.visit_type_ref_decl(d) 228 229 ### Other References ### 230 231 def visit_package_ref_decl(self, d: "PackageRefDecl") -> T: 232 return self.visit_decl(d) 233 234 def visit_declaration_ref_decl(self, d: "DeclarationRefDecl") -> T: 235 return self.visit_decl(d) 236 237 ### Imports ### 238 239 def visit_import_decl(self, d: "ImportDecl") -> T: 240 return self.visit_decl(d) 241 242 def visit_package_import_decl(self, d: "PackageImportDecl") -> T: 243 return self.visit_import_decl(d) 244 245 def visit_decl_import_decl(self, d: "DeclarationImportDecl") -> T: 246 return self.visit_import_decl(d) 247 248 ### Package Level Function ### 249 250 def visit_glob_func_decl(self, d: "GlobFuncDecl") -> T: 251 return self.visit_decl(d) 252 253 ### Package Level Type ### 254 255 def visit_type_decl(self, d: "TypeDecl") -> T: 256 return self.visit_decl(d) 257 258 ### Enum ### 259 260 def visit_enum_item_decl(self, d: "EnumItemDecl") -> T: 261 return self.visit_decl(d) 262 263 def visit_enum_decl(self, d: "EnumDecl") -> T: 264 return self.visit_type_decl(d) 265 266 ### Struct ### 267 268 def visit_struct_field_decl(self, d: "StructFieldDecl") -> T: 269 return self.visit_decl(d) 270 271 def visit_struct_decl(self, d: "StructDecl") -> T: 272 return self.visit_type_decl(d) 273 274 ### Union ### 275 276 def visit_union_field_decl(self, d: "UnionFieldDecl") -> T: 277 return self.visit_decl(d) 278 279 def visit_union_decl(self, d: "UnionDecl") -> T: 280 return self.visit_type_decl(d) 281 282 ### Interface ### 283 284 def visit_iface_parent_decl(self, d: "IfaceParentDecl") -> T: 285 return self.visit_decl(d) 286 287 def visit_iface_func_decl(self, d: "IfaceMethodDecl") -> T: 288 return self.visit_decl(d) 289 290 def visit_iface_decl(self, d: "IfaceDecl") -> T: 291 return self.visit_type_decl(d) 292 293 ### Package ### 294 295 def visit_package_decl(self, p: "PackageDecl") -> T: 296 return self.visit_decl(p) 297 298 def visit_package_group(self, g: "PackageGroup") -> T: 299 raise NotImplementedError 300 301 302class RecursiveDeclVisitor(DeclVisitor[None]): 303 """A visitor that recursively traverses all declarations and their sub-declarations. 304 305 This class is useful for full-tree traversal scenarios. 306 """ 307 308 @override 309 def visit_decl(self, d: "Decl") -> None: 310 pass 311 312 @override 313 def visit_param_decl(self, d: "ParamDecl") -> None: 314 self.handle_decl(d.ty_ref) 315 316 return self.visit_decl(d) 317 318 ### Type References ### 319 320 @override 321 def visit_type_ref_decl(self, d: "TypeRefDecl") -> None: 322 return self.visit_decl(d) 323 324 @override 325 def visit_short_type_ref_decl(self, d: "ShortTypeRefDecl") -> None: 326 return self.visit_type_ref_decl(d) 327 328 @override 329 def visit_long_type_ref_decl(self, d: "LongTypeRefDecl") -> None: 330 return self.visit_type_ref_decl(d) 331 332 @override 333 def visit_generic_type_ref_decl(self, d: "GenericTypeRefDecl") -> None: 334 for i in d.args_ty_ref: 335 self.handle_decl(i) 336 337 return self.visit_type_ref_decl(d) 338 339 @override 340 def visit_callback_type_ref_decl(self, d: "CallbackTypeRefDecl") -> None: 341 for i in d.params: 342 self.handle_decl(i) 343 344 if d.return_ty_ref: 345 self.handle_decl(d.return_ty_ref) 346 347 return self.visit_type_ref_decl(d) 348 349 ### Other References ### 350 351 @override 352 def visit_package_ref_decl(self, d: "PackageRefDecl") -> None: 353 return self.visit_decl(d) 354 355 @override 356 def visit_declaration_ref_decl(self, d: "DeclarationRefDecl") -> None: 357 self.handle_decl(d.pkg_ref) 358 359 return self.visit_decl(d) 360 361 ### Imports ### 362 363 @override 364 def visit_import_decl(self, d: "ImportDecl") -> None: 365 return self.visit_decl(d) 366 367 @override 368 def visit_package_import_decl(self, d: "PackageImportDecl") -> None: 369 self.handle_decl(d.pkg_ref) 370 371 return self.visit_import_decl(d) 372 373 @override 374 def visit_decl_import_decl(self, d: "DeclarationImportDecl") -> None: 375 self.handle_decl(d.decl_ref) 376 377 return self.visit_import_decl(d) 378 379 ### Functions ### 380 381 @override 382 def visit_glob_func_decl(self, d: "GlobFuncDecl") -> None: 383 for i in d.params: 384 self.handle_decl(i) 385 386 if d.return_ty_ref: 387 self.handle_decl(d.return_ty_ref) 388 389 return self.visit_decl(d) 390 391 ### Type "Decl" ### 392 393 @override 394 def visit_type_decl(self, d: "TypeDecl") -> None: 395 return self.visit_decl(d) 396 397 ### Enum ### 398 399 @override 400 def visit_enum_item_decl(self, d: "EnumItemDecl") -> None: 401 return self.visit_decl(d) 402 403 @override 404 def visit_enum_decl(self, d: "EnumDecl") -> None: 405 self.handle_decl(d.ty_ref) 406 407 for i in d.items: 408 self.handle_decl(i) 409 410 return self.visit_type_decl(d) 411 412 ### Struct ### 413 414 @override 415 def visit_struct_field_decl(self, d: "StructFieldDecl") -> None: 416 self.handle_decl(d.ty_ref) 417 418 return self.visit_decl(d) 419 420 @override 421 def visit_struct_decl(self, d: "StructDecl") -> None: 422 for i in d.fields: 423 self.handle_decl(i) 424 425 return self.visit_type_decl(d) 426 427 ### Union ### 428 429 @override 430 def visit_union_field_decl(self, d: "UnionFieldDecl") -> None: 431 if d.ty_ref: 432 self.handle_decl(d.ty_ref) 433 434 return self.visit_decl(d) 435 436 @override 437 def visit_union_decl(self, d: "UnionDecl") -> None: 438 for i in d.fields: 439 self.handle_decl(i) 440 441 return self.visit_type_decl(d) 442 443 ### Interface ### 444 445 @override 446 def visit_iface_parent_decl(self, d: "IfaceParentDecl") -> None: 447 self.handle_decl(d.ty_ref) 448 449 return self.visit_decl(d) 450 451 @override 452 def visit_iface_func_decl(self, d: "IfaceMethodDecl") -> None: 453 for i in d.params: 454 self.handle_decl(i) 455 456 if d.return_ty_ref: 457 self.handle_decl(d.return_ty_ref) 458 459 return self.visit_decl(d) 460 461 @override 462 def visit_iface_decl(self, d: "IfaceDecl") -> None: 463 for i in d.parents: 464 self.handle_decl(i) 465 466 for i in d.methods: 467 self.handle_decl(i) 468 469 return self.visit_type_decl(d) 470 471 ### Package ### 472 473 @override 474 def visit_package_decl(self, p: "PackageDecl") -> None: 475 for i in p.pkg_imports: 476 self.handle_decl(i) 477 478 for i in p.decl_imports: 479 self.handle_decl(i) 480 481 for i in p.declarations: 482 self.handle_decl(i) 483 484 return self.visit_decl(p) 485 486 @override 487 def visit_package_group(self, g: "PackageGroup") -> None: 488 for i in g.packages: 489 self.handle_decl(i)