1# Lint as: python2, python3 2# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16"""Upgrader for Python scripts according to an API change specification.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import ast 23import collections 24import os 25import re 26import shutil 27import sys 28import tempfile 29import traceback 30 31import pasta 32import six 33from six.moves import range 34 35# Some regular expressions we will need for parsing 36FIND_OPEN = re.compile(r"^\s*(\[).*$") 37FIND_STRING_CHARS = re.compile(r"['\"]") 38 39 40INFO = "INFO" 41WARNING = "WARNING" 42ERROR = "ERROR" 43 44 45ImportRename = collections.namedtuple( 46 "ImportRename", ["new_name", "excluded_prefixes"]) 47 48 49def full_name_node(name, ctx=ast.Load()): 50 """Make an Attribute or Name node for name. 51 52 Translate a qualified name into nested Attribute nodes (and a Name node). 53 54 Args: 55 name: The name to translate to a node. 56 ctx: What context this name is used in. Defaults to Load() 57 58 Returns: 59 A Name or Attribute node. 60 """ 61 names = six.ensure_str(name).split(".") 62 names.reverse() 63 node = ast.Name(id=names.pop(), ctx=ast.Load()) 64 while names: 65 node = ast.Attribute(value=node, attr=names.pop(), ctx=ast.Load()) 66 67 # Change outermost ctx to the one given to us (inner ones should be Load). 68 node.ctx = ctx 69 return node 70 71 72def get_arg_value(node, arg_name, arg_pos=None): 73 """Get the value of an argument from a ast.Call node. 74 75 This function goes through the positional and keyword arguments to check 76 whether a given argument was used, and if so, returns its value (the node 77 representing its value). 78 79 This cannot introspect *args or **args, but it safely handles *args in 80 Python3.5+. 81 82 Args: 83 node: The ast.Call node to extract arg values from. 84 arg_name: The name of the argument to extract. 85 arg_pos: The position of the argument (in case it's passed as a positional 86 argument). 87 88 Returns: 89 A tuple (arg_present, arg_value) containing a boolean indicating whether 90 the argument is present, and its value in case it is. 91 """ 92 # Check keyword args 93 if arg_name is not None: 94 for kw in node.keywords: 95 if kw.arg == arg_name: 96 return (True, kw.value) 97 98 # Check positional args 99 if arg_pos is not None: 100 idx = 0 101 for arg in node.args: 102 if sys.version_info[:2] >= (3, 5) and isinstance(arg, ast.Starred): 103 continue # Can't parse Starred 104 if idx == arg_pos: 105 return (True, arg) 106 idx += 1 107 108 return (False, None) 109 110 111def uses_star_args_in_call(node): 112 """Check if an ast.Call node uses arbitrary-length positional *args. 113 114 This function works with the AST call node format of Python3.5+ 115 as well as the different AST format of earlier versions of Python. 116 117 Args: 118 node: The ast.Call node to check arg values for. 119 120 Returns: 121 True if the node uses starred variadic positional args or keyword args. 122 False if it does not. 123 """ 124 if sys.version_info[:2] >= (3, 5): 125 # Check for an *args usage in python 3.5+ 126 for arg in node.args: 127 if isinstance(arg, ast.Starred): 128 return True 129 else: 130 if node.starargs: 131 return True 132 return False 133 134 135def uses_star_kwargs_in_call(node): 136 """Check if an ast.Call node uses arbitrary-length **kwargs. 137 138 This function works with the AST call node format of Python3.5+ 139 as well as the different AST format of earlier versions of Python. 140 141 Args: 142 node: The ast.Call node to check arg values for. 143 144 Returns: 145 True if the node uses starred variadic positional args or keyword args. 146 False if it does not. 147 """ 148 if sys.version_info[:2] >= (3, 5): 149 # Check for a **kwarg usage in python 3.5+ 150 for keyword in node.keywords: 151 if keyword.arg is None: 152 return True 153 else: 154 if node.kwargs: 155 return True 156 return False 157 158 159def uses_star_args_or_kwargs_in_call(node): 160 """Check if an ast.Call node uses arbitrary-length *args or **kwargs. 161 162 This function works with the AST call node format of Python3.5+ 163 as well as the different AST format of earlier versions of Python. 164 165 Args: 166 node: The ast.Call node to check arg values for. 167 168 Returns: 169 True if the node uses starred variadic positional args or keyword args. 170 False if it does not. 171 """ 172 return uses_star_args_in_call(node) or uses_star_kwargs_in_call(node) 173 174 175def excluded_from_module_rename(module, import_rename_spec): 176 """Check if this module import should not be renamed. 177 178 Args: 179 module: (string) module name. 180 import_rename_spec: ImportRename instance. 181 182 Returns: 183 True if this import should not be renamed according to the 184 import_rename_spec. 185 """ 186 for excluded_prefix in import_rename_spec.excluded_prefixes: 187 if module.startswith(excluded_prefix): 188 return True 189 return False 190 191 192class APIChangeSpec(object): 193 """This class defines the transformations that need to happen. 194 195 This class must provide the following fields: 196 197 * `function_keyword_renames`: maps function names to a map of old -> new 198 argument names 199 * `symbol_renames`: maps function names to new function names 200 * `change_to_function`: a set of function names that have changed (for 201 notifications) 202 * `function_reorders`: maps functions whose argument order has changed to the 203 list of arguments in the new order 204 * `function_warnings`: maps full names of functions to warnings that will be 205 printed out if the function is used. (e.g. tf.nn.convolution()) 206 * `function_transformers`: maps function names to custom handlers 207 * `module_deprecations`: maps module names to warnings that will be printed 208 if the module is still used after all other transformations have run 209 * `import_renames`: maps import name (must be a short name without '.') 210 to ImportRename instance. 211 212 For an example, see `TFAPIChangeSpec`. 213 """ 214 215 def preprocess(self, root_node): # pylint: disable=unused-argument 216 """Preprocess a parse tree. Return a preprocessed node, logs and errors.""" 217 return root_node, [], [] 218 219 def clear_preprocessing(self): 220 """Restore this APIChangeSpec to before it preprocessed a file. 221 222 This is needed if preprocessing a file changed any rewriting rules. 223 """ 224 pass 225 226 227class NoUpdateSpec(APIChangeSpec): 228 """A specification of an API change which doesn't change anything.""" 229 230 def __init__(self): 231 self.function_handle = {} 232 self.function_reorders = {} 233 self.function_keyword_renames = {} 234 self.symbol_renames = {} 235 self.function_warnings = {} 236 self.change_to_function = {} 237 self.module_deprecations = {} 238 self.function_transformers = {} 239 self.import_renames = {} 240 241 242class _PastaEditVisitor(ast.NodeVisitor): 243 """AST Visitor that processes function calls. 244 245 Updates function calls from old API version to new API version using a given 246 change spec. 247 """ 248 249 def __init__(self, api_change_spec): 250 self._api_change_spec = api_change_spec 251 self._log = [] # Holds 4-tuples: severity, line, col, msg. 252 self._stack = [] # Allow easy access to parents. 253 254 # Overridden to maintain a stack of nodes to allow for parent access 255 def visit(self, node): 256 self._stack.append(node) 257 super(_PastaEditVisitor, self).visit(node) 258 self._stack.pop() 259 260 @property 261 def errors(self): 262 return [log for log in self._log if log[0] == ERROR] 263 264 @property 265 def warnings(self): 266 return [log for log in self._log if log[0] == WARNING] 267 268 @property 269 def warnings_and_errors(self): 270 return [log for log in self._log if log[0] in (WARNING, ERROR)] 271 272 @property 273 def info(self): 274 return [log for log in self._log if log[0] == INFO] 275 276 @property 277 def log(self): 278 return self._log 279 280 def add_log(self, severity, lineno, col, msg): 281 self._log.append((severity, lineno, col, msg)) 282 print("%s line %d:%d: %s" % (severity, lineno, col, msg)) 283 284 def add_logs(self, logs): 285 """Record a log and print it. 286 287 The log should be a tuple `(severity, lineno, col_offset, msg)`, which will 288 be printed and recorded. It is part of the log available in the `self.log` 289 property. 290 291 Args: 292 logs: The logs to add. Must be a list of tuples 293 `(severity, lineno, col_offset, msg)`. 294 """ 295 self._log.extend(logs) 296 for log in logs: 297 print("%s line %d:%d: %s" % log) 298 299 def _get_applicable_entries(self, transformer_field, full_name, name): 300 """Get all list entries indexed by name that apply to full_name or name.""" 301 # Transformers are indexed to full name, name, or no name 302 # as a performance optimization. 303 function_transformers = getattr(self._api_change_spec, 304 transformer_field, {}) 305 306 glob_name = "*." + six.ensure_str(name) if name else None 307 transformers = [] 308 if full_name in function_transformers: 309 transformers.append(function_transformers[full_name]) 310 if glob_name in function_transformers: 311 transformers.append(function_transformers[glob_name]) 312 if "*" in function_transformers: 313 transformers.append(function_transformers["*"]) 314 return transformers 315 316 def _get_applicable_dict(self, transformer_field, full_name, name): 317 """Get all dict entries indexed by name that apply to full_name or name.""" 318 # Transformers are indexed to full name, name, or no name 319 # as a performance optimization. 320 function_transformers = getattr(self._api_change_spec, 321 transformer_field, {}) 322 323 glob_name = "*." + six.ensure_str(name) if name else None 324 transformers = function_transformers.get("*", {}).copy() 325 transformers.update(function_transformers.get(glob_name, {})) 326 transformers.update(function_transformers.get(full_name, {})) 327 return transformers 328 329 def _get_full_name(self, node): 330 """Traverse an Attribute node to generate a full name, e.g., "tf.foo.bar". 331 332 This is the inverse of `full_name_node`. 333 334 Args: 335 node: A Node of type Attribute. 336 337 Returns: 338 a '.'-delimited full-name or None if node was not Attribute or Name. 339 i.e. `foo()+b).bar` returns None, while `a.b.c` would return "a.b.c". 340 """ 341 curr = node 342 items = [] 343 while not isinstance(curr, ast.Name): 344 if not isinstance(curr, ast.Attribute): 345 return None 346 items.append(curr.attr) 347 curr = curr.value 348 items.append(curr.id) 349 return ".".join(reversed(items)) 350 351 def _maybe_add_warning(self, node, full_name): 352 """Adds an error to be printed about full_name at node.""" 353 function_warnings = self._api_change_spec.function_warnings 354 if full_name in function_warnings: 355 level, message = function_warnings[full_name] 356 message = six.ensure_str(message).replace("<function name>", full_name) 357 self.add_log(level, node.lineno, node.col_offset, 358 "%s requires manual check. %s" % (full_name, message)) 359 return True 360 else: 361 return False 362 363 def _maybe_add_module_deprecation_warning(self, node, full_name, whole_name): 364 """Adds a warning if full_name is a deprecated module.""" 365 warnings = self._api_change_spec.module_deprecations 366 if full_name in warnings: 367 level, message = warnings[full_name] 368 message = six.ensure_str(message).replace("<function name>", 369 six.ensure_str(whole_name)) 370 self.add_log(level, node.lineno, node.col_offset, 371 "Using member %s in deprecated module %s. %s" % (whole_name, 372 full_name, 373 message)) 374 return True 375 else: 376 return False 377 378 def _maybe_add_call_warning(self, node, full_name, name): 379 """Print a warning when specific functions are called with selected args. 380 381 The function _print_warning_for_function matches the full name of the called 382 function, e.g., tf.foo.bar(). This function matches the function name that 383 is called, as long as the function is an attribute. For example, 384 `tf.foo.bar()` and `foo.bar()` are matched, but not `bar()`. 385 386 Args: 387 node: ast.Call object 388 full_name: The precomputed full name of the callable, if one exists, None 389 otherwise. 390 name: The precomputed name of the callable, if one exists, None otherwise. 391 392 Returns: 393 Whether an error was recorded. 394 """ 395 # Only look for *.-warnings here, the other will be handled by the Attribute 396 # visitor. Also, do not warn for bare functions, only if the call func is 397 # an attribute. 398 warned = False 399 if isinstance(node.func, ast.Attribute): 400 warned = self._maybe_add_warning(node, "*." + six.ensure_str(name)) 401 402 # All arg warnings are handled here, since only we have the args 403 arg_warnings = self._get_applicable_dict("function_arg_warnings", 404 full_name, name) 405 406 variadic_args = uses_star_args_or_kwargs_in_call(node) 407 408 for (kwarg, arg), (level, warning) in sorted(arg_warnings.items()): 409 present, _ = get_arg_value(node, kwarg, arg) or variadic_args 410 if present: 411 warned = True 412 warning_message = six.ensure_str(warning).replace( 413 "<function name>", six.ensure_str(full_name or name)) 414 template = "%s called with %s argument, requires manual check: %s" 415 if variadic_args: 416 template = ("%s called with *args or **kwargs that may include %s, " 417 "requires manual check: %s") 418 self.add_log(level, node.lineno, node.col_offset, 419 template % (full_name or name, kwarg, warning_message)) 420 421 return warned 422 423 def _maybe_rename(self, parent, node, full_name): 424 """Replace node (Attribute or Name) with a node representing full_name.""" 425 new_name = self._api_change_spec.symbol_renames.get(full_name, None) 426 if new_name: 427 self.add_log(INFO, node.lineno, node.col_offset, 428 "Renamed %r to %r" % (full_name, new_name)) 429 new_node = full_name_node(new_name, node.ctx) 430 ast.copy_location(new_node, node) 431 pasta.ast_utils.replace_child(parent, node, new_node) 432 return True 433 else: 434 return False 435 436 def _maybe_change_to_function_call(self, parent, node, full_name): 437 """Wraps node (typically, an Attribute or Expr) in a Call.""" 438 if full_name in self._api_change_spec.change_to_function: 439 if not isinstance(parent, ast.Call): 440 # ast.Call's constructor is really picky about how many arguments it 441 # wants, and also, it changed between Py2 and Py3. 442 if six.PY2: 443 new_node = ast.Call(node, [], [], None, None) 444 else: 445 new_node = ast.Call(node, [], []) 446 pasta.ast_utils.replace_child(parent, node, new_node) 447 ast.copy_location(new_node, node) 448 self.add_log(INFO, node.lineno, node.col_offset, 449 "Changed %r to a function call" % full_name) 450 return True 451 return False 452 453 def _maybe_add_arg_names(self, node, full_name): 454 """Make args into keyword args if function called full_name requires it.""" 455 function_reorders = self._api_change_spec.function_reorders 456 457 if full_name in function_reorders: 458 if uses_star_args_in_call(node): 459 self.add_log(WARNING, node.lineno, node.col_offset, 460 "(Manual check required) upgrading %s may require " 461 "re-ordering the call arguments, but it was passed " 462 "variable-length positional *args. The upgrade " 463 "script cannot handle these automatically." % full_name) 464 465 reordered = function_reorders[full_name] 466 new_keywords = [] 467 idx = 0 468 for arg in node.args: 469 if sys.version_info[:2] >= (3, 5) and isinstance(arg, ast.Starred): 470 continue # Can't move Starred to keywords 471 keyword_arg = reordered[idx] 472 keyword = ast.keyword(arg=keyword_arg, value=arg) 473 new_keywords.append(keyword) 474 idx += 1 475 476 if new_keywords: 477 self.add_log(INFO, node.lineno, node.col_offset, 478 "Added keywords to args of function %r" % full_name) 479 node.args = [] 480 node.keywords = new_keywords + (node.keywords or []) 481 return True 482 return False 483 484 def _maybe_modify_args(self, node, full_name, name): 485 """Rename keyword args if the function called full_name requires it.""" 486 renamed_keywords = self._get_applicable_dict("function_keyword_renames", 487 full_name, name) 488 489 if not renamed_keywords: 490 return False 491 492 if uses_star_kwargs_in_call(node): 493 self.add_log(WARNING, node.lineno, node.col_offset, 494 "(Manual check required) upgrading %s may require " 495 "renaming or removing call arguments, but it was passed " 496 "variable-length *args or **kwargs. The upgrade " 497 "script cannot handle these automatically." % 498 (full_name or name)) 499 modified = False 500 new_keywords = [] 501 for keyword in node.keywords: 502 argkey = keyword.arg 503 if argkey in renamed_keywords: 504 modified = True 505 if renamed_keywords[argkey] is None: 506 lineno = getattr(keyword, "lineno", node.lineno) 507 col_offset = getattr(keyword, "col_offset", node.col_offset) 508 self.add_log(INFO, lineno, col_offset, 509 "Removed argument %s for function %s" % ( 510 argkey, full_name or name)) 511 else: 512 keyword.arg = renamed_keywords[argkey] 513 lineno = getattr(keyword, "lineno", node.lineno) 514 col_offset = getattr(keyword, "col_offset", node.col_offset) 515 self.add_log(INFO, lineno, col_offset, 516 "Renamed keyword argument for %s from %s to %s" % ( 517 full_name, argkey, renamed_keywords[argkey])) 518 new_keywords.append(keyword) 519 else: 520 new_keywords.append(keyword) 521 522 if modified: 523 node.keywords = new_keywords 524 return modified 525 526 def visit_Call(self, node): # pylint: disable=invalid-name 527 """Handle visiting a call node in the AST. 528 529 Args: 530 node: Current Node 531 """ 532 assert self._stack[-1] is node 533 534 # Get the name for this call, so we can index stuff with it. 535 full_name = self._get_full_name(node.func) 536 if full_name: 537 name = full_name.split(".")[-1] 538 elif isinstance(node.func, ast.Name): 539 name = node.func.id 540 elif isinstance(node.func, ast.Attribute): 541 name = node.func.attr 542 else: 543 name = None 544 545 # Call standard transformers for this node. 546 # Make sure warnings come first, since args or names triggering warnings 547 # may be removed by the other transformations. 548 self._maybe_add_call_warning(node, full_name, name) 549 # Make all args into kwargs 550 self._maybe_add_arg_names(node, full_name) 551 # Argument name changes or deletions 552 self._maybe_modify_args(node, full_name, name) 553 554 # Call transformers. These have the ability to modify the node, and if they 555 # do, will return the new node they created (or the same node if they just 556 # changed it). The are given the parent, but we will take care of 557 # integrating their changes into the parent if they return a new node. 558 # 559 # These are matched on the old name, since renaming is performed by the 560 # Attribute visitor, which happens later. 561 transformers = self._get_applicable_entries("function_transformers", 562 full_name, name) 563 564 parent = self._stack[-2] 565 566 if transformers: 567 if uses_star_args_or_kwargs_in_call(node): 568 self.add_log(WARNING, node.lineno, node.col_offset, 569 "(Manual check required) upgrading %s may require " 570 "modifying call arguments, but it was passed " 571 "variable-length *args or **kwargs. The upgrade " 572 "script cannot handle these automatically." % 573 (full_name or name)) 574 575 for transformer in transformers: 576 logs = [] 577 new_node = transformer(parent, node, full_name, name, logs) 578 self.add_logs(logs) 579 if new_node and new_node is not node: 580 pasta.ast_utils.replace_child(parent, node, new_node) 581 node = new_node 582 self._stack[-1] = node 583 584 self.generic_visit(node) 585 586 def visit_Attribute(self, node): # pylint: disable=invalid-name 587 """Handle bare Attributes i.e. [tf.foo, tf.bar].""" 588 assert self._stack[-1] is node 589 590 full_name = self._get_full_name(node) 591 if full_name: 592 parent = self._stack[-2] 593 594 # Make sure the warning comes first, otherwise the name may have changed 595 self._maybe_add_warning(node, full_name) 596 597 # Once we did a modification, node is invalid and not worth inspecting 598 # further. Also, we only perform modifications for simple nodes, so 599 # There'd be no point in descending further. 600 if self._maybe_rename(parent, node, full_name): 601 return 602 if self._maybe_change_to_function_call(parent, node, full_name): 603 return 604 605 # The isinstance check is enough -- a bare Attribute is never root. 606 i = 2 607 while isinstance(self._stack[-i], ast.Attribute): 608 i += 1 609 whole_name = pasta.dump(self._stack[-(i-1)]) 610 611 self._maybe_add_module_deprecation_warning(node, full_name, whole_name) 612 613 self.generic_visit(node) 614 615 def visit_Import(self, node): # pylint: disable=invalid-name 616 """Handle visiting an import node in the AST. 617 618 Args: 619 node: Current Node 620 """ 621 new_aliases = [] 622 import_updated = False 623 import_renames = getattr(self._api_change_spec, "import_renames", {}) 624 max_submodule_depth = getattr(self._api_change_spec, "max_submodule_depth", 625 1) 626 inserts_after_imports = getattr(self._api_change_spec, 627 "inserts_after_imports", {}) 628 629 # This loop processes imports in the format 630 # import foo as f, bar as b 631 for import_alias in node.names: 632 all_import_components = six.ensure_str(import_alias.name).split(".") 633 # Look for rename, starting with longest import levels. 634 found_update = False 635 for i in reversed(list(range(1, max_submodule_depth + 1))): 636 import_component = all_import_components[0] 637 for j in range(1, min(i, len(all_import_components))): 638 import_component += "." + six.ensure_str(all_import_components[j]) 639 import_rename_spec = import_renames.get(import_component, None) 640 641 if not import_rename_spec or excluded_from_module_rename( 642 import_alias.name, import_rename_spec): 643 continue 644 645 new_name = ( 646 import_rename_spec.new_name + 647 import_alias.name[len(import_component):]) 648 649 # If current import is 650 # import foo 651 # then new import should preserve imported name: 652 # import new_foo as foo 653 # This happens when module has just one component. 654 new_asname = import_alias.asname 655 if not new_asname and "." not in import_alias.name: 656 new_asname = import_alias.name 657 658 new_alias = ast.alias(name=new_name, asname=new_asname) 659 new_aliases.append(new_alias) 660 import_updated = True 661 found_update = True 662 663 # Insert any followup lines that should happen after this import. 664 full_import = (import_alias.name, import_alias.asname) 665 insert_offset = 1 666 for line_to_insert in inserts_after_imports.get(full_import, []): 667 assert self._stack[-1] is node 668 parent = self._stack[-2] 669 670 new_line_node = pasta.parse(line_to_insert) 671 ast.copy_location(new_line_node, node) 672 parent.body.insert( 673 parent.body.index(node) + insert_offset, new_line_node) 674 insert_offset += 1 675 676 # Insert a newline after the import if necessary 677 old_suffix = pasta.base.formatting.get(node, "suffix") 678 if old_suffix is None: 679 old_suffix = os.linesep 680 if os.linesep not in old_suffix: 681 pasta.base.formatting.set(node, "suffix", 682 six.ensure_str(old_suffix) + os.linesep) 683 684 # Apply indentation to new node. 685 pasta.base.formatting.set(new_line_node, "prefix", 686 pasta.base.formatting.get(node, "prefix")) 687 pasta.base.formatting.set(new_line_node, "suffix", os.linesep) 688 self.add_log( 689 INFO, node.lineno, node.col_offset, 690 "Adding `%s` after import of %s" % 691 (new_line_node, import_alias.name)) 692 # Find one match, break 693 if found_update: 694 break 695 # No rename is found for all levels 696 if not found_update: 697 new_aliases.append(import_alias) # no change needed 698 699 # Replace the node if at least one import needs to be updated. 700 if import_updated: 701 assert self._stack[-1] is node 702 parent = self._stack[-2] 703 704 new_node = ast.Import(new_aliases) 705 ast.copy_location(new_node, node) 706 pasta.ast_utils.replace_child(parent, node, new_node) 707 self.add_log( 708 INFO, node.lineno, node.col_offset, 709 "Changed import from %r to %r." % 710 (pasta.dump(node), pasta.dump(new_node))) 711 712 self.generic_visit(node) 713 714 def visit_ImportFrom(self, node): # pylint: disable=invalid-name 715 """Handle visiting an import-from node in the AST. 716 717 Args: 718 node: Current Node 719 """ 720 if not node.module: 721 self.generic_visit(node) 722 return 723 724 from_import = node.module 725 726 # Look for rename based on first component of from-import. 727 # i.e. based on foo in foo.bar. 728 from_import_first_component = six.ensure_str(from_import).split(".")[0] 729 import_renames = getattr(self._api_change_spec, "import_renames", {}) 730 import_rename_spec = import_renames.get(from_import_first_component, None) 731 if not import_rename_spec: 732 self.generic_visit(node) 733 return 734 735 # Split module aliases into the ones that require import update 736 # and those that don't. For e.g. if we want to rename "a" to "b" 737 # unless we import "a.c" in the following: 738 # from a import c, d 739 # we want to update import for "d" but not for "c". 740 updated_aliases = [] 741 same_aliases = [] 742 for import_alias in node.names: 743 full_module_name = "%s.%s" % (from_import, import_alias.name) 744 if excluded_from_module_rename(full_module_name, import_rename_spec): 745 same_aliases.append(import_alias) 746 else: 747 updated_aliases.append(import_alias) 748 749 if not updated_aliases: 750 self.generic_visit(node) 751 return 752 753 assert self._stack[-1] is node 754 parent = self._stack[-2] 755 756 # Replace first component of from-import with new name. 757 new_from_import = ( 758 import_rename_spec.new_name + 759 from_import[len(from_import_first_component):]) 760 updated_node = ast.ImportFrom(new_from_import, updated_aliases, node.level) 761 ast.copy_location(updated_node, node) 762 pasta.ast_utils.replace_child(parent, node, updated_node) 763 764 # If some imports had to stay the same, add another import for them. 765 additional_import_log = "" 766 if same_aliases: 767 same_node = ast.ImportFrom(from_import, same_aliases, node.level, 768 col_offset=node.col_offset, lineno=node.lineno) 769 ast.copy_location(same_node, node) 770 parent.body.insert(parent.body.index(updated_node), same_node) 771 # Apply indentation to new node. 772 pasta.base.formatting.set( 773 same_node, "prefix", 774 pasta.base.formatting.get(updated_node, "prefix")) 775 additional_import_log = " and %r" % pasta.dump(same_node) 776 777 self.add_log( 778 INFO, node.lineno, node.col_offset, 779 "Changed import from %r to %r%s." % 780 (pasta.dump(node), 781 pasta.dump(updated_node), 782 additional_import_log)) 783 784 self.generic_visit(node) 785 786 787class AnalysisResult(object): 788 """This class represents an analysis result and how it should be logged. 789 790 This class must provide the following fields: 791 792 * `log_level`: The log level to which this detection should be logged 793 * `log_message`: The message that should be logged for this detection 794 795 For an example, see `VersionedTFImport`. 796 """ 797 798 799class APIAnalysisSpec(object): 800 """This class defines how `AnalysisResult`s should be generated. 801 802 It specifies how to map imports and symbols to `AnalysisResult`s. 803 804 This class must provide the following fields: 805 806 * `symbols_to_detect`: maps function names to `AnalysisResult`s 807 * `imports_to_detect`: maps imports represented as (full module name, alias) 808 tuples to `AnalysisResult`s 809 notifications) 810 811 For an example, see `TFAPIImportAnalysisSpec`. 812 """ 813 814 815class PastaAnalyzeVisitor(_PastaEditVisitor): 816 """AST Visitor that looks for specific API usage without editing anything. 817 818 This is used before any rewriting is done to detect if any symbols are used 819 that require changing imports or disabling rewriting altogether. 820 """ 821 822 def __init__(self, api_analysis_spec): 823 super(PastaAnalyzeVisitor, self).__init__(NoUpdateSpec()) 824 self._api_analysis_spec = api_analysis_spec 825 self._results = [] # Holds AnalysisResult objects 826 827 @property 828 def results(self): 829 return self._results 830 831 def add_result(self, analysis_result): 832 self._results.append(analysis_result) 833 834 def visit_Attribute(self, node): # pylint: disable=invalid-name 835 """Handle bare Attributes i.e. [tf.foo, tf.bar].""" 836 full_name = self._get_full_name(node) 837 if full_name: 838 detection = self._api_analysis_spec.symbols_to_detect.get(full_name, None) 839 if detection: 840 self.add_result(detection) 841 self.add_log( 842 detection.log_level, node.lineno, node.col_offset, 843 detection.log_message) 844 845 self.generic_visit(node) 846 847 def visit_Import(self, node): # pylint: disable=invalid-name 848 """Handle visiting an import node in the AST. 849 850 Args: 851 node: Current Node 852 """ 853 for import_alias in node.names: 854 # Detect based on full import name and alias) 855 full_import = (import_alias.name, import_alias.asname) 856 detection = (self._api_analysis_spec 857 .imports_to_detect.get(full_import, None)) 858 if detection: 859 self.add_result(detection) 860 self.add_log( 861 detection.log_level, node.lineno, node.col_offset, 862 detection.log_message) 863 864 self.generic_visit(node) 865 866 def visit_ImportFrom(self, node): # pylint: disable=invalid-name 867 """Handle visiting an import-from node in the AST. 868 869 Args: 870 node: Current Node 871 """ 872 if not node.module: 873 self.generic_visit(node) 874 return 875 876 from_import = node.module 877 878 for import_alias in node.names: 879 # Detect based on full import name(to & as) 880 full_module_name = "%s.%s" % (from_import, import_alias.name) 881 full_import = (full_module_name, import_alias.asname) 882 detection = (self._api_analysis_spec 883 .imports_to_detect.get(full_import, None)) 884 if detection: 885 self.add_result(detection) 886 self.add_log( 887 detection.log_level, node.lineno, node.col_offset, 888 detection.log_message) 889 890 self.generic_visit(node) 891 892 893class ASTCodeUpgrader(object): 894 """Handles upgrading a set of Python files using a given API change spec.""" 895 896 def __init__(self, api_change_spec): 897 if not isinstance(api_change_spec, APIChangeSpec): 898 raise TypeError("Must pass APIChangeSpec to ASTCodeUpgrader, got %s" % 899 type(api_change_spec)) 900 self._api_change_spec = api_change_spec 901 902 def process_file(self, 903 in_filename, 904 out_filename, 905 no_change_to_outfile_on_error=False): 906 """Process the given python file for incompatible changes. 907 908 Args: 909 in_filename: filename to parse 910 out_filename: output file to write to 911 no_change_to_outfile_on_error: not modify the output file on errors 912 Returns: 913 A tuple representing number of files processed, log of actions, errors 914 """ 915 916 # Write to a temporary file, just in case we are doing an implace modify. 917 # pylint: disable=g-backslash-continuation 918 with open(in_filename, "r") as in_file, \ 919 tempfile.NamedTemporaryFile("w", delete=False) as temp_file: 920 ret = self.process_opened_file(in_filename, in_file, out_filename, 921 temp_file) 922 # pylint: enable=g-backslash-continuation 923 924 if no_change_to_outfile_on_error and ret[0] == 0: 925 os.remove(temp_file.name) 926 else: 927 shutil.move(temp_file.name, out_filename) 928 return ret 929 930 def format_log(self, log, in_filename): 931 log_string = "%d:%d: %s: %s" % (log[1], log[2], log[0], log[3]) 932 if in_filename: 933 return six.ensure_str(in_filename) + ":" + log_string 934 else: 935 return log_string 936 937 def update_string_pasta(self, text, in_filename): 938 """Updates a file using pasta.""" 939 try: 940 t = pasta.parse(text) 941 except (SyntaxError, ValueError, TypeError): 942 log = ["ERROR: Failed to parse.\n" + traceback.format_exc()] 943 return 0, "", log, [] 944 945 t, preprocess_logs, preprocess_errors = self._api_change_spec.preprocess(t) 946 947 visitor = _PastaEditVisitor(self._api_change_spec) 948 visitor.visit(t) 949 950 self._api_change_spec.clear_preprocessing() 951 952 logs = [self.format_log(log, None) for log in (preprocess_logs + 953 visitor.log)] 954 errors = [self.format_log(error, in_filename) 955 for error in (preprocess_errors + 956 visitor.warnings_and_errors)] 957 return 1, pasta.dump(t), logs, errors 958 959 def _format_log(self, log, in_filename, out_filename): 960 text = six.ensure_str("-" * 80) + "\n" 961 text += "Processing file %r\n outputting to %r\n" % (in_filename, 962 out_filename) 963 text += six.ensure_str("-" * 80) + "\n\n" 964 text += "\n".join(log) + "\n" 965 text += six.ensure_str("-" * 80) + "\n\n" 966 return text 967 968 def process_opened_file(self, in_filename, in_file, out_filename, out_file): 969 """Process the given python file for incompatible changes. 970 971 This function is split out to facilitate StringIO testing from 972 tf_upgrade_test.py. 973 974 Args: 975 in_filename: filename to parse 976 in_file: opened file (or StringIO) 977 out_filename: output file to write to 978 out_file: opened file (or StringIO) 979 Returns: 980 A tuple representing number of files processed, log of actions, errors 981 """ 982 lines = in_file.readlines() 983 processed_file, new_file_content, log, process_errors = ( 984 self.update_string_pasta("".join(lines), in_filename)) 985 986 if out_file and processed_file: 987 out_file.write(new_file_content) 988 989 return (processed_file, 990 self._format_log(log, in_filename, out_filename), 991 process_errors) 992 993 def process_tree(self, root_directory, output_root_directory, 994 copy_other_files): 995 """Processes upgrades on an entire tree of python files in place. 996 997 Note that only Python files. If you have custom code in other languages, 998 you will need to manually upgrade those. 999 1000 Args: 1001 root_directory: Directory to walk and process. 1002 output_root_directory: Directory to use as base. 1003 copy_other_files: Copy files that are not touched by this converter. 1004 1005 Returns: 1006 A tuple of files processed, the report string for all files, and a dict 1007 mapping filenames to errors encountered in that file. 1008 """ 1009 1010 if output_root_directory == root_directory: 1011 return self.process_tree_inplace(root_directory) 1012 1013 # make sure output directory doesn't exist 1014 if output_root_directory and os.path.exists(output_root_directory): 1015 print("Output directory %r must not already exist." % 1016 (output_root_directory)) 1017 sys.exit(1) 1018 1019 # make sure output directory does not overlap with root_directory 1020 norm_root = os.path.split(os.path.normpath(root_directory)) 1021 norm_output = os.path.split(os.path.normpath(output_root_directory)) 1022 if norm_root == norm_output: 1023 print("Output directory %r same as input directory %r" % 1024 (root_directory, output_root_directory)) 1025 sys.exit(1) 1026 1027 # Collect list of files to process (we do this to correctly handle if the 1028 # user puts the output directory in some sub directory of the input dir) 1029 files_to_process = [] 1030 files_to_copy = [] 1031 for dir_name, _, file_list in os.walk(root_directory): 1032 py_files = [f for f in file_list if six.ensure_str(f).endswith(".py")] 1033 copy_files = [ 1034 f for f in file_list if not six.ensure_str(f).endswith(".py") 1035 ] 1036 for filename in py_files: 1037 fullpath = os.path.join(dir_name, filename) 1038 fullpath_output = os.path.join(output_root_directory, 1039 os.path.relpath(fullpath, 1040 root_directory)) 1041 files_to_process.append((fullpath, fullpath_output)) 1042 if copy_other_files: 1043 for filename in copy_files: 1044 fullpath = os.path.join(dir_name, filename) 1045 fullpath_output = os.path.join(output_root_directory, 1046 os.path.relpath( 1047 fullpath, root_directory)) 1048 files_to_copy.append((fullpath, fullpath_output)) 1049 1050 file_count = 0 1051 tree_errors = {} 1052 report = "" 1053 report += six.ensure_str(("=" * 80)) + "\n" 1054 report += "Input tree: %r\n" % root_directory 1055 report += six.ensure_str(("=" * 80)) + "\n" 1056 1057 for input_path, output_path in files_to_process: 1058 output_directory = os.path.dirname(output_path) 1059 if not os.path.isdir(output_directory): 1060 os.makedirs(output_directory) 1061 1062 if os.path.islink(input_path): 1063 link_target = os.readlink(input_path) 1064 link_target_output = os.path.join( 1065 output_root_directory, os.path.relpath(link_target, root_directory)) 1066 if (link_target, link_target_output) in files_to_process: 1067 # Create a link to the new location of the target file 1068 os.symlink(link_target_output, output_path) 1069 else: 1070 report += "Copying symlink %s without modifying its target %s" % ( 1071 input_path, link_target) 1072 os.symlink(link_target, output_path) 1073 continue 1074 1075 file_count += 1 1076 _, l_report, l_errors = self.process_file(input_path, output_path) 1077 tree_errors[input_path] = l_errors 1078 report += l_report 1079 1080 for input_path, output_path in files_to_copy: 1081 output_directory = os.path.dirname(output_path) 1082 if not os.path.isdir(output_directory): 1083 os.makedirs(output_directory) 1084 shutil.copy(input_path, output_path) 1085 return file_count, report, tree_errors 1086 1087 def process_tree_inplace(self, root_directory): 1088 """Process a directory of python files in place.""" 1089 files_to_process = [] 1090 for dir_name, _, file_list in os.walk(root_directory): 1091 py_files = [ 1092 os.path.join(dir_name, f) 1093 for f in file_list 1094 if six.ensure_str(f).endswith(".py") 1095 ] 1096 files_to_process += py_files 1097 1098 file_count = 0 1099 tree_errors = {} 1100 report = "" 1101 report += six.ensure_str(("=" * 80)) + "\n" 1102 report += "Input tree: %r\n" % root_directory 1103 report += six.ensure_str(("=" * 80)) + "\n" 1104 1105 for path in files_to_process: 1106 if os.path.islink(path): 1107 report += "Skipping symlink %s.\n" % path 1108 continue 1109 file_count += 1 1110 _, l_report, l_errors = self.process_file(path, path) 1111 tree_errors[path] = l_errors 1112 report += l_report 1113 1114 return file_count, report, tree_errors 1115