• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# coding=utf-8
2#
3# Copyright (c) 2025 Huawei Device Co., Ltd.
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Convert AST to IR."""
17
18from collections.abc import Iterable
19from typing import Any
20
21from typing_extensions import override
22
23from taihe.parse.antlr.TaiheAST import TaiheAST as ast
24from taihe.parse.antlr.TaiheVisitor import TaiheVisitor as Visitor
25from taihe.parse.ast_generation import generate_ast
26from taihe.semantics.declarations import (
27    AttrItemDecl,
28    CallbackTypeRefDecl,
29    DeclarationImportDecl,
30    DeclarationRefDecl,
31    EnumDecl,
32    EnumItemDecl,
33    GenericTypeRefDecl,
34    GlobFuncDecl,
35    IfaceDecl,
36    IfaceMethodDecl,
37    IfaceParentDecl,
38    LongTypeRefDecl,
39    PackageDecl,
40    PackageImportDecl,
41    PackageRefDecl,
42    ParamDecl,
43    ShortTypeRefDecl,
44    StructDecl,
45    StructFieldDecl,
46    UnionDecl,
47    UnionFieldDecl,
48)
49from taihe.utils.diagnostics import DiagnosticsManager
50from taihe.utils.exceptions import InvalidPackageNameError
51from taihe.utils.sources import SourceBase, SourceLocation
52
53
54def pkg2str(pkg_name: ast.PkgName) -> str:
55    return ".".join(t.text for t in pkg_name.parts)
56
57
58def is_valid_pkg_name(name: str) -> bool:
59    """Checks if the package name is valid."""
60    for part in name.split("."):
61        if not part:
62            return False
63        if not all(c.isalpha() or c == "_" for c in part[:1]):
64            return False
65        if not all(c.isalnum() or c == "_" for c in part[1:]):
66            return False
67    return True
68
69
70class ExprEvaluator(Visitor):
71    # Bool Expr
72
73    @override
74    def visit_literal_bool_expr(self, node: ast.LiteralBoolExpr) -> bool:
75        return {
76            "true": True,
77            "false": False,
78        }[node.val.text]
79
80    @override
81    def visit_int_comparison_bool_expr(self, node: ast.IntComparisonBoolExpr) -> bool:
82        return {
83            ">": int.__gt__,
84            "<": int.__lt__,
85            ">=": int.__ge__,
86            "<=": int.__le__,
87            "==": int.__eq__,
88            "!=": int.__ne__,
89        }[node.op.text](
90            int(self.visit(node.left)),
91            int(self.visit(node.right)),
92        )
93
94    @override
95    def visit_float_comparison_bool_expr(
96        self, node: ast.FloatComparisonBoolExpr
97    ) -> bool:
98        return {
99            ">": float.__gt__,
100            "<": float.__lt__,
101            ">=": float.__ge__,
102            "<=": float.__le__,
103            "==": float.__eq__,
104            "!=": float.__ne__,
105        }[node.op.text](
106            float(self.visit(node.left)),
107            float(self.visit(node.right)),
108        )
109
110    @override
111    def visit_unary_bool_expr(self, node: ast.UnaryBoolExpr) -> bool:
112        pass
113        return not self.visit(node.expr)
114
115    @override
116    def visit_binary_bool_expr(self, node: ast.BinaryBoolExpr) -> bool:
117        return {
118            "&&": bool.__and__,
119            "||": bool.__or__,
120        }[node.op.text](
121            bool(self.visit(node.left)),
122            bool(self.visit(node.right)),
123        )
124
125    @override
126    def visit_parenthesis_bool_expr(self, node: ast.ParenthesisBoolExpr) -> bool:
127        return self.visit(node.expr)
128
129    @override
130    def visit_conditional_bool_expr(self, node: ast.ConditionalBoolExpr) -> bool:
131        return (
132            self.visit(node.then_expr)
133            if self.visit(node.cond)
134            else self.visit(node.else_expr)
135        )
136
137    # Int Expr
138
139    @override
140    def visit_literal_int_expr(self, node: ast.LiteralIntExpr) -> int:
141        text = node.val.text
142        if text.startswith("0b"):
143            return int(text, 2)
144        if text.startswith("0o"):
145            return int(text, 8)
146        if text.startswith("0x"):
147            return int(text, 16)
148        return int(text)
149
150    @override
151    def visit_parenthesis_int_expr(self, node: ast.ParenthesisIntExpr) -> int:
152        return self.visit(node.expr)
153
154    @override
155    def visit_conditional_int_expr(self, node: ast.ConditionalIntExpr) -> int:
156        return (
157            self.visit(node.then_expr)
158            if self.visit(node.cond)
159            else self.visit(node.else_expr)
160        )
161
162    @override
163    def visit_unary_int_expr(self, node: ast.UnaryIntExpr) -> int:
164        return {
165            "-": int.__neg__,
166            "+": int.__pos__,
167            "~": int.__invert__,
168        }[node.op.text](
169            int(self.visit(node.expr)),
170        )
171
172    @override
173    def visit_binary_int_expr(self, node: ast.BinaryIntExpr) -> int:
174        return {
175            "+": int.__add__,
176            "-": int.__sub__,
177            "*": int.__mul__,
178            "/": int.__floordiv__,
179            "%": int.__mod__,
180            "<<": int.__lshift__,
181            ">>": int.__rshift__,
182            "&": int.__and__,
183            "|": int.__or__,
184            "^": int.__xor__,
185        }[node.op.text](
186            int(self.visit(node.left)),
187            int(self.visit(node.right)),
188        )
189
190    @override
191    def visit_binary_int_shift_expr(self, node: ast.BinaryIntShiftExpr) -> int:
192        return {
193            "<": int.__lshift__,
194            ">": int.__rshift__,
195        }[node.ch.text](
196            int(self.visit(node.left)),
197            int(self.visit(node.right)),
198        )
199
200    # Float Expr
201
202    @override
203    def visit_literal_float_expr(self, node: ast.LiteralFloatExpr) -> float:
204        return float(node.val.text)
205
206    @override
207    def visit_parenthesis_float_expr(self, node: ast.ParenthesisFloatExpr) -> float:
208        return self.visit(node.expr)
209
210    @override
211    def visit_conditional_float_expr(self, node: ast.ConditionalFloatExpr) -> Any:
212        return (
213            self.visit(node.then_expr)
214            if self.visit(node.cond)
215            else self.visit(node.else_expr)
216        )
217
218    @override
219    def visit_unary_float_expr(self, node: ast.UnaryFloatExpr) -> float:
220        return {
221            "-": float.__neg__,
222            "+": float.__pos__,
223        }[node.op.text](
224            float(self.visit(node.expr)),
225        )
226
227    @override
228    def visit_binary_float_expr(self, node: ast.BinaryFloatExpr) -> float:
229        return {
230            "+": float.__add__,
231            "-": float.__sub__,
232            "*": float.__mul__,
233            "/": float.__truediv__,
234        }[node.op.text](
235            float(self.visit(node.left)),
236            float(self.visit(node.right)),
237        )
238
239    # String Expr
240
241    @override
242    def visit_literal_string_expr(self, node: ast.LiteralStringExpr) -> str:
243        return node.val.text[1:-1].encode("utf-8").decode("unicode_escape")
244
245    @override
246    def visit_literal_doc_string_expr(self, node: ast.LiteralDocStringExpr) -> str:
247        return node.val.text[3:-3]
248
249    @override
250    def visit_binary_string_expr(self, node: ast.BinaryStringExpr) -> str:
251        return self.visit(node.left) + self.visit(node.right)
252
253    @override
254    def visit_any_expr(self, node: ast.AnyExpr) -> Any:
255        return self.visit(node.expr)
256
257
258class AstConverter(ExprEvaluator):
259    """Converts a node on AST to the intermetiade representation.
260
261    Note that declerations with errors are discarded.
262    """
263
264    source: SourceBase
265    diag: DiagnosticsManager
266
267    def __init__(self, source: SourceBase, diag: DiagnosticsManager):
268        self.source = source
269        self.diag = diag
270
271    # Type References
272
273    @override
274    def visit_long_type(self, node: ast.LongType) -> LongTypeRefDecl:
275        d = LongTypeRefDecl(node.loc, pkg2str(node.pkg_name), str(node.decl_name))
276        self.diag.for_each(node.forward_attrs, lambda a: d.add_attr(self.visit(a)))
277        return d
278
279    @override
280    def visit_short_type(self, node: ast.ShortType) -> ShortTypeRefDecl:
281        d = ShortTypeRefDecl(node.loc, str(node.decl_name))
282        self.diag.for_each(node.forward_attrs, lambda a: d.add_attr(self.visit(a)))
283        return d
284
285    @override
286    def visit_generic_type(self, node: ast.GenericType) -> GenericTypeRefDecl:
287        d = GenericTypeRefDecl(node.loc, str(node.decl_name))
288        self.diag.for_each(node.args, lambda a: d.add_arg_ty_ref(self.visit(a)))
289        self.diag.for_each(node.forward_attrs, lambda a: d.add_attr(self.visit(a)))
290        return d
291
292    @override
293    def visit_callback_type(self, node: ast.CallbackType) -> CallbackTypeRefDecl:
294        if ty := node.return_ty:
295            d = CallbackTypeRefDecl(node.loc, self.visit(ty))
296        else:
297            d = CallbackTypeRefDecl(node.loc)
298        self.diag.for_each(node.parameters, lambda p: d.add_param(self.visit(p)))
299        self.diag.for_each(node.forward_attrs, lambda a: d.add_attr(self.visit(a)))
300        return d
301
302    # Uses
303
304    @override
305    def visit_use_package(self, node: ast.UsePackage) -> Iterable[PackageImportDecl]:
306        p_ref = PackageRefDecl(node.pkg_name.loc, pkg2str(node.pkg_name))
307        if node.pkg_alias:
308            d = PackageImportDecl(
309                p_ref,
310                name=str(node.pkg_alias),
311                loc=node.pkg_alias.loc,
312            )
313        else:
314            d = PackageImportDecl(
315                p_ref,
316            )
317        yield d
318
319    @override
320    def visit_use_symbol(self, node: ast.UseSymbol) -> Iterable[DeclarationImportDecl]:
321        p_ref = PackageRefDecl(node.pkg_name.loc, pkg2str(node.pkg_name))
322        for p in node.decl_alias_pairs:
323            d_ref = DeclarationRefDecl(p.decl_name.loc, str(p.decl_name), p_ref)
324            if p.decl_alias:
325                d = DeclarationImportDecl(
326                    d_ref,
327                    name=str(p.decl_alias),
328                    loc=p.decl_alias.loc,
329                )
330            else:
331                d = DeclarationImportDecl(
332                    d_ref,
333                )
334            yield d
335
336    # Declarations
337
338    @override
339    def visit_struct_property(self, node: ast.StructProperty) -> StructFieldDecl:
340        d = StructFieldDecl(node.name.loc, str(node.name), self.visit(node.ty))
341        self.diag.for_each(node.forward_attrs, lambda a: d.add_attr(self.visit(a)))
342        return d
343
344    @override
345    def visit_struct(self, node: ast.Struct) -> StructDecl:
346        d = StructDecl(node.name.loc, str(node.name))
347        self.diag.for_each(node.fields, lambda f: d.add_field(self.visit(f)))
348        self.diag.for_each(node.forward_attrs, lambda a: d.add_attr(self.visit(a)))
349        self.diag.for_each(node.inner_attrs, lambda a: d.add_attr(self.visit(a)))
350        return d
351
352    @override
353    def visit_enum_property(self, node: ast.EnumProperty) -> EnumItemDecl:
354        if node.val:
355            d = EnumItemDecl(node.name.loc, str(node.name), self.visit(node.val))
356        else:
357            d = EnumItemDecl(node.name.loc, str(node.name))
358        self.diag.for_each(node.forward_attrs, lambda a: d.add_attr(self.visit(a)))
359        return d
360
361    @override
362    def visit_enum(self, node: ast.Enum) -> EnumDecl:
363        d = EnumDecl(node.name.loc, str(node.name), self.visit(node.enum_ty))
364        self.diag.for_each(node.fields, lambda a: d.add_item(self.visit(a)))
365        self.diag.for_each(node.forward_attrs, lambda a: d.add_attr(self.visit(a)))
366        return d
367
368    @override
369    def visit_union_property(self, node: ast.UnionProperty) -> UnionFieldDecl:
370        if ty := node.ty:
371            d = UnionFieldDecl(node.name.loc, str(node.name), self.visit(ty))
372        else:
373            d = UnionFieldDecl(node.name.loc, str(node.name))
374        self.diag.for_each(node.forward_attrs, lambda a: d.add_attr(self.visit(a)))
375        return d
376
377    @override
378    def visit_union(self, node: ast.Union) -> UnionDecl:
379        d = UnionDecl(node.name.loc, str(node.name))
380        self.diag.for_each(node.fields, lambda f: d.add_field(self.visit(f)))
381        self.diag.for_each(node.forward_attrs, lambda a: d.add_attr(self.visit(a)))
382        self.diag.for_each(node.inner_attrs, lambda a: d.add_attr(self.visit(a)))
383        return d
384
385    @override
386    def visit_parameter(self, node: ast.Parameter) -> ParamDecl:
387        d = ParamDecl(node.name.loc, str(node.name), self.visit(node.ty))
388        self.diag.for_each(node.forward_attrs, lambda a: d.add_attr(self.visit(a)))
389        return d
390
391    @override
392    def visit_interface_function(self, node: ast.InterfaceFunction) -> IfaceMethodDecl:
393        if ty := node.return_ty:
394            d = IfaceMethodDecl(node.name.loc, str(node.name), self.visit(ty))
395        else:
396            d = IfaceMethodDecl(node.name.loc, str(node.name))
397        self.diag.for_each(node.parameters, lambda p: d.add_param(self.visit(p)))
398        self.diag.for_each(node.forward_attrs, lambda a: d.add_attr(self.visit(a)))
399        return d
400
401    @override
402    def visit_interface_parent(self, node: ast.InterfaceParent) -> IfaceParentDecl:
403        p = IfaceParentDecl(node.ty.loc, self.visit(node.ty))
404        return p
405
406    @override
407    def visit_interface(self, node: ast.Interface) -> IfaceDecl:
408        d = IfaceDecl(node.name.loc, str(node.name))
409        self.diag.for_each(node.fields, lambda f: d.add_method(self.visit(f)))
410        self.diag.for_each(node.extends, lambda i: d.add_parent(self.visit(i)))
411        self.diag.for_each(node.forward_attrs, lambda a: d.add_attr(self.visit(a)))
412        self.diag.for_each(node.inner_attrs, lambda a: d.add_attr(self.visit(a)))
413        return d
414
415    @override
416    def visit_global_function(self, node: ast.GlobalFunction) -> GlobFuncDecl:
417        if ty := node.return_ty:
418            d = GlobFuncDecl(node.name.loc, str(node.name), self.visit(ty))
419        else:
420            d = GlobFuncDecl(node.name.loc, str(node.name))
421        self.diag.for_each(node.parameters, lambda p: d.add_param(self.visit(p)))
422        self.diag.for_each(node.forward_attrs, lambda a: d.add_attr(self.visit(a)))
423        return d
424
425    # Attributes
426
427    def visit_attr(self, node: ast.DeclAttr | ast.ScopeAttr) -> AttrItemDecl:
428        args: list[Any] = []
429        kwargs: dict[str, Any] = {}
430        for arg in node.args:
431            if isinstance(arg, ast.NamedAttrArg):
432                kwargs[str(arg.name)] = self.visit(arg.val)
433            else:
434                args.append(self.visit(arg.val))
435        d = AttrItemDecl(node.name.loc, str(node.name), args, kwargs)
436        return d
437
438    @override
439    def visit_decl_attr(self, node: ast.DeclAttr) -> AttrItemDecl:
440        return self.visit_attr(node)
441
442    @override
443    def visit_scope_attr(self, node: ast.ScopeAttr) -> AttrItemDecl:
444        return self.visit_attr(node)
445
446    # Package
447
448    @override
449    def visit_spec(self, node: ast.Spec) -> PackageDecl:
450        if not is_valid_pkg_name(self.source.pkg_name):
451            raise InvalidPackageNameError(
452                self.source.pkg_name,
453                loc=SourceLocation(self.source),
454            )
455        pkg = PackageDecl(self.source.pkg_name, SourceLocation(self.source))
456        for u in node.uses:
457            self.diag.for_each(self.visit(u), pkg.add_import)
458        self.diag.for_each(node.fields, lambda n: pkg.add_declaration(self.visit(n)))
459        self.diag.for_each(node.inner_attrs, lambda a: pkg.add_attr(self.visit(a)))
460        return pkg
461
462    def convert(self) -> PackageDecl:
463        """Converts the whole source code buffer to a package.
464
465        Returns:
466            PackageDecl: The package declaration containing all declarations
467            and imports from the source code.
468
469        Raises:
470            InvalidPackageNameError: If the package name is invalid.
471        """
472        ast_nodes = generate_ast(self.source, self.diag)
473        return self.visit_spec(ast_nodes)