• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""
2    ast
3    ~~~
4
5    The `ast` module helps Python applications to process trees of the Python
6    abstract syntax grammar.  The abstract syntax itself might change with
7    each Python release; this module helps to find out programmatically what
8    the current grammar looks like and allows modifications of it.
9
10    An abstract syntax tree can be generated by passing `ast.PyCF_ONLY_AST` as
11    a flag to the `compile()` builtin function or by using the `parse()`
12    function from this module.  The result will be a tree of objects whose
13    classes all inherit from `ast.AST`.
14
15    A modified abstract syntax tree can be compiled into a Python code object
16    using the built-in `compile()` function.
17
18    Additionally various helper functions are provided that make working with
19    the trees simpler.  The main intention of the helper functions and this
20    module in general is to provide an easy to use interface for libraries
21    that work tightly with the python syntax (template engines for example).
22
23
24    :copyright: Copyright 2008 by Armin Ronacher.
25    :license: Python License.
26"""
27from _ast import *
28
29
30def parse(source, filename='<unknown>', mode='exec', *,
31          type_comments=False, feature_version=None):
32    """
33    Parse the source into an AST node.
34    Equivalent to compile(source, filename, mode, PyCF_ONLY_AST).
35    Pass type_comments=True to get back type comments where the syntax allows.
36    """
37    flags = PyCF_ONLY_AST
38    if type_comments:
39        flags |= PyCF_TYPE_COMMENTS
40    if isinstance(feature_version, tuple):
41        major, minor = feature_version  # Should be a 2-tuple.
42        assert major == 3
43        feature_version = minor
44    elif feature_version is None:
45        feature_version = -1
46    # Else it should be an int giving the minor version for 3.x.
47    return compile(source, filename, mode, flags,
48                   _feature_version=feature_version)
49
50
51def literal_eval(node_or_string):
52    """
53    Safely evaluate an expression node or a string containing a Python
54    expression.  The string or node provided may only consist of the following
55    Python literal structures: strings, bytes, numbers, tuples, lists, dicts,
56    sets, booleans, and None.
57    """
58    if isinstance(node_or_string, str):
59        node_or_string = parse(node_or_string, mode='eval')
60    if isinstance(node_or_string, Expression):
61        node_or_string = node_or_string.body
62    def _raise_malformed_node(node):
63        raise ValueError(f'malformed node or string: {node!r}')
64    def _convert_num(node):
65        if not isinstance(node, Constant) or type(node.value) not in (int, float, complex):
66            _raise_malformed_node(node)
67        return node.value
68    def _convert_signed_num(node):
69        if isinstance(node, UnaryOp) and isinstance(node.op, (UAdd, USub)):
70            operand = _convert_num(node.operand)
71            if isinstance(node.op, UAdd):
72                return + operand
73            else:
74                return - operand
75        return _convert_num(node)
76    def _convert(node):
77        if isinstance(node, Constant):
78            return node.value
79        elif isinstance(node, Tuple):
80            return tuple(map(_convert, node.elts))
81        elif isinstance(node, List):
82            return list(map(_convert, node.elts))
83        elif isinstance(node, Set):
84            return set(map(_convert, node.elts))
85        elif isinstance(node, Dict):
86            if len(node.keys) != len(node.values):
87                _raise_malformed_node(node)
88            return dict(zip(map(_convert, node.keys),
89                            map(_convert, node.values)))
90        elif isinstance(node, BinOp) and isinstance(node.op, (Add, Sub)):
91            left = _convert_signed_num(node.left)
92            right = _convert_num(node.right)
93            if isinstance(left, (int, float)) and isinstance(right, complex):
94                if isinstance(node.op, Add):
95                    return left + right
96                else:
97                    return left - right
98        return _convert_signed_num(node)
99    return _convert(node_or_string)
100
101
102def dump(node, annotate_fields=True, include_attributes=False):
103    """
104    Return a formatted dump of the tree in node.  This is mainly useful for
105    debugging purposes.  If annotate_fields is true (by default),
106    the returned string will show the names and the values for fields.
107    If annotate_fields is false, the result string will be more compact by
108    omitting unambiguous field names.  Attributes such as line
109    numbers and column offsets are not dumped by default.  If this is wanted,
110    include_attributes can be set to true.
111    """
112    def _format(node):
113        if isinstance(node, AST):
114            args = []
115            keywords = annotate_fields
116            for field in node._fields:
117                try:
118                    value = getattr(node, field)
119                except AttributeError:
120                    keywords = True
121                else:
122                    if keywords:
123                        args.append('%s=%s' % (field, _format(value)))
124                    else:
125                        args.append(_format(value))
126            if include_attributes and node._attributes:
127                for a in node._attributes:
128                    try:
129                        args.append('%s=%s' % (a, _format(getattr(node, a))))
130                    except AttributeError:
131                        pass
132            return '%s(%s)' % (node.__class__.__name__, ', '.join(args))
133        elif isinstance(node, list):
134            return '[%s]' % ', '.join(_format(x) for x in node)
135        return repr(node)
136    if not isinstance(node, AST):
137        raise TypeError('expected AST, got %r' % node.__class__.__name__)
138    return _format(node)
139
140
141def copy_location(new_node, old_node):
142    """
143    Copy source location (`lineno`, `col_offset`, `end_lineno`, and `end_col_offset`
144    attributes) from *old_node* to *new_node* if possible, and return *new_node*.
145    """
146    for attr in 'lineno', 'col_offset', 'end_lineno', 'end_col_offset':
147        if attr in old_node._attributes and attr in new_node._attributes \
148           and hasattr(old_node, attr):
149            setattr(new_node, attr, getattr(old_node, attr))
150    return new_node
151
152
153def fix_missing_locations(node):
154    """
155    When you compile a node tree with compile(), the compiler expects lineno and
156    col_offset attributes for every node that supports them.  This is rather
157    tedious to fill in for generated nodes, so this helper adds these attributes
158    recursively where not already set, by setting them to the values of the
159    parent node.  It works recursively starting at *node*.
160    """
161    def _fix(node, lineno, col_offset, end_lineno, end_col_offset):
162        if 'lineno' in node._attributes:
163            if not hasattr(node, 'lineno'):
164                node.lineno = lineno
165            else:
166                lineno = node.lineno
167        if 'end_lineno' in node._attributes:
168            if not hasattr(node, 'end_lineno'):
169                node.end_lineno = end_lineno
170            else:
171                end_lineno = node.end_lineno
172        if 'col_offset' in node._attributes:
173            if not hasattr(node, 'col_offset'):
174                node.col_offset = col_offset
175            else:
176                col_offset = node.col_offset
177        if 'end_col_offset' in node._attributes:
178            if not hasattr(node, 'end_col_offset'):
179                node.end_col_offset = end_col_offset
180            else:
181                end_col_offset = node.end_col_offset
182        for child in iter_child_nodes(node):
183            _fix(child, lineno, col_offset, end_lineno, end_col_offset)
184    _fix(node, 1, 0, 1, 0)
185    return node
186
187
188def increment_lineno(node, n=1):
189    """
190    Increment the line number and end line number of each node in the tree
191    starting at *node* by *n*. This is useful to "move code" to a different
192    location in a file.
193    """
194    for child in walk(node):
195        if 'lineno' in child._attributes:
196            child.lineno = getattr(child, 'lineno', 0) + n
197        if 'end_lineno' in child._attributes:
198            child.end_lineno = getattr(child, 'end_lineno', 0) + n
199    return node
200
201
202def iter_fields(node):
203    """
204    Yield a tuple of ``(fieldname, value)`` for each field in ``node._fields``
205    that is present on *node*.
206    """
207    for field in node._fields:
208        try:
209            yield field, getattr(node, field)
210        except AttributeError:
211            pass
212
213
214def iter_child_nodes(node):
215    """
216    Yield all direct child nodes of *node*, that is, all fields that are nodes
217    and all items of fields that are lists of nodes.
218    """
219    for name, field in iter_fields(node):
220        if isinstance(field, AST):
221            yield field
222        elif isinstance(field, list):
223            for item in field:
224                if isinstance(item, AST):
225                    yield item
226
227
228def get_docstring(node, clean=True):
229    """
230    Return the docstring for the given node or None if no docstring can
231    be found.  If the node provided does not have docstrings a TypeError
232    will be raised.
233
234    If *clean* is `True`, all tabs are expanded to spaces and any whitespace
235    that can be uniformly removed from the second line onwards is removed.
236    """
237    if not isinstance(node, (AsyncFunctionDef, FunctionDef, ClassDef, Module)):
238        raise TypeError("%r can't have docstrings" % node.__class__.__name__)
239    if not(node.body and isinstance(node.body[0], Expr)):
240        return None
241    node = node.body[0].value
242    if isinstance(node, Str):
243        text = node.s
244    elif isinstance(node, Constant) and isinstance(node.value, str):
245        text = node.value
246    else:
247        return None
248    if clean:
249        import inspect
250        text = inspect.cleandoc(text)
251    return text
252
253
254def _splitlines_no_ff(source):
255    """Split a string into lines ignoring form feed and other chars.
256
257    This mimics how the Python parser splits source code.
258    """
259    idx = 0
260    lines = []
261    next_line = ''
262    while idx < len(source):
263        c = source[idx]
264        next_line += c
265        idx += 1
266        # Keep \r\n together
267        if c == '\r' and idx < len(source) and source[idx] == '\n':
268            next_line += '\n'
269            idx += 1
270        if c in '\r\n':
271            lines.append(next_line)
272            next_line = ''
273
274    if next_line:
275        lines.append(next_line)
276    return lines
277
278
279def _pad_whitespace(source):
280    """Replace all chars except '\f\t' in a line with spaces."""
281    result = ''
282    for c in source:
283        if c in '\f\t':
284            result += c
285        else:
286            result += ' '
287    return result
288
289
290def get_source_segment(source, node, *, padded=False):
291    """Get source code segment of the *source* that generated *node*.
292
293    If some location information (`lineno`, `end_lineno`, `col_offset`,
294    or `end_col_offset`) is missing, return None.
295
296    If *padded* is `True`, the first line of a multi-line statement will
297    be padded with spaces to match its original position.
298    """
299    try:
300        lineno = node.lineno - 1
301        end_lineno = node.end_lineno - 1
302        col_offset = node.col_offset
303        end_col_offset = node.end_col_offset
304    except AttributeError:
305        return None
306
307    lines = _splitlines_no_ff(source)
308    if end_lineno == lineno:
309        return lines[lineno].encode()[col_offset:end_col_offset].decode()
310
311    if padded:
312        padding = _pad_whitespace(lines[lineno].encode()[:col_offset].decode())
313    else:
314        padding = ''
315
316    first = padding + lines[lineno].encode()[col_offset:].decode()
317    last = lines[end_lineno].encode()[:end_col_offset].decode()
318    lines = lines[lineno+1:end_lineno]
319
320    lines.insert(0, first)
321    lines.append(last)
322    return ''.join(lines)
323
324
325def walk(node):
326    """
327    Recursively yield all descendant nodes in the tree starting at *node*
328    (including *node* itself), in no specified order.  This is useful if you
329    only want to modify nodes in place and don't care about the context.
330    """
331    from collections import deque
332    todo = deque([node])
333    while todo:
334        node = todo.popleft()
335        todo.extend(iter_child_nodes(node))
336        yield node
337
338
339class NodeVisitor(object):
340    """
341    A node visitor base class that walks the abstract syntax tree and calls a
342    visitor function for every node found.  This function may return a value
343    which is forwarded by the `visit` method.
344
345    This class is meant to be subclassed, with the subclass adding visitor
346    methods.
347
348    Per default the visitor functions for the nodes are ``'visit_'`` +
349    class name of the node.  So a `TryFinally` node visit function would
350    be `visit_TryFinally`.  This behavior can be changed by overriding
351    the `visit` method.  If no visitor function exists for a node
352    (return value `None`) the `generic_visit` visitor is used instead.
353
354    Don't use the `NodeVisitor` if you want to apply changes to nodes during
355    traversing.  For this a special visitor exists (`NodeTransformer`) that
356    allows modifications.
357    """
358
359    def visit(self, node):
360        """Visit a node."""
361        method = 'visit_' + node.__class__.__name__
362        visitor = getattr(self, method, self.generic_visit)
363        return visitor(node)
364
365    def generic_visit(self, node):
366        """Called if no explicit visitor function exists for a node."""
367        for field, value in iter_fields(node):
368            if isinstance(value, list):
369                for item in value:
370                    if isinstance(item, AST):
371                        self.visit(item)
372            elif isinstance(value, AST):
373                self.visit(value)
374
375    def visit_Constant(self, node):
376        value = node.value
377        type_name = _const_node_type_names.get(type(value))
378        if type_name is None:
379            for cls, name in _const_node_type_names.items():
380                if isinstance(value, cls):
381                    type_name = name
382                    break
383        if type_name is not None:
384            method = 'visit_' + type_name
385            try:
386                visitor = getattr(self, method)
387            except AttributeError:
388                pass
389            else:
390                import warnings
391                warnings.warn(f"{method} is deprecated; add visit_Constant",
392                              PendingDeprecationWarning, 2)
393                return visitor(node)
394        return self.generic_visit(node)
395
396
397class NodeTransformer(NodeVisitor):
398    """
399    A :class:`NodeVisitor` subclass that walks the abstract syntax tree and
400    allows modification of nodes.
401
402    The `NodeTransformer` will walk the AST and use the return value of the
403    visitor methods to replace or remove the old node.  If the return value of
404    the visitor method is ``None``, the node will be removed from its location,
405    otherwise it is replaced with the return value.  The return value may be the
406    original node in which case no replacement takes place.
407
408    Here is an example transformer that rewrites all occurrences of name lookups
409    (``foo``) to ``data['foo']``::
410
411       class RewriteName(NodeTransformer):
412
413           def visit_Name(self, node):
414               return Subscript(
415                   value=Name(id='data', ctx=Load()),
416                   slice=Index(value=Str(s=node.id)),
417                   ctx=node.ctx
418               )
419
420    Keep in mind that if the node you're operating on has child nodes you must
421    either transform the child nodes yourself or call the :meth:`generic_visit`
422    method for the node first.
423
424    For nodes that were part of a collection of statements (that applies to all
425    statement nodes), the visitor may also return a list of nodes rather than
426    just a single node.
427
428    Usually you use the transformer like this::
429
430       node = YourTransformer().visit(node)
431    """
432
433    def generic_visit(self, node):
434        for field, old_value in iter_fields(node):
435            if isinstance(old_value, list):
436                new_values = []
437                for value in old_value:
438                    if isinstance(value, AST):
439                        value = self.visit(value)
440                        if value is None:
441                            continue
442                        elif not isinstance(value, AST):
443                            new_values.extend(value)
444                            continue
445                    new_values.append(value)
446                old_value[:] = new_values
447            elif isinstance(old_value, AST):
448                new_node = self.visit(old_value)
449                if new_node is None:
450                    delattr(node, field)
451                else:
452                    setattr(node, field, new_node)
453        return node
454
455
456# The following code is for backward compatibility.
457# It will be removed in future.
458
459def _getter(self):
460    return self.value
461
462def _setter(self, value):
463    self.value = value
464
465Constant.n = property(_getter, _setter)
466Constant.s = property(_getter, _setter)
467
468class _ABC(type):
469
470    def __instancecheck__(cls, inst):
471        if not isinstance(inst, Constant):
472            return False
473        if cls in _const_types:
474            try:
475                value = inst.value
476            except AttributeError:
477                return False
478            else:
479                return (
480                    isinstance(value, _const_types[cls]) and
481                    not isinstance(value, _const_types_not.get(cls, ()))
482                )
483        return type.__instancecheck__(cls, inst)
484
485def _new(cls, *args, **kwargs):
486    for key in kwargs:
487        if key not in cls._fields:
488            # arbitrary keyword arguments are accepted
489            continue
490        pos = cls._fields.index(key)
491        if pos < len(args):
492            raise TypeError(f"{cls.__name__} got multiple values for argument {key!r}")
493    if cls in _const_types:
494        return Constant(*args, **kwargs)
495    return Constant.__new__(cls, *args, **kwargs)
496
497class Num(Constant, metaclass=_ABC):
498    _fields = ('n',)
499    __new__ = _new
500
501class Str(Constant, metaclass=_ABC):
502    _fields = ('s',)
503    __new__ = _new
504
505class Bytes(Constant, metaclass=_ABC):
506    _fields = ('s',)
507    __new__ = _new
508
509class NameConstant(Constant, metaclass=_ABC):
510    __new__ = _new
511
512class Ellipsis(Constant, metaclass=_ABC):
513    _fields = ()
514
515    def __new__(cls, *args, **kwargs):
516        if cls is Ellipsis:
517            return Constant(..., *args, **kwargs)
518        return Constant.__new__(cls, *args, **kwargs)
519
520_const_types = {
521    Num: (int, float, complex),
522    Str: (str,),
523    Bytes: (bytes,),
524    NameConstant: (type(None), bool),
525    Ellipsis: (type(...),),
526}
527_const_types_not = {
528    Num: (bool,),
529}
530_const_node_type_names = {
531    bool: 'NameConstant',  # should be before int
532    type(None): 'NameConstant',
533    int: 'Num',
534    float: 'Num',
535    complex: 'Num',
536    str: 'Str',
537    bytes: 'Bytes',
538    type(...): 'Ellipsis',
539}
540