1#!/usr/bin/env python3 2# Copyright (c) 2025 Huawei Device Co., Ltd. 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15 16import re 17import sys 18from functools import cache 19from pathlib import Path 20from subprocess import check_call 21 22from taihe.utils.resources import ResourceLocator, ResourceType 23 24ANTLR_PKG = "taihe.parse.antlr" 25CURRENT_DIR = Path(__file__).parent.resolve() 26ANTLR_PATH = CURRENT_DIR / ANTLR_PKG.replace(".", "/") 27G4_FILE = CURRENT_DIR / "Taihe.g4" 28 29# HACK: The parent module `taihe.parse` imports the code which will be generated by us soon. 30# We directly import ANTLR-generated module, without initializing the parent module. 31sys.path.insert(0, str(ANTLR_PATH)) 32 33 34@cache 35def get_parser(): 36 from TaiheParser import TaiheParser 37 38 return TaiheParser 39 40 41def get_hint(attr_kind): 42 if attr_kind.endswith("Lst"): 43 return f'List["TaiheAST.{attr_kind[:-3]}"]' 44 if attr_kind.endswith("Opt"): 45 return f'Optional["TaiheAST.{attr_kind[:-3]}"]' 46 return f'"TaiheAST.{attr_kind}"' 47 48 49def get_attr_pairs(ctx): 50 for attr_full_name, attr_ctx in ctx.__dict__.items(): 51 if not attr_full_name.startswith("_") and attr_full_name != "parser": 52 yield attr_full_name.split("_", 1) 53 54 55def snake_case(name): 56 """Convert CamelCase to snake_case.""" 57 return re.sub(r"(?<!^)(?=[A-Z])", "_", name).lower() 58 59 60class Inspector: 61 def __init__(self): 62 self.parentCtx = None 63 self.invokingState = None 64 self.children = None 65 self.start = None 66 self.stop = None 67 68 69def generate_ast(): 70 with open(ANTLR_PATH / "TaiheAST.py", "w") as file: 71 file.write( 72 f"from dataclasses import dataclass\n" 73 f"from typing import Any, Union, List, Optional\n" 74 f"\n" 75 f"from taihe.utils.sources import SourceLocation\n" 76 f"\n" 77 f"\n" 78 f"class TaiheAST:\n" 79 f" @dataclass(kw_only=True)\n" 80 f" class any:\n" 81 f" loc: SourceLocation\n" 82 f"\n" 83 f" def _accept(self, visitor) -> Any:\n" 84 f" raise NotImplementedError()\n" 85 f"\n" 86 f"\n" 87 f" @dataclass\n" 88 f" class TOKEN(any):\n" 89 f" text: str\n" 90 f"\n" 91 f" def __str__(self):\n" 92 f" return self.text\n" 93 f"\n" 94 f" def _accept(self, visitor) -> Any:\n" 95 f" return visitor.visit_token(self)\n" 96 f"\n" 97 ) 98 parser = get_parser() 99 type_list = [] 100 for rule_name in parser.ruleNames: 101 node_kind = rule_name[0].upper() + rule_name[1:] 102 ctx_kind = node_kind + "Context" 103 ctx_type = getattr(parser, ctx_kind) 104 type_list.append((node_kind, ctx_type)) 105 for node_kind, ctx_type in type_list: 106 subclasses = ctx_type.__subclasses__() 107 if subclasses: 108 file.write(f" {node_kind} = Union[\n") 109 for sub_type in subclasses: 110 sub_kind = sub_type.__name__ 111 attr_kind = sub_kind[:-7] 112 attr_hint = get_hint(attr_kind) 113 type_list.append((attr_kind, sub_type)) 114 file.write(f" {attr_hint},\n") 115 file.write(f" ]\n" f"\n") 116 else: 117 ctx = ctx_type(None, Inspector()) 118 file.write(f" @dataclass\n" f" class {node_kind}(any):\n") 119 for attr_kind, attr_name in get_attr_pairs(ctx): 120 attr_hint = get_hint(attr_kind) 121 file.write(f" {attr_name}: {attr_hint}\n") 122 file.write( 123 f"\n" 124 f" def _accept(self, visitor) -> Any:\n" 125 f" return visitor.visit_{snake_case(node_kind)}(self)\n" 126 f"\n" 127 ) 128 129 130def generate_visitor(): 131 with open(ANTLR_PATH / "TaiheVisitor.py", "w") as file: 132 file.write( 133 f"from {ANTLR_PKG}.TaiheAST import TaiheAST\n" 134 f"\n" 135 f"from typing import Any\n" 136 f"\n" 137 f"\n" 138 f"class TaiheVisitor:\n" 139 f" def visit(self, node: TaiheAST.any) -> Any:\n" 140 f" return node._accept(self)\n" 141 f"\n" 142 f" def visit_token(self, node: TaiheAST.TOKEN) -> Any:\n" 143 f" raise NotImplementedError()\n" 144 f"\n" 145 ) 146 parser = get_parser() 147 type_list = [] 148 for rule_name in parser.ruleNames: 149 node_kind = rule_name[0].upper() + rule_name[1:] 150 ctx_kind = node_kind + "Context" 151 ctx_type = getattr(parser, ctx_kind) 152 type_list.append((node_kind, ctx_type)) 153 for node_kind, ctx_type in type_list: 154 subclasses = ctx_type.__subclasses__() 155 if subclasses: 156 for sub_type in subclasses: 157 sub_kind = sub_type.__name__ 158 attr_kind = sub_kind[:-7] 159 file.write( 160 f" def visit_{snake_case(attr_kind)}(self, node: TaiheAST.{attr_kind}) -> Any:\n" 161 f" return self.visit_{snake_case(node_kind)}(node)\n" 162 f"\n" 163 ) 164 file.write( 165 f" def visit_{snake_case(node_kind)}(self, node: TaiheAST.{node_kind}) -> Any:\n" 166 f" raise NotImplementedError()\n" 167 f"\n" 168 ) 169 170 171def run_antlr(): 172 locator = ResourceLocator.detect() 173 jar = locator.get(ResourceType.DEV_ANTLR) 174 args = ["java", "-cp", str(jar), "org.antlr.v4.Tool"] 175 args += ["-Dlanguage=Python3", "-no-listener", G4_FILE, "-o", ANTLR_PATH] 176 check_call(args, env={}) 177 178 179def main(): 180 run_antlr() 181 generate_ast() 182 generate_visitor() 183 184 185if __name__ == "__main__": 186 main() 187