#!/usr/bin/env python3 # Copyright (c) 2025 Huawei Device Co., Ltd. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import re import sys from functools import cache from pathlib import Path from subprocess import check_call from taihe.utils.resources import ResourceLocator, ResourceType ANTLR_PKG = "taihe.parse.antlr" CURRENT_DIR = Path(__file__).parent.resolve() ANTLR_PATH = CURRENT_DIR / ANTLR_PKG.replace(".", "/") G4_FILE = CURRENT_DIR / "Taihe.g4" # HACK: The parent module `taihe.parse` imports the code which will be generated by us soon. # We directly import ANTLR-generated module, without initializing the parent module. sys.path.insert(0, str(ANTLR_PATH)) @cache def get_parser(): from TaiheParser import TaiheParser return TaiheParser def get_hint(attr_kind): if attr_kind.endswith("Lst"): return f'List["TaiheAST.{attr_kind[:-3]}"]' if attr_kind.endswith("Opt"): return f'Optional["TaiheAST.{attr_kind[:-3]}"]' return f'"TaiheAST.{attr_kind}"' def get_attr_pairs(ctx): for attr_full_name, attr_ctx in ctx.__dict__.items(): if not attr_full_name.startswith("_") and attr_full_name != "parser": yield attr_full_name.split("_", 1) def snake_case(name): """Convert CamelCase to snake_case.""" return re.sub(r"(? Any:\n" f" raise NotImplementedError()\n" f"\n" f"\n" f" @dataclass\n" f" class TOKEN(any):\n" f" text: str\n" f"\n" f" def __str__(self):\n" f" return self.text\n" f"\n" f" def _accept(self, visitor) -> Any:\n" f" return visitor.visit_token(self)\n" f"\n" ) parser = get_parser() type_list = [] for rule_name in parser.ruleNames: node_kind = rule_name[0].upper() + rule_name[1:] ctx_kind = node_kind + "Context" ctx_type = getattr(parser, ctx_kind) type_list.append((node_kind, ctx_type)) for node_kind, ctx_type in type_list: subclasses = ctx_type.__subclasses__() if subclasses: file.write(f" {node_kind} = Union[\n") for sub_type in subclasses: sub_kind = sub_type.__name__ attr_kind = sub_kind[:-7] attr_hint = get_hint(attr_kind) type_list.append((attr_kind, sub_type)) file.write(f" {attr_hint},\n") file.write(f" ]\n" f"\n") else: ctx = ctx_type(None, Inspector()) file.write(f" @dataclass\n" f" class {node_kind}(any):\n") for attr_kind, attr_name in get_attr_pairs(ctx): attr_hint = get_hint(attr_kind) file.write(f" {attr_name}: {attr_hint}\n") file.write( f"\n" f" def _accept(self, visitor) -> Any:\n" f" return visitor.visit_{snake_case(node_kind)}(self)\n" f"\n" ) def generate_visitor(): with open(ANTLR_PATH / "TaiheVisitor.py", "w") as file: file.write( f"from {ANTLR_PKG}.TaiheAST import TaiheAST\n" f"\n" f"from typing import Any\n" f"\n" f"\n" f"class TaiheVisitor:\n" f" def visit(self, node: TaiheAST.any) -> Any:\n" f" return node._accept(self)\n" f"\n" f" def visit_token(self, node: TaiheAST.TOKEN) -> Any:\n" f" raise NotImplementedError()\n" f"\n" ) parser = get_parser() type_list = [] for rule_name in parser.ruleNames: node_kind = rule_name[0].upper() + rule_name[1:] ctx_kind = node_kind + "Context" ctx_type = getattr(parser, ctx_kind) type_list.append((node_kind, ctx_type)) for node_kind, ctx_type in type_list: subclasses = ctx_type.__subclasses__() if subclasses: for sub_type in subclasses: sub_kind = sub_type.__name__ attr_kind = sub_kind[:-7] file.write( f" def visit_{snake_case(attr_kind)}(self, node: TaiheAST.{attr_kind}) -> Any:\n" f" return self.visit_{snake_case(node_kind)}(node)\n" f"\n" ) file.write( f" def visit_{snake_case(node_kind)}(self, node: TaiheAST.{node_kind}) -> Any:\n" f" raise NotImplementedError()\n" f"\n" ) def run_antlr(): locator = ResourceLocator.detect() jar = locator.get(ResourceType.DEV_ANTLR) args = ["java", "-cp", str(jar), "org.antlr.v4.Tool"] args += ["-Dlanguage=Python3", "-no-listener", G4_FILE, "-o", ANTLR_PATH] check_call(args, env={}) def main(): run_antlr() generate_ast() generate_visitor() if __name__ == "__main__": main()