1# mako/_ast_util.py 2# Copyright (C) 2006-2015 the Mako authors and contributors <see AUTHORS file> 3# 4# This module is part of Mako and is released under 5# the MIT License: http://www.opensource.org/licenses/mit-license.php 6 7""" 8 ast 9 ~~~ 10 11 The `ast` module helps Python applications to process trees of the Python 12 abstract syntax grammar. The abstract syntax itself might change with 13 each Python release; this module helps to find out programmatically what 14 the current grammar looks like and allows modifications of it. 15 16 An abstract syntax tree can be generated by passing `ast.PyCF_ONLY_AST` as 17 a flag to the `compile()` builtin function or by using the `parse()` 18 function from this module. The result will be a tree of objects whose 19 classes all inherit from `ast.AST`. 20 21 A modified abstract syntax tree can be compiled into a Python code object 22 using the built-in `compile()` function. 23 24 Additionally various helper functions are provided that make working with 25 the trees simpler. The main intention of the helper functions and this 26 module in general is to provide an easy to use interface for libraries 27 that work tightly with the python syntax (template engines for example). 28 29 30 :copyright: Copyright 2008 by Armin Ronacher. 31 :license: Python License. 32""" 33from _ast import * 34from mako.compat import arg_stringname 35 36BOOLOP_SYMBOLS = { 37 And: 'and', 38 Or: 'or' 39} 40 41BINOP_SYMBOLS = { 42 Add: '+', 43 Sub: '-', 44 Mult: '*', 45 Div: '/', 46 FloorDiv: '//', 47 Mod: '%', 48 LShift: '<<', 49 RShift: '>>', 50 BitOr: '|', 51 BitAnd: '&', 52 BitXor: '^' 53} 54 55CMPOP_SYMBOLS = { 56 Eq: '==', 57 Gt: '>', 58 GtE: '>=', 59 In: 'in', 60 Is: 'is', 61 IsNot: 'is not', 62 Lt: '<', 63 LtE: '<=', 64 NotEq: '!=', 65 NotIn: 'not in' 66} 67 68UNARYOP_SYMBOLS = { 69 Invert: '~', 70 Not: 'not', 71 UAdd: '+', 72 USub: '-' 73} 74 75ALL_SYMBOLS = {} 76ALL_SYMBOLS.update(BOOLOP_SYMBOLS) 77ALL_SYMBOLS.update(BINOP_SYMBOLS) 78ALL_SYMBOLS.update(CMPOP_SYMBOLS) 79ALL_SYMBOLS.update(UNARYOP_SYMBOLS) 80 81 82def parse(expr, filename='<unknown>', mode='exec'): 83 """Parse an expression into an AST node.""" 84 return compile(expr, filename, mode, PyCF_ONLY_AST) 85 86 87def to_source(node, indent_with=' ' * 4): 88 """ 89 This function can convert a node tree back into python sourcecode. This 90 is useful for debugging purposes, especially if you're dealing with custom 91 asts not generated by python itself. 92 93 It could be that the sourcecode is evaluable when the AST itself is not 94 compilable / evaluable. The reason for this is that the AST contains some 95 more data than regular sourcecode does, which is dropped during 96 conversion. 97 98 Each level of indentation is replaced with `indent_with`. Per default this 99 parameter is equal to four spaces as suggested by PEP 8, but it might be 100 adjusted to match the application's styleguide. 101 """ 102 generator = SourceGenerator(indent_with) 103 generator.visit(node) 104 return ''.join(generator.result) 105 106 107def dump(node): 108 """ 109 A very verbose representation of the node passed. This is useful for 110 debugging purposes. 111 """ 112 def _format(node): 113 if isinstance(node, AST): 114 return '%s(%s)' % (node.__class__.__name__, 115 ', '.join('%s=%s' % (a, _format(b)) 116 for a, b in iter_fields(node))) 117 elif isinstance(node, list): 118 return '[%s]' % ', '.join(_format(x) for x in node) 119 return repr(node) 120 if not isinstance(node, AST): 121 raise TypeError('expected AST, got %r' % node.__class__.__name__) 122 return _format(node) 123 124 125def copy_location(new_node, old_node): 126 """ 127 Copy the source location hint (`lineno` and `col_offset`) from the 128 old to the new node if possible and return the new one. 129 """ 130 for attr in 'lineno', 'col_offset': 131 if attr in old_node._attributes and attr in new_node._attributes \ 132 and hasattr(old_node, attr): 133 setattr(new_node, attr, getattr(old_node, attr)) 134 return new_node 135 136 137def fix_missing_locations(node): 138 """ 139 Some nodes require a line number and the column offset. Without that 140 information the compiler will abort the compilation. Because it can be 141 a dull task to add appropriate line numbers and column offsets when 142 adding new nodes this function can help. It copies the line number and 143 column offset of the parent node to the child nodes without this 144 information. 145 146 Unlike `copy_location` this works recursive and won't touch nodes that 147 already have a location information. 148 """ 149 def _fix(node, lineno, col_offset): 150 if 'lineno' in node._attributes: 151 if not hasattr(node, 'lineno'): 152 node.lineno = lineno 153 else: 154 lineno = node.lineno 155 if 'col_offset' in node._attributes: 156 if not hasattr(node, 'col_offset'): 157 node.col_offset = col_offset 158 else: 159 col_offset = node.col_offset 160 for child in iter_child_nodes(node): 161 _fix(child, lineno, col_offset) 162 _fix(node, 1, 0) 163 return node 164 165 166def increment_lineno(node, n=1): 167 """ 168 Increment the line numbers of all nodes by `n` if they have line number 169 attributes. This is useful to "move code" to a different location in a 170 file. 171 """ 172 for node in zip((node,), walk(node)): 173 if 'lineno' in node._attributes: 174 node.lineno = getattr(node, 'lineno', 0) + n 175 176 177def iter_fields(node): 178 """Iterate over all fields of a node, only yielding existing fields.""" 179 # CPython 2.5 compat 180 if not hasattr(node, '_fields') or not node._fields: 181 return 182 for field in node._fields: 183 try: 184 yield field, getattr(node, field) 185 except AttributeError: 186 pass 187 188 189def get_fields(node): 190 """Like `iter_fiels` but returns a dict.""" 191 return dict(iter_fields(node)) 192 193 194def iter_child_nodes(node): 195 """Iterate over all child nodes or a node.""" 196 for name, field in iter_fields(node): 197 if isinstance(field, AST): 198 yield field 199 elif isinstance(field, list): 200 for item in field: 201 if isinstance(item, AST): 202 yield item 203 204 205def get_child_nodes(node): 206 """Like `iter_child_nodes` but returns a list.""" 207 return list(iter_child_nodes(node)) 208 209 210def get_compile_mode(node): 211 """ 212 Get the mode for `compile` of a given node. If the node is not a `mod` 213 node (`Expression`, `Module` etc.) a `TypeError` is thrown. 214 """ 215 if not isinstance(node, mod): 216 raise TypeError('expected mod node, got %r' % node.__class__.__name__) 217 return { 218 Expression: 'eval', 219 Interactive: 'single' 220 }.get(node.__class__, 'expr') 221 222 223def get_docstring(node): 224 """ 225 Return the docstring for the given node or `None` if no docstring can be 226 found. If the node provided does not accept docstrings a `TypeError` 227 will be raised. 228 """ 229 if not isinstance(node, (FunctionDef, ClassDef, Module)): 230 raise TypeError("%r can't have docstrings" % node.__class__.__name__) 231 if node.body and isinstance(node.body[0], Str): 232 return node.body[0].s 233 234 235def walk(node): 236 """ 237 Iterate over all nodes. This is useful if you only want to modify nodes in 238 place and don't care about the context or the order the nodes are returned. 239 """ 240 from collections import deque 241 todo = deque([node]) 242 while todo: 243 node = todo.popleft() 244 todo.extend(iter_child_nodes(node)) 245 yield node 246 247 248class NodeVisitor(object): 249 """ 250 Walks the abstract syntax tree and call visitor functions for every node 251 found. The visitor functions may return values which will be forwarded 252 by the `visit` method. 253 254 Per default the visitor functions for the nodes are ``'visit_'`` + 255 class name of the node. So a `TryFinally` node visit function would 256 be `visit_TryFinally`. This behavior can be changed by overriding 257 the `get_visitor` function. If no visitor function exists for a node 258 (return value `None`) the `generic_visit` visitor is used instead. 259 260 Don't use the `NodeVisitor` if you want to apply changes to nodes during 261 traversing. For this a special visitor exists (`NodeTransformer`) that 262 allows modifications. 263 """ 264 265 def get_visitor(self, node): 266 """ 267 Return the visitor function for this node or `None` if no visitor 268 exists for this node. In that case the generic visit function is 269 used instead. 270 """ 271 method = 'visit_' + node.__class__.__name__ 272 return getattr(self, method, None) 273 274 def visit(self, node): 275 """Visit a node.""" 276 f = self.get_visitor(node) 277 if f is not None: 278 return f(node) 279 return self.generic_visit(node) 280 281 def generic_visit(self, node): 282 """Called if no explicit visitor function exists for a node.""" 283 for field, value in iter_fields(node): 284 if isinstance(value, list): 285 for item in value: 286 if isinstance(item, AST): 287 self.visit(item) 288 elif isinstance(value, AST): 289 self.visit(value) 290 291 292class NodeTransformer(NodeVisitor): 293 """ 294 Walks the abstract syntax tree and allows modifications of nodes. 295 296 The `NodeTransformer` will walk the AST and use the return value of the 297 visitor functions to replace or remove the old node. If the return 298 value of the visitor function is `None` the node will be removed 299 from the previous location otherwise it's replaced with the return 300 value. The return value may be the original node in which case no 301 replacement takes place. 302 303 Here an example transformer that rewrites all `foo` to `data['foo']`:: 304 305 class RewriteName(NodeTransformer): 306 307 def visit_Name(self, node): 308 return copy_location(Subscript( 309 value=Name(id='data', ctx=Load()), 310 slice=Index(value=Str(s=node.id)), 311 ctx=node.ctx 312 ), node) 313 314 Keep in mind that if the node you're operating on has child nodes 315 you must either transform the child nodes yourself or call the generic 316 visit function for the node first. 317 318 Nodes that were part of a collection of statements (that applies to 319 all statement nodes) may also return a list of nodes rather than just 320 a single node. 321 322 Usually you use the transformer like this:: 323 324 node = YourTransformer().visit(node) 325 """ 326 327 def generic_visit(self, node): 328 for field, old_value in iter_fields(node): 329 old_value = getattr(node, field, None) 330 if isinstance(old_value, list): 331 new_values = [] 332 for value in old_value: 333 if isinstance(value, AST): 334 value = self.visit(value) 335 if value is None: 336 continue 337 elif not isinstance(value, AST): 338 new_values.extend(value) 339 continue 340 new_values.append(value) 341 old_value[:] = new_values 342 elif isinstance(old_value, AST): 343 new_node = self.visit(old_value) 344 if new_node is None: 345 delattr(node, field) 346 else: 347 setattr(node, field, new_node) 348 return node 349 350 351class SourceGenerator(NodeVisitor): 352 """ 353 This visitor is able to transform a well formed syntax tree into python 354 sourcecode. For more details have a look at the docstring of the 355 `node_to_source` function. 356 """ 357 358 def __init__(self, indent_with): 359 self.result = [] 360 self.indent_with = indent_with 361 self.indentation = 0 362 self.new_lines = 0 363 364 def write(self, x): 365 if self.new_lines: 366 if self.result: 367 self.result.append('\n' * self.new_lines) 368 self.result.append(self.indent_with * self.indentation) 369 self.new_lines = 0 370 self.result.append(x) 371 372 def newline(self, n=1): 373 self.new_lines = max(self.new_lines, n) 374 375 def body(self, statements): 376 self.new_line = True 377 self.indentation += 1 378 for stmt in statements: 379 self.visit(stmt) 380 self.indentation -= 1 381 382 def body_or_else(self, node): 383 self.body(node.body) 384 if node.orelse: 385 self.newline() 386 self.write('else:') 387 self.body(node.orelse) 388 389 def signature(self, node): 390 want_comma = [] 391 def write_comma(): 392 if want_comma: 393 self.write(', ') 394 else: 395 want_comma.append(True) 396 397 padding = [None] * (len(node.args) - len(node.defaults)) 398 for arg, default in zip(node.args, padding + node.defaults): 399 write_comma() 400 self.visit(arg) 401 if default is not None: 402 self.write('=') 403 self.visit(default) 404 if node.vararg is not None: 405 write_comma() 406 self.write('*' + arg_stringname(node.vararg)) 407 if node.kwarg is not None: 408 write_comma() 409 self.write('**' + arg_stringname(node.kwarg)) 410 411 def decorators(self, node): 412 for decorator in node.decorator_list: 413 self.newline() 414 self.write('@') 415 self.visit(decorator) 416 417 # Statements 418 419 def visit_Assign(self, node): 420 self.newline() 421 for idx, target in enumerate(node.targets): 422 if idx: 423 self.write(', ') 424 self.visit(target) 425 self.write(' = ') 426 self.visit(node.value) 427 428 def visit_AugAssign(self, node): 429 self.newline() 430 self.visit(node.target) 431 self.write(BINOP_SYMBOLS[type(node.op)] + '=') 432 self.visit(node.value) 433 434 def visit_ImportFrom(self, node): 435 self.newline() 436 self.write('from %s%s import ' % ('.' * node.level, node.module)) 437 for idx, item in enumerate(node.names): 438 if idx: 439 self.write(', ') 440 self.write(item) 441 442 def visit_Import(self, node): 443 self.newline() 444 for item in node.names: 445 self.write('import ') 446 self.visit(item) 447 448 def visit_Expr(self, node): 449 self.newline() 450 self.generic_visit(node) 451 452 def visit_FunctionDef(self, node): 453 self.newline(n=2) 454 self.decorators(node) 455 self.newline() 456 self.write('def %s(' % node.name) 457 self.signature(node.args) 458 self.write('):') 459 self.body(node.body) 460 461 def visit_ClassDef(self, node): 462 have_args = [] 463 def paren_or_comma(): 464 if have_args: 465 self.write(', ') 466 else: 467 have_args.append(True) 468 self.write('(') 469 470 self.newline(n=3) 471 self.decorators(node) 472 self.newline() 473 self.write('class %s' % node.name) 474 for base in node.bases: 475 paren_or_comma() 476 self.visit(base) 477 # XXX: the if here is used to keep this module compatible 478 # with python 2.6. 479 if hasattr(node, 'keywords'): 480 for keyword in node.keywords: 481 paren_or_comma() 482 self.write(keyword.arg + '=') 483 self.visit(keyword.value) 484 if node.starargs is not None: 485 paren_or_comma() 486 self.write('*') 487 self.visit(node.starargs) 488 if node.kwargs is not None: 489 paren_or_comma() 490 self.write('**') 491 self.visit(node.kwargs) 492 self.write(have_args and '):' or ':') 493 self.body(node.body) 494 495 def visit_If(self, node): 496 self.newline() 497 self.write('if ') 498 self.visit(node.test) 499 self.write(':') 500 self.body(node.body) 501 while True: 502 else_ = node.orelse 503 if len(else_) == 1 and isinstance(else_[0], If): 504 node = else_[0] 505 self.newline() 506 self.write('elif ') 507 self.visit(node.test) 508 self.write(':') 509 self.body(node.body) 510 else: 511 self.newline() 512 self.write('else:') 513 self.body(else_) 514 break 515 516 def visit_For(self, node): 517 self.newline() 518 self.write('for ') 519 self.visit(node.target) 520 self.write(' in ') 521 self.visit(node.iter) 522 self.write(':') 523 self.body_or_else(node) 524 525 def visit_While(self, node): 526 self.newline() 527 self.write('while ') 528 self.visit(node.test) 529 self.write(':') 530 self.body_or_else(node) 531 532 def visit_With(self, node): 533 self.newline() 534 self.write('with ') 535 self.visit(node.context_expr) 536 if node.optional_vars is not None: 537 self.write(' as ') 538 self.visit(node.optional_vars) 539 self.write(':') 540 self.body(node.body) 541 542 def visit_Pass(self, node): 543 self.newline() 544 self.write('pass') 545 546 def visit_Print(self, node): 547 # XXX: python 2.6 only 548 self.newline() 549 self.write('print ') 550 want_comma = False 551 if node.dest is not None: 552 self.write(' >> ') 553 self.visit(node.dest) 554 want_comma = True 555 for value in node.values: 556 if want_comma: 557 self.write(', ') 558 self.visit(value) 559 want_comma = True 560 if not node.nl: 561 self.write(',') 562 563 def visit_Delete(self, node): 564 self.newline() 565 self.write('del ') 566 for idx, target in enumerate(node): 567 if idx: 568 self.write(', ') 569 self.visit(target) 570 571 def visit_TryExcept(self, node): 572 self.newline() 573 self.write('try:') 574 self.body(node.body) 575 for handler in node.handlers: 576 self.visit(handler) 577 578 def visit_TryFinally(self, node): 579 self.newline() 580 self.write('try:') 581 self.body(node.body) 582 self.newline() 583 self.write('finally:') 584 self.body(node.finalbody) 585 586 def visit_Global(self, node): 587 self.newline() 588 self.write('global ' + ', '.join(node.names)) 589 590 def visit_Nonlocal(self, node): 591 self.newline() 592 self.write('nonlocal ' + ', '.join(node.names)) 593 594 def visit_Return(self, node): 595 self.newline() 596 self.write('return ') 597 self.visit(node.value) 598 599 def visit_Break(self, node): 600 self.newline() 601 self.write('break') 602 603 def visit_Continue(self, node): 604 self.newline() 605 self.write('continue') 606 607 def visit_Raise(self, node): 608 # XXX: Python 2.6 / 3.0 compatibility 609 self.newline() 610 self.write('raise') 611 if hasattr(node, 'exc') and node.exc is not None: 612 self.write(' ') 613 self.visit(node.exc) 614 if node.cause is not None: 615 self.write(' from ') 616 self.visit(node.cause) 617 elif hasattr(node, 'type') and node.type is not None: 618 self.visit(node.type) 619 if node.inst is not None: 620 self.write(', ') 621 self.visit(node.inst) 622 if node.tback is not None: 623 self.write(', ') 624 self.visit(node.tback) 625 626 # Expressions 627 628 def visit_Attribute(self, node): 629 self.visit(node.value) 630 self.write('.' + node.attr) 631 632 def visit_Call(self, node): 633 want_comma = [] 634 def write_comma(): 635 if want_comma: 636 self.write(', ') 637 else: 638 want_comma.append(True) 639 640 self.visit(node.func) 641 self.write('(') 642 for arg in node.args: 643 write_comma() 644 self.visit(arg) 645 for keyword in node.keywords: 646 write_comma() 647 self.write(keyword.arg + '=') 648 self.visit(keyword.value) 649 if node.starargs is not None: 650 write_comma() 651 self.write('*') 652 self.visit(node.starargs) 653 if node.kwargs is not None: 654 write_comma() 655 self.write('**') 656 self.visit(node.kwargs) 657 self.write(')') 658 659 def visit_Name(self, node): 660 self.write(node.id) 661 662 def visit_NameConstant(self, node): 663 self.write(str(node.value)) 664 665 def visit_arg(self, node): 666 self.write(node.arg) 667 668 def visit_Str(self, node): 669 self.write(repr(node.s)) 670 671 def visit_Bytes(self, node): 672 self.write(repr(node.s)) 673 674 def visit_Num(self, node): 675 self.write(repr(node.n)) 676 677 def visit_Tuple(self, node): 678 self.write('(') 679 idx = -1 680 for idx, item in enumerate(node.elts): 681 if idx: 682 self.write(', ') 683 self.visit(item) 684 self.write(idx and ')' or ',)') 685 686 def sequence_visit(left, right): 687 def visit(self, node): 688 self.write(left) 689 for idx, item in enumerate(node.elts): 690 if idx: 691 self.write(', ') 692 self.visit(item) 693 self.write(right) 694 return visit 695 696 visit_List = sequence_visit('[', ']') 697 visit_Set = sequence_visit('{', '}') 698 del sequence_visit 699 700 def visit_Dict(self, node): 701 self.write('{') 702 for idx, (key, value) in enumerate(zip(node.keys, node.values)): 703 if idx: 704 self.write(', ') 705 self.visit(key) 706 self.write(': ') 707 self.visit(value) 708 self.write('}') 709 710 def visit_BinOp(self, node): 711 self.write('(') 712 self.visit(node.left) 713 self.write(' %s ' % BINOP_SYMBOLS[type(node.op)]) 714 self.visit(node.right) 715 self.write(')') 716 717 def visit_BoolOp(self, node): 718 self.write('(') 719 for idx, value in enumerate(node.values): 720 if idx: 721 self.write(' %s ' % BOOLOP_SYMBOLS[type(node.op)]) 722 self.visit(value) 723 self.write(')') 724 725 def visit_Compare(self, node): 726 self.write('(') 727 self.visit(node.left) 728 for op, right in zip(node.ops, node.comparators): 729 self.write(' %s ' % CMPOP_SYMBOLS[type(op)]) 730 self.visit(right) 731 self.write(')') 732 733 def visit_UnaryOp(self, node): 734 self.write('(') 735 op = UNARYOP_SYMBOLS[type(node.op)] 736 self.write(op) 737 if op == 'not': 738 self.write(' ') 739 self.visit(node.operand) 740 self.write(')') 741 742 def visit_Subscript(self, node): 743 self.visit(node.value) 744 self.write('[') 745 self.visit(node.slice) 746 self.write(']') 747 748 def visit_Slice(self, node): 749 if node.lower is not None: 750 self.visit(node.lower) 751 self.write(':') 752 if node.upper is not None: 753 self.visit(node.upper) 754 if node.step is not None: 755 self.write(':') 756 if not (isinstance(node.step, Name) and node.step.id == 'None'): 757 self.visit(node.step) 758 759 def visit_ExtSlice(self, node): 760 for idx, item in node.dims: 761 if idx: 762 self.write(', ') 763 self.visit(item) 764 765 def visit_Yield(self, node): 766 self.write('yield ') 767 self.visit(node.value) 768 769 def visit_Lambda(self, node): 770 self.write('lambda ') 771 self.signature(node.args) 772 self.write(': ') 773 self.visit(node.body) 774 775 def visit_Ellipsis(self, node): 776 self.write('Ellipsis') 777 778 def generator_visit(left, right): 779 def visit(self, node): 780 self.write(left) 781 self.visit(node.elt) 782 for comprehension in node.generators: 783 self.visit(comprehension) 784 self.write(right) 785 return visit 786 787 visit_ListComp = generator_visit('[', ']') 788 visit_GeneratorExp = generator_visit('(', ')') 789 visit_SetComp = generator_visit('{', '}') 790 del generator_visit 791 792 def visit_DictComp(self, node): 793 self.write('{') 794 self.visit(node.key) 795 self.write(': ') 796 self.visit(node.value) 797 for comprehension in node.generators: 798 self.visit(comprehension) 799 self.write('}') 800 801 def visit_IfExp(self, node): 802 self.visit(node.body) 803 self.write(' if ') 804 self.visit(node.test) 805 self.write(' else ') 806 self.visit(node.orelse) 807 808 def visit_Starred(self, node): 809 self.write('*') 810 self.visit(node.value) 811 812 def visit_Repr(self, node): 813 # XXX: python 2.6 only 814 self.write('`') 815 self.visit(node.value) 816 self.write('`') 817 818 # Helper Nodes 819 820 def visit_alias(self, node): 821 self.write(node.name) 822 if node.asname is not None: 823 self.write(' as ' + node.asname) 824 825 def visit_comprehension(self, node): 826 self.write(' for ') 827 self.visit(node.target) 828 self.write(' in ') 829 self.visit(node.iter) 830 if node.ifs: 831 for if_ in node.ifs: 832 self.write(' if ') 833 self.visit(if_) 834 835 def visit_excepthandler(self, node): 836 self.newline() 837 self.write('except') 838 if node.type is not None: 839 self.write(' ') 840 self.visit(node.type) 841 if node.name is not None: 842 self.write(' as ') 843 self.visit(node.name) 844 self.write(':') 845 self.body(node.body) 846