• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2024 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""Helpful commands for working with a Git repository."""
15
16import logging
17from pathlib import Path
18import re
19import subprocess
20from typing import Collection, Iterable, Pattern
21
22from pw_cli.plural import plural
23from pw_cli.tool_runner import ToolRunner
24
25_LOG = logging.getLogger(__name__)
26
27TRACKING_BRANCH_ALIAS = '@{upstream}'
28_TRACKING_BRANCH_ALIASES = TRACKING_BRANCH_ALIAS, '@{u}'
29_NON_TRACKING_FALLBACK = 'HEAD~10'
30
31
32class GitError(Exception):
33    """A Git-raised exception."""
34
35    def __init__(self, message, returncode):
36        super().__init__(message)
37        self.returncode = returncode
38
39
40class _GitTool:
41    def __init__(self, tool_runner: ToolRunner, working_dir: Path):
42        self._run_tool = tool_runner
43        self._working_dir = working_dir
44
45    def __call__(self, *args, **kwargs) -> str:
46        proc = self._run_tool(
47            tool='git',
48            args=(
49                '-C',
50                str(self._working_dir),
51                *args,
52            ),
53            **kwargs,
54        )
55
56        if proc.returncode != 0:
57            if not proc.stderr:
58                err = '(no output)'
59            else:
60                err = proc.stderr.decode().strip()
61            raise GitError(err, proc.returncode)
62
63        return '' if not proc.stdout else proc.stdout.decode().strip()
64
65
66class GitRepo:
67    """Represents a checked out Git repository that may be queried for info."""
68
69    def __init__(self, root: Path, tool_runner: ToolRunner):
70        self._root = root.resolve()
71        self._git = _GitTool(tool_runner, self._root)
72
73    def tracking_branch(
74        self,
75        fallback: str | None = None,
76    ) -> str | None:
77        """Returns the tracking branch of the current branch.
78
79        Since most callers of this function can safely handle a return value of
80        None, suppress exceptions and return None if there is no tracking
81        branch.
82
83        Returns:
84          the remote tracking branch name or None if there is none
85        """
86
87        # This command should only error out if there's no upstream branch set.
88        try:
89            return self._git(
90                'rev-parse',
91                '--abbrev-ref',
92                '--symbolic-full-name',
93                TRACKING_BRANCH_ALIAS,
94            )
95
96        except GitError:
97            return fallback
98
99    def current_branch(self) -> str | None:
100        """Returns the current branch, or None if it cannot be determined."""
101        try:
102            return self._git('rev-parse', '--abbrev-ref', 'HEAD')
103        except GitError:
104            return None
105
106    def _ls_files(self, pathspecs: Collection[Path | str]) -> Iterable[Path]:
107        """Returns results of git ls-files as absolute paths."""
108        for file in self._git('ls-files', '--', *pathspecs).splitlines():
109            full_path = self._root / file
110            # Modified submodules will show up as directories and should be
111            # ignored.
112            if full_path.is_file():
113                yield full_path
114
115    def _diff_names(
116        self, commit: str, pathspecs: Collection[Path | str]
117    ) -> Iterable[Path]:
118        """Returns paths of files changed since the specified commit.
119
120        All returned paths are absolute file paths.
121        """
122        for file in self._git(
123            'diff',
124            '--name-only',
125            '--diff-filter=d',
126            commit,
127            '--',
128            *pathspecs,
129        ).splitlines():
130            full_path = self._root / file
131            # Modified submodules will show up as directories and should be
132            # ignored.
133            if full_path.is_file():
134                yield full_path
135
136    def list_files(
137        self,
138        commit: str | None = None,
139        pathspecs: Collection[Path | str] = (),
140    ) -> list[Path]:
141        """Lists files modified since the specified commit.
142
143        If ``commit`` is not found in the current repo, all files in the
144        repository are listed.
145
146        Arugments:
147            commit: The Git hash to start from when listing modified files
148            pathspecs: Git pathspecs use when filtering results
149
150        Returns:
151            A sorted list of absolute paths.
152        """
153
154        if commit in _TRACKING_BRANCH_ALIASES:
155            commit = self.tracking_branch(fallback=_NON_TRACKING_FALLBACK)
156
157        if commit:
158            try:
159                return sorted(self._diff_names(commit, pathspecs))
160            except GitError:
161                _LOG.warning(
162                    'Error comparing with base revision %s of %s, listing all '
163                    'files instead of just changed files',
164                    commit,
165                    self._root,
166                )
167
168        return sorted(self._ls_files(pathspecs))
169
170    def has_uncommitted_changes(self) -> bool:
171        """Returns True if this Git repo has uncommitted changes in it.
172
173        Note: This does not check for untracked files.
174
175        Returns:
176            True if the Git repo has uncommitted changes in it.
177        """
178
179        # Refresh the Git index so that the diff-index command will be accurate.
180        # The `git update-index` command isn't reliable when run in parallel
181        # with other processes that may touch files in the repo directory, so
182        # retry a few times before giving up. The hallmark of this failure mode
183        # is the lack of an error message on stderr, so if we see something
184        # there we can assume it's some other issue and raise.
185        retries = 6
186        for i in range(retries):
187            try:
188                self._git(
189                    'update-index',
190                    '-q',
191                    '--refresh',
192                    pw_presubmit_ignore_dry_run=True,
193                )
194            except subprocess.CalledProcessError as err:
195                if err.stderr or i == retries - 1:
196                    raise
197                continue
198
199        try:
200            self._git(
201                'diff-index',
202                '--quiet',
203                'HEAD',
204                '--',
205                pw_presubmit_ignore_dry_run=True,
206            )
207        except GitError as err:
208            # diff-index exits with 1 if there are uncommitted changes.
209            if err.returncode == 1:
210                return True
211
212            # Unexpected error.
213            raise
214
215        return False
216
217    def root(self) -> Path:
218        """The root file path of this Git repository.
219
220        Returns:
221            The repository root as an absolute path.
222        """
223        return self._root
224
225    def list_submodules(
226        self, excluded_paths: Collection[Pattern | str] = ()
227    ) -> list[Path]:
228        """Query Git and return a list of submodules in the current project.
229
230        Arguments:
231            excluded_paths: Pattern or string that match submodules that should
232                not be returned. All matches are done on posix-style paths
233                relative to the project root.
234
235        Returns:
236            List of "Path"s which were found but not excluded. All paths are
237            absolute.
238        """
239        discovery_report = self._git(
240            'submodule',
241            'foreach',
242            '--quiet',
243            '--recursive',
244            'echo $toplevel/$sm_path',
245        )
246        module_dirs = [Path(line) for line in discovery_report.split()]
247
248        for exclude in excluded_paths:
249            if isinstance(exclude, Pattern):
250                for module_dir in reversed(module_dirs):
251                    if exclude.fullmatch(
252                        module_dir.relative_to(self._root).as_posix()
253                    ):
254                        module_dirs.remove(module_dir)
255            else:
256                for module_dir in reversed(module_dirs):
257                    print(f'not regex: {exclude}')
258                    if exclude == module_dir.relative_to(self._root).as_posix():
259                        module_dirs.remove(module_dir)
260
261        return module_dirs
262
263    def commit_message(self, commit: str = 'HEAD') -> str:
264        """Returns the commit message of the specified commit.
265
266        Defaults to ``HEAD`` if no commit specified.
267
268        Returns:
269            Commit message contents as a string.
270        """
271        return self._git('log', '--format=%B', '-n1', commit)
272
273    def commit_author(self, commit: str = 'HEAD') -> str:
274        """Returns the author of the specified commit.
275
276        Defaults to ``HEAD`` if no commit specified.
277
278        Returns:
279            Commit author as a string.
280        """
281        return self._git('log', '--format=%ae', '-n1', commit)
282
283    def commit_hash(
284        self,
285        commit: str = 'HEAD',
286        short: bool = True,
287    ) -> str:
288        """Returns the hash associated with the specified commit.
289
290        Defaults to ``HEAD`` if no commit specified.
291
292        Returns:
293            Commit hash as a string.
294        """
295        args = ['rev-parse']
296        if short:
297            args += ['--short']
298        args += [commit]
299        return self._git(*args)
300
301    def commit_change_id(self, commit: str = 'HEAD') -> str | None:
302        """Returns the Gerrit Change-Id of the specified commit.
303
304        Defaults to ``HEAD`` if no commit specified.
305
306        Returns:
307            Change-Id as a string, or ``None`` if it does not exist.
308        """
309        message = self.commit_message(commit)
310        regex = re.compile(
311            'Change-Id: (I[a-fA-F0-9]+)',
312            re.MULTILINE,
313        )
314        match = regex.search(message)
315        return match.group(1) if match else None
316
317
318def find_git_repo(path_in_repo: Path, tool_runner: ToolRunner) -> GitRepo:
319    """Tries to find the root of the Git repo that owns ``path_in_repo``.
320
321    Raises:
322        GitError: The specified path does not live in a Git repository.
323
324    Returns:
325        A GitRepo representing the the enclosing repository that tracks the
326        specified file or folder.
327    """
328    git_tool = _GitTool(
329        tool_runner,
330        path_in_repo if path_in_repo.is_dir() else path_in_repo.parent,
331    )
332    root = Path(
333        git_tool(
334            'rev-parse',
335            '--show-toplevel',
336        )
337    )
338
339    return GitRepo(root, tool_runner)
340
341
342def is_in_git_repo(p: Path, tool_runner: ToolRunner) -> bool:
343    """Returns true if the specified path is tracked by a Git repository.
344
345    Returns:
346        True if the specified file or folder is tracked by a Git repository.
347    """
348    try:
349        find_git_repo(p, tool_runner)
350    except GitError:
351        return False
352
353    return True
354
355
356def _describe_constraints(
357    repo: GitRepo,
358    working_dir: Path,
359    commit: str | None,
360    pathspecs: Collection[Path | str],
361    exclude: Collection[Pattern[str]],
362) -> Iterable[str]:
363    if not repo.root().samefile(working_dir):
364        yield (
365            'under the '
366            f'{working_dir.resolve().relative_to(repo.root().resolve())}'
367            ' subdirectory'
368        )
369
370    if commit in _TRACKING_BRANCH_ALIASES:
371        commit = repo.tracking_branch()
372        if commit is None:
373            _LOG.warning(
374                'Attempted to list files changed since the remote tracking '
375                'branch, but the repo is not tracking a branch'
376            )
377
378    if commit:
379        yield f'that have changed since {commit}'
380
381    if pathspecs:
382        paths_str = ', '.join(str(p) for p in pathspecs)
383        yield f'that match {plural(pathspecs, "pathspec")} ({paths_str})'
384
385    if exclude:
386        yield (
387            f'that do not match {plural(exclude, "pattern")} ('
388            + ', '.join(p.pattern for p in exclude)
389            + ')'
390        )
391
392
393def describe_git_pattern(
394    working_dir: Path,
395    commit: str | None,
396    pathspecs: Collection[Path | str],
397    exclude: Collection[Pattern],
398    tool_runner: ToolRunner,
399    project_root: Path | None = None,
400) -> str:
401    """Provides a description for a set of files in a Git repo.
402
403    Example:
404
405        files in the pigweed repo
406        - that have changed since origin/main..HEAD
407        - that do not match 7 patterns (...)
408
409    The unit tests for this function are the source of truth for the expected
410    output.
411
412    Returns:
413        A multi-line string with descriptive information about the provided
414        Git pathspecs.
415    """
416    repo = find_git_repo(working_dir, tool_runner)
417    constraints = list(
418        _describe_constraints(repo, working_dir, commit, pathspecs, exclude)
419    )
420
421    name = repo.root().name
422    if project_root and project_root != repo.root():
423        name = str(repo.root().relative_to(project_root))
424
425    if not constraints:
426        return f'all files in the {name} repo'
427
428    msg = f'files in the {name} repo'
429    if len(constraints) == 1:
430        return f'{msg} {constraints[0]}'
431
432    return msg + ''.join(f'\n    - {line}' for line in constraints)
433