1"""Utility functions, node construction macros, etc.""" 2# Author: Collin Winter 3 4# Local imports 5from .pgen2 import token 6from .pytree import Leaf, Node 7from .pygram import python_symbols as syms 8from . import patcomp 9 10 11########################################################### 12### Common node-construction "macros" 13########################################################### 14 15def KeywordArg(keyword, value): 16 return Node(syms.argument, 17 [keyword, Leaf(token.EQUAL, "="), value]) 18 19def LParen(): 20 return Leaf(token.LPAR, "(") 21 22def RParen(): 23 return Leaf(token.RPAR, ")") 24 25def Assign(target, source): 26 """Build an assignment statement""" 27 if not isinstance(target, list): 28 target = [target] 29 if not isinstance(source, list): 30 source.prefix = " " 31 source = [source] 32 33 return Node(syms.atom, 34 target + [Leaf(token.EQUAL, "=", prefix=" ")] + source) 35 36def Name(name, prefix=None): 37 """Return a NAME leaf""" 38 return Leaf(token.NAME, name, prefix=prefix) 39 40def Attr(obj, attr): 41 """A node tuple for obj.attr""" 42 return [obj, Node(syms.trailer, [Dot(), attr])] 43 44def Comma(): 45 """A comma leaf""" 46 return Leaf(token.COMMA, ",") 47 48def Dot(): 49 """A period (.) leaf""" 50 return Leaf(token.DOT, ".") 51 52def ArgList(args, lparen=LParen(), rparen=RParen()): 53 """A parenthesised argument list, used by Call()""" 54 node = Node(syms.trailer, [lparen.clone(), rparen.clone()]) 55 if args: 56 node.insert_child(1, Node(syms.arglist, args)) 57 return node 58 59def Call(func_name, args=None, prefix=None): 60 """A function call""" 61 node = Node(syms.power, [func_name, ArgList(args)]) 62 if prefix is not None: 63 node.prefix = prefix 64 return node 65 66def Newline(): 67 """A newline literal""" 68 return Leaf(token.NEWLINE, "\n") 69 70def BlankLine(): 71 """A blank line""" 72 return Leaf(token.NEWLINE, "") 73 74def Number(n, prefix=None): 75 return Leaf(token.NUMBER, n, prefix=prefix) 76 77def Subscript(index_node): 78 """A numeric or string subscript""" 79 return Node(syms.trailer, [Leaf(token.LBRACE, "["), 80 index_node, 81 Leaf(token.RBRACE, "]")]) 82 83def String(string, prefix=None): 84 """A string leaf""" 85 return Leaf(token.STRING, string, prefix=prefix) 86 87def ListComp(xp, fp, it, test=None): 88 """A list comprehension of the form [xp for fp in it if test]. 89 90 If test is None, the "if test" part is omitted. 91 """ 92 xp.prefix = "" 93 fp.prefix = " " 94 it.prefix = " " 95 for_leaf = Leaf(token.NAME, "for") 96 for_leaf.prefix = " " 97 in_leaf = Leaf(token.NAME, "in") 98 in_leaf.prefix = " " 99 inner_args = [for_leaf, fp, in_leaf, it] 100 if test: 101 test.prefix = " " 102 if_leaf = Leaf(token.NAME, "if") 103 if_leaf.prefix = " " 104 inner_args.append(Node(syms.comp_if, [if_leaf, test])) 105 inner = Node(syms.listmaker, [xp, Node(syms.comp_for, inner_args)]) 106 return Node(syms.atom, 107 [Leaf(token.LBRACE, "["), 108 inner, 109 Leaf(token.RBRACE, "]")]) 110 111def FromImport(package_name, name_leafs): 112 """ Return an import statement in the form: 113 from package import name_leafs""" 114 # XXX: May not handle dotted imports properly (eg, package_name='foo.bar') 115 #assert package_name == '.' or '.' not in package_name, "FromImport has "\ 116 # "not been tested with dotted package names -- use at your own "\ 117 # "peril!" 118 119 for leaf in name_leafs: 120 # Pull the leaves out of their old tree 121 leaf.remove() 122 123 children = [Leaf(token.NAME, "from"), 124 Leaf(token.NAME, package_name, prefix=" "), 125 Leaf(token.NAME, "import", prefix=" "), 126 Node(syms.import_as_names, name_leafs)] 127 imp = Node(syms.import_from, children) 128 return imp 129 130def ImportAndCall(node, results, names): 131 """Returns an import statement and calls a method 132 of the module: 133 134 import module 135 module.name()""" 136 obj = results["obj"].clone() 137 if obj.type == syms.arglist: 138 newarglist = obj.clone() 139 else: 140 newarglist = Node(syms.arglist, [obj.clone()]) 141 after = results["after"] 142 if after: 143 after = [n.clone() for n in after] 144 new = Node(syms.power, 145 Attr(Name(names[0]), Name(names[1])) + 146 [Node(syms.trailer, 147 [results["lpar"].clone(), 148 newarglist, 149 results["rpar"].clone()])] + after) 150 new.prefix = node.prefix 151 return new 152 153 154########################################################### 155### Determine whether a node represents a given literal 156########################################################### 157 158def is_tuple(node): 159 """Does the node represent a tuple literal?""" 160 if isinstance(node, Node) and node.children == [LParen(), RParen()]: 161 return True 162 return (isinstance(node, Node) 163 and len(node.children) == 3 164 and isinstance(node.children[0], Leaf) 165 and isinstance(node.children[1], Node) 166 and isinstance(node.children[2], Leaf) 167 and node.children[0].value == "(" 168 and node.children[2].value == ")") 169 170def is_list(node): 171 """Does the node represent a list literal?""" 172 return (isinstance(node, Node) 173 and len(node.children) > 1 174 and isinstance(node.children[0], Leaf) 175 and isinstance(node.children[-1], Leaf) 176 and node.children[0].value == "[" 177 and node.children[-1].value == "]") 178 179 180########################################################### 181### Misc 182########################################################### 183 184def parenthesize(node): 185 return Node(syms.atom, [LParen(), node, RParen()]) 186 187 188consuming_calls = {"sorted", "list", "set", "any", "all", "tuple", "sum", 189 "min", "max", "enumerate"} 190 191def attr_chain(obj, attr): 192 """Follow an attribute chain. 193 194 If you have a chain of objects where a.foo -> b, b.foo-> c, etc, 195 use this to iterate over all objects in the chain. Iteration is 196 terminated by getattr(x, attr) is None. 197 198 Args: 199 obj: the starting object 200 attr: the name of the chaining attribute 201 202 Yields: 203 Each successive object in the chain. 204 """ 205 next = getattr(obj, attr) 206 while next: 207 yield next 208 next = getattr(next, attr) 209 210p0 = """for_stmt< 'for' any 'in' node=any ':' any* > 211 | comp_for< 'for' any 'in' node=any any* > 212 """ 213p1 = """ 214power< 215 ( 'iter' | 'list' | 'tuple' | 'sorted' | 'set' | 'sum' | 216 'any' | 'all' | 'enumerate' | (any* trailer< '.' 'join' >) ) 217 trailer< '(' node=any ')' > 218 any* 219> 220""" 221p2 = """ 222power< 223 ( 'sorted' | 'enumerate' ) 224 trailer< '(' arglist<node=any any*> ')' > 225 any* 226> 227""" 228pats_built = False 229def in_special_context(node): 230 """ Returns true if node is in an environment where all that is required 231 of it is being iterable (ie, it doesn't matter if it returns a list 232 or an iterator). 233 See test_map_nochange in test_fixers.py for some examples and tests. 234 """ 235 global p0, p1, p2, pats_built 236 if not pats_built: 237 p0 = patcomp.compile_pattern(p0) 238 p1 = patcomp.compile_pattern(p1) 239 p2 = patcomp.compile_pattern(p2) 240 pats_built = True 241 patterns = [p0, p1, p2] 242 for pattern, parent in zip(patterns, attr_chain(node, "parent")): 243 results = {} 244 if pattern.match(parent, results) and results["node"] is node: 245 return True 246 return False 247 248def is_probably_builtin(node): 249 """ 250 Check that something isn't an attribute or function name etc. 251 """ 252 prev = node.prev_sibling 253 if prev is not None and prev.type == token.DOT: 254 # Attribute lookup. 255 return False 256 parent = node.parent 257 if parent.type in (syms.funcdef, syms.classdef): 258 return False 259 if parent.type == syms.expr_stmt and parent.children[0] is node: 260 # Assignment. 261 return False 262 if parent.type == syms.parameters or \ 263 (parent.type == syms.typedargslist and ( 264 (prev is not None and prev.type == token.COMMA) or 265 parent.children[0] is node 266 )): 267 # The name of an argument. 268 return False 269 return True 270 271def find_indentation(node): 272 """Find the indentation of *node*.""" 273 while node is not None: 274 if node.type == syms.suite and len(node.children) > 2: 275 indent = node.children[1] 276 if indent.type == token.INDENT: 277 return indent.value 278 node = node.parent 279 return "" 280 281########################################################### 282### The following functions are to find bindings in a suite 283########################################################### 284 285def make_suite(node): 286 if node.type == syms.suite: 287 return node 288 node = node.clone() 289 parent, node.parent = node.parent, None 290 suite = Node(syms.suite, [node]) 291 suite.parent = parent 292 return suite 293 294def find_root(node): 295 """Find the top level namespace.""" 296 # Scamper up to the top level namespace 297 while node.type != syms.file_input: 298 node = node.parent 299 if not node: 300 raise ValueError("root found before file_input node was found.") 301 return node 302 303def does_tree_import(package, name, node): 304 """ Returns true if name is imported from package at the 305 top level of the tree which node belongs to. 306 To cover the case of an import like 'import foo', use 307 None for the package and 'foo' for the name. """ 308 binding = find_binding(name, find_root(node), package) 309 return bool(binding) 310 311def is_import(node): 312 """Returns true if the node is an import statement.""" 313 return node.type in (syms.import_name, syms.import_from) 314 315def touch_import(package, name, node): 316 """ Works like `does_tree_import` but adds an import statement 317 if it was not imported. """ 318 def is_import_stmt(node): 319 return (node.type == syms.simple_stmt and node.children and 320 is_import(node.children[0])) 321 322 root = find_root(node) 323 324 if does_tree_import(package, name, root): 325 return 326 327 # figure out where to insert the new import. First try to find 328 # the first import and then skip to the last one. 329 insert_pos = offset = 0 330 for idx, node in enumerate(root.children): 331 if not is_import_stmt(node): 332 continue 333 for offset, node2 in enumerate(root.children[idx:]): 334 if not is_import_stmt(node2): 335 break 336 insert_pos = idx + offset 337 break 338 339 # if there are no imports where we can insert, find the docstring. 340 # if that also fails, we stick to the beginning of the file 341 if insert_pos == 0: 342 for idx, node in enumerate(root.children): 343 if (node.type == syms.simple_stmt and node.children and 344 node.children[0].type == token.STRING): 345 insert_pos = idx + 1 346 break 347 348 if package is None: 349 import_ = Node(syms.import_name, [ 350 Leaf(token.NAME, "import"), 351 Leaf(token.NAME, name, prefix=" ") 352 ]) 353 else: 354 import_ = FromImport(package, [Leaf(token.NAME, name, prefix=" ")]) 355 356 children = [import_, Newline()] 357 root.insert_child(insert_pos, Node(syms.simple_stmt, children)) 358 359 360_def_syms = {syms.classdef, syms.funcdef} 361def find_binding(name, node, package=None): 362 """ Returns the node which binds variable name, otherwise None. 363 If optional argument package is supplied, only imports will 364 be returned. 365 See test cases for examples.""" 366 for child in node.children: 367 ret = None 368 if child.type == syms.for_stmt: 369 if _find(name, child.children[1]): 370 return child 371 n = find_binding(name, make_suite(child.children[-1]), package) 372 if n: ret = n 373 elif child.type in (syms.if_stmt, syms.while_stmt): 374 n = find_binding(name, make_suite(child.children[-1]), package) 375 if n: ret = n 376 elif child.type == syms.try_stmt: 377 n = find_binding(name, make_suite(child.children[2]), package) 378 if n: 379 ret = n 380 else: 381 for i, kid in enumerate(child.children[3:]): 382 if kid.type == token.COLON and kid.value == ":": 383 # i+3 is the colon, i+4 is the suite 384 n = find_binding(name, make_suite(child.children[i+4]), package) 385 if n: ret = n 386 elif child.type in _def_syms and child.children[1].value == name: 387 ret = child 388 elif _is_import_binding(child, name, package): 389 ret = child 390 elif child.type == syms.simple_stmt: 391 ret = find_binding(name, child, package) 392 elif child.type == syms.expr_stmt: 393 if _find(name, child.children[0]): 394 ret = child 395 396 if ret: 397 if not package: 398 return ret 399 if is_import(ret): 400 return ret 401 return None 402 403_block_syms = {syms.funcdef, syms.classdef, syms.trailer} 404def _find(name, node): 405 nodes = [node] 406 while nodes: 407 node = nodes.pop() 408 if node.type > 256 and node.type not in _block_syms: 409 nodes.extend(node.children) 410 elif node.type == token.NAME and node.value == name: 411 return node 412 return None 413 414def _is_import_binding(node, name, package=None): 415 """ Will return node if node will import name, or node 416 will import * from package. None is returned otherwise. 417 See test cases for examples. """ 418 419 if node.type == syms.import_name and not package: 420 imp = node.children[1] 421 if imp.type == syms.dotted_as_names: 422 for child in imp.children: 423 if child.type == syms.dotted_as_name: 424 if child.children[2].value == name: 425 return node 426 elif child.type == token.NAME and child.value == name: 427 return node 428 elif imp.type == syms.dotted_as_name: 429 last = imp.children[-1] 430 if last.type == token.NAME and last.value == name: 431 return node 432 elif imp.type == token.NAME and imp.value == name: 433 return node 434 elif node.type == syms.import_from: 435 # str(...) is used to make life easier here, because 436 # from a.b import parses to ['import', ['a', '.', 'b'], ...] 437 if package and str(node.children[1]).strip() != package: 438 return None 439 n = node.children[3] 440 if package and _find("as", n): 441 # See test_from_import_as for explanation 442 return None 443 elif n.type == syms.import_as_names and _find(name, n): 444 return node 445 elif n.type == syms.import_as_name: 446 child = n.children[2] 447 if child.type == token.NAME and child.value == name: 448 return node 449 elif n.type == token.NAME and n.value == name: 450 return node 451 elif package and n.type == token.STAR: 452 return node 453 return None 454