• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""Utility functions, node construction macros, etc."""
2# Author: Collin Winter
3
4from . import patcomp
5# Local imports
6from .pgen2 import token
7from .pygram import python_symbols as syms
8from .pytree import Leaf
9from .pytree import Node
10
11###########################################################
12# Common node-construction "macros"
13###########################################################
14
15
16def KeywordArg(keyword, value):
17  return Node(syms.argument, [keyword, Leaf(token.EQUAL, '='), value])
18
19
20def LParen():
21  return Leaf(token.LPAR, '(')
22
23
24def RParen():
25  return Leaf(token.RPAR, ')')
26
27
28def Assign(target, source):
29  """Build an assignment statement"""
30  if not isinstance(target, list):
31    target = [target]
32  if not isinstance(source, list):
33    source.prefix = ' '
34    source = [source]
35
36  return Node(syms.atom, target + [Leaf(token.EQUAL, '=', prefix=' ')] + source)
37
38
39def Name(name, prefix=None):
40  """Return a NAME leaf"""
41  return Leaf(token.NAME, name, prefix=prefix)
42
43
44def Attr(obj, attr):
45  """A node tuple for obj.attr"""
46  return [obj, Node(syms.trailer, [Dot(), attr])]
47
48
49def Comma():
50  """A comma leaf"""
51  return Leaf(token.COMMA, ',')
52
53
54def Dot():
55  """A period (.) leaf"""
56  return Leaf(token.DOT, '.')
57
58
59def ArgList(args, lparen=LParen(), rparen=RParen()):
60  """A parenthesised argument list, used by Call()"""
61  node = Node(syms.trailer, [lparen.clone(), rparen.clone()])
62  if args:
63    node.insert_child(1, Node(syms.arglist, args))
64  return node
65
66
67def Call(func_name, args=None, prefix=None):
68  """A function call"""
69  node = Node(syms.power, [func_name, ArgList(args)])
70  if prefix is not None:
71    node.prefix = prefix
72  return node
73
74
75def Newline():
76  """A newline literal"""
77  return Leaf(token.NEWLINE, '\n')
78
79
80def BlankLine():
81  """A blank line"""
82  return Leaf(token.NEWLINE, '')
83
84
85def Number(n, prefix=None):
86  return Leaf(token.NUMBER, n, prefix=prefix)
87
88
89def Subscript(index_node):
90  """A numeric or string subscript"""
91  return Node(syms.trailer,
92              [Leaf(token.LBRACE, '['), index_node,
93               Leaf(token.RBRACE, ']')])
94
95
96def String(string, prefix=None):
97  """A string leaf"""
98  return Leaf(token.STRING, string, prefix=prefix)
99
100
101def ListComp(xp, fp, it, test=None):
102  """A list comprehension of the form [xp for fp in it if test].
103
104  If test is None, the "if test" part is omitted.
105  """
106  xp.prefix = ''
107  fp.prefix = ' '
108  it.prefix = ' '
109  for_leaf = Leaf(token.NAME, 'for')
110  for_leaf.prefix = ' '
111  in_leaf = Leaf(token.NAME, 'in')
112  in_leaf.prefix = ' '
113  inner_args = [for_leaf, fp, in_leaf, it]
114  if test:
115    test.prefix = ' '
116    if_leaf = Leaf(token.NAME, 'if')
117    if_leaf.prefix = ' '
118    inner_args.append(Node(syms.comp_if, [if_leaf, test]))
119  inner = Node(syms.listmaker, [xp, Node(syms.comp_for, inner_args)])
120  return Node(syms.atom,
121              [Leaf(token.LBRACE, '['), inner,
122               Leaf(token.RBRACE, ']')])
123
124
125def FromImport(package_name, name_leafs):
126  """ Return an import statement in the form:
127
128       from package import name_leafs
129  """
130  # XXX: May not handle dotted imports properly (eg, package_name='foo.bar')
131  # #assert package_name == '.' or '.' not in package_name, "FromImport has "\
132  #       "not been tested with dotted package names -- use at your own "\
133  #       "peril!"
134
135  for leaf in name_leafs:
136    # Pull the leaves out of their old tree
137    leaf.remove()
138
139  children = [
140      Leaf(token.NAME, 'from'),
141      Leaf(token.NAME, package_name, prefix=' '),
142      Leaf(token.NAME, 'import', prefix=' '),
143      Node(syms.import_as_names, name_leafs)
144  ]
145  imp = Node(syms.import_from, children)
146  return imp
147
148
149def ImportAndCall(node, results, names):
150  """Returns an import statement and calls a method of the module:
151
152      import module
153      module.name()
154  """
155  obj = results['obj'].clone()
156  if obj.type == syms.arglist:
157    newarglist = obj.clone()
158  else:
159    newarglist = Node(syms.arglist, [obj.clone()])
160  after = results['after']
161  if after:
162    after = [n.clone() for n in after]
163  new = Node(
164      syms.power,
165      Attr(Name(names[0]), Name(names[1])) + [
166          Node(syms.trailer,
167               [results['lpar'].clone(), newarglist, results['rpar'].clone()])
168      ] + after)
169  new.prefix = node.prefix
170  return new
171
172
173###########################################################
174# Determine whether a node represents a given literal
175###########################################################
176
177
178def is_tuple(node):
179  """Does the node represent a tuple literal?"""
180  if isinstance(node, Node) and node.children == [LParen(), RParen()]:
181    return True
182  return (isinstance(node, Node) and len(node.children) == 3 and
183          isinstance(node.children[0], Leaf) and
184          isinstance(node.children[1], Node) and
185          isinstance(node.children[2], Leaf) and
186          node.children[0].value == '(' and node.children[2].value == ')')
187
188
189def is_list(node):
190  """Does the node represent a list literal?"""
191  return (isinstance(node, Node) and len(node.children) > 1 and
192          isinstance(node.children[0], Leaf) and
193          isinstance(node.children[-1], Leaf) and
194          node.children[0].value == '[' and node.children[-1].value == ']')
195
196
197###########################################################
198# Misc
199###########################################################
200
201
202def parenthesize(node):
203  return Node(syms.atom, [LParen(), node, RParen()])
204
205
206consuming_calls = {
207    'sorted', 'list', 'set', 'any', 'all', 'tuple', 'sum', 'min', 'max',
208    'enumerate'
209}
210
211
212def attr_chain(obj, attr):
213  """Follow an attribute chain.
214
215  If you have a chain of objects where a.foo -> b, b.foo-> c, etc, use this to
216  iterate over all objects in the chain. Iteration is terminated by getattr(x,
217  attr) is None.
218
219  Args:
220      obj: the starting object
221      attr: the name of the chaining attribute
222
223  Yields:
224      Each successive object in the chain.
225  """
226  next = getattr(obj, attr)
227  while next:
228    yield next
229    next = getattr(next, attr)
230
231
232p0 = """for_stmt< 'for' any 'in' node=any ':' any* >
233        | comp_for< 'for' any 'in' node=any any* >
234     """
235p1 = """
236power<
237    ( 'iter' | 'list' | 'tuple' | 'sorted' | 'set' | 'sum' |
238      'any' | 'all' | 'enumerate' | (any* trailer< '.' 'join' >) )
239    trailer< '(' node=any ')' >
240    any*
241>
242"""
243p2 = """
244power<
245    ( 'sorted' | 'enumerate' )
246    trailer< '(' arglist<node=any any*> ')' >
247    any*
248>
249"""
250pats_built = False
251
252
253def in_special_context(node):
254  """ Returns true if node is in an environment where all that is required
255      of it is being iterable (ie, it doesn't matter if it returns a list
256      or an iterator).
257      See test_map_nochange in test_fixers.py for some examples and tests.
258  """
259  global p0, p1, p2, pats_built
260  if not pats_built:
261    p0 = patcomp.compile_pattern(p0)
262    p1 = patcomp.compile_pattern(p1)
263    p2 = patcomp.compile_pattern(p2)
264    pats_built = True
265  patterns = [p0, p1, p2]
266  for pattern, parent in zip(patterns, attr_chain(node, 'parent')):
267    results = {}
268    if pattern.match(parent, results) and results['node'] is node:
269      return True
270  return False
271
272
273def is_probably_builtin(node):
274  """Check that something isn't an attribute or function name etc."""
275  prev = node.prev_sibling
276  if prev is not None and prev.type == token.DOT:
277    # Attribute lookup.
278    return False
279  parent = node.parent
280  if parent.type in (syms.funcdef, syms.classdef):
281    return False
282  if parent.type == syms.expr_stmt and parent.children[0] is node:
283    # Assignment.
284    return False
285  if parent.type == syms.parameters or (parent.type == syms.typedargslist and (
286      (prev is not None and prev.type == token.COMMA) or
287      parent.children[0] is node)):
288    # The name of an argument.
289    return False
290  return True
291
292
293def find_indentation(node):
294  """Find the indentation of *node*."""
295  while node is not None:
296    if node.type == syms.suite and len(node.children) > 2:
297      indent = node.children[1]
298      if indent.type == token.INDENT:
299        return indent.value
300    node = node.parent
301  return ''
302
303
304###########################################################
305# The following functions are to find bindings in a suite
306###########################################################
307
308
309def make_suite(node):
310  if node.type == syms.suite:
311    return node
312  node = node.clone()
313  parent, node.parent = node.parent, None
314  suite = Node(syms.suite, [node])
315  suite.parent = parent
316  return suite
317
318
319def find_root(node):
320  """Find the top level namespace."""
321  # Scamper up to the top level namespace
322  while node.type != syms.file_input:
323    node = node.parent
324    if not node:
325      raise ValueError('root found before file_input node was found.')
326  return node
327
328
329def does_tree_import(package, name, node):
330  """ Returns true if name is imported from package at the
331      top level of the tree which node belongs to.
332      To cover the case of an import like 'import foo', use
333      None for the package and 'foo' for the name.
334  """
335  binding = find_binding(name, find_root(node), package)
336  return bool(binding)
337
338
339def is_import(node):
340  """Returns true if the node is an import statement."""
341  return node.type in (syms.import_name, syms.import_from)
342
343
344def touch_import(package, name, node):
345  """ Works like `does_tree_import` but adds an import statement
346      if it was not imported. """
347
348  def is_import_stmt(node):
349    return (node.type == syms.simple_stmt and node.children and
350            is_import(node.children[0]))
351
352  root = find_root(node)
353
354  if does_tree_import(package, name, root):
355    return
356
357  # figure out where to insert the new import.  First try to find
358  # the first import and then skip to the last one.
359  insert_pos = offset = 0
360  for idx, node in enumerate(root.children):
361    if not is_import_stmt(node):
362      continue
363    for offset, node2 in enumerate(root.children[idx:]):
364      if not is_import_stmt(node2):
365        break
366    insert_pos = idx + offset
367    break
368
369  # if there are no imports where we can insert, find the docstring.
370  # if that also fails, we stick to the beginning of the file
371  if insert_pos == 0:
372    for idx, node in enumerate(root.children):
373      if (node.type == syms.simple_stmt and node.children and
374          node.children[0].type == token.STRING):
375        insert_pos = idx + 1
376        break
377
378  if package is None:
379    import_ = Node(
380        syms.import_name,
381        [Leaf(token.NAME, 'import'),
382         Leaf(token.NAME, name, prefix=' ')])
383  else:
384    import_ = FromImport(package, [Leaf(token.NAME, name, prefix=' ')])
385
386  children = [import_, Newline()]
387  root.insert_child(insert_pos, Node(syms.simple_stmt, children))
388
389
390_def_syms = {syms.classdef, syms.funcdef}
391
392
393def find_binding(name, node, package=None):
394  """ Returns the node which binds variable name, otherwise None.
395      If optional argument package is supplied, only imports will
396      be returned.
397      See test cases for examples.
398  """
399  for child in node.children:
400    ret = None
401    if child.type == syms.for_stmt:
402      if _find(name, child.children[1]):
403        return child
404      n = find_binding(name, make_suite(child.children[-1]), package)
405      if n:
406        ret = n
407    elif child.type in (syms.if_stmt, syms.while_stmt):
408      n = find_binding(name, make_suite(child.children[-1]), package)
409      if n:
410        ret = n
411    elif child.type == syms.try_stmt:
412      n = find_binding(name, make_suite(child.children[2]), package)
413      if n:
414        ret = n
415      else:
416        for i, kid in enumerate(child.children[3:]):
417          if kid.type == token.COLON and kid.value == ':':
418            # i+3 is the colon, i+4 is the suite
419            n = find_binding(name, make_suite(child.children[i + 4]), package)
420            if n:
421              ret = n
422    elif child.type in _def_syms and child.children[1].value == name:
423      ret = child
424    elif _is_import_binding(child, name, package):
425      ret = child
426    elif child.type == syms.simple_stmt:
427      ret = find_binding(name, child, package)
428    elif child.type == syms.expr_stmt:
429      if _find(name, child.children[0]):
430        ret = child
431
432    if ret:
433      if not package:
434        return ret
435      if is_import(ret):
436        return ret
437  return None
438
439
440_block_syms = {syms.funcdef, syms.classdef, syms.trailer}
441
442
443def _find(name, node):
444  nodes = [node]
445  while nodes:
446    node = nodes.pop()
447    if node.type > 256 and node.type not in _block_syms:
448      nodes.extend(node.children)
449    elif node.type == token.NAME and node.value == name:
450      return node
451  return None
452
453
454def _is_import_binding(node, name, package=None):
455  """ Will return node if node will import name, or node
456      will import * from package.  None is returned otherwise.
457      See test cases for examples.
458  """
459  if node.type == syms.import_name and not package:
460    imp = node.children[1]
461    if imp.type == syms.dotted_as_names:
462      for child in imp.children:
463        if child.type == syms.dotted_as_name:
464          if child.children[2].value == name:
465            return node
466        elif child.type == token.NAME and child.value == name:
467          return node
468    elif imp.type == syms.dotted_as_name:
469      last = imp.children[-1]
470      if last.type == token.NAME and last.value == name:
471        return node
472    elif imp.type == token.NAME and imp.value == name:
473      return node
474  elif node.type == syms.import_from:
475    # str(...) is used to make life easier here, because
476    # from a.b import parses to ['import', ['a', '.', 'b'], ...]
477    if package and str(node.children[1]).strip() != package:
478      return None
479    n = node.children[3]
480    if package and _find('as', n):
481      # See test_from_import_as for explanation
482      return None
483    elif n.type == syms.import_as_names and _find(name, n):
484      return node
485    elif n.type == syms.import_as_name:
486      child = n.children[2]
487      if child.type == token.NAME and child.value == name:
488        return node
489    elif n.type == token.NAME and n.value == name:
490      return node
491    elif package and n.type == token.STAR:
492      return node
493  return None
494