• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2022 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""Keep specified lists sorted."""
15
16import argparse
17import dataclasses
18import difflib
19import logging
20import os
21from pathlib import Path
22import re
23import sys
24from typing import (
25    Callable,
26    Collection,
27    Dict,
28    List,
29    Optional,
30    Pattern,
31    Sequence,
32    Tuple,
33    Union,
34)
35
36import pw_cli
37from . import cli, format_code, git_repo, presubmit, tools
38
39DEFAULT_PATH = Path('out', 'presubmit', 'keep_sorted')
40
41_COLOR = pw_cli.color.colors()
42_LOG: logging.Logger = logging.getLogger(__name__)
43
44# Ignore a whole section. Please do not change the order of these lines.
45_START = re.compile(r'keep-sorted: (begin|start)', re.IGNORECASE)
46_END = re.compile(r'keep-sorted: (stop|end)', re.IGNORECASE)
47_IGNORE_CASE = re.compile(r'ignore-case', re.IGNORECASE)
48_ALLOW_DUPES = re.compile(r'allow-dupes', re.IGNORECASE)
49_IGNORE_PREFIX = re.compile(r'ignore-prefix=(\S+)', re.IGNORECASE)
50_STICKY_COMMENTS = re.compile(r'sticky-comments=(\S+)', re.IGNORECASE)
51
52# Only include these literals here so keep_sorted doesn't try to reorder later
53# test lines.
54(
55    START,
56    END,
57) = """
58keep-sorted: start
59keep-sorted: end
60""".strip().splitlines()
61
62
63@dataclasses.dataclass
64class KeepSortedContext:
65    paths: List[Path]
66    fix: bool
67    output_dir: Path
68    failure_summary_log: Path
69    failed: bool = False
70
71    def fail(
72        self,
73        description: str = '',
74        path: Optional[Path] = None,
75        line: Optional[int] = None,
76    ) -> None:
77        if not self.fix:
78            self.failed = True
79
80        line_part: str = ''
81        if line is not None:
82            line_part = f'{line}:'
83
84        log = _LOG.error
85        if self.fix:
86            log = _LOG.warning
87
88        if path:
89            log('%s:%s %s', path, line_part, description)
90        else:
91            log('%s', description)
92
93
94class KeepSortedParsingError(presubmit.PresubmitFailure):
95    pass
96
97
98@dataclasses.dataclass
99class _Line:
100    value: str = ''
101    sticky_comments: Sequence[str] = ()
102    continuations: Sequence[str] = ()
103
104    @property
105    def full(self):
106        return ''.join((*self.sticky_comments, self.value, *self.continuations))
107
108    def __lt__(self, other):
109        if not isinstance(other, _Line):
110            return NotImplemented
111        left = (self.value, self.continuations, self.sticky_comments)
112        right = (other.value, other.continuations, other.sticky_comments)
113        return left < right
114
115
116@dataclasses.dataclass
117class _Block:
118    ignore_case: bool = False
119    allow_dupes: bool = False
120    ignored_prefixes: Sequence[str] = dataclasses.field(default_factory=list)
121    sticky_comments: Tuple[str, ...] = ()
122    start_line_number: int = -1
123    start_line: str = ''
124    end_line: str = ''
125    lines: List[str] = dataclasses.field(default_factory=list)
126
127
128class _FileSorter:
129    def __init__(
130        self,
131        ctx: Union[presubmit.PresubmitContext, KeepSortedContext],
132        path: Path,
133        errors: Optional[Dict[Path, Sequence[str]]] = None,
134    ):
135        self.ctx = ctx
136        self.path: Path = path
137        self.all_lines: List[str] = []
138        self.changed: bool = False
139        self._errors: Dict[Path, Sequence[str]] = {}
140        if errors is not None:
141            self._errors = errors
142
143    def _process_block(self, block: _Block) -> Sequence[str]:
144        raw_lines: List[str] = block.lines
145        lines: List[_Line] = []
146
147        prefix = lambda x: len(x) - len(x.lstrip())
148
149        prev_prefix: Optional[int] = None
150        comments: List[str] = []
151        for raw_line in raw_lines:
152            curr_prefix: int = prefix(raw_line)
153            _LOG.debug('prev_prefix %r', prev_prefix)
154            _LOG.debug('curr_prefix %r', curr_prefix)
155            # A "sticky" comment is a comment in the middle of a list of
156            # non-comments. The keep-sorted check keeps this comment with the
157            # following item in the list. For more details see
158            # https://pigweed.dev/pw_presubmit/#sorted-blocks.
159            if block.sticky_comments and raw_line.lstrip().startswith(
160                block.sticky_comments
161            ):
162                _LOG.debug('found sticky %r', raw_line)
163                comments.append(raw_line)
164            elif prev_prefix is not None and curr_prefix > prev_prefix:
165                _LOG.debug('found continuation %r', raw_line)
166                lines[-1].continuations = (*lines[-1].continuations, raw_line)
167                _LOG.debug('modified line %s', lines[-1])
168            else:
169                _LOG.debug('non-sticky %r', raw_line)
170                line = _Line(raw_line, tuple(comments))
171                _LOG.debug('line %s', line)
172                lines.append(line)
173                comments = []
174                prev_prefix = curr_prefix
175        if comments:
176            self.ctx.fail(
177                f'sticky comment at end of block: {comments[0].strip()}',
178                self.path,
179                block.start_line_number,
180            )
181
182        if not block.allow_dupes:
183            lines = list({x.full: x for x in lines}.values())
184
185        StrLinePair = Tuple[str, _Line]
186        sort_key_funcs: List[Callable[[StrLinePair], StrLinePair]] = []
187
188        if block.ignored_prefixes:
189
190            def strip_ignored_prefixes(val):
191                """Remove one ignored prefix from val, if present."""
192                wo_white = val[0].lstrip()
193                white = val[0][0 : -len(wo_white)]
194                for prefix in block.ignored_prefixes:
195                    if wo_white.startswith(prefix):
196                        return (f'{white}{wo_white[len(prefix):]}', val[1])
197                return (val[0], val[1])
198
199            sort_key_funcs.append(strip_ignored_prefixes)
200
201        if block.ignore_case:
202            sort_key_funcs.append(lambda val: (val[0].lower(), val[1]))
203
204        def sort_key(line):
205            vals = (line.value, line)
206            for sort_key_func in sort_key_funcs:
207                vals = sort_key_func(vals)
208            return vals
209
210        for val in lines:
211            _LOG.debug('For sorting: %r => %r', val, sort_key(val))
212
213        sorted_lines = sorted(lines, key=sort_key)
214        raw_sorted_lines: List[str] = []
215        for line in sorted_lines:
216            raw_sorted_lines.extend(line.sticky_comments)
217            raw_sorted_lines.append(line.value)
218            raw_sorted_lines.extend(line.continuations)
219
220        if block.lines != raw_sorted_lines:
221            self.changed = True
222            diff = difflib.Differ()
223            diff_lines = ''.join(diff.compare(block.lines, raw_sorted_lines))
224
225            self._errors.setdefault(self.path, [])
226            self._errors[self.path] = (
227                f'@@ {block.start_line_number},{len(block.lines)+2} '
228                f'{block.start_line_number},{len(raw_sorted_lines)+2} @@\n'
229                f'  {block.start_line}{diff_lines}  {block.end_line}'
230            )
231
232        return raw_sorted_lines
233
234    def _parse_file(self, ins):
235        block: Optional[_Block] = None
236
237        for i, line in enumerate(ins, start=1):
238            if block:
239                if _START.search(line):
240                    raise KeepSortedParsingError(
241                        f'found {line.strip()!r} inside keep-sorted block',
242                        self.path,
243                        i,
244                    )
245
246                if _END.search(line):
247                    _LOG.debug('Found end line %d %r', i, line)
248                    block.end_line = line
249                    self.all_lines.extend(self._process_block(block))
250                    block = None
251                    self.all_lines.append(line)
252
253                else:
254                    _LOG.debug('Adding to block line %d %r', i, line)
255                    block.lines.append(line)
256
257            elif start_match := _START.search(line):
258                _LOG.debug('Found start line %d %r', i, line)
259
260                block = _Block()
261
262                block.ignore_case = bool(_IGNORE_CASE.search(line))
263                _LOG.debug('ignore_case: %s', block.ignore_case)
264
265                block.allow_dupes = bool(_ALLOW_DUPES.search(line))
266                _LOG.debug('allow_dupes: %s', block.allow_dupes)
267
268                match = _IGNORE_PREFIX.search(line)
269                if match:
270                    block.ignored_prefixes = match.group(1).split(',')
271
272                    # We want to check the longest prefixes first, in case one
273                    # prefix is a prefix of another prefix.
274                    block.ignored_prefixes.sort(key=lambda x: (-len(x), x))
275                _LOG.debug('ignored_prefixes: %r', block.ignored_prefixes)
276
277                match = _STICKY_COMMENTS.search(line)
278                if match:
279                    if match.group(1) == 'no':
280                        block.sticky_comments = ()
281                    else:
282                        block.sticky_comments = tuple(match.group(1).split(','))
283                else:
284                    prefix = line[: start_match.start()].strip()
285                    if prefix and len(prefix) <= 3:
286                        block.sticky_comments = (prefix,)
287                _LOG.debug('sticky_comments: %s', block.sticky_comments)
288
289                block.start_line = line
290                block.start_line_number = i
291                self.all_lines.append(line)
292
293                remaining = line[start_match.end() :].strip()
294                remaining = _IGNORE_CASE.sub('', remaining, count=1).strip()
295                remaining = _ALLOW_DUPES.sub('', remaining, count=1).strip()
296                remaining = _IGNORE_PREFIX.sub('', remaining, count=1).strip()
297                remaining = _STICKY_COMMENTS.sub('', remaining, count=1).strip()
298                if remaining.strip():
299                    raise KeepSortedParsingError(
300                        f'unrecognized directive on keep-sorted line: '
301                        f'{remaining}',
302                        self.path,
303                        i,
304                    )
305
306            elif _END.search(line):
307                raise KeepSortedParsingError(
308                    f'found {line.strip()!r} outside keep-sorted block',
309                    self.path,
310                    i,
311                )
312
313            else:
314                self.all_lines.append(line)
315
316        if block:
317            raise KeepSortedParsingError(
318                f'found EOF while looking for "{END}"', self.path
319            )
320
321    def sort(self) -> None:
322        """Check for unsorted keep-sorted blocks."""
323        _LOG.debug('Evaluating path %s', self.path)
324        try:
325            with self.path.open() as ins:
326                _LOG.debug('Processing %s', self.path)
327                self._parse_file(ins)
328
329        except UnicodeDecodeError:
330            # File is not text, like a gif.
331            _LOG.debug('File %s is not a text file', self.path)
332
333    def write(self, path: Optional[Path] = None) -> None:
334        if not self.changed:
335            return
336        if not path:
337            path = self.path
338        with path.open('w') as outs:
339            outs.writelines(self.all_lines)
340            _LOG.info('Applied keep-sorted changes to %s', path)
341
342
343def _print_howto_fix(paths: Sequence[Path]) -> None:
344    def path_relative_to_cwd(path):
345        try:
346            return Path(path).resolve().relative_to(Path.cwd().resolve())
347        except ValueError:
348            return Path(path).resolve()
349
350    message = (
351        f'  pw keep-sorted --fix {path_relative_to_cwd(path)}' for path in paths
352    )
353    _LOG.warning('To sort these blocks, run:\n\n%s\n', '\n'.join(message))
354
355
356def _process_files(
357    ctx: Union[presubmit.PresubmitContext, KeepSortedContext]
358) -> Dict[Path, Sequence[str]]:
359    fix = getattr(ctx, 'fix', False)
360    errors: Dict[Path, Sequence[str]] = {}
361
362    for path in ctx.paths:
363        if path.is_symlink() or path.is_dir():
364            continue
365
366        try:
367            sorter = _FileSorter(ctx, path, errors)
368
369            sorter.sort()
370            if sorter.changed:
371                if fix:
372                    sorter.write()
373
374        except KeepSortedParsingError as exc:
375            ctx.fail(str(exc))
376
377    if not errors:
378        return errors
379
380    ctx.fail(f'Found {len(errors)} files with keep-sorted errors:')
381
382    with ctx.failure_summary_log.open('w') as outs:
383        for path, diffs in errors.items():
384            diff = ''.join(
385                [
386                    f'--- {path} (original)\n',
387                    f'+++ {path} (sorted)\n',
388                    *diffs,
389                ]
390            )
391
392            outs.write(diff)
393            print(format_code.colorize_diff(diff))
394
395    return errors
396
397
398@presubmit.check(name='keep_sorted')
399def presubmit_check(ctx: presubmit.PresubmitContext) -> None:
400    """Presubmit check that ensures specified lists remain sorted."""
401
402    errors = _process_files(ctx)
403
404    if errors:
405        _print_howto_fix(list(errors.keys()))
406
407
408def parse_args() -> argparse.Namespace:
409    """Creates an argument parser and parses arguments."""
410
411    parser = argparse.ArgumentParser(description=__doc__)
412    cli.add_path_arguments(parser)
413    parser.add_argument(
414        '--fix', action='store_true', help='Apply fixes in place.'
415    )
416
417    parser.add_argument(
418        '--output-directory',
419        type=Path,
420        help=f'Output directory (default: {"<repo root>" / DEFAULT_PATH})',
421    )
422
423    return parser.parse_args()
424
425
426def keep_sorted_in_repo(
427    paths: Collection[Union[Path, str]],
428    fix: bool,
429    exclude: Collection[Pattern[str]],
430    base: str,
431    output_directory: Optional[Path],
432) -> int:
433    """Checks or fixes keep-sorted blocks for files in a Git repo."""
434
435    files = [Path(path).resolve() for path in paths if os.path.isfile(path)]
436    repo = git_repo.root() if git_repo.is_repo() else None
437
438    # Implement a graceful fallback in case the tracking branch isn't available.
439    if base == git_repo.TRACKING_BRANCH_ALIAS and not git_repo.tracking_branch(
440        repo
441    ):
442        _LOG.warning(
443            'Failed to determine the tracking branch, using --base HEAD~1 '
444            'instead of listing all files'
445        )
446        base = 'HEAD~1'
447
448    # If this is a Git repo, list the original paths with git ls-files or diff.
449    project_root = pw_cli.env.pigweed_environment().PW_PROJECT_ROOT
450    if repo:
451        _LOG.info(
452            'Sorting %s',
453            git_repo.describe_files(
454                repo, Path.cwd(), base, paths, exclude, project_root
455            ),
456        )
457
458        # Add files from Git and remove duplicates.
459        files = sorted(
460            set(tools.exclude_paths(exclude, git_repo.list_files(base, paths)))
461            | set(files)
462        )
463    elif base:
464        _LOG.critical(
465            'A base commit may only be provided if running from a Git repo'
466        )
467        return 1
468
469    outdir: Path
470    if output_directory:
471        outdir = output_directory
472    elif repo:
473        outdir = repo / DEFAULT_PATH
474    else:
475        outdir = project_root / DEFAULT_PATH
476
477    ctx = KeepSortedContext(
478        paths=files,
479        fix=fix,
480        output_dir=outdir,
481        failure_summary_log=outdir / 'failure-summary.log',
482    )
483    errors = _process_files(ctx)
484
485    if not fix and errors:
486        _print_howto_fix(list(errors.keys()))
487
488    return int(ctx.failed)
489
490
491def main() -> int:
492    return keep_sorted_in_repo(**vars(parse_args()))
493
494
495if __name__ == '__main__':
496    pw_cli.log.install(logging.INFO)
497    sys.exit(main())
498