1# Copyright 2006 Google, Inc. All Rights Reserved. 2# Licensed to PSF under a Contributor Agreement. 3 4"""Refactoring framework. 5 6Used as a main program, this can refactor any number of files and/or 7recursively descend down directories. Imported as a module, this 8provides infrastructure to write your own refactoring tool. 9""" 10 11from __future__ import with_statement 12 13__author__ = "Guido van Rossum <guido@python.org>" 14 15 16# Python imports 17import os 18import sys 19import logging 20import operator 21import collections 22import StringIO 23from itertools import chain 24 25# Local imports 26from .pgen2 import driver, tokenize, token 27from .fixer_util import find_root 28from . import pytree, pygram 29from . import btm_utils as bu 30from . import btm_matcher as bm 31 32 33def get_all_fix_names(fixer_pkg, remove_prefix=True): 34 """Return a sorted list of all available fix names in the given package.""" 35 pkg = __import__(fixer_pkg, [], [], ["*"]) 36 fixer_dir = os.path.dirname(pkg.__file__) 37 fix_names = [] 38 for name in sorted(os.listdir(fixer_dir)): 39 if name.startswith("fix_") and name.endswith(".py"): 40 if remove_prefix: 41 name = name[4:] 42 fix_names.append(name[:-3]) 43 return fix_names 44 45 46class _EveryNode(Exception): 47 pass 48 49 50def _get_head_types(pat): 51 """ Accepts a pytree Pattern Node and returns a set 52 of the pattern types which will match first. """ 53 54 if isinstance(pat, (pytree.NodePattern, pytree.LeafPattern)): 55 # NodePatters must either have no type and no content 56 # or a type and content -- so they don't get any farther 57 # Always return leafs 58 if pat.type is None: 59 raise _EveryNode 60 return set([pat.type]) 61 62 if isinstance(pat, pytree.NegatedPattern): 63 if pat.content: 64 return _get_head_types(pat.content) 65 raise _EveryNode # Negated Patterns don't have a type 66 67 if isinstance(pat, pytree.WildcardPattern): 68 # Recurse on each node in content 69 r = set() 70 for p in pat.content: 71 for x in p: 72 r.update(_get_head_types(x)) 73 return r 74 75 raise Exception("Oh no! I don't understand pattern %s" %(pat)) 76 77 78def _get_headnode_dict(fixer_list): 79 """ Accepts a list of fixers and returns a dictionary 80 of head node type --> fixer list. """ 81 head_nodes = collections.defaultdict(list) 82 every = [] 83 for fixer in fixer_list: 84 if fixer.pattern: 85 try: 86 heads = _get_head_types(fixer.pattern) 87 except _EveryNode: 88 every.append(fixer) 89 else: 90 for node_type in heads: 91 head_nodes[node_type].append(fixer) 92 else: 93 if fixer._accept_type is not None: 94 head_nodes[fixer._accept_type].append(fixer) 95 else: 96 every.append(fixer) 97 for node_type in chain(pygram.python_grammar.symbol2number.itervalues(), 98 pygram.python_grammar.tokens): 99 head_nodes[node_type].extend(every) 100 return dict(head_nodes) 101 102 103def get_fixers_from_package(pkg_name): 104 """ 105 Return the fully qualified names for fixers in the package pkg_name. 106 """ 107 return [pkg_name + "." + fix_name 108 for fix_name in get_all_fix_names(pkg_name, False)] 109 110def _identity(obj): 111 return obj 112 113if sys.version_info < (3, 0): 114 import codecs 115 _open_with_encoding = codecs.open 116 # codecs.open doesn't translate newlines sadly. 117 def _from_system_newlines(input): 118 return input.replace(u"\r\n", u"\n") 119 def _to_system_newlines(input): 120 if os.linesep != "\n": 121 return input.replace(u"\n", os.linesep) 122 else: 123 return input 124else: 125 _open_with_encoding = open 126 _from_system_newlines = _identity 127 _to_system_newlines = _identity 128 129 130def _detect_future_features(source): 131 have_docstring = False 132 gen = tokenize.generate_tokens(StringIO.StringIO(source).readline) 133 def advance(): 134 tok = gen.next() 135 return tok[0], tok[1] 136 ignore = frozenset((token.NEWLINE, tokenize.NL, token.COMMENT)) 137 features = set() 138 try: 139 while True: 140 tp, value = advance() 141 if tp in ignore: 142 continue 143 elif tp == token.STRING: 144 if have_docstring: 145 break 146 have_docstring = True 147 elif tp == token.NAME and value == u"from": 148 tp, value = advance() 149 if tp != token.NAME or value != u"__future__": 150 break 151 tp, value = advance() 152 if tp != token.NAME or value != u"import": 153 break 154 tp, value = advance() 155 if tp == token.OP and value == u"(": 156 tp, value = advance() 157 while tp == token.NAME: 158 features.add(value) 159 tp, value = advance() 160 if tp != token.OP or value != u",": 161 break 162 tp, value = advance() 163 else: 164 break 165 except StopIteration: 166 pass 167 return frozenset(features) 168 169 170class FixerError(Exception): 171 """A fixer could not be loaded.""" 172 173 174class RefactoringTool(object): 175 176 _default_options = {"print_function" : False, 177 "write_unchanged_files" : False} 178 179 CLASS_PREFIX = "Fix" # The prefix for fixer classes 180 FILE_PREFIX = "fix_" # The prefix for modules with a fixer within 181 182 def __init__(self, fixer_names, options=None, explicit=None): 183 """Initializer. 184 185 Args: 186 fixer_names: a list of fixers to import 187 options: a dict with configuration. 188 explicit: a list of fixers to run even if they are explicit. 189 """ 190 self.fixers = fixer_names 191 self.explicit = explicit or [] 192 self.options = self._default_options.copy() 193 if options is not None: 194 self.options.update(options) 195 if self.options["print_function"]: 196 self.grammar = pygram.python_grammar_no_print_statement 197 else: 198 self.grammar = pygram.python_grammar 199 # When this is True, the refactor*() methods will call write_file() for 200 # files processed even if they were not changed during refactoring. If 201 # and only if the refactor method's write parameter was True. 202 self.write_unchanged_files = self.options.get("write_unchanged_files") 203 self.errors = [] 204 self.logger = logging.getLogger("RefactoringTool") 205 self.fixer_log = [] 206 self.wrote = False 207 self.driver = driver.Driver(self.grammar, 208 convert=pytree.convert, 209 logger=self.logger) 210 self.pre_order, self.post_order = self.get_fixers() 211 212 213 self.files = [] # List of files that were or should be modified 214 215 self.BM = bm.BottomMatcher() 216 self.bmi_pre_order = [] # Bottom Matcher incompatible fixers 217 self.bmi_post_order = [] 218 219 for fixer in chain(self.post_order, self.pre_order): 220 if fixer.BM_compatible: 221 self.BM.add_fixer(fixer) 222 # remove fixers that will be handled by the bottom-up 223 # matcher 224 elif fixer in self.pre_order: 225 self.bmi_pre_order.append(fixer) 226 elif fixer in self.post_order: 227 self.bmi_post_order.append(fixer) 228 229 self.bmi_pre_order_heads = _get_headnode_dict(self.bmi_pre_order) 230 self.bmi_post_order_heads = _get_headnode_dict(self.bmi_post_order) 231 232 233 234 def get_fixers(self): 235 """Inspects the options to load the requested patterns and handlers. 236 237 Returns: 238 (pre_order, post_order), where pre_order is the list of fixers that 239 want a pre-order AST traversal, and post_order is the list that want 240 post-order traversal. 241 """ 242 pre_order_fixers = [] 243 post_order_fixers = [] 244 for fix_mod_path in self.fixers: 245 mod = __import__(fix_mod_path, {}, {}, ["*"]) 246 fix_name = fix_mod_path.rsplit(".", 1)[-1] 247 if fix_name.startswith(self.FILE_PREFIX): 248 fix_name = fix_name[len(self.FILE_PREFIX):] 249 parts = fix_name.split("_") 250 class_name = self.CLASS_PREFIX + "".join([p.title() for p in parts]) 251 try: 252 fix_class = getattr(mod, class_name) 253 except AttributeError: 254 raise FixerError("Can't find %s.%s" % (fix_name, class_name)) 255 fixer = fix_class(self.options, self.fixer_log) 256 if fixer.explicit and self.explicit is not True and \ 257 fix_mod_path not in self.explicit: 258 self.log_message("Skipping optional fixer: %s", fix_name) 259 continue 260 261 self.log_debug("Adding transformation: %s", fix_name) 262 if fixer.order == "pre": 263 pre_order_fixers.append(fixer) 264 elif fixer.order == "post": 265 post_order_fixers.append(fixer) 266 else: 267 raise FixerError("Illegal fixer order: %r" % fixer.order) 268 269 key_func = operator.attrgetter("run_order") 270 pre_order_fixers.sort(key=key_func) 271 post_order_fixers.sort(key=key_func) 272 return (pre_order_fixers, post_order_fixers) 273 274 def log_error(self, msg, *args, **kwds): 275 """Called when an error occurs.""" 276 raise 277 278 def log_message(self, msg, *args): 279 """Hook to log a message.""" 280 if args: 281 msg = msg % args 282 self.logger.info(msg) 283 284 def log_debug(self, msg, *args): 285 if args: 286 msg = msg % args 287 self.logger.debug(msg) 288 289 def print_output(self, old_text, new_text, filename, equal): 290 """Called with the old version, new version, and filename of a 291 refactored file.""" 292 pass 293 294 def refactor(self, items, write=False, doctests_only=False): 295 """Refactor a list of files and directories.""" 296 297 for dir_or_file in items: 298 if os.path.isdir(dir_or_file): 299 self.refactor_dir(dir_or_file, write, doctests_only) 300 else: 301 self.refactor_file(dir_or_file, write, doctests_only) 302 303 def refactor_dir(self, dir_name, write=False, doctests_only=False): 304 """Descends down a directory and refactor every Python file found. 305 306 Python files are assumed to have a .py extension. 307 308 Files and subdirectories starting with '.' are skipped. 309 """ 310 py_ext = os.extsep + "py" 311 for dirpath, dirnames, filenames in os.walk(dir_name): 312 self.log_debug("Descending into %s", dirpath) 313 dirnames.sort() 314 filenames.sort() 315 for name in filenames: 316 if (not name.startswith(".") and 317 os.path.splitext(name)[1] == py_ext): 318 fullname = os.path.join(dirpath, name) 319 self.refactor_file(fullname, write, doctests_only) 320 # Modify dirnames in-place to remove subdirs with leading dots 321 dirnames[:] = [dn for dn in dirnames if not dn.startswith(".")] 322 323 def _read_python_source(self, filename): 324 """ 325 Do our best to decode a Python source file correctly. 326 """ 327 try: 328 f = open(filename, "rb") 329 except IOError as err: 330 self.log_error("Can't open %s: %s", filename, err) 331 return None, None 332 try: 333 encoding = tokenize.detect_encoding(f.readline)[0] 334 finally: 335 f.close() 336 with _open_with_encoding(filename, "r", encoding=encoding) as f: 337 return _from_system_newlines(f.read()), encoding 338 339 def refactor_file(self, filename, write=False, doctests_only=False): 340 """Refactors a file.""" 341 input, encoding = self._read_python_source(filename) 342 if input is None: 343 # Reading the file failed. 344 return 345 input += u"\n" # Silence certain parse errors 346 if doctests_only: 347 self.log_debug("Refactoring doctests in %s", filename) 348 output = self.refactor_docstring(input, filename) 349 if self.write_unchanged_files or output != input: 350 self.processed_file(output, filename, input, write, encoding) 351 else: 352 self.log_debug("No doctest changes in %s", filename) 353 else: 354 tree = self.refactor_string(input, filename) 355 if self.write_unchanged_files or (tree and tree.was_changed): 356 # The [:-1] is to take off the \n we added earlier 357 self.processed_file(unicode(tree)[:-1], filename, 358 write=write, encoding=encoding) 359 else: 360 self.log_debug("No changes in %s", filename) 361 362 def refactor_string(self, data, name): 363 """Refactor a given input string. 364 365 Args: 366 data: a string holding the code to be refactored. 367 name: a human-readable name for use in error/log messages. 368 369 Returns: 370 An AST corresponding to the refactored input stream; None if 371 there were errors during the parse. 372 """ 373 features = _detect_future_features(data) 374 if "print_function" in features: 375 self.driver.grammar = pygram.python_grammar_no_print_statement 376 try: 377 tree = self.driver.parse_string(data) 378 except Exception as err: 379 self.log_error("Can't parse %s: %s: %s", 380 name, err.__class__.__name__, err) 381 return 382 finally: 383 self.driver.grammar = self.grammar 384 tree.future_features = features 385 self.log_debug("Refactoring %s", name) 386 self.refactor_tree(tree, name) 387 return tree 388 389 def refactor_stdin(self, doctests_only=False): 390 input = sys.stdin.read() 391 if doctests_only: 392 self.log_debug("Refactoring doctests in stdin") 393 output = self.refactor_docstring(input, "<stdin>") 394 if self.write_unchanged_files or output != input: 395 self.processed_file(output, "<stdin>", input) 396 else: 397 self.log_debug("No doctest changes in stdin") 398 else: 399 tree = self.refactor_string(input, "<stdin>") 400 if self.write_unchanged_files or (tree and tree.was_changed): 401 self.processed_file(unicode(tree), "<stdin>", input) 402 else: 403 self.log_debug("No changes in stdin") 404 405 def refactor_tree(self, tree, name): 406 """Refactors a parse tree (modifying the tree in place). 407 408 For compatible patterns the bottom matcher module is 409 used. Otherwise the tree is traversed node-to-node for 410 matches. 411 412 Args: 413 tree: a pytree.Node instance representing the root of the tree 414 to be refactored. 415 name: a human-readable name for this tree. 416 417 Returns: 418 True if the tree was modified, False otherwise. 419 """ 420 421 for fixer in chain(self.pre_order, self.post_order): 422 fixer.start_tree(tree, name) 423 424 #use traditional matching for the incompatible fixers 425 self.traverse_by(self.bmi_pre_order_heads, tree.pre_order()) 426 self.traverse_by(self.bmi_post_order_heads, tree.post_order()) 427 428 # obtain a set of candidate nodes 429 match_set = self.BM.run(tree.leaves()) 430 431 while any(match_set.values()): 432 for fixer in self.BM.fixers: 433 if fixer in match_set and match_set[fixer]: 434 #sort by depth; apply fixers from bottom(of the AST) to top 435 match_set[fixer].sort(key=pytree.Base.depth, reverse=True) 436 437 if fixer.keep_line_order: 438 #some fixers(eg fix_imports) must be applied 439 #with the original file's line order 440 match_set[fixer].sort(key=pytree.Base.get_lineno) 441 442 for node in list(match_set[fixer]): 443 if node in match_set[fixer]: 444 match_set[fixer].remove(node) 445 446 try: 447 find_root(node) 448 except ValueError: 449 # this node has been cut off from a 450 # previous transformation ; skip 451 continue 452 453 if node.fixers_applied and fixer in node.fixers_applied: 454 # do not apply the same fixer again 455 continue 456 457 results = fixer.match(node) 458 459 if results: 460 new = fixer.transform(node, results) 461 if new is not None: 462 node.replace(new) 463 #new.fixers_applied.append(fixer) 464 for node in new.post_order(): 465 # do not apply the fixer again to 466 # this or any subnode 467 if not node.fixers_applied: 468 node.fixers_applied = [] 469 node.fixers_applied.append(fixer) 470 471 # update the original match set for 472 # the added code 473 new_matches = self.BM.run(new.leaves()) 474 for fxr in new_matches: 475 if not fxr in match_set: 476 match_set[fxr]=[] 477 478 match_set[fxr].extend(new_matches[fxr]) 479 480 for fixer in chain(self.pre_order, self.post_order): 481 fixer.finish_tree(tree, name) 482 return tree.was_changed 483 484 def traverse_by(self, fixers, traversal): 485 """Traverse an AST, applying a set of fixers to each node. 486 487 This is a helper method for refactor_tree(). 488 489 Args: 490 fixers: a list of fixer instances. 491 traversal: a generator that yields AST nodes. 492 493 Returns: 494 None 495 """ 496 if not fixers: 497 return 498 for node in traversal: 499 for fixer in fixers[node.type]: 500 results = fixer.match(node) 501 if results: 502 new = fixer.transform(node, results) 503 if new is not None: 504 node.replace(new) 505 node = new 506 507 def processed_file(self, new_text, filename, old_text=None, write=False, 508 encoding=None): 509 """ 510 Called when a file has been refactored and there may be changes. 511 """ 512 self.files.append(filename) 513 if old_text is None: 514 old_text = self._read_python_source(filename)[0] 515 if old_text is None: 516 return 517 equal = old_text == new_text 518 self.print_output(old_text, new_text, filename, equal) 519 if equal: 520 self.log_debug("No changes to %s", filename) 521 if not self.write_unchanged_files: 522 return 523 if write: 524 self.write_file(new_text, filename, old_text, encoding) 525 else: 526 self.log_debug("Not writing changes to %s", filename) 527 528 def write_file(self, new_text, filename, old_text, encoding=None): 529 """Writes a string to a file. 530 531 It first shows a unified diff between the old text and the new text, and 532 then rewrites the file; the latter is only done if the write option is 533 set. 534 """ 535 try: 536 f = _open_with_encoding(filename, "w", encoding=encoding) 537 except os.error as err: 538 self.log_error("Can't create %s: %s", filename, err) 539 return 540 try: 541 f.write(_to_system_newlines(new_text)) 542 except os.error as err: 543 self.log_error("Can't write %s: %s", filename, err) 544 finally: 545 f.close() 546 self.log_debug("Wrote changes to %s", filename) 547 self.wrote = True 548 549 PS1 = ">>> " 550 PS2 = "... " 551 552 def refactor_docstring(self, input, filename): 553 """Refactors a docstring, looking for doctests. 554 555 This returns a modified version of the input string. It looks 556 for doctests, which start with a ">>>" prompt, and may be 557 continued with "..." prompts, as long as the "..." is indented 558 the same as the ">>>". 559 560 (Unfortunately we can't use the doctest module's parser, 561 since, like most parsers, it is not geared towards preserving 562 the original source.) 563 """ 564 result = [] 565 block = None 566 block_lineno = None 567 indent = None 568 lineno = 0 569 for line in input.splitlines(True): 570 lineno += 1 571 if line.lstrip().startswith(self.PS1): 572 if block is not None: 573 result.extend(self.refactor_doctest(block, block_lineno, 574 indent, filename)) 575 block_lineno = lineno 576 block = [line] 577 i = line.find(self.PS1) 578 indent = line[:i] 579 elif (indent is not None and 580 (line.startswith(indent + self.PS2) or 581 line == indent + self.PS2.rstrip() + u"\n")): 582 block.append(line) 583 else: 584 if block is not None: 585 result.extend(self.refactor_doctest(block, block_lineno, 586 indent, filename)) 587 block = None 588 indent = None 589 result.append(line) 590 if block is not None: 591 result.extend(self.refactor_doctest(block, block_lineno, 592 indent, filename)) 593 return u"".join(result) 594 595 def refactor_doctest(self, block, lineno, indent, filename): 596 """Refactors one doctest. 597 598 A doctest is given as a block of lines, the first of which starts 599 with ">>>" (possibly indented), while the remaining lines start 600 with "..." (identically indented). 601 602 """ 603 try: 604 tree = self.parse_block(block, lineno, indent) 605 except Exception as err: 606 if self.logger.isEnabledFor(logging.DEBUG): 607 for line in block: 608 self.log_debug("Source: %s", line.rstrip(u"\n")) 609 self.log_error("Can't parse docstring in %s line %s: %s: %s", 610 filename, lineno, err.__class__.__name__, err) 611 return block 612 if self.refactor_tree(tree, filename): 613 new = unicode(tree).splitlines(True) 614 # Undo the adjustment of the line numbers in wrap_toks() below. 615 clipped, new = new[:lineno-1], new[lineno-1:] 616 assert clipped == [u"\n"] * (lineno-1), clipped 617 if not new[-1].endswith(u"\n"): 618 new[-1] += u"\n" 619 block = [indent + self.PS1 + new.pop(0)] 620 if new: 621 block += [indent + self.PS2 + line for line in new] 622 return block 623 624 def summarize(self): 625 if self.wrote: 626 were = "were" 627 else: 628 were = "need to be" 629 if not self.files: 630 self.log_message("No files %s modified.", were) 631 else: 632 self.log_message("Files that %s modified:", were) 633 for file in self.files: 634 self.log_message(file) 635 if self.fixer_log: 636 self.log_message("Warnings/messages while refactoring:") 637 for message in self.fixer_log: 638 self.log_message(message) 639 if self.errors: 640 if len(self.errors) == 1: 641 self.log_message("There was 1 error:") 642 else: 643 self.log_message("There were %d errors:", len(self.errors)) 644 for msg, args, kwds in self.errors: 645 self.log_message(msg, *args, **kwds) 646 647 def parse_block(self, block, lineno, indent): 648 """Parses a block into a tree. 649 650 This is necessary to get correct line number / offset information 651 in the parser diagnostics and embedded into the parse tree. 652 """ 653 tree = self.driver.parse_tokens(self.wrap_toks(block, lineno, indent)) 654 tree.future_features = frozenset() 655 return tree 656 657 def wrap_toks(self, block, lineno, indent): 658 """Wraps a tokenize stream to systematically modify start/end.""" 659 tokens = tokenize.generate_tokens(self.gen_lines(block, indent).next) 660 for type, value, (line0, col0), (line1, col1), line_text in tokens: 661 line0 += lineno - 1 662 line1 += lineno - 1 663 # Don't bother updating the columns; this is too complicated 664 # since line_text would also have to be updated and it would 665 # still break for tokens spanning lines. Let the user guess 666 # that the column numbers for doctests are relative to the 667 # end of the prompt string (PS1 or PS2). 668 yield type, value, (line0, col0), (line1, col1), line_text 669 670 671 def gen_lines(self, block, indent): 672 """Generates lines as expected by tokenize from a list of lines. 673 674 This strips the first len(indent + self.PS1) characters off each line. 675 """ 676 prefix1 = indent + self.PS1 677 prefix2 = indent + self.PS2 678 prefix = prefix1 679 for line in block: 680 if line.startswith(prefix): 681 yield line[len(prefix):] 682 elif line == prefix.rstrip() + u"\n": 683 yield u"\n" 684 else: 685 raise AssertionError("line=%r, prefix=%r" % (line, prefix)) 686 prefix = prefix2 687 while True: 688 yield "" 689 690 691class MultiprocessingUnsupported(Exception): 692 pass 693 694 695class MultiprocessRefactoringTool(RefactoringTool): 696 697 def __init__(self, *args, **kwargs): 698 super(MultiprocessRefactoringTool, self).__init__(*args, **kwargs) 699 self.queue = None 700 self.output_lock = None 701 702 def refactor(self, items, write=False, doctests_only=False, 703 num_processes=1): 704 if num_processes == 1: 705 return super(MultiprocessRefactoringTool, self).refactor( 706 items, write, doctests_only) 707 try: 708 import multiprocessing 709 except ImportError: 710 raise MultiprocessingUnsupported 711 if self.queue is not None: 712 raise RuntimeError("already doing multiple processes") 713 self.queue = multiprocessing.JoinableQueue() 714 self.output_lock = multiprocessing.Lock() 715 processes = [multiprocessing.Process(target=self._child) 716 for i in xrange(num_processes)] 717 try: 718 for p in processes: 719 p.start() 720 super(MultiprocessRefactoringTool, self).refactor(items, write, 721 doctests_only) 722 finally: 723 self.queue.join() 724 for i in xrange(num_processes): 725 self.queue.put(None) 726 for p in processes: 727 if p.is_alive(): 728 p.join() 729 self.queue = None 730 731 def _child(self): 732 task = self.queue.get() 733 while task is not None: 734 args, kwargs = task 735 try: 736 super(MultiprocessRefactoringTool, self).refactor_file( 737 *args, **kwargs) 738 finally: 739 self.queue.task_done() 740 task = self.queue.get() 741 742 def refactor_file(self, *args, **kwargs): 743 if self.queue is not None: 744 self.queue.put((args, kwargs)) 745 else: 746 return super(MultiprocessRefactoringTool, self).refactor_file( 747 *args, **kwargs) 748