1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Upgrader for Python scripts according to an API change specification.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import ast 22import os 23import re 24import shutil 25import sys 26import tempfile 27import traceback 28 29import pasta 30import six 31 32# Some regular expressions we will need for parsing 33FIND_OPEN = re.compile(r"^\s*(\[).*$") 34FIND_STRING_CHARS = re.compile(r"['\"]") 35 36 37INFO = "INFO" 38WARNING = "WARNING" 39ERROR = "ERROR" 40 41 42def full_name_node(name, ctx=ast.Load()): 43 """Make an Attribute or Name node for name. 44 45 Translate a qualified name into nested Attribute nodes (and a Name node). 46 47 Args: 48 name: The name to translate to a node. 49 ctx: What context this name is used in. Defaults to Load() 50 51 Returns: 52 A Name or Attribute node. 53 """ 54 names = name.split(".") 55 names.reverse() 56 node = ast.Name(id=names.pop(), ctx=ast.Load()) 57 while names: 58 node = ast.Attribute(value=node, attr=names.pop(), ctx=ast.Load()) 59 60 # Change outermost ctx to the one given to us (inner ones should be Load). 61 node.ctx = ctx 62 return node 63 64 65def get_arg_value(node, arg_name, arg_pos=None): 66 """Get the value of an argument from a ast.Call node. 67 68 This function goes through the positional and keyword arguments to check 69 whether a given argument was used, and if so, returns its value (the node 70 representing its value). 71 72 This cannot introspect *args or **args, but it safely handles *args in 73 Python3.5+. 74 75 Args: 76 node: The ast.Call node to extract arg values from. 77 arg_name: The name of the argument to extract. 78 arg_pos: The position of the argument (in case it's passed as a positional 79 argument). 80 81 Returns: 82 A tuple (arg_present, arg_value) containing a boolean indicating whether 83 the argument is present, and its value in case it is. 84 """ 85 # Check keyword args 86 if arg_name is not None: 87 for kw in node.keywords: 88 if kw.arg == arg_name: 89 return (True, kw.value) 90 91 # Check positional args 92 if arg_pos is not None: 93 idx = 0 94 for arg in node.args: 95 if sys.version_info[:2] >= (3, 5) and isinstance(arg, ast.Starred): 96 continue # Can't parse Starred 97 if idx == arg_pos: 98 return (True, arg) 99 idx += 1 100 101 return (False, None) 102 103 104class APIChangeSpec(object): 105 """This class defines the transformations that need to happen. 106 107 This class must provide the following fields: 108 109 * `function_keyword_renames`: maps function names to a map of old -> new 110 argument names 111 * `symbol_renames`: maps function names to new function names 112 * `change_to_function`: a set of function names that have changed (for 113 notifications) 114 * `function_reorders`: maps functions whose argument order has changed to the 115 list of arguments in the new order 116 * `function_warnings`: maps full names of functions to warnings that will be 117 printed out if the function is used. (e.g. tf.nn.convolution()) 118 * `function_transformers`: maps function names to custom handlers 119 * `module_deprecations`: maps module names to warnings that will be printed 120 if the module is still used after all other transformations have run 121 122 For an example, see `TFAPIChangeSpec`. 123 """ 124 125 126class _PastaEditVisitor(ast.NodeVisitor): 127 """AST Visitor that processes function calls. 128 129 Updates function calls from old API version to new API version using a given 130 change spec. 131 """ 132 133 def __init__(self, api_change_spec): 134 self._api_change_spec = api_change_spec 135 self._log = [] # Holds 4-tuples: severity, line, col, msg. 136 self._stack = [] # Allow easy access to parents. 137 138 # Overridden to maintain a stack of nodes to allow for parent access 139 def visit(self, node): 140 self._stack.append(node) 141 super(_PastaEditVisitor, self).visit(node) 142 self._stack.pop() 143 144 @property 145 def errors(self): 146 return [log for log in self._log if log[0] == ERROR] 147 148 @property 149 def warnings(self): 150 return [log for log in self._log if log[0] == WARNING] 151 152 @property 153 def warnings_and_errors(self): 154 return [log for log in self._log if log[0] in (WARNING, ERROR)] 155 156 @property 157 def info(self): 158 return [log for log in self._log if log[0] == INFO] 159 160 @property 161 def log(self): 162 return self._log 163 164 def add_log(self, severity, lineno, col, msg): 165 self._log.append((severity, lineno, col, msg)) 166 print("%s line %d:%d: %s" % (severity, lineno, col, msg)) 167 168 def add_logs(self, logs): 169 """Record a log and print it. 170 171 The log should be a tuple `(severity, lineno, col_offset, msg)`, which will 172 be printed and recorded. It is part of the log available in the `self.log` 173 property. 174 175 Args: 176 logs: The logs to add. Must be a list of tuples 177 `(severity, lineno, col_offset, msg)`. 178 """ 179 self._log.extend(logs) 180 for log in logs: 181 print("%s line %d:%d: %s" % log) 182 183 def _get_applicable_entries(self, transformer_field, full_name, name): 184 """Get all list entries indexed by name that apply to full_name or name.""" 185 # Transformers are indexed to full name, name, or no name 186 # as a performance optimization. 187 function_transformers = getattr(self._api_change_spec, 188 transformer_field, {}) 189 190 glob_name = "*." + name if name else None 191 transformers = [] 192 if full_name in function_transformers: 193 transformers.append(function_transformers[full_name]) 194 if glob_name in function_transformers: 195 transformers.append(function_transformers[glob_name]) 196 if "*" in function_transformers: 197 transformers.append(function_transformers["*"]) 198 return transformers 199 200 def _get_applicable_dict(self, transformer_field, full_name, name): 201 """Get all dict entries indexed by name that apply to full_name or name.""" 202 # Transformers are indexed to full name, name, or no name 203 # as a performance optimization. 204 function_transformers = getattr(self._api_change_spec, 205 transformer_field, {}) 206 207 glob_name = "*." + name if name else None 208 transformers = function_transformers.get("*", {}).copy() 209 transformers.update(function_transformers.get(glob_name, {})) 210 transformers.update(function_transformers.get(full_name, {})) 211 return transformers 212 213 def _get_full_name(self, node): 214 """Traverse an Attribute node to generate a full name, e.g., "tf.foo.bar". 215 216 This is the inverse of `full_name_node`. 217 218 Args: 219 node: A Node of type Attribute. 220 221 Returns: 222 a '.'-delimited full-name or None if node was not Attribute or Name. 223 i.e. `foo()+b).bar` returns None, while `a.b.c` would return "a.b.c". 224 """ 225 curr = node 226 items = [] 227 while not isinstance(curr, ast.Name): 228 if not isinstance(curr, ast.Attribute): 229 return None 230 items.append(curr.attr) 231 curr = curr.value 232 items.append(curr.id) 233 return ".".join(reversed(items)) 234 235 def _maybe_add_warning(self, node, full_name): 236 """Adds an error to be printed about full_name at node.""" 237 function_warnings = self._api_change_spec.function_warnings 238 if full_name in function_warnings: 239 level, message = function_warnings[full_name] 240 message = message.replace("<function name>", full_name) 241 self.add_log(level, node.lineno, node.col_offset, 242 "%s requires manual check. %s" % (full_name, message)) 243 return True 244 else: 245 return False 246 247 def _maybe_add_module_deprecation_warning(self, node, full_name, whole_name): 248 """Adds a warning if full_name is a deprecated module.""" 249 warnings = self._api_change_spec.module_deprecations 250 if full_name in warnings: 251 level, message = warnings[full_name] 252 message = message.replace("<function name>", whole_name) 253 self.add_log(level, node.lineno, node.col_offset, 254 "Using member %s in deprecated module %s. %s" % (whole_name, 255 full_name, 256 message)) 257 return True 258 else: 259 return False 260 261 def _maybe_add_call_warning(self, node, full_name, name): 262 """Print a warning when specific functions are called with selected args. 263 264 The function _print_warning_for_function matches the full name of the called 265 function, e.g., tf.foo.bar(). This function matches the function name that 266 is called, as long as the function is an attribute. For example, 267 `tf.foo.bar()` and `foo.bar()` are matched, but not `bar()`. 268 269 Args: 270 node: ast.Call object 271 full_name: The precomputed full name of the callable, if one exists, None 272 otherwise. 273 name: The precomputed name of the callable, if one exists, None otherwise. 274 275 Returns: 276 Whether an error was recorded. 277 """ 278 # Only look for *.-warnings here, the other will be handled by the Attribute 279 # visitor. Also, do not warn for bare functions, only if the call func is 280 # an attribute. 281 warned = False 282 if isinstance(node.func, ast.Attribute): 283 warned = self._maybe_add_warning(node, "*." + name) 284 285 # All arg warnings are handled here, since only we have the args 286 arg_warnings = self._get_applicable_dict("function_arg_warnings", 287 full_name, name) 288 289 for (kwarg, arg), (level, warning) in sorted(arg_warnings.items()): 290 present, _ = get_arg_value(node, kwarg, arg) 291 if present: 292 warned = True 293 warning_message = warning.replace("<function name>", full_name or name) 294 self.add_log(level, node.lineno, node.col_offset, 295 "%s called with %s argument requires manual check: %s" % 296 (full_name or name, kwarg, warning_message)) 297 298 return warned 299 300 def _maybe_rename(self, parent, node, full_name): 301 """Replace node (Attribute or Name) with a node representing full_name.""" 302 new_name = self._api_change_spec.symbol_renames.get(full_name, None) 303 if new_name: 304 self.add_log(INFO, node.lineno, node.col_offset, 305 "Renamed %r to %r" % (full_name, new_name)) 306 new_node = full_name_node(new_name, node.ctx) 307 ast.copy_location(new_node, node) 308 pasta.ast_utils.replace_child(parent, node, new_node) 309 return True 310 else: 311 return False 312 313 def _maybe_change_to_function_call(self, parent, node, full_name): 314 """Wraps node (typically, an Attribute or Expr) in a Call.""" 315 if full_name in self._api_change_spec.change_to_function: 316 if not isinstance(parent, ast.Call): 317 # ast.Call's constructor is really picky about how many arguments it 318 # wants, and also, it changed between Py2 and Py3. 319 if six.PY2: 320 new_node = ast.Call(node, [], [], None, None) 321 else: 322 new_node = ast.Call(node, [], []) 323 pasta.ast_utils.replace_child(parent, node, new_node) 324 ast.copy_location(new_node, node) 325 self.add_log(INFO, node.lineno, node.col_offset, 326 "Changed %r to a function call" % full_name) 327 return True 328 return False 329 330 def _maybe_add_arg_names(self, node, full_name): 331 """Make args into keyword args if function called full_name requires it.""" 332 function_reorders = self._api_change_spec.function_reorders 333 334 if full_name in function_reorders: 335 reordered = function_reorders[full_name] 336 new_keywords = [] 337 idx = 0 338 for arg in node.args: 339 if sys.version_info[:2] >= (3, 5) and isinstance(arg, ast.Starred): 340 continue # Can't move Starred to keywords 341 keyword_arg = reordered[idx] 342 keyword = ast.keyword(arg=keyword_arg, value=arg) 343 new_keywords.append(keyword) 344 idx += 1 345 346 if new_keywords: 347 self.add_log(INFO, node.lineno, node.col_offset, 348 "Added keywords to args of function %r" % full_name) 349 node.args = [] 350 node.keywords = new_keywords + (node.keywords or []) 351 return True 352 return False 353 354 def _maybe_modify_args(self, node, full_name, name): 355 """Rename keyword args if the function called full_name requires it.""" 356 renamed_keywords = self._get_applicable_dict("function_keyword_renames", 357 full_name, name) 358 359 if not renamed_keywords: 360 return False 361 362 modified = False 363 new_keywords = [] 364 for keyword in node.keywords: 365 argkey = keyword.arg 366 if argkey in renamed_keywords: 367 modified = True 368 if renamed_keywords[argkey] is None: 369 lineno = getattr(keyword, "lineno", node.lineno) 370 col_offset = getattr(keyword, "col_offset", node.col_offset) 371 self.add_log(INFO, lineno, col_offset, 372 "Removed argument %s for function %s" % ( 373 argkey, full_name or name)) 374 else: 375 keyword.arg = renamed_keywords[argkey] 376 lineno = getattr(keyword, "lineno", node.lineno) 377 col_offset = getattr(keyword, "col_offset", node.col_offset) 378 self.add_log(INFO, lineno, col_offset, 379 "Renamed keyword argument for %s from %s to %s" % ( 380 full_name, argkey, renamed_keywords[argkey])) 381 new_keywords.append(keyword) 382 else: 383 new_keywords.append(keyword) 384 385 if modified: 386 node.keywords = new_keywords 387 return modified 388 389 def visit_Call(self, node): # pylint: disable=invalid-name 390 """Handle visiting a call node in the AST. 391 392 Args: 393 node: Current Node 394 """ 395 assert self._stack[-1] is node 396 397 # Get the name for this call, so we can index stuff with it. 398 full_name = self._get_full_name(node.func) 399 if full_name: 400 name = full_name.split(".")[-1] 401 elif isinstance(node.func, ast.Name): 402 name = node.func.id 403 elif isinstance(node.func, ast.Attribute): 404 name = node.func.attr 405 else: 406 name = None 407 408 # Call standard transformers for this node. 409 # Make sure warnings come first, since args or names triggering warnings 410 # may be removed by the other transformations. 411 self._maybe_add_call_warning(node, full_name, name) 412 # Make all args into kwargs 413 self._maybe_add_arg_names(node, full_name) 414 # Argument name changes or deletions 415 self._maybe_modify_args(node, full_name, name) 416 417 # Call transformers. These have the ability to modify the node, and if they 418 # do, will return the new node they created (or the same node if they just 419 # changed it). The are given the parent, but we will take care of 420 # integrating their changes into the parent if they return a new node. 421 # 422 # These are matched on the old name, since renaming is performed by the 423 # Attribute visitor, which happens later. 424 transformers = self._get_applicable_entries("function_transformers", 425 full_name, name) 426 427 parent = self._stack[-2] 428 429 for transformer in transformers: 430 logs = [] 431 new_node = transformer(parent, node, full_name, name, logs) 432 self.add_logs(logs) 433 if new_node and new_node is not node: 434 pasta.ast_utils.replace_child(parent, node, new_node) 435 node = new_node 436 self._stack[-1] = node 437 438 self.generic_visit(node) 439 440 def visit_Attribute(self, node): # pylint: disable=invalid-name 441 """Handle bare Attributes i.e. [tf.foo, tf.bar].""" 442 assert self._stack[-1] is node 443 444 full_name = self._get_full_name(node) 445 if full_name: 446 parent = self._stack[-2] 447 448 # Make sure the warning comes first, otherwise the name may have changed 449 self._maybe_add_warning(node, full_name) 450 451 # Once we did a modification, node is invalid and not worth inspecting 452 # further. Also, we only perform modifications for simple nodes, so 453 # There'd be no point in descending further. 454 if self._maybe_rename(parent, node, full_name): 455 return 456 if self._maybe_change_to_function_call(parent, node, full_name): 457 return 458 459 # The isinstance check is enough -- a bare Attribute is never root. 460 i = 2 461 while isinstance(self._stack[-i], ast.Attribute): 462 i += 1 463 whole_name = pasta.dump(self._stack[-(i-1)]) 464 465 self._maybe_add_module_deprecation_warning(node, full_name, whole_name) 466 467 self.generic_visit(node) 468 469 470class ASTCodeUpgrader(object): 471 """Handles upgrading a set of Python files using a given API change spec.""" 472 473 def __init__(self, api_change_spec): 474 if not isinstance(api_change_spec, APIChangeSpec): 475 raise TypeError("Must pass APIChangeSpec to ASTCodeUpgrader, got %s" % 476 type(api_change_spec)) 477 self._api_change_spec = api_change_spec 478 479 def process_file(self, in_filename, out_filename): 480 """Process the given python file for incompatible changes. 481 482 Args: 483 in_filename: filename to parse 484 out_filename: output file to write to 485 Returns: 486 A tuple representing number of files processed, log of actions, errors 487 """ 488 489 # Write to a temporary file, just in case we are doing an implace modify. 490 # pylint: disable=g-backslash-continuation 491 with open(in_filename, "r") as in_file, \ 492 tempfile.NamedTemporaryFile("w", delete=False) as temp_file: 493 ret = self.process_opened_file(in_filename, in_file, out_filename, 494 temp_file) 495 # pylint: enable=g-backslash-continuation 496 497 shutil.move(temp_file.name, out_filename) 498 return ret 499 500 def format_log(self, log, in_filename): 501 log_string = "%d:%d: %s: %s" % (log[1], log[2], log[0], log[3]) 502 if in_filename: 503 return in_filename + ":" + log_string 504 else: 505 return log_string 506 507 def update_string_pasta(self, text, in_filename): 508 """Updates a file using pasta.""" 509 try: 510 t = pasta.parse(text) 511 except (SyntaxError, ValueError, TypeError): 512 log = ["ERROR: Failed to parse.\n" + traceback.format_exc()] 513 return 0, "", log, [] 514 515 visitor = _PastaEditVisitor(self._api_change_spec) 516 visitor.visit(t) 517 518 logs = [self.format_log(log, None) for log in visitor.log] 519 errors = [self.format_log(error, in_filename) 520 for error in visitor.warnings_and_errors] 521 return 1, pasta.dump(t), logs, errors 522 523 def _format_log(self, log, in_filename, out_filename): 524 text = "-" * 80 + "\n" 525 text += "Processing file %r\n outputting to %r\n" % (in_filename, 526 out_filename) 527 text += "-" * 80 + "\n\n" 528 text += "\n".join(log) + "\n" 529 text += "-" * 80 + "\n\n" 530 return text 531 532 def process_opened_file(self, in_filename, in_file, out_filename, out_file): 533 """Process the given python file for incompatible changes. 534 535 This function is split out to facilitate StringIO testing from 536 tf_upgrade_test.py. 537 538 Args: 539 in_filename: filename to parse 540 in_file: opened file (or StringIO) 541 out_filename: output file to write to 542 out_file: opened file (or StringIO) 543 Returns: 544 A tuple representing number of files processed, log of actions, errors 545 """ 546 lines = in_file.readlines() 547 processed_file, new_file_content, log, process_errors = ( 548 self.update_string_pasta("".join(lines), in_filename)) 549 550 if out_file and processed_file: 551 out_file.write(new_file_content) 552 553 return (processed_file, 554 self._format_log(log, in_filename, out_filename), 555 process_errors) 556 557 def process_tree(self, root_directory, output_root_directory, 558 copy_other_files): 559 """Processes upgrades on an entire tree of python files in place. 560 561 Note that only Python files. If you have custom code in other languages, 562 you will need to manually upgrade those. 563 564 Args: 565 root_directory: Directory to walk and process. 566 output_root_directory: Directory to use as base. 567 copy_other_files: Copy files that are not touched by this converter. 568 569 Returns: 570 A tuple of files processed, the report string for all files, and a dict 571 mapping filenames to errors encountered in that file. 572 """ 573 574 if output_root_directory == root_directory: 575 return self.process_tree_inplace(root_directory) 576 577 # make sure output directory doesn't exist 578 if output_root_directory and os.path.exists(output_root_directory): 579 print("Output directory %r must not already exist." % 580 (output_root_directory)) 581 sys.exit(1) 582 583 # make sure output directory does not overlap with root_directory 584 norm_root = os.path.split(os.path.normpath(root_directory)) 585 norm_output = os.path.split(os.path.normpath(output_root_directory)) 586 if norm_root == norm_output: 587 print("Output directory %r same as input directory %r" % 588 (root_directory, output_root_directory)) 589 sys.exit(1) 590 591 # Collect list of files to process (we do this to correctly handle if the 592 # user puts the output directory in some sub directory of the input dir) 593 files_to_process = [] 594 files_to_copy = [] 595 for dir_name, _, file_list in os.walk(root_directory): 596 py_files = [f for f in file_list if f.endswith(".py")] 597 copy_files = [f for f in file_list if not f.endswith(".py")] 598 for filename in py_files: 599 fullpath = os.path.join(dir_name, filename) 600 fullpath_output = os.path.join(output_root_directory, 601 os.path.relpath(fullpath, 602 root_directory)) 603 files_to_process.append((fullpath, fullpath_output)) 604 if copy_other_files: 605 for filename in copy_files: 606 fullpath = os.path.join(dir_name, filename) 607 fullpath_output = os.path.join(output_root_directory, 608 os.path.relpath( 609 fullpath, root_directory)) 610 files_to_copy.append((fullpath, fullpath_output)) 611 612 file_count = 0 613 tree_errors = {} 614 report = "" 615 report += ("=" * 80) + "\n" 616 report += "Input tree: %r\n" % root_directory 617 report += ("=" * 80) + "\n" 618 619 for input_path, output_path in files_to_process: 620 output_directory = os.path.dirname(output_path) 621 if not os.path.isdir(output_directory): 622 os.makedirs(output_directory) 623 file_count += 1 624 _, l_report, l_errors = self.process_file(input_path, output_path) 625 tree_errors[input_path] = l_errors 626 report += l_report 627 for input_path, output_path in files_to_copy: 628 output_directory = os.path.dirname(output_path) 629 if not os.path.isdir(output_directory): 630 os.makedirs(output_directory) 631 shutil.copy(input_path, output_path) 632 return file_count, report, tree_errors 633 634 def process_tree_inplace(self, root_directory): 635 """Process a directory of python files in place.""" 636 files_to_process = [] 637 for dir_name, _, file_list in os.walk(root_directory): 638 py_files = [os.path.join(dir_name, 639 f) for f in file_list if f.endswith(".py")] 640 files_to_process += py_files 641 642 file_count = 0 643 tree_errors = {} 644 report = "" 645 report += ("=" * 80) + "\n" 646 report += "Input tree: %r\n" % root_directory 647 report += ("=" * 80) + "\n" 648 649 for path in files_to_process: 650 file_count += 1 651 _, l_report, l_errors = self.process_file(path, path) 652 tree_errors[path] = l_errors 653 report += l_report 654 655 return file_count, report, tree_errors 656