1""" @package antlr3.tree 2@brief ANTLR3 runtime package, treewizard module 3 4A utility module to create ASTs at runtime. 5See <http://www.antlr.org/wiki/display/~admin/2007/07/02/Exploring+Concept+of+TreeWizard> for an overview. Note that the API of the Python implementation is slightly different. 6 7""" 8 9# begin[licence] 10# 11# [The "BSD licence"] 12# Copyright (c) 2005-2012 Terence Parr 13# All rights reserved. 14# 15# Redistribution and use in source and binary forms, with or without 16# modification, are permitted provided that the following conditions 17# are met: 18# 1. Redistributions of source code must retain the above copyright 19# notice, this list of conditions and the following disclaimer. 20# 2. Redistributions in binary form must reproduce the above copyright 21# notice, this list of conditions and the following disclaimer in the 22# documentation and/or other materials provided with the distribution. 23# 3. The name of the author may not be used to endorse or promote products 24# derived from this software without specific prior written permission. 25# 26# THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR 27# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES 28# OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. 29# IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, 30# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT 31# NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 32# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 33# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 34# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF 35# THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 36# 37# end[licence] 38 39from .constants import INVALID_TOKEN_TYPE 40from .tokens import CommonToken 41from .tree import CommonTree, CommonTreeAdaptor 42 43 44def computeTokenTypes(tokenNames): 45 """ 46 Compute a dict that is an inverted index of 47 tokenNames (which maps int token types to names). 48 """ 49 50 if tokenNames: 51 return dict((name, type) for type, name in enumerate(tokenNames)) 52 53 return {} 54 55 56## token types for pattern parser 57EOF = -1 58BEGIN = 1 59END = 2 60ID = 3 61ARG = 4 62PERCENT = 5 63COLON = 6 64DOT = 7 65 66class TreePatternLexer(object): 67 def __init__(self, pattern): 68 ## The tree pattern to lex like "(A B C)" 69 self.pattern = pattern 70 71 ## Index into input string 72 self.p = -1 73 74 ## Current char 75 self.c = None 76 77 ## How long is the pattern in char? 78 self.n = len(pattern) 79 80 ## Set when token type is ID or ARG 81 self.sval = None 82 83 self.error = False 84 85 self.consume() 86 87 88 __idStartChar = frozenset( 89 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ_' 90 ) 91 __idChar = __idStartChar | frozenset('0123456789') 92 93 def nextToken(self): 94 self.sval = "" 95 while self.c != EOF: 96 if self.c in (' ', '\n', '\r', '\t'): 97 self.consume() 98 continue 99 100 if self.c in self.__idStartChar: 101 self.sval += self.c 102 self.consume() 103 while self.c in self.__idChar: 104 self.sval += self.c 105 self.consume() 106 107 return ID 108 109 if self.c == '(': 110 self.consume() 111 return BEGIN 112 113 if self.c == ')': 114 self.consume() 115 return END 116 117 if self.c == '%': 118 self.consume() 119 return PERCENT 120 121 if self.c == ':': 122 self.consume() 123 return COLON 124 125 if self.c == '.': 126 self.consume() 127 return DOT 128 129 if self.c == '[': # grab [x] as a string, returning x 130 self.consume() 131 while self.c != ']': 132 if self.c == '\\': 133 self.consume() 134 if self.c != ']': 135 self.sval += '\\' 136 137 self.sval += self.c 138 139 else: 140 self.sval += self.c 141 142 self.consume() 143 144 self.consume() 145 return ARG 146 147 self.consume() 148 self.error = True 149 return EOF 150 151 return EOF 152 153 154 def consume(self): 155 self.p += 1 156 if self.p >= self.n: 157 self.c = EOF 158 159 else: 160 self.c = self.pattern[self.p] 161 162 163class TreePatternParser(object): 164 def __init__(self, tokenizer, wizard, adaptor): 165 self.tokenizer = tokenizer 166 self.wizard = wizard 167 self.adaptor = adaptor 168 self.ttype = tokenizer.nextToken() # kickstart 169 170 171 def pattern(self): 172 if self.ttype == BEGIN: 173 return self.parseTree() 174 175 elif self.ttype == ID: 176 node = self.parseNode() 177 if self.ttype == EOF: 178 return node 179 180 return None # extra junk on end 181 182 return None 183 184 185 def parseTree(self): 186 if self.ttype != BEGIN: 187 return None 188 189 self.ttype = self.tokenizer.nextToken() 190 root = self.parseNode() 191 if root is None: 192 return None 193 194 while self.ttype in (BEGIN, ID, PERCENT, DOT): 195 if self.ttype == BEGIN: 196 subtree = self.parseTree() 197 self.adaptor.addChild(root, subtree) 198 199 else: 200 child = self.parseNode() 201 if child is None: 202 return None 203 204 self.adaptor.addChild(root, child) 205 206 if self.ttype != END: 207 return None 208 209 self.ttype = self.tokenizer.nextToken() 210 return root 211 212 213 def parseNode(self): 214 # "%label:" prefix 215 label = None 216 217 if self.ttype == PERCENT: 218 self.ttype = self.tokenizer.nextToken() 219 if self.ttype != ID: 220 return None 221 222 label = self.tokenizer.sval 223 self.ttype = self.tokenizer.nextToken() 224 if self.ttype != COLON: 225 return None 226 227 self.ttype = self.tokenizer.nextToken() # move to ID following colon 228 229 # Wildcard? 230 if self.ttype == DOT: 231 self.ttype = self.tokenizer.nextToken() 232 wildcardPayload = CommonToken(0, ".") 233 node = WildcardTreePattern(wildcardPayload) 234 if label is not None: 235 node.label = label 236 return node 237 238 # "ID" or "ID[arg]" 239 if self.ttype != ID: 240 return None 241 242 tokenName = self.tokenizer.sval 243 self.ttype = self.tokenizer.nextToken() 244 245 if tokenName == "nil": 246 return self.adaptor.nil() 247 248 text = tokenName 249 # check for arg 250 arg = None 251 if self.ttype == ARG: 252 arg = self.tokenizer.sval 253 text = arg 254 self.ttype = self.tokenizer.nextToken() 255 256 # create node 257 treeNodeType = self.wizard.getTokenType(tokenName) 258 if treeNodeType == INVALID_TOKEN_TYPE: 259 return None 260 261 node = self.adaptor.createFromType(treeNodeType, text) 262 if label is not None and isinstance(node, TreePattern): 263 node.label = label 264 265 if arg is not None and isinstance(node, TreePattern): 266 node.hasTextArg = True 267 268 return node 269 270 271class TreePattern(CommonTree): 272 """ 273 When using %label:TOKENNAME in a tree for parse(), we must 274 track the label. 275 """ 276 277 def __init__(self, payload): 278 super().__init__(payload) 279 280 self.label = None 281 self.hasTextArg = None 282 283 284 def toString(self): 285 if self.label: 286 return '%' + self.label + ':' + super().toString() 287 288 else: 289 return super().toString() 290 291 292class WildcardTreePattern(TreePattern): 293 pass 294 295 296class TreePatternTreeAdaptor(CommonTreeAdaptor): 297 """This adaptor creates TreePattern objects for use during scan()""" 298 299 def createWithPayload(self, payload): 300 return TreePattern(payload) 301 302 303class TreeWizard(object): 304 """ 305 Build and navigate trees with this object. Must know about the names 306 of tokens so you have to pass in a map or array of token names (from which 307 this class can build the map). I.e., Token DECL means nothing unless the 308 class can translate it to a token type. 309 310 In order to create nodes and navigate, this class needs a TreeAdaptor. 311 312 This class can build a token type -> node index for repeated use or for 313 iterating over the various nodes with a particular type. 314 315 This class works in conjunction with the TreeAdaptor rather than moving 316 all this functionality into the adaptor. An adaptor helps build and 317 navigate trees using methods. This class helps you do it with string 318 patterns like "(A B C)". You can create a tree from that pattern or 319 match subtrees against it. 320 """ 321 322 def __init__(self, adaptor=None, tokenNames=None, typeMap=None): 323 if adaptor is None: 324 self.adaptor = CommonTreeAdaptor() 325 326 else: 327 self.adaptor = adaptor 328 329 if typeMap is None: 330 self.tokenNameToTypeMap = computeTokenTypes(tokenNames) 331 332 else: 333 if tokenNames: 334 raise ValueError("Can't have both tokenNames and typeMap") 335 336 self.tokenNameToTypeMap = typeMap 337 338 339 def getTokenType(self, tokenName): 340 """Using the map of token names to token types, return the type.""" 341 342 if tokenName in self.tokenNameToTypeMap: 343 return self.tokenNameToTypeMap[tokenName] 344 else: 345 return INVALID_TOKEN_TYPE 346 347 348 def create(self, pattern): 349 """ 350 Create a tree or node from the indicated tree pattern that closely 351 follows ANTLR tree grammar tree element syntax: 352 353 (root child1 ... child2). 354 355 You can also just pass in a node: ID 356 357 Any node can have a text argument: ID[foo] 358 (notice there are no quotes around foo--it's clear it's a string). 359 360 nil is a special name meaning "give me a nil node". Useful for 361 making lists: (nil A B C) is a list of A B C. 362 """ 363 364 tokenizer = TreePatternLexer(pattern) 365 parser = TreePatternParser(tokenizer, self, self.adaptor) 366 return parser.pattern() 367 368 369 def index(self, tree): 370 """Walk the entire tree and make a node name to nodes mapping. 371 372 For now, use recursion but later nonrecursive version may be 373 more efficient. Returns a dict int -> list where the list is 374 of your AST node type. The int is the token type of the node. 375 """ 376 377 m = {} 378 self._index(tree, m) 379 return m 380 381 382 def _index(self, t, m): 383 """Do the work for index""" 384 385 if t is None: 386 return 387 388 ttype = self.adaptor.getType(t) 389 elements = m.get(ttype) 390 if elements is None: 391 m[ttype] = elements = [] 392 393 elements.append(t) 394 for i in range(self.adaptor.getChildCount(t)): 395 child = self.adaptor.getChild(t, i) 396 self._index(child, m) 397 398 399 def find(self, tree, what): 400 """Return a list of matching token. 401 402 what may either be an integer specifzing the token type to find or 403 a string with a pattern that must be matched. 404 405 """ 406 407 if isinstance(what, int): 408 return self._findTokenType(tree, what) 409 410 elif isinstance(what, str): 411 return self._findPattern(tree, what) 412 413 else: 414 raise TypeError("'what' must be string or integer") 415 416 417 def _findTokenType(self, t, ttype): 418 """Return a List of tree nodes with token type ttype""" 419 420 nodes = [] 421 422 def visitor(tree, parent, childIndex, labels): 423 nodes.append(tree) 424 425 self.visit(t, ttype, visitor) 426 427 return nodes 428 429 430 def _findPattern(self, t, pattern): 431 """Return a List of subtrees matching pattern.""" 432 433 subtrees = [] 434 435 # Create a TreePattern from the pattern 436 tokenizer = TreePatternLexer(pattern) 437 parser = TreePatternParser(tokenizer, self, TreePatternTreeAdaptor()) 438 tpattern = parser.pattern() 439 440 # don't allow invalid patterns 441 if (tpattern is None or tpattern.isNil() 442 or isinstance(tpattern, WildcardTreePattern)): 443 return None 444 445 rootTokenType = tpattern.getType() 446 447 def visitor(tree, parent, childIndex, label): 448 if self._parse(tree, tpattern, None): 449 subtrees.append(tree) 450 451 self.visit(t, rootTokenType, visitor) 452 453 return subtrees 454 455 456 def visit(self, tree, what, visitor): 457 """Visit every node in tree matching what, invoking the visitor. 458 459 If what is a string, it is parsed as a pattern and only matching 460 subtrees will be visited. 461 The implementation uses the root node of the pattern in combination 462 with visit(t, ttype, visitor) so nil-rooted patterns are not allowed. 463 Patterns with wildcard roots are also not allowed. 464 465 If what is an integer, it is used as a token type and visit will match 466 all nodes of that type (this is faster than the pattern match). 467 The labels arg of the visitor action method is never set (it's None) 468 since using a token type rather than a pattern doesn't let us set a 469 label. 470 """ 471 472 if isinstance(what, int): 473 self._visitType(tree, None, 0, what, visitor) 474 475 elif isinstance(what, str): 476 self._visitPattern(tree, what, visitor) 477 478 else: 479 raise TypeError("'what' must be string or integer") 480 481 482 def _visitType(self, t, parent, childIndex, ttype, visitor): 483 """Do the recursive work for visit""" 484 485 if t is None: 486 return 487 488 if self.adaptor.getType(t) == ttype: 489 visitor(t, parent, childIndex, None) 490 491 for i in range(self.adaptor.getChildCount(t)): 492 child = self.adaptor.getChild(t, i) 493 self._visitType(child, t, i, ttype, visitor) 494 495 496 def _visitPattern(self, tree, pattern, visitor): 497 """ 498 For all subtrees that match the pattern, execute the visit action. 499 """ 500 501 # Create a TreePattern from the pattern 502 tokenizer = TreePatternLexer(pattern) 503 parser = TreePatternParser(tokenizer, self, TreePatternTreeAdaptor()) 504 tpattern = parser.pattern() 505 506 # don't allow invalid patterns 507 if (tpattern is None or tpattern.isNil() 508 or isinstance(tpattern, WildcardTreePattern)): 509 return 510 511 rootTokenType = tpattern.getType() 512 513 def rootvisitor(tree, parent, childIndex, labels): 514 labels = {} 515 if self._parse(tree, tpattern, labels): 516 visitor(tree, parent, childIndex, labels) 517 518 self.visit(tree, rootTokenType, rootvisitor) 519 520 521 def parse(self, t, pattern, labels=None): 522 """ 523 Given a pattern like (ASSIGN %lhs:ID %rhs:.) with optional labels 524 on the various nodes and '.' (dot) as the node/subtree wildcard, 525 return true if the pattern matches and fill the labels Map with 526 the labels pointing at the appropriate nodes. Return false if 527 the pattern is malformed or the tree does not match. 528 529 If a node specifies a text arg in pattern, then that must match 530 for that node in t. 531 """ 532 533 tokenizer = TreePatternLexer(pattern) 534 parser = TreePatternParser(tokenizer, self, TreePatternTreeAdaptor()) 535 tpattern = parser.pattern() 536 537 return self._parse(t, tpattern, labels) 538 539 540 def _parse(self, t1, tpattern, labels): 541 """ 542 Do the work for parse. Check to see if the tpattern fits the 543 structure and token types in t1. Check text if the pattern has 544 text arguments on nodes. Fill labels map with pointers to nodes 545 in tree matched against nodes in pattern with labels. 546 """ 547 548 # make sure both are non-null 549 if t1 is None or tpattern is None: 550 return False 551 552 # check roots (wildcard matches anything) 553 if not isinstance(tpattern, WildcardTreePattern): 554 if self.adaptor.getType(t1) != tpattern.getType(): 555 return False 556 557 # if pattern has text, check node text 558 if (tpattern.hasTextArg 559 and self.adaptor.getText(t1) != tpattern.getText()): 560 return False 561 562 if tpattern.label is not None and labels is not None: 563 # map label in pattern to node in t1 564 labels[tpattern.label] = t1 565 566 # check children 567 n1 = self.adaptor.getChildCount(t1) 568 n2 = tpattern.getChildCount() 569 if n1 != n2: 570 return False 571 572 for i in range(n1): 573 child1 = self.adaptor.getChild(t1, i) 574 child2 = tpattern.getChild(i) 575 if not self._parse(child1, child2, labels): 576 return False 577 578 return True 579 580 581 def equals(self, t1, t2, adaptor=None): 582 """ 583 Compare t1 and t2; return true if token types/text, structure match 584 exactly. 585 The trees are examined in their entirety so that (A B) does not match 586 (A B C) nor (A (B C)). 587 """ 588 589 if adaptor is None: 590 adaptor = self.adaptor 591 592 return self._equals(t1, t2, adaptor) 593 594 595 def _equals(self, t1, t2, adaptor): 596 # make sure both are non-null 597 if t1 is None or t2 is None: 598 return False 599 600 # check roots 601 if adaptor.getType(t1) != adaptor.getType(t2): 602 return False 603 604 if adaptor.getText(t1) != adaptor.getText(t2): 605 return False 606 607 # check children 608 n1 = adaptor.getChildCount(t1) 609 n2 = adaptor.getChildCount(t2) 610 if n1 != n2: 611 return False 612 613 for i in range(n1): 614 child1 = adaptor.getChild(t1, i) 615 child2 = adaptor.getChild(t2, i) 616 if not self._equals(child1, child2, adaptor): 617 return False 618 619 return True 620