1#!/usr/bin/env python3 2 3# Copyright 2020 The Pigweed Authors 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); you may not 6# use this file except in compliance with the License. You may obtain a copy of 7# the License at 8# 9# https://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 14# License for the specific language governing permissions and limitations under 15# the License. 16"""Checks and fixes formatting for source files. 17 18This uses clang-format, gn format, gofmt, and python -m yapf to format source 19code. These tools must be available on the path when this script is invoked! 20""" 21 22import argparse 23import collections 24import difflib 25import logging 26import os 27from pathlib import Path 28import re 29import subprocess 30import sys 31from typing import Callable, Collection, Dict, Iterable, List, NamedTuple 32from typing import Optional, Pattern, Tuple, Union 33 34try: 35 import pw_presubmit 36except ImportError: 37 # Append the pw_presubmit package path to the module search path to allow 38 # running this module without installing the pw_presubmit package. 39 sys.path.append(os.path.dirname(os.path.dirname( 40 os.path.abspath(__file__)))) 41 import pw_presubmit 42 43from pw_presubmit import cli, git_repo 44from pw_presubmit.tools import exclude_paths, file_summary, log_run, plural 45 46_LOG: logging.Logger = logging.getLogger(__name__) 47 48 49def _colorize_diff_line(line: str) -> str: 50 if line.startswith('--- ') or line.startswith('+++ '): 51 return pw_presubmit.color_bold_white(line) 52 if line.startswith('-'): 53 return pw_presubmit.color_red(line) 54 if line.startswith('+'): 55 return pw_presubmit.color_green(line) 56 if line.startswith('@@ '): 57 return pw_presubmit.color_aqua(line) 58 return line 59 60 61def colorize_diff(lines: Iterable[str]) -> str: 62 """Takes a diff str or list of str lines and returns a colorized version.""" 63 if isinstance(lines, str): 64 lines = lines.splitlines(True) 65 66 return ''.join(_colorize_diff_line(line) for line in lines) 67 68 69def _diff(path, original: bytes, formatted: bytes) -> str: 70 return colorize_diff( 71 difflib.unified_diff( 72 original.decode(errors='replace').splitlines(True), 73 formatted.decode(errors='replace').splitlines(True), 74 f'{path} (original)', f'{path} (reformatted)')) 75 76 77Formatter = Callable[[str, bytes], bytes] 78 79 80def _diff_formatted(path, formatter: Formatter) -> Optional[str]: 81 """Returns a diff comparing a file to its formatted version.""" 82 with open(path, 'rb') as fd: 83 original = fd.read() 84 85 formatted = formatter(path, original) 86 87 return None if formatted == original else _diff(path, original, formatted) 88 89 90def _check_files(files, formatter: Formatter) -> Dict[Path, str]: 91 errors = {} 92 93 for path in files: 94 difference = _diff_formatted(path, formatter) 95 if difference: 96 errors[path] = difference 97 98 return errors 99 100 101def _clang_format(*args: str, **kwargs) -> bytes: 102 return log_run(['clang-format', '--style=file', *args], 103 stdout=subprocess.PIPE, 104 check=True, 105 **kwargs).stdout 106 107 108def clang_format_check(files: Iterable[Path]) -> Dict[Path, str]: 109 """Checks formatting; returns {path: diff} for files with bad formatting.""" 110 return _check_files(files, lambda path, _: _clang_format(path)) 111 112 113def clang_format_fix(files: Iterable) -> None: 114 """Fixes formatting for the provided files in place.""" 115 _clang_format('-i', *files) 116 117 118def check_gn_format(files: Iterable[Path]) -> Dict[Path, str]: 119 """Checks formatting; returns {path: diff} for files with bad formatting.""" 120 return _check_files( 121 files, lambda _, data: log_run(['gn', 'format', '--stdin'], 122 input=data, 123 stdout=subprocess.PIPE, 124 check=True).stdout) 125 126 127def fix_gn_format(files: Iterable[Path]) -> None: 128 """Fixes formatting for the provided files in place.""" 129 log_run(['gn', 'format', *files], check=True) 130 131 132def check_go_format(files: Iterable[Path]) -> Dict[Path, str]: 133 """Checks formatting; returns {path: diff} for files with bad formatting.""" 134 return _check_files( 135 files, lambda path, _: log_run( 136 ['gofmt', path], stdout=subprocess.PIPE, check=True).stdout) 137 138 139def fix_go_format(files: Iterable[Path]) -> None: 140 """Fixes formatting for the provided files in place.""" 141 log_run(['gofmt', '-w', *files], check=True) 142 143 144def _yapf(*args, **kwargs) -> subprocess.CompletedProcess: 145 return log_run(['python', '-m', 'yapf', '--parallel', *args], 146 capture_output=True, 147 **kwargs) 148 149 150_DIFF_START = re.compile(r'^--- (.*)\s+\(original\)$', flags=re.MULTILINE) 151 152 153def check_py_format(files: Iterable[Path]) -> Dict[Path, str]: 154 """Checks formatting; returns {path: diff} for files with bad formatting.""" 155 process = _yapf('--diff', *files) 156 157 errors: Dict[Path, str] = {} 158 159 if process.stdout: 160 raw_diff = process.stdout.decode(errors='replace') 161 162 matches = tuple(_DIFF_START.finditer(raw_diff)) 163 for start, end in zip(matches, (*matches[1:], None)): 164 errors[Path(start.group(1))] = colorize_diff( 165 raw_diff[start.start():end.start() if end else None]) 166 167 if process.stderr: 168 _LOG.error('yapf encountered an error:\n%s', 169 process.stderr.decode(errors='replace').rstrip()) 170 errors.update({file: '' for file in files if file not in errors}) 171 172 return errors 173 174 175def fix_py_format(files: Iterable): 176 """Fixes formatting for the provided files in place.""" 177 _yapf('--in-place', *files, check=True) 178 179 180_TRAILING_SPACE = re.compile(rb'[ \t]+$', flags=re.MULTILINE) 181 182 183def _check_trailing_space(paths: Iterable[Path], fix: bool) -> Dict[Path, str]: 184 """Checks for and optionally removes trailing whitespace.""" 185 errors = {} 186 187 for path in paths: 188 with path.open('rb') as fd: 189 contents = fd.read() 190 191 corrected = _TRAILING_SPACE.sub(b'', contents) 192 if corrected != contents: 193 errors[path] = _diff(path, contents, corrected) 194 195 if fix: 196 with path.open('wb') as fd: 197 fd.write(corrected) 198 199 return errors 200 201 202def check_trailing_space(files: Iterable[Path]) -> Dict[Path, str]: 203 return _check_trailing_space(files, fix=False) 204 205 206def fix_trailing_space(files: Iterable[Path]) -> None: 207 _check_trailing_space(files, fix=True) 208 209 210def print_format_check(errors: Dict[Path, str], 211 show_fix_commands: bool) -> None: 212 """Prints and returns the result of a check_*_format function.""" 213 if not errors: 214 # Don't print anything in the all-good case. 215 return 216 217 # Show the format fixing diff suggested by the tooling (with colors). 218 _LOG.warning('Found %d files with formatting errors. Format changes:', 219 len(errors)) 220 for diff in errors.values(): 221 print(diff, end='') 222 223 # Show a copy-and-pastable command to fix the issues. 224 if show_fix_commands: 225 226 def path_relative_to_cwd(path): 227 try: 228 return Path(path).resolve().relative_to(Path.cwd().resolve()) 229 except ValueError: 230 return Path(path).resolve() 231 232 message = (f' pw format --fix {path_relative_to_cwd(path)}' 233 for path in errors) 234 _LOG.warning('To fix formatting, run:\n\n%s\n', '\n'.join(message)) 235 236 237class CodeFormat(NamedTuple): 238 language: str 239 extensions: Collection[str] 240 exclude: Collection[str] 241 check: Callable[[Iterable], Dict[Path, str]] 242 fix: Callable[[Iterable], None] 243 244 245C_FORMAT: CodeFormat = CodeFormat( 246 'C and C++', 247 frozenset(['.h', '.hh', '.hpp', '.c', '.cc', '.cpp', '.inc', '.inl']), 248 (r'\.pb\.h$', r'\.pb\.c$'), clang_format_check, clang_format_fix) 249 250PROTO_FORMAT: CodeFormat = CodeFormat('Protocol buffer', ('.proto', ), (), 251 clang_format_check, clang_format_fix) 252 253JAVA_FORMAT: CodeFormat = CodeFormat('Java', ('.java', ), (), 254 clang_format_check, clang_format_fix) 255 256JAVASCRIPT_FORMAT: CodeFormat = CodeFormat('JavaScript', ('.js', ), (), 257 clang_format_check, 258 clang_format_fix) 259 260GO_FORMAT: CodeFormat = CodeFormat('Go', ('.go', ), (), check_go_format, 261 fix_go_format) 262 263PYTHON_FORMAT: CodeFormat = CodeFormat('Python', ('.py', ), (), 264 check_py_format, fix_py_format) 265 266GN_FORMAT: CodeFormat = CodeFormat('GN', ('.gn', '.gni'), (), check_gn_format, 267 fix_gn_format) 268 269# TODO(pwbug/191): Add real code formatting support for Bazel and CMake 270BAZEL_FORMAT: CodeFormat = CodeFormat('Bazel', ('BUILD', ), (), 271 check_trailing_space, fix_trailing_space) 272 273CMAKE_FORMAT: CodeFormat = CodeFormat('CMake', ('CMakeLists.txt', '.cmake'), 274 (), check_trailing_space, 275 fix_trailing_space) 276 277RST_FORMAT: CodeFormat = CodeFormat('reStructuredText', ('.rst', ), (), 278 check_trailing_space, fix_trailing_space) 279 280MARKDOWN_FORMAT: CodeFormat = CodeFormat('Markdown', ('.md', ), (), 281 check_trailing_space, 282 fix_trailing_space) 283 284CODE_FORMATS: Tuple[CodeFormat, ...] = ( 285 C_FORMAT, 286 JAVA_FORMAT, 287 JAVASCRIPT_FORMAT, 288 PROTO_FORMAT, 289 GO_FORMAT, 290 PYTHON_FORMAT, 291 GN_FORMAT, 292 BAZEL_FORMAT, 293 CMAKE_FORMAT, 294 RST_FORMAT, 295 MARKDOWN_FORMAT, 296) 297 298 299def presubmit_check(code_format: CodeFormat, **filter_paths_args) -> Callable: 300 """Creates a presubmit check function from a CodeFormat object.""" 301 filter_paths_args.setdefault('endswith', code_format.extensions) 302 filter_paths_args.setdefault('exclude', code_format.exclude) 303 304 @pw_presubmit.filter_paths(**filter_paths_args) 305 def check_code_format(ctx: pw_presubmit.PresubmitContext): 306 errors = code_format.check(ctx.paths) 307 print_format_check( 308 errors, 309 # When running as part of presubmit, show the fix command help. 310 show_fix_commands=True, 311 ) 312 if errors: 313 raise pw_presubmit.PresubmitFailure 314 315 language = code_format.language.lower().replace('+', 'p').replace(' ', '_') 316 check_code_format.__name__ = f'{language}_format' 317 318 return check_code_format 319 320 321def presubmit_checks(**filter_paths_args) -> Tuple[Callable, ...]: 322 """Returns a tuple with all supported code format presubmit checks.""" 323 return tuple( 324 presubmit_check(fmt, **filter_paths_args) for fmt in CODE_FORMATS) 325 326 327class CodeFormatter: 328 """Checks or fixes the formatting of a set of files.""" 329 def __init__(self, files: Iterable[Path]): 330 self.paths = list(files) 331 self._formats: Dict[CodeFormat, List] = collections.defaultdict(list) 332 333 for path in self.paths: 334 for code_format in CODE_FORMATS: 335 if any(path.as_posix().endswith(e) 336 for e in code_format.extensions): 337 self._formats[code_format].append(path) 338 339 def check(self) -> Dict[Path, str]: 340 """Returns {path: diff} for files with incorrect formatting.""" 341 errors: Dict[Path, str] = {} 342 343 for code_format, files in self._formats.items(): 344 _LOG.debug('Checking %s', ', '.join(str(f) for f in files)) 345 errors.update(code_format.check(files)) 346 347 return collections.OrderedDict(sorted(errors.items())) 348 349 def fix(self) -> None: 350 """Fixes format errors for supported files in place.""" 351 for code_format, files in self._formats.items(): 352 code_format.fix(files) 353 _LOG.info('Formatted %s', 354 plural(files, code_format.language + ' file')) 355 356 357def _file_summary(files: Iterable[Union[Path, str]], base: Path) -> List[str]: 358 try: 359 return file_summary( 360 Path(f).resolve().relative_to(base.resolve()) for f in files) 361 except ValueError: 362 return [] 363 364 365def format_paths_in_repo(paths: Collection[Union[Path, str]], 366 exclude: Collection[Pattern[str]], fix: bool, 367 base: str) -> int: 368 """Checks or fixes formatting for files in a Git repo.""" 369 files = [Path(path).resolve() for path in paths if os.path.isfile(path)] 370 repo = git_repo.root() if git_repo.is_repo() else None 371 372 # If this is a Git repo, list the original paths with git ls-files or diff. 373 if repo: 374 _LOG.info( 375 'Formatting %s', 376 git_repo.describe_files(repo, Path.cwd(), base, paths, exclude)) 377 378 # Add files from Git and remove duplicates. 379 files = sorted( 380 set(exclude_paths(exclude, git_repo.list_files(base, paths))) 381 | set(files)) 382 elif base: 383 _LOG.critical( 384 'A base commit may only be provided if running from a Git repo') 385 return 1 386 387 return format_files(files, fix, repo=repo) 388 389 390def format_files(paths: Collection[Union[Path, str]], 391 fix: bool, 392 repo: Optional[Path] = None) -> int: 393 """Checks or fixes formatting for the specified files.""" 394 formatter = CodeFormatter(Path(p) for p in paths) 395 396 _LOG.info('Checking formatting for %s', plural(formatter.paths, 'file')) 397 398 for line in _file_summary(paths, repo if repo else Path.cwd()): 399 print(line, file=sys.stderr) 400 401 errors = formatter.check() 402 print_format_check(errors, show_fix_commands=(not fix)) 403 404 if errors: 405 if fix: 406 formatter.fix() 407 # TODO: This should perhaps check that the fixes were successful. 408 _LOG.info('Formatting fixes applied successfully') 409 return 0 410 411 _LOG.error('Formatting errors found') 412 return 1 413 414 _LOG.info('Congratulations! No formatting changes needed') 415 return 0 416 417 418def arguments(git_paths: bool) -> argparse.ArgumentParser: 419 """Creates an argument parser for format_files or format_paths_in_repo.""" 420 421 parser = argparse.ArgumentParser(description=__doc__) 422 423 if git_paths: 424 cli.add_path_arguments(parser) 425 else: 426 427 def existing_path(arg: str) -> Path: 428 path = Path(arg) 429 if not path.is_file(): 430 raise argparse.ArgumentTypeError( 431 f'{arg} is not a path to a file') 432 433 return path 434 435 parser.add_argument('paths', 436 metavar='path', 437 nargs='+', 438 type=existing_path, 439 help='File paths to check') 440 441 parser.add_argument('--fix', 442 action='store_true', 443 help='Apply formatting fixes in place.') 444 return parser 445 446 447def main() -> int: 448 """Check and fix formatting for source files.""" 449 return format_paths_in_repo(**vars(arguments(git_paths=True).parse_args())) 450 451 452if __name__ == '__main__': 453 try: 454 # If pw_cli is available, use it to initialize logs. 455 from pw_cli import log 456 457 log.install(logging.INFO) 458 except ImportError: 459 # If pw_cli isn't available, display log messages like a simple print. 460 logging.basicConfig(format='%(message)s', level=logging.INFO) 461 462 sys.exit(main()) 463