• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""
2    ast
3    ~~~
4
5    The `ast` module helps Python applications to process trees of the Python
6    abstract syntax grammar.  The abstract syntax itself might change with
7    each Python release; this module helps to find out programmatically what
8    the current grammar looks like and allows modifications of it.
9
10    An abstract syntax tree can be generated by passing `ast.PyCF_ONLY_AST` as
11    a flag to the `compile()` builtin function or by using the `parse()`
12    function from this module.  The result will be a tree of objects whose
13    classes all inherit from `ast.AST`.
14
15    A modified abstract syntax tree can be compiled into a Python code object
16    using the built-in `compile()` function.
17
18    Additionally various helper functions are provided that make working with
19    the trees simpler.  The main intention of the helper functions and this
20    module in general is to provide an easy to use interface for libraries
21    that work tightly with the python syntax (template engines for example).
22
23
24    :copyright: Copyright 2008 by Armin Ronacher.
25    :license: Python License.
26"""
27from _ast import *
28
29
30def parse(source, filename='<unknown>', mode='exec', *,
31          type_comments=False, feature_version=None):
32    """
33    Parse the source into an AST node.
34    Equivalent to compile(source, filename, mode, PyCF_ONLY_AST).
35    Pass type_comments=True to get back type comments where the syntax allows.
36    """
37    flags = PyCF_ONLY_AST
38    if type_comments:
39        flags |= PyCF_TYPE_COMMENTS
40    if isinstance(feature_version, tuple):
41        major, minor = feature_version  # Should be a 2-tuple.
42        assert major == 3
43        feature_version = minor
44    elif feature_version is None:
45        feature_version = -1
46    # Else it should be an int giving the minor version for 3.x.
47    return compile(source, filename, mode, flags,
48                   _feature_version=feature_version)
49
50
51def literal_eval(node_or_string):
52    """
53    Safely evaluate an expression node or a string containing a Python
54    expression.  The string or node provided may only consist of the following
55    Python literal structures: strings, bytes, numbers, tuples, lists, dicts,
56    sets, booleans, and None.
57    """
58    if isinstance(node_or_string, str):
59        node_or_string = parse(node_or_string, mode='eval')
60    if isinstance(node_or_string, Expression):
61        node_or_string = node_or_string.body
62    def _convert_num(node):
63        if isinstance(node, Constant):
64            if type(node.value) in (int, float, complex):
65                return node.value
66        raise ValueError('malformed node or string: ' + repr(node))
67    def _convert_signed_num(node):
68        if isinstance(node, UnaryOp) and isinstance(node.op, (UAdd, USub)):
69            operand = _convert_num(node.operand)
70            if isinstance(node.op, UAdd):
71                return + operand
72            else:
73                return - operand
74        return _convert_num(node)
75    def _convert(node):
76        if isinstance(node, Constant):
77            return node.value
78        elif isinstance(node, Tuple):
79            return tuple(map(_convert, node.elts))
80        elif isinstance(node, List):
81            return list(map(_convert, node.elts))
82        elif isinstance(node, Set):
83            return set(map(_convert, node.elts))
84        elif isinstance(node, Dict):
85            return dict(zip(map(_convert, node.keys),
86                            map(_convert, node.values)))
87        elif isinstance(node, BinOp) and isinstance(node.op, (Add, Sub)):
88            left = _convert_signed_num(node.left)
89            right = _convert_num(node.right)
90            if isinstance(left, (int, float)) and isinstance(right, complex):
91                if isinstance(node.op, Add):
92                    return left + right
93                else:
94                    return left - right
95        return _convert_signed_num(node)
96    return _convert(node_or_string)
97
98
99def dump(node, annotate_fields=True, include_attributes=False):
100    """
101    Return a formatted dump of the tree in node.  This is mainly useful for
102    debugging purposes.  If annotate_fields is true (by default),
103    the returned string will show the names and the values for fields.
104    If annotate_fields is false, the result string will be more compact by
105    omitting unambiguous field names.  Attributes such as line
106    numbers and column offsets are not dumped by default.  If this is wanted,
107    include_attributes can be set to true.
108    """
109    def _format(node):
110        if isinstance(node, AST):
111            args = []
112            keywords = annotate_fields
113            for field in node._fields:
114                try:
115                    value = getattr(node, field)
116                except AttributeError:
117                    keywords = True
118                else:
119                    if keywords:
120                        args.append('%s=%s' % (field, _format(value)))
121                    else:
122                        args.append(_format(value))
123            if include_attributes and node._attributes:
124                for a in node._attributes:
125                    try:
126                        args.append('%s=%s' % (a, _format(getattr(node, a))))
127                    except AttributeError:
128                        pass
129            return '%s(%s)' % (node.__class__.__name__, ', '.join(args))
130        elif isinstance(node, list):
131            return '[%s]' % ', '.join(_format(x) for x in node)
132        return repr(node)
133    if not isinstance(node, AST):
134        raise TypeError('expected AST, got %r' % node.__class__.__name__)
135    return _format(node)
136
137
138def copy_location(new_node, old_node):
139    """
140    Copy source location (`lineno`, `col_offset`, `end_lineno`, and `end_col_offset`
141    attributes) from *old_node* to *new_node* if possible, and return *new_node*.
142    """
143    for attr in 'lineno', 'col_offset', 'end_lineno', 'end_col_offset':
144        if attr in old_node._attributes and attr in new_node._attributes \
145           and hasattr(old_node, attr):
146            setattr(new_node, attr, getattr(old_node, attr))
147    return new_node
148
149
150def fix_missing_locations(node):
151    """
152    When you compile a node tree with compile(), the compiler expects lineno and
153    col_offset attributes for every node that supports them.  This is rather
154    tedious to fill in for generated nodes, so this helper adds these attributes
155    recursively where not already set, by setting them to the values of the
156    parent node.  It works recursively starting at *node*.
157    """
158    def _fix(node, lineno, col_offset, end_lineno, end_col_offset):
159        if 'lineno' in node._attributes:
160            if not hasattr(node, 'lineno'):
161                node.lineno = lineno
162            else:
163                lineno = node.lineno
164        if 'end_lineno' in node._attributes:
165            if not hasattr(node, 'end_lineno'):
166                node.end_lineno = end_lineno
167            else:
168                end_lineno = node.end_lineno
169        if 'col_offset' in node._attributes:
170            if not hasattr(node, 'col_offset'):
171                node.col_offset = col_offset
172            else:
173                col_offset = node.col_offset
174        if 'end_col_offset' in node._attributes:
175            if not hasattr(node, 'end_col_offset'):
176                node.end_col_offset = end_col_offset
177            else:
178                end_col_offset = node.end_col_offset
179        for child in iter_child_nodes(node):
180            _fix(child, lineno, col_offset, end_lineno, end_col_offset)
181    _fix(node, 1, 0, 1, 0)
182    return node
183
184
185def increment_lineno(node, n=1):
186    """
187    Increment the line number and end line number of each node in the tree
188    starting at *node* by *n*. This is useful to "move code" to a different
189    location in a file.
190    """
191    for child in walk(node):
192        if 'lineno' in child._attributes:
193            child.lineno = getattr(child, 'lineno', 0) + n
194        if 'end_lineno' in child._attributes:
195            child.end_lineno = getattr(child, 'end_lineno', 0) + n
196    return node
197
198
199def iter_fields(node):
200    """
201    Yield a tuple of ``(fieldname, value)`` for each field in ``node._fields``
202    that is present on *node*.
203    """
204    for field in node._fields:
205        try:
206            yield field, getattr(node, field)
207        except AttributeError:
208            pass
209
210
211def iter_child_nodes(node):
212    """
213    Yield all direct child nodes of *node*, that is, all fields that are nodes
214    and all items of fields that are lists of nodes.
215    """
216    for name, field in iter_fields(node):
217        if isinstance(field, AST):
218            yield field
219        elif isinstance(field, list):
220            for item in field:
221                if isinstance(item, AST):
222                    yield item
223
224
225def get_docstring(node, clean=True):
226    """
227    Return the docstring for the given node or None if no docstring can
228    be found.  If the node provided does not have docstrings a TypeError
229    will be raised.
230
231    If *clean* is `True`, all tabs are expanded to spaces and any whitespace
232    that can be uniformly removed from the second line onwards is removed.
233    """
234    if not isinstance(node, (AsyncFunctionDef, FunctionDef, ClassDef, Module)):
235        raise TypeError("%r can't have docstrings" % node.__class__.__name__)
236    if not(node.body and isinstance(node.body[0], Expr)):
237        return None
238    node = node.body[0].value
239    if isinstance(node, Str):
240        text = node.s
241    elif isinstance(node, Constant) and isinstance(node.value, str):
242        text = node.value
243    else:
244        return None
245    if clean:
246        import inspect
247        text = inspect.cleandoc(text)
248    return text
249
250
251def _splitlines_no_ff(source):
252    """Split a string into lines ignoring form feed and other chars.
253
254    This mimics how the Python parser splits source code.
255    """
256    idx = 0
257    lines = []
258    next_line = ''
259    while idx < len(source):
260        c = source[idx]
261        next_line += c
262        idx += 1
263        # Keep \r\n together
264        if c == '\r' and idx < len(source) and source[idx] == '\n':
265            next_line += '\n'
266            idx += 1
267        if c in '\r\n':
268            lines.append(next_line)
269            next_line = ''
270
271    if next_line:
272        lines.append(next_line)
273    return lines
274
275
276def _pad_whitespace(source):
277    """Replace all chars except '\f\t' in a line with spaces."""
278    result = ''
279    for c in source:
280        if c in '\f\t':
281            result += c
282        else:
283            result += ' '
284    return result
285
286
287def get_source_segment(source, node, *, padded=False):
288    """Get source code segment of the *source* that generated *node*.
289
290    If some location information (`lineno`, `end_lineno`, `col_offset`,
291    or `end_col_offset`) is missing, return None.
292
293    If *padded* is `True`, the first line of a multi-line statement will
294    be padded with spaces to match its original position.
295    """
296    try:
297        lineno = node.lineno - 1
298        end_lineno = node.end_lineno - 1
299        col_offset = node.col_offset
300        end_col_offset = node.end_col_offset
301    except AttributeError:
302        return None
303
304    lines = _splitlines_no_ff(source)
305    if end_lineno == lineno:
306        return lines[lineno].encode()[col_offset:end_col_offset].decode()
307
308    if padded:
309        padding = _pad_whitespace(lines[lineno].encode()[:col_offset].decode())
310    else:
311        padding = ''
312
313    first = padding + lines[lineno].encode()[col_offset:].decode()
314    last = lines[end_lineno].encode()[:end_col_offset].decode()
315    lines = lines[lineno+1:end_lineno]
316
317    lines.insert(0, first)
318    lines.append(last)
319    return ''.join(lines)
320
321
322def walk(node):
323    """
324    Recursively yield all descendant nodes in the tree starting at *node*
325    (including *node* itself), in no specified order.  This is useful if you
326    only want to modify nodes in place and don't care about the context.
327    """
328    from collections import deque
329    todo = deque([node])
330    while todo:
331        node = todo.popleft()
332        todo.extend(iter_child_nodes(node))
333        yield node
334
335
336class NodeVisitor(object):
337    """
338    A node visitor base class that walks the abstract syntax tree and calls a
339    visitor function for every node found.  This function may return a value
340    which is forwarded by the `visit` method.
341
342    This class is meant to be subclassed, with the subclass adding visitor
343    methods.
344
345    Per default the visitor functions for the nodes are ``'visit_'`` +
346    class name of the node.  So a `TryFinally` node visit function would
347    be `visit_TryFinally`.  This behavior can be changed by overriding
348    the `visit` method.  If no visitor function exists for a node
349    (return value `None`) the `generic_visit` visitor is used instead.
350
351    Don't use the `NodeVisitor` if you want to apply changes to nodes during
352    traversing.  For this a special visitor exists (`NodeTransformer`) that
353    allows modifications.
354    """
355
356    def visit(self, node):
357        """Visit a node."""
358        method = 'visit_' + node.__class__.__name__
359        visitor = getattr(self, method, self.generic_visit)
360        return visitor(node)
361
362    def generic_visit(self, node):
363        """Called if no explicit visitor function exists for a node."""
364        for field, value in iter_fields(node):
365            if isinstance(value, list):
366                for item in value:
367                    if isinstance(item, AST):
368                        self.visit(item)
369            elif isinstance(value, AST):
370                self.visit(value)
371
372    def visit_Constant(self, node):
373        value = node.value
374        type_name = _const_node_type_names.get(type(value))
375        if type_name is None:
376            for cls, name in _const_node_type_names.items():
377                if isinstance(value, cls):
378                    type_name = name
379                    break
380        if type_name is not None:
381            method = 'visit_' + type_name
382            try:
383                visitor = getattr(self, method)
384            except AttributeError:
385                pass
386            else:
387                import warnings
388                warnings.warn(f"{method} is deprecated; add visit_Constant",
389                              PendingDeprecationWarning, 2)
390                return visitor(node)
391        return self.generic_visit(node)
392
393
394class NodeTransformer(NodeVisitor):
395    """
396    A :class:`NodeVisitor` subclass that walks the abstract syntax tree and
397    allows modification of nodes.
398
399    The `NodeTransformer` will walk the AST and use the return value of the
400    visitor methods to replace or remove the old node.  If the return value of
401    the visitor method is ``None``, the node will be removed from its location,
402    otherwise it is replaced with the return value.  The return value may be the
403    original node in which case no replacement takes place.
404
405    Here is an example transformer that rewrites all occurrences of name lookups
406    (``foo``) to ``data['foo']``::
407
408       class RewriteName(NodeTransformer):
409
410           def visit_Name(self, node):
411               return copy_location(Subscript(
412                   value=Name(id='data', ctx=Load()),
413                   slice=Index(value=Str(s=node.id)),
414                   ctx=node.ctx
415               ), node)
416
417    Keep in mind that if the node you're operating on has child nodes you must
418    either transform the child nodes yourself or call the :meth:`generic_visit`
419    method for the node first.
420
421    For nodes that were part of a collection of statements (that applies to all
422    statement nodes), the visitor may also return a list of nodes rather than
423    just a single node.
424
425    Usually you use the transformer like this::
426
427       node = YourTransformer().visit(node)
428    """
429
430    def generic_visit(self, node):
431        for field, old_value in iter_fields(node):
432            if isinstance(old_value, list):
433                new_values = []
434                for value in old_value:
435                    if isinstance(value, AST):
436                        value = self.visit(value)
437                        if value is None:
438                            continue
439                        elif not isinstance(value, AST):
440                            new_values.extend(value)
441                            continue
442                    new_values.append(value)
443                old_value[:] = new_values
444            elif isinstance(old_value, AST):
445                new_node = self.visit(old_value)
446                if new_node is None:
447                    delattr(node, field)
448                else:
449                    setattr(node, field, new_node)
450        return node
451
452
453# The following code is for backward compatibility.
454# It will be removed in future.
455
456def _getter(self):
457    return self.value
458
459def _setter(self, value):
460    self.value = value
461
462Constant.n = property(_getter, _setter)
463Constant.s = property(_getter, _setter)
464
465class _ABC(type):
466
467    def __instancecheck__(cls, inst):
468        if not isinstance(inst, Constant):
469            return False
470        if cls in _const_types:
471            try:
472                value = inst.value
473            except AttributeError:
474                return False
475            else:
476                return (
477                    isinstance(value, _const_types[cls]) and
478                    not isinstance(value, _const_types_not.get(cls, ()))
479                )
480        return type.__instancecheck__(cls, inst)
481
482def _new(cls, *args, **kwargs):
483    if cls in _const_types:
484        return Constant(*args, **kwargs)
485    return Constant.__new__(cls, *args, **kwargs)
486
487class Num(Constant, metaclass=_ABC):
488    _fields = ('n',)
489    __new__ = _new
490
491class Str(Constant, metaclass=_ABC):
492    _fields = ('s',)
493    __new__ = _new
494
495class Bytes(Constant, metaclass=_ABC):
496    _fields = ('s',)
497    __new__ = _new
498
499class NameConstant(Constant, metaclass=_ABC):
500    __new__ = _new
501
502class Ellipsis(Constant, metaclass=_ABC):
503    _fields = ()
504
505    def __new__(cls, *args, **kwargs):
506        if cls is Ellipsis:
507            return Constant(..., *args, **kwargs)
508        return Constant.__new__(cls, *args, **kwargs)
509
510_const_types = {
511    Num: (int, float, complex),
512    Str: (str,),
513    Bytes: (bytes,),
514    NameConstant: (type(None), bool),
515    Ellipsis: (type(...),),
516}
517_const_types_not = {
518    Num: (bool,),
519}
520_const_node_type_names = {
521    bool: 'NameConstant',  # should be before int
522    type(None): 'NameConstant',
523    int: 'Num',
524    float: 'Num',
525    complex: 'Num',
526    str: 'Str',
527    bytes: 'Bytes',
528    type(...): 'Ellipsis',
529}
530