• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2
3# Copyright 2020 The Pigweed Authors
4#
5# Licensed under the Apache License, Version 2.0 (the "License"); you may not
6# use this file except in compliance with the License. You may obtain a copy of
7# the License at
8#
9#     https://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
14# License for the specific language governing permissions and limitations under
15# the License.
16"""Checks and fixes formatting for source files.
17
18This uses clang-format, gn format, gofmt, and python -m yapf to format source
19code. These tools must be available on the path when this script is invoked!
20"""
21
22import argparse
23import collections
24import difflib
25import logging
26import os
27from pathlib import Path
28import re
29import subprocess
30import sys
31import tempfile
32from typing import Callable, Collection, Dict, Iterable, List, NamedTuple
33from typing import Optional, Pattern, Tuple, Union
34
35try:
36    import pw_presubmit
37except ImportError:
38    # Append the pw_presubmit package path to the module search path to allow
39    # running this module without installing the pw_presubmit package.
40    sys.path.append(os.path.dirname(os.path.dirname(
41        os.path.abspath(__file__))))
42    import pw_presubmit
43
44import pw_cli.env
45from pw_presubmit import cli, git_repo
46from pw_presubmit.tools import exclude_paths, file_summary, log_run, plural
47
48_LOG: logging.Logger = logging.getLogger(__name__)
49
50
51def _colorize_diff_line(line: str) -> str:
52    if line.startswith('--- ') or line.startswith('+++ '):
53        return pw_presubmit.color_bold_white(line)
54    if line.startswith('-'):
55        return pw_presubmit.color_red(line)
56    if line.startswith('+'):
57        return pw_presubmit.color_green(line)
58    if line.startswith('@@ '):
59        return pw_presubmit.color_aqua(line)
60    return line
61
62
63def colorize_diff(lines: Iterable[str]) -> str:
64    """Takes a diff str or list of str lines and returns a colorized version."""
65    if isinstance(lines, str):
66        lines = lines.splitlines(True)
67
68    return ''.join(_colorize_diff_line(line) for line in lines)
69
70
71def _diff(path, original: bytes, formatted: bytes) -> str:
72    return colorize_diff(
73        difflib.unified_diff(
74            original.decode(errors='replace').splitlines(True),
75            formatted.decode(errors='replace').splitlines(True),
76            f'{path}  (original)', f'{path}  (reformatted)'))
77
78
79Formatter = Callable[[str, bytes], bytes]
80
81
82def _diff_formatted(path, formatter: Formatter) -> Optional[str]:
83    """Returns a diff comparing a file to its formatted version."""
84    with open(path, 'rb') as fd:
85        original = fd.read()
86
87    formatted = formatter(path, original)
88
89    return None if formatted == original else _diff(path, original, formatted)
90
91
92def _check_files(files, formatter: Formatter) -> Dict[Path, str]:
93    errors = {}
94
95    for path in files:
96        difference = _diff_formatted(path, formatter)
97        if difference:
98            errors[path] = difference
99
100    return errors
101
102
103def _clang_format(*args: str, **kwargs) -> bytes:
104    return log_run(['clang-format', '--style=file', *args],
105                   stdout=subprocess.PIPE,
106                   check=True,
107                   **kwargs).stdout
108
109
110def clang_format_check(files: Iterable[Path]) -> Dict[Path, str]:
111    """Checks formatting; returns {path: diff} for files with bad formatting."""
112    return _check_files(files, lambda path, _: _clang_format(path))
113
114
115def clang_format_fix(files: Iterable) -> Dict[Path, str]:
116    """Fixes formatting for the provided files in place."""
117    _clang_format('-i', *files)
118    return {}
119
120
121def check_gn_format(files: Iterable[Path]) -> Dict[Path, str]:
122    """Checks formatting; returns {path: diff} for files with bad formatting."""
123    return _check_files(
124        files, lambda _, data: log_run(['gn', 'format', '--stdin'],
125                                       input=data,
126                                       stdout=subprocess.PIPE,
127                                       check=True).stdout)
128
129
130def fix_gn_format(files: Iterable[Path]) -> Dict[Path, str]:
131    """Fixes formatting for the provided files in place."""
132    log_run(['gn', 'format', *files], check=True)
133    return {}
134
135
136def check_bazel_format(files: Iterable[Path]) -> Dict[Path, str]:
137    """Checks formatting; returns {path: diff} for files with bad formatting."""
138    errors: Dict[Path, str] = {}
139
140    def _format_temp(path: Union[Path, str], data: bytes) -> bytes:
141        # buildifier doesn't have an option to output the changed file, so
142        # copy the file to a temp location, run buildifier on it, read that
143        # modified copy, and return its contents.
144        with tempfile.TemporaryDirectory() as temp:
145            build = Path(temp) / os.path.basename(path)
146            build.write_bytes(data)
147
148            proc = log_run(['buildifier', build], capture_output=True)
149            if proc.returncode:
150                stderr = proc.stderr.decode(errors='replace')
151                stderr = stderr.replace(str(build), str(path))
152                errors[Path(path)] = stderr
153            return build.read_bytes()
154
155    result = _check_files(files, _format_temp)
156    result.update(errors)
157    return result
158
159
160def fix_bazel_format(files: Iterable[Path]) -> Dict[Path, str]:
161    """Fixes formatting for the provided files in place."""
162    errors = {}
163    for path in files:
164        proc = log_run(['buildifier', path], capture_output=True)
165        if proc.returncode:
166            errors[path] = proc.stderr.decode()
167    return errors
168
169
170def check_go_format(files: Iterable[Path]) -> Dict[Path, str]:
171    """Checks formatting; returns {path: diff} for files with bad formatting."""
172    return _check_files(
173        files, lambda path, _: log_run(
174            ['gofmt', path], stdout=subprocess.PIPE, check=True).stdout)
175
176
177def fix_go_format(files: Iterable[Path]) -> Dict[Path, str]:
178    """Fixes formatting for the provided files in place."""
179    log_run(['gofmt', '-w', *files], check=True)
180    return {}
181
182
183def _yapf(*args, **kwargs) -> subprocess.CompletedProcess:
184    return log_run(['python', '-m', 'yapf', '--parallel', *args],
185                   capture_output=True,
186                   **kwargs)
187
188
189_DIFF_START = re.compile(r'^--- (.*)\s+\(original\)$', flags=re.MULTILINE)
190
191
192def check_py_format(files: Iterable[Path]) -> Dict[Path, str]:
193    """Checks formatting; returns {path: diff} for files with bad formatting."""
194    process = _yapf('--diff', *files)
195
196    errors: Dict[Path, str] = {}
197
198    if process.stdout:
199        raw_diff = process.stdout.decode(errors='replace')
200
201        matches = tuple(_DIFF_START.finditer(raw_diff))
202        for start, end in zip(matches, (*matches[1:], None)):
203            errors[Path(start.group(1))] = colorize_diff(
204                raw_diff[start.start():end.start() if end else None])
205
206    if process.stderr:
207        _LOG.error('yapf encountered an error:\n%s',
208                   process.stderr.decode(errors='replace').rstrip())
209        errors.update({file: '' for file in files if file not in errors})
210
211    return errors
212
213
214def fix_py_format(files: Iterable) -> Dict[Path, str]:
215    """Fixes formatting for the provided files in place."""
216    _yapf('--in-place', *files, check=True)
217    return {}
218
219
220_TRAILING_SPACE = re.compile(rb'[ \t]+$', flags=re.MULTILINE)
221
222
223def _check_trailing_space(paths: Iterable[Path], fix: bool) -> Dict[Path, str]:
224    """Checks for and optionally removes trailing whitespace."""
225    errors = {}
226
227    for path in paths:
228        with path.open('rb') as fd:
229            contents = fd.read()
230
231        corrected = _TRAILING_SPACE.sub(b'', contents)
232        if corrected != contents:
233            errors[path] = _diff(path, contents, corrected)
234
235            if fix:
236                with path.open('wb') as fd:
237                    fd.write(corrected)
238
239    return errors
240
241
242def check_trailing_space(files: Iterable[Path]) -> Dict[Path, str]:
243    return _check_trailing_space(files, fix=False)
244
245
246def fix_trailing_space(files: Iterable[Path]) -> Dict[Path, str]:
247    _check_trailing_space(files, fix=True)
248    return {}
249
250
251def print_format_check(errors: Dict[Path, str],
252                       show_fix_commands: bool) -> None:
253    """Prints and returns the result of a check_*_format function."""
254    if not errors:
255        # Don't print anything in the all-good case.
256        return
257
258    # Show the format fixing diff suggested by the tooling (with colors).
259    _LOG.warning('Found %d files with formatting errors. Format changes:',
260                 len(errors))
261    for diff in errors.values():
262        print(diff, end='')
263
264    # Show a copy-and-pastable command to fix the issues.
265    if show_fix_commands:
266
267        def path_relative_to_cwd(path):
268            try:
269                return Path(path).resolve().relative_to(Path.cwd().resolve())
270            except ValueError:
271                return Path(path).resolve()
272
273        message = (f'  pw format --fix {path_relative_to_cwd(path)}'
274                   for path in errors)
275        _LOG.warning('To fix formatting, run:\n\n%s\n', '\n'.join(message))
276
277
278class CodeFormat(NamedTuple):
279    language: str
280    extensions: Collection[str]
281    exclude: Collection[str]
282    check: Callable[[Iterable], Dict[Path, str]]
283    fix: Callable[[Iterable], Dict[Path, str]]
284
285
286CPP_HEADER_EXTS = frozenset(
287    ('.h', '.hpp', '.hxx', '.h++', '.hh', '.H', '.inc', '.inl'))
288CPP_SOURCE_EXTS = frozenset(('.c', '.cpp', '.cxx', '.c++', '.cc', '.C'))
289CPP_EXTS = CPP_HEADER_EXTS.union(CPP_SOURCE_EXTS)
290
291C_FORMAT: CodeFormat = CodeFormat('C and C++', CPP_EXTS,
292                                  (r'\.pb\.h$', r'\.pb\.c$'),
293                                  clang_format_check, clang_format_fix)
294
295PROTO_FORMAT: CodeFormat = CodeFormat('Protocol buffer', ('.proto', ), (),
296                                      clang_format_check, clang_format_fix)
297
298JAVA_FORMAT: CodeFormat = CodeFormat('Java', ('.java', ), (),
299                                     clang_format_check, clang_format_fix)
300
301JAVASCRIPT_FORMAT: CodeFormat = CodeFormat('JavaScript', ('.js', ), (),
302                                           clang_format_check,
303                                           clang_format_fix)
304
305GO_FORMAT: CodeFormat = CodeFormat('Go', ('.go', ), (), check_go_format,
306                                   fix_go_format)
307
308PYTHON_FORMAT: CodeFormat = CodeFormat('Python', ('.py', ), (),
309                                       check_py_format, fix_py_format)
310
311GN_FORMAT: CodeFormat = CodeFormat('GN', ('.gn', '.gni'), (), check_gn_format,
312                                   fix_gn_format)
313
314# TODO(pwbug/191): Add real code formatting support for Bazel and CMake
315BAZEL_FORMAT: CodeFormat = CodeFormat('Bazel', ('BUILD', '.bazel', '.bzl'), (),
316                                      check_bazel_format, fix_bazel_format)
317
318CMAKE_FORMAT: CodeFormat = CodeFormat('CMake', ('CMakeLists.txt', '.cmake'),
319                                      (), check_trailing_space,
320                                      fix_trailing_space)
321
322RST_FORMAT: CodeFormat = CodeFormat('reStructuredText', ('.rst', ), (),
323                                    check_trailing_space, fix_trailing_space)
324
325MARKDOWN_FORMAT: CodeFormat = CodeFormat('Markdown', ('.md', ), (),
326                                         check_trailing_space,
327                                         fix_trailing_space)
328
329CODE_FORMATS: Tuple[CodeFormat, ...] = (
330    C_FORMAT,
331    JAVA_FORMAT,
332    JAVASCRIPT_FORMAT,
333    PROTO_FORMAT,
334    GO_FORMAT,
335    PYTHON_FORMAT,
336    GN_FORMAT,
337    BAZEL_FORMAT,
338    CMAKE_FORMAT,
339    RST_FORMAT,
340    MARKDOWN_FORMAT,
341)
342
343
344def presubmit_check(code_format: CodeFormat, **filter_paths_args) -> Callable:
345    """Creates a presubmit check function from a CodeFormat object."""
346    filter_paths_args.setdefault('endswith', code_format.extensions)
347    filter_paths_args.setdefault('exclude', code_format.exclude)
348
349    @pw_presubmit.filter_paths(**filter_paths_args)
350    def check_code_format(ctx: pw_presubmit.PresubmitContext):
351        errors = code_format.check(ctx.paths)
352        print_format_check(
353            errors,
354            # When running as part of presubmit, show the fix command help.
355            show_fix_commands=True,
356        )
357        if errors:
358            raise pw_presubmit.PresubmitFailure
359
360    language = code_format.language.lower().replace('+', 'p').replace(' ', '_')
361    check_code_format.__name__ = f'{language}_format'
362
363    return check_code_format
364
365
366def presubmit_checks(**filter_paths_args) -> Tuple[Callable, ...]:
367    """Returns a tuple with all supported code format presubmit checks."""
368    return tuple(
369        presubmit_check(fmt, **filter_paths_args) for fmt in CODE_FORMATS)
370
371
372class CodeFormatter:
373    """Checks or fixes the formatting of a set of files."""
374    def __init__(self, files: Iterable[Path]):
375        self.paths = list(files)
376        self._formats: Dict[CodeFormat, List] = collections.defaultdict(list)
377
378        for path in self.paths:
379            for code_format in CODE_FORMATS:
380                if any(path.as_posix().endswith(e)
381                       for e in code_format.extensions):
382                    self._formats[code_format].append(path)
383
384    def check(self) -> Dict[Path, str]:
385        """Returns {path: diff} for files with incorrect formatting."""
386        errors: Dict[Path, str] = {}
387
388        for code_format, files in self._formats.items():
389            _LOG.debug('Checking %s', ', '.join(str(f) for f in files))
390            errors.update(code_format.check(files))
391
392        return collections.OrderedDict(sorted(errors.items()))
393
394    def fix(self) -> Dict[Path, str]:
395        """Fixes format errors for supported files in place."""
396        all_errors: Dict[Path, str] = {}
397        for code_format, files in self._formats.items():
398            errors = code_format.fix(files)
399            if errors:
400                for path, error in errors.items():
401                    _LOG.error('Failed to format %s', path)
402                    for line in error.splitlines():
403                        _LOG.error('%s', line)
404                all_errors.update(errors)
405                continue
406
407            _LOG.info('Formatted %s',
408                      plural(files, code_format.language + ' file'))
409        return all_errors
410
411
412def _file_summary(files: Iterable[Union[Path, str]], base: Path) -> List[str]:
413    try:
414        return file_summary(
415            Path(f).resolve().relative_to(base.resolve()) for f in files)
416    except ValueError:
417        return []
418
419
420def format_paths_in_repo(paths: Collection[Union[Path, str]],
421                         exclude: Collection[Pattern[str]], fix: bool,
422                         base: str) -> int:
423    """Checks or fixes formatting for files in a Git repo."""
424    files = [Path(path).resolve() for path in paths if os.path.isfile(path)]
425    repo = git_repo.root() if git_repo.is_repo() else None
426
427    # Implement a graceful fallback in case the tracking branch isn't available.
428    if (base == git_repo.TRACKING_BRANCH_ALIAS
429            and not git_repo.tracking_branch(repo)):
430        _LOG.warning(
431            'Failed to determine the tracking branch, using --base HEAD~1 '
432            'instead of listing all files')
433        base = 'HEAD~1'
434
435    # If this is a Git repo, list the original paths with git ls-files or diff.
436    if repo:
437        project_root = Path(pw_cli.env.pigweed_environment().PW_PROJECT_ROOT)
438        _LOG.info(
439            'Formatting %s',
440            git_repo.describe_files(repo, Path.cwd(), base, paths, exclude,
441                                    project_root))
442
443        # Add files from Git and remove duplicates.
444        files = sorted(
445            set(exclude_paths(exclude, git_repo.list_files(base, paths)))
446            | set(files))
447    elif base:
448        _LOG.critical(
449            'A base commit may only be provided if running from a Git repo')
450        return 1
451
452    return format_files(files, fix, repo=repo)
453
454
455def format_files(paths: Collection[Union[Path, str]],
456                 fix: bool,
457                 repo: Optional[Path] = None) -> int:
458    """Checks or fixes formatting for the specified files."""
459    formatter = CodeFormatter(Path(p) for p in paths)
460
461    _LOG.info('Checking formatting for %s', plural(formatter.paths, 'file'))
462
463    for line in _file_summary(paths, repo if repo else Path.cwd()):
464        print(line, file=sys.stderr)
465
466    check_errors = formatter.check()
467    print_format_check(check_errors, show_fix_commands=(not fix))
468
469    if check_errors:
470        if fix:
471            _LOG.info('Applying formatting fixes to %d files',
472                      len(check_errors))
473            fix_errors = formatter.fix()
474            if fix_errors:
475                _LOG.info('Failed to apply formatting fixes')
476                print_format_check(fix_errors, show_fix_commands=False)
477                return 1
478
479            _LOG.info('Formatting fixes applied successfully')
480            return 0
481
482        _LOG.error('Formatting errors found')
483        return 1
484
485    _LOG.info('Congratulations! No formatting changes needed')
486    return 0
487
488
489def arguments(git_paths: bool) -> argparse.ArgumentParser:
490    """Creates an argument parser for format_files or format_paths_in_repo."""
491
492    parser = argparse.ArgumentParser(description=__doc__)
493
494    if git_paths:
495        cli.add_path_arguments(parser)
496    else:
497
498        def existing_path(arg: str) -> Path:
499            path = Path(arg)
500            if not path.is_file():
501                raise argparse.ArgumentTypeError(
502                    f'{arg} is not a path to a file')
503
504            return path
505
506        parser.add_argument('paths',
507                            metavar='path',
508                            nargs='+',
509                            type=existing_path,
510                            help='File paths to check')
511
512    parser.add_argument('--fix',
513                        action='store_true',
514                        help='Apply formatting fixes in place.')
515    return parser
516
517
518def main() -> int:
519    """Check and fix formatting for source files."""
520    return format_paths_in_repo(**vars(arguments(git_paths=True).parse_args()))
521
522
523if __name__ == '__main__':
524    try:
525        # If pw_cli is available, use it to initialize logs.
526        from pw_cli import log
527
528        log.install(logging.INFO)
529    except ImportError:
530        # If pw_cli isn't available, display log messages like a simple print.
531        logging.basicConfig(format='%(message)s', level=logging.INFO)
532
533    sys.exit(main())
534