• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2022 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""Keep specified lists sorted."""
15
16import argparse
17import dataclasses
18import difflib
19import logging
20import os
21from pathlib import Path
22import re
23import sys
24from typing import (
25    Callable,
26    Collection,
27    Pattern,
28    Sequence,
29)
30
31import pw_cli
32from pw_cli.plural import plural
33from . import cli, git_repo, presubmit, presubmit_context, tools
34
35DEFAULT_PATH = Path('out', 'presubmit', 'keep_sorted')
36
37_LOG: logging.Logger = logging.getLogger(__name__)
38
39# Ignore a whole section. Please do not change the order of these lines.
40_START = re.compile(r'keep-sorted: (begin|start)', re.IGNORECASE)
41_END = re.compile(r'keep-sorted: (stop|end)', re.IGNORECASE)
42_IGNORE_CASE = re.compile(r'ignore-case', re.IGNORECASE)
43_ALLOW_DUPES = re.compile(r'allow-dupes', re.IGNORECASE)
44_IGNORE_PREFIX = re.compile(r'ignore-prefix=(\S+)', re.IGNORECASE)
45_STICKY_COMMENTS = re.compile(r'sticky-comments=(\S+)', re.IGNORECASE)
46
47# Only include these literals here so keep_sorted doesn't try to reorder later
48# test lines.
49(
50    START,
51    END,
52) = """
53keep-sorted: start
54keep-sorted: end
55""".strip().splitlines()
56
57
58@dataclasses.dataclass
59class KeepSortedContext:
60    paths: list[Path]
61    fix: bool
62    output_dir: Path
63    failure_summary_log: Path
64    failed: bool = False
65
66    def fail(
67        self,
68        description: str = '',
69        path: Path | None = None,
70        line: int | None = None,
71    ) -> None:
72        if not self.fix:
73            self.failed = True
74
75        line_part: str = ''
76        if line is not None:
77            line_part = f'{line}:'
78
79        log = _LOG.error
80        if self.fix:
81            log = _LOG.warning
82
83        if path:
84            log('%s:%s %s', path, line_part, description)
85        else:
86            log('%s', description)
87
88
89class KeepSortedParsingError(presubmit.PresubmitFailure):
90    pass
91
92
93@dataclasses.dataclass
94class _Line:
95    value: str = ''
96    sticky_comments: Sequence[str] = ()
97    continuations: Sequence[str] = ()
98
99    @property
100    def full(self):
101        return ''.join((*self.sticky_comments, self.value, *self.continuations))
102
103    def __lt__(self, other):
104        if not isinstance(other, _Line):
105            return NotImplemented
106        left = (self.value, self.continuations, self.sticky_comments)
107        right = (other.value, other.continuations, other.sticky_comments)
108        return left < right
109
110
111@dataclasses.dataclass
112class _Block:
113    ignore_case: bool = False
114    allow_dupes: bool = False
115    ignored_prefixes: Sequence[str] = dataclasses.field(default_factory=list)
116    sticky_comments: tuple[str, ...] = ()
117    start_line_number: int = -1
118    start_line: str = ''
119    end_line: str = ''
120    lines: list[str] = dataclasses.field(default_factory=list)
121
122
123class _FileSorter:
124    def __init__(
125        self,
126        ctx: presubmit.PresubmitContext | KeepSortedContext,
127        path: Path,
128        errors: dict[Path, Sequence[str]] | None = None,
129    ):
130        self.ctx = ctx
131        self.path: Path = path
132        self.all_lines: list[str] = []
133        self.changed: bool = False
134        self._errors: dict[Path, Sequence[str]] = {}
135        if errors is not None:
136            self._errors = errors
137
138    def _process_block(self, block: _Block) -> Sequence[str]:
139        raw_lines: list[str] = block.lines
140        lines: list[_Line] = []
141
142        prefix = lambda x: len(x) - len(x.lstrip())
143
144        prev_prefix: int | None = None
145        comments: list[str] = []
146        for raw_line in raw_lines:
147            curr_prefix: int = prefix(raw_line)
148            _LOG.debug('prev_prefix %r', prev_prefix)
149            _LOG.debug('curr_prefix %r', curr_prefix)
150            # A "sticky" comment is a comment in the middle of a list of
151            # non-comments. The keep-sorted check keeps this comment with the
152            # following item in the list. For more details see
153            # https://pigweed.dev/pw_presubmit/#sorted-blocks.
154            if block.sticky_comments and raw_line.lstrip().startswith(
155                block.sticky_comments
156            ):
157                _LOG.debug('found sticky %r', raw_line)
158                comments.append(raw_line)
159            elif prev_prefix is not None and curr_prefix > prev_prefix:
160                _LOG.debug('found continuation %r', raw_line)
161                lines[-1].continuations = (*lines[-1].continuations, raw_line)
162                _LOG.debug('modified line %s', lines[-1])
163            else:
164                _LOG.debug('non-sticky %r', raw_line)
165                line = _Line(raw_line, tuple(comments))
166                _LOG.debug('line %s', line)
167                lines.append(line)
168                comments = []
169                prev_prefix = curr_prefix
170        if comments:
171            self.ctx.fail(
172                f'sticky comment at end of block: {comments[0].strip()}',
173                self.path,
174                block.start_line_number,
175            )
176
177        if not block.allow_dupes:
178            lines = list({x.full: x for x in lines}.values())
179
180        StrLinePair = tuple[str, _Line]  # pylint: disable=invalid-name
181        sort_key_funcs: list[Callable[[StrLinePair], StrLinePair]] = []
182
183        if block.ignored_prefixes:
184
185            def strip_ignored_prefixes(val):
186                """Remove one ignored prefix from val, if present."""
187                wo_white = val[0].lstrip()
188                white = val[0][0 : -len(wo_white)]
189                for prefix in block.ignored_prefixes:
190                    if wo_white.startswith(prefix):
191                        return (f'{white}{wo_white[len(prefix):]}', val[1])
192                return (val[0], val[1])
193
194            sort_key_funcs.append(strip_ignored_prefixes)
195
196        if block.ignore_case:
197            sort_key_funcs.append(lambda val: (val[0].lower(), val[1]))
198
199        def sort_key(line):
200            vals = (line.value, line)
201            for sort_key_func in sort_key_funcs:
202                vals = sort_key_func(vals)
203            return vals
204
205        for val in lines:
206            _LOG.debug('For sorting: %r => %r', val, sort_key(val))
207
208        sorted_lines = sorted(lines, key=sort_key)
209        raw_sorted_lines: list[str] = []
210        for line in sorted_lines:
211            raw_sorted_lines.extend(line.sticky_comments)
212            raw_sorted_lines.append(line.value)
213            raw_sorted_lines.extend(line.continuations)
214
215        if block.lines != raw_sorted_lines:
216            self.changed = True
217            diff = difflib.Differ()
218            diff_lines = ''.join(diff.compare(block.lines, raw_sorted_lines))
219
220            self._errors.setdefault(self.path, [])
221            self._errors[self.path] = (
222                f'@@ {block.start_line_number},{len(block.lines)+2} '
223                f'{block.start_line_number},{len(raw_sorted_lines)+2} @@\n'
224                f'  {block.start_line}{diff_lines}  {block.end_line}'
225            )
226
227        return raw_sorted_lines
228
229    def _parse_file(self, ins):
230        block: _Block | None = None
231
232        for i, line in enumerate(ins, start=1):
233            if block:
234                if _START.search(line):
235                    raise KeepSortedParsingError(
236                        f'found {line.strip()!r} inside keep-sorted block',
237                        self.path,
238                        i,
239                    )
240
241                if _END.search(line):
242                    _LOG.debug('Found end line %d %r', i, line)
243                    block.end_line = line
244                    self.all_lines.extend(self._process_block(block))
245                    block = None
246                    self.all_lines.append(line)
247
248                else:
249                    _LOG.debug('Adding to block line %d %r', i, line)
250                    block.lines.append(line)
251
252            elif start_match := _START.search(line):
253                _LOG.debug('Found start line %d %r', i, line)
254
255                block = _Block()
256
257                block.ignore_case = bool(_IGNORE_CASE.search(line))
258                _LOG.debug('ignore_case: %s', block.ignore_case)
259
260                block.allow_dupes = bool(_ALLOW_DUPES.search(line))
261                _LOG.debug('allow_dupes: %s', block.allow_dupes)
262
263                match = _IGNORE_PREFIX.search(line)
264                if match:
265                    block.ignored_prefixes = match.group(1).split(',')
266
267                    # We want to check the longest prefixes first, in case one
268                    # prefix is a prefix of another prefix.
269                    block.ignored_prefixes.sort(key=lambda x: (-len(x), x))
270                _LOG.debug('ignored_prefixes: %r', block.ignored_prefixes)
271
272                match = _STICKY_COMMENTS.search(line)
273                if match:
274                    if match.group(1) == 'no':
275                        block.sticky_comments = ()
276                    else:
277                        block.sticky_comments = tuple(match.group(1).split(','))
278                else:
279                    prefix = line[: start_match.start()].strip()
280                    if prefix and len(prefix) <= 3:
281                        block.sticky_comments = (prefix,)
282                _LOG.debug('sticky_comments: %s', block.sticky_comments)
283
284                block.start_line = line
285                block.start_line_number = i
286                self.all_lines.append(line)
287
288                remaining = line[start_match.end() :].strip()
289                remaining = _IGNORE_CASE.sub('', remaining, count=1).strip()
290                remaining = _ALLOW_DUPES.sub('', remaining, count=1).strip()
291                remaining = _IGNORE_PREFIX.sub('', remaining, count=1).strip()
292                remaining = _STICKY_COMMENTS.sub('', remaining, count=1).strip()
293                if remaining.strip():
294                    raise KeepSortedParsingError(
295                        f'unrecognized directive on keep-sorted line: '
296                        f'{remaining}',
297                        self.path,
298                        i,
299                    )
300
301            elif _END.search(line):
302                raise KeepSortedParsingError(
303                    f'found {line.strip()!r} outside keep-sorted block',
304                    self.path,
305                    i,
306                )
307
308            else:
309                self.all_lines.append(line)
310
311        if block:
312            raise KeepSortedParsingError(
313                f'found EOF while looking for "{END}"', self.path
314            )
315
316    def sort(self) -> None:
317        """Check for unsorted keep-sorted blocks."""
318        _LOG.debug('Evaluating path %s', self.path)
319        try:
320            with self.path.open() as ins:
321                _LOG.debug('Processing %s', self.path)
322                self._parse_file(ins)
323
324        except UnicodeDecodeError:
325            # File is not text, like a gif.
326            _LOG.debug('File %s is not a text file', self.path)
327
328    def write(self, path: Path | None = None) -> None:
329        if not self.changed:
330            return
331        if not path:
332            path = self.path
333        with path.open('w') as outs:
334            outs.writelines(self.all_lines)
335            _LOG.info('Applied keep-sorted changes to %s', path)
336
337
338def _print_howto_fix(paths: Sequence[Path]) -> None:
339    def path_relative_to_cwd(path):
340        try:
341            return Path(path).resolve().relative_to(Path.cwd().resolve())
342        except ValueError:
343            return Path(path).resolve()
344
345    message = (
346        f'  pw keep-sorted --fix {path_relative_to_cwd(path)}' for path in paths
347    )
348    _LOG.warning('To sort these blocks, run:\n\n%s\n', '\n'.join(message))
349
350
351def _process_files(
352    ctx: presubmit.PresubmitContext | KeepSortedContext,
353) -> dict[Path, Sequence[str]]:
354    fix = getattr(ctx, 'fix', False)
355    errors: dict[Path, Sequence[str]] = {}
356
357    for path in ctx.paths:
358        if path.is_symlink() or path.is_dir():
359            continue
360
361        try:
362            sorter = _FileSorter(ctx, path, errors)
363
364            sorter.sort()
365            if sorter.changed:
366                if fix:
367                    sorter.write()
368
369        except KeepSortedParsingError as exc:
370            ctx.fail(str(exc))
371
372    if not errors:
373        return errors
374
375    ctx.fail(f'Found {plural(errors, "file")} with keep-sorted errors:')
376
377    with ctx.failure_summary_log.open('w') as outs:
378        for path, diffs in errors.items():
379            diff = ''.join(
380                [
381                    f'--- {path} (original)\n',
382                    f'+++ {path} (sorted)\n',
383                    *diffs,
384                ]
385            )
386
387            outs.write(diff)
388            print(tools.colorize_diff(diff))
389
390    return errors
391
392
393@presubmit.check(name='keep_sorted')
394def presubmit_check(ctx: presubmit.PresubmitContext) -> None:
395    """Presubmit check that ensures specified lists remain sorted."""
396
397    ctx.paths = presubmit_context.apply_exclusions(ctx)
398    errors = _process_files(ctx)
399
400    if errors:
401        _print_howto_fix(list(errors.keys()))
402
403
404def parse_args() -> argparse.Namespace:
405    """Creates an argument parser and parses arguments."""
406
407    parser = argparse.ArgumentParser(description=__doc__)
408    cli.add_path_arguments(parser)
409    parser.add_argument(
410        '--fix', action='store_true', help='Apply fixes in place.'
411    )
412
413    parser.add_argument(
414        '--output-directory',
415        type=Path,
416        help=f'Output directory (default: {"<repo root>" / DEFAULT_PATH})',
417    )
418
419    return parser.parse_args()
420
421
422def keep_sorted_in_repo(
423    paths: Collection[Path | str],
424    fix: bool,
425    exclude: Collection[Pattern[str]],
426    base: str,
427    output_directory: Path | None,
428) -> int:
429    """Checks or fixes keep-sorted blocks for files in a Git repo."""
430
431    files = [Path(path).resolve() for path in paths if os.path.isfile(path)]
432    repo = git_repo.root() if git_repo.is_repo() else None
433
434    # Implement a graceful fallback in case the tracking branch isn't available.
435    if base == git_repo.TRACKING_BRANCH_ALIAS and not git_repo.tracking_branch(
436        repo
437    ):
438        _LOG.warning(
439            'Failed to determine the tracking branch, using --base HEAD~1 '
440            'instead of listing all files'
441        )
442        base = 'HEAD~1'
443
444    # If this is a Git repo, list the original paths with git ls-files or diff.
445    project_root = pw_cli.env.pigweed_environment().PW_PROJECT_ROOT
446    if repo:
447        _LOG.info(
448            'Sorting %s',
449            git_repo.describe_files(
450                repo, Path.cwd(), base, paths, exclude, project_root
451            ),
452        )
453
454        # Add files from Git and remove duplicates.
455        files = sorted(
456            set(tools.exclude_paths(exclude, git_repo.list_files(base, paths)))
457            | set(files)
458        )
459    elif base:
460        _LOG.critical(
461            'A base commit may only be provided if running from a Git repo'
462        )
463        return 1
464
465    outdir: Path
466    if output_directory:
467        outdir = output_directory
468    elif repo:
469        outdir = repo / DEFAULT_PATH
470    else:
471        outdir = project_root / DEFAULT_PATH
472
473    ctx = KeepSortedContext(
474        paths=files,
475        fix=fix,
476        output_dir=outdir,
477        failure_summary_log=outdir / 'failure-summary.log',
478    )
479    errors = _process_files(ctx)
480
481    if not fix and errors:
482        _print_howto_fix(list(errors.keys()))
483
484    return int(ctx.failed)
485
486
487def main() -> int:
488    return keep_sorted_in_repo(**vars(parse_args()))
489
490
491if __name__ == '__main__':
492    pw_cli.log.install(logging.INFO)
493    sys.exit(main())
494