• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# coding=utf-8
2#
3# Copyright (c) 2025 Huawei Device Co., Ltd.
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Format the IDL files."""
17
18from collections.abc import Callable
19from typing import TYPE_CHECKING, Any
20
21from typing_extensions import override
22
23from taihe.semantics.visitor import DeclVisitor
24from taihe.utils.diagnostics import AnsiStyle
25from taihe.utils.outputs import BaseWriter
26
27if TYPE_CHECKING:
28    from taihe.semantics.declarations import (
29        CallbackTypeRefDecl,
30        Decl,
31        DeclarationImportDecl,
32        DeclarationRefDecl,
33        EnumDecl,
34        EnumItemDecl,
35        GenericTypeRefDecl,
36        GlobFuncDecl,
37        IfaceDecl,
38        IfaceMethodDecl,
39        IfaceParentDecl,
40        LongTypeRefDecl,
41        PackageDecl,
42        PackageGroup,
43        PackageImportDecl,
44        PackageRefDecl,
45        ParamDecl,
46        ShortTypeRefDecl,
47        StructDecl,
48        StructFieldDecl,
49        TypeRefDecl,
50        UnionDecl,
51        UnionFieldDecl,
52    )
53
54WrapF = Callable[[str], str]
55
56
57class PrettyFormatter(DeclVisitor[str]):
58    as_keyword: WrapF
59    as_attr: WrapF
60    as_comment: WrapF
61
62    def __init__(self, show_resolved: bool = False, colorize: bool = False):
63        self.show_resolved = show_resolved
64        if colorize:
65            self.as_keyword = lambda s: f"{AnsiStyle.CYAN}{s}{AnsiStyle.RESET}"
66            self.as_attr = lambda s: f"{AnsiStyle.MAGENTA}{s}{AnsiStyle.RESET}"
67            self.as_comment = lambda s: f"{AnsiStyle.GREEN}{s}{AnsiStyle.RESET}"
68        else:
69            self.as_keyword = lambda s: s
70            self.as_attr = lambda s: s
71            self.as_comment = lambda s: s
72
73    def with_attr(self, d: "Decl", s: str) -> str:
74        if not d.attrs:
75            return s
76        fmt_attrs = " ".join(
77            self.as_attr(f"@{item}") for item in self.get_format_attr(d)
78        )
79        return f"{fmt_attrs} {s}"
80
81    def get_type_ref_decl(self, d: "TypeRefDecl"):
82        type_ref_repr = self.handle_decl(d)
83        if not d.is_resolved or not self.show_resolved:
84            return type_ref_repr
85        real_type = (
86            d.maybe_resolved_ty.signature if d.maybe_resolved_ty else "<error type>"
87        )
88        comment = self.as_comment(f"/* {real_type} */")
89        return f"{type_ref_repr} {comment}"
90
91    @override
92    def visit_long_type_ref_decl(self, d: "LongTypeRefDecl") -> str:
93        return self.with_attr(d, f"{d.pkname}.{d.symbol}")
94
95    @override
96    def visit_short_type_ref_decl(self, d: "ShortTypeRefDecl") -> str:
97        return self.with_attr(d, d.symbol)
98
99    @override
100    def visit_generic_type_ref_decl(self, d: "GenericTypeRefDecl") -> str:
101        args_fmt = ", ".join(map(self.get_type_ref_decl, d.args_ty_ref))
102        return self.with_attr(d, f"{d.symbol}<{args_fmt}>")
103
104    @override
105    def visit_callback_type_ref_decl(self, d: "CallbackTypeRefDecl") -> str:
106        fmt_args = ", ".join(map(self.get_param_decl, d.params))
107        ret = self.get_type_ref_decl(d.return_ty_ref) if d.return_ty_ref else "void"
108        return self.with_attr(d, f"({fmt_args}) => {ret}")
109
110    def get_package_ref_decl(self, d: "PackageRefDecl") -> str:
111        package_ref_repr = d.symbol
112        if not d.is_resolved or not self.show_resolved:
113            return package_ref_repr
114        real_package = (
115            d.maybe_resolved_pkg.description
116            if d.maybe_resolved_pkg
117            else "<error package>"
118        )
119        comment = self.as_comment(f"/* {real_package} */")
120        return f"{package_ref_repr} {comment}"
121
122    def get_declaration_ref_decl(self, d: "DeclarationRefDecl") -> str:
123        decl_ref_repr = d.symbol
124        if not d.is_resolved or not self.show_resolved:
125            return decl_ref_repr
126        real_decl = (
127            d.maybe_resolved_decl.description
128            if d.maybe_resolved_decl
129            else "<error declaration>"
130        )
131        comment = self.as_comment(f"/* {real_decl} */")
132        return f"{decl_ref_repr} {comment}"
133
134    def get_parent_decl(self, d: "IfaceParentDecl") -> str:
135        res = self.get_type_ref_decl(d.ty_ref)
136        return self.with_attr(d, res)
137
138    def get_param_decl(self, d: "ParamDecl") -> str:
139        res = f"{d.name}: {self.get_type_ref_decl(d.ty_ref)}"
140        return self.with_attr(d, res)
141
142    def get_value(self, obj: Any) -> str:
143        if isinstance(obj, str):
144            return '"' + obj.encode("unicode_escape").decode("utf-8") + '"'
145        if isinstance(obj, bool):
146            return "true" if obj else "false"
147        if isinstance(obj, int):
148            return f"{obj:d}"
149        if isinstance(obj, float):
150            return f"{obj:f}"
151        raise TypeError(f"Unsupported type: {type(obj)}")
152
153    def get_format_attr(self, d: "Decl") -> list[str]:
154        formatted_attributes: list[str] = []
155        for key, items in d.attrs.items():
156            for item in items:
157                if item.args:
158                    values_fmt = ", ".join(map(self.get_value, item.args))
159                    formatted_attributes.append(f"{key}({values_fmt})")
160                else:
161                    formatted_attributes.append(key)
162        return formatted_attributes
163
164
165class PrettyPrinter(DeclVisitor[None]):
166    def __init__(
167        self,
168        out: BaseWriter,
169        show_resolved: bool = False,
170        colorize: bool = False,
171    ):
172        self.out = out
173        self.fmt = PrettyFormatter(show_resolved, colorize)
174
175    def write_pkg_attr(self, d: "PackageDecl"):
176        for item in self.fmt.get_format_attr(d):
177            attr = self.fmt.as_attr(f"@!{item}")
178            self.out.writeln(f"{attr}")
179
180    def write_attr(self, d: "Decl"):
181        for item in self.fmt.get_format_attr(d):
182            attr = self.fmt.as_attr(f"@{item}")
183            self.out.writeln(f"{attr}")
184
185    @override
186    def visit_package_import_decl(self, d: "PackageImportDecl"):
187        self.write_attr(d)
188
189        use_kw = self.fmt.as_keyword("use")
190        as_kw = self.fmt.as_keyword("as")
191
192        alias_pair = (
193            f"{self.fmt.get_package_ref_decl(d.pkg_ref)} {as_kw} {d.name}"
194            if d.is_alias()
195            else self.fmt.get_package_ref_decl(d.pkg_ref)
196        )
197
198        self.out.writeln(f"{use_kw} {alias_pair};")
199
200    @override
201    def visit_decl_import_decl(self, d: "DeclarationImportDecl"):
202        self.write_attr(d)
203
204        from_kw = self.fmt.as_keyword("from")
205        use_kw = self.fmt.as_keyword("use")
206        as_kw = self.fmt.as_keyword("as")
207
208        alias_pair = (
209            f"{self.fmt.get_declaration_ref_decl(d.decl_ref)} {as_kw} {d.name}"
210            if d.is_alias()
211            else self.fmt.get_declaration_ref_decl(d.decl_ref)
212        )
213
214        self.out.writeln(
215            f"{from_kw} {self.fmt.get_package_ref_decl(d.decl_ref.pkg_ref)} {use_kw} {alias_pair};"
216        )
217
218    @override
219    def visit_glob_func_decl(self, d: "GlobFuncDecl"):
220        self.write_attr(d)
221
222        func_kw = self.fmt.as_keyword("function")
223
224        fmt_args = ", ".join(map(self.fmt.get_param_decl, d.params))
225        ret = self.fmt.get_type_ref_decl(d.return_ty_ref) if d.return_ty_ref else "void"
226
227        self.out.writeln(f"{func_kw} {d.name}({fmt_args}): {ret};")
228
229    @override
230    def visit_enum_item_decl(self, d: "EnumItemDecl") -> None:
231        self.write_attr(d)
232
233        if d.value is None:
234            self.out.writeln(f"{d.name},")
235        else:
236            self.out.writeln(f"{d.name} = {self.fmt.get_value(d.value)},")
237
238    @override
239    def visit_enum_decl(self, d: "EnumDecl") -> None:
240        self.write_attr(d)
241
242        enum_kw = self.fmt.as_keyword("enum")
243
244        full_decl = f"{d.name}: {self.fmt.get_type_ref_decl(d.ty_ref)}"
245        prologue = f"{enum_kw} {full_decl} {{"
246        epilogue = f"}}"
247
248        if d.items:
249            with self.out.indented(prologue, epilogue):
250                for i in d.items:
251                    self.handle_decl(i)
252        else:
253            self.out.writeln(prologue + epilogue)
254
255    @override
256    def visit_union_field_decl(self, d: "UnionFieldDecl"):
257        self.write_attr(d)
258
259        if d.ty_ref:
260            self.out.writeln(f"{d.name}: {self.fmt.get_type_ref_decl(d.ty_ref)};")
261        else:
262            self.out.writeln(f"{d.name};")
263
264    @override
265    def visit_union_decl(self, d: "UnionDecl"):
266        self.write_attr(d)
267
268        union_kw = self.fmt.as_keyword("union")
269        prologue = f"{union_kw} {d.name} {{"
270        epilogue = f"}}"
271
272        if d.fields:
273            with self.out.indented(prologue, epilogue):
274                for f in d.fields:
275                    self.handle_decl(f)
276        else:
277            self.out.writeln(prologue + epilogue)
278
279    @override
280    def visit_struct_field_decl(self, d: "StructFieldDecl"):
281        self.write_attr(d)
282
283        self.out.writeln(f"{d.name}: {self.fmt.get_type_ref_decl(d.ty_ref)};")
284
285    @override
286    def visit_struct_decl(self, d: "StructDecl"):
287        self.write_attr(d)
288
289        struct_kw = self.fmt.as_keyword("struct")
290        prologue = f"{struct_kw} {d.name} {{"
291        epilogue = f"}}"
292
293        if d.fields:
294            with self.out.indented(prologue, epilogue):
295                for f in d.fields:
296                    self.handle_decl(f)
297        else:
298            self.out.writeln(prologue + epilogue)
299
300    @override
301    def visit_iface_func_decl(self, d: "IfaceMethodDecl"):
302        self.write_attr(d)
303
304        fmt_args = ", ".join(map(self.fmt.get_param_decl, d.params))
305        ret = self.fmt.get_type_ref_decl(d.return_ty_ref) if d.return_ty_ref else "void"
306
307        self.out.writeln(f"{d.name}({fmt_args}): {ret};")
308
309    @override
310    def visit_iface_decl(self, d: "IfaceDecl"):
311        self.write_attr(d)
312
313        iface_kw = self.fmt.as_keyword("interface")
314
315        full_decl = (
316            f"{d.name}: " + ", ".join(map(self.fmt.get_parent_decl, d.parents))
317            if d.parents
318            else d.name
319        )
320        prologue = f"{iface_kw} {full_decl} {{"
321        epilogue = f"}}"
322
323        if d.methods:
324            with self.out.indented(prologue, epilogue):
325                for f in d.methods:
326                    self.handle_decl(f)
327        else:
328            self.out.writeln(prologue + epilogue)
329
330    @override
331    def visit_package_decl(self, p: "PackageDecl"):
332        self.out.writeln(f"// {p.name}")
333        self.write_pkg_attr(p)
334        for d in p.pkg_imports:
335            self.handle_decl(d)
336        for d in p.decl_imports:
337            self.handle_decl(d)
338        for d in p.declarations:
339            self.handle_decl(d)
340
341    @override
342    def visit_package_group(self, g: "PackageGroup"):
343        for i, p in enumerate(g.packages):
344            if i != 0:
345                self.out.newline()
346            self.handle_decl(p)