• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""General purpose tools for running presubmit checks."""
15
16import collections.abc
17from collections import Counter, defaultdict
18import logging
19import os
20from pathlib import Path
21import shlex
22import subprocess
23from typing import (
24    Any,
25    Iterable,
26    Iterator,
27    Sequence,
28    Pattern,
29)
30
31import pw_cli.color
32from pw_cli.plural import plural
33from pw_cli.tool_runner import ToolRunner
34from pw_presubmit.presubmit_context import PRESUBMIT_CONTEXT
35
36_LOG: logging.Logger = logging.getLogger(__name__)
37_COLOR = pw_cli.color.colors()
38
39
40def colorize_diff_line(line: str) -> str:
41    if line.startswith('--- ') or line.startswith('+++ '):
42        return _COLOR.bold_white(line)
43    if line.startswith('-'):
44        return _COLOR.red(line)
45    if line.startswith('+'):
46        return _COLOR.green(line)
47    if line.startswith('@@ '):
48        return _COLOR.cyan(line)
49    return line
50
51
52def colorize_diff(lines: Iterable[str]) -> str:
53    """Takes a diff str or list of str lines and returns a colorized version."""
54    if isinstance(lines, str):
55        lines = lines.splitlines(True)
56
57    return ''.join(colorize_diff_line(line) for line in lines)
58
59
60def make_box(section_alignments: Sequence[str]) -> str:
61    indices = [i + 1 for i in range(len(section_alignments))]
62    top_sections = '{2}'.join('{1:{1}^{width%d}}' % i for i in indices)
63    mid_sections = '{5}'.join(
64        '{section%d:%s{width%d}}' % (i, section_alignments[i - 1], i)
65        for i in indices
66    )
67    bot_sections = '{9}'.join('{8:{8}^{width%d}}' % i for i in indices)
68
69    return ''.join(
70        [
71            '{0}',
72            *top_sections,
73            '{3}\n',
74            '{4}',
75            *mid_sections,
76            '{6}\n',
77            '{7}',
78            *bot_sections,
79            '{10}',
80        ]
81    )
82
83
84def file_summary(
85    paths: Iterable[Path],
86    levels: int = 2,
87    max_lines: int = 12,
88    max_types: int = 3,
89    pad: str = ' ',
90    pad_start: str = ' ',
91    pad_end: str = ' ',
92) -> list[str]:
93    """Summarizes a list of files by the file types in each directory."""
94
95    # Count the file types in each directory.
96    all_counts: dict[Any, Counter] = defaultdict(Counter)
97
98    for path in paths:
99        parent = path.parents[max(len(path.parents) - levels, 0)]
100        all_counts[parent][path.suffix] += 1
101
102    # If there are too many lines, condense directories with the fewest files.
103    if len(all_counts) > max_lines:
104        counts = sorted(
105            all_counts.items(), key=lambda item: -sum(item[1].values())
106        )
107        counts, others = (
108            sorted(counts[: max_lines - 1]),
109            counts[max_lines - 1 :],
110        )
111        counts.append(
112            (
113                f'({plural(others, "other")})',
114                sum((c for _, c in others), Counter()),
115            )
116        )
117    else:
118        counts = sorted(all_counts.items())
119
120    width = max(len(str(d)) + len(os.sep) for d, _ in counts) if counts else 0
121    width += len(pad_start)
122
123    # Prepare the output.
124    output = []
125    for path, files in counts:
126        total = sum(files.values())
127        del files['']  # Never display no-extension files individually.
128
129        if files:
130            extensions = files.most_common(max_types)
131            other_extensions = total - sum(count for _, count in extensions)
132            if other_extensions:
133                extensions.append(('other', other_extensions))
134
135            types = ' (' + ', '.join(f'{c} {e}' for e, c in extensions) + ')'
136        else:
137            types = ''
138
139        root = f'{path}{os.sep}{pad_start}'.ljust(width, pad)
140        output.append(f'{root}{pad_end}{plural(total, "file")}{types}')
141
142    return output
143
144
145def relative_paths(paths: Iterable[Path], start: Path) -> Iterable[Path]:
146    """Returns relative Paths calculated with os.path.relpath."""
147    for path in paths:
148        yield Path(os.path.relpath(path, start))
149
150
151def exclude_paths(
152    exclusions: Iterable[Pattern[str]],
153    paths: Iterable[Path],
154    relative_to: Path | None = None,
155) -> Iterable[Path]:
156    """Excludes paths based on a series of regular expressions."""
157    if relative_to:
158        relpath = lambda path: Path(os.path.relpath(path, relative_to))
159    else:
160        relpath = lambda path: path
161
162    for path in paths:
163        if not any(e.search(relpath(path).as_posix()) for e in exclusions):
164            yield path
165
166
167def _truncate(value, length: int = 60) -> str:
168    value = str(value)
169    return (value[: length - 5] + '[...]') if len(value) > length else value
170
171
172def format_command(args: Sequence, kwargs: dict) -> tuple[str, str]:
173    attr = ', '.join(f'{k}={_truncate(v)}' for k, v in sorted(kwargs.items()))
174    return attr, ' '.join(shlex.quote(str(arg)) for arg in args)
175
176
177def log_run(
178    args, ignore_dry_run: bool = False, **kwargs
179) -> subprocess.CompletedProcess:
180    """Logs a command then runs it with subprocess.run.
181
182    Takes the same arguments as subprocess.run. The command is only executed if
183    dry-run is not enabled.
184    """
185    ctx = PRESUBMIT_CONTEXT.get()
186    if ctx:
187        if not ignore_dry_run:
188            ctx.append_check_command(*args, **kwargs)
189        if ctx.dry_run and not ignore_dry_run:
190            # Return an empty CompletedProcess
191            empty_proc: subprocess.CompletedProcess = (
192                subprocess.CompletedProcess('', 0)
193            )
194            empty_proc.stdout = b''
195            empty_proc.stderr = b''
196            return empty_proc
197    _LOG.debug('[COMMAND] %s\n%s', *format_command(args, kwargs))
198    return subprocess.run(args, **kwargs)
199
200
201class PresubmitToolRunner(ToolRunner):
202    """A simple ToolRunner that runs a process via `log_run()`."""
203
204    @staticmethod
205    def _custom_args() -> Iterable[str]:
206        return ['pw_presubmit_ignore_dry_run']
207
208    def _run_tool(
209        self, tool: str, args, pw_presubmit_ignore_dry_run=False, **kwargs
210    ) -> subprocess.CompletedProcess:
211        """Run the requested tool as a subprocess."""
212        return log_run(
213            [tool, *args],
214            **kwargs,
215            ignore_dry_run=pw_presubmit_ignore_dry_run,
216        )
217
218
219def flatten(*items) -> Iterator:
220    """Yields items from a series of items and nested iterables.
221
222    This function is used to flatten arbitrarily nested lists. str and bytes
223    are kept intact.
224    """
225
226    for item in items:
227        if isinstance(item, collections.abc.Iterable) and not isinstance(
228            item, (str, bytes, bytearray)
229        ):
230            yield from flatten(*item)
231        else:
232            yield item
233