• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015-2017 Google Inc. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Entry points for YAPF.
15
16The main APIs that YAPF exposes to drive the reformatting.
17
18  FormatFile(): reformat a file.
19  FormatCode(): reformat a string of code.
20
21These APIs have some common arguments:
22
23  style_config: (string) Either a style name or a path to a file that contains
24    formatting style settings. If None is specified, use the default style
25    as set in style.DEFAULT_STYLE_FACTORY
26  lines: (list of tuples of integers) A list of tuples of lines, [start, end],
27    that we want to format. The lines are 1-based indexed. It can be used by
28    third-party code (e.g., IDEs) when reformatting a snippet of code rather
29    than a whole file.
30  print_diff: (bool) Instead of returning the reformatted source, return a
31    diff that turns the formatted source into reformatter source.
32  verify: (bool) True if reformatted code should be verified for syntax.
33"""
34
35import difflib
36import re
37import sys
38
39from lib2to3.pgen2 import tokenize
40
41from yapf.yapflib import blank_line_calculator
42from yapf.yapflib import comment_splicer
43from yapf.yapflib import continuation_splicer
44from yapf.yapflib import file_resources
45from yapf.yapflib import py3compat
46from yapf.yapflib import pytree_unwrapper
47from yapf.yapflib import pytree_utils
48from yapf.yapflib import reformatter
49from yapf.yapflib import split_penalty
50from yapf.yapflib import style
51from yapf.yapflib import subtype_assigner
52
53
54def FormatFile(filename,
55               style_config=None,
56               lines=None,
57               print_diff=False,
58               verify=False,
59               in_place=False,
60               logger=None):
61  """Format a single Python file and return the formatted code.
62
63  Arguments:
64    filename: (unicode) The file to reformat.
65    in_place: (bool) If True, write the reformatted code back to the file.
66    logger: (io streamer) A stream to output logging.
67    remaining arguments: see comment at the top of this module.
68
69  Returns:
70    Tuple of (reformatted_code, encoding, changed). reformatted_code is None if
71    the file is sucessfully written to (having used in_place). reformatted_code
72    is a diff if print_diff is True.
73
74  Raises:
75    IOError: raised if there was an error reading the file.
76    ValueError: raised if in_place and print_diff are both specified.
77  """
78  _CheckPythonVersion()
79
80  if in_place and print_diff:
81    raise ValueError('Cannot pass both in_place and print_diff.')
82
83  original_source, newline, encoding = ReadFile(filename, logger)
84  reformatted_source, changed = FormatCode(
85      original_source,
86      style_config=style_config,
87      filename=filename,
88      lines=lines,
89      print_diff=print_diff,
90      verify=verify)
91  if reformatted_source.rstrip('\n'):
92    lines = reformatted_source.rstrip('\n').split('\n')
93    reformatted_source = newline.join(line for line in lines) + newline
94  if in_place:
95    if original_source and original_source != reformatted_source:
96      file_resources.WriteReformattedCode(filename, reformatted_source,
97                                          in_place, encoding)
98    return None, encoding, changed
99
100  return reformatted_source, encoding, changed
101
102
103def FormatCode(unformatted_source,
104               filename='<unknown>',
105               style_config=None,
106               lines=None,
107               print_diff=False,
108               verify=False):
109  """Format a string of Python code.
110
111  This provides an alternative entry point to YAPF.
112
113  Arguments:
114    unformatted_source: (unicode) The code to format.
115    filename: (unicode) The name of the file being reformatted.
116    remaining arguments: see comment at the top of this module.
117
118  Returns:
119    Tuple of (reformatted_source, changed). reformatted_source conforms to the
120    desired formatting style. changed is True if the source changed.
121  """
122  _CheckPythonVersion()
123  style.SetGlobalStyle(style.CreateStyleFromConfig(style_config))
124  if not unformatted_source.endswith('\n'):
125    unformatted_source += '\n'
126  tree = pytree_utils.ParseCodeToTree(unformatted_source)
127
128  # Run passes on the tree, modifying it in place.
129  comment_splicer.SpliceComments(tree)
130  continuation_splicer.SpliceContinuations(tree)
131  subtype_assigner.AssignSubtypes(tree)
132  split_penalty.ComputeSplitPenalties(tree)
133  blank_line_calculator.CalculateBlankLines(tree)
134
135  uwlines = pytree_unwrapper.UnwrapPyTree(tree)
136  for uwl in uwlines:
137    uwl.CalculateFormattingInformation()
138
139  _MarkLinesToFormat(uwlines, lines)
140  reformatted_source = reformatter.Reformat(uwlines, verify)
141
142  if unformatted_source == reformatted_source:
143    return '' if print_diff else reformatted_source, False
144
145  code_diff = _GetUnifiedDiff(
146      unformatted_source, reformatted_source, filename=filename)
147
148  if print_diff:
149    return code_diff, code_diff != ''
150
151  return reformatted_source, True
152
153
154def _CheckPythonVersion():  # pragma: no cover
155  errmsg = 'yapf is only supported for Python 2.7 or 3.4+'
156  if sys.version_info[0] == 2:
157    if sys.version_info[1] < 7:
158      raise RuntimeError(errmsg)
159  elif sys.version_info[0] == 3:
160    if sys.version_info[1] < 4:
161      raise RuntimeError(errmsg)
162
163
164def ReadFile(filename, logger=None):
165  """Read the contents of the file.
166
167  An optional logger can be specified to emit messages to your favorite logging
168  stream. If specified, then no exception is raised. This is external so that it
169  can be used by third-party applications.
170
171  Arguments:
172    filename: (unicode) The name of the file.
173    logger: (function) A function or lambda that takes a string and emits it.
174
175  Returns:
176    The contents of filename.
177
178  Raises:
179    IOError: raised if there was an error reading the file.
180  """
181  try:
182    with open(filename, 'rb') as fd:
183      encoding = tokenize.detect_encoding(fd.readline)[0]
184  except IOError as err:
185    if logger:
186      logger(err)
187    raise
188
189  try:
190    # Preserves line endings.
191    with py3compat.open_with_encoding(
192        filename, mode='r', encoding=encoding, newline='') as fd:
193      lines = fd.readlines()
194
195    line_ending = file_resources.LineEnding(lines)
196    source = '\n'.join(line.rstrip('\r\n') for line in lines) + '\n'
197    return source, line_ending, encoding
198  except IOError as err:  # pragma: no cover
199    if logger:
200      logger(err)
201    raise
202
203
204DISABLE_PATTERN = r'^#.*\byapf:\s*disable\b'
205ENABLE_PATTERN = r'^#.*\byapf:\s*enable\b'
206
207
208def _MarkLinesToFormat(uwlines, lines):
209  """Skip sections of code that we shouldn't reformat."""
210  if lines:
211    for uwline in uwlines:
212      uwline.disable = True
213
214    # Sort and combine overlapping ranges.
215    lines = sorted(lines)
216    line_ranges = [lines[0]] if len(lines[0]) else []
217    index = 1
218    while index < len(lines):
219      current = line_ranges[-1]
220      if lines[index][0] <= current[1]:
221        # The ranges overlap, so combine them.
222        line_ranges[-1] = (current[0], max(lines[index][1], current[1]))
223      else:
224        line_ranges.append(lines[index])
225      index += 1
226
227    # Mark lines to format as not being disabled.
228    index = 0
229    for start, end in sorted(line_ranges):
230      while index < len(uwlines) and uwlines[index].last.lineno < start:
231        index += 1
232      if index >= len(uwlines):
233        break
234
235      while index < len(uwlines):
236        if uwlines[index].lineno > end:
237          break
238        if (uwlines[index].lineno >= start or
239            uwlines[index].last.lineno >= start):
240          uwlines[index].disable = False
241        index += 1
242
243  # Now go through the lines and disable any lines explicitly marked as
244  # disabled.
245  index = 0
246  while index < len(uwlines):
247    uwline = uwlines[index]
248    if uwline.is_comment:
249      if _DisableYAPF(uwline.first.value.strip()):
250        index += 1
251        while index < len(uwlines):
252          uwline = uwlines[index]
253          if uwline.is_comment and _EnableYAPF(uwline.first.value.strip()):
254            break
255          uwline.disable = True
256          index += 1
257    elif re.search(DISABLE_PATTERN, uwline.last.value.strip(), re.IGNORECASE):
258      uwline.disable = True
259    index += 1
260
261
262def _DisableYAPF(line):
263  return (
264      re.search(DISABLE_PATTERN, line.split('\n')[0].strip(), re.IGNORECASE) or
265      re.search(DISABLE_PATTERN, line.split('\n')[-1].strip(), re.IGNORECASE))
266
267
268def _EnableYAPF(line):
269  return (
270      re.search(ENABLE_PATTERN, line.split('\n')[0].strip(), re.IGNORECASE) or
271      re.search(ENABLE_PATTERN, line.split('\n')[-1].strip(), re.IGNORECASE))
272
273
274def _GetUnifiedDiff(before, after, filename='code'):
275  """Get a unified diff of the changes.
276
277  Arguments:
278    before: (unicode) The original source code.
279    after: (unicode) The reformatted source code.
280    filename: (unicode) The code's filename.
281
282  Returns:
283    The unified diff text.
284  """
285  before = before.splitlines()
286  after = after.splitlines()
287  return '\n'.join(
288      difflib.unified_diff(
289          before,
290          after,
291          filename,
292          filename,
293          '(original)',
294          '(reformatted)',
295          lineterm='')) + '\n'
296