• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# coding=utf-8
2#
3# Copyright (c) 2025 Huawei Device Co., Ltd.
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Implements the classic visitor pattern for core types.
17
18In most cases, you need to call `Visitor.handle_{type,decl}`.
19
20Design:
21- Each visitable type implements the "_accept" method, which delegates to the
22    corresponding `VisitorBase.visit_xxx` method.
23- `VisitorBase.visit_xxx` implements the default logic for each type.
24    1. Calls `node._traverse` to visit the "children" nodes for declarations.
25    2. Calls `self.visit_super_type` to bubble up towards the base type inside
26    the type hierarchy.
27- The `VisitorBase.visit_{type,decl}` is the "root" of the type hierarchy.
28"""
29
30from typing import TYPE_CHECKING, Generic, TypeVar
31
32from typing_extensions import override
33
34if TYPE_CHECKING:
35    from taihe.semantics.declarations import (
36        CallbackTypeRefDecl,
37        Decl,
38        DeclarationImportDecl,
39        DeclarationRefDecl,
40        DeclProtocol,
41        EnumDecl,
42        EnumItemDecl,
43        GenericTypeRefDecl,
44        GlobFuncDecl,
45        IfaceDecl,
46        IfaceMethodDecl,
47        IfaceParentDecl,
48        ImportDecl,
49        LongTypeRefDecl,
50        PackageDecl,
51        PackageGroup,
52        PackageImportDecl,
53        PackageRefDecl,
54        ParamDecl,
55        ShortTypeRefDecl,
56        StructDecl,
57        StructFieldDecl,
58        TypeDecl,
59        TypeRefDecl,
60        UnionDecl,
61        UnionFieldDecl,
62    )
63    from taihe.semantics.types import (
64        ArrayType,
65        BuiltinType,
66        CallbackType,
67        EnumType,
68        GenericType,
69        IfaceType,
70        MapType,
71        OpaqueType,
72        OptionalType,
73        ScalarType,
74        SetType,
75        StringType,
76        StructType,
77        Type,
78        TypeProtocol,
79        UnionType,
80        UserType,
81        VectorType,
82    )
83
84T = TypeVar("T")
85
86
87class TypeVisitor(Generic[T]):
88    """A base visitor for traversing types in the type hierarchy.
89
90    This visitor allows handling different types by defining specific visit methods.
91    Unlike `DeclVisitor`, it does NOT recursively visit internal declarations.
92
93    Usage:
94    - Override specific `visit_*` methods for custom behavior.
95    - Call `handle_type(t)` to start visiting a type.
96    """
97
98    visiting: "TypeProtocol | None" = None
99    """The current node being visited. Only for debug use."""
100
101    def handle_type(self, t: "TypeProtocol") -> T:
102        """The entrance for visiting."""
103        r = self.visiting
104        self.visiting = t
105        try:
106            return t._accept(self)  # type: ignore
107        except:
108            print(
109                f"Internal error from {self.__class__.__name__} while handling {self.visiting}"
110            )
111            raise
112        finally:
113            self.visiting = r
114
115    def visit_type(self, t: "Type") -> T:
116        """The fallback method which handles the most general type.
117
118        Note that `TypeRef` is NOT a `Type`.
119        """
120        raise NotImplementedError
121
122    ### Built-in types ###
123
124    def visit_builtin_type(self, t: "BuiltinType") -> T:
125        return self.visit_type(t)
126
127    def visit_scalar_type(self, t: "ScalarType") -> T:
128        return self.visit_builtin_type(t)
129
130    def visit_string_type(self, t: "StringType") -> T:
131        return self.visit_builtin_type(t)
132
133    def visit_opaque_type(self, t: "OpaqueType") -> T:
134        return self.visit_builtin_type(t)
135
136    ### UserTypes ###
137
138    def visit_user_type(self, t: "UserType") -> T:
139        return self.visit_type(t)
140
141    def visit_enum_type(self, t: "EnumType") -> T:
142        return self.visit_user_type(t)
143
144    def visit_struct_type(self, t: "StructType") -> T:
145        return self.visit_user_type(t)
146
147    def visit_union_type(self, t: "UnionType") -> T:
148        return self.visit_user_type(t)
149
150    def visit_iface_type(self, t: "IfaceType") -> T:
151        return self.visit_user_type(t)
152
153    ### Generic Types ###
154
155    def visit_callback_type(self, t: "CallbackType") -> T:
156        return self.visit_type(t)
157
158    def visit_generic_type(self, t: "GenericType") -> T:
159        return self.visit_type(t)
160
161    def visit_array_type(self, t: "ArrayType") -> T:
162        return self.visit_generic_type(t)
163
164    def visit_optional_type(self, t: "OptionalType") -> T:
165        return self.visit_generic_type(t)
166
167    def visit_vector_type(self, t: "VectorType") -> T:
168        return self.visit_generic_type(t)
169
170    def visit_map_type(self, t: "MapType") -> T:
171        return self.visit_generic_type(t)
172
173    def visit_set_type(self, t: "SetType") -> T:
174        return self.visit_generic_type(t)
175
176
177class DeclVisitor(Generic[T]):
178    """A base visitor for traversing declarations, including nested structures.
179
180    This visitor walks through the declaration hierarchy, visiting each declaration
181    and its associated types where applicable.
182
183    Usage:
184    - Override `visit_*` methods for specific behavior.
185    - Call `handle_decl(d)` to start visiting a declaration.
186    """
187
188    visiting: "DeclProtocol | None" = None
189    """The current node being visited. Only for debug use."""
190
191    def handle_decl(self, d: "DeclProtocol") -> T:
192        """The entrance for visiting anything "acceptable"."""
193        r = self.visiting
194        self.visiting = d
195        try:
196            return d._accept(self)  # type: ignore
197        except:
198            print(
199                f"Internal error from {self.__class__.__name__} while handling {self.visiting}"
200            )
201            raise
202        finally:
203            self.visiting = r
204
205    def visit_decl(self, d: "Decl") -> T:
206        """The fallback method which handles the most general cases."""
207        raise NotImplementedError
208
209    def visit_param_decl(self, d: "ParamDecl") -> T:
210        return self.visit_decl(d)
211
212    ### Type References ###
213
214    def visit_type_ref_decl(self, d: "TypeRefDecl") -> T:
215        return self.visit_decl(d)
216
217    def visit_short_type_ref_decl(self, d: "ShortTypeRefDecl") -> T:
218        return self.visit_type_ref_decl(d)
219
220    def visit_long_type_ref_decl(self, d: "LongTypeRefDecl") -> T:
221        return self.visit_type_ref_decl(d)
222
223    def visit_generic_type_ref_decl(self, d: "GenericTypeRefDecl") -> T:
224        return self.visit_type_ref_decl(d)
225
226    def visit_callback_type_ref_decl(self, d: "CallbackTypeRefDecl") -> T:
227        return self.visit_type_ref_decl(d)
228
229    ### Other References ###
230
231    def visit_package_ref_decl(self, d: "PackageRefDecl") -> T:
232        return self.visit_decl(d)
233
234    def visit_declaration_ref_decl(self, d: "DeclarationRefDecl") -> T:
235        return self.visit_decl(d)
236
237    ### Imports ###
238
239    def visit_import_decl(self, d: "ImportDecl") -> T:
240        return self.visit_decl(d)
241
242    def visit_package_import_decl(self, d: "PackageImportDecl") -> T:
243        return self.visit_import_decl(d)
244
245    def visit_decl_import_decl(self, d: "DeclarationImportDecl") -> T:
246        return self.visit_import_decl(d)
247
248    ### Package Level Function ###
249
250    def visit_glob_func_decl(self, d: "GlobFuncDecl") -> T:
251        return self.visit_decl(d)
252
253    ### Package Level Type ###
254
255    def visit_type_decl(self, d: "TypeDecl") -> T:
256        return self.visit_decl(d)
257
258    ### Enum ###
259
260    def visit_enum_item_decl(self, d: "EnumItemDecl") -> T:
261        return self.visit_decl(d)
262
263    def visit_enum_decl(self, d: "EnumDecl") -> T:
264        return self.visit_type_decl(d)
265
266    ### Struct ###
267
268    def visit_struct_field_decl(self, d: "StructFieldDecl") -> T:
269        return self.visit_decl(d)
270
271    def visit_struct_decl(self, d: "StructDecl") -> T:
272        return self.visit_type_decl(d)
273
274    ### Union ###
275
276    def visit_union_field_decl(self, d: "UnionFieldDecl") -> T:
277        return self.visit_decl(d)
278
279    def visit_union_decl(self, d: "UnionDecl") -> T:
280        return self.visit_type_decl(d)
281
282    ### Interface ###
283
284    def visit_iface_parent_decl(self, d: "IfaceParentDecl") -> T:
285        return self.visit_decl(d)
286
287    def visit_iface_func_decl(self, d: "IfaceMethodDecl") -> T:
288        return self.visit_decl(d)
289
290    def visit_iface_decl(self, d: "IfaceDecl") -> T:
291        return self.visit_type_decl(d)
292
293    ### Package ###
294
295    def visit_package_decl(self, p: "PackageDecl") -> T:
296        return self.visit_decl(p)
297
298    def visit_package_group(self, g: "PackageGroup") -> T:
299        raise NotImplementedError
300
301
302class RecursiveDeclVisitor(DeclVisitor[None]):
303    """A visitor that recursively traverses all declarations and their sub-declarations.
304
305    This class is useful for full-tree traversal scenarios.
306    """
307
308    @override
309    def visit_decl(self, d: "Decl") -> None:
310        pass
311
312    @override
313    def visit_param_decl(self, d: "ParamDecl") -> None:
314        self.handle_decl(d.ty_ref)
315
316        return self.visit_decl(d)
317
318    ### Type References ###
319
320    @override
321    def visit_type_ref_decl(self, d: "TypeRefDecl") -> None:
322        return self.visit_decl(d)
323
324    @override
325    def visit_short_type_ref_decl(self, d: "ShortTypeRefDecl") -> None:
326        return self.visit_type_ref_decl(d)
327
328    @override
329    def visit_long_type_ref_decl(self, d: "LongTypeRefDecl") -> None:
330        return self.visit_type_ref_decl(d)
331
332    @override
333    def visit_generic_type_ref_decl(self, d: "GenericTypeRefDecl") -> None:
334        for i in d.args_ty_ref:
335            self.handle_decl(i)
336
337        return self.visit_type_ref_decl(d)
338
339    @override
340    def visit_callback_type_ref_decl(self, d: "CallbackTypeRefDecl") -> None:
341        for i in d.params:
342            self.handle_decl(i)
343
344        if d.return_ty_ref:
345            self.handle_decl(d.return_ty_ref)
346
347        return self.visit_type_ref_decl(d)
348
349    ### Other References ###
350
351    @override
352    def visit_package_ref_decl(self, d: "PackageRefDecl") -> None:
353        return self.visit_decl(d)
354
355    @override
356    def visit_declaration_ref_decl(self, d: "DeclarationRefDecl") -> None:
357        self.handle_decl(d.pkg_ref)
358
359        return self.visit_decl(d)
360
361    ### Imports ###
362
363    @override
364    def visit_import_decl(self, d: "ImportDecl") -> None:
365        return self.visit_decl(d)
366
367    @override
368    def visit_package_import_decl(self, d: "PackageImportDecl") -> None:
369        self.handle_decl(d.pkg_ref)
370
371        return self.visit_import_decl(d)
372
373    @override
374    def visit_decl_import_decl(self, d: "DeclarationImportDecl") -> None:
375        self.handle_decl(d.decl_ref)
376
377        return self.visit_import_decl(d)
378
379    ### Functions ###
380
381    @override
382    def visit_glob_func_decl(self, d: "GlobFuncDecl") -> None:
383        for i in d.params:
384            self.handle_decl(i)
385
386        if d.return_ty_ref:
387            self.handle_decl(d.return_ty_ref)
388
389        return self.visit_decl(d)
390
391    ### Type "Decl" ###
392
393    @override
394    def visit_type_decl(self, d: "TypeDecl") -> None:
395        return self.visit_decl(d)
396
397    ### Enum ###
398
399    @override
400    def visit_enum_item_decl(self, d: "EnumItemDecl") -> None:
401        return self.visit_decl(d)
402
403    @override
404    def visit_enum_decl(self, d: "EnumDecl") -> None:
405        self.handle_decl(d.ty_ref)
406
407        for i in d.items:
408            self.handle_decl(i)
409
410        return self.visit_type_decl(d)
411
412    ### Struct ###
413
414    @override
415    def visit_struct_field_decl(self, d: "StructFieldDecl") -> None:
416        self.handle_decl(d.ty_ref)
417
418        return self.visit_decl(d)
419
420    @override
421    def visit_struct_decl(self, d: "StructDecl") -> None:
422        for i in d.fields:
423            self.handle_decl(i)
424
425        return self.visit_type_decl(d)
426
427    ### Union ###
428
429    @override
430    def visit_union_field_decl(self, d: "UnionFieldDecl") -> None:
431        if d.ty_ref:
432            self.handle_decl(d.ty_ref)
433
434        return self.visit_decl(d)
435
436    @override
437    def visit_union_decl(self, d: "UnionDecl") -> None:
438        for i in d.fields:
439            self.handle_decl(i)
440
441        return self.visit_type_decl(d)
442
443    ### Interface ###
444
445    @override
446    def visit_iface_parent_decl(self, d: "IfaceParentDecl") -> None:
447        self.handle_decl(d.ty_ref)
448
449        return self.visit_decl(d)
450
451    @override
452    def visit_iface_func_decl(self, d: "IfaceMethodDecl") -> None:
453        for i in d.params:
454            self.handle_decl(i)
455
456        if d.return_ty_ref:
457            self.handle_decl(d.return_ty_ref)
458
459        return self.visit_decl(d)
460
461    @override
462    def visit_iface_decl(self, d: "IfaceDecl") -> None:
463        for i in d.parents:
464            self.handle_decl(i)
465
466        for i in d.methods:
467            self.handle_decl(i)
468
469        return self.visit_type_decl(d)
470
471    ### Package ###
472
473    @override
474    def visit_package_decl(self, p: "PackageDecl") -> None:
475        for i in p.pkg_imports:
476            self.handle_decl(i)
477
478        for i in p.decl_imports:
479            self.handle_decl(i)
480
481        for i in p.declarations:
482            self.handle_decl(i)
483
484        return self.visit_decl(p)
485
486    @override
487    def visit_package_group(self, g: "PackageGroup") -> None:
488        for i in g.packages:
489            self.handle_decl(i)