1"""Utility functions, node construction macros, etc."""
2# Author: Collin Winter
4from itertools import islice
6# Local imports
7from .pgen2 import token
8from .pytree import Leaf, Node
9from .pygram import python_symbols as syms
10from . import patcomp
14### Common node-construction "macros"
17def KeywordArg(keyword, value):
18    return Node(syms.argument,
19                [keyword, Leaf(token.EQUAL, u"="), value])
21def LParen():
22    return Leaf(token.LPAR, u"(")
24def RParen():
25    return Leaf(token.RPAR, u")")
27def Assign(target, source):
28    """Build an assignment statement"""
29    if not isinstance(target, list):
30        target = [target]
31    if not isinstance(source, list):
32        source.prefix = u" "
33        source = [source]
35    return Node(syms.atom,
36                target + [Leaf(token.EQUAL, u"=", prefix=u" ")] + source)
38def Name(name, prefix=None):
39    """Return a NAME leaf"""
40    return Leaf(token.NAME, name, prefix=prefix)
42def Attr(obj, attr):
43    """A node tuple for obj.attr"""
44    return [obj, Node(syms.trailer, [Dot(), attr])]
46def Comma():
47    """A comma leaf"""
48    return Leaf(token.COMMA, u",")
50def Dot():
51    """A period (.) leaf"""
52    return Leaf(token.DOT, u".")
54def ArgList(args, lparen=LParen(), rparen=RParen()):
55    """A parenthesised argument list, used by Call()"""
56    node = Node(syms.trailer, [lparen.clone(), rparen.clone()])
57    if args:
58        node.insert_child(1, Node(syms.arglist, args))
59    return node
61def Call(func_name, args=None, prefix=None):
62    """A function call"""
63    node = Node(syms.power, [func_name, ArgList(args)])
64    if prefix is not None:
65        node.prefix = prefix
66    return node
68def Newline():
69    """A newline literal"""
70    return Leaf(token.NEWLINE, u"\n")
72def BlankLine():
73    """A blank line"""
74    return Leaf(token.NEWLINE, u"")
76def Number(n, prefix=None):
77    return Leaf(token.NUMBER, n, prefix=prefix)
79def Subscript(index_node):
80    """A numeric or string subscript"""
81    return Node(syms.trailer, [Leaf(token.LBRACE, u"["),
82                               index_node,
83                               Leaf(token.RBRACE, u"]")])
85def String(string, prefix=None):
86    """A string leaf"""
87    return Leaf(token.STRING, string, prefix=prefix)
89def ListComp(xp, fp, it, test=None):
90    """A list comprehension of the form [xp for fp in it if test].
92    If test is None, the "if test" part is omitted.
93    """
94    xp.prefix = u""
95    fp.prefix = u" "
96    it.prefix = u" "
97    for_leaf = Leaf(token.NAME, u"for")
98    for_leaf.prefix = u" "
99    in_leaf = Leaf(token.NAME, u"in")
100    in_leaf.prefix = u" "
101    inner_args = [for_leaf, fp, in_leaf, it]
102    if test:
103        test.prefix = u" "
104        if_leaf = Leaf(token.NAME, u"if")
105        if_leaf.prefix = u" "
106        inner_args.append(Node(syms.comp_if, [if_leaf, test]))
107    inner = Node(syms.listmaker, [xp, Node(syms.comp_for, inner_args)])
108    return Node(syms.atom,
109                       [Leaf(token.LBRACE, u"["),
110                        inner,
111                        Leaf(token.RBRACE, u"]")])
113def FromImport(package_name, name_leafs):
114    """ Return an import statement in the form:
115        from package import name_leafs"""
116    # XXX: May not handle dotted imports properly (eg, package_name='foo.bar')
117    #assert package_name == '.' or '.' not in package_name, "FromImport has "\
118    #       "not been tested with dotted package names -- use at your own "\
119    #       "peril!"
121    for leaf in name_leafs:
122        # Pull the leaves out of their old tree
123        leaf.remove()
125    children = [Leaf(token.NAME, u"from"),
126                Leaf(token.NAME, package_name, prefix=u" "),
127                Leaf(token.NAME, u"import", prefix=u" "),
128                Node(syms.import_as_names, name_leafs)]
129    imp = Node(syms.import_from, children)
130    return imp
134### Determine whether a node represents a given literal
137def is_tuple(node):
138    """Does the node represent a tuple literal?"""
139    if isinstance(node, Node) and node.children == [LParen(), RParen()]:
140        return True
141    return (isinstance(node, Node)
142            and len(node.children) == 3
143            and isinstance(node.children[0], Leaf)
144            and isinstance(node.children[1], Node)
145            and isinstance(node.children[2], Leaf)
146            and node.children[0].value == u"("
147            and node.children[2].value == u")")
149def is_list(node):
150    """Does the node represent a list literal?"""
151    return (isinstance(node, Node)
152            and len(node.children) > 1
153            and isinstance(node.children[0], Leaf)
154            and isinstance(node.children[-1], Leaf)
155            and node.children[0].value == u"["
156            and node.children[-1].value == u"]")
160### Misc
163def parenthesize(node):
164    return Node(syms.atom, [LParen(), node, RParen()])
167consuming_calls = set(["sorted", "list", "set", "any", "all", "tuple", "sum",
168                       "min", "max", "enumerate"])
170def attr_chain(obj, attr):
171    """Follow an attribute chain.
173    If you have a chain of objects where a.foo -> b, b.foo-> c, etc,
174    use this to iterate over all objects in the chain. Iteration is
175    terminated by getattr(x, attr) is None.
177    Args:
178        obj: the starting object
179        attr: the name of the chaining attribute
181    Yields:
182        Each successive object in the chain.
183    """
184    next = getattr(obj, attr)
185    while next:
186        yield next
187        next = getattr(next, attr)
189p0 = """for_stmt< 'for' any 'in' node=any ':' any* >
190        | comp_for< 'for' any 'in' node=any any* >
191     """
192p1 = """
194    ( 'iter' | 'list' | 'tuple' | 'sorted' | 'set' | 'sum' |
195      'any' | 'all' | 'enumerate' | (any* trailer< '.' 'join' >) )
196    trailer< '(' node=any ')' >
197    any*
200p2 = """
202    ( 'sorted' | 'enumerate' )
203    trailer< '(' arglist<node=any any*> ')' >
204    any*
207pats_built = False
208def in_special_context(node):
209    """ Returns true if node is in an environment where all that is required
210        of it is being iterable (ie, it doesn't matter if it returns a list
211        or an iterator).
212        See test_map_nochange in test_fixers.py for some examples and tests.
213        """
214    global p0, p1, p2, pats_built
215    if not pats_built:
216        p0 = patcomp.compile_pattern(p0)
217        p1 = patcomp.compile_pattern(p1)
218        p2 = patcomp.compile_pattern(p2)
219        pats_built = True
220    patterns = [p0, p1, p2]
221    for pattern, parent in zip(patterns, attr_chain(node, "parent")):
222        results = {}
223        if pattern.match(parent, results) and results["node"] is node:
224            return True
225    return False
227def is_probably_builtin(node):
228    """
229    Check that something isn't an attribute or function name etc.
230    """
231    prev = node.prev_sibling
232    if prev is not None and prev.type == token.DOT:
233        # Attribute lookup.
234        return False
235    parent = node.parent
236    if parent.type in (syms.funcdef, syms.classdef):
237        return False
238    if parent.type == syms.expr_stmt and parent.children[0] is node:
239        # Assignment.
240        return False
241    if parent.type == syms.parameters or \
242            (parent.type == syms.typedargslist and (
243            (prev is not None and prev.type == token.COMMA) or
244            parent.children[0] is node
245            )):
246        # The name of an argument.
247        return False
248    return True
250def find_indentation(node):
251    """Find the indentation of *node*."""
252    while node is not None:
253        if node.type == syms.suite and len(node.children) > 2:
254            indent = node.children[1]
255            if indent.type == token.INDENT:
256                return indent.value
257        node = node.parent
258    return u""
261### The following functions are to find bindings in a suite
264def make_suite(node):
265    if node.type == syms.suite:
266        return node
267    node = node.clone()
268    parent, node.parent = node.parent, None
269    suite = Node(syms.suite, [node])
270    suite.parent = parent
271    return suite
273def find_root(node):
274    """Find the top level namespace."""
275    # Scamper up to the top level namespace
276    while node.type != syms.file_input:
277        node = node.parent
278        if not node:
279            raise ValueError("root found before file_input node was found.")
280    return node
282def does_tree_import(package, name, node):
283    """ Returns true if name is imported from package at the
284        top level of the tree which node belongs to.
285        To cover the case of an import like 'import foo', use
286        None for the package and 'foo' for the name. """
287    binding = find_binding(name, find_root(node), package)
288    return bool(binding)
290def is_import(node):
291    """Returns true if the node is an import statement."""
292    return node.type in (syms.import_name, syms.import_from)
294def touch_import(package, name, node):
295    """ Works like `does_tree_import` but adds an import statement
296        if it was not imported. """
297    def is_import_stmt(node):
298        return (node.type == syms.simple_stmt and node.children and
299                is_import(node.children[0]))
301    root = find_root(node)
303    if does_tree_import(package, name, root):
304        return
306    # figure out where to insert the new import.  First try to find
307    # the first import and then skip to the last one.
308    insert_pos = offset = 0
309    for idx, node in enumerate(root.children):
310        if not is_import_stmt(node):
311            continue
312        for offset, node2 in enumerate(root.children[idx:]):
313            if not is_import_stmt(node2):
314                break
315        insert_pos = idx + offset
316        break
318    # if there are no imports where we can insert, find the docstring.
319    # if that also fails, we stick to the beginning of the file
320    if insert_pos == 0:
321        for idx, node in enumerate(root.children):
322            if (node.type == syms.simple_stmt and node.children and
323               node.children[0].type == token.STRING):
324                insert_pos = idx + 1
325                break
327    if package is None:
328        import_ = Node(syms.import_name, [
329            Leaf(token.NAME, u"import"),
330            Leaf(token.NAME, name, prefix=u" ")
331        ])
332    else:
333        import_ = FromImport(package, [Leaf(token.NAME, name, prefix=u" ")])
335    children = [import_, Newline()]
336    root.insert_child(insert_pos, Node(syms.simple_stmt, children))
339_def_syms = set([syms.classdef, syms.funcdef])
340def find_binding(name, node, package=None):
341    """ Returns the node which binds variable name, otherwise None.
342        If optional argument package is supplied, only imports will
343        be returned.
344        See test cases for examples."""
345    for child in node.children:
346        ret = None
347        if child.type == syms.for_stmt:
348            if _find(name, child.children[1]):
349                return child
350            n = find_binding(name, make_suite(child.children[-1]), package)
351            if n: ret = n
352        elif child.type in (syms.if_stmt, syms.while_stmt):
353            n = find_binding(name, make_suite(child.children[-1]), package)
354            if n: ret = n
355        elif child.type == syms.try_stmt:
356            n = find_binding(name, make_suite(child.children[2]), package)
357            if n:
358                ret = n
359            else:
360                for i, kid in enumerate(child.children[3:]):
361                    if kid.type == token.COLON and kid.value == ":":
362                        # i+3 is the colon, i+4 is the suite
363                        n = find_binding(name, make_suite(child.children[i+4]), package)
364                        if n: ret = n
365        elif child.type in _def_syms and child.children[1].value == name:
366            ret = child
367        elif _is_import_binding(child, name, package):
368            ret = child
369        elif child.type == syms.simple_stmt:
370            ret = find_binding(name, child, package)
371        elif child.type == syms.expr_stmt:
372            if _find(name, child.children[0]):
373                ret = child
375        if ret:
376            if not package:
377                return ret
378            if is_import(ret):
379                return ret
380    return None
382_block_syms = set([syms.funcdef, syms.classdef, syms.trailer])
383def _find(name, node):
384    nodes = [node]
385    while nodes:
386        node = nodes.pop()
387        if node.type > 256 and node.type not in _block_syms:
388            nodes.extend(node.children)
389        elif node.type == token.NAME and node.value == name:
390            return node
391    return None
393def _is_import_binding(node, name, package=None):
394    """ Will reuturn node if node will import name, or node
395        will import * from package.  None is returned otherwise.
396        See test cases for examples. """
398    if node.type == syms.import_name and not package:
399        imp = node.children[1]
400        if imp.type == syms.dotted_as_names:
401            for child in imp.children:
402                if child.type == syms.dotted_as_name:
403                    if child.children[2].value == name:
404                        return node
405                elif child.type == token.NAME and child.value == name:
406                    return node
407        elif imp.type == syms.dotted_as_name:
408            last = imp.children[-1]
409            if last.type == token.NAME and last.value == name:
410                return node
411        elif imp.type == token.NAME and imp.value == name:
412            return node
413    elif node.type == syms.import_from:
414        # unicode(...) is used to make life easier here, because
415        # from a.b import parses to ['import', ['a', '.', 'b'], ...]
416        if package and unicode(node.children[1]).strip() != package:
417            return None
418        n = node.children[3]
419        if package and _find(u"as", n):
420            # See test_from_import_as for explanation
421            return None
422        elif n.type == syms.import_as_names and _find(name, n):
423            return node
424        elif n.type == syms.import_as_name:
425            child = n.children[2]
426            if child.type == token.NAME and child.value == name:
427                return node
428        elif n.type == token.NAME and n.value == name:
429            return node
430        elif package and n.type == token.STAR:
431            return node
432    return None