• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""
2    ast
3    ~~~
4
5    The `ast` module helps Python applications to process trees of the Python
6    abstract syntax grammar.  The abstract syntax itself might change with
7    each Python release; this module helps to find out programmatically what
8    the current grammar looks like and allows modifications of it.
9
10    An abstract syntax tree can be generated by passing `ast.PyCF_ONLY_AST` as
11    a flag to the `compile()` builtin function or by using the `parse()`
12    function from this module.  The result will be a tree of objects whose
13    classes all inherit from `ast.AST`.
14
15    A modified abstract syntax tree can be compiled into a Python code object
16    using the built-in `compile()` function.
17
18    Additionally various helper functions are provided that make working with
19    the trees simpler.  The main intention of the helper functions and this
20    module in general is to provide an easy to use interface for libraries
21    that work tightly with the python syntax (template engines for example).
22
23
24    :copyright: Copyright 2008 by Armin Ronacher.
25    :license: Python License.
26"""
27from _ast import *
28
29
30def parse(source, filename='<unknown>', mode='exec'):
31    """
32    Parse the source into an AST node.
33    Equivalent to compile(source, filename, mode, PyCF_ONLY_AST).
34    """
35    return compile(source, filename, mode, PyCF_ONLY_AST)
36
37
38_NUM_TYPES = (int, float, complex)
39
40def literal_eval(node_or_string):
41    """
42    Safely evaluate an expression node or a string containing a Python
43    expression.  The string or node provided may only consist of the following
44    Python literal structures: strings, bytes, numbers, tuples, lists, dicts,
45    sets, booleans, and None.
46    """
47    if isinstance(node_or_string, str):
48        node_or_string = parse(node_or_string, mode='eval')
49    if isinstance(node_or_string, Expression):
50        node_or_string = node_or_string.body
51    def _convert(node):
52        if isinstance(node, Constant):
53            return node.value
54        elif isinstance(node, (Str, Bytes)):
55            return node.s
56        elif isinstance(node, Num):
57            return node.n
58        elif isinstance(node, Tuple):
59            return tuple(map(_convert, node.elts))
60        elif isinstance(node, List):
61            return list(map(_convert, node.elts))
62        elif isinstance(node, Set):
63            return set(map(_convert, node.elts))
64        elif isinstance(node, Dict):
65            return dict((_convert(k), _convert(v)) for k, v
66                        in zip(node.keys, node.values))
67        elif isinstance(node, NameConstant):
68            return node.value
69        elif isinstance(node, UnaryOp) and isinstance(node.op, (UAdd, USub)):
70            operand = _convert(node.operand)
71            if isinstance(operand, _NUM_TYPES):
72                if isinstance(node.op, UAdd):
73                    return + operand
74                else:
75                    return - operand
76        elif isinstance(node, BinOp) and isinstance(node.op, (Add, Sub)):
77            left = _convert(node.left)
78            right = _convert(node.right)
79            if isinstance(left, _NUM_TYPES) and isinstance(right, _NUM_TYPES):
80                if isinstance(node.op, Add):
81                    return left + right
82                else:
83                    return left - right
84        raise ValueError('malformed node or string: ' + repr(node))
85    return _convert(node_or_string)
86
87
88def dump(node, annotate_fields=True, include_attributes=False):
89    """
90    Return a formatted dump of the tree in *node*.  This is mainly useful for
91    debugging purposes.  The returned string will show the names and the values
92    for fields.  This makes the code impossible to evaluate, so if evaluation is
93    wanted *annotate_fields* must be set to False.  Attributes such as line
94    numbers and column offsets are not dumped by default.  If this is wanted,
95    *include_attributes* can be set to True.
96    """
97    def _format(node):
98        if isinstance(node, AST):
99            fields = [(a, _format(b)) for a, b in iter_fields(node)]
100            rv = '%s(%s' % (node.__class__.__name__, ', '.join(
101                ('%s=%s' % field for field in fields)
102                if annotate_fields else
103                (b for a, b in fields)
104            ))
105            if include_attributes and node._attributes:
106                rv += fields and ', ' or ' '
107                rv += ', '.join('%s=%s' % (a, _format(getattr(node, a)))
108                                for a in node._attributes)
109            return rv + ')'
110        elif isinstance(node, list):
111            return '[%s]' % ', '.join(_format(x) for x in node)
112        return repr(node)
113    if not isinstance(node, AST):
114        raise TypeError('expected AST, got %r' % node.__class__.__name__)
115    return _format(node)
116
117
118def copy_location(new_node, old_node):
119    """
120    Copy source location (`lineno` and `col_offset` attributes) from
121    *old_node* to *new_node* if possible, and return *new_node*.
122    """
123    for attr in 'lineno', 'col_offset':
124        if attr in old_node._attributes and attr in new_node._attributes \
125           and hasattr(old_node, attr):
126            setattr(new_node, attr, getattr(old_node, attr))
127    return new_node
128
129
130def fix_missing_locations(node):
131    """
132    When you compile a node tree with compile(), the compiler expects lineno and
133    col_offset attributes for every node that supports them.  This is rather
134    tedious to fill in for generated nodes, so this helper adds these attributes
135    recursively where not already set, by setting them to the values of the
136    parent node.  It works recursively starting at *node*.
137    """
138    def _fix(node, lineno, col_offset):
139        if 'lineno' in node._attributes:
140            if not hasattr(node, 'lineno'):
141                node.lineno = lineno
142            else:
143                lineno = node.lineno
144        if 'col_offset' in node._attributes:
145            if not hasattr(node, 'col_offset'):
146                node.col_offset = col_offset
147            else:
148                col_offset = node.col_offset
149        for child in iter_child_nodes(node):
150            _fix(child, lineno, col_offset)
151    _fix(node, 1, 0)
152    return node
153
154
155def increment_lineno(node, n=1):
156    """
157    Increment the line number of each node in the tree starting at *node* by *n*.
158    This is useful to "move code" to a different location in a file.
159    """
160    for child in walk(node):
161        if 'lineno' in child._attributes:
162            child.lineno = getattr(child, 'lineno', 0) + n
163    return node
164
165
166def iter_fields(node):
167    """
168    Yield a tuple of ``(fieldname, value)`` for each field in ``node._fields``
169    that is present on *node*.
170    """
171    for field in node._fields:
172        try:
173            yield field, getattr(node, field)
174        except AttributeError:
175            pass
176
177
178def iter_child_nodes(node):
179    """
180    Yield all direct child nodes of *node*, that is, all fields that are nodes
181    and all items of fields that are lists of nodes.
182    """
183    for name, field in iter_fields(node):
184        if isinstance(field, AST):
185            yield field
186        elif isinstance(field, list):
187            for item in field:
188                if isinstance(item, AST):
189                    yield item
190
191
192def get_docstring(node, clean=True):
193    """
194    Return the docstring for the given node or None if no docstring can
195    be found.  If the node provided does not have docstrings a TypeError
196    will be raised.
197    """
198    if not isinstance(node, (AsyncFunctionDef, FunctionDef, ClassDef, Module)):
199        raise TypeError("%r can't have docstrings" % node.__class__.__name__)
200    if not(node.body and isinstance(node.body[0], Expr)):
201        return
202    node = node.body[0].value
203    if isinstance(node, Str):
204        text = node.s
205    elif isinstance(node, Constant) and isinstance(node.value, str):
206        text = node.value
207    else:
208        return
209    if clean:
210        import inspect
211        text = inspect.cleandoc(text)
212    return text
213
214
215def walk(node):
216    """
217    Recursively yield all descendant nodes in the tree starting at *node*
218    (including *node* itself), in no specified order.  This is useful if you
219    only want to modify nodes in place and don't care about the context.
220    """
221    from collections import deque
222    todo = deque([node])
223    while todo:
224        node = todo.popleft()
225        todo.extend(iter_child_nodes(node))
226        yield node
227
228
229class NodeVisitor(object):
230    """
231    A node visitor base class that walks the abstract syntax tree and calls a
232    visitor function for every node found.  This function may return a value
233    which is forwarded by the `visit` method.
234
235    This class is meant to be subclassed, with the subclass adding visitor
236    methods.
237
238    Per default the visitor functions for the nodes are ``'visit_'`` +
239    class name of the node.  So a `TryFinally` node visit function would
240    be `visit_TryFinally`.  This behavior can be changed by overriding
241    the `visit` method.  If no visitor function exists for a node
242    (return value `None`) the `generic_visit` visitor is used instead.
243
244    Don't use the `NodeVisitor` if you want to apply changes to nodes during
245    traversing.  For this a special visitor exists (`NodeTransformer`) that
246    allows modifications.
247    """
248
249    def visit(self, node):
250        """Visit a node."""
251        method = 'visit_' + node.__class__.__name__
252        visitor = getattr(self, method, self.generic_visit)
253        return visitor(node)
254
255    def generic_visit(self, node):
256        """Called if no explicit visitor function exists for a node."""
257        for field, value in iter_fields(node):
258            if isinstance(value, list):
259                for item in value:
260                    if isinstance(item, AST):
261                        self.visit(item)
262            elif isinstance(value, AST):
263                self.visit(value)
264
265
266class NodeTransformer(NodeVisitor):
267    """
268    A :class:`NodeVisitor` subclass that walks the abstract syntax tree and
269    allows modification of nodes.
270
271    The `NodeTransformer` will walk the AST and use the return value of the
272    visitor methods to replace or remove the old node.  If the return value of
273    the visitor method is ``None``, the node will be removed from its location,
274    otherwise it is replaced with the return value.  The return value may be the
275    original node in which case no replacement takes place.
276
277    Here is an example transformer that rewrites all occurrences of name lookups
278    (``foo``) to ``data['foo']``::
279
280       class RewriteName(NodeTransformer):
281
282           def visit_Name(self, node):
283               return copy_location(Subscript(
284                   value=Name(id='data', ctx=Load()),
285                   slice=Index(value=Str(s=node.id)),
286                   ctx=node.ctx
287               ), node)
288
289    Keep in mind that if the node you're operating on has child nodes you must
290    either transform the child nodes yourself or call the :meth:`generic_visit`
291    method for the node first.
292
293    For nodes that were part of a collection of statements (that applies to all
294    statement nodes), the visitor may also return a list of nodes rather than
295    just a single node.
296
297    Usually you use the transformer like this::
298
299       node = YourTransformer().visit(node)
300    """
301
302    def generic_visit(self, node):
303        for field, old_value in iter_fields(node):
304            if isinstance(old_value, list):
305                new_values = []
306                for value in old_value:
307                    if isinstance(value, AST):
308                        value = self.visit(value)
309                        if value is None:
310                            continue
311                        elif not isinstance(value, AST):
312                            new_values.extend(value)
313                            continue
314                    new_values.append(value)
315                old_value[:] = new_values
316            elif isinstance(old_value, AST):
317                new_node = self.visit(old_value)
318                if new_node is None:
319                    delattr(node, field)
320                else:
321                    setattr(node, field, new_node)
322        return node
323