• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2# Copyright (c) 2025 Huawei Device Co., Ltd.
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15
16import re
17import sys
18from functools import cache
19from pathlib import Path
20from subprocess import check_call
21
22from taihe.utils.resources import ResourceLocator, ResourceType
23
24ANTLR_PKG = "taihe.parse.antlr"
25CURRENT_DIR = Path(__file__).parent.resolve()
26ANTLR_PATH = CURRENT_DIR / ANTLR_PKG.replace(".", "/")
27G4_FILE = CURRENT_DIR / "Taihe.g4"
28
29# HACK: The parent module `taihe.parse` imports the code which will be generated by us soon.
30# We directly import ANTLR-generated module, without initializing the parent module.
31sys.path.insert(0, str(ANTLR_PATH))
32
33
34@cache
35def get_parser():
36    from TaiheParser import TaiheParser
37
38    return TaiheParser
39
40
41def get_hint(attr_kind):
42    if attr_kind.endswith("Lst"):
43        return f'List["TaiheAST.{attr_kind[:-3]}"]'
44    if attr_kind.endswith("Opt"):
45        return f'Optional["TaiheAST.{attr_kind[:-3]}"]'
46    return f'"TaiheAST.{attr_kind}"'
47
48
49def get_attr_pairs(ctx):
50    for attr_full_name, attr_ctx in ctx.__dict__.items():
51        if not attr_full_name.startswith("_") and attr_full_name != "parser":
52            yield attr_full_name.split("_", 1)
53
54
55def snake_case(name):
56    """Convert CamelCase to snake_case."""
57    return re.sub(r"(?<!^)(?=[A-Z])", "_", name).lower()
58
59
60class Inspector:
61    def __init__(self):
62        self.parentCtx = None
63        self.invokingState = None
64        self.children = None
65        self.start = None
66        self.stop = None
67
68
69def generate_ast():
70    with open(ANTLR_PATH / "TaiheAST.py", "w") as file:
71        file.write(
72            f"from dataclasses import dataclass\n"
73            f"from typing import Any, Union, List, Optional\n"
74            f"\n"
75            f"from taihe.utils.sources import SourceLocation\n"
76            f"\n"
77            f"\n"
78            f"class TaiheAST:\n"
79            f"    @dataclass(kw_only=True)\n"
80            f"    class any:\n"
81            f"        loc: SourceLocation\n"
82            f"\n"
83            f"        def _accept(self, visitor) -> Any:\n"
84            f"            raise NotImplementedError()\n"
85            f"\n"
86            f"\n"
87            f"    @dataclass\n"
88            f"    class TOKEN(any):\n"
89            f"        text: str\n"
90            f"\n"
91            f"        def __str__(self):\n"
92            f"            return self.text\n"
93            f"\n"
94            f"        def _accept(self, visitor) -> Any:\n"
95            f"            return visitor.visit_token(self)\n"
96            f"\n"
97        )
98        parser = get_parser()
99        type_list = []
100        for rule_name in parser.ruleNames:
101            node_kind = rule_name[0].upper() + rule_name[1:]
102            ctx_kind = node_kind + "Context"
103            ctx_type = getattr(parser, ctx_kind)
104            type_list.append((node_kind, ctx_type))
105        for node_kind, ctx_type in type_list:
106            subclasses = ctx_type.__subclasses__()
107            if subclasses:
108                file.write(f"    {node_kind} = Union[\n")
109                for sub_type in subclasses:
110                    sub_kind = sub_type.__name__
111                    attr_kind = sub_kind[:-7]
112                    attr_hint = get_hint(attr_kind)
113                    type_list.append((attr_kind, sub_type))
114                    file.write(f"        {attr_hint},\n")
115                file.write(f"    ]\n" f"\n")
116            else:
117                ctx = ctx_type(None, Inspector())
118                file.write(f"    @dataclass\n" f"    class {node_kind}(any):\n")
119                for attr_kind, attr_name in get_attr_pairs(ctx):
120                    attr_hint = get_hint(attr_kind)
121                    file.write(f"        {attr_name}: {attr_hint}\n")
122                file.write(
123                    f"\n"
124                    f"        def _accept(self, visitor) -> Any:\n"
125                    f"            return visitor.visit_{snake_case(node_kind)}(self)\n"
126                    f"\n"
127                )
128
129
130def generate_visitor():
131    with open(ANTLR_PATH / "TaiheVisitor.py", "w") as file:
132        file.write(
133            f"from {ANTLR_PKG}.TaiheAST import TaiheAST\n"
134            f"\n"
135            f"from typing import Any\n"
136            f"\n"
137            f"\n"
138            f"class TaiheVisitor:\n"
139            f"    def visit(self, node: TaiheAST.any) -> Any:\n"
140            f"        return node._accept(self)\n"
141            f"\n"
142            f"    def visit_token(self, node: TaiheAST.TOKEN) -> Any:\n"
143            f"        raise NotImplementedError()\n"
144            f"\n"
145        )
146        parser = get_parser()
147        type_list = []
148        for rule_name in parser.ruleNames:
149            node_kind = rule_name[0].upper() + rule_name[1:]
150            ctx_kind = node_kind + "Context"
151            ctx_type = getattr(parser, ctx_kind)
152            type_list.append((node_kind, ctx_type))
153        for node_kind, ctx_type in type_list:
154            subclasses = ctx_type.__subclasses__()
155            if subclasses:
156                for sub_type in subclasses:
157                    sub_kind = sub_type.__name__
158                    attr_kind = sub_kind[:-7]
159                    file.write(
160                        f"    def visit_{snake_case(attr_kind)}(self, node: TaiheAST.{attr_kind}) -> Any:\n"
161                        f"        return self.visit_{snake_case(node_kind)}(node)\n"
162                        f"\n"
163                    )
164            file.write(
165                f"    def visit_{snake_case(node_kind)}(self, node: TaiheAST.{node_kind}) -> Any:\n"
166                f"        raise NotImplementedError()\n"
167                f"\n"
168            )
169
170
171def run_antlr():
172    locator = ResourceLocator.detect()
173    jar = locator.get(ResourceType.DEV_ANTLR)
174    args = ["java", "-cp", str(jar), "org.antlr.v4.Tool"]
175    args += ["-Dlanguage=Python3", "-no-listener", G4_FILE, "-o", ANTLR_PATH]
176    check_call(args, env={})
177
178
179def main():
180    run_antlr()
181    generate_ast()
182    generate_visitor()
183
184
185if __name__ == "__main__":
186    main()
187