1#!/usr/bin/env python3 2 3# Copyright 2020 The Pigweed Authors 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); you may not 6# use this file except in compliance with the License. You may obtain a copy of 7# the License at 8# 9# https://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 14# License for the specific language governing permissions and limitations under 15# the License. 16"""Checks and fixes formatting for source files. 17 18This uses clang-format, gn format, gofmt, and python -m yapf to format source 19code. These tools must be available on the path when this script is invoked! 20""" 21 22import argparse 23import collections 24import difflib 25import logging 26import os 27from pathlib import Path 28import re 29import subprocess 30import sys 31import tempfile 32from typing import ( 33 Callable, 34 Collection, 35 Dict, 36 Iterable, 37 List, 38 NamedTuple, 39 Optional, 40 Pattern, 41 Sequence, 42 TextIO, 43 Tuple, 44 Union, 45) 46 47try: 48 import pw_presubmit 49except ImportError: 50 # Append the pw_presubmit package path to the module search path to allow 51 # running this module without installing the pw_presubmit package. 52 sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 53 import pw_presubmit 54 55import pw_cli.color 56import pw_cli.env 57from pw_presubmit.presubmit import FileFilter 58from pw_presubmit import ( 59 cli, 60 FormatContext, 61 FormatOptions, 62 git_repo, 63 owners_checks, 64 PresubmitContext, 65) 66from pw_presubmit.tools import exclude_paths, file_summary, log_run, plural 67 68_LOG: logging.Logger = logging.getLogger(__name__) 69_COLOR = pw_cli.color.colors() 70_DEFAULT_PATH = Path('out', 'format') 71 72_Context = Union[PresubmitContext, FormatContext] 73 74 75def _colorize_diff_line(line: str) -> str: 76 if line.startswith('--- ') or line.startswith('+++ '): 77 return _COLOR.bold_white(line) 78 if line.startswith('-'): 79 return _COLOR.red(line) 80 if line.startswith('+'): 81 return _COLOR.green(line) 82 if line.startswith('@@ '): 83 return _COLOR.cyan(line) 84 return line 85 86 87def colorize_diff(lines: Iterable[str]) -> str: 88 """Takes a diff str or list of str lines and returns a colorized version.""" 89 if isinstance(lines, str): 90 lines = lines.splitlines(True) 91 92 return ''.join(_colorize_diff_line(line) for line in lines) 93 94 95def _diff(path, original: bytes, formatted: bytes) -> str: 96 return ''.join( 97 difflib.unified_diff( 98 original.decode(errors='replace').splitlines(True), 99 formatted.decode(errors='replace').splitlines(True), 100 f'{path} (original)', 101 f'{path} (reformatted)', 102 ) 103 ) 104 105 106Formatter = Callable[[str, bytes], bytes] 107 108 109def _diff_formatted(path, formatter: Formatter) -> Optional[str]: 110 """Returns a diff comparing a file to its formatted version.""" 111 with open(path, 'rb') as fd: 112 original = fd.read() 113 114 formatted = formatter(path, original) 115 116 return None if formatted == original else _diff(path, original, formatted) 117 118 119def _check_files(files, formatter: Formatter) -> Dict[Path, str]: 120 errors = {} 121 122 for path in files: 123 difference = _diff_formatted(path, formatter) 124 if difference: 125 errors[path] = difference 126 127 return errors 128 129 130def _clang_format(*args: Union[Path, str], **kwargs) -> bytes: 131 return log_run( 132 ['clang-format', '--style=file', *args], 133 stdout=subprocess.PIPE, 134 check=True, 135 **kwargs, 136 ).stdout 137 138 139def clang_format_check(ctx: _Context) -> Dict[Path, str]: 140 """Checks formatting; returns {path: diff} for files with bad formatting.""" 141 return _check_files(ctx.paths, lambda path, _: _clang_format(path)) 142 143 144def clang_format_fix(ctx: _Context) -> Dict[Path, str]: 145 """Fixes formatting for the provided files in place.""" 146 _clang_format('-i', *ctx.paths) 147 return {} 148 149 150def check_gn_format(ctx: _Context) -> Dict[Path, str]: 151 """Checks formatting; returns {path: diff} for files with bad formatting.""" 152 return _check_files( 153 ctx.paths, 154 lambda _, data: log_run( 155 ['gn', 'format', '--stdin'], 156 input=data, 157 stdout=subprocess.PIPE, 158 check=True, 159 ).stdout, 160 ) 161 162 163def fix_gn_format(ctx: _Context) -> Dict[Path, str]: 164 """Fixes formatting for the provided files in place.""" 165 log_run(['gn', 'format', *ctx.paths], check=True) 166 return {} 167 168 169def check_bazel_format(ctx: _Context) -> Dict[Path, str]: 170 """Checks formatting; returns {path: diff} for files with bad formatting.""" 171 errors: Dict[Path, str] = {} 172 173 def _format_temp(path: Union[Path, str], data: bytes) -> bytes: 174 # buildifier doesn't have an option to output the changed file, so 175 # copy the file to a temp location, run buildifier on it, read that 176 # modified copy, and return its contents. 177 with tempfile.TemporaryDirectory(dir=ctx.output_dir) as temp: 178 build = Path(temp) / os.path.basename(path) 179 build.write_bytes(data) 180 181 proc = log_run(['buildifier', build], capture_output=True) 182 if proc.returncode: 183 stderr = proc.stderr.decode(errors='replace') 184 stderr = stderr.replace(str(build), str(path)) 185 errors[Path(path)] = stderr 186 return build.read_bytes() 187 188 result = _check_files(ctx.paths, _format_temp) 189 result.update(errors) 190 return result 191 192 193def fix_bazel_format(ctx: _Context) -> Dict[Path, str]: 194 """Fixes formatting for the provided files in place.""" 195 errors = {} 196 for path in ctx.paths: 197 proc = log_run(['buildifier', path], capture_output=True) 198 if proc.returncode: 199 errors[path] = proc.stderr.decode() 200 return errors 201 202 203def check_owners_format(ctx: _Context) -> Dict[Path, str]: 204 return owners_checks.run_owners_checks(ctx.paths) 205 206 207def fix_owners_format(ctx: _Context) -> Dict[Path, str]: 208 return owners_checks.format_owners_file(ctx.paths) 209 210 211def check_go_format(ctx: _Context) -> Dict[Path, str]: 212 """Checks formatting; returns {path: diff} for files with bad formatting.""" 213 return _check_files( 214 ctx.paths, 215 lambda path, _: log_run( 216 ['gofmt', path], stdout=subprocess.PIPE, check=True 217 ).stdout, 218 ) 219 220 221def fix_go_format(ctx: _Context) -> Dict[Path, str]: 222 """Fixes formatting for the provided files in place.""" 223 log_run(['gofmt', '-w', *ctx.paths], check=True) 224 return {} 225 226 227# TODO(b/259595799) Remove yapf support. 228def _yapf(*args, **kwargs) -> subprocess.CompletedProcess: 229 return log_run( 230 ['python', '-m', 'yapf', '--parallel', *args], 231 capture_output=True, 232 **kwargs, 233 ) 234 235 236_DIFF_START = re.compile(r'^--- (.*)\s+\(original\)$', flags=re.MULTILINE) 237 238 239def check_py_format_yapf(ctx: _Context) -> Dict[Path, str]: 240 """Checks formatting; returns {path: diff} for files with bad formatting.""" 241 process = _yapf('--diff', *ctx.paths) 242 243 errors: Dict[Path, str] = {} 244 245 if process.stdout: 246 raw_diff = process.stdout.decode(errors='replace') 247 248 matches = tuple(_DIFF_START.finditer(raw_diff)) 249 for start, end in zip(matches, (*matches[1:], None)): 250 errors[Path(start.group(1))] = raw_diff[ 251 start.start() : end.start() if end else None 252 ] 253 254 if process.stderr: 255 _LOG.error( 256 'yapf encountered an error:\n%s', 257 process.stderr.decode(errors='replace').rstrip(), 258 ) 259 errors.update({file: '' for file in ctx.paths if file not in errors}) 260 261 return errors 262 263 264def fix_py_format_yapf(ctx: _Context) -> Dict[Path, str]: 265 """Fixes formatting for the provided files in place.""" 266 _yapf('--in-place', *ctx.paths, check=True) 267 return {} 268 269 270def _enumerate_black_configs() -> Iterable[Path]: 271 if directory := os.environ.get('PW_PROJECT_ROOT'): 272 yield Path(directory, '.black.toml') 273 yield Path(directory, 'pyproject.toml') 274 275 if directory := os.environ.get('PW_ROOT'): 276 yield Path(directory, '.black.toml') 277 yield Path(directory, 'pyproject.toml') 278 279 280def _black_config_args() -> Sequence[Union[str, Path]]: 281 config = None 282 for config_location in _enumerate_black_configs(): 283 if config_location.is_file(): 284 config = config_location 285 break 286 287 config_args: Sequence[Union[str, Path]] = () 288 if config: 289 config_args = ('--config', config) 290 return config_args 291 292 293def _black_multiple_files(ctx: _Context) -> Tuple[str, ...]: 294 black = ctx.format_options.black_path 295 changed_paths: List[str] = [] 296 for line in ( 297 log_run( 298 [black, '--check', *_black_config_args(), *ctx.paths], 299 capture_output=True, 300 ) 301 .stderr.decode() 302 .splitlines() 303 ): 304 if match := re.search(r'^would reformat (.*)\s*$', line): 305 changed_paths.append(match.group(1)) 306 return tuple(changed_paths) 307 308 309def check_py_format_black(ctx: _Context) -> Dict[Path, str]: 310 """Checks formatting; returns {path: diff} for files with bad formatting.""" 311 errors: Dict[Path, str] = {} 312 313 # Run black --check on the full list of paths and then only run black 314 # individually on the files that black found issue with. 315 paths: Tuple[str, ...] = _black_multiple_files(ctx) 316 317 def _format_temp(path: Union[Path, str], data: bytes) -> bytes: 318 # black doesn't have an option to output the changed file, so copy the 319 # file to a temp location, run buildifier on it, read that modified 320 # copy, and return its contents. 321 with tempfile.TemporaryDirectory(dir=ctx.output_dir) as temp: 322 build = Path(temp) / os.path.basename(path) 323 build.write_bytes(data) 324 325 proc = log_run( 326 [ctx.format_options.black_path, *_black_config_args(), build], 327 capture_output=True, 328 ) 329 if proc.returncode: 330 stderr = proc.stderr.decode(errors='replace') 331 stderr = stderr.replace(str(build), str(path)) 332 errors[Path(path)] = stderr 333 return build.read_bytes() 334 335 result = _check_files( 336 [x for x in ctx.paths if str(x).endswith(paths)], 337 _format_temp, 338 ) 339 result.update(errors) 340 return result 341 342 343def fix_py_format_black(ctx: _Context) -> Dict[Path, str]: 344 """Fixes formatting for the provided files in place.""" 345 errors: Dict[Path, str] = {} 346 347 # Run black --check on the full list of paths and then only run black 348 # individually on the files that black found issue with. 349 paths: Tuple[str, ...] = _black_multiple_files(ctx) 350 351 for path in ctx.paths: 352 if not str(path).endswith(paths): 353 continue 354 355 proc = log_run( 356 [ctx.format_options.black_path, *_black_config_args(), path], 357 capture_output=True, 358 ) 359 if proc.returncode: 360 errors[path] = proc.stderr.decode() 361 return errors 362 363 364def check_py_format(ctx: _Context) -> Dict[Path, str]: 365 if ctx.format_options.python_formatter == 'black': 366 return check_py_format_black(ctx) 367 if ctx.format_options.python_formatter == 'yapf': 368 return check_py_format_yapf(ctx) 369 raise ValueError(ctx.format_options.python_formatter) 370 371 372def fix_py_format(ctx: _Context) -> Dict[Path, str]: 373 if ctx.format_options.python_formatter == 'black': 374 return fix_py_format_black(ctx) 375 if ctx.format_options.python_formatter == 'yapf': 376 return fix_py_format_yapf(ctx) 377 raise ValueError(ctx.format_options.python_formatter) 378 379 380_TRAILING_SPACE = re.compile(rb'[ \t]+$', flags=re.MULTILINE) 381 382 383def _check_trailing_space(paths: Iterable[Path], fix: bool) -> Dict[Path, str]: 384 """Checks for and optionally removes trailing whitespace.""" 385 errors = {} 386 387 for path in paths: 388 with path.open('rb') as fd: 389 contents = fd.read() 390 391 corrected = _TRAILING_SPACE.sub(b'', contents) 392 if corrected != contents: 393 errors[path] = _diff(path, contents, corrected) 394 395 if fix: 396 with path.open('wb') as fd: 397 fd.write(corrected) 398 399 return errors 400 401 402def check_trailing_space(ctx: _Context) -> Dict[Path, str]: 403 return _check_trailing_space(ctx.paths, fix=False) 404 405 406def fix_trailing_space(ctx: _Context) -> Dict[Path, str]: 407 _check_trailing_space(ctx.paths, fix=True) 408 return {} 409 410 411def print_format_check( 412 errors: Dict[Path, str], 413 show_fix_commands: bool, 414 show_summary: bool = True, 415 colors: Optional[bool] = None, 416 file: TextIO = sys.stdout, 417) -> None: 418 """Prints and returns the result of a check_*_format function.""" 419 if not errors: 420 # Don't print anything in the all-good case. 421 return 422 423 if colors is None: 424 colors = file == sys.stdout 425 426 # Show the format fixing diff suggested by the tooling (with colors). 427 if show_summary: 428 _LOG.warning( 429 'Found %d files with formatting errors. Format changes:', 430 len(errors), 431 ) 432 for diff in errors.values(): 433 if colors: 434 diff = colorize_diff(diff) 435 print(diff, end='', file=file) 436 437 # Show a copy-and-pastable command to fix the issues. 438 if show_fix_commands: 439 440 def path_relative_to_cwd(path: Path): 441 try: 442 return Path(path).resolve().relative_to(Path.cwd().resolve()) 443 except ValueError: 444 return Path(path).resolve() 445 446 message = ( 447 f' pw format --fix {path_relative_to_cwd(path)}' for path in errors 448 ) 449 _LOG.warning('To fix formatting, run:\n\n%s\n', '\n'.join(message)) 450 451 452class CodeFormat(NamedTuple): 453 language: str 454 filter: FileFilter 455 check: Callable[[_Context], Dict[Path, str]] 456 fix: Callable[[_Context], Dict[Path, str]] 457 458 @property 459 def extensions(self): 460 # TODO(b/23842636): Switch calls of this to using 'filter' and remove. 461 return self.filter.endswith 462 463 464CPP_HEADER_EXTS = frozenset(('.h', '.hpp', '.hxx', '.h++', '.hh', '.H')) 465CPP_SOURCE_EXTS = frozenset( 466 ('.c', '.cpp', '.cxx', '.c++', '.cc', '.C', '.inc', '.inl') 467) 468CPP_EXTS = CPP_HEADER_EXTS.union(CPP_SOURCE_EXTS) 469CPP_FILE_FILTER = FileFilter( 470 endswith=CPP_EXTS, exclude=(r'\.pb\.h$', r'\.pb\.c$') 471) 472 473C_FORMAT = CodeFormat( 474 'C and C++', CPP_FILE_FILTER, clang_format_check, clang_format_fix 475) 476 477PROTO_FORMAT: CodeFormat = CodeFormat( 478 'Protocol buffer', 479 FileFilter(endswith=('.proto',)), 480 clang_format_check, 481 clang_format_fix, 482) 483 484JAVA_FORMAT: CodeFormat = CodeFormat( 485 'Java', 486 FileFilter(endswith=('.java',)), 487 clang_format_check, 488 clang_format_fix, 489) 490 491JAVASCRIPT_FORMAT: CodeFormat = CodeFormat( 492 'JavaScript', 493 FileFilter(endswith=('.js',)), 494 clang_format_check, 495 clang_format_fix, 496) 497 498GO_FORMAT: CodeFormat = CodeFormat( 499 'Go', FileFilter(endswith=('.go',)), check_go_format, fix_go_format 500) 501 502PYTHON_FORMAT: CodeFormat = CodeFormat( 503 'Python', 504 FileFilter(endswith=('.py',)), 505 check_py_format, 506 fix_py_format, 507) 508 509GN_FORMAT: CodeFormat = CodeFormat( 510 'GN', FileFilter(endswith=('.gn', '.gni')), check_gn_format, fix_gn_format 511) 512 513BAZEL_FORMAT: CodeFormat = CodeFormat( 514 'Bazel', 515 FileFilter(endswith=('BUILD', '.bazel', '.bzl'), name=('WORKSPACE')), 516 check_bazel_format, 517 fix_bazel_format, 518) 519 520COPYBARA_FORMAT: CodeFormat = CodeFormat( 521 'Copybara', 522 FileFilter(endswith=('.bara.sky',)), 523 check_bazel_format, 524 fix_bazel_format, 525) 526 527# TODO(b/234881054): Add real code formatting support for CMake 528CMAKE_FORMAT: CodeFormat = CodeFormat( 529 'CMake', 530 FileFilter(endswith=('CMakeLists.txt', '.cmake')), 531 check_trailing_space, 532 fix_trailing_space, 533) 534 535RST_FORMAT: CodeFormat = CodeFormat( 536 'reStructuredText', 537 FileFilter(endswith=('.rst',)), 538 check_trailing_space, 539 fix_trailing_space, 540) 541 542MARKDOWN_FORMAT: CodeFormat = CodeFormat( 543 'Markdown', 544 FileFilter(endswith=('.md',)), 545 check_trailing_space, 546 fix_trailing_space, 547) 548 549OWNERS_CODE_FORMAT = CodeFormat( 550 'OWNERS', 551 filter=FileFilter(name=('OWNERS',)), 552 check=check_owners_format, 553 fix=fix_owners_format, 554) 555 556CODE_FORMATS: Tuple[CodeFormat, ...] = ( 557 # keep-sorted: start 558 BAZEL_FORMAT, 559 CMAKE_FORMAT, 560 COPYBARA_FORMAT, 561 C_FORMAT, 562 GN_FORMAT, 563 GO_FORMAT, 564 JAVASCRIPT_FORMAT, 565 JAVA_FORMAT, 566 MARKDOWN_FORMAT, 567 OWNERS_CODE_FORMAT, 568 PROTO_FORMAT, 569 PYTHON_FORMAT, 570 RST_FORMAT, 571 # keep-sorted: end 572) 573 574# TODO(b/264578594) Remove these lines when these globals aren't referenced. 575CODE_FORMATS_WITH_BLACK: Tuple[CodeFormat, ...] = CODE_FORMATS 576CODE_FORMATS_WITH_YAPF: Tuple[CodeFormat, ...] = CODE_FORMATS 577 578 579def presubmit_check( 580 code_format: CodeFormat, 581 *, 582 exclude: Collection[Union[str, Pattern[str]]] = (), 583) -> Callable: 584 """Creates a presubmit check function from a CodeFormat object. 585 586 Args: 587 exclude: Additional exclusion regexes to apply. 588 """ 589 590 # Make a copy of the FileFilter and add in any additional excludes. 591 file_filter = FileFilter(**vars(code_format.filter)) 592 file_filter.exclude += tuple(re.compile(e) for e in exclude) 593 594 @pw_presubmit.filter_paths(file_filter=file_filter) 595 def check_code_format(ctx: pw_presubmit.PresubmitContext): 596 errors = code_format.check(ctx) 597 print_format_check( 598 errors, 599 # When running as part of presubmit, show the fix command help. 600 show_fix_commands=True, 601 ) 602 if not errors: 603 return 604 605 with ctx.failure_summary_log.open('w') as outs: 606 print_format_check( 607 errors, 608 show_summary=False, 609 show_fix_commands=False, 610 file=outs, 611 ) 612 613 raise pw_presubmit.PresubmitFailure 614 615 language = code_format.language.lower().replace('+', 'p').replace(' ', '_') 616 check_code_format.name = f'{language}_format' 617 check_code_format.doc = f'Check the format of {code_format.language} files.' 618 619 return check_code_format 620 621 622def presubmit_checks( 623 *, 624 exclude: Collection[Union[str, Pattern[str]]] = (), 625 code_formats: Collection[CodeFormat] = CODE_FORMATS, 626) -> Tuple[Callable, ...]: 627 """Returns a tuple with all supported code format presubmit checks. 628 629 Args: 630 exclude: Additional exclusion regexes to apply. 631 code_formats: A list of CodeFormat objects to run checks with. 632 """ 633 634 return tuple(presubmit_check(fmt, exclude=exclude) for fmt in code_formats) 635 636 637class CodeFormatter: 638 """Checks or fixes the formatting of a set of files.""" 639 640 def __init__( 641 self, 642 root: Optional[Path], 643 files: Iterable[Path], 644 output_dir: Path, 645 code_formats: Collection[CodeFormat] = CODE_FORMATS_WITH_YAPF, 646 package_root: Optional[Path] = None, 647 ): 648 self.root = root 649 self.paths = list(files) 650 self._formats: Dict[CodeFormat, List] = collections.defaultdict(list) 651 self.root_output_dir = output_dir 652 self.package_root = package_root or output_dir / 'packages' 653 654 for path in self.paths: 655 for code_format in code_formats: 656 if code_format.filter.matches(path): 657 _LOG.debug( 658 'Formatting %s as %s', path, code_format.language 659 ) 660 self._formats[code_format].append(path) 661 break 662 else: 663 _LOG.debug('No formatter found for %s', path) 664 665 def _context(self, code_format: CodeFormat): 666 outdir = self.root_output_dir / code_format.language.replace(' ', '_') 667 os.makedirs(outdir, exist_ok=True) 668 669 return FormatContext( 670 root=self.root, 671 output_dir=outdir, 672 paths=tuple(self._formats[code_format]), 673 package_root=self.package_root, 674 format_options=FormatOptions.load(), 675 ) 676 677 def check(self) -> Dict[Path, str]: 678 """Returns {path: diff} for files with incorrect formatting.""" 679 errors: Dict[Path, str] = {} 680 681 for code_format, files in self._formats.items(): 682 _LOG.debug('Checking %s', ', '.join(str(f) for f in files)) 683 errors.update(code_format.check(self._context(code_format))) 684 685 return collections.OrderedDict(sorted(errors.items())) 686 687 def fix(self) -> Dict[Path, str]: 688 """Fixes format errors for supported files in place.""" 689 all_errors: Dict[Path, str] = {} 690 for code_format, files in self._formats.items(): 691 errors = code_format.fix(self._context(code_format)) 692 if errors: 693 for path, error in errors.items(): 694 _LOG.error('Failed to format %s', path) 695 for line in error.splitlines(): 696 _LOG.error('%s', line) 697 all_errors.update(errors) 698 continue 699 700 _LOG.info( 701 'Formatted %s', plural(files, code_format.language + ' file') 702 ) 703 return all_errors 704 705 706def _file_summary(files: Iterable[Union[Path, str]], base: Path) -> List[str]: 707 try: 708 return file_summary( 709 Path(f).resolve().relative_to(base.resolve()) for f in files 710 ) 711 except ValueError: 712 return [] 713 714 715def format_paths_in_repo( 716 paths: Collection[Union[Path, str]], 717 exclude: Collection[Pattern[str]], 718 fix: bool, 719 base: str, 720 code_formats: Collection[CodeFormat] = CODE_FORMATS, 721 output_directory: Optional[Path] = None, 722 package_root: Optional[Path] = None, 723) -> int: 724 """Checks or fixes formatting for files in a Git repo.""" 725 726 files = [Path(path).resolve() for path in paths if os.path.isfile(path)] 727 repo = git_repo.root() if git_repo.is_repo() else None 728 729 # Implement a graceful fallback in case the tracking branch isn't available. 730 if base == git_repo.TRACKING_BRANCH_ALIAS and not git_repo.tracking_branch( 731 repo 732 ): 733 _LOG.warning( 734 'Failed to determine the tracking branch, using --base HEAD~1 ' 735 'instead of listing all files' 736 ) 737 base = 'HEAD~1' 738 739 # If this is a Git repo, list the original paths with git ls-files or diff. 740 if repo: 741 project_root = pw_cli.env.pigweed_environment().PW_PROJECT_ROOT 742 _LOG.info( 743 'Formatting %s', 744 git_repo.describe_files( 745 repo, Path.cwd(), base, paths, exclude, project_root 746 ), 747 ) 748 749 # Add files from Git and remove duplicates. 750 files = sorted( 751 set(exclude_paths(exclude, git_repo.list_files(base, paths))) 752 | set(files) 753 ) 754 elif base: 755 _LOG.critical( 756 'A base commit may only be provided if running from a Git repo' 757 ) 758 return 1 759 760 return format_files( 761 files, 762 fix, 763 repo=repo, 764 code_formats=code_formats, 765 output_directory=output_directory, 766 package_root=package_root, 767 ) 768 769 770def format_files( 771 paths: Collection[Union[Path, str]], 772 fix: bool, 773 repo: Optional[Path] = None, 774 code_formats: Collection[CodeFormat] = CODE_FORMATS, 775 output_directory: Optional[Path] = None, 776 package_root: Optional[Path] = None, 777) -> int: 778 """Checks or fixes formatting for the specified files.""" 779 780 root: Optional[Path] = None 781 782 if git_repo.is_repo(): 783 root = git_repo.root() 784 elif paths: 785 parent = Path(next(iter(paths))).parent 786 if git_repo.is_repo(parent): 787 root = git_repo.root(parent) 788 789 output_dir: Path 790 if output_directory: 791 output_dir = output_directory 792 elif root: 793 output_dir = root / _DEFAULT_PATH 794 else: 795 tempdir = tempfile.TemporaryDirectory() 796 output_dir = Path(tempdir.name) 797 798 formatter = CodeFormatter( 799 files=(Path(p) for p in paths), 800 code_formats=code_formats, 801 root=root, 802 output_dir=output_dir, 803 package_root=package_root, 804 ) 805 806 _LOG.info('Checking formatting for %s', plural(formatter.paths, 'file')) 807 808 for line in _file_summary(paths, repo if repo else Path.cwd()): 809 print(line, file=sys.stderr) 810 811 check_errors = formatter.check() 812 print_format_check(check_errors, show_fix_commands=(not fix)) 813 814 if check_errors: 815 if fix: 816 _LOG.info( 817 'Applying formatting fixes to %d files', len(check_errors) 818 ) 819 fix_errors = formatter.fix() 820 if fix_errors: 821 _LOG.info('Failed to apply formatting fixes') 822 print_format_check(fix_errors, show_fix_commands=False) 823 return 1 824 825 _LOG.info('Formatting fixes applied successfully') 826 return 0 827 828 _LOG.error('Formatting errors found') 829 return 1 830 831 _LOG.info('Congratulations! No formatting changes needed') 832 return 0 833 834 835def arguments(git_paths: bool) -> argparse.ArgumentParser: 836 """Creates an argument parser for format_files or format_paths_in_repo.""" 837 838 parser = argparse.ArgumentParser(description=__doc__) 839 840 if git_paths: 841 cli.add_path_arguments(parser) 842 else: 843 844 def existing_path(arg: str) -> Path: 845 path = Path(arg) 846 if not path.is_file(): 847 raise argparse.ArgumentTypeError( 848 f'{arg} is not a path to a file' 849 ) 850 851 return path 852 853 parser.add_argument( 854 'paths', 855 metavar='path', 856 nargs='+', 857 type=existing_path, 858 help='File paths to check', 859 ) 860 861 parser.add_argument( 862 '--fix', action='store_true', help='Apply formatting fixes in place.' 863 ) 864 865 parser.add_argument( 866 '--output-directory', 867 type=Path, 868 help=f"Output directory (default: {'<repo root>' / _DEFAULT_PATH})", 869 ) 870 parser.add_argument( 871 '--package-root', 872 type=Path, 873 default=Path(os.environ['PW_PACKAGE_ROOT']), 874 help='Package root directory', 875 ) 876 877 return parser 878 879 880def main() -> int: 881 """Check and fix formatting for source files.""" 882 return format_paths_in_repo(**vars(arguments(git_paths=True).parse_args())) 883 884 885def _pigweed_upstream_main() -> int: 886 """Check and fix formatting for source files in upstream Pigweed. 887 888 Excludes third party sources. 889 """ 890 args = arguments(git_paths=True).parse_args() 891 892 # Exclude paths with third party code from formatting. 893 args.exclude.append(re.compile('^third_party/fuchsia/repo/')) 894 895 return format_paths_in_repo(**vars(args)) 896 897 898if __name__ == '__main__': 899 try: 900 # If pw_cli is available, use it to initialize logs. 901 from pw_cli import log 902 903 log.install(logging.INFO) 904 except ImportError: 905 # If pw_cli isn't available, display log messages like a simple print. 906 logging.basicConfig(format='%(message)s', level=logging.INFO) 907 908 sys.exit(main()) 909