• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2
3# Copyright 2020 The Pigweed Authors
4#
5# Licensed under the Apache License, Version 2.0 (the "License"); you may not
6# use this file except in compliance with the License. You may obtain a copy of
7# the License at
8#
9#     https://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
14# License for the specific language governing permissions and limitations under
15# the License.
16"""Checks and fixes formatting for source files.
17
18This uses clang-format, gn format, gofmt, and python -m yapf to format source
19code. These tools must be available on the path when this script is invoked!
20"""
21
22import argparse
23import collections
24import difflib
25import logging
26import os
27from pathlib import Path
28import re
29import subprocess
30import sys
31import tempfile
32from typing import (
33    Callable,
34    Collection,
35    Dict,
36    Iterable,
37    List,
38    NamedTuple,
39    Optional,
40    Pattern,
41    Sequence,
42    TextIO,
43    Tuple,
44    Union,
45)
46
47try:
48    import pw_presubmit
49except ImportError:
50    # Append the pw_presubmit package path to the module search path to allow
51    # running this module without installing the pw_presubmit package.
52    sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
53    import pw_presubmit
54
55import pw_cli.color
56import pw_cli.env
57from pw_presubmit.presubmit import FileFilter
58from pw_presubmit import (
59    cli,
60    FormatContext,
61    FormatOptions,
62    git_repo,
63    owners_checks,
64    PresubmitContext,
65)
66from pw_presubmit.tools import exclude_paths, file_summary, log_run, plural
67
68_LOG: logging.Logger = logging.getLogger(__name__)
69_COLOR = pw_cli.color.colors()
70_DEFAULT_PATH = Path('out', 'format')
71
72_Context = Union[PresubmitContext, FormatContext]
73
74
75def _colorize_diff_line(line: str) -> str:
76    if line.startswith('--- ') or line.startswith('+++ '):
77        return _COLOR.bold_white(line)
78    if line.startswith('-'):
79        return _COLOR.red(line)
80    if line.startswith('+'):
81        return _COLOR.green(line)
82    if line.startswith('@@ '):
83        return _COLOR.cyan(line)
84    return line
85
86
87def colorize_diff(lines: Iterable[str]) -> str:
88    """Takes a diff str or list of str lines and returns a colorized version."""
89    if isinstance(lines, str):
90        lines = lines.splitlines(True)
91
92    return ''.join(_colorize_diff_line(line) for line in lines)
93
94
95def _diff(path, original: bytes, formatted: bytes) -> str:
96    return ''.join(
97        difflib.unified_diff(
98            original.decode(errors='replace').splitlines(True),
99            formatted.decode(errors='replace').splitlines(True),
100            f'{path}  (original)',
101            f'{path}  (reformatted)',
102        )
103    )
104
105
106Formatter = Callable[[str, bytes], bytes]
107
108
109def _diff_formatted(path, formatter: Formatter) -> Optional[str]:
110    """Returns a diff comparing a file to its formatted version."""
111    with open(path, 'rb') as fd:
112        original = fd.read()
113
114    formatted = formatter(path, original)
115
116    return None if formatted == original else _diff(path, original, formatted)
117
118
119def _check_files(files, formatter: Formatter) -> Dict[Path, str]:
120    errors = {}
121
122    for path in files:
123        difference = _diff_formatted(path, formatter)
124        if difference:
125            errors[path] = difference
126
127    return errors
128
129
130def _clang_format(*args: Union[Path, str], **kwargs) -> bytes:
131    return log_run(
132        ['clang-format', '--style=file', *args],
133        stdout=subprocess.PIPE,
134        check=True,
135        **kwargs,
136    ).stdout
137
138
139def clang_format_check(ctx: _Context) -> Dict[Path, str]:
140    """Checks formatting; returns {path: diff} for files with bad formatting."""
141    return _check_files(ctx.paths, lambda path, _: _clang_format(path))
142
143
144def clang_format_fix(ctx: _Context) -> Dict[Path, str]:
145    """Fixes formatting for the provided files in place."""
146    _clang_format('-i', *ctx.paths)
147    return {}
148
149
150def check_gn_format(ctx: _Context) -> Dict[Path, str]:
151    """Checks formatting; returns {path: diff} for files with bad formatting."""
152    return _check_files(
153        ctx.paths,
154        lambda _, data: log_run(
155            ['gn', 'format', '--stdin'],
156            input=data,
157            stdout=subprocess.PIPE,
158            check=True,
159        ).stdout,
160    )
161
162
163def fix_gn_format(ctx: _Context) -> Dict[Path, str]:
164    """Fixes formatting for the provided files in place."""
165    log_run(['gn', 'format', *ctx.paths], check=True)
166    return {}
167
168
169def check_bazel_format(ctx: _Context) -> Dict[Path, str]:
170    """Checks formatting; returns {path: diff} for files with bad formatting."""
171    errors: Dict[Path, str] = {}
172
173    def _format_temp(path: Union[Path, str], data: bytes) -> bytes:
174        # buildifier doesn't have an option to output the changed file, so
175        # copy the file to a temp location, run buildifier on it, read that
176        # modified copy, and return its contents.
177        with tempfile.TemporaryDirectory(dir=ctx.output_dir) as temp:
178            build = Path(temp) / os.path.basename(path)
179            build.write_bytes(data)
180
181            proc = log_run(['buildifier', build], capture_output=True)
182            if proc.returncode:
183                stderr = proc.stderr.decode(errors='replace')
184                stderr = stderr.replace(str(build), str(path))
185                errors[Path(path)] = stderr
186            return build.read_bytes()
187
188    result = _check_files(ctx.paths, _format_temp)
189    result.update(errors)
190    return result
191
192
193def fix_bazel_format(ctx: _Context) -> Dict[Path, str]:
194    """Fixes formatting for the provided files in place."""
195    errors = {}
196    for path in ctx.paths:
197        proc = log_run(['buildifier', path], capture_output=True)
198        if proc.returncode:
199            errors[path] = proc.stderr.decode()
200    return errors
201
202
203def check_owners_format(ctx: _Context) -> Dict[Path, str]:
204    return owners_checks.run_owners_checks(ctx.paths)
205
206
207def fix_owners_format(ctx: _Context) -> Dict[Path, str]:
208    return owners_checks.format_owners_file(ctx.paths)
209
210
211def check_go_format(ctx: _Context) -> Dict[Path, str]:
212    """Checks formatting; returns {path: diff} for files with bad formatting."""
213    return _check_files(
214        ctx.paths,
215        lambda path, _: log_run(
216            ['gofmt', path], stdout=subprocess.PIPE, check=True
217        ).stdout,
218    )
219
220
221def fix_go_format(ctx: _Context) -> Dict[Path, str]:
222    """Fixes formatting for the provided files in place."""
223    log_run(['gofmt', '-w', *ctx.paths], check=True)
224    return {}
225
226
227# TODO(b/259595799) Remove yapf support.
228def _yapf(*args, **kwargs) -> subprocess.CompletedProcess:
229    return log_run(
230        ['python', '-m', 'yapf', '--parallel', *args],
231        capture_output=True,
232        **kwargs,
233    )
234
235
236_DIFF_START = re.compile(r'^--- (.*)\s+\(original\)$', flags=re.MULTILINE)
237
238
239def check_py_format_yapf(ctx: _Context) -> Dict[Path, str]:
240    """Checks formatting; returns {path: diff} for files with bad formatting."""
241    process = _yapf('--diff', *ctx.paths)
242
243    errors: Dict[Path, str] = {}
244
245    if process.stdout:
246        raw_diff = process.stdout.decode(errors='replace')
247
248        matches = tuple(_DIFF_START.finditer(raw_diff))
249        for start, end in zip(matches, (*matches[1:], None)):
250            errors[Path(start.group(1))] = raw_diff[
251                start.start() : end.start() if end else None
252            ]
253
254    if process.stderr:
255        _LOG.error(
256            'yapf encountered an error:\n%s',
257            process.stderr.decode(errors='replace').rstrip(),
258        )
259        errors.update({file: '' for file in ctx.paths if file not in errors})
260
261    return errors
262
263
264def fix_py_format_yapf(ctx: _Context) -> Dict[Path, str]:
265    """Fixes formatting for the provided files in place."""
266    _yapf('--in-place', *ctx.paths, check=True)
267    return {}
268
269
270def _enumerate_black_configs() -> Iterable[Path]:
271    if directory := os.environ.get('PW_PROJECT_ROOT'):
272        yield Path(directory, '.black.toml')
273        yield Path(directory, 'pyproject.toml')
274
275    if directory := os.environ.get('PW_ROOT'):
276        yield Path(directory, '.black.toml')
277        yield Path(directory, 'pyproject.toml')
278
279
280def _black_config_args() -> Sequence[Union[str, Path]]:
281    config = None
282    for config_location in _enumerate_black_configs():
283        if config_location.is_file():
284            config = config_location
285            break
286
287    config_args: Sequence[Union[str, Path]] = ()
288    if config:
289        config_args = ('--config', config)
290    return config_args
291
292
293def _black_multiple_files(ctx: _Context) -> Tuple[str, ...]:
294    black = ctx.format_options.black_path
295    changed_paths: List[str] = []
296    for line in (
297        log_run(
298            [black, '--check', *_black_config_args(), *ctx.paths],
299            capture_output=True,
300        )
301        .stderr.decode()
302        .splitlines()
303    ):
304        if match := re.search(r'^would reformat (.*)\s*$', line):
305            changed_paths.append(match.group(1))
306    return tuple(changed_paths)
307
308
309def check_py_format_black(ctx: _Context) -> Dict[Path, str]:
310    """Checks formatting; returns {path: diff} for files with bad formatting."""
311    errors: Dict[Path, str] = {}
312
313    # Run black --check on the full list of paths and then only run black
314    # individually on the files that black found issue with.
315    paths: Tuple[str, ...] = _black_multiple_files(ctx)
316
317    def _format_temp(path: Union[Path, str], data: bytes) -> bytes:
318        # black doesn't have an option to output the changed file, so copy the
319        # file to a temp location, run buildifier on it, read that modified
320        # copy, and return its contents.
321        with tempfile.TemporaryDirectory(dir=ctx.output_dir) as temp:
322            build = Path(temp) / os.path.basename(path)
323            build.write_bytes(data)
324
325            proc = log_run(
326                [ctx.format_options.black_path, *_black_config_args(), build],
327                capture_output=True,
328            )
329            if proc.returncode:
330                stderr = proc.stderr.decode(errors='replace')
331                stderr = stderr.replace(str(build), str(path))
332                errors[Path(path)] = stderr
333            return build.read_bytes()
334
335    result = _check_files(
336        [x for x in ctx.paths if str(x).endswith(paths)],
337        _format_temp,
338    )
339    result.update(errors)
340    return result
341
342
343def fix_py_format_black(ctx: _Context) -> Dict[Path, str]:
344    """Fixes formatting for the provided files in place."""
345    errors: Dict[Path, str] = {}
346
347    # Run black --check on the full list of paths and then only run black
348    # individually on the files that black found issue with.
349    paths: Tuple[str, ...] = _black_multiple_files(ctx)
350
351    for path in ctx.paths:
352        if not str(path).endswith(paths):
353            continue
354
355        proc = log_run(
356            [ctx.format_options.black_path, *_black_config_args(), path],
357            capture_output=True,
358        )
359        if proc.returncode:
360            errors[path] = proc.stderr.decode()
361    return errors
362
363
364def check_py_format(ctx: _Context) -> Dict[Path, str]:
365    if ctx.format_options.python_formatter == 'black':
366        return check_py_format_black(ctx)
367    if ctx.format_options.python_formatter == 'yapf':
368        return check_py_format_yapf(ctx)
369    raise ValueError(ctx.format_options.python_formatter)
370
371
372def fix_py_format(ctx: _Context) -> Dict[Path, str]:
373    if ctx.format_options.python_formatter == 'black':
374        return fix_py_format_black(ctx)
375    if ctx.format_options.python_formatter == 'yapf':
376        return fix_py_format_yapf(ctx)
377    raise ValueError(ctx.format_options.python_formatter)
378
379
380_TRAILING_SPACE = re.compile(rb'[ \t]+$', flags=re.MULTILINE)
381
382
383def _check_trailing_space(paths: Iterable[Path], fix: bool) -> Dict[Path, str]:
384    """Checks for and optionally removes trailing whitespace."""
385    errors = {}
386
387    for path in paths:
388        with path.open('rb') as fd:
389            contents = fd.read()
390
391        corrected = _TRAILING_SPACE.sub(b'', contents)
392        if corrected != contents:
393            errors[path] = _diff(path, contents, corrected)
394
395            if fix:
396                with path.open('wb') as fd:
397                    fd.write(corrected)
398
399    return errors
400
401
402def check_trailing_space(ctx: _Context) -> Dict[Path, str]:
403    return _check_trailing_space(ctx.paths, fix=False)
404
405
406def fix_trailing_space(ctx: _Context) -> Dict[Path, str]:
407    _check_trailing_space(ctx.paths, fix=True)
408    return {}
409
410
411def print_format_check(
412    errors: Dict[Path, str],
413    show_fix_commands: bool,
414    show_summary: bool = True,
415    colors: Optional[bool] = None,
416    file: TextIO = sys.stdout,
417) -> None:
418    """Prints and returns the result of a check_*_format function."""
419    if not errors:
420        # Don't print anything in the all-good case.
421        return
422
423    if colors is None:
424        colors = file == sys.stdout
425
426    # Show the format fixing diff suggested by the tooling (with colors).
427    if show_summary:
428        _LOG.warning(
429            'Found %d files with formatting errors. Format changes:',
430            len(errors),
431        )
432    for diff in errors.values():
433        if colors:
434            diff = colorize_diff(diff)
435        print(diff, end='', file=file)
436
437    # Show a copy-and-pastable command to fix the issues.
438    if show_fix_commands:
439
440        def path_relative_to_cwd(path: Path):
441            try:
442                return Path(path).resolve().relative_to(Path.cwd().resolve())
443            except ValueError:
444                return Path(path).resolve()
445
446        message = (
447            f'  pw format --fix {path_relative_to_cwd(path)}' for path in errors
448        )
449        _LOG.warning('To fix formatting, run:\n\n%s\n', '\n'.join(message))
450
451
452class CodeFormat(NamedTuple):
453    language: str
454    filter: FileFilter
455    check: Callable[[_Context], Dict[Path, str]]
456    fix: Callable[[_Context], Dict[Path, str]]
457
458    @property
459    def extensions(self):
460        # TODO(b/23842636): Switch calls of this to using 'filter' and remove.
461        return self.filter.endswith
462
463
464CPP_HEADER_EXTS = frozenset(('.h', '.hpp', '.hxx', '.h++', '.hh', '.H'))
465CPP_SOURCE_EXTS = frozenset(
466    ('.c', '.cpp', '.cxx', '.c++', '.cc', '.C', '.inc', '.inl')
467)
468CPP_EXTS = CPP_HEADER_EXTS.union(CPP_SOURCE_EXTS)
469CPP_FILE_FILTER = FileFilter(
470    endswith=CPP_EXTS, exclude=(r'\.pb\.h$', r'\.pb\.c$')
471)
472
473C_FORMAT = CodeFormat(
474    'C and C++', CPP_FILE_FILTER, clang_format_check, clang_format_fix
475)
476
477PROTO_FORMAT: CodeFormat = CodeFormat(
478    'Protocol buffer',
479    FileFilter(endswith=('.proto',)),
480    clang_format_check,
481    clang_format_fix,
482)
483
484JAVA_FORMAT: CodeFormat = CodeFormat(
485    'Java',
486    FileFilter(endswith=('.java',)),
487    clang_format_check,
488    clang_format_fix,
489)
490
491JAVASCRIPT_FORMAT: CodeFormat = CodeFormat(
492    'JavaScript',
493    FileFilter(endswith=('.js',)),
494    clang_format_check,
495    clang_format_fix,
496)
497
498GO_FORMAT: CodeFormat = CodeFormat(
499    'Go', FileFilter(endswith=('.go',)), check_go_format, fix_go_format
500)
501
502PYTHON_FORMAT: CodeFormat = CodeFormat(
503    'Python',
504    FileFilter(endswith=('.py',)),
505    check_py_format,
506    fix_py_format,
507)
508
509GN_FORMAT: CodeFormat = CodeFormat(
510    'GN', FileFilter(endswith=('.gn', '.gni')), check_gn_format, fix_gn_format
511)
512
513BAZEL_FORMAT: CodeFormat = CodeFormat(
514    'Bazel',
515    FileFilter(endswith=('BUILD', '.bazel', '.bzl'), name=('WORKSPACE')),
516    check_bazel_format,
517    fix_bazel_format,
518)
519
520COPYBARA_FORMAT: CodeFormat = CodeFormat(
521    'Copybara',
522    FileFilter(endswith=('.bara.sky',)),
523    check_bazel_format,
524    fix_bazel_format,
525)
526
527# TODO(b/234881054): Add real code formatting support for CMake
528CMAKE_FORMAT: CodeFormat = CodeFormat(
529    'CMake',
530    FileFilter(endswith=('CMakeLists.txt', '.cmake')),
531    check_trailing_space,
532    fix_trailing_space,
533)
534
535RST_FORMAT: CodeFormat = CodeFormat(
536    'reStructuredText',
537    FileFilter(endswith=('.rst',)),
538    check_trailing_space,
539    fix_trailing_space,
540)
541
542MARKDOWN_FORMAT: CodeFormat = CodeFormat(
543    'Markdown',
544    FileFilter(endswith=('.md',)),
545    check_trailing_space,
546    fix_trailing_space,
547)
548
549OWNERS_CODE_FORMAT = CodeFormat(
550    'OWNERS',
551    filter=FileFilter(name=('OWNERS',)),
552    check=check_owners_format,
553    fix=fix_owners_format,
554)
555
556CODE_FORMATS: Tuple[CodeFormat, ...] = (
557    # keep-sorted: start
558    BAZEL_FORMAT,
559    CMAKE_FORMAT,
560    COPYBARA_FORMAT,
561    C_FORMAT,
562    GN_FORMAT,
563    GO_FORMAT,
564    JAVASCRIPT_FORMAT,
565    JAVA_FORMAT,
566    MARKDOWN_FORMAT,
567    OWNERS_CODE_FORMAT,
568    PROTO_FORMAT,
569    PYTHON_FORMAT,
570    RST_FORMAT,
571    # keep-sorted: end
572)
573
574# TODO(b/264578594) Remove these lines when these globals aren't referenced.
575CODE_FORMATS_WITH_BLACK: Tuple[CodeFormat, ...] = CODE_FORMATS
576CODE_FORMATS_WITH_YAPF: Tuple[CodeFormat, ...] = CODE_FORMATS
577
578
579def presubmit_check(
580    code_format: CodeFormat,
581    *,
582    exclude: Collection[Union[str, Pattern[str]]] = (),
583) -> Callable:
584    """Creates a presubmit check function from a CodeFormat object.
585
586    Args:
587      exclude: Additional exclusion regexes to apply.
588    """
589
590    # Make a copy of the FileFilter and add in any additional excludes.
591    file_filter = FileFilter(**vars(code_format.filter))
592    file_filter.exclude += tuple(re.compile(e) for e in exclude)
593
594    @pw_presubmit.filter_paths(file_filter=file_filter)
595    def check_code_format(ctx: pw_presubmit.PresubmitContext):
596        errors = code_format.check(ctx)
597        print_format_check(
598            errors,
599            # When running as part of presubmit, show the fix command help.
600            show_fix_commands=True,
601        )
602        if not errors:
603            return
604
605        with ctx.failure_summary_log.open('w') as outs:
606            print_format_check(
607                errors,
608                show_summary=False,
609                show_fix_commands=False,
610                file=outs,
611            )
612
613        raise pw_presubmit.PresubmitFailure
614
615    language = code_format.language.lower().replace('+', 'p').replace(' ', '_')
616    check_code_format.name = f'{language}_format'
617    check_code_format.doc = f'Check the format of {code_format.language} files.'
618
619    return check_code_format
620
621
622def presubmit_checks(
623    *,
624    exclude: Collection[Union[str, Pattern[str]]] = (),
625    code_formats: Collection[CodeFormat] = CODE_FORMATS,
626) -> Tuple[Callable, ...]:
627    """Returns a tuple with all supported code format presubmit checks.
628
629    Args:
630      exclude: Additional exclusion regexes to apply.
631      code_formats: A list of CodeFormat objects to run checks with.
632    """
633
634    return tuple(presubmit_check(fmt, exclude=exclude) for fmt in code_formats)
635
636
637class CodeFormatter:
638    """Checks or fixes the formatting of a set of files."""
639
640    def __init__(
641        self,
642        root: Optional[Path],
643        files: Iterable[Path],
644        output_dir: Path,
645        code_formats: Collection[CodeFormat] = CODE_FORMATS_WITH_YAPF,
646        package_root: Optional[Path] = None,
647    ):
648        self.root = root
649        self.paths = list(files)
650        self._formats: Dict[CodeFormat, List] = collections.defaultdict(list)
651        self.root_output_dir = output_dir
652        self.package_root = package_root or output_dir / 'packages'
653
654        for path in self.paths:
655            for code_format in code_formats:
656                if code_format.filter.matches(path):
657                    _LOG.debug(
658                        'Formatting %s as %s', path, code_format.language
659                    )
660                    self._formats[code_format].append(path)
661                    break
662            else:
663                _LOG.debug('No formatter found for %s', path)
664
665    def _context(self, code_format: CodeFormat):
666        outdir = self.root_output_dir / code_format.language.replace(' ', '_')
667        os.makedirs(outdir, exist_ok=True)
668
669        return FormatContext(
670            root=self.root,
671            output_dir=outdir,
672            paths=tuple(self._formats[code_format]),
673            package_root=self.package_root,
674            format_options=FormatOptions.load(),
675        )
676
677    def check(self) -> Dict[Path, str]:
678        """Returns {path: diff} for files with incorrect formatting."""
679        errors: Dict[Path, str] = {}
680
681        for code_format, files in self._formats.items():
682            _LOG.debug('Checking %s', ', '.join(str(f) for f in files))
683            errors.update(code_format.check(self._context(code_format)))
684
685        return collections.OrderedDict(sorted(errors.items()))
686
687    def fix(self) -> Dict[Path, str]:
688        """Fixes format errors for supported files in place."""
689        all_errors: Dict[Path, str] = {}
690        for code_format, files in self._formats.items():
691            errors = code_format.fix(self._context(code_format))
692            if errors:
693                for path, error in errors.items():
694                    _LOG.error('Failed to format %s', path)
695                    for line in error.splitlines():
696                        _LOG.error('%s', line)
697                all_errors.update(errors)
698                continue
699
700            _LOG.info(
701                'Formatted %s', plural(files, code_format.language + ' file')
702            )
703        return all_errors
704
705
706def _file_summary(files: Iterable[Union[Path, str]], base: Path) -> List[str]:
707    try:
708        return file_summary(
709            Path(f).resolve().relative_to(base.resolve()) for f in files
710        )
711    except ValueError:
712        return []
713
714
715def format_paths_in_repo(
716    paths: Collection[Union[Path, str]],
717    exclude: Collection[Pattern[str]],
718    fix: bool,
719    base: str,
720    code_formats: Collection[CodeFormat] = CODE_FORMATS,
721    output_directory: Optional[Path] = None,
722    package_root: Optional[Path] = None,
723) -> int:
724    """Checks or fixes formatting for files in a Git repo."""
725
726    files = [Path(path).resolve() for path in paths if os.path.isfile(path)]
727    repo = git_repo.root() if git_repo.is_repo() else None
728
729    # Implement a graceful fallback in case the tracking branch isn't available.
730    if base == git_repo.TRACKING_BRANCH_ALIAS and not git_repo.tracking_branch(
731        repo
732    ):
733        _LOG.warning(
734            'Failed to determine the tracking branch, using --base HEAD~1 '
735            'instead of listing all files'
736        )
737        base = 'HEAD~1'
738
739    # If this is a Git repo, list the original paths with git ls-files or diff.
740    if repo:
741        project_root = pw_cli.env.pigweed_environment().PW_PROJECT_ROOT
742        _LOG.info(
743            'Formatting %s',
744            git_repo.describe_files(
745                repo, Path.cwd(), base, paths, exclude, project_root
746            ),
747        )
748
749        # Add files from Git and remove duplicates.
750        files = sorted(
751            set(exclude_paths(exclude, git_repo.list_files(base, paths)))
752            | set(files)
753        )
754    elif base:
755        _LOG.critical(
756            'A base commit may only be provided if running from a Git repo'
757        )
758        return 1
759
760    return format_files(
761        files,
762        fix,
763        repo=repo,
764        code_formats=code_formats,
765        output_directory=output_directory,
766        package_root=package_root,
767    )
768
769
770def format_files(
771    paths: Collection[Union[Path, str]],
772    fix: bool,
773    repo: Optional[Path] = None,
774    code_formats: Collection[CodeFormat] = CODE_FORMATS,
775    output_directory: Optional[Path] = None,
776    package_root: Optional[Path] = None,
777) -> int:
778    """Checks or fixes formatting for the specified files."""
779
780    root: Optional[Path] = None
781
782    if git_repo.is_repo():
783        root = git_repo.root()
784    elif paths:
785        parent = Path(next(iter(paths))).parent
786        if git_repo.is_repo(parent):
787            root = git_repo.root(parent)
788
789    output_dir: Path
790    if output_directory:
791        output_dir = output_directory
792    elif root:
793        output_dir = root / _DEFAULT_PATH
794    else:
795        tempdir = tempfile.TemporaryDirectory()
796        output_dir = Path(tempdir.name)
797
798    formatter = CodeFormatter(
799        files=(Path(p) for p in paths),
800        code_formats=code_formats,
801        root=root,
802        output_dir=output_dir,
803        package_root=package_root,
804    )
805
806    _LOG.info('Checking formatting for %s', plural(formatter.paths, 'file'))
807
808    for line in _file_summary(paths, repo if repo else Path.cwd()):
809        print(line, file=sys.stderr)
810
811    check_errors = formatter.check()
812    print_format_check(check_errors, show_fix_commands=(not fix))
813
814    if check_errors:
815        if fix:
816            _LOG.info(
817                'Applying formatting fixes to %d files', len(check_errors)
818            )
819            fix_errors = formatter.fix()
820            if fix_errors:
821                _LOG.info('Failed to apply formatting fixes')
822                print_format_check(fix_errors, show_fix_commands=False)
823                return 1
824
825            _LOG.info('Formatting fixes applied successfully')
826            return 0
827
828        _LOG.error('Formatting errors found')
829        return 1
830
831    _LOG.info('Congratulations! No formatting changes needed')
832    return 0
833
834
835def arguments(git_paths: bool) -> argparse.ArgumentParser:
836    """Creates an argument parser for format_files or format_paths_in_repo."""
837
838    parser = argparse.ArgumentParser(description=__doc__)
839
840    if git_paths:
841        cli.add_path_arguments(parser)
842    else:
843
844        def existing_path(arg: str) -> Path:
845            path = Path(arg)
846            if not path.is_file():
847                raise argparse.ArgumentTypeError(
848                    f'{arg} is not a path to a file'
849                )
850
851            return path
852
853        parser.add_argument(
854            'paths',
855            metavar='path',
856            nargs='+',
857            type=existing_path,
858            help='File paths to check',
859        )
860
861    parser.add_argument(
862        '--fix', action='store_true', help='Apply formatting fixes in place.'
863    )
864
865    parser.add_argument(
866        '--output-directory',
867        type=Path,
868        help=f"Output directory (default: {'<repo root>' / _DEFAULT_PATH})",
869    )
870    parser.add_argument(
871        '--package-root',
872        type=Path,
873        default=Path(os.environ['PW_PACKAGE_ROOT']),
874        help='Package root directory',
875    )
876
877    return parser
878
879
880def main() -> int:
881    """Check and fix formatting for source files."""
882    return format_paths_in_repo(**vars(arguments(git_paths=True).parse_args()))
883
884
885def _pigweed_upstream_main() -> int:
886    """Check and fix formatting for source files in upstream Pigweed.
887
888    Excludes third party sources.
889    """
890    args = arguments(git_paths=True).parse_args()
891
892    # Exclude paths with third party code from formatting.
893    args.exclude.append(re.compile('^third_party/fuchsia/repo/'))
894
895    return format_paths_in_repo(**vars(args))
896
897
898if __name__ == '__main__':
899    try:
900        # If pw_cli is available, use it to initialize logs.
901        from pw_cli import log
902
903        log.install(logging.INFO)
904    except ImportError:
905        # If pw_cli isn't available, display log messages like a simple print.
906        logging.basicConfig(format='%(message)s', level=logging.INFO)
907
908    sys.exit(main())
909