1#!/usr/bin/env python3 2"""Check whether a PR has required labels.""" 3 4import sys 5from typing import Any 6 7from github_utils import gh_delete_comment, gh_post_pr_comment 8from gitutils import get_git_remote_name, get_git_repo_dir, GitRepo 9from label_utils import has_required_labels, is_label_err_comment, LABEL_ERR_MSG 10from trymerge import GitHubPR 11 12 13def delete_all_label_err_comments(pr: "GitHubPR") -> None: 14 for comment in pr.get_comments(): 15 if is_label_err_comment(comment): 16 gh_delete_comment(pr.org, pr.project, comment.database_id) 17 18 19def add_label_err_comment(pr: "GitHubPR") -> None: 20 # Only make a comment if one doesn't exist already 21 if not any(is_label_err_comment(comment) for comment in pr.get_comments()): 22 gh_post_pr_comment(pr.org, pr.project, pr.pr_num, LABEL_ERR_MSG) 23 24 25def parse_args() -> Any: 26 from argparse import ArgumentParser 27 28 parser = ArgumentParser("Check PR labels") 29 parser.add_argument("pr_num", type=int) 30 # add a flag to return a non-zero exit code if the PR does not have the required labels 31 parser.add_argument( 32 "--exit-non-zero", 33 action="store_true", 34 help="Return a non-zero exit code if the PR does not have the required labels", 35 ) 36 37 return parser.parse_args() 38 39 40def main() -> None: 41 args = parse_args() 42 repo = GitRepo(get_git_repo_dir(), get_git_remote_name()) 43 org, project = repo.gh_owner_and_name() 44 pr = GitHubPR(org, project, args.pr_num) 45 46 try: 47 if not has_required_labels(pr): 48 print(LABEL_ERR_MSG, flush=True) 49 add_label_err_comment(pr) 50 if args.exit_non_zero: 51 raise RuntimeError("PR does not have required labels") 52 else: 53 delete_all_label_err_comments(pr) 54 except Exception as e: 55 if args.exit_non_zero: 56 raise RuntimeError(f"Error checking labels: {e}") from e 57 58 sys.exit(0) 59 60 61if __name__ == "__main__": 62 main() 63