1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7"""GitHub Label Utilities.""" 8 9import json 10 11from functools import lru_cache 12from typing import Any, List, Tuple, TYPE_CHECKING, Union 13 14from github_utils import gh_fetch_url_and_headers, GitHubComment 15 16# TODO: this is a temp workaround to avoid circular dependencies, 17# and should be removed once GitHubPR is refactored out of trymerge script. 18if TYPE_CHECKING: 19 from trymerge import GitHubPR 20 21BOT_AUTHORS = ["github-actions", "pytorchmergebot", "pytorch-bot"] 22 23LABEL_ERR_MSG_TITLE = "This PR needs a `release notes:` label" 24LABEL_ERR_MSG = f"""# {LABEL_ERR_MSG_TITLE} 25If your changes are user facing and intended to be a part of release notes, please use a label starting with `release notes:`. 26 27If not, please add the `topic: not user facing` label. 28 29To add a label, you can comment to pytorchbot, for example 30`@pytorchbot label "topic: not user facing"` 31 32For more information, see 33https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work. 34""" 35 36 37def request_for_labels(url: str) -> Tuple[Any, Any]: 38 headers = {"Accept": "application/vnd.github.v3+json"} 39 return gh_fetch_url_and_headers( 40 url, headers=headers, reader=lambda x: x.read().decode("utf-8") 41 ) 42 43 44def update_labels(labels: List[str], info: str) -> None: 45 labels_json = json.loads(info) 46 labels.extend([x["name"] for x in labels_json]) 47 48 49def get_last_page_num_from_header(header: Any) -> int: 50 # Link info looks like: <https://api.github.com/repositories/65600975/labels?per_page=100&page=2>; 51 # rel="next", <https://api.github.com/repositories/65600975/labels?per_page=100&page=3>; rel="last" 52 link_info = header["link"] 53 # Docs does not specify that it should be present for projects with just few labels 54 # And https://github.com/malfet/deleteme/actions/runs/7334565243/job/19971396887 it's not the case 55 if link_info is None: 56 return 1 57 prefix = "&page=" 58 suffix = ">;" 59 return int( 60 link_info[link_info.rindex(prefix) + len(prefix) : link_info.rindex(suffix)] 61 ) 62 63 64@lru_cache 65def gh_get_labels(org: str, repo: str) -> List[str]: 66 prefix = f"https://api.github.com/repos/{org}/{repo}/labels?per_page=100" 67 header, info = request_for_labels(prefix + "&page=1") 68 labels: List[str] = [] 69 update_labels(labels, info) 70 71 last_page = get_last_page_num_from_header(header) 72 assert ( 73 last_page > 0 74 ), "Error reading header info to determine total number of pages of labels" 75 for page_number in range(2, last_page + 1): # skip page 1 76 _, info = request_for_labels(prefix + f"&page={page_number}") 77 update_labels(labels, info) 78 79 return labels 80 81 82def gh_add_labels( 83 org: str, repo: str, pr_num: int, labels: Union[str, List[str]], dry_run: bool 84) -> None: 85 if dry_run: 86 print(f"Dryrun: Adding labels {labels} to PR {pr_num}") 87 return 88 gh_fetch_url_and_headers( 89 url=f"https://api.github.com/repos/{org}/{repo}/issues/{pr_num}/labels", 90 data={"labels": labels}, 91 ) 92 93 94def gh_remove_label( 95 org: str, repo: str, pr_num: int, label: str, dry_run: bool 96) -> None: 97 if dry_run: 98 print(f"Dryrun: Removing {label} from PR {pr_num}") 99 return 100 gh_fetch_url_and_headers( 101 url=f"https://api.github.com/repos/{org}/{repo}/issues/{pr_num}/labels/{label}", 102 method="DELETE", 103 ) 104 105 106def get_release_notes_labels(org: str, repo: str) -> List[str]: 107 return [ 108 label 109 for label in gh_get_labels(org, repo) 110 if label.lstrip().startswith("release notes:") 111 ] 112 113 114def has_required_labels(pr: "GitHubPR") -> bool: 115 pr_labels = pr.get_labels() 116 # Check if PR is not user facing 117 is_not_user_facing_pr = any( 118 label.strip() == "topic: not user facing" for label in pr_labels 119 ) 120 return is_not_user_facing_pr or any( 121 label.strip() in get_release_notes_labels(pr.org, pr.project) 122 for label in pr_labels 123 ) 124 125 126def is_label_err_comment(comment: GitHubComment) -> bool: 127 # comment.body_text returns text without markdown 128 no_format_title = LABEL_ERR_MSG_TITLE.replace("`", "") 129 return ( 130 comment.body_text.lstrip(" #").startswith(no_format_title) 131 and comment.author_login in BOT_AUTHORS 132 ) 133