• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2
3# Copyright 2020 The Pigweed Authors
4#
5# Licensed under the Apache License, Version 2.0 (the "License"); you may not
6# use this file except in compliance with the License. You may obtain a copy of
7# the License at
8#
9#     https://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
14# License for the specific language governing permissions and limitations under
15# the License.
16"""Checks and fixes formatting for source files.
17
18This uses clang-format, gn format, gofmt, and python -m yapf to format source
19code. These tools must be available on the path when this script is invoked!
20"""
21
22import argparse
23import collections
24import difflib
25import logging
26import os
27from pathlib import Path
28import re
29import subprocess
30import sys
31from typing import Callable, Collection, Dict, Iterable, List, NamedTuple
32from typing import Optional, Pattern, Tuple, Union
33
34try:
35    import pw_presubmit
36except ImportError:
37    # Append the pw_presubmit package path to the module search path to allow
38    # running this module without installing the pw_presubmit package.
39    sys.path.append(os.path.dirname(os.path.dirname(
40        os.path.abspath(__file__))))
41    import pw_presubmit
42
43from pw_presubmit import cli, git_repo
44from pw_presubmit.tools import exclude_paths, file_summary, log_run, plural
45
46_LOG: logging.Logger = logging.getLogger(__name__)
47
48
49def _colorize_diff_line(line: str) -> str:
50    if line.startswith('--- ') or line.startswith('+++ '):
51        return pw_presubmit.color_bold_white(line)
52    if line.startswith('-'):
53        return pw_presubmit.color_red(line)
54    if line.startswith('+'):
55        return pw_presubmit.color_green(line)
56    if line.startswith('@@ '):
57        return pw_presubmit.color_aqua(line)
58    return line
59
60
61def colorize_diff(lines: Iterable[str]) -> str:
62    """Takes a diff str or list of str lines and returns a colorized version."""
63    if isinstance(lines, str):
64        lines = lines.splitlines(True)
65
66    return ''.join(_colorize_diff_line(line) for line in lines)
67
68
69def _diff(path, original: bytes, formatted: bytes) -> str:
70    return colorize_diff(
71        difflib.unified_diff(
72            original.decode(errors='replace').splitlines(True),
73            formatted.decode(errors='replace').splitlines(True),
74            f'{path}  (original)', f'{path}  (reformatted)'))
75
76
77Formatter = Callable[[str, bytes], bytes]
78
79
80def _diff_formatted(path, formatter: Formatter) -> Optional[str]:
81    """Returns a diff comparing a file to its formatted version."""
82    with open(path, 'rb') as fd:
83        original = fd.read()
84
85    formatted = formatter(path, original)
86
87    return None if formatted == original else _diff(path, original, formatted)
88
89
90def _check_files(files, formatter: Formatter) -> Dict[Path, str]:
91    errors = {}
92
93    for path in files:
94        difference = _diff_formatted(path, formatter)
95        if difference:
96            errors[path] = difference
97
98    return errors
99
100
101def _clang_format(*args: str, **kwargs) -> bytes:
102    return log_run(['clang-format', '--style=file', *args],
103                   stdout=subprocess.PIPE,
104                   check=True,
105                   **kwargs).stdout
106
107
108def clang_format_check(files: Iterable[Path]) -> Dict[Path, str]:
109    """Checks formatting; returns {path: diff} for files with bad formatting."""
110    return _check_files(files, lambda path, _: _clang_format(path))
111
112
113def clang_format_fix(files: Iterable) -> None:
114    """Fixes formatting for the provided files in place."""
115    _clang_format('-i', *files)
116
117
118def check_gn_format(files: Iterable[Path]) -> Dict[Path, str]:
119    """Checks formatting; returns {path: diff} for files with bad formatting."""
120    return _check_files(
121        files, lambda _, data: log_run(['gn', 'format', '--stdin'],
122                                       input=data,
123                                       stdout=subprocess.PIPE,
124                                       check=True).stdout)
125
126
127def fix_gn_format(files: Iterable[Path]) -> None:
128    """Fixes formatting for the provided files in place."""
129    log_run(['gn', 'format', *files], check=True)
130
131
132def check_go_format(files: Iterable[Path]) -> Dict[Path, str]:
133    """Checks formatting; returns {path: diff} for files with bad formatting."""
134    return _check_files(
135        files, lambda path, _: log_run(
136            ['gofmt', path], stdout=subprocess.PIPE, check=True).stdout)
137
138
139def fix_go_format(files: Iterable[Path]) -> None:
140    """Fixes formatting for the provided files in place."""
141    log_run(['gofmt', '-w', *files], check=True)
142
143
144def _yapf(*args, **kwargs) -> subprocess.CompletedProcess:
145    return log_run(['python', '-m', 'yapf', '--parallel', *args],
146                   capture_output=True,
147                   **kwargs)
148
149
150_DIFF_START = re.compile(r'^--- (.*)\s+\(original\)$', flags=re.MULTILINE)
151
152
153def check_py_format(files: Iterable[Path]) -> Dict[Path, str]:
154    """Checks formatting; returns {path: diff} for files with bad formatting."""
155    process = _yapf('--diff', *files)
156
157    errors: Dict[Path, str] = {}
158
159    if process.stdout:
160        raw_diff = process.stdout.decode(errors='replace')
161
162        matches = tuple(_DIFF_START.finditer(raw_diff))
163        for start, end in zip(matches, (*matches[1:], None)):
164            errors[Path(start.group(1))] = colorize_diff(
165                raw_diff[start.start():end.start() if end else None])
166
167    if process.stderr:
168        _LOG.error('yapf encountered an error:\n%s',
169                   process.stderr.decode(errors='replace').rstrip())
170        errors.update({file: '' for file in files if file not in errors})
171
172    return errors
173
174
175def fix_py_format(files: Iterable):
176    """Fixes formatting for the provided files in place."""
177    _yapf('--in-place', *files, check=True)
178
179
180_TRAILING_SPACE = re.compile(rb'[ \t]+$', flags=re.MULTILINE)
181
182
183def _check_trailing_space(paths: Iterable[Path], fix: bool) -> Dict[Path, str]:
184    """Checks for and optionally removes trailing whitespace."""
185    errors = {}
186
187    for path in paths:
188        with path.open('rb') as fd:
189            contents = fd.read()
190
191        corrected = _TRAILING_SPACE.sub(b'', contents)
192        if corrected != contents:
193            errors[path] = _diff(path, contents, corrected)
194
195            if fix:
196                with path.open('wb') as fd:
197                    fd.write(corrected)
198
199    return errors
200
201
202def check_trailing_space(files: Iterable[Path]) -> Dict[Path, str]:
203    return _check_trailing_space(files, fix=False)
204
205
206def fix_trailing_space(files: Iterable[Path]) -> None:
207    _check_trailing_space(files, fix=True)
208
209
210def print_format_check(errors: Dict[Path, str],
211                       show_fix_commands: bool) -> None:
212    """Prints and returns the result of a check_*_format function."""
213    if not errors:
214        # Don't print anything in the all-good case.
215        return
216
217    # Show the format fixing diff suggested by the tooling (with colors).
218    _LOG.warning('Found %d files with formatting errors. Format changes:',
219                 len(errors))
220    for diff in errors.values():
221        print(diff, end='')
222
223    # Show a copy-and-pastable command to fix the issues.
224    if show_fix_commands:
225
226        def path_relative_to_cwd(path):
227            try:
228                return Path(path).resolve().relative_to(Path.cwd().resolve())
229            except ValueError:
230                return Path(path).resolve()
231
232        message = (f'  pw format --fix {path_relative_to_cwd(path)}'
233                   for path in errors)
234        _LOG.warning('To fix formatting, run:\n\n%s\n', '\n'.join(message))
235
236
237class CodeFormat(NamedTuple):
238    language: str
239    extensions: Collection[str]
240    exclude: Collection[str]
241    check: Callable[[Iterable], Dict[Path, str]]
242    fix: Callable[[Iterable], None]
243
244
245C_FORMAT: CodeFormat = CodeFormat(
246    'C and C++',
247    frozenset(['.h', '.hh', '.hpp', '.c', '.cc', '.cpp', '.inc', '.inl']),
248    (r'\.pb\.h$', r'\.pb\.c$'), clang_format_check, clang_format_fix)
249
250PROTO_FORMAT: CodeFormat = CodeFormat('Protocol buffer', ('.proto', ), (),
251                                      clang_format_check, clang_format_fix)
252
253JAVA_FORMAT: CodeFormat = CodeFormat('Java', ('.java', ), (),
254                                     clang_format_check, clang_format_fix)
255
256JAVASCRIPT_FORMAT: CodeFormat = CodeFormat('JavaScript', ('.js', ), (),
257                                           clang_format_check,
258                                           clang_format_fix)
259
260GO_FORMAT: CodeFormat = CodeFormat('Go', ('.go', ), (), check_go_format,
261                                   fix_go_format)
262
263PYTHON_FORMAT: CodeFormat = CodeFormat('Python', ('.py', ), (),
264                                       check_py_format, fix_py_format)
265
266GN_FORMAT: CodeFormat = CodeFormat('GN', ('.gn', '.gni'), (), check_gn_format,
267                                   fix_gn_format)
268
269# TODO(pwbug/191): Add real code formatting support for Bazel and CMake
270BAZEL_FORMAT: CodeFormat = CodeFormat('Bazel', ('BUILD', ), (),
271                                      check_trailing_space, fix_trailing_space)
272
273CMAKE_FORMAT: CodeFormat = CodeFormat('CMake', ('CMakeLists.txt', '.cmake'),
274                                      (), check_trailing_space,
275                                      fix_trailing_space)
276
277RST_FORMAT: CodeFormat = CodeFormat('reStructuredText', ('.rst', ), (),
278                                    check_trailing_space, fix_trailing_space)
279
280MARKDOWN_FORMAT: CodeFormat = CodeFormat('Markdown', ('.md', ), (),
281                                         check_trailing_space,
282                                         fix_trailing_space)
283
284CODE_FORMATS: Tuple[CodeFormat, ...] = (
285    C_FORMAT,
286    JAVA_FORMAT,
287    JAVASCRIPT_FORMAT,
288    PROTO_FORMAT,
289    GO_FORMAT,
290    PYTHON_FORMAT,
291    GN_FORMAT,
292    BAZEL_FORMAT,
293    CMAKE_FORMAT,
294    RST_FORMAT,
295    MARKDOWN_FORMAT,
296)
297
298
299def presubmit_check(code_format: CodeFormat, **filter_paths_args) -> Callable:
300    """Creates a presubmit check function from a CodeFormat object."""
301    filter_paths_args.setdefault('endswith', code_format.extensions)
302    filter_paths_args.setdefault('exclude', code_format.exclude)
303
304    @pw_presubmit.filter_paths(**filter_paths_args)
305    def check_code_format(ctx: pw_presubmit.PresubmitContext):
306        errors = code_format.check(ctx.paths)
307        print_format_check(
308            errors,
309            # When running as part of presubmit, show the fix command help.
310            show_fix_commands=True,
311        )
312        if errors:
313            raise pw_presubmit.PresubmitFailure
314
315    language = code_format.language.lower().replace('+', 'p').replace(' ', '_')
316    check_code_format.__name__ = f'{language}_format'
317
318    return check_code_format
319
320
321def presubmit_checks(**filter_paths_args) -> Tuple[Callable, ...]:
322    """Returns a tuple with all supported code format presubmit checks."""
323    return tuple(
324        presubmit_check(fmt, **filter_paths_args) for fmt in CODE_FORMATS)
325
326
327class CodeFormatter:
328    """Checks or fixes the formatting of a set of files."""
329    def __init__(self, files: Iterable[Path]):
330        self.paths = list(files)
331        self._formats: Dict[CodeFormat, List] = collections.defaultdict(list)
332
333        for path in self.paths:
334            for code_format in CODE_FORMATS:
335                if any(path.as_posix().endswith(e)
336                       for e in code_format.extensions):
337                    self._formats[code_format].append(path)
338
339    def check(self) -> Dict[Path, str]:
340        """Returns {path: diff} for files with incorrect formatting."""
341        errors: Dict[Path, str] = {}
342
343        for code_format, files in self._formats.items():
344            _LOG.debug('Checking %s', ', '.join(str(f) for f in files))
345            errors.update(code_format.check(files))
346
347        return collections.OrderedDict(sorted(errors.items()))
348
349    def fix(self) -> None:
350        """Fixes format errors for supported files in place."""
351        for code_format, files in self._formats.items():
352            code_format.fix(files)
353            _LOG.info('Formatted %s',
354                      plural(files, code_format.language + ' file'))
355
356
357def _file_summary(files: Iterable[Union[Path, str]], base: Path) -> List[str]:
358    try:
359        return file_summary(
360            Path(f).resolve().relative_to(base.resolve()) for f in files)
361    except ValueError:
362        return []
363
364
365def format_paths_in_repo(paths: Collection[Union[Path, str]],
366                         exclude: Collection[Pattern[str]], fix: bool,
367                         base: str) -> int:
368    """Checks or fixes formatting for files in a Git repo."""
369    files = [Path(path).resolve() for path in paths if os.path.isfile(path)]
370    repo = git_repo.root() if git_repo.is_repo() else None
371
372    # If this is a Git repo, list the original paths with git ls-files or diff.
373    if repo:
374        _LOG.info(
375            'Formatting %s',
376            git_repo.describe_files(repo, Path.cwd(), base, paths, exclude))
377
378        # Add files from Git and remove duplicates.
379        files = sorted(
380            set(exclude_paths(exclude, git_repo.list_files(base, paths)))
381            | set(files))
382    elif base:
383        _LOG.critical(
384            'A base commit may only be provided if running from a Git repo')
385        return 1
386
387    return format_files(files, fix, repo=repo)
388
389
390def format_files(paths: Collection[Union[Path, str]],
391                 fix: bool,
392                 repo: Optional[Path] = None) -> int:
393    """Checks or fixes formatting for the specified files."""
394    formatter = CodeFormatter(Path(p) for p in paths)
395
396    _LOG.info('Checking formatting for %s', plural(formatter.paths, 'file'))
397
398    for line in _file_summary(paths, repo if repo else Path.cwd()):
399        print(line, file=sys.stderr)
400
401    errors = formatter.check()
402    print_format_check(errors, show_fix_commands=(not fix))
403
404    if errors:
405        if fix:
406            formatter.fix()
407            # TODO: This should perhaps check that the fixes were successful.
408            _LOG.info('Formatting fixes applied successfully')
409            return 0
410
411        _LOG.error('Formatting errors found')
412        return 1
413
414    _LOG.info('Congratulations! No formatting changes needed')
415    return 0
416
417
418def arguments(git_paths: bool) -> argparse.ArgumentParser:
419    """Creates an argument parser for format_files or format_paths_in_repo."""
420
421    parser = argparse.ArgumentParser(description=__doc__)
422
423    if git_paths:
424        cli.add_path_arguments(parser)
425    else:
426
427        def existing_path(arg: str) -> Path:
428            path = Path(arg)
429            if not path.is_file():
430                raise argparse.ArgumentTypeError(
431                    f'{arg} is not a path to a file')
432
433            return path
434
435        parser.add_argument('paths',
436                            metavar='path',
437                            nargs='+',
438                            type=existing_path,
439                            help='File paths to check')
440
441    parser.add_argument('--fix',
442                        action='store_true',
443                        help='Apply formatting fixes in place.')
444    return parser
445
446
447def main() -> int:
448    """Check and fix formatting for source files."""
449    return format_paths_in_repo(**vars(arguments(git_paths=True).parse_args()))
450
451
452if __name__ == '__main__':
453    try:
454        # If pw_cli is available, use it to initialize logs.
455        from pw_cli import log
456
457        log.install(logging.INFO)
458    except ImportError:
459        # If pw_cli isn't available, display log messages like a simple print.
460        logging.basicConfig(format='%(message)s', level=logging.INFO)
461
462    sys.exit(main())
463