• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Upgrader for Python scripts according to an API change specification."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import ast
22import os
23import re
24import shutil
25import sys
26import tempfile
27import traceback
28
29import pasta
30import six
31
32# Some regular expressions we will need for parsing
33FIND_OPEN = re.compile(r"^\s*(\[).*$")
34FIND_STRING_CHARS = re.compile(r"['\"]")
35
36
37INFO = "INFO"
38WARNING = "WARNING"
39ERROR = "ERROR"
40
41
42def full_name_node(name, ctx=ast.Load()):
43  """Make an Attribute or Name node for name.
44
45  Translate a qualified name into nested Attribute nodes (and a Name node).
46
47  Args:
48    name: The name to translate to a node.
49    ctx: What context this name is used in. Defaults to Load()
50
51  Returns:
52    A Name or Attribute node.
53  """
54  names = name.split(".")
55  names.reverse()
56  node = ast.Name(id=names.pop(), ctx=ast.Load())
57  while names:
58    node = ast.Attribute(value=node, attr=names.pop(), ctx=ast.Load())
59
60  # Change outermost ctx to the one given to us (inner ones should be Load).
61  node.ctx = ctx
62  return node
63
64
65def get_arg_value(node, arg_name, arg_pos=None):
66  """Get the value of an argument from a ast.Call node.
67
68  This function goes through the positional and keyword arguments to check
69  whether a given argument was used, and if so, returns its value (the node
70  representing its value).
71
72  This cannot introspect *args or **args, but it safely handles *args in
73  Python3.5+.
74
75  Args:
76    node: The ast.Call node to extract arg values from.
77    arg_name: The name of the argument to extract.
78    arg_pos: The position of the argument (in case it's passed as a positional
79      argument).
80
81  Returns:
82    A tuple (arg_present, arg_value) containing a boolean indicating whether
83    the argument is present, and its value in case it is.
84  """
85  # Check keyword args
86  if arg_name is not None:
87    for kw in node.keywords:
88      if kw.arg == arg_name:
89        return (True, kw.value)
90
91  # Check positional args
92  if arg_pos is not None:
93    idx = 0
94    for arg in node.args:
95      if sys.version_info[:2] >= (3, 5) and isinstance(arg, ast.Starred):
96        continue  # Can't parse Starred
97      if idx == arg_pos:
98        return (True, arg)
99      idx += 1
100
101  return (False, None)
102
103
104class APIChangeSpec(object):
105  """This class defines the transformations that need to happen.
106
107  This class must provide the following fields:
108
109  * `function_keyword_renames`: maps function names to a map of old -> new
110    argument names
111  * `symbol_renames`: maps function names to new function names
112  * `change_to_function`: a set of function names that have changed (for
113    notifications)
114  * `function_reorders`: maps functions whose argument order has changed to the
115    list of arguments in the new order
116  * `function_warnings`: maps full names of functions to warnings that will be
117    printed out if the function is used. (e.g. tf.nn.convolution())
118  * `function_transformers`: maps function names to custom handlers
119  * `module_deprecations`: maps module names to warnings that will be printed
120    if the module is still used after all other transformations have run
121
122  For an example, see `TFAPIChangeSpec`.
123  """
124
125
126class _PastaEditVisitor(ast.NodeVisitor):
127  """AST Visitor that processes function calls.
128
129  Updates function calls from old API version to new API version using a given
130  change spec.
131  """
132
133  def __init__(self, api_change_spec):
134    self._api_change_spec = api_change_spec
135    self._log = []   # Holds 4-tuples: severity, line, col, msg.
136    self._stack = []  # Allow easy access to parents.
137
138  # Overridden to maintain a stack of nodes to allow for parent access
139  def visit(self, node):
140    self._stack.append(node)
141    super(_PastaEditVisitor, self).visit(node)
142    self._stack.pop()
143
144  @property
145  def errors(self):
146    return [log for log in self._log if log[0] == ERROR]
147
148  @property
149  def warnings(self):
150    return [log for log in self._log if log[0] == WARNING]
151
152  @property
153  def warnings_and_errors(self):
154    return [log for log in self._log if log[0] in (WARNING, ERROR)]
155
156  @property
157  def info(self):
158    return [log for log in self._log if log[0] == INFO]
159
160  @property
161  def log(self):
162    return self._log
163
164  def add_log(self, severity, lineno, col, msg):
165    self._log.append((severity, lineno, col, msg))
166    print("%s line %d:%d: %s" % (severity, lineno, col, msg))
167
168  def add_logs(self, logs):
169    """Record a log and print it.
170
171    The log should be a tuple `(severity, lineno, col_offset, msg)`, which will
172    be printed and recorded. It is part of the log available in the `self.log`
173    property.
174
175    Args:
176      logs: The logs to add. Must be a list of tuples
177        `(severity, lineno, col_offset, msg)`.
178    """
179    self._log.extend(logs)
180    for log in logs:
181      print("%s line %d:%d: %s" % log)
182
183  def _get_applicable_entries(self, transformer_field, full_name, name):
184    """Get all list entries indexed by name that apply to full_name or name."""
185    # Transformers are indexed to full name, name, or no name
186    # as a performance optimization.
187    function_transformers = getattr(self._api_change_spec,
188                                    transformer_field, {})
189
190    glob_name = "*." + name if name else None
191    transformers = []
192    if full_name in function_transformers:
193      transformers.append(function_transformers[full_name])
194    if glob_name in function_transformers:
195      transformers.append(function_transformers[glob_name])
196    if "*" in function_transformers:
197      transformers.append(function_transformers["*"])
198    return transformers
199
200  def _get_applicable_dict(self, transformer_field, full_name, name):
201    """Get all dict entries indexed by name that apply to full_name or name."""
202    # Transformers are indexed to full name, name, or no name
203    # as a performance optimization.
204    function_transformers = getattr(self._api_change_spec,
205                                    transformer_field, {})
206
207    glob_name = "*." + name if name else None
208    transformers = function_transformers.get("*", {}).copy()
209    transformers.update(function_transformers.get(glob_name, {}))
210    transformers.update(function_transformers.get(full_name, {}))
211    return transformers
212
213  def _get_full_name(self, node):
214    """Traverse an Attribute node to generate a full name, e.g., "tf.foo.bar".
215
216    This is the inverse of `full_name_node`.
217
218    Args:
219      node: A Node of type Attribute.
220
221    Returns:
222      a '.'-delimited full-name or None if node was not Attribute or Name.
223      i.e. `foo()+b).bar` returns None, while `a.b.c` would return "a.b.c".
224    """
225    curr = node
226    items = []
227    while not isinstance(curr, ast.Name):
228      if not isinstance(curr, ast.Attribute):
229        return None
230      items.append(curr.attr)
231      curr = curr.value
232    items.append(curr.id)
233    return ".".join(reversed(items))
234
235  def _maybe_add_warning(self, node, full_name):
236    """Adds an error to be printed about full_name at node."""
237    function_warnings = self._api_change_spec.function_warnings
238    if full_name in function_warnings:
239      level, message = function_warnings[full_name]
240      message = message.replace("<function name>", full_name)
241      self.add_log(level, node.lineno, node.col_offset,
242                   "%s requires manual check. %s" % (full_name, message))
243      return True
244    else:
245      return False
246
247  def _maybe_add_module_deprecation_warning(self, node, full_name, whole_name):
248    """Adds a warning if full_name is a deprecated module."""
249    warnings = self._api_change_spec.module_deprecations
250    if full_name in warnings:
251      level, message = warnings[full_name]
252      message = message.replace("<function name>", whole_name)
253      self.add_log(level, node.lineno, node.col_offset,
254                   "Using member %s in deprecated module %s. %s" % (whole_name,
255                                                                    full_name,
256                                                                    message))
257      return True
258    else:
259      return False
260
261  def _maybe_add_call_warning(self, node, full_name, name):
262    """Print a warning when specific functions are called with selected args.
263
264    The function _print_warning_for_function matches the full name of the called
265    function, e.g., tf.foo.bar(). This function matches the function name that
266    is called, as long as the function is an attribute. For example,
267    `tf.foo.bar()` and `foo.bar()` are matched, but not `bar()`.
268
269    Args:
270      node: ast.Call object
271      full_name: The precomputed full name of the callable, if one exists, None
272        otherwise.
273      name: The precomputed name of the callable, if one exists, None otherwise.
274
275    Returns:
276      Whether an error was recorded.
277    """
278    # Only look for *.-warnings here, the other will be handled by the Attribute
279    # visitor. Also, do not warn for bare functions, only if the call func is
280    # an attribute.
281    warned = False
282    if isinstance(node.func, ast.Attribute):
283      warned = self._maybe_add_warning(node, "*." + name)
284
285    # All arg warnings are handled here, since only we have the args
286    arg_warnings = self._get_applicable_dict("function_arg_warnings",
287                                             full_name, name)
288
289    for (kwarg, arg), (level, warning) in sorted(arg_warnings.items()):
290      present, _ = get_arg_value(node, kwarg, arg)
291      if present:
292        warned = True
293        warning_message = warning.replace("<function name>", full_name or name)
294        self.add_log(level, node.lineno, node.col_offset,
295                     "%s called with %s argument requires manual check: %s" %
296                     (full_name or name, kwarg, warning_message))
297
298    return warned
299
300  def _maybe_rename(self, parent, node, full_name):
301    """Replace node (Attribute or Name) with a node representing full_name."""
302    new_name = self._api_change_spec.symbol_renames.get(full_name, None)
303    if new_name:
304      self.add_log(INFO, node.lineno, node.col_offset,
305                   "Renamed %r to %r" % (full_name, new_name))
306      new_node = full_name_node(new_name, node.ctx)
307      ast.copy_location(new_node, node)
308      pasta.ast_utils.replace_child(parent, node, new_node)
309      return True
310    else:
311      return False
312
313  def _maybe_change_to_function_call(self, parent, node, full_name):
314    """Wraps node (typically, an Attribute or Expr) in a Call."""
315    if full_name in self._api_change_spec.change_to_function:
316      if not isinstance(parent, ast.Call):
317        # ast.Call's constructor is really picky about how many arguments it
318        # wants, and also, it changed between Py2 and Py3.
319        if six.PY2:
320          new_node = ast.Call(node, [], [], None, None)
321        else:
322          new_node = ast.Call(node, [], [])
323        pasta.ast_utils.replace_child(parent, node, new_node)
324        ast.copy_location(new_node, node)
325        self.add_log(INFO, node.lineno, node.col_offset,
326                     "Changed %r to a function call" % full_name)
327        return True
328    return False
329
330  def _maybe_add_arg_names(self, node, full_name):
331    """Make args into keyword args if function called full_name requires it."""
332    function_reorders = self._api_change_spec.function_reorders
333
334    if full_name in function_reorders:
335      reordered = function_reorders[full_name]
336      new_keywords = []
337      idx = 0
338      for arg in node.args:
339        if sys.version_info[:2] >= (3, 5) and isinstance(arg, ast.Starred):
340          continue  # Can't move Starred to keywords
341        keyword_arg = reordered[idx]
342        keyword = ast.keyword(arg=keyword_arg, value=arg)
343        new_keywords.append(keyword)
344        idx += 1
345
346      if new_keywords:
347        self.add_log(INFO, node.lineno, node.col_offset,
348                     "Added keywords to args of function %r" % full_name)
349        node.args = []
350        node.keywords = new_keywords + (node.keywords or [])
351        return True
352    return False
353
354  def _maybe_modify_args(self, node, full_name, name):
355    """Rename keyword args if the function called full_name requires it."""
356    renamed_keywords = self._get_applicable_dict("function_keyword_renames",
357                                                 full_name, name)
358
359    if not renamed_keywords:
360      return False
361
362    modified = False
363    new_keywords = []
364    for keyword in node.keywords:
365      argkey = keyword.arg
366      if argkey in renamed_keywords:
367        modified = True
368        if renamed_keywords[argkey] is None:
369          lineno = getattr(keyword, "lineno", node.lineno)
370          col_offset = getattr(keyword, "col_offset", node.col_offset)
371          self.add_log(INFO, lineno, col_offset,
372                       "Removed argument %s for function %s" % (
373                           argkey, full_name or name))
374        else:
375          keyword.arg = renamed_keywords[argkey]
376          lineno = getattr(keyword, "lineno", node.lineno)
377          col_offset = getattr(keyword, "col_offset", node.col_offset)
378          self.add_log(INFO, lineno, col_offset,
379                       "Renamed keyword argument for %s from %s to %s" % (
380                           full_name, argkey, renamed_keywords[argkey]))
381          new_keywords.append(keyword)
382      else:
383        new_keywords.append(keyword)
384
385    if modified:
386      node.keywords = new_keywords
387    return modified
388
389  def visit_Call(self, node):  # pylint: disable=invalid-name
390    """Handle visiting a call node in the AST.
391
392    Args:
393      node: Current Node
394    """
395    assert self._stack[-1] is node
396
397    # Get the name for this call, so we can index stuff with it.
398    full_name = self._get_full_name(node.func)
399    if full_name:
400      name = full_name.split(".")[-1]
401    elif isinstance(node.func, ast.Name):
402      name = node.func.id
403    elif isinstance(node.func, ast.Attribute):
404      name = node.func.attr
405    else:
406      name = None
407
408    # Call standard transformers for this node.
409    # Make sure warnings come first, since args or names triggering warnings
410    # may be removed by the other transformations.
411    self._maybe_add_call_warning(node, full_name, name)
412    # Make all args into kwargs
413    self._maybe_add_arg_names(node, full_name)
414    # Argument name changes or deletions
415    self._maybe_modify_args(node, full_name, name)
416
417    # Call transformers. These have the ability to modify the node, and if they
418    # do, will return the new node they created (or the same node if they just
419    # changed it). The are given the parent, but we will take care of
420    # integrating their changes into the parent if they return a new node.
421    #
422    # These are matched on the old name, since renaming is performed by the
423    # Attribute visitor, which happens later.
424    transformers = self._get_applicable_entries("function_transformers",
425                                                full_name, name)
426
427    parent = self._stack[-2]
428
429    for transformer in transformers:
430      logs = []
431      new_node = transformer(parent, node, full_name, name, logs)
432      self.add_logs(logs)
433      if new_node and new_node is not node:
434        pasta.ast_utils.replace_child(parent, node, new_node)
435        node = new_node
436        self._stack[-1] = node
437
438    self.generic_visit(node)
439
440  def visit_Attribute(self, node):  # pylint: disable=invalid-name
441    """Handle bare Attributes i.e. [tf.foo, tf.bar]."""
442    assert self._stack[-1] is node
443
444    full_name = self._get_full_name(node)
445    if full_name:
446      parent = self._stack[-2]
447
448      # Make sure the warning comes first, otherwise the name may have changed
449      self._maybe_add_warning(node, full_name)
450
451      # Once we did a modification, node is invalid and not worth inspecting
452      # further. Also, we only perform modifications for simple nodes, so
453      # There'd be no point in descending further.
454      if self._maybe_rename(parent, node, full_name):
455        return
456      if self._maybe_change_to_function_call(parent, node, full_name):
457        return
458
459      # The isinstance check is enough -- a bare Attribute is never root.
460      i = 2
461      while isinstance(self._stack[-i], ast.Attribute):
462        i += 1
463      whole_name = pasta.dump(self._stack[-(i-1)])
464
465      self._maybe_add_module_deprecation_warning(node, full_name, whole_name)
466
467    self.generic_visit(node)
468
469
470class ASTCodeUpgrader(object):
471  """Handles upgrading a set of Python files using a given API change spec."""
472
473  def __init__(self, api_change_spec):
474    if not isinstance(api_change_spec, APIChangeSpec):
475      raise TypeError("Must pass APIChangeSpec to ASTCodeUpgrader, got %s" %
476                      type(api_change_spec))
477    self._api_change_spec = api_change_spec
478
479  def process_file(self, in_filename, out_filename):
480    """Process the given python file for incompatible changes.
481
482    Args:
483      in_filename: filename to parse
484      out_filename: output file to write to
485    Returns:
486      A tuple representing number of files processed, log of actions, errors
487    """
488
489    # Write to a temporary file, just in case we are doing an implace modify.
490    # pylint: disable=g-backslash-continuation
491    with open(in_filename, "r") as in_file, \
492        tempfile.NamedTemporaryFile("w", delete=False) as temp_file:
493      ret = self.process_opened_file(in_filename, in_file, out_filename,
494                                     temp_file)
495    # pylint: enable=g-backslash-continuation
496
497    shutil.move(temp_file.name, out_filename)
498    return ret
499
500  def format_log(self, log, in_filename):
501    log_string = "%d:%d: %s: %s" % (log[1], log[2], log[0], log[3])
502    if in_filename:
503      return in_filename + ":" + log_string
504    else:
505      return log_string
506
507  def update_string_pasta(self, text, in_filename):
508    """Updates a file using pasta."""
509    try:
510      t = pasta.parse(text)
511    except (SyntaxError, ValueError, TypeError):
512      log = ["ERROR: Failed to parse.\n" + traceback.format_exc()]
513      return 0, "", log, []
514
515    visitor = _PastaEditVisitor(self._api_change_spec)
516    visitor.visit(t)
517
518    logs = [self.format_log(log, None) for log in visitor.log]
519    errors = [self.format_log(error, in_filename)
520              for error in visitor.warnings_and_errors]
521    return 1, pasta.dump(t), logs, errors
522
523  def _format_log(self, log, in_filename, out_filename):
524    text = "-" * 80 + "\n"
525    text += "Processing file %r\n outputting to %r\n" % (in_filename,
526                                                         out_filename)
527    text += "-" * 80 + "\n\n"
528    text += "\n".join(log) + "\n"
529    text += "-" * 80 + "\n\n"
530    return text
531
532  def process_opened_file(self, in_filename, in_file, out_filename, out_file):
533    """Process the given python file for incompatible changes.
534
535    This function is split out to facilitate StringIO testing from
536    tf_upgrade_test.py.
537
538    Args:
539      in_filename: filename to parse
540      in_file: opened file (or StringIO)
541      out_filename: output file to write to
542      out_file: opened file (or StringIO)
543    Returns:
544      A tuple representing number of files processed, log of actions, errors
545    """
546    lines = in_file.readlines()
547    processed_file, new_file_content, log, process_errors = (
548        self.update_string_pasta("".join(lines), in_filename))
549
550    if out_file and processed_file:
551      out_file.write(new_file_content)
552
553    return (processed_file,
554            self._format_log(log, in_filename, out_filename),
555            process_errors)
556
557  def process_tree(self, root_directory, output_root_directory,
558                   copy_other_files):
559    """Processes upgrades on an entire tree of python files in place.
560
561    Note that only Python files. If you have custom code in other languages,
562    you will need to manually upgrade those.
563
564    Args:
565      root_directory: Directory to walk and process.
566      output_root_directory: Directory to use as base.
567      copy_other_files: Copy files that are not touched by this converter.
568
569    Returns:
570      A tuple of files processed, the report string for all files, and a dict
571        mapping filenames to errors encountered in that file.
572    """
573
574    if output_root_directory == root_directory:
575      return self.process_tree_inplace(root_directory)
576
577    # make sure output directory doesn't exist
578    if output_root_directory and os.path.exists(output_root_directory):
579      print("Output directory %r must not already exist." %
580            (output_root_directory))
581      sys.exit(1)
582
583    # make sure output directory does not overlap with root_directory
584    norm_root = os.path.split(os.path.normpath(root_directory))
585    norm_output = os.path.split(os.path.normpath(output_root_directory))
586    if norm_root == norm_output:
587      print("Output directory %r same as input directory %r" %
588            (root_directory, output_root_directory))
589      sys.exit(1)
590
591    # Collect list of files to process (we do this to correctly handle if the
592    # user puts the output directory in some sub directory of the input dir)
593    files_to_process = []
594    files_to_copy = []
595    for dir_name, _, file_list in os.walk(root_directory):
596      py_files = [f for f in file_list if f.endswith(".py")]
597      copy_files = [f for f in file_list if not f.endswith(".py")]
598      for filename in py_files:
599        fullpath = os.path.join(dir_name, filename)
600        fullpath_output = os.path.join(output_root_directory,
601                                       os.path.relpath(fullpath,
602                                                       root_directory))
603        files_to_process.append((fullpath, fullpath_output))
604      if copy_other_files:
605        for filename in copy_files:
606          fullpath = os.path.join(dir_name, filename)
607          fullpath_output = os.path.join(output_root_directory,
608                                         os.path.relpath(
609                                             fullpath, root_directory))
610          files_to_copy.append((fullpath, fullpath_output))
611
612    file_count = 0
613    tree_errors = {}
614    report = ""
615    report += ("=" * 80) + "\n"
616    report += "Input tree: %r\n" % root_directory
617    report += ("=" * 80) + "\n"
618
619    for input_path, output_path in files_to_process:
620      output_directory = os.path.dirname(output_path)
621      if not os.path.isdir(output_directory):
622        os.makedirs(output_directory)
623      file_count += 1
624      _, l_report, l_errors = self.process_file(input_path, output_path)
625      tree_errors[input_path] = l_errors
626      report += l_report
627    for input_path, output_path in files_to_copy:
628      output_directory = os.path.dirname(output_path)
629      if not os.path.isdir(output_directory):
630        os.makedirs(output_directory)
631      shutil.copy(input_path, output_path)
632    return file_count, report, tree_errors
633
634  def process_tree_inplace(self, root_directory):
635    """Process a directory of python files in place."""
636    files_to_process = []
637    for dir_name, _, file_list in os.walk(root_directory):
638      py_files = [os.path.join(dir_name,
639                               f) for f in file_list if f.endswith(".py")]
640      files_to_process += py_files
641
642    file_count = 0
643    tree_errors = {}
644    report = ""
645    report += ("=" * 80) + "\n"
646    report += "Input tree: %r\n" % root_directory
647    report += ("=" * 80) + "\n"
648
649    for path in files_to_process:
650      file_count += 1
651      _, l_report, l_errors = self.process_file(path, path)
652      tree_errors[path] = l_errors
653      report += l_report
654
655    return file_count, report, tree_errors
656