1#!/usr/bin/env python3 2 3# Copyright 2020 The Pigweed Authors 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); you may not 6# use this file except in compliance with the License. You may obtain a copy of 7# the License at 8# 9# https://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 14# License for the specific language governing permissions and limitations under 15# the License. 16"""Checks and fixes formatting for source files. 17 18This uses clang-format, gn format, gofmt, and python -m yapf to format source 19code. These tools must be available on the path when this script is invoked! 20""" 21 22import argparse 23import collections 24import difflib 25import logging 26import os 27from pathlib import Path 28import re 29import subprocess 30import sys 31import tempfile 32from typing import Callable, Collection, Dict, Iterable, List, NamedTuple 33from typing import Optional, Pattern, Tuple, Union 34 35try: 36 import pw_presubmit 37except ImportError: 38 # Append the pw_presubmit package path to the module search path to allow 39 # running this module without installing the pw_presubmit package. 40 sys.path.append(os.path.dirname(os.path.dirname( 41 os.path.abspath(__file__)))) 42 import pw_presubmit 43 44import pw_cli.env 45from pw_presubmit import cli, git_repo 46from pw_presubmit.tools import exclude_paths, file_summary, log_run, plural 47 48_LOG: logging.Logger = logging.getLogger(__name__) 49 50 51def _colorize_diff_line(line: str) -> str: 52 if line.startswith('--- ') or line.startswith('+++ '): 53 return pw_presubmit.color_bold_white(line) 54 if line.startswith('-'): 55 return pw_presubmit.color_red(line) 56 if line.startswith('+'): 57 return pw_presubmit.color_green(line) 58 if line.startswith('@@ '): 59 return pw_presubmit.color_aqua(line) 60 return line 61 62 63def colorize_diff(lines: Iterable[str]) -> str: 64 """Takes a diff str or list of str lines and returns a colorized version.""" 65 if isinstance(lines, str): 66 lines = lines.splitlines(True) 67 68 return ''.join(_colorize_diff_line(line) for line in lines) 69 70 71def _diff(path, original: bytes, formatted: bytes) -> str: 72 return colorize_diff( 73 difflib.unified_diff( 74 original.decode(errors='replace').splitlines(True), 75 formatted.decode(errors='replace').splitlines(True), 76 f'{path} (original)', f'{path} (reformatted)')) 77 78 79Formatter = Callable[[str, bytes], bytes] 80 81 82def _diff_formatted(path, formatter: Formatter) -> Optional[str]: 83 """Returns a diff comparing a file to its formatted version.""" 84 with open(path, 'rb') as fd: 85 original = fd.read() 86 87 formatted = formatter(path, original) 88 89 return None if formatted == original else _diff(path, original, formatted) 90 91 92def _check_files(files, formatter: Formatter) -> Dict[Path, str]: 93 errors = {} 94 95 for path in files: 96 difference = _diff_formatted(path, formatter) 97 if difference: 98 errors[path] = difference 99 100 return errors 101 102 103def _clang_format(*args: str, **kwargs) -> bytes: 104 return log_run(['clang-format', '--style=file', *args], 105 stdout=subprocess.PIPE, 106 check=True, 107 **kwargs).stdout 108 109 110def clang_format_check(files: Iterable[Path]) -> Dict[Path, str]: 111 """Checks formatting; returns {path: diff} for files with bad formatting.""" 112 return _check_files(files, lambda path, _: _clang_format(path)) 113 114 115def clang_format_fix(files: Iterable) -> Dict[Path, str]: 116 """Fixes formatting for the provided files in place.""" 117 _clang_format('-i', *files) 118 return {} 119 120 121def check_gn_format(files: Iterable[Path]) -> Dict[Path, str]: 122 """Checks formatting; returns {path: diff} for files with bad formatting.""" 123 return _check_files( 124 files, lambda _, data: log_run(['gn', 'format', '--stdin'], 125 input=data, 126 stdout=subprocess.PIPE, 127 check=True).stdout) 128 129 130def fix_gn_format(files: Iterable[Path]) -> Dict[Path, str]: 131 """Fixes formatting for the provided files in place.""" 132 log_run(['gn', 'format', *files], check=True) 133 return {} 134 135 136def check_bazel_format(files: Iterable[Path]) -> Dict[Path, str]: 137 """Checks formatting; returns {path: diff} for files with bad formatting.""" 138 errors: Dict[Path, str] = {} 139 140 def _format_temp(path: Union[Path, str], data: bytes) -> bytes: 141 # buildifier doesn't have an option to output the changed file, so 142 # copy the file to a temp location, run buildifier on it, read that 143 # modified copy, and return its contents. 144 with tempfile.TemporaryDirectory() as temp: 145 build = Path(temp) / os.path.basename(path) 146 build.write_bytes(data) 147 148 proc = log_run(['buildifier', build], capture_output=True) 149 if proc.returncode: 150 stderr = proc.stderr.decode(errors='replace') 151 stderr = stderr.replace(str(build), str(path)) 152 errors[Path(path)] = stderr 153 return build.read_bytes() 154 155 result = _check_files(files, _format_temp) 156 result.update(errors) 157 return result 158 159 160def fix_bazel_format(files: Iterable[Path]) -> Dict[Path, str]: 161 """Fixes formatting for the provided files in place.""" 162 errors = {} 163 for path in files: 164 proc = log_run(['buildifier', path], capture_output=True) 165 if proc.returncode: 166 errors[path] = proc.stderr.decode() 167 return errors 168 169 170def check_go_format(files: Iterable[Path]) -> Dict[Path, str]: 171 """Checks formatting; returns {path: diff} for files with bad formatting.""" 172 return _check_files( 173 files, lambda path, _: log_run( 174 ['gofmt', path], stdout=subprocess.PIPE, check=True).stdout) 175 176 177def fix_go_format(files: Iterable[Path]) -> Dict[Path, str]: 178 """Fixes formatting for the provided files in place.""" 179 log_run(['gofmt', '-w', *files], check=True) 180 return {} 181 182 183def _yapf(*args, **kwargs) -> subprocess.CompletedProcess: 184 return log_run(['python', '-m', 'yapf', '--parallel', *args], 185 capture_output=True, 186 **kwargs) 187 188 189_DIFF_START = re.compile(r'^--- (.*)\s+\(original\)$', flags=re.MULTILINE) 190 191 192def check_py_format(files: Iterable[Path]) -> Dict[Path, str]: 193 """Checks formatting; returns {path: diff} for files with bad formatting.""" 194 process = _yapf('--diff', *files) 195 196 errors: Dict[Path, str] = {} 197 198 if process.stdout: 199 raw_diff = process.stdout.decode(errors='replace') 200 201 matches = tuple(_DIFF_START.finditer(raw_diff)) 202 for start, end in zip(matches, (*matches[1:], None)): 203 errors[Path(start.group(1))] = colorize_diff( 204 raw_diff[start.start():end.start() if end else None]) 205 206 if process.stderr: 207 _LOG.error('yapf encountered an error:\n%s', 208 process.stderr.decode(errors='replace').rstrip()) 209 errors.update({file: '' for file in files if file not in errors}) 210 211 return errors 212 213 214def fix_py_format(files: Iterable) -> Dict[Path, str]: 215 """Fixes formatting for the provided files in place.""" 216 _yapf('--in-place', *files, check=True) 217 return {} 218 219 220_TRAILING_SPACE = re.compile(rb'[ \t]+$', flags=re.MULTILINE) 221 222 223def _check_trailing_space(paths: Iterable[Path], fix: bool) -> Dict[Path, str]: 224 """Checks for and optionally removes trailing whitespace.""" 225 errors = {} 226 227 for path in paths: 228 with path.open('rb') as fd: 229 contents = fd.read() 230 231 corrected = _TRAILING_SPACE.sub(b'', contents) 232 if corrected != contents: 233 errors[path] = _diff(path, contents, corrected) 234 235 if fix: 236 with path.open('wb') as fd: 237 fd.write(corrected) 238 239 return errors 240 241 242def check_trailing_space(files: Iterable[Path]) -> Dict[Path, str]: 243 return _check_trailing_space(files, fix=False) 244 245 246def fix_trailing_space(files: Iterable[Path]) -> Dict[Path, str]: 247 _check_trailing_space(files, fix=True) 248 return {} 249 250 251def print_format_check(errors: Dict[Path, str], 252 show_fix_commands: bool) -> None: 253 """Prints and returns the result of a check_*_format function.""" 254 if not errors: 255 # Don't print anything in the all-good case. 256 return 257 258 # Show the format fixing diff suggested by the tooling (with colors). 259 _LOG.warning('Found %d files with formatting errors. Format changes:', 260 len(errors)) 261 for diff in errors.values(): 262 print(diff, end='') 263 264 # Show a copy-and-pastable command to fix the issues. 265 if show_fix_commands: 266 267 def path_relative_to_cwd(path): 268 try: 269 return Path(path).resolve().relative_to(Path.cwd().resolve()) 270 except ValueError: 271 return Path(path).resolve() 272 273 message = (f' pw format --fix {path_relative_to_cwd(path)}' 274 for path in errors) 275 _LOG.warning('To fix formatting, run:\n\n%s\n', '\n'.join(message)) 276 277 278class CodeFormat(NamedTuple): 279 language: str 280 extensions: Collection[str] 281 exclude: Collection[str] 282 check: Callable[[Iterable], Dict[Path, str]] 283 fix: Callable[[Iterable], Dict[Path, str]] 284 285 286CPP_HEADER_EXTS = frozenset( 287 ('.h', '.hpp', '.hxx', '.h++', '.hh', '.H', '.inc', '.inl')) 288CPP_SOURCE_EXTS = frozenset(('.c', '.cpp', '.cxx', '.c++', '.cc', '.C')) 289CPP_EXTS = CPP_HEADER_EXTS.union(CPP_SOURCE_EXTS) 290 291C_FORMAT: CodeFormat = CodeFormat('C and C++', CPP_EXTS, 292 (r'\.pb\.h$', r'\.pb\.c$'), 293 clang_format_check, clang_format_fix) 294 295PROTO_FORMAT: CodeFormat = CodeFormat('Protocol buffer', ('.proto', ), (), 296 clang_format_check, clang_format_fix) 297 298JAVA_FORMAT: CodeFormat = CodeFormat('Java', ('.java', ), (), 299 clang_format_check, clang_format_fix) 300 301JAVASCRIPT_FORMAT: CodeFormat = CodeFormat('JavaScript', ('.js', ), (), 302 clang_format_check, 303 clang_format_fix) 304 305GO_FORMAT: CodeFormat = CodeFormat('Go', ('.go', ), (), check_go_format, 306 fix_go_format) 307 308PYTHON_FORMAT: CodeFormat = CodeFormat('Python', ('.py', ), (), 309 check_py_format, fix_py_format) 310 311GN_FORMAT: CodeFormat = CodeFormat('GN', ('.gn', '.gni'), (), check_gn_format, 312 fix_gn_format) 313 314# TODO(pwbug/191): Add real code formatting support for Bazel and CMake 315BAZEL_FORMAT: CodeFormat = CodeFormat('Bazel', ('BUILD', '.bazel', '.bzl'), (), 316 check_bazel_format, fix_bazel_format) 317 318CMAKE_FORMAT: CodeFormat = CodeFormat('CMake', ('CMakeLists.txt', '.cmake'), 319 (), check_trailing_space, 320 fix_trailing_space) 321 322RST_FORMAT: CodeFormat = CodeFormat('reStructuredText', ('.rst', ), (), 323 check_trailing_space, fix_trailing_space) 324 325MARKDOWN_FORMAT: CodeFormat = CodeFormat('Markdown', ('.md', ), (), 326 check_trailing_space, 327 fix_trailing_space) 328 329CODE_FORMATS: Tuple[CodeFormat, ...] = ( 330 C_FORMAT, 331 JAVA_FORMAT, 332 JAVASCRIPT_FORMAT, 333 PROTO_FORMAT, 334 GO_FORMAT, 335 PYTHON_FORMAT, 336 GN_FORMAT, 337 BAZEL_FORMAT, 338 CMAKE_FORMAT, 339 RST_FORMAT, 340 MARKDOWN_FORMAT, 341) 342 343 344def presubmit_check(code_format: CodeFormat, **filter_paths_args) -> Callable: 345 """Creates a presubmit check function from a CodeFormat object.""" 346 filter_paths_args.setdefault('endswith', code_format.extensions) 347 filter_paths_args.setdefault('exclude', code_format.exclude) 348 349 @pw_presubmit.filter_paths(**filter_paths_args) 350 def check_code_format(ctx: pw_presubmit.PresubmitContext): 351 errors = code_format.check(ctx.paths) 352 print_format_check( 353 errors, 354 # When running as part of presubmit, show the fix command help. 355 show_fix_commands=True, 356 ) 357 if errors: 358 raise pw_presubmit.PresubmitFailure 359 360 language = code_format.language.lower().replace('+', 'p').replace(' ', '_') 361 check_code_format.__name__ = f'{language}_format' 362 363 return check_code_format 364 365 366def presubmit_checks(**filter_paths_args) -> Tuple[Callable, ...]: 367 """Returns a tuple with all supported code format presubmit checks.""" 368 return tuple( 369 presubmit_check(fmt, **filter_paths_args) for fmt in CODE_FORMATS) 370 371 372class CodeFormatter: 373 """Checks or fixes the formatting of a set of files.""" 374 def __init__(self, files: Iterable[Path]): 375 self.paths = list(files) 376 self._formats: Dict[CodeFormat, List] = collections.defaultdict(list) 377 378 for path in self.paths: 379 for code_format in CODE_FORMATS: 380 if any(path.as_posix().endswith(e) 381 for e in code_format.extensions): 382 self._formats[code_format].append(path) 383 384 def check(self) -> Dict[Path, str]: 385 """Returns {path: diff} for files with incorrect formatting.""" 386 errors: Dict[Path, str] = {} 387 388 for code_format, files in self._formats.items(): 389 _LOG.debug('Checking %s', ', '.join(str(f) for f in files)) 390 errors.update(code_format.check(files)) 391 392 return collections.OrderedDict(sorted(errors.items())) 393 394 def fix(self) -> Dict[Path, str]: 395 """Fixes format errors for supported files in place.""" 396 all_errors: Dict[Path, str] = {} 397 for code_format, files in self._formats.items(): 398 errors = code_format.fix(files) 399 if errors: 400 for path, error in errors.items(): 401 _LOG.error('Failed to format %s', path) 402 for line in error.splitlines(): 403 _LOG.error('%s', line) 404 all_errors.update(errors) 405 continue 406 407 _LOG.info('Formatted %s', 408 plural(files, code_format.language + ' file')) 409 return all_errors 410 411 412def _file_summary(files: Iterable[Union[Path, str]], base: Path) -> List[str]: 413 try: 414 return file_summary( 415 Path(f).resolve().relative_to(base.resolve()) for f in files) 416 except ValueError: 417 return [] 418 419 420def format_paths_in_repo(paths: Collection[Union[Path, str]], 421 exclude: Collection[Pattern[str]], fix: bool, 422 base: str) -> int: 423 """Checks or fixes formatting for files in a Git repo.""" 424 files = [Path(path).resolve() for path in paths if os.path.isfile(path)] 425 repo = git_repo.root() if git_repo.is_repo() else None 426 427 # Implement a graceful fallback in case the tracking branch isn't available. 428 if (base == git_repo.TRACKING_BRANCH_ALIAS 429 and not git_repo.tracking_branch(repo)): 430 _LOG.warning( 431 'Failed to determine the tracking branch, using --base HEAD~1 ' 432 'instead of listing all files') 433 base = 'HEAD~1' 434 435 # If this is a Git repo, list the original paths with git ls-files or diff. 436 if repo: 437 project_root = Path(pw_cli.env.pigweed_environment().PW_PROJECT_ROOT) 438 _LOG.info( 439 'Formatting %s', 440 git_repo.describe_files(repo, Path.cwd(), base, paths, exclude, 441 project_root)) 442 443 # Add files from Git and remove duplicates. 444 files = sorted( 445 set(exclude_paths(exclude, git_repo.list_files(base, paths))) 446 | set(files)) 447 elif base: 448 _LOG.critical( 449 'A base commit may only be provided if running from a Git repo') 450 return 1 451 452 return format_files(files, fix, repo=repo) 453 454 455def format_files(paths: Collection[Union[Path, str]], 456 fix: bool, 457 repo: Optional[Path] = None) -> int: 458 """Checks or fixes formatting for the specified files.""" 459 formatter = CodeFormatter(Path(p) for p in paths) 460 461 _LOG.info('Checking formatting for %s', plural(formatter.paths, 'file')) 462 463 for line in _file_summary(paths, repo if repo else Path.cwd()): 464 print(line, file=sys.stderr) 465 466 check_errors = formatter.check() 467 print_format_check(check_errors, show_fix_commands=(not fix)) 468 469 if check_errors: 470 if fix: 471 _LOG.info('Applying formatting fixes to %d files', 472 len(check_errors)) 473 fix_errors = formatter.fix() 474 if fix_errors: 475 _LOG.info('Failed to apply formatting fixes') 476 print_format_check(fix_errors, show_fix_commands=False) 477 return 1 478 479 _LOG.info('Formatting fixes applied successfully') 480 return 0 481 482 _LOG.error('Formatting errors found') 483 return 1 484 485 _LOG.info('Congratulations! No formatting changes needed') 486 return 0 487 488 489def arguments(git_paths: bool) -> argparse.ArgumentParser: 490 """Creates an argument parser for format_files or format_paths_in_repo.""" 491 492 parser = argparse.ArgumentParser(description=__doc__) 493 494 if git_paths: 495 cli.add_path_arguments(parser) 496 else: 497 498 def existing_path(arg: str) -> Path: 499 path = Path(arg) 500 if not path.is_file(): 501 raise argparse.ArgumentTypeError( 502 f'{arg} is not a path to a file') 503 504 return path 505 506 parser.add_argument('paths', 507 metavar='path', 508 nargs='+', 509 type=existing_path, 510 help='File paths to check') 511 512 parser.add_argument('--fix', 513 action='store_true', 514 help='Apply formatting fixes in place.') 515 return parser 516 517 518def main() -> int: 519 """Check and fix formatting for source files.""" 520 return format_paths_in_repo(**vars(arguments(git_paths=True).parse_args())) 521 522 523if __name__ == '__main__': 524 try: 525 # If pw_cli is available, use it to initialize logs. 526 from pw_cli import log 527 528 log.install(logging.INFO) 529 except ImportError: 530 # If pw_cli isn't available, display log messages like a simple print. 531 logging.basicConfig(format='%(message)s', level=logging.INFO) 532 533 sys.exit(main()) 534