1"""Utility functions, node construction macros, etc.""" 2# Author: Collin Winter 3 4from . import patcomp 5# Local imports 6from .pgen2 import token 7from .pygram import python_symbols as syms 8from .pytree import Leaf 9from .pytree import Node 10 11########################################################### 12# Common node-construction "macros" 13########################################################### 14 15 16def KeywordArg(keyword, value): 17 return Node(syms.argument, [keyword, Leaf(token.EQUAL, '='), value]) 18 19 20def LParen(): 21 return Leaf(token.LPAR, '(') 22 23 24def RParen(): 25 return Leaf(token.RPAR, ')') 26 27 28def Assign(target, source): 29 """Build an assignment statement""" 30 if not isinstance(target, list): 31 target = [target] 32 if not isinstance(source, list): 33 source.prefix = ' ' 34 source = [source] 35 36 return Node(syms.atom, target + [Leaf(token.EQUAL, '=', prefix=' ')] + source) 37 38 39def Name(name, prefix=None): 40 """Return a NAME leaf""" 41 return Leaf(token.NAME, name, prefix=prefix) 42 43 44def Attr(obj, attr): 45 """A node tuple for obj.attr""" 46 return [obj, Node(syms.trailer, [Dot(), attr])] 47 48 49def Comma(): 50 """A comma leaf""" 51 return Leaf(token.COMMA, ',') 52 53 54def Dot(): 55 """A period (.) leaf""" 56 return Leaf(token.DOT, '.') 57 58 59def ArgList(args, lparen=LParen(), rparen=RParen()): 60 """A parenthesised argument list, used by Call()""" 61 node = Node(syms.trailer, [lparen.clone(), rparen.clone()]) 62 if args: 63 node.insert_child(1, Node(syms.arglist, args)) 64 return node 65 66 67def Call(func_name, args=None, prefix=None): 68 """A function call""" 69 node = Node(syms.power, [func_name, ArgList(args)]) 70 if prefix is not None: 71 node.prefix = prefix 72 return node 73 74 75def Newline(): 76 """A newline literal""" 77 return Leaf(token.NEWLINE, '\n') 78 79 80def BlankLine(): 81 """A blank line""" 82 return Leaf(token.NEWLINE, '') 83 84 85def Number(n, prefix=None): 86 return Leaf(token.NUMBER, n, prefix=prefix) 87 88 89def Subscript(index_node): 90 """A numeric or string subscript""" 91 return Node(syms.trailer, 92 [Leaf(token.LBRACE, '['), index_node, 93 Leaf(token.RBRACE, ']')]) 94 95 96def String(string, prefix=None): 97 """A string leaf""" 98 return Leaf(token.STRING, string, prefix=prefix) 99 100 101def ListComp(xp, fp, it, test=None): 102 """A list comprehension of the form [xp for fp in it if test]. 103 104 If test is None, the "if test" part is omitted. 105 """ 106 xp.prefix = '' 107 fp.prefix = ' ' 108 it.prefix = ' ' 109 for_leaf = Leaf(token.NAME, 'for') 110 for_leaf.prefix = ' ' 111 in_leaf = Leaf(token.NAME, 'in') 112 in_leaf.prefix = ' ' 113 inner_args = [for_leaf, fp, in_leaf, it] 114 if test: 115 test.prefix = ' ' 116 if_leaf = Leaf(token.NAME, 'if') 117 if_leaf.prefix = ' ' 118 inner_args.append(Node(syms.comp_if, [if_leaf, test])) 119 inner = Node(syms.listmaker, [xp, Node(syms.comp_for, inner_args)]) 120 return Node(syms.atom, 121 [Leaf(token.LBRACE, '['), inner, 122 Leaf(token.RBRACE, ']')]) 123 124 125def FromImport(package_name, name_leafs): 126 """ Return an import statement in the form: 127 128 from package import name_leafs 129 """ 130 # XXX: May not handle dotted imports properly (eg, package_name='foo.bar') 131 # #assert package_name == '.' or '.' not in package_name, "FromImport has "\ 132 # "not been tested with dotted package names -- use at your own "\ 133 # "peril!" 134 135 for leaf in name_leafs: 136 # Pull the leaves out of their old tree 137 leaf.remove() 138 139 children = [ 140 Leaf(token.NAME, 'from'), 141 Leaf(token.NAME, package_name, prefix=' '), 142 Leaf(token.NAME, 'import', prefix=' '), 143 Node(syms.import_as_names, name_leafs) 144 ] 145 imp = Node(syms.import_from, children) 146 return imp 147 148 149def ImportAndCall(node, results, names): 150 """Returns an import statement and calls a method of the module: 151 152 import module 153 module.name() 154 """ 155 obj = results['obj'].clone() 156 if obj.type == syms.arglist: 157 newarglist = obj.clone() 158 else: 159 newarglist = Node(syms.arglist, [obj.clone()]) 160 after = results['after'] 161 if after: 162 after = [n.clone() for n in after] 163 new = Node( 164 syms.power, 165 Attr(Name(names[0]), Name(names[1])) + [ 166 Node(syms.trailer, 167 [results['lpar'].clone(), newarglist, results['rpar'].clone()]) 168 ] + after) 169 new.prefix = node.prefix 170 return new 171 172 173########################################################### 174# Determine whether a node represents a given literal 175########################################################### 176 177 178def is_tuple(node): 179 """Does the node represent a tuple literal?""" 180 if isinstance(node, Node) and node.children == [LParen(), RParen()]: 181 return True 182 return (isinstance(node, Node) and len(node.children) == 3 and 183 isinstance(node.children[0], Leaf) and 184 isinstance(node.children[1], Node) and 185 isinstance(node.children[2], Leaf) and 186 node.children[0].value == '(' and node.children[2].value == ')') 187 188 189def is_list(node): 190 """Does the node represent a list literal?""" 191 return (isinstance(node, Node) and len(node.children) > 1 and 192 isinstance(node.children[0], Leaf) and 193 isinstance(node.children[-1], Leaf) and 194 node.children[0].value == '[' and node.children[-1].value == ']') 195 196 197########################################################### 198# Misc 199########################################################### 200 201 202def parenthesize(node): 203 return Node(syms.atom, [LParen(), node, RParen()]) 204 205 206consuming_calls = { 207 'sorted', 'list', 'set', 'any', 'all', 'tuple', 'sum', 'min', 'max', 208 'enumerate' 209} 210 211 212def attr_chain(obj, attr): 213 """Follow an attribute chain. 214 215 If you have a chain of objects where a.foo -> b, b.foo-> c, etc, use this to 216 iterate over all objects in the chain. Iteration is terminated by getattr(x, 217 attr) is None. 218 219 Args: 220 obj: the starting object 221 attr: the name of the chaining attribute 222 223 Yields: 224 Each successive object in the chain. 225 """ 226 next = getattr(obj, attr) 227 while next: 228 yield next 229 next = getattr(next, attr) 230 231 232p0 = """for_stmt< 'for' any 'in' node=any ':' any* > 233 | comp_for< 'for' any 'in' node=any any* > 234 """ 235p1 = """ 236power< 237 ( 'iter' | 'list' | 'tuple' | 'sorted' | 'set' | 'sum' | 238 'any' | 'all' | 'enumerate' | (any* trailer< '.' 'join' >) ) 239 trailer< '(' node=any ')' > 240 any* 241> 242""" 243p2 = """ 244power< 245 ( 'sorted' | 'enumerate' ) 246 trailer< '(' arglist<node=any any*> ')' > 247 any* 248> 249""" 250pats_built = False 251 252 253def in_special_context(node): 254 """ Returns true if node is in an environment where all that is required 255 of it is being iterable (ie, it doesn't matter if it returns a list 256 or an iterator). 257 See test_map_nochange in test_fixers.py for some examples and tests. 258 """ 259 global p0, p1, p2, pats_built 260 if not pats_built: 261 p0 = patcomp.compile_pattern(p0) 262 p1 = patcomp.compile_pattern(p1) 263 p2 = patcomp.compile_pattern(p2) 264 pats_built = True 265 patterns = [p0, p1, p2] 266 for pattern, parent in zip(patterns, attr_chain(node, 'parent')): 267 results = {} 268 if pattern.match(parent, results) and results['node'] is node: 269 return True 270 return False 271 272 273def is_probably_builtin(node): 274 """Check that something isn't an attribute or function name etc.""" 275 prev = node.prev_sibling 276 if prev is not None and prev.type == token.DOT: 277 # Attribute lookup. 278 return False 279 parent = node.parent 280 if parent.type in (syms.funcdef, syms.classdef): 281 return False 282 if parent.type == syms.expr_stmt and parent.children[0] is node: 283 # Assignment. 284 return False 285 if parent.type == syms.parameters or (parent.type == syms.typedargslist and ( 286 (prev is not None and prev.type == token.COMMA) or 287 parent.children[0] is node)): 288 # The name of an argument. 289 return False 290 return True 291 292 293def find_indentation(node): 294 """Find the indentation of *node*.""" 295 while node is not None: 296 if node.type == syms.suite and len(node.children) > 2: 297 indent = node.children[1] 298 if indent.type == token.INDENT: 299 return indent.value 300 node = node.parent 301 return '' 302 303 304########################################################### 305# The following functions are to find bindings in a suite 306########################################################### 307 308 309def make_suite(node): 310 if node.type == syms.suite: 311 return node 312 node = node.clone() 313 parent, node.parent = node.parent, None 314 suite = Node(syms.suite, [node]) 315 suite.parent = parent 316 return suite 317 318 319def find_root(node): 320 """Find the top level namespace.""" 321 # Scamper up to the top level namespace 322 while node.type != syms.file_input: 323 node = node.parent 324 if not node: 325 raise ValueError('root found before file_input node was found.') 326 return node 327 328 329def does_tree_import(package, name, node): 330 """ Returns true if name is imported from package at the 331 top level of the tree which node belongs to. 332 To cover the case of an import like 'import foo', use 333 None for the package and 'foo' for the name. 334 """ 335 binding = find_binding(name, find_root(node), package) 336 return bool(binding) 337 338 339def is_import(node): 340 """Returns true if the node is an import statement.""" 341 return node.type in (syms.import_name, syms.import_from) 342 343 344def touch_import(package, name, node): 345 """ Works like `does_tree_import` but adds an import statement 346 if it was not imported. """ 347 348 def is_import_stmt(node): 349 return (node.type == syms.simple_stmt and node.children and 350 is_import(node.children[0])) 351 352 root = find_root(node) 353 354 if does_tree_import(package, name, root): 355 return 356 357 # figure out where to insert the new import. First try to find 358 # the first import and then skip to the last one. 359 insert_pos = offset = 0 360 for idx, node in enumerate(root.children): 361 if not is_import_stmt(node): 362 continue 363 for offset, node2 in enumerate(root.children[idx:]): 364 if not is_import_stmt(node2): 365 break 366 insert_pos = idx + offset 367 break 368 369 # if there are no imports where we can insert, find the docstring. 370 # if that also fails, we stick to the beginning of the file 371 if insert_pos == 0: 372 for idx, node in enumerate(root.children): 373 if (node.type == syms.simple_stmt and node.children and 374 node.children[0].type == token.STRING): 375 insert_pos = idx + 1 376 break 377 378 if package is None: 379 import_ = Node( 380 syms.import_name, 381 [Leaf(token.NAME, 'import'), 382 Leaf(token.NAME, name, prefix=' ')]) 383 else: 384 import_ = FromImport(package, [Leaf(token.NAME, name, prefix=' ')]) 385 386 children = [import_, Newline()] 387 root.insert_child(insert_pos, Node(syms.simple_stmt, children)) 388 389 390_def_syms = {syms.classdef, syms.funcdef} 391 392 393def find_binding(name, node, package=None): 394 """ Returns the node which binds variable name, otherwise None. 395 If optional argument package is supplied, only imports will 396 be returned. 397 See test cases for examples. 398 """ 399 for child in node.children: 400 ret = None 401 if child.type == syms.for_stmt: 402 if _find(name, child.children[1]): 403 return child 404 n = find_binding(name, make_suite(child.children[-1]), package) 405 if n: 406 ret = n 407 elif child.type in (syms.if_stmt, syms.while_stmt): 408 n = find_binding(name, make_suite(child.children[-1]), package) 409 if n: 410 ret = n 411 elif child.type == syms.try_stmt: 412 n = find_binding(name, make_suite(child.children[2]), package) 413 if n: 414 ret = n 415 else: 416 for i, kid in enumerate(child.children[3:]): 417 if kid.type == token.COLON and kid.value == ':': 418 # i+3 is the colon, i+4 is the suite 419 n = find_binding(name, make_suite(child.children[i + 4]), package) 420 if n: 421 ret = n 422 elif child.type in _def_syms and child.children[1].value == name: 423 ret = child 424 elif _is_import_binding(child, name, package): 425 ret = child 426 elif child.type == syms.simple_stmt: 427 ret = find_binding(name, child, package) 428 elif child.type == syms.expr_stmt: 429 if _find(name, child.children[0]): 430 ret = child 431 432 if ret: 433 if not package: 434 return ret 435 if is_import(ret): 436 return ret 437 return None 438 439 440_block_syms = {syms.funcdef, syms.classdef, syms.trailer} 441 442 443def _find(name, node): 444 nodes = [node] 445 while nodes: 446 node = nodes.pop() 447 if node.type > 256 and node.type not in _block_syms: 448 nodes.extend(node.children) 449 elif node.type == token.NAME and node.value == name: 450 return node 451 return None 452 453 454def _is_import_binding(node, name, package=None): 455 """ Will return node if node will import name, or node 456 will import * from package. None is returned otherwise. 457 See test cases for examples. 458 """ 459 if node.type == syms.import_name and not package: 460 imp = node.children[1] 461 if imp.type == syms.dotted_as_names: 462 for child in imp.children: 463 if child.type == syms.dotted_as_name: 464 if child.children[2].value == name: 465 return node 466 elif child.type == token.NAME and child.value == name: 467 return node 468 elif imp.type == syms.dotted_as_name: 469 last = imp.children[-1] 470 if last.type == token.NAME and last.value == name: 471 return node 472 elif imp.type == token.NAME and imp.value == name: 473 return node 474 elif node.type == syms.import_from: 475 # str(...) is used to make life easier here, because 476 # from a.b import parses to ['import', ['a', '.', 'b'], ...] 477 if package and str(node.children[1]).strip() != package: 478 return None 479 n = node.children[3] 480 if package and _find('as', n): 481 # See test_from_import_as for explanation 482 return None 483 elif n.type == syms.import_as_names and _find(name, n): 484 return node 485 elif n.type == syms.import_as_name: 486 child = n.children[2] 487 if child.type == token.NAME and child.value == name: 488 return node 489 elif n.type == token.NAME and n.value == name: 490 return node 491 elif package and n.type == token.STAR: 492 return node 493 return None 494