1 2from io import StringIO 3import os 4import unittest 5 6from antlr3.tree import CommonTreeAdaptor, CommonTree, INVALID_TOKEN_TYPE 7from antlr3.treewizard import TreeWizard, computeTokenTypes, \ 8 TreePatternLexer, EOF, ID, BEGIN, END, PERCENT, COLON, DOT, ARG, \ 9 TreePatternParser, \ 10 TreePattern, WildcardTreePattern, TreePatternTreeAdaptor 11 12 13class TestComputeTokenTypes(unittest.TestCase): 14 """Test case for the computeTokenTypes function.""" 15 16 def testNone(self): 17 """computeTokenTypes(None) -> {}""" 18 19 typeMap = computeTokenTypes(None) 20 self.assertIsInstance(typeMap, dict) 21 self.assertEqual(typeMap, {}) 22 23 24 def testList(self): 25 """computeTokenTypes(['a', 'b']) -> { 'a': 0, 'b': 1 }""" 26 27 typeMap = computeTokenTypes(['a', 'b']) 28 self.assertIsInstance(typeMap, dict) 29 self.assertEqual(typeMap, { 'a': 0, 'b': 1 }) 30 31 32class TestTreePatternLexer(unittest.TestCase): 33 """Test case for the TreePatternLexer class.""" 34 35 def testBegin(self): 36 """TreePatternLexer(): '('""" 37 38 lexer = TreePatternLexer('(') 39 type = lexer.nextToken() 40 self.assertEqual(type, BEGIN) 41 self.assertEqual(lexer.sval, '') 42 self.assertFalse(lexer.error) 43 44 45 def testEnd(self): 46 """TreePatternLexer(): ')'""" 47 48 lexer = TreePatternLexer(')') 49 type = lexer.nextToken() 50 self.assertEqual(type, END) 51 self.assertEqual(lexer.sval, '') 52 self.assertFalse(lexer.error) 53 54 55 def testPercent(self): 56 """TreePatternLexer(): '%'""" 57 58 lexer = TreePatternLexer('%') 59 type = lexer.nextToken() 60 self.assertEqual(type, PERCENT) 61 self.assertEqual(lexer.sval, '') 62 self.assertFalse(lexer.error) 63 64 65 def testDot(self): 66 """TreePatternLexer(): '.'""" 67 68 lexer = TreePatternLexer('.') 69 type = lexer.nextToken() 70 self.assertEqual(type, DOT) 71 self.assertEqual(lexer.sval, '') 72 self.assertFalse(lexer.error) 73 74 75 def testColon(self): 76 """TreePatternLexer(): ':'""" 77 78 lexer = TreePatternLexer(':') 79 type = lexer.nextToken() 80 self.assertEqual(type, COLON) 81 self.assertEqual(lexer.sval, '') 82 self.assertFalse(lexer.error) 83 84 85 def testEOF(self): 86 """TreePatternLexer(): EOF""" 87 88 lexer = TreePatternLexer(' \n \r \t ') 89 type = lexer.nextToken() 90 self.assertEqual(type, EOF) 91 self.assertEqual(lexer.sval, '') 92 self.assertFalse(lexer.error) 93 94 95 def testID(self): 96 """TreePatternLexer(): ID""" 97 98 lexer = TreePatternLexer('_foo12_bar') 99 type = lexer.nextToken() 100 self.assertEqual(type, ID) 101 self.assertEqual(lexer.sval, '_foo12_bar') 102 self.assertFalse(lexer.error) 103 104 105 def testARG(self): 106 """TreePatternLexer(): ARG""" 107 108 lexer = TreePatternLexer(r'[ \]bla\n]') 109 type = lexer.nextToken() 110 self.assertEqual(type, ARG) 111 self.assertEqual(lexer.sval, r' ]bla\n') 112 self.assertFalse(lexer.error) 113 114 115 def testError(self): 116 """TreePatternLexer(): error""" 117 118 lexer = TreePatternLexer('1') 119 type = lexer.nextToken() 120 self.assertEqual(type, EOF) 121 self.assertEqual(lexer.sval, '') 122 self.assertTrue(lexer.error) 123 124 125class TestTreePatternParser(unittest.TestCase): 126 """Test case for the TreePatternParser class.""" 127 128 def setUp(self): 129 """Setup text fixure 130 131 We need a tree adaptor, use CommonTreeAdaptor. 132 And a constant list of token names. 133 134 """ 135 136 self.adaptor = CommonTreeAdaptor() 137 self.tokens = [ 138 "", "", "", "", "", "A", "B", "C", "D", "E", "ID", "VAR" 139 ] 140 self.wizard = TreeWizard(self.adaptor, tokenNames=self.tokens) 141 142 143 def testSingleNode(self): 144 """TreePatternParser: 'ID'""" 145 lexer = TreePatternLexer('ID') 146 parser = TreePatternParser(lexer, self.wizard, self.adaptor) 147 tree = parser.pattern() 148 self.assertIsInstance(tree, CommonTree) 149 self.assertEqual(tree.getType(), 10) 150 self.assertEqual(tree.getText(), 'ID') 151 152 153 def testSingleNodeWithArg(self): 154 """TreePatternParser: 'ID[foo]'""" 155 lexer = TreePatternLexer('ID[foo]') 156 parser = TreePatternParser(lexer, self.wizard, self.adaptor) 157 tree = parser.pattern() 158 self.assertIsInstance(tree, CommonTree) 159 self.assertEqual(tree.getType(), 10) 160 self.assertEqual(tree.getText(), 'foo') 161 162 163 def testSingleLevelTree(self): 164 """TreePatternParser: '(A B)'""" 165 lexer = TreePatternLexer('(A B)') 166 parser = TreePatternParser(lexer, self.wizard, self.adaptor) 167 tree = parser.pattern() 168 self.assertIsInstance(tree, CommonTree) 169 self.assertEqual(tree.getType(), 5) 170 self.assertEqual(tree.getText(), 'A') 171 self.assertEqual(tree.getChildCount(), 1) 172 self.assertEqual(tree.getChild(0).getType(), 6) 173 self.assertEqual(tree.getChild(0).getText(), 'B') 174 175 176 def testNil(self): 177 """TreePatternParser: 'nil'""" 178 lexer = TreePatternLexer('nil') 179 parser = TreePatternParser(lexer, self.wizard, self.adaptor) 180 tree = parser.pattern() 181 self.assertIsInstance(tree, CommonTree) 182 self.assertEqual(tree.getType(), 0) 183 self.assertIsNone(tree.getText()) 184 185 186 def testWildcard(self): 187 """TreePatternParser: '(.)'""" 188 lexer = TreePatternLexer('(.)') 189 parser = TreePatternParser(lexer, self.wizard, self.adaptor) 190 tree = parser.pattern() 191 self.assertIsInstance(tree, WildcardTreePattern) 192 193 194 def testLabel(self): 195 """TreePatternParser: '(%a:A)'""" 196 lexer = TreePatternLexer('(%a:A)') 197 parser = TreePatternParser(lexer, self.wizard, TreePatternTreeAdaptor()) 198 tree = parser.pattern() 199 self.assertIsInstance(tree, TreePattern) 200 self.assertEqual(tree.label, 'a') 201 202 203 def testError1(self): 204 """TreePatternParser: ')'""" 205 lexer = TreePatternLexer(')') 206 parser = TreePatternParser(lexer, self.wizard, self.adaptor) 207 tree = parser.pattern() 208 self.assertIsNone(tree) 209 210 211 def testError2(self): 212 """TreePatternParser: '()'""" 213 lexer = TreePatternLexer('()') 214 parser = TreePatternParser(lexer, self.wizard, self.adaptor) 215 tree = parser.pattern() 216 self.assertIsNone(tree) 217 218 219 def testError3(self): 220 """TreePatternParser: '(A ])'""" 221 lexer = TreePatternLexer('(A ])') 222 parser = TreePatternParser(lexer, self.wizard, self.adaptor) 223 tree = parser.pattern() 224 self.assertIsNone(tree) 225 226 227class TestTreeWizard(unittest.TestCase): 228 """Test case for the TreeWizard class.""" 229 230 def setUp(self): 231 """Setup text fixure 232 233 We need a tree adaptor, use CommonTreeAdaptor. 234 And a constant list of token names. 235 236 """ 237 238 self.adaptor = CommonTreeAdaptor() 239 self.tokens = [ 240 "", "", "", "", "", "A", "B", "C", "D", "E", "ID", "VAR" 241 ] 242 243 244 def testInit(self): 245 """TreeWizard.__init__()""" 246 247 wiz = TreeWizard( 248 self.adaptor, 249 tokenNames=['a', 'b'] 250 ) 251 252 self.assertIs(wiz.adaptor, self.adaptor) 253 self.assertEqual( 254 wiz.tokenNameToTypeMap, 255 { 'a': 0, 'b': 1 } 256 ) 257 258 259 def testGetTokenType(self): 260 """TreeWizard.getTokenType()""" 261 262 wiz = TreeWizard( 263 self.adaptor, 264 tokenNames=self.tokens 265 ) 266 267 self.assertEqual( 268 wiz.getTokenType('A'), 269 5 270 ) 271 272 self.assertEqual( 273 wiz.getTokenType('VAR'), 274 11 275 ) 276 277 self.assertEqual( 278 wiz.getTokenType('invalid'), 279 INVALID_TOKEN_TYPE 280 ) 281 282 def testSingleNode(self): 283 wiz = TreeWizard(self.adaptor, self.tokens) 284 t = wiz.create("ID") 285 found = t.toStringTree() 286 expecting = "ID" 287 self.assertEqual(expecting, found) 288 289 290 def testSingleNodeWithArg(self): 291 wiz = TreeWizard(self.adaptor, self.tokens) 292 t = wiz.create("ID[foo]") 293 found = t.toStringTree() 294 expecting = "foo" 295 self.assertEqual(expecting, found) 296 297 298 def testSingleNodeTree(self): 299 wiz = TreeWizard(self.adaptor, self.tokens) 300 t = wiz.create("(A)") 301 found = t.toStringTree() 302 expecting = "A" 303 self.assertEqual(expecting, found) 304 305 306 def testSingleLevelTree(self): 307 wiz = TreeWizard(self.adaptor, self.tokens) 308 t = wiz.create("(A B C D)") 309 found = t.toStringTree() 310 expecting = "(A B C D)" 311 self.assertEqual(expecting, found) 312 313 314 def testListTree(self): 315 wiz = TreeWizard(self.adaptor, self.tokens) 316 t = wiz.create("(nil A B C)") 317 found = t.toStringTree() 318 expecting = "A B C" 319 self.assertEqual(expecting, found) 320 321 322 def testInvalidListTree(self): 323 wiz = TreeWizard(self.adaptor, self.tokens) 324 t = wiz.create("A B C") 325 self.assertIsNone(t) 326 327 328 def testDoubleLevelTree(self): 329 wiz = TreeWizard(self.adaptor, self.tokens) 330 t = wiz.create("(A (B C) (B D) E)") 331 found = t.toStringTree() 332 expecting = "(A (B C) (B D) E)" 333 self.assertEqual(expecting, found) 334 335 336 def __simplifyIndexMap(self, indexMap): 337 return dict( # stringify nodes for easy comparing 338 (ttype, [str(node) for node in nodes]) 339 for ttype, nodes in indexMap.items() 340 ) 341 342 def testSingleNodeIndex(self): 343 wiz = TreeWizard(self.adaptor, self.tokens) 344 tree = wiz.create("ID") 345 indexMap = wiz.index(tree) 346 found = self.__simplifyIndexMap(indexMap) 347 expecting = { 10: ["ID"] } 348 self.assertEqual(expecting, found) 349 350 351 def testNoRepeatsIndex(self): 352 wiz = TreeWizard(self.adaptor, self.tokens) 353 tree = wiz.create("(A B C D)") 354 indexMap = wiz.index(tree) 355 found = self.__simplifyIndexMap(indexMap) 356 expecting = { 8:['D'], 6:['B'], 7:['C'], 5:['A'] } 357 self.assertEqual(expecting, found) 358 359 360 def testRepeatsIndex(self): 361 wiz = TreeWizard(self.adaptor, self.tokens) 362 tree = wiz.create("(A B (A C B) B D D)") 363 indexMap = wiz.index(tree) 364 found = self.__simplifyIndexMap(indexMap) 365 expecting = { 8: ['D', 'D'], 6: ['B', 'B', 'B'], 7: ['C'], 5: ['A', 'A'] } 366 self.assertEqual(expecting, found) 367 368 369 def testNoRepeatsVisit(self): 370 wiz = TreeWizard(self.adaptor, self.tokens) 371 tree = wiz.create("(A B C D)") 372 373 elements = [] 374 def visitor(node, parent, childIndex, labels): 375 elements.append(str(node)) 376 377 wiz.visit(tree, wiz.getTokenType("B"), visitor) 378 379 expecting = ['B'] 380 self.assertEqual(expecting, elements) 381 382 383 def testNoRepeatsVisit2(self): 384 wiz = TreeWizard(self.adaptor, self.tokens) 385 tree = wiz.create("(A B (A C B) B D D)") 386 387 elements = [] 388 def visitor(node, parent, childIndex, labels): 389 elements.append(str(node)) 390 391 wiz.visit(tree, wiz.getTokenType("C"), visitor) 392 393 expecting = ['C'] 394 self.assertEqual(expecting, elements) 395 396 397 def testRepeatsVisit(self): 398 wiz = TreeWizard(self.adaptor, self.tokens) 399 tree = wiz.create("(A B (A C B) B D D)") 400 401 elements = [] 402 def visitor(node, parent, childIndex, labels): 403 elements.append(str(node)) 404 405 wiz.visit(tree, wiz.getTokenType("B"), visitor) 406 407 expecting = ['B', 'B', 'B'] 408 self.assertEqual(expecting, elements) 409 410 411 def testRepeatsVisit2(self): 412 wiz = TreeWizard(self.adaptor, self.tokens) 413 tree = wiz.create("(A B (A C B) B D D)") 414 415 elements = [] 416 def visitor(node, parent, childIndex, labels): 417 elements.append(str(node)) 418 419 wiz.visit(tree, wiz.getTokenType("A"), visitor) 420 421 expecting = ['A', 'A'] 422 self.assertEqual(expecting, elements) 423 424 425 def testRepeatsVisitWithContext(self): 426 wiz = TreeWizard(self.adaptor, self.tokens) 427 tree = wiz.create("(A B (A C B) B D D)") 428 429 elements = [] 430 def visitor(node, parent, childIndex, labels): 431 elements.append('{}@{}[{}]'.format(node, parent, childIndex)) 432 433 wiz.visit(tree, wiz.getTokenType("B"), visitor) 434 435 expecting = ['B@A[0]', 'B@A[1]', 'B@A[2]'] 436 self.assertEqual(expecting, elements) 437 438 439 def testRepeatsVisitWithNullParentAndContext(self): 440 wiz = TreeWizard(self.adaptor, self.tokens) 441 tree = wiz.create("(A B (A C B) B D D)") 442 443 elements = [] 444 def visitor(node, parent, childIndex, labels): 445 elements.append( 446 '{}@{}[{}]'.format( 447 node, parent or 'nil', childIndex) 448 ) 449 450 wiz.visit(tree, wiz.getTokenType("A"), visitor) 451 452 expecting = ['A@nil[0]', 'A@A[1]'] 453 self.assertEqual(expecting, elements) 454 455 456 def testVisitPattern(self): 457 wiz = TreeWizard(self.adaptor, self.tokens) 458 tree = wiz.create("(A B C (A B) D)") 459 460 elements = [] 461 def visitor(node, parent, childIndex, labels): 462 elements.append( 463 str(node) 464 ) 465 466 wiz.visit(tree, '(A B)', visitor) 467 468 expecting = ['A'] # shouldn't match overall root, just (A B) 469 self.assertEqual(expecting, elements) 470 471 472 def testVisitPatternMultiple(self): 473 wiz = TreeWizard(self.adaptor, self.tokens) 474 tree = wiz.create("(A B C (A B) (D (A B)))") 475 476 elements = [] 477 def visitor(node, parent, childIndex, labels): 478 elements.append( 479 '{}@{}[{}]'.format(node, parent or 'nil', childIndex) 480 ) 481 482 wiz.visit(tree, '(A B)', visitor) 483 484 expecting = ['A@A[2]', 'A@D[0]'] 485 self.assertEqual(expecting, elements) 486 487 488 def testVisitPatternMultipleWithLabels(self): 489 wiz = TreeWizard(self.adaptor, self.tokens) 490 tree = wiz.create("(A B C (A[foo] B[bar]) (D (A[big] B[dog])))") 491 492 elements = [] 493 def visitor(node, parent, childIndex, labels): 494 elements.append( 495 '{}@{}[{}]{}&{}'.format( 496 node, 497 parent or 'nil', 498 childIndex, 499 labels['a'], 500 labels['b'], 501 ) 502 ) 503 504 wiz.visit(tree, '(%a:A %b:B)', visitor) 505 506 expecting = ['foo@A[2]foo&bar', 'big@D[0]big&dog'] 507 self.assertEqual(expecting, elements) 508 509 510 def testParse(self): 511 wiz = TreeWizard(self.adaptor, self.tokens) 512 t = wiz.create("(A B C)") 513 valid = wiz.parse(t, "(A B C)") 514 self.assertTrue(valid) 515 516 517 def testParseSingleNode(self): 518 wiz = TreeWizard(self.adaptor, self.tokens) 519 t = wiz.create("A") 520 valid = wiz.parse(t, "A") 521 self.assertTrue(valid) 522 523 524 def testParseSingleNodeFails(self): 525 wiz = TreeWizard(self.adaptor, self.tokens) 526 t = wiz.create("A") 527 valid = wiz.parse(t, "B") 528 self.assertFalse(valid) 529 530 531 def testParseFlatTree(self): 532 wiz = TreeWizard(self.adaptor, self.tokens) 533 t = wiz.create("(nil A B C)") 534 valid = wiz.parse(t, "(nil A B C)") 535 self.assertTrue(valid) 536 537 538 def testParseFlatTreeFails(self): 539 wiz = TreeWizard(self.adaptor, self.tokens) 540 t = wiz.create("(nil A B C)") 541 valid = wiz.parse(t, "(nil A B)") 542 self.assertFalse(valid) 543 544 545 def testParseFlatTreeFails2(self): 546 wiz = TreeWizard(self.adaptor, self.tokens) 547 t = wiz.create("(nil A B C)") 548 valid = wiz.parse(t, "(nil A B A)") 549 self.assertFalse(valid) 550 551 552 def testWildcard(self): 553 wiz = TreeWizard(self.adaptor, self.tokens) 554 t = wiz.create("(A B C)") 555 valid = wiz.parse(t, "(A . .)") 556 self.assertTrue(valid) 557 558 559 def testParseWithText(self): 560 wiz = TreeWizard(self.adaptor, self.tokens) 561 t = wiz.create("(A B[foo] C[bar])") 562 # C pattern has no text arg so despite [bar] in t, no need 563 # to match text--check structure only. 564 valid = wiz.parse(t, "(A B[foo] C)") 565 self.assertTrue(valid) 566 567 568 def testParseWithText2(self): 569 wiz = TreeWizard(self.adaptor, self.tokens) 570 t = wiz.create("(A B[T__32] (C (D E[a])))") 571 # C pattern has no text arg so despite [bar] in t, no need 572 # to match text--check structure only. 573 valid = wiz.parse(t, "(A B[foo] C)") 574 self.assertEqual("(A T__32 (C (D a)))", t.toStringTree()) 575 576 577 def testParseWithTextFails(self): 578 wiz = TreeWizard(self.adaptor, self.tokens) 579 t = wiz.create("(A B C)") 580 valid = wiz.parse(t, "(A[foo] B C)") 581 self.assertFalse(valid) # fails 582 583 584 def testParseLabels(self): 585 wiz = TreeWizard(self.adaptor, self.tokens) 586 t = wiz.create("(A B C)") 587 labels = {} 588 valid = wiz.parse(t, "(%a:A %b:B %c:C)", labels) 589 self.assertTrue(valid) 590 self.assertEqual("A", str(labels["a"])) 591 self.assertEqual("B", str(labels["b"])) 592 self.assertEqual("C", str(labels["c"])) 593 594 595 def testParseWithWildcardLabels(self): 596 wiz = TreeWizard(self.adaptor, self.tokens) 597 t = wiz.create("(A B C)") 598 labels = {} 599 valid = wiz.parse(t, "(A %b:. %c:.)", labels) 600 self.assertTrue(valid) 601 self.assertEqual("B", str(labels["b"])) 602 self.assertEqual("C", str(labels["c"])) 603 604 605 def testParseLabelsAndTestText(self): 606 wiz = TreeWizard(self.adaptor, self.tokens) 607 t = wiz.create("(A B[foo] C)") 608 labels = {} 609 valid = wiz.parse(t, "(%a:A %b:B[foo] %c:C)", labels) 610 self.assertTrue(valid) 611 self.assertEqual("A", str(labels["a"])) 612 self.assertEqual("foo", str(labels["b"])) 613 self.assertEqual("C", str(labels["c"])) 614 615 616 def testParseLabelsInNestedTree(self): 617 wiz = TreeWizard(self.adaptor, self.tokens) 618 t = wiz.create("(A (B C) (D E))") 619 labels = {} 620 valid = wiz.parse(t, "(%a:A (%b:B %c:C) (%d:D %e:E) )", labels) 621 self.assertTrue(valid) 622 self.assertEqual("A", str(labels["a"])) 623 self.assertEqual("B", str(labels["b"])) 624 self.assertEqual("C", str(labels["c"])) 625 self.assertEqual("D", str(labels["d"])) 626 self.assertEqual("E", str(labels["e"])) 627 628 629 def testEquals(self): 630 wiz = TreeWizard(self.adaptor, self.tokens) 631 t1 = wiz.create("(A B C)") 632 t2 = wiz.create("(A B C)") 633 same = wiz.equals(t1, t2) 634 self.assertTrue(same) 635 636 637 def testEqualsWithText(self): 638 wiz = TreeWizard(self.adaptor, self.tokens) 639 t1 = wiz.create("(A B[foo] C)") 640 t2 = wiz.create("(A B[foo] C)") 641 same = wiz.equals(t1, t2) 642 self.assertTrue(same) 643 644 645 def testEqualsWithMismatchedText(self): 646 wiz = TreeWizard(self.adaptor, self.tokens) 647 t1 = wiz.create("(A B[foo] C)") 648 t2 = wiz.create("(A B C)") 649 same = wiz.equals(t1, t2) 650 self.assertFalse(same) 651 652 653 def testEqualsWithMismatchedList(self): 654 wiz = TreeWizard(self.adaptor, self.tokens) 655 t1 = wiz.create("(A B C)") 656 t2 = wiz.create("(A B A)") 657 same = wiz.equals(t1, t2) 658 self.assertFalse(same) 659 660 661 def testEqualsWithMismatchedListLength(self): 662 wiz = TreeWizard(self.adaptor, self.tokens) 663 t1 = wiz.create("(A B C)") 664 t2 = wiz.create("(A B)") 665 same = wiz.equals(t1, t2) 666 self.assertFalse(same) 667 668 669 def testFindPattern(self): 670 wiz = TreeWizard(self.adaptor, self.tokens) 671 t = wiz.create("(A B C (A[foo] B[bar]) (D (A[big] B[dog])))") 672 subtrees = wiz.find(t, "(A B)") 673 found = [str(node) for node in subtrees] 674 expecting = ['foo', 'big'] 675 self.assertEqual(expecting, found) 676 677 678 def testFindTokenType(self): 679 wiz = TreeWizard(self.adaptor, self.tokens) 680 t = wiz.create("(A B C (A[foo] B[bar]) (D (A[big] B[dog])))") 681 subtrees = wiz.find(t, wiz.getTokenType('A')) 682 found = [str(node) for node in subtrees] 683 expecting = ['A', 'foo', 'big'] 684 self.assertEqual(expecting, found) 685 686 687 688if __name__ == "__main__": 689 unittest.main(testRunner=unittest.TextTestRunner(verbosity=2)) 690