1""" 2 ast 3 ~~~ 4 5 The `ast` module helps Python applications to process trees of the Python 6 abstract syntax grammar. The abstract syntax itself might change with 7 each Python release; this module helps to find out programmatically what 8 the current grammar looks like and allows modifications of it. 9 10 An abstract syntax tree can be generated by passing `ast.PyCF_ONLY_AST` as 11 a flag to the `compile()` builtin function or by using the `parse()` 12 function from this module. The result will be a tree of objects whose 13 classes all inherit from `ast.AST`. 14 15 A modified abstract syntax tree can be compiled into a Python code object 16 using the built-in `compile()` function. 17 18 Additionally various helper functions are provided that make working with 19 the trees simpler. The main intention of the helper functions and this 20 module in general is to provide an easy to use interface for libraries 21 that work tightly with the python syntax (template engines for example). 22 23 24 :copyright: Copyright 2008 by Armin Ronacher. 25 :license: Python License. 26""" 27from _ast import * 28 29 30def parse(source, filename='<unknown>', mode='exec', *, 31 type_comments=False, feature_version=None): 32 """ 33 Parse the source into an AST node. 34 Equivalent to compile(source, filename, mode, PyCF_ONLY_AST). 35 Pass type_comments=True to get back type comments where the syntax allows. 36 """ 37 flags = PyCF_ONLY_AST 38 if type_comments: 39 flags |= PyCF_TYPE_COMMENTS 40 if isinstance(feature_version, tuple): 41 major, minor = feature_version # Should be a 2-tuple. 42 assert major == 3 43 feature_version = minor 44 elif feature_version is None: 45 feature_version = -1 46 # Else it should be an int giving the minor version for 3.x. 47 return compile(source, filename, mode, flags, 48 _feature_version=feature_version) 49 50 51def literal_eval(node_or_string): 52 """ 53 Safely evaluate an expression node or a string containing a Python 54 expression. The string or node provided may only consist of the following 55 Python literal structures: strings, bytes, numbers, tuples, lists, dicts, 56 sets, booleans, and None. 57 """ 58 if isinstance(node_or_string, str): 59 node_or_string = parse(node_or_string, mode='eval') 60 if isinstance(node_or_string, Expression): 61 node_or_string = node_or_string.body 62 def _convert_num(node): 63 if isinstance(node, Constant): 64 if type(node.value) in (int, float, complex): 65 return node.value 66 raise ValueError('malformed node or string: ' + repr(node)) 67 def _convert_signed_num(node): 68 if isinstance(node, UnaryOp) and isinstance(node.op, (UAdd, USub)): 69 operand = _convert_num(node.operand) 70 if isinstance(node.op, UAdd): 71 return + operand 72 else: 73 return - operand 74 return _convert_num(node) 75 def _convert(node): 76 if isinstance(node, Constant): 77 return node.value 78 elif isinstance(node, Tuple): 79 return tuple(map(_convert, node.elts)) 80 elif isinstance(node, List): 81 return list(map(_convert, node.elts)) 82 elif isinstance(node, Set): 83 return set(map(_convert, node.elts)) 84 elif isinstance(node, Dict): 85 return dict(zip(map(_convert, node.keys), 86 map(_convert, node.values))) 87 elif isinstance(node, BinOp) and isinstance(node.op, (Add, Sub)): 88 left = _convert_signed_num(node.left) 89 right = _convert_num(node.right) 90 if isinstance(left, (int, float)) and isinstance(right, complex): 91 if isinstance(node.op, Add): 92 return left + right 93 else: 94 return left - right 95 return _convert_signed_num(node) 96 return _convert(node_or_string) 97 98 99def dump(node, annotate_fields=True, include_attributes=False): 100 """ 101 Return a formatted dump of the tree in node. This is mainly useful for 102 debugging purposes. If annotate_fields is true (by default), 103 the returned string will show the names and the values for fields. 104 If annotate_fields is false, the result string will be more compact by 105 omitting unambiguous field names. Attributes such as line 106 numbers and column offsets are not dumped by default. If this is wanted, 107 include_attributes can be set to true. 108 """ 109 def _format(node): 110 if isinstance(node, AST): 111 args = [] 112 keywords = annotate_fields 113 for field in node._fields: 114 try: 115 value = getattr(node, field) 116 except AttributeError: 117 keywords = True 118 else: 119 if keywords: 120 args.append('%s=%s' % (field, _format(value))) 121 else: 122 args.append(_format(value)) 123 if include_attributes and node._attributes: 124 for a in node._attributes: 125 try: 126 args.append('%s=%s' % (a, _format(getattr(node, a)))) 127 except AttributeError: 128 pass 129 return '%s(%s)' % (node.__class__.__name__, ', '.join(args)) 130 elif isinstance(node, list): 131 return '[%s]' % ', '.join(_format(x) for x in node) 132 return repr(node) 133 if not isinstance(node, AST): 134 raise TypeError('expected AST, got %r' % node.__class__.__name__) 135 return _format(node) 136 137 138def copy_location(new_node, old_node): 139 """ 140 Copy source location (`lineno`, `col_offset`, `end_lineno`, and `end_col_offset` 141 attributes) from *old_node* to *new_node* if possible, and return *new_node*. 142 """ 143 for attr in 'lineno', 'col_offset', 'end_lineno', 'end_col_offset': 144 if attr in old_node._attributes and attr in new_node._attributes \ 145 and hasattr(old_node, attr): 146 setattr(new_node, attr, getattr(old_node, attr)) 147 return new_node 148 149 150def fix_missing_locations(node): 151 """ 152 When you compile a node tree with compile(), the compiler expects lineno and 153 col_offset attributes for every node that supports them. This is rather 154 tedious to fill in for generated nodes, so this helper adds these attributes 155 recursively where not already set, by setting them to the values of the 156 parent node. It works recursively starting at *node*. 157 """ 158 def _fix(node, lineno, col_offset, end_lineno, end_col_offset): 159 if 'lineno' in node._attributes: 160 if not hasattr(node, 'lineno'): 161 node.lineno = lineno 162 else: 163 lineno = node.lineno 164 if 'end_lineno' in node._attributes: 165 if not hasattr(node, 'end_lineno'): 166 node.end_lineno = end_lineno 167 else: 168 end_lineno = node.end_lineno 169 if 'col_offset' in node._attributes: 170 if not hasattr(node, 'col_offset'): 171 node.col_offset = col_offset 172 else: 173 col_offset = node.col_offset 174 if 'end_col_offset' in node._attributes: 175 if not hasattr(node, 'end_col_offset'): 176 node.end_col_offset = end_col_offset 177 else: 178 end_col_offset = node.end_col_offset 179 for child in iter_child_nodes(node): 180 _fix(child, lineno, col_offset, end_lineno, end_col_offset) 181 _fix(node, 1, 0, 1, 0) 182 return node 183 184 185def increment_lineno(node, n=1): 186 """ 187 Increment the line number and end line number of each node in the tree 188 starting at *node* by *n*. This is useful to "move code" to a different 189 location in a file. 190 """ 191 for child in walk(node): 192 if 'lineno' in child._attributes: 193 child.lineno = getattr(child, 'lineno', 0) + n 194 if 'end_lineno' in child._attributes: 195 child.end_lineno = getattr(child, 'end_lineno', 0) + n 196 return node 197 198 199def iter_fields(node): 200 """ 201 Yield a tuple of ``(fieldname, value)`` for each field in ``node._fields`` 202 that is present on *node*. 203 """ 204 for field in node._fields: 205 try: 206 yield field, getattr(node, field) 207 except AttributeError: 208 pass 209 210 211def iter_child_nodes(node): 212 """ 213 Yield all direct child nodes of *node*, that is, all fields that are nodes 214 and all items of fields that are lists of nodes. 215 """ 216 for name, field in iter_fields(node): 217 if isinstance(field, AST): 218 yield field 219 elif isinstance(field, list): 220 for item in field: 221 if isinstance(item, AST): 222 yield item 223 224 225def get_docstring(node, clean=True): 226 """ 227 Return the docstring for the given node or None if no docstring can 228 be found. If the node provided does not have docstrings a TypeError 229 will be raised. 230 231 If *clean* is `True`, all tabs are expanded to spaces and any whitespace 232 that can be uniformly removed from the second line onwards is removed. 233 """ 234 if not isinstance(node, (AsyncFunctionDef, FunctionDef, ClassDef, Module)): 235 raise TypeError("%r can't have docstrings" % node.__class__.__name__) 236 if not(node.body and isinstance(node.body[0], Expr)): 237 return None 238 node = node.body[0].value 239 if isinstance(node, Str): 240 text = node.s 241 elif isinstance(node, Constant) and isinstance(node.value, str): 242 text = node.value 243 else: 244 return None 245 if clean: 246 import inspect 247 text = inspect.cleandoc(text) 248 return text 249 250 251def _splitlines_no_ff(source): 252 """Split a string into lines ignoring form feed and other chars. 253 254 This mimics how the Python parser splits source code. 255 """ 256 idx = 0 257 lines = [] 258 next_line = '' 259 while idx < len(source): 260 c = source[idx] 261 next_line += c 262 idx += 1 263 # Keep \r\n together 264 if c == '\r' and idx < len(source) and source[idx] == '\n': 265 next_line += '\n' 266 idx += 1 267 if c in '\r\n': 268 lines.append(next_line) 269 next_line = '' 270 271 if next_line: 272 lines.append(next_line) 273 return lines 274 275 276def _pad_whitespace(source): 277 """Replace all chars except '\f\t' in a line with spaces.""" 278 result = '' 279 for c in source: 280 if c in '\f\t': 281 result += c 282 else: 283 result += ' ' 284 return result 285 286 287def get_source_segment(source, node, *, padded=False): 288 """Get source code segment of the *source* that generated *node*. 289 290 If some location information (`lineno`, `end_lineno`, `col_offset`, 291 or `end_col_offset`) is missing, return None. 292 293 If *padded* is `True`, the first line of a multi-line statement will 294 be padded with spaces to match its original position. 295 """ 296 try: 297 lineno = node.lineno - 1 298 end_lineno = node.end_lineno - 1 299 col_offset = node.col_offset 300 end_col_offset = node.end_col_offset 301 except AttributeError: 302 return None 303 304 lines = _splitlines_no_ff(source) 305 if end_lineno == lineno: 306 return lines[lineno].encode()[col_offset:end_col_offset].decode() 307 308 if padded: 309 padding = _pad_whitespace(lines[lineno].encode()[:col_offset].decode()) 310 else: 311 padding = '' 312 313 first = padding + lines[lineno].encode()[col_offset:].decode() 314 last = lines[end_lineno].encode()[:end_col_offset].decode() 315 lines = lines[lineno+1:end_lineno] 316 317 lines.insert(0, first) 318 lines.append(last) 319 return ''.join(lines) 320 321 322def walk(node): 323 """ 324 Recursively yield all descendant nodes in the tree starting at *node* 325 (including *node* itself), in no specified order. This is useful if you 326 only want to modify nodes in place and don't care about the context. 327 """ 328 from collections import deque 329 todo = deque([node]) 330 while todo: 331 node = todo.popleft() 332 todo.extend(iter_child_nodes(node)) 333 yield node 334 335 336class NodeVisitor(object): 337 """ 338 A node visitor base class that walks the abstract syntax tree and calls a 339 visitor function for every node found. This function may return a value 340 which is forwarded by the `visit` method. 341 342 This class is meant to be subclassed, with the subclass adding visitor 343 methods. 344 345 Per default the visitor functions for the nodes are ``'visit_'`` + 346 class name of the node. So a `TryFinally` node visit function would 347 be `visit_TryFinally`. This behavior can be changed by overriding 348 the `visit` method. If no visitor function exists for a node 349 (return value `None`) the `generic_visit` visitor is used instead. 350 351 Don't use the `NodeVisitor` if you want to apply changes to nodes during 352 traversing. For this a special visitor exists (`NodeTransformer`) that 353 allows modifications. 354 """ 355 356 def visit(self, node): 357 """Visit a node.""" 358 method = 'visit_' + node.__class__.__name__ 359 visitor = getattr(self, method, self.generic_visit) 360 return visitor(node) 361 362 def generic_visit(self, node): 363 """Called if no explicit visitor function exists for a node.""" 364 for field, value in iter_fields(node): 365 if isinstance(value, list): 366 for item in value: 367 if isinstance(item, AST): 368 self.visit(item) 369 elif isinstance(value, AST): 370 self.visit(value) 371 372 def visit_Constant(self, node): 373 value = node.value 374 type_name = _const_node_type_names.get(type(value)) 375 if type_name is None: 376 for cls, name in _const_node_type_names.items(): 377 if isinstance(value, cls): 378 type_name = name 379 break 380 if type_name is not None: 381 method = 'visit_' + type_name 382 try: 383 visitor = getattr(self, method) 384 except AttributeError: 385 pass 386 else: 387 import warnings 388 warnings.warn(f"{method} is deprecated; add visit_Constant", 389 PendingDeprecationWarning, 2) 390 return visitor(node) 391 return self.generic_visit(node) 392 393 394class NodeTransformer(NodeVisitor): 395 """ 396 A :class:`NodeVisitor` subclass that walks the abstract syntax tree and 397 allows modification of nodes. 398 399 The `NodeTransformer` will walk the AST and use the return value of the 400 visitor methods to replace or remove the old node. If the return value of 401 the visitor method is ``None``, the node will be removed from its location, 402 otherwise it is replaced with the return value. The return value may be the 403 original node in which case no replacement takes place. 404 405 Here is an example transformer that rewrites all occurrences of name lookups 406 (``foo``) to ``data['foo']``:: 407 408 class RewriteName(NodeTransformer): 409 410 def visit_Name(self, node): 411 return copy_location(Subscript( 412 value=Name(id='data', ctx=Load()), 413 slice=Index(value=Str(s=node.id)), 414 ctx=node.ctx 415 ), node) 416 417 Keep in mind that if the node you're operating on has child nodes you must 418 either transform the child nodes yourself or call the :meth:`generic_visit` 419 method for the node first. 420 421 For nodes that were part of a collection of statements (that applies to all 422 statement nodes), the visitor may also return a list of nodes rather than 423 just a single node. 424 425 Usually you use the transformer like this:: 426 427 node = YourTransformer().visit(node) 428 """ 429 430 def generic_visit(self, node): 431 for field, old_value in iter_fields(node): 432 if isinstance(old_value, list): 433 new_values = [] 434 for value in old_value: 435 if isinstance(value, AST): 436 value = self.visit(value) 437 if value is None: 438 continue 439 elif not isinstance(value, AST): 440 new_values.extend(value) 441 continue 442 new_values.append(value) 443 old_value[:] = new_values 444 elif isinstance(old_value, AST): 445 new_node = self.visit(old_value) 446 if new_node is None: 447 delattr(node, field) 448 else: 449 setattr(node, field, new_node) 450 return node 451 452 453# The following code is for backward compatibility. 454# It will be removed in future. 455 456def _getter(self): 457 return self.value 458 459def _setter(self, value): 460 self.value = value 461 462Constant.n = property(_getter, _setter) 463Constant.s = property(_getter, _setter) 464 465class _ABC(type): 466 467 def __instancecheck__(cls, inst): 468 if not isinstance(inst, Constant): 469 return False 470 if cls in _const_types: 471 try: 472 value = inst.value 473 except AttributeError: 474 return False 475 else: 476 return ( 477 isinstance(value, _const_types[cls]) and 478 not isinstance(value, _const_types_not.get(cls, ())) 479 ) 480 return type.__instancecheck__(cls, inst) 481 482def _new(cls, *args, **kwargs): 483 if cls in _const_types: 484 return Constant(*args, **kwargs) 485 return Constant.__new__(cls, *args, **kwargs) 486 487class Num(Constant, metaclass=_ABC): 488 _fields = ('n',) 489 __new__ = _new 490 491class Str(Constant, metaclass=_ABC): 492 _fields = ('s',) 493 __new__ = _new 494 495class Bytes(Constant, metaclass=_ABC): 496 _fields = ('s',) 497 __new__ = _new 498 499class NameConstant(Constant, metaclass=_ABC): 500 __new__ = _new 501 502class Ellipsis(Constant, metaclass=_ABC): 503 _fields = () 504 505 def __new__(cls, *args, **kwargs): 506 if cls is Ellipsis: 507 return Constant(..., *args, **kwargs) 508 return Constant.__new__(cls, *args, **kwargs) 509 510_const_types = { 511 Num: (int, float, complex), 512 Str: (str,), 513 Bytes: (bytes,), 514 NameConstant: (type(None), bool), 515 Ellipsis: (type(...),), 516} 517_const_types_not = { 518 Num: (bool,), 519} 520_const_node_type_names = { 521 bool: 'NameConstant', # should be before int 522 type(None): 'NameConstant', 523 int: 'Num', 524 float: 'Num', 525 complex: 'Num', 526 str: 'Str', 527 bytes: 'Bytes', 528 type(...): 'Ellipsis', 529} 530