1""" 2 ast 3 ~~~ 4 5 The `ast` module helps Python applications to process trees of the Python 6 abstract syntax grammar. The abstract syntax itself might change with 7 each Python release; this module helps to find out programmatically what 8 the current grammar looks like and allows modifications of it. 9 10 An abstract syntax tree can be generated by passing `ast.PyCF_ONLY_AST` as 11 a flag to the `compile()` builtin function or by using the `parse()` 12 function from this module. The result will be a tree of objects whose 13 classes all inherit from `ast.AST`. 14 15 A modified abstract syntax tree can be compiled into a Python code object 16 using the built-in `compile()` function. 17 18 Additionally various helper functions are provided that make working with 19 the trees simpler. The main intention of the helper functions and this 20 module in general is to provide an easy to use interface for libraries 21 that work tightly with the python syntax (template engines for example). 22 23 24 :copyright: Copyright 2008 by Armin Ronacher. 25 :license: Python License. 26""" 27import sys 28import re 29from _ast import * 30from contextlib import contextmanager, nullcontext 31from enum import IntEnum, auto, _simple_enum 32 33 34def parse(source, filename='<unknown>', mode='exec', *, 35 type_comments=False, feature_version=None, optimize=-1): 36 """ 37 Parse the source into an AST node. 38 Equivalent to compile(source, filename, mode, PyCF_ONLY_AST). 39 Pass type_comments=True to get back type comments where the syntax allows. 40 """ 41 flags = PyCF_ONLY_AST 42 if optimize > 0: 43 flags |= PyCF_OPTIMIZED_AST 44 if type_comments: 45 flags |= PyCF_TYPE_COMMENTS 46 if feature_version is None: 47 feature_version = -1 48 elif isinstance(feature_version, tuple): 49 major, minor = feature_version # Should be a 2-tuple. 50 if major != 3: 51 raise ValueError(f"Unsupported major version: {major}") 52 feature_version = minor 53 # Else it should be an int giving the minor version for 3.x. 54 return compile(source, filename, mode, flags, 55 _feature_version=feature_version, optimize=optimize) 56 57 58def literal_eval(node_or_string): 59 """ 60 Evaluate an expression node or a string containing only a Python 61 expression. The string or node provided may only consist of the following 62 Python literal structures: strings, bytes, numbers, tuples, lists, dicts, 63 sets, booleans, and None. 64 65 Caution: A complex expression can overflow the C stack and cause a crash. 66 """ 67 if isinstance(node_or_string, str): 68 node_or_string = parse(node_or_string.lstrip(" \t"), mode='eval') 69 if isinstance(node_or_string, Expression): 70 node_or_string = node_or_string.body 71 def _raise_malformed_node(node): 72 msg = "malformed node or string" 73 if lno := getattr(node, 'lineno', None): 74 msg += f' on line {lno}' 75 raise ValueError(msg + f': {node!r}') 76 def _convert_num(node): 77 if not isinstance(node, Constant) or type(node.value) not in (int, float, complex): 78 _raise_malformed_node(node) 79 return node.value 80 def _convert_signed_num(node): 81 if isinstance(node, UnaryOp) and isinstance(node.op, (UAdd, USub)): 82 operand = _convert_num(node.operand) 83 if isinstance(node.op, UAdd): 84 return + operand 85 else: 86 return - operand 87 return _convert_num(node) 88 def _convert(node): 89 if isinstance(node, Constant): 90 return node.value 91 elif isinstance(node, Tuple): 92 return tuple(map(_convert, node.elts)) 93 elif isinstance(node, List): 94 return list(map(_convert, node.elts)) 95 elif isinstance(node, Set): 96 return set(map(_convert, node.elts)) 97 elif (isinstance(node, Call) and isinstance(node.func, Name) and 98 node.func.id == 'set' and node.args == node.keywords == []): 99 return set() 100 elif isinstance(node, Dict): 101 if len(node.keys) != len(node.values): 102 _raise_malformed_node(node) 103 return dict(zip(map(_convert, node.keys), 104 map(_convert, node.values))) 105 elif isinstance(node, BinOp) and isinstance(node.op, (Add, Sub)): 106 left = _convert_signed_num(node.left) 107 right = _convert_num(node.right) 108 if isinstance(left, (int, float)) and isinstance(right, complex): 109 if isinstance(node.op, Add): 110 return left + right 111 else: 112 return left - right 113 return _convert_signed_num(node) 114 return _convert(node_or_string) 115 116 117def dump( 118 node, annotate_fields=True, include_attributes=False, 119 *, 120 indent=None, show_empty=False, 121): 122 """ 123 Return a formatted dump of the tree in node. This is mainly useful for 124 debugging purposes. If annotate_fields is true (by default), 125 the returned string will show the names and the values for fields. 126 If annotate_fields is false, the result string will be more compact by 127 omitting unambiguous field names. Attributes such as line 128 numbers and column offsets are not dumped by default. If this is wanted, 129 include_attributes can be set to true. If indent is a non-negative 130 integer or string, then the tree will be pretty-printed with that indent 131 level. None (the default) selects the single line representation. 132 If show_empty is False, then empty lists and fields that are None 133 will be omitted from the output for better readability. 134 """ 135 def _format(node, level=0): 136 if indent is not None: 137 level += 1 138 prefix = '\n' + indent * level 139 sep = ',\n' + indent * level 140 else: 141 prefix = '' 142 sep = ', ' 143 if isinstance(node, AST): 144 cls = type(node) 145 args = [] 146 args_buffer = [] 147 allsimple = True 148 keywords = annotate_fields 149 for name in node._fields: 150 try: 151 value = getattr(node, name) 152 except AttributeError: 153 keywords = True 154 continue 155 if value is None and getattr(cls, name, ...) is None: 156 keywords = True 157 continue 158 if ( 159 not show_empty 160 and (value is None or value == []) 161 # Special cases: 162 # `Constant(value=None)` and `MatchSingleton(value=None)` 163 and not isinstance(node, (Constant, MatchSingleton)) 164 ): 165 args_buffer.append(repr(value)) 166 continue 167 elif not keywords: 168 args.extend(args_buffer) 169 args_buffer = [] 170 value, simple = _format(value, level) 171 allsimple = allsimple and simple 172 if keywords: 173 args.append('%s=%s' % (name, value)) 174 else: 175 args.append(value) 176 if include_attributes and node._attributes: 177 for name in node._attributes: 178 try: 179 value = getattr(node, name) 180 except AttributeError: 181 continue 182 if value is None and getattr(cls, name, ...) is None: 183 continue 184 value, simple = _format(value, level) 185 allsimple = allsimple and simple 186 args.append('%s=%s' % (name, value)) 187 if allsimple and len(args) <= 3: 188 return '%s(%s)' % (node.__class__.__name__, ', '.join(args)), not args 189 return '%s(%s%s)' % (node.__class__.__name__, prefix, sep.join(args)), False 190 elif isinstance(node, list): 191 if not node: 192 return '[]', True 193 return '[%s%s]' % (prefix, sep.join(_format(x, level)[0] for x in node)), False 194 return repr(node), True 195 196 if not isinstance(node, AST): 197 raise TypeError('expected AST, got %r' % node.__class__.__name__) 198 if indent is not None and not isinstance(indent, str): 199 indent = ' ' * indent 200 return _format(node)[0] 201 202 203def copy_location(new_node, old_node): 204 """ 205 Copy source location (`lineno`, `col_offset`, `end_lineno`, and `end_col_offset` 206 attributes) from *old_node* to *new_node* if possible, and return *new_node*. 207 """ 208 for attr in 'lineno', 'col_offset', 'end_lineno', 'end_col_offset': 209 if attr in old_node._attributes and attr in new_node._attributes: 210 value = getattr(old_node, attr, None) 211 # end_lineno and end_col_offset are optional attributes, and they 212 # should be copied whether the value is None or not. 213 if value is not None or ( 214 hasattr(old_node, attr) and attr.startswith("end_") 215 ): 216 setattr(new_node, attr, value) 217 return new_node 218 219 220def fix_missing_locations(node): 221 """ 222 When you compile a node tree with compile(), the compiler expects lineno and 223 col_offset attributes for every node that supports them. This is rather 224 tedious to fill in for generated nodes, so this helper adds these attributes 225 recursively where not already set, by setting them to the values of the 226 parent node. It works recursively starting at *node*. 227 """ 228 def _fix(node, lineno, col_offset, end_lineno, end_col_offset): 229 if 'lineno' in node._attributes: 230 if not hasattr(node, 'lineno'): 231 node.lineno = lineno 232 else: 233 lineno = node.lineno 234 if 'end_lineno' in node._attributes: 235 if getattr(node, 'end_lineno', None) is None: 236 node.end_lineno = end_lineno 237 else: 238 end_lineno = node.end_lineno 239 if 'col_offset' in node._attributes: 240 if not hasattr(node, 'col_offset'): 241 node.col_offset = col_offset 242 else: 243 col_offset = node.col_offset 244 if 'end_col_offset' in node._attributes: 245 if getattr(node, 'end_col_offset', None) is None: 246 node.end_col_offset = end_col_offset 247 else: 248 end_col_offset = node.end_col_offset 249 for child in iter_child_nodes(node): 250 _fix(child, lineno, col_offset, end_lineno, end_col_offset) 251 _fix(node, 1, 0, 1, 0) 252 return node 253 254 255def increment_lineno(node, n=1): 256 """ 257 Increment the line number and end line number of each node in the tree 258 starting at *node* by *n*. This is useful to "move code" to a different 259 location in a file. 260 """ 261 for child in walk(node): 262 # TypeIgnore is a special case where lineno is not an attribute 263 # but rather a field of the node itself. 264 if isinstance(child, TypeIgnore): 265 child.lineno = getattr(child, 'lineno', 0) + n 266 continue 267 268 if 'lineno' in child._attributes: 269 child.lineno = getattr(child, 'lineno', 0) + n 270 if ( 271 "end_lineno" in child._attributes 272 and (end_lineno := getattr(child, "end_lineno", 0)) is not None 273 ): 274 child.end_lineno = end_lineno + n 275 return node 276 277 278def iter_fields(node): 279 """ 280 Yield a tuple of ``(fieldname, value)`` for each field in ``node._fields`` 281 that is present on *node*. 282 """ 283 for field in node._fields: 284 try: 285 yield field, getattr(node, field) 286 except AttributeError: 287 pass 288 289 290def iter_child_nodes(node): 291 """ 292 Yield all direct child nodes of *node*, that is, all fields that are nodes 293 and all items of fields that are lists of nodes. 294 """ 295 for name, field in iter_fields(node): 296 if isinstance(field, AST): 297 yield field 298 elif isinstance(field, list): 299 for item in field: 300 if isinstance(item, AST): 301 yield item 302 303 304def get_docstring(node, clean=True): 305 """ 306 Return the docstring for the given node or None if no docstring can 307 be found. If the node provided does not have docstrings a TypeError 308 will be raised. 309 310 If *clean* is `True`, all tabs are expanded to spaces and any whitespace 311 that can be uniformly removed from the second line onwards is removed. 312 """ 313 if not isinstance(node, (AsyncFunctionDef, FunctionDef, ClassDef, Module)): 314 raise TypeError("%r can't have docstrings" % node.__class__.__name__) 315 if not(node.body and isinstance(node.body[0], Expr)): 316 return None 317 node = node.body[0].value 318 if isinstance(node, Constant) and isinstance(node.value, str): 319 text = node.value 320 else: 321 return None 322 if clean: 323 import inspect 324 text = inspect.cleandoc(text) 325 return text 326 327 328_line_pattern = re.compile(r"(.*?(?:\r\n|\n|\r|$))") 329def _splitlines_no_ff(source, maxlines=None): 330 """Split a string into lines ignoring form feed and other chars. 331 332 This mimics how the Python parser splits source code. 333 """ 334 lines = [] 335 for lineno, match in enumerate(_line_pattern.finditer(source), 1): 336 if maxlines is not None and lineno > maxlines: 337 break 338 lines.append(match[0]) 339 return lines 340 341 342def _pad_whitespace(source): 343 r"""Replace all chars except '\f\t' in a line with spaces.""" 344 result = '' 345 for c in source: 346 if c in '\f\t': 347 result += c 348 else: 349 result += ' ' 350 return result 351 352 353def get_source_segment(source, node, *, padded=False): 354 """Get source code segment of the *source* that generated *node*. 355 356 If some location information (`lineno`, `end_lineno`, `col_offset`, 357 or `end_col_offset`) is missing, return None. 358 359 If *padded* is `True`, the first line of a multi-line statement will 360 be padded with spaces to match its original position. 361 """ 362 try: 363 if node.end_lineno is None or node.end_col_offset is None: 364 return None 365 lineno = node.lineno - 1 366 end_lineno = node.end_lineno - 1 367 col_offset = node.col_offset 368 end_col_offset = node.end_col_offset 369 except AttributeError: 370 return None 371 372 lines = _splitlines_no_ff(source, maxlines=end_lineno+1) 373 if end_lineno == lineno: 374 return lines[lineno].encode()[col_offset:end_col_offset].decode() 375 376 if padded: 377 padding = _pad_whitespace(lines[lineno].encode()[:col_offset].decode()) 378 else: 379 padding = '' 380 381 first = padding + lines[lineno].encode()[col_offset:].decode() 382 last = lines[end_lineno].encode()[:end_col_offset].decode() 383 lines = lines[lineno+1:end_lineno] 384 385 lines.insert(0, first) 386 lines.append(last) 387 return ''.join(lines) 388 389 390def walk(node): 391 """ 392 Recursively yield all descendant nodes in the tree starting at *node* 393 (including *node* itself), in no specified order. This is useful if you 394 only want to modify nodes in place and don't care about the context. 395 """ 396 from collections import deque 397 todo = deque([node]) 398 while todo: 399 node = todo.popleft() 400 todo.extend(iter_child_nodes(node)) 401 yield node 402 403 404class NodeVisitor(object): 405 """ 406 A node visitor base class that walks the abstract syntax tree and calls a 407 visitor function for every node found. This function may return a value 408 which is forwarded by the `visit` method. 409 410 This class is meant to be subclassed, with the subclass adding visitor 411 methods. 412 413 Per default the visitor functions for the nodes are ``'visit_'`` + 414 class name of the node. So a `TryFinally` node visit function would 415 be `visit_TryFinally`. This behavior can be changed by overriding 416 the `visit` method. If no visitor function exists for a node 417 (return value `None`) the `generic_visit` visitor is used instead. 418 419 Don't use the `NodeVisitor` if you want to apply changes to nodes during 420 traversing. For this a special visitor exists (`NodeTransformer`) that 421 allows modifications. 422 """ 423 424 def visit(self, node): 425 """Visit a node.""" 426 method = 'visit_' + node.__class__.__name__ 427 visitor = getattr(self, method, self.generic_visit) 428 return visitor(node) 429 430 def generic_visit(self, node): 431 """Called if no explicit visitor function exists for a node.""" 432 for field, value in iter_fields(node): 433 if isinstance(value, list): 434 for item in value: 435 if isinstance(item, AST): 436 self.visit(item) 437 elif isinstance(value, AST): 438 self.visit(value) 439 440 def visit_Constant(self, node): 441 value = node.value 442 type_name = _const_node_type_names.get(type(value)) 443 if type_name is None: 444 for cls, name in _const_node_type_names.items(): 445 if isinstance(value, cls): 446 type_name = name 447 break 448 if type_name is not None: 449 method = 'visit_' + type_name 450 try: 451 visitor = getattr(self, method) 452 except AttributeError: 453 pass 454 else: 455 import warnings 456 warnings.warn(f"{method} is deprecated; add visit_Constant", 457 DeprecationWarning, 2) 458 return visitor(node) 459 return self.generic_visit(node) 460 461 462class NodeTransformer(NodeVisitor): 463 """ 464 A :class:`NodeVisitor` subclass that walks the abstract syntax tree and 465 allows modification of nodes. 466 467 The `NodeTransformer` will walk the AST and use the return value of the 468 visitor methods to replace or remove the old node. If the return value of 469 the visitor method is ``None``, the node will be removed from its location, 470 otherwise it is replaced with the return value. The return value may be the 471 original node in which case no replacement takes place. 472 473 Here is an example transformer that rewrites all occurrences of name lookups 474 (``foo``) to ``data['foo']``:: 475 476 class RewriteName(NodeTransformer): 477 478 def visit_Name(self, node): 479 return Subscript( 480 value=Name(id='data', ctx=Load()), 481 slice=Constant(value=node.id), 482 ctx=node.ctx 483 ) 484 485 Keep in mind that if the node you're operating on has child nodes you must 486 either transform the child nodes yourself or call the :meth:`generic_visit` 487 method for the node first. 488 489 For nodes that were part of a collection of statements (that applies to all 490 statement nodes), the visitor may also return a list of nodes rather than 491 just a single node. 492 493 Usually you use the transformer like this:: 494 495 node = YourTransformer().visit(node) 496 """ 497 498 def generic_visit(self, node): 499 for field, old_value in iter_fields(node): 500 if isinstance(old_value, list): 501 new_values = [] 502 for value in old_value: 503 if isinstance(value, AST): 504 value = self.visit(value) 505 if value is None: 506 continue 507 elif not isinstance(value, AST): 508 new_values.extend(value) 509 continue 510 new_values.append(value) 511 old_value[:] = new_values 512 elif isinstance(old_value, AST): 513 new_node = self.visit(old_value) 514 if new_node is None: 515 delattr(node, field) 516 else: 517 setattr(node, field, new_node) 518 return node 519 520 521_DEPRECATED_VALUE_ALIAS_MESSAGE = ( 522 "{name} is deprecated and will be removed in Python {remove}; use value instead" 523) 524_DEPRECATED_CLASS_MESSAGE = ( 525 "{name} is deprecated and will be removed in Python {remove}; " 526 "use ast.Constant instead" 527) 528 529 530# If the ast module is loaded more than once, only add deprecated methods once 531if not hasattr(Constant, 'n'): 532 # The following code is for backward compatibility. 533 # It will be removed in future. 534 535 def _n_getter(self): 536 """Deprecated. Use value instead.""" 537 import warnings 538 warnings._deprecated( 539 "Attribute n", message=_DEPRECATED_VALUE_ALIAS_MESSAGE, remove=(3, 14) 540 ) 541 return self.value 542 543 def _n_setter(self, value): 544 import warnings 545 warnings._deprecated( 546 "Attribute n", message=_DEPRECATED_VALUE_ALIAS_MESSAGE, remove=(3, 14) 547 ) 548 self.value = value 549 550 def _s_getter(self): 551 """Deprecated. Use value instead.""" 552 import warnings 553 warnings._deprecated( 554 "Attribute s", message=_DEPRECATED_VALUE_ALIAS_MESSAGE, remove=(3, 14) 555 ) 556 return self.value 557 558 def _s_setter(self, value): 559 import warnings 560 warnings._deprecated( 561 "Attribute s", message=_DEPRECATED_VALUE_ALIAS_MESSAGE, remove=(3, 14) 562 ) 563 self.value = value 564 565 Constant.n = property(_n_getter, _n_setter) 566 Constant.s = property(_s_getter, _s_setter) 567 568class _ABC(type): 569 570 def __init__(cls, *args): 571 cls.__doc__ = """Deprecated AST node class. Use ast.Constant instead""" 572 573 def __instancecheck__(cls, inst): 574 if cls in _const_types: 575 import warnings 576 warnings._deprecated( 577 f"ast.{cls.__qualname__}", 578 message=_DEPRECATED_CLASS_MESSAGE, 579 remove=(3, 14) 580 ) 581 if not isinstance(inst, Constant): 582 return False 583 if cls in _const_types: 584 try: 585 value = inst.value 586 except AttributeError: 587 return False 588 else: 589 return ( 590 isinstance(value, _const_types[cls]) and 591 not isinstance(value, _const_types_not.get(cls, ())) 592 ) 593 return type.__instancecheck__(cls, inst) 594 595def _new(cls, *args, **kwargs): 596 for key in kwargs: 597 if key not in cls._fields: 598 # arbitrary keyword arguments are accepted 599 continue 600 pos = cls._fields.index(key) 601 if pos < len(args): 602 raise TypeError(f"{cls.__name__} got multiple values for argument {key!r}") 603 if cls in _const_types: 604 import warnings 605 warnings._deprecated( 606 f"ast.{cls.__qualname__}", message=_DEPRECATED_CLASS_MESSAGE, remove=(3, 14) 607 ) 608 return Constant(*args, **kwargs) 609 return Constant.__new__(cls, *args, **kwargs) 610 611class Num(Constant, metaclass=_ABC): 612 _fields = ('n',) 613 __new__ = _new 614 615class Str(Constant, metaclass=_ABC): 616 _fields = ('s',) 617 __new__ = _new 618 619class Bytes(Constant, metaclass=_ABC): 620 _fields = ('s',) 621 __new__ = _new 622 623class NameConstant(Constant, metaclass=_ABC): 624 __new__ = _new 625 626class Ellipsis(Constant, metaclass=_ABC): 627 _fields = () 628 629 def __new__(cls, *args, **kwargs): 630 if cls is _ast_Ellipsis: 631 import warnings 632 warnings._deprecated( 633 "ast.Ellipsis", message=_DEPRECATED_CLASS_MESSAGE, remove=(3, 14) 634 ) 635 return Constant(..., *args, **kwargs) 636 return Constant.__new__(cls, *args, **kwargs) 637 638# Keep another reference to Ellipsis in the global namespace 639# so it can be referenced in Ellipsis.__new__ 640# (The original "Ellipsis" name is removed from the global namespace later on) 641_ast_Ellipsis = Ellipsis 642 643_const_types = { 644 Num: (int, float, complex), 645 Str: (str,), 646 Bytes: (bytes,), 647 NameConstant: (type(None), bool), 648 Ellipsis: (type(...),), 649} 650_const_types_not = { 651 Num: (bool,), 652} 653 654_const_node_type_names = { 655 bool: 'NameConstant', # should be before int 656 type(None): 'NameConstant', 657 int: 'Num', 658 float: 'Num', 659 complex: 'Num', 660 str: 'Str', 661 bytes: 'Bytes', 662 type(...): 'Ellipsis', 663} 664 665class slice(AST): 666 """Deprecated AST node class.""" 667 668class Index(slice): 669 """Deprecated AST node class. Use the index value directly instead.""" 670 def __new__(cls, value, **kwargs): 671 return value 672 673class ExtSlice(slice): 674 """Deprecated AST node class. Use ast.Tuple instead.""" 675 def __new__(cls, dims=(), **kwargs): 676 return Tuple(list(dims), Load(), **kwargs) 677 678# If the ast module is loaded more than once, only add deprecated methods once 679if not hasattr(Tuple, 'dims'): 680 # The following code is for backward compatibility. 681 # It will be removed in future. 682 683 def _dims_getter(self): 684 """Deprecated. Use elts instead.""" 685 return self.elts 686 687 def _dims_setter(self, value): 688 self.elts = value 689 690 Tuple.dims = property(_dims_getter, _dims_setter) 691 692class Suite(mod): 693 """Deprecated AST node class. Unused in Python 3.""" 694 695class AugLoad(expr_context): 696 """Deprecated AST node class. Unused in Python 3.""" 697 698class AugStore(expr_context): 699 """Deprecated AST node class. Unused in Python 3.""" 700 701class Param(expr_context): 702 """Deprecated AST node class. Unused in Python 3.""" 703 704 705# Large float and imaginary literals get turned into infinities in the AST. 706# We unparse those infinities to INFSTR. 707_INFSTR = "1e" + repr(sys.float_info.max_10_exp + 1) 708 709@_simple_enum(IntEnum) 710class _Precedence: 711 """Precedence table that originated from python grammar.""" 712 713 NAMED_EXPR = auto() # <target> := <expr1> 714 TUPLE = auto() # <expr1>, <expr2> 715 YIELD = auto() # 'yield', 'yield from' 716 TEST = auto() # 'if'-'else', 'lambda' 717 OR = auto() # 'or' 718 AND = auto() # 'and' 719 NOT = auto() # 'not' 720 CMP = auto() # '<', '>', '==', '>=', '<=', '!=', 721 # 'in', 'not in', 'is', 'is not' 722 EXPR = auto() 723 BOR = EXPR # '|' 724 BXOR = auto() # '^' 725 BAND = auto() # '&' 726 SHIFT = auto() # '<<', '>>' 727 ARITH = auto() # '+', '-' 728 TERM = auto() # '*', '@', '/', '%', '//' 729 FACTOR = auto() # unary '+', '-', '~' 730 POWER = auto() # '**' 731 AWAIT = auto() # 'await' 732 ATOM = auto() 733 734 def next(self): 735 try: 736 return self.__class__(self + 1) 737 except ValueError: 738 return self 739 740 741_SINGLE_QUOTES = ("'", '"') 742_MULTI_QUOTES = ('"""', "'''") 743_ALL_QUOTES = (*_SINGLE_QUOTES, *_MULTI_QUOTES) 744 745class _Unparser(NodeVisitor): 746 """Methods in this class recursively traverse an AST and 747 output source code for the abstract syntax; original formatting 748 is disregarded.""" 749 750 def __init__(self): 751 self._source = [] 752 self._precedences = {} 753 self._type_ignores = {} 754 self._indent = 0 755 self._in_try_star = False 756 757 def interleave(self, inter, f, seq): 758 """Call f on each item in seq, calling inter() in between.""" 759 seq = iter(seq) 760 try: 761 f(next(seq)) 762 except StopIteration: 763 pass 764 else: 765 for x in seq: 766 inter() 767 f(x) 768 769 def items_view(self, traverser, items): 770 """Traverse and separate the given *items* with a comma and append it to 771 the buffer. If *items* is a single item sequence, a trailing comma 772 will be added.""" 773 if len(items) == 1: 774 traverser(items[0]) 775 self.write(",") 776 else: 777 self.interleave(lambda: self.write(", "), traverser, items) 778 779 def maybe_newline(self): 780 """Adds a newline if it isn't the start of generated source""" 781 if self._source: 782 self.write("\n") 783 784 def fill(self, text=""): 785 """Indent a piece of text and append it, according to the current 786 indentation level""" 787 self.maybe_newline() 788 self.write(" " * self._indent + text) 789 790 def write(self, *text): 791 """Add new source parts""" 792 self._source.extend(text) 793 794 @contextmanager 795 def buffered(self, buffer = None): 796 if buffer is None: 797 buffer = [] 798 799 original_source = self._source 800 self._source = buffer 801 yield buffer 802 self._source = original_source 803 804 @contextmanager 805 def block(self, *, extra = None): 806 """A context manager for preparing the source for blocks. It adds 807 the character':', increases the indentation on enter and decreases 808 the indentation on exit. If *extra* is given, it will be directly 809 appended after the colon character. 810 """ 811 self.write(":") 812 if extra: 813 self.write(extra) 814 self._indent += 1 815 yield 816 self._indent -= 1 817 818 @contextmanager 819 def delimit(self, start, end): 820 """A context manager for preparing the source for expressions. It adds 821 *start* to the buffer and enters, after exit it adds *end*.""" 822 823 self.write(start) 824 yield 825 self.write(end) 826 827 def delimit_if(self, start, end, condition): 828 if condition: 829 return self.delimit(start, end) 830 else: 831 return nullcontext() 832 833 def require_parens(self, precedence, node): 834 """Shortcut to adding precedence related parens""" 835 return self.delimit_if("(", ")", self.get_precedence(node) > precedence) 836 837 def get_precedence(self, node): 838 return self._precedences.get(node, _Precedence.TEST) 839 840 def set_precedence(self, precedence, *nodes): 841 for node in nodes: 842 self._precedences[node] = precedence 843 844 def get_raw_docstring(self, node): 845 """If a docstring node is found in the body of the *node* parameter, 846 return that docstring node, None otherwise. 847 848 Logic mirrored from ``_PyAST_GetDocString``.""" 849 if not isinstance( 850 node, (AsyncFunctionDef, FunctionDef, ClassDef, Module) 851 ) or len(node.body) < 1: 852 return None 853 node = node.body[0] 854 if not isinstance(node, Expr): 855 return None 856 node = node.value 857 if isinstance(node, Constant) and isinstance(node.value, str): 858 return node 859 860 def get_type_comment(self, node): 861 comment = self._type_ignores.get(node.lineno) or node.type_comment 862 if comment is not None: 863 return f" # type: {comment}" 864 865 def traverse(self, node): 866 if isinstance(node, list): 867 for item in node: 868 self.traverse(item) 869 else: 870 super().visit(node) 871 872 # Note: as visit() resets the output text, do NOT rely on 873 # NodeVisitor.generic_visit to handle any nodes (as it calls back in to 874 # the subclass visit() method, which resets self._source to an empty list) 875 def visit(self, node): 876 """Outputs a source code string that, if converted back to an ast 877 (using ast.parse) will generate an AST equivalent to *node*""" 878 self._source = [] 879 self.traverse(node) 880 return "".join(self._source) 881 882 def _write_docstring_and_traverse_body(self, node): 883 if (docstring := self.get_raw_docstring(node)): 884 self._write_docstring(docstring) 885 self.traverse(node.body[1:]) 886 else: 887 self.traverse(node.body) 888 889 def visit_Module(self, node): 890 self._type_ignores = { 891 ignore.lineno: f"ignore{ignore.tag}" 892 for ignore in node.type_ignores 893 } 894 self._write_docstring_and_traverse_body(node) 895 self._type_ignores.clear() 896 897 def visit_FunctionType(self, node): 898 with self.delimit("(", ")"): 899 self.interleave( 900 lambda: self.write(", "), self.traverse, node.argtypes 901 ) 902 903 self.write(" -> ") 904 self.traverse(node.returns) 905 906 def visit_Expr(self, node): 907 self.fill() 908 self.set_precedence(_Precedence.YIELD, node.value) 909 self.traverse(node.value) 910 911 def visit_NamedExpr(self, node): 912 with self.require_parens(_Precedence.NAMED_EXPR, node): 913 self.set_precedence(_Precedence.ATOM, node.target, node.value) 914 self.traverse(node.target) 915 self.write(" := ") 916 self.traverse(node.value) 917 918 def visit_Import(self, node): 919 self.fill("import ") 920 self.interleave(lambda: self.write(", "), self.traverse, node.names) 921 922 def visit_ImportFrom(self, node): 923 self.fill("from ") 924 self.write("." * (node.level or 0)) 925 if node.module: 926 self.write(node.module) 927 self.write(" import ") 928 self.interleave(lambda: self.write(", "), self.traverse, node.names) 929 930 def visit_Assign(self, node): 931 self.fill() 932 for target in node.targets: 933 self.set_precedence(_Precedence.TUPLE, target) 934 self.traverse(target) 935 self.write(" = ") 936 self.traverse(node.value) 937 if type_comment := self.get_type_comment(node): 938 self.write(type_comment) 939 940 def visit_AugAssign(self, node): 941 self.fill() 942 self.traverse(node.target) 943 self.write(" " + self.binop[node.op.__class__.__name__] + "= ") 944 self.traverse(node.value) 945 946 def visit_AnnAssign(self, node): 947 self.fill() 948 with self.delimit_if("(", ")", not node.simple and isinstance(node.target, Name)): 949 self.traverse(node.target) 950 self.write(": ") 951 self.traverse(node.annotation) 952 if node.value: 953 self.write(" = ") 954 self.traverse(node.value) 955 956 def visit_Return(self, node): 957 self.fill("return") 958 if node.value: 959 self.write(" ") 960 self.traverse(node.value) 961 962 def visit_Pass(self, node): 963 self.fill("pass") 964 965 def visit_Break(self, node): 966 self.fill("break") 967 968 def visit_Continue(self, node): 969 self.fill("continue") 970 971 def visit_Delete(self, node): 972 self.fill("del ") 973 self.interleave(lambda: self.write(", "), self.traverse, node.targets) 974 975 def visit_Assert(self, node): 976 self.fill("assert ") 977 self.traverse(node.test) 978 if node.msg: 979 self.write(", ") 980 self.traverse(node.msg) 981 982 def visit_Global(self, node): 983 self.fill("global ") 984 self.interleave(lambda: self.write(", "), self.write, node.names) 985 986 def visit_Nonlocal(self, node): 987 self.fill("nonlocal ") 988 self.interleave(lambda: self.write(", "), self.write, node.names) 989 990 def visit_Await(self, node): 991 with self.require_parens(_Precedence.AWAIT, node): 992 self.write("await") 993 if node.value: 994 self.write(" ") 995 self.set_precedence(_Precedence.ATOM, node.value) 996 self.traverse(node.value) 997 998 def visit_Yield(self, node): 999 with self.require_parens(_Precedence.YIELD, node): 1000 self.write("yield") 1001 if node.value: 1002 self.write(" ") 1003 self.set_precedence(_Precedence.ATOM, node.value) 1004 self.traverse(node.value) 1005 1006 def visit_YieldFrom(self, node): 1007 with self.require_parens(_Precedence.YIELD, node): 1008 self.write("yield from ") 1009 if not node.value: 1010 raise ValueError("Node can't be used without a value attribute.") 1011 self.set_precedence(_Precedence.ATOM, node.value) 1012 self.traverse(node.value) 1013 1014 def visit_Raise(self, node): 1015 self.fill("raise") 1016 if not node.exc: 1017 if node.cause: 1018 raise ValueError(f"Node can't use cause without an exception.") 1019 return 1020 self.write(" ") 1021 self.traverse(node.exc) 1022 if node.cause: 1023 self.write(" from ") 1024 self.traverse(node.cause) 1025 1026 def do_visit_try(self, node): 1027 self.fill("try") 1028 with self.block(): 1029 self.traverse(node.body) 1030 for ex in node.handlers: 1031 self.traverse(ex) 1032 if node.orelse: 1033 self.fill("else") 1034 with self.block(): 1035 self.traverse(node.orelse) 1036 if node.finalbody: 1037 self.fill("finally") 1038 with self.block(): 1039 self.traverse(node.finalbody) 1040 1041 def visit_Try(self, node): 1042 prev_in_try_star = self._in_try_star 1043 try: 1044 self._in_try_star = False 1045 self.do_visit_try(node) 1046 finally: 1047 self._in_try_star = prev_in_try_star 1048 1049 def visit_TryStar(self, node): 1050 prev_in_try_star = self._in_try_star 1051 try: 1052 self._in_try_star = True 1053 self.do_visit_try(node) 1054 finally: 1055 self._in_try_star = prev_in_try_star 1056 1057 def visit_ExceptHandler(self, node): 1058 self.fill("except*" if self._in_try_star else "except") 1059 if node.type: 1060 self.write(" ") 1061 self.traverse(node.type) 1062 if node.name: 1063 self.write(" as ") 1064 self.write(node.name) 1065 with self.block(): 1066 self.traverse(node.body) 1067 1068 def visit_ClassDef(self, node): 1069 self.maybe_newline() 1070 for deco in node.decorator_list: 1071 self.fill("@") 1072 self.traverse(deco) 1073 self.fill("class " + node.name) 1074 if hasattr(node, "type_params"): 1075 self._type_params_helper(node.type_params) 1076 with self.delimit_if("(", ")", condition = node.bases or node.keywords): 1077 comma = False 1078 for e in node.bases: 1079 if comma: 1080 self.write(", ") 1081 else: 1082 comma = True 1083 self.traverse(e) 1084 for e in node.keywords: 1085 if comma: 1086 self.write(", ") 1087 else: 1088 comma = True 1089 self.traverse(e) 1090 1091 with self.block(): 1092 self._write_docstring_and_traverse_body(node) 1093 1094 def visit_FunctionDef(self, node): 1095 self._function_helper(node, "def") 1096 1097 def visit_AsyncFunctionDef(self, node): 1098 self._function_helper(node, "async def") 1099 1100 def _function_helper(self, node, fill_suffix): 1101 self.maybe_newline() 1102 for deco in node.decorator_list: 1103 self.fill("@") 1104 self.traverse(deco) 1105 def_str = fill_suffix + " " + node.name 1106 self.fill(def_str) 1107 if hasattr(node, "type_params"): 1108 self._type_params_helper(node.type_params) 1109 with self.delimit("(", ")"): 1110 self.traverse(node.args) 1111 if node.returns: 1112 self.write(" -> ") 1113 self.traverse(node.returns) 1114 with self.block(extra=self.get_type_comment(node)): 1115 self._write_docstring_and_traverse_body(node) 1116 1117 def _type_params_helper(self, type_params): 1118 if type_params is not None and len(type_params) > 0: 1119 with self.delimit("[", "]"): 1120 self.interleave(lambda: self.write(", "), self.traverse, type_params) 1121 1122 def visit_TypeVar(self, node): 1123 self.write(node.name) 1124 if node.bound: 1125 self.write(": ") 1126 self.traverse(node.bound) 1127 if node.default_value: 1128 self.write(" = ") 1129 self.traverse(node.default_value) 1130 1131 def visit_TypeVarTuple(self, node): 1132 self.write("*" + node.name) 1133 if node.default_value: 1134 self.write(" = ") 1135 self.traverse(node.default_value) 1136 1137 def visit_ParamSpec(self, node): 1138 self.write("**" + node.name) 1139 if node.default_value: 1140 self.write(" = ") 1141 self.traverse(node.default_value) 1142 1143 def visit_TypeAlias(self, node): 1144 self.fill("type ") 1145 self.traverse(node.name) 1146 self._type_params_helper(node.type_params) 1147 self.write(" = ") 1148 self.traverse(node.value) 1149 1150 def visit_For(self, node): 1151 self._for_helper("for ", node) 1152 1153 def visit_AsyncFor(self, node): 1154 self._for_helper("async for ", node) 1155 1156 def _for_helper(self, fill, node): 1157 self.fill(fill) 1158 self.set_precedence(_Precedence.TUPLE, node.target) 1159 self.traverse(node.target) 1160 self.write(" in ") 1161 self.traverse(node.iter) 1162 with self.block(extra=self.get_type_comment(node)): 1163 self.traverse(node.body) 1164 if node.orelse: 1165 self.fill("else") 1166 with self.block(): 1167 self.traverse(node.orelse) 1168 1169 def visit_If(self, node): 1170 self.fill("if ") 1171 self.traverse(node.test) 1172 with self.block(): 1173 self.traverse(node.body) 1174 # collapse nested ifs into equivalent elifs. 1175 while node.orelse and len(node.orelse) == 1 and isinstance(node.orelse[0], If): 1176 node = node.orelse[0] 1177 self.fill("elif ") 1178 self.traverse(node.test) 1179 with self.block(): 1180 self.traverse(node.body) 1181 # final else 1182 if node.orelse: 1183 self.fill("else") 1184 with self.block(): 1185 self.traverse(node.orelse) 1186 1187 def visit_While(self, node): 1188 self.fill("while ") 1189 self.traverse(node.test) 1190 with self.block(): 1191 self.traverse(node.body) 1192 if node.orelse: 1193 self.fill("else") 1194 with self.block(): 1195 self.traverse(node.orelse) 1196 1197 def visit_With(self, node): 1198 self.fill("with ") 1199 self.interleave(lambda: self.write(", "), self.traverse, node.items) 1200 with self.block(extra=self.get_type_comment(node)): 1201 self.traverse(node.body) 1202 1203 def visit_AsyncWith(self, node): 1204 self.fill("async with ") 1205 self.interleave(lambda: self.write(", "), self.traverse, node.items) 1206 with self.block(extra=self.get_type_comment(node)): 1207 self.traverse(node.body) 1208 1209 def _str_literal_helper( 1210 self, string, *, quote_types=_ALL_QUOTES, escape_special_whitespace=False 1211 ): 1212 """Helper for writing string literals, minimizing escapes. 1213 Returns the tuple (string literal to write, possible quote types). 1214 """ 1215 def escape_char(c): 1216 # \n and \t are non-printable, but we only escape them if 1217 # escape_special_whitespace is True 1218 if not escape_special_whitespace and c in "\n\t": 1219 return c 1220 # Always escape backslashes and other non-printable characters 1221 if c == "\\" or not c.isprintable(): 1222 return c.encode("unicode_escape").decode("ascii") 1223 return c 1224 1225 escaped_string = "".join(map(escape_char, string)) 1226 possible_quotes = quote_types 1227 if "\n" in escaped_string: 1228 possible_quotes = [q for q in possible_quotes if q in _MULTI_QUOTES] 1229 possible_quotes = [q for q in possible_quotes if q not in escaped_string] 1230 if not possible_quotes: 1231 # If there aren't any possible_quotes, fallback to using repr 1232 # on the original string. Try to use a quote from quote_types, 1233 # e.g., so that we use triple quotes for docstrings. 1234 string = repr(string) 1235 quote = next((q for q in quote_types if string[0] in q), string[0]) 1236 return string[1:-1], [quote] 1237 if escaped_string: 1238 # Sort so that we prefer '''"''' over """\"""" 1239 possible_quotes.sort(key=lambda q: q[0] == escaped_string[-1]) 1240 # If we're using triple quotes and we'd need to escape a final 1241 # quote, escape it 1242 if possible_quotes[0][0] == escaped_string[-1]: 1243 assert len(possible_quotes[0]) == 3 1244 escaped_string = escaped_string[:-1] + "\\" + escaped_string[-1] 1245 return escaped_string, possible_quotes 1246 1247 def _write_str_avoiding_backslashes(self, string, *, quote_types=_ALL_QUOTES): 1248 """Write string literal value with a best effort attempt to avoid backslashes.""" 1249 string, quote_types = self._str_literal_helper(string, quote_types=quote_types) 1250 quote_type = quote_types[0] 1251 self.write(f"{quote_type}{string}{quote_type}") 1252 1253 def visit_JoinedStr(self, node): 1254 self.write("f") 1255 1256 fstring_parts = [] 1257 for value in node.values: 1258 with self.buffered() as buffer: 1259 self._write_fstring_inner(value) 1260 fstring_parts.append( 1261 ("".join(buffer), isinstance(value, Constant)) 1262 ) 1263 1264 new_fstring_parts = [] 1265 quote_types = list(_ALL_QUOTES) 1266 fallback_to_repr = False 1267 for value, is_constant in fstring_parts: 1268 if is_constant: 1269 value, new_quote_types = self._str_literal_helper( 1270 value, 1271 quote_types=quote_types, 1272 escape_special_whitespace=True, 1273 ) 1274 if set(new_quote_types).isdisjoint(quote_types): 1275 fallback_to_repr = True 1276 break 1277 quote_types = new_quote_types 1278 elif "\n" in value: 1279 quote_types = [q for q in quote_types if q in _MULTI_QUOTES] 1280 assert quote_types 1281 new_fstring_parts.append(value) 1282 1283 if fallback_to_repr: 1284 # If we weren't able to find a quote type that works for all parts 1285 # of the JoinedStr, fallback to using repr and triple single quotes. 1286 quote_types = ["'''"] 1287 new_fstring_parts.clear() 1288 for value, is_constant in fstring_parts: 1289 if is_constant: 1290 value = repr('"' + value) # force repr to use single quotes 1291 expected_prefix = "'\"" 1292 assert value.startswith(expected_prefix), repr(value) 1293 value = value[len(expected_prefix):-1] 1294 new_fstring_parts.append(value) 1295 1296 value = "".join(new_fstring_parts) 1297 quote_type = quote_types[0] 1298 self.write(f"{quote_type}{value}{quote_type}") 1299 1300 def _write_fstring_inner(self, node, is_format_spec=False): 1301 if isinstance(node, JoinedStr): 1302 # for both the f-string itself, and format_spec 1303 for value in node.values: 1304 self._write_fstring_inner(value, is_format_spec=is_format_spec) 1305 elif isinstance(node, Constant) and isinstance(node.value, str): 1306 value = node.value.replace("{", "{{").replace("}", "}}") 1307 1308 if is_format_spec: 1309 value = value.replace("\\", "\\\\") 1310 value = value.replace("'", "\\'") 1311 value = value.replace('"', '\\"') 1312 value = value.replace("\n", "\\n") 1313 self.write(value) 1314 elif isinstance(node, FormattedValue): 1315 self.visit_FormattedValue(node) 1316 else: 1317 raise ValueError(f"Unexpected node inside JoinedStr, {node!r}") 1318 1319 def visit_FormattedValue(self, node): 1320 def unparse_inner(inner): 1321 unparser = type(self)() 1322 unparser.set_precedence(_Precedence.TEST.next(), inner) 1323 return unparser.visit(inner) 1324 1325 with self.delimit("{", "}"): 1326 expr = unparse_inner(node.value) 1327 if expr.startswith("{"): 1328 # Separate pair of opening brackets as "{ {" 1329 self.write(" ") 1330 self.write(expr) 1331 if node.conversion != -1: 1332 self.write(f"!{chr(node.conversion)}") 1333 if node.format_spec: 1334 self.write(":") 1335 self._write_fstring_inner(node.format_spec, is_format_spec=True) 1336 1337 def visit_Name(self, node): 1338 self.write(node.id) 1339 1340 def _write_docstring(self, node): 1341 self.fill() 1342 if node.kind == "u": 1343 self.write("u") 1344 self._write_str_avoiding_backslashes(node.value, quote_types=_MULTI_QUOTES) 1345 1346 def _write_constant(self, value): 1347 if isinstance(value, (float, complex)): 1348 # Substitute overflowing decimal literal for AST infinities, 1349 # and inf - inf for NaNs. 1350 self.write( 1351 repr(value) 1352 .replace("inf", _INFSTR) 1353 .replace("nan", f"({_INFSTR}-{_INFSTR})") 1354 ) 1355 else: 1356 self.write(repr(value)) 1357 1358 def visit_Constant(self, node): 1359 value = node.value 1360 if isinstance(value, tuple): 1361 with self.delimit("(", ")"): 1362 self.items_view(self._write_constant, value) 1363 elif value is ...: 1364 self.write("...") 1365 else: 1366 if node.kind == "u": 1367 self.write("u") 1368 self._write_constant(node.value) 1369 1370 def visit_List(self, node): 1371 with self.delimit("[", "]"): 1372 self.interleave(lambda: self.write(", "), self.traverse, node.elts) 1373 1374 def visit_ListComp(self, node): 1375 with self.delimit("[", "]"): 1376 self.traverse(node.elt) 1377 for gen in node.generators: 1378 self.traverse(gen) 1379 1380 def visit_GeneratorExp(self, node): 1381 with self.delimit("(", ")"): 1382 self.traverse(node.elt) 1383 for gen in node.generators: 1384 self.traverse(gen) 1385 1386 def visit_SetComp(self, node): 1387 with self.delimit("{", "}"): 1388 self.traverse(node.elt) 1389 for gen in node.generators: 1390 self.traverse(gen) 1391 1392 def visit_DictComp(self, node): 1393 with self.delimit("{", "}"): 1394 self.traverse(node.key) 1395 self.write(": ") 1396 self.traverse(node.value) 1397 for gen in node.generators: 1398 self.traverse(gen) 1399 1400 def visit_comprehension(self, node): 1401 if node.is_async: 1402 self.write(" async for ") 1403 else: 1404 self.write(" for ") 1405 self.set_precedence(_Precedence.TUPLE, node.target) 1406 self.traverse(node.target) 1407 self.write(" in ") 1408 self.set_precedence(_Precedence.TEST.next(), node.iter, *node.ifs) 1409 self.traverse(node.iter) 1410 for if_clause in node.ifs: 1411 self.write(" if ") 1412 self.traverse(if_clause) 1413 1414 def visit_IfExp(self, node): 1415 with self.require_parens(_Precedence.TEST, node): 1416 self.set_precedence(_Precedence.TEST.next(), node.body, node.test) 1417 self.traverse(node.body) 1418 self.write(" if ") 1419 self.traverse(node.test) 1420 self.write(" else ") 1421 self.set_precedence(_Precedence.TEST, node.orelse) 1422 self.traverse(node.orelse) 1423 1424 def visit_Set(self, node): 1425 if node.elts: 1426 with self.delimit("{", "}"): 1427 self.interleave(lambda: self.write(", "), self.traverse, node.elts) 1428 else: 1429 # `{}` would be interpreted as a dictionary literal, and 1430 # `set` might be shadowed. Thus: 1431 self.write('{*()}') 1432 1433 def visit_Dict(self, node): 1434 def write_key_value_pair(k, v): 1435 self.traverse(k) 1436 self.write(": ") 1437 self.traverse(v) 1438 1439 def write_item(item): 1440 k, v = item 1441 if k is None: 1442 # for dictionary unpacking operator in dicts {**{'y': 2}} 1443 # see PEP 448 for details 1444 self.write("**") 1445 self.set_precedence(_Precedence.EXPR, v) 1446 self.traverse(v) 1447 else: 1448 write_key_value_pair(k, v) 1449 1450 with self.delimit("{", "}"): 1451 self.interleave( 1452 lambda: self.write(", "), write_item, zip(node.keys, node.values) 1453 ) 1454 1455 def visit_Tuple(self, node): 1456 with self.delimit_if( 1457 "(", 1458 ")", 1459 len(node.elts) == 0 or self.get_precedence(node) > _Precedence.TUPLE 1460 ): 1461 self.items_view(self.traverse, node.elts) 1462 1463 unop = {"Invert": "~", "Not": "not", "UAdd": "+", "USub": "-"} 1464 unop_precedence = { 1465 "not": _Precedence.NOT, 1466 "~": _Precedence.FACTOR, 1467 "+": _Precedence.FACTOR, 1468 "-": _Precedence.FACTOR, 1469 } 1470 1471 def visit_UnaryOp(self, node): 1472 operator = self.unop[node.op.__class__.__name__] 1473 operator_precedence = self.unop_precedence[operator] 1474 with self.require_parens(operator_precedence, node): 1475 self.write(operator) 1476 # factor prefixes (+, -, ~) shouldn't be separated 1477 # from the value they belong, (e.g: +1 instead of + 1) 1478 if operator_precedence is not _Precedence.FACTOR: 1479 self.write(" ") 1480 self.set_precedence(operator_precedence, node.operand) 1481 self.traverse(node.operand) 1482 1483 binop = { 1484 "Add": "+", 1485 "Sub": "-", 1486 "Mult": "*", 1487 "MatMult": "@", 1488 "Div": "/", 1489 "Mod": "%", 1490 "LShift": "<<", 1491 "RShift": ">>", 1492 "BitOr": "|", 1493 "BitXor": "^", 1494 "BitAnd": "&", 1495 "FloorDiv": "//", 1496 "Pow": "**", 1497 } 1498 1499 binop_precedence = { 1500 "+": _Precedence.ARITH, 1501 "-": _Precedence.ARITH, 1502 "*": _Precedence.TERM, 1503 "@": _Precedence.TERM, 1504 "/": _Precedence.TERM, 1505 "%": _Precedence.TERM, 1506 "<<": _Precedence.SHIFT, 1507 ">>": _Precedence.SHIFT, 1508 "|": _Precedence.BOR, 1509 "^": _Precedence.BXOR, 1510 "&": _Precedence.BAND, 1511 "//": _Precedence.TERM, 1512 "**": _Precedence.POWER, 1513 } 1514 1515 binop_rassoc = frozenset(("**",)) 1516 def visit_BinOp(self, node): 1517 operator = self.binop[node.op.__class__.__name__] 1518 operator_precedence = self.binop_precedence[operator] 1519 with self.require_parens(operator_precedence, node): 1520 if operator in self.binop_rassoc: 1521 left_precedence = operator_precedence.next() 1522 right_precedence = operator_precedence 1523 else: 1524 left_precedence = operator_precedence 1525 right_precedence = operator_precedence.next() 1526 1527 self.set_precedence(left_precedence, node.left) 1528 self.traverse(node.left) 1529 self.write(f" {operator} ") 1530 self.set_precedence(right_precedence, node.right) 1531 self.traverse(node.right) 1532 1533 cmpops = { 1534 "Eq": "==", 1535 "NotEq": "!=", 1536 "Lt": "<", 1537 "LtE": "<=", 1538 "Gt": ">", 1539 "GtE": ">=", 1540 "Is": "is", 1541 "IsNot": "is not", 1542 "In": "in", 1543 "NotIn": "not in", 1544 } 1545 1546 def visit_Compare(self, node): 1547 with self.require_parens(_Precedence.CMP, node): 1548 self.set_precedence(_Precedence.CMP.next(), node.left, *node.comparators) 1549 self.traverse(node.left) 1550 for o, e in zip(node.ops, node.comparators): 1551 self.write(" " + self.cmpops[o.__class__.__name__] + " ") 1552 self.traverse(e) 1553 1554 boolops = {"And": "and", "Or": "or"} 1555 boolop_precedence = {"and": _Precedence.AND, "or": _Precedence.OR} 1556 1557 def visit_BoolOp(self, node): 1558 operator = self.boolops[node.op.__class__.__name__] 1559 operator_precedence = self.boolop_precedence[operator] 1560 1561 def increasing_level_traverse(node): 1562 nonlocal operator_precedence 1563 operator_precedence = operator_precedence.next() 1564 self.set_precedence(operator_precedence, node) 1565 self.traverse(node) 1566 1567 with self.require_parens(operator_precedence, node): 1568 s = f" {operator} " 1569 self.interleave(lambda: self.write(s), increasing_level_traverse, node.values) 1570 1571 def visit_Attribute(self, node): 1572 self.set_precedence(_Precedence.ATOM, node.value) 1573 self.traverse(node.value) 1574 # Special case: 3.__abs__() is a syntax error, so if node.value 1575 # is an integer literal then we need to either parenthesize 1576 # it or add an extra space to get 3 .__abs__(). 1577 if isinstance(node.value, Constant) and isinstance(node.value.value, int): 1578 self.write(" ") 1579 self.write(".") 1580 self.write(node.attr) 1581 1582 def visit_Call(self, node): 1583 self.set_precedence(_Precedence.ATOM, node.func) 1584 self.traverse(node.func) 1585 with self.delimit("(", ")"): 1586 comma = False 1587 for e in node.args: 1588 if comma: 1589 self.write(", ") 1590 else: 1591 comma = True 1592 self.traverse(e) 1593 for e in node.keywords: 1594 if comma: 1595 self.write(", ") 1596 else: 1597 comma = True 1598 self.traverse(e) 1599 1600 def visit_Subscript(self, node): 1601 def is_non_empty_tuple(slice_value): 1602 return ( 1603 isinstance(slice_value, Tuple) 1604 and slice_value.elts 1605 ) 1606 1607 self.set_precedence(_Precedence.ATOM, node.value) 1608 self.traverse(node.value) 1609 with self.delimit("[", "]"): 1610 if is_non_empty_tuple(node.slice): 1611 # parentheses can be omitted if the tuple isn't empty 1612 self.items_view(self.traverse, node.slice.elts) 1613 else: 1614 self.traverse(node.slice) 1615 1616 def visit_Starred(self, node): 1617 self.write("*") 1618 self.set_precedence(_Precedence.EXPR, node.value) 1619 self.traverse(node.value) 1620 1621 def visit_Ellipsis(self, node): 1622 self.write("...") 1623 1624 def visit_Slice(self, node): 1625 if node.lower: 1626 self.traverse(node.lower) 1627 self.write(":") 1628 if node.upper: 1629 self.traverse(node.upper) 1630 if node.step: 1631 self.write(":") 1632 self.traverse(node.step) 1633 1634 def visit_Match(self, node): 1635 self.fill("match ") 1636 self.traverse(node.subject) 1637 with self.block(): 1638 for case in node.cases: 1639 self.traverse(case) 1640 1641 def visit_arg(self, node): 1642 self.write(node.arg) 1643 if node.annotation: 1644 self.write(": ") 1645 self.traverse(node.annotation) 1646 1647 def visit_arguments(self, node): 1648 first = True 1649 # normal arguments 1650 all_args = node.posonlyargs + node.args 1651 defaults = [None] * (len(all_args) - len(node.defaults)) + node.defaults 1652 for index, elements in enumerate(zip(all_args, defaults), 1): 1653 a, d = elements 1654 if first: 1655 first = False 1656 else: 1657 self.write(", ") 1658 self.traverse(a) 1659 if d: 1660 self.write("=") 1661 self.traverse(d) 1662 if index == len(node.posonlyargs): 1663 self.write(", /") 1664 1665 # varargs, or bare '*' if no varargs but keyword-only arguments present 1666 if node.vararg or node.kwonlyargs: 1667 if first: 1668 first = False 1669 else: 1670 self.write(", ") 1671 self.write("*") 1672 if node.vararg: 1673 self.write(node.vararg.arg) 1674 if node.vararg.annotation: 1675 self.write(": ") 1676 self.traverse(node.vararg.annotation) 1677 1678 # keyword-only arguments 1679 if node.kwonlyargs: 1680 for a, d in zip(node.kwonlyargs, node.kw_defaults): 1681 self.write(", ") 1682 self.traverse(a) 1683 if d: 1684 self.write("=") 1685 self.traverse(d) 1686 1687 # kwargs 1688 if node.kwarg: 1689 if first: 1690 first = False 1691 else: 1692 self.write(", ") 1693 self.write("**" + node.kwarg.arg) 1694 if node.kwarg.annotation: 1695 self.write(": ") 1696 self.traverse(node.kwarg.annotation) 1697 1698 def visit_keyword(self, node): 1699 if node.arg is None: 1700 self.write("**") 1701 else: 1702 self.write(node.arg) 1703 self.write("=") 1704 self.traverse(node.value) 1705 1706 def visit_Lambda(self, node): 1707 with self.require_parens(_Precedence.TEST, node): 1708 self.write("lambda") 1709 with self.buffered() as buffer: 1710 self.traverse(node.args) 1711 if buffer: 1712 self.write(" ", *buffer) 1713 self.write(": ") 1714 self.set_precedence(_Precedence.TEST, node.body) 1715 self.traverse(node.body) 1716 1717 def visit_alias(self, node): 1718 self.write(node.name) 1719 if node.asname: 1720 self.write(" as " + node.asname) 1721 1722 def visit_withitem(self, node): 1723 self.traverse(node.context_expr) 1724 if node.optional_vars: 1725 self.write(" as ") 1726 self.traverse(node.optional_vars) 1727 1728 def visit_match_case(self, node): 1729 self.fill("case ") 1730 self.traverse(node.pattern) 1731 if node.guard: 1732 self.write(" if ") 1733 self.traverse(node.guard) 1734 with self.block(): 1735 self.traverse(node.body) 1736 1737 def visit_MatchValue(self, node): 1738 self.traverse(node.value) 1739 1740 def visit_MatchSingleton(self, node): 1741 self._write_constant(node.value) 1742 1743 def visit_MatchSequence(self, node): 1744 with self.delimit("[", "]"): 1745 self.interleave( 1746 lambda: self.write(", "), self.traverse, node.patterns 1747 ) 1748 1749 def visit_MatchStar(self, node): 1750 name = node.name 1751 if name is None: 1752 name = "_" 1753 self.write(f"*{name}") 1754 1755 def visit_MatchMapping(self, node): 1756 def write_key_pattern_pair(pair): 1757 k, p = pair 1758 self.traverse(k) 1759 self.write(": ") 1760 self.traverse(p) 1761 1762 with self.delimit("{", "}"): 1763 keys = node.keys 1764 self.interleave( 1765 lambda: self.write(", "), 1766 write_key_pattern_pair, 1767 zip(keys, node.patterns, strict=True), 1768 ) 1769 rest = node.rest 1770 if rest is not None: 1771 if keys: 1772 self.write(", ") 1773 self.write(f"**{rest}") 1774 1775 def visit_MatchClass(self, node): 1776 self.set_precedence(_Precedence.ATOM, node.cls) 1777 self.traverse(node.cls) 1778 with self.delimit("(", ")"): 1779 patterns = node.patterns 1780 self.interleave( 1781 lambda: self.write(", "), self.traverse, patterns 1782 ) 1783 attrs = node.kwd_attrs 1784 if attrs: 1785 def write_attr_pattern(pair): 1786 attr, pattern = pair 1787 self.write(f"{attr}=") 1788 self.traverse(pattern) 1789 1790 if patterns: 1791 self.write(", ") 1792 self.interleave( 1793 lambda: self.write(", "), 1794 write_attr_pattern, 1795 zip(attrs, node.kwd_patterns, strict=True), 1796 ) 1797 1798 def visit_MatchAs(self, node): 1799 name = node.name 1800 pattern = node.pattern 1801 if name is None: 1802 self.write("_") 1803 elif pattern is None: 1804 self.write(node.name) 1805 else: 1806 with self.require_parens(_Precedence.TEST, node): 1807 self.set_precedence(_Precedence.BOR, node.pattern) 1808 self.traverse(node.pattern) 1809 self.write(f" as {node.name}") 1810 1811 def visit_MatchOr(self, node): 1812 with self.require_parens(_Precedence.BOR, node): 1813 self.set_precedence(_Precedence.BOR.next(), *node.patterns) 1814 self.interleave(lambda: self.write(" | "), self.traverse, node.patterns) 1815 1816def unparse(ast_obj): 1817 unparser = _Unparser() 1818 return unparser.visit(ast_obj) 1819 1820 1821_deprecated_globals = { 1822 name: globals().pop(name) 1823 for name in ('Num', 'Str', 'Bytes', 'NameConstant', 'Ellipsis') 1824} 1825 1826def __getattr__(name): 1827 if name in _deprecated_globals: 1828 globals()[name] = value = _deprecated_globals[name] 1829 import warnings 1830 warnings._deprecated( 1831 f"ast.{name}", message=_DEPRECATED_CLASS_MESSAGE, remove=(3, 14) 1832 ) 1833 return value 1834 raise AttributeError(f"module 'ast' has no attribute '{name}'") 1835 1836 1837def main(): 1838 import argparse 1839 1840 parser = argparse.ArgumentParser(prog='python -m ast') 1841 parser.add_argument('infile', nargs='?', default='-', 1842 help='the file to parse; defaults to stdin') 1843 parser.add_argument('-m', '--mode', default='exec', 1844 choices=('exec', 'single', 'eval', 'func_type'), 1845 help='specify what kind of code must be parsed') 1846 parser.add_argument('--no-type-comments', default=True, action='store_false', 1847 help="don't add information about type comments") 1848 parser.add_argument('-a', '--include-attributes', action='store_true', 1849 help='include attributes such as line numbers and ' 1850 'column offsets') 1851 parser.add_argument('-i', '--indent', type=int, default=3, 1852 help='indentation of nodes (number of spaces)') 1853 args = parser.parse_args() 1854 1855 if args.infile == '-': 1856 name = '<stdin>' 1857 source = sys.stdin.buffer.read() 1858 else: 1859 name = args.infile 1860 with open(args.infile, 'rb') as infile: 1861 source = infile.read() 1862 tree = parse(source, name, args.mode, type_comments=args.no_type_comments) 1863 print(dump(tree, include_attributes=args.include_attributes, indent=args.indent)) 1864 1865if __name__ == '__main__': 1866 main() 1867