• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# coding=utf-8
2#
3# Copyright (c) 2025 Huawei Device Co., Ltd.
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16from collections.abc import Callable, Iterable
17from typing import Any, TypeGuard, TypeVar
18
19from typing_extensions import override
20
21from taihe.semantics.declarations import (
22    CallbackTypeRefDecl,
23    DeclarationRefDecl,
24    EnumDecl,
25    EnumItemDecl,
26    GenericTypeRefDecl,
27    GlobFuncDecl,
28    IfaceDecl,
29    IfaceMethodDecl,
30    IfaceParentDecl,
31    LongTypeRefDecl,
32    NamedDecl,
33    PackageDecl,
34    PackageGroup,
35    PackageRefDecl,
36    ShortTypeRefDecl,
37    StructDecl,
38    TypeDecl,
39    TypeRefDecl,
40    UnionDecl,
41)
42from taihe.semantics.types import (
43    BUILTIN_GENERICS,
44    BUILTIN_TYPES,
45    CallbackType,
46    ScalarKind,
47    ScalarType,
48    StringType,
49    Type,
50    UserType,
51)
52from taihe.semantics.visitor import RecursiveDeclVisitor
53from taihe.utils.diagnostics import DiagnosticsManager
54from taihe.utils.exceptions import (
55    DeclarationNotInScopeError,
56    DeclNotExistError,
57    DeclRedefError,
58    DuplicateExtendsWarn,
59    EnumValueError,
60    GenericArgumentsError,
61    NotATypeError,
62    PackageNotExistError,
63    PackageNotInScopeError,
64    RecursiveReferenceError,
65    SymbolConflictWithNamespaceError,
66    TypeUsageError,
67)
68
69
70def analyze_semantics(pg: PackageGroup, diag: DiagnosticsManager):
71    """Runs semantic analysis passes on the given package group."""
72    _check_decl_confilct_with_namespace(pg, diag)
73    _ResolveImportsPass(diag).handle_decl(pg)
74    _CheckFieldNameCollisionErrorPass(diag).handle_decl(pg)
75    _CheckEnumTypePass(diag).handle_decl(pg)
76    _CheckRecursiveInclusionPass(diag).handle_decl(pg)
77
78
79def _check_decl_confilct_with_namespace(
80    pg: PackageGroup,
81    diag: DiagnosticsManager,
82):
83    """Checks for declarations conflicts with namespaces."""
84    namespaces: dict[str, list[PackageDecl]] = {}
85    for pkg in pg.packages:
86        pkg_name = pkg.name
87        # package "a.b.c" -> namespaces ["a.b.c", "a.b", "a"]
88        while True:
89            namespaces.setdefault(pkg_name, []).append(pkg)
90            splited = pkg_name.rsplit(".", maxsplit=1)
91            if len(splited) == 2:
92                pkg_name = splited[0]
93            else:
94                break
95
96    for p in pg.packages:
97        for d in p.declarations:
98            name = p.name + "." + d.name
99            if packages := namespaces.get(name, []):
100                diag.emit(SymbolConflictWithNamespaceError(d, name, packages))
101
102
103class _ResolveImportsPass(RecursiveDeclVisitor):
104    """Resolves imports and type references within a package group."""
105
106    diag: DiagnosticsManager
107
108    def __init__(self, diag: DiagnosticsManager):
109        self._current_pkg_group = None
110        self._current_pkg = None  # Always points to the current package.
111        self.diag = diag
112
113    @property
114    def pkg(self) -> PackageDecl:
115        pass
116        return self._current_pkg
117
118    @property
119    def pkg_group(self) -> PackageGroup:
120        pass
121        return self._current_pkg_group
122
123    @override
124    def visit_package_decl(self, p: PackageDecl) -> None:
125        self._current_pkg = p
126        super().visit_package_decl(p)
127        self._current_pkg = None
128
129    @override
130    def visit_package_group(self, g: PackageGroup) -> None:
131        self._current_pkg_group = g
132        super().visit_package_group(g)
133        self._current_pkg_group = None
134
135    @override
136    def visit_package_ref_decl(self, d: PackageRefDecl) -> None:
137        if d.is_resolved:
138            return
139        d.is_resolved = True
140
141        pkg = self.pkg_group.lookup(d.symbol)
142
143        if pkg is None:
144            self.diag.emit(PackageNotExistError(d.symbol, loc=d.loc))
145            return
146
147        d.maybe_resolved_pkg = pkg
148
149    @override
150    def visit_declaration_ref_decl(self, d: DeclarationRefDecl) -> None:
151        if d.is_resolved:
152            return
153        d.is_resolved = True
154
155        self.handle_decl(d.pkg_ref)
156
157        pkg = d.pkg_ref.maybe_resolved_pkg
158
159        if pkg is None:
160            # No need to repeatedly throw exceptions
161            return
162
163        decl = pkg.lookup(d.symbol)
164
165        if decl is None:
166            self.diag.emit(DeclNotExistError(d.symbol, loc=d.loc))
167            return
168
169        d.maybe_resolved_decl = decl
170
171    @override
172    def visit_long_type_ref_decl(self, d: LongTypeRefDecl) -> None:
173        if d.is_resolved:
174            return
175        d.is_resolved = True
176
177        # Find the corresponding imported package according to the package name
178        pkg_import = self.pkg.lookup_pkg_import(d.pkname)
179
180        if pkg_import is None:
181            self.diag.emit(PackageNotInScopeError(d.pkname, loc=d.loc))
182            return
183
184        # Then find the corresponding type declaration from the package
185        pkg = pkg_import.pkg_ref.maybe_resolved_pkg
186
187        if pkg is None:
188            # No need to repeatedly throw exceptions
189            return
190
191        decl = pkg.lookup(d.symbol)
192
193        if decl is None:
194            self.diag.emit(DeclNotExistError(d.symbol, loc=d.loc))
195            return
196
197        if not isinstance(decl, TypeDecl):
198            self.diag.emit(NotATypeError(d.symbol, loc=d.loc))
199            return
200
201        d.maybe_resolved_ty = decl.as_type(d)
202
203    @override
204    def visit_short_type_ref_decl(self, d: ShortTypeRefDecl) -> None:
205        if d.is_resolved:
206            return
207        d.is_resolved = True
208
209        # Find Builtin Types
210        builder = BUILTIN_TYPES.get(d.symbol)
211
212        if builder:
213            d.maybe_resolved_ty = builder(d)
214            return
215
216        # Find types declared in the current package
217        decl = self.pkg.lookup(d.symbol)
218
219        if decl:
220            if not isinstance(decl, TypeDecl):
221                self.diag.emit(NotATypeError(d.symbol, loc=d.loc))
222                return
223
224            d.maybe_resolved_ty = decl.as_type(d)
225            return
226
227        # Look for imported type declarations
228        decl_import = self.pkg.lookup_decl_import(d.symbol)
229
230        if decl_import is None:
231            self.diag.emit(DeclarationNotInScopeError(d.symbol, loc=d.loc))
232            return
233
234        decl = decl_import.decl_ref.maybe_resolved_decl
235
236        if decl is None:
237            # No need to repeatedly throw exceptions
238            return
239
240        if not isinstance(decl, TypeDecl):
241            self.diag.emit(NotATypeError(d.symbol, loc=d.loc))
242            return
243
244        d.maybe_resolved_ty = decl.as_type(d)
245
246    @override
247    def visit_generic_type_ref_decl(self, d: GenericTypeRefDecl) -> None:
248        if d.is_resolved:
249            return
250        d.is_resolved = True
251
252        super().visit_generic_type_ref_decl(d)
253
254        args_ty: list[Type] = []
255        for arg_ty_ref in d.args_ty_ref:
256            arg_ty = arg_ty_ref.maybe_resolved_ty
257            if arg_ty is None:
258                # No need to repeatedly throw exceptions
259                return
260            args_ty.append(arg_ty)
261
262        decl_name = d.symbol
263
264        builder = BUILTIN_GENERICS.get(decl_name)
265
266        if builder is None:
267            self.diag.emit(DeclarationNotInScopeError(decl_name, loc=d.loc))
268            return
269
270        try:
271            d.maybe_resolved_ty = builder(d, *args_ty)
272        except TypeError:
273            self.diag.emit(GenericArgumentsError(d.text, loc=d.loc))
274
275    @override
276    def visit_callback_type_ref_decl(self, d: CallbackTypeRefDecl) -> None:
277        if d.is_resolved:
278            return
279        d.is_resolved = True
280
281        super().visit_callback_type_ref_decl(d)
282
283        if d.return_ty_ref:
284            return_ty = d.return_ty_ref.maybe_resolved_ty
285            if return_ty is None:
286                # No need to repeatedly throw exceptions
287                return
288        else:
289            return_ty = None
290
291        params_ty: list[Type] = []
292        for param in d.params:
293            arg_ty = param.ty_ref.maybe_resolved_ty
294            if arg_ty is None:
295                # No need to repeatedly throw exceptions
296                return
297            params_ty.append(arg_ty)
298
299        d.maybe_resolved_ty = CallbackType(d, return_ty, tuple(params_ty))
300
301
302class _CheckFieldNameCollisionErrorPass(RecursiveDeclVisitor):
303    """Check for duplicate field names in declarations and name anonymous declarations."""
304
305    diag: DiagnosticsManager
306
307    def __init__(self, diag: DiagnosticsManager):
308        self.diag = diag
309
310    @override
311    def visit_glob_func_decl(self, d: GlobFuncDecl) -> None:
312        self.check_collision_helper(d.params)
313        return super().visit_glob_func_decl(d)
314
315    @override
316    def visit_iface_func_decl(self, d: IfaceMethodDecl) -> None:
317        self.check_collision_helper(d.params)
318        return super().visit_iface_func_decl(d)
319
320    @override
321    def visit_enum_decl(self, d: EnumDecl) -> None:
322        self.check_collision_helper(d.items)
323        return super().visit_enum_decl(d)
324
325    @override
326    def visit_struct_decl(self, d: StructDecl) -> None:
327        self.check_collision_helper(d.fields)
328        return super().visit_struct_decl(d)
329
330    @override
331    def visit_union_decl(self, d: UnionDecl) -> None:
332        self.check_collision_helper(d.fields)
333        return super().visit_union_decl(d)
334
335    @override
336    def visit_iface_decl(self, d: IfaceDecl) -> None:
337        self.check_collision_helper(d.methods)
338        return super().visit_iface_decl(d)
339
340    @override
341    def visit_package_decl(self, p: PackageDecl) -> None:
342        self.check_collision_helper(p.declarations)
343        return super().visit_package_decl(p)
344
345    def check_collision_helper(self, children: Iterable[NamedDecl]):
346        names: dict[str, NamedDecl] = {}
347        for f in children:
348            if (prev := names.setdefault(f.name, f)) != f:
349                self.diag.emit(DeclRedefError(prev, f))
350
351
352class _CheckEnumTypePass(RecursiveDeclVisitor):
353    """Validated enum item types."""
354
355    diag: DiagnosticsManager
356
357    def __init__(self, diag: DiagnosticsManager):
358        self.diag = diag
359
360    def visit_enum_decl(self, d: EnumDecl) -> None:
361        def is_int(val: Any) -> TypeGuard[int]:
362            return not isinstance(val, bool) and isinstance(val, int)
363
364        valid: Callable[[Any], bool]
365        increment: Callable[[Any, EnumItemDecl], Any]
366        default: Callable[[EnumItemDecl], Any]
367
368        match d.ty_ref.maybe_resolved_ty:
369            case ScalarType(_, ScalarKind.I8):
370                valid = lambda val: is_int(val) and -(2**7) <= val < 2**7
371                increment = lambda prev, item: prev + 1
372                default = lambda item: 0
373            case ScalarType(_, ScalarKind.I16):
374                valid = lambda val: is_int(val) and -(2**15) <= val < 2**15
375                increment = lambda prev, item: prev + 1
376                default = lambda item: 0
377            case ScalarType(_, ScalarKind.I32):
378                valid = lambda val: is_int(val) and -(2**31) <= val < 2**31
379                increment = lambda prev, item: prev + 1
380                default = lambda item: 0
381            case ScalarType(_, ScalarKind.I64):
382                valid = lambda val: is_int(val) and -(2**63) <= val < 2**63
383                increment = lambda prev, item: prev + 1
384                default = lambda item: 0
385            case ScalarType(_, ScalarKind.U8):
386                valid = lambda val: is_int(val) and 0 <= val < 2**8
387                increment = lambda prev, item: prev + 1
388                default = lambda item: 0
389            case ScalarType(_, ScalarKind.U16):
390                valid = lambda val: is_int(val) and 0 <= val < 2**16
391                increment = lambda prev, item: prev + 1
392                default = lambda item: 0
393            case ScalarType(_, ScalarKind.U32):
394                valid = lambda val: is_int(val) and 0 <= val < 2**32
395                increment = lambda prev, item: prev + 1
396                default = lambda item: 0
397            case ScalarType(_, ScalarKind.U64):
398                valid = lambda val: is_int(val) and 0 <= val < 2**64
399                increment = lambda prev, item: prev + 1
400                default = lambda item: 0
401            case ScalarType(_, ScalarKind.BOOL):
402                valid = lambda val: isinstance(val, bool)
403                increment = lambda prev, item: False
404                default = lambda item: False
405            case ScalarType(_, ScalarKind.F32):
406                valid = lambda val: isinstance(val, float)
407                increment = lambda prev, item: 0.0
408                default = lambda item: 0.0
409            case ScalarType(_, ScalarKind.F64):
410                valid = lambda val: isinstance(val, float)
411                increment = lambda prev, item: 0.0
412                default = lambda item: 0.0
413            case StringType():
414                valid = lambda val: isinstance(val, str)
415                increment = lambda prev, item: item.name
416                default = lambda item: item.name
417            case None:
418                return
419            case _:
420                self.diag.emit(TypeUsageError(d.ty_ref))
421                return
422
423        prev = None
424        for item in d.items:
425            if item.value is None:
426                item.value = default(item) if prev is None else increment(prev, item)
427            if not valid(item.value):
428                self.diag.emit(EnumValueError(item, d))
429                prev = None
430            else:
431                prev = item.value
432
433
434class _CheckRecursiveInclusionPass(RecursiveDeclVisitor):
435    """Validates struct fields for type correctness and cycles."""
436
437    diag: DiagnosticsManager
438
439    def __init__(self, diag: DiagnosticsManager):
440        self.diag = diag
441        self.type_table: dict[
442            TypeDecl,
443            list[tuple[tuple[TypeDecl, TypeRefDecl], TypeDecl]],
444        ] = {}
445
446    def visit_package_group(self, g: PackageGroup) -> None:
447        self.type_table = {}
448        super().visit_package_group(g)
449        cycles = detect_cycles(self.type_table)
450        for cycle in cycles:
451            last, *other = cycle[::-1]
452            self.diag.emit(RecursiveReferenceError(last, other))
453
454    def visit_enum_decl(self, d: EnumDecl) -> None:
455        self.type_table[d] = []
456
457    def visit_iface_decl(self, d: IfaceDecl) -> None:
458        parent_iface_list = self.type_table.setdefault(d, [])
459        parent_iface_dict: dict[IfaceDecl, IfaceParentDecl] = {}
460        for parent in d.parents:
461            if (parent_ty := parent.ty_ref.maybe_resolved_ty) is None:
462                continue
463            if not isinstance(parent_ty, UserType):
464                self.diag.emit(TypeUsageError(parent.ty_ref))
465                continue
466            if not isinstance(parent_iface := parent_ty.ty_decl, IfaceDecl):
467                self.diag.emit(TypeUsageError(parent.ty_ref))
468                continue
469            parent_iface_list.append(((d, parent.ty_ref), parent_iface))
470            prev = parent_iface_dict.setdefault(parent_iface, parent)
471            if prev != parent:
472                self.diag.emit(
473                    DuplicateExtendsWarn(
474                        d,
475                        parent_iface,
476                        loc=parent.ty_ref.loc,
477                        prev_loc=prev.ty_ref.loc,
478                    )
479                )
480
481    def visit_struct_decl(self, d: StructDecl) -> None:
482        type_list = self.type_table.setdefault(d, [])
483        for f in d.fields:
484            if isinstance(ty := f.ty_ref.maybe_resolved_ty, UserType):
485                type_list.append(((d, f.ty_ref), ty.ty_decl))
486
487    def visit_union_decl(self, d: UnionDecl) -> None:
488        type_list = self.type_table.setdefault(d, [])
489        for i in d.fields:
490            if i.ty_ref is None:
491                continue
492            if isinstance(ty := i.ty_ref.maybe_resolved_ty, UserType):
493                type_list.append(((d, i.ty_ref), ty.ty_decl))
494
495
496V = TypeVar("V")
497E = TypeVar("E")
498
499
500def detect_cycles(graph: dict[V, list[tuple[E, V]]]) -> list[list[E]]:
501    """Detects and returns all cycles in a directed graph.
502
503    Example:
504    -------
505    >>> graph = {
506            "A": [("A.b_0", "B")],
507            "B": [("B.c_0", "C")],
508            "C": [("C.a_0", "A"), ("C.a_1", "A")],
509        }
510    >>> detect_cycles(graph)
511    [["A.b_0", "B.c_0", "C.a_0"], ["A.b_0", "B.c_0", "C.a_1"]]
512    """
513    cycles: list[list[E]] = []
514
515    order = {point: i for i, point in enumerate(graph)}
516    glist = [
517        [(edge, order[child]) for edge, child in children]
518        for children in graph.values()
519    ]
520    visited = [False for _ in glist]
521    edges: list[E] = []
522
523    def visit(i: int):
524        if i < k:
525            return
526        if visited[i]:
527            if i == k:
528                cycles.append(edges.copy())
529            return
530        visited[i] = True
531        for edge, j in glist[i]:
532            edges.append(edge)
533            visit(j)
534            edges.pop()
535        visited[i] = False
536
537    for k in range(len(glist)):
538        visit(k)
539
540    return cycles