• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2"""Check whether a PR has required labels."""
3
4import sys
5from typing import Any
6
7from github_utils import gh_delete_comment, gh_post_pr_comment
8from gitutils import get_git_remote_name, get_git_repo_dir, GitRepo
9from label_utils import has_required_labels, is_label_err_comment, LABEL_ERR_MSG
10from trymerge import GitHubPR
11
12
13def delete_all_label_err_comments(pr: "GitHubPR") -> None:
14    for comment in pr.get_comments():
15        if is_label_err_comment(comment):
16            gh_delete_comment(pr.org, pr.project, comment.database_id)
17
18
19def add_label_err_comment(pr: "GitHubPR") -> None:
20    # Only make a comment if one doesn't exist already
21    if not any(is_label_err_comment(comment) for comment in pr.get_comments()):
22        gh_post_pr_comment(pr.org, pr.project, pr.pr_num, LABEL_ERR_MSG)
23
24
25def parse_args() -> Any:
26    from argparse import ArgumentParser
27
28    parser = ArgumentParser("Check PR labels")
29    parser.add_argument("pr_num", type=int)
30    # add a flag to return a non-zero exit code if the PR does not have the required labels
31    parser.add_argument(
32        "--exit-non-zero",
33        action="store_true",
34        help="Return a non-zero exit code if the PR does not have the required labels",
35    )
36
37    return parser.parse_args()
38
39
40def main() -> None:
41    args = parse_args()
42    repo = GitRepo(get_git_repo_dir(), get_git_remote_name())
43    org, project = repo.gh_owner_and_name()
44    pr = GitHubPR(org, project, args.pr_num)
45
46    try:
47        if not has_required_labels(pr):
48            print(LABEL_ERR_MSG, flush=True)
49            add_label_err_comment(pr)
50            if args.exit_non_zero:
51                raise RuntimeError("PR does not have required labels")
52        else:
53            delete_all_label_err_comments(pr)
54    except Exception as e:
55        if args.exit_non_zero:
56            raise RuntimeError(f"Error checking labels: {e}") from e
57
58    sys.exit(0)
59
60
61if __name__ == "__main__":
62    main()
63