• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7"""GitHub Label Utilities."""
8
9import json
10
11from functools import lru_cache
12from typing import Any, List, Tuple, TYPE_CHECKING, Union
13
14from github_utils import gh_fetch_url_and_headers, GitHubComment
15
16# TODO: this is a temp workaround to avoid circular dependencies,
17#       and should be removed once GitHubPR is refactored out of trymerge script.
18if TYPE_CHECKING:
19    from trymerge import GitHubPR
20
21BOT_AUTHORS = ["github-actions", "pytorchmergebot", "pytorch-bot"]
22
23LABEL_ERR_MSG_TITLE = "This PR needs a `release notes:` label"
24LABEL_ERR_MSG = f"""# {LABEL_ERR_MSG_TITLE}
25If your changes are user facing and intended to be a part of release notes, please use a label starting with `release notes:`.
26
27If not, please add the `topic: not user facing` label.
28
29To add a label, you can comment to pytorchbot, for example
30`@pytorchbot label "topic: not user facing"`
31
32For more information, see
33https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.
34"""
35
36
37def request_for_labels(url: str) -> Tuple[Any, Any]:
38    headers = {"Accept": "application/vnd.github.v3+json"}
39    return gh_fetch_url_and_headers(
40        url, headers=headers, reader=lambda x: x.read().decode("utf-8")
41    )
42
43
44def update_labels(labels: List[str], info: str) -> None:
45    labels_json = json.loads(info)
46    labels.extend([x["name"] for x in labels_json])
47
48
49def get_last_page_num_from_header(header: Any) -> int:
50    # Link info looks like: <https://api.github.com/repositories/65600975/labels?per_page=100&page=2>;
51    # rel="next", <https://api.github.com/repositories/65600975/labels?per_page=100&page=3>; rel="last"
52    link_info = header["link"]
53    # Docs does not specify that it should be present for projects with just few labels
54    # And https://github.com/malfet/deleteme/actions/runs/7334565243/job/19971396887 it's not the case
55    if link_info is None:
56        return 1
57    prefix = "&page="
58    suffix = ">;"
59    return int(
60        link_info[link_info.rindex(prefix) + len(prefix) : link_info.rindex(suffix)]
61    )
62
63
64@lru_cache
65def gh_get_labels(org: str, repo: str) -> List[str]:
66    prefix = f"https://api.github.com/repos/{org}/{repo}/labels?per_page=100"
67    header, info = request_for_labels(prefix + "&page=1")
68    labels: List[str] = []
69    update_labels(labels, info)
70
71    last_page = get_last_page_num_from_header(header)
72    assert (
73        last_page > 0
74    ), "Error reading header info to determine total number of pages of labels"
75    for page_number in range(2, last_page + 1):  # skip page 1
76        _, info = request_for_labels(prefix + f"&page={page_number}")
77        update_labels(labels, info)
78
79    return labels
80
81
82def gh_add_labels(
83    org: str, repo: str, pr_num: int, labels: Union[str, List[str]], dry_run: bool
84) -> None:
85    if dry_run:
86        print(f"Dryrun: Adding labels {labels} to PR {pr_num}")
87        return
88    gh_fetch_url_and_headers(
89        url=f"https://api.github.com/repos/{org}/{repo}/issues/{pr_num}/labels",
90        data={"labels": labels},
91    )
92
93
94def gh_remove_label(
95    org: str, repo: str, pr_num: int, label: str, dry_run: bool
96) -> None:
97    if dry_run:
98        print(f"Dryrun: Removing {label} from PR {pr_num}")
99        return
100    gh_fetch_url_and_headers(
101        url=f"https://api.github.com/repos/{org}/{repo}/issues/{pr_num}/labels/{label}",
102        method="DELETE",
103    )
104
105
106def get_release_notes_labels(org: str, repo: str) -> List[str]:
107    return [
108        label
109        for label in gh_get_labels(org, repo)
110        if label.lstrip().startswith("release notes:")
111    ]
112
113
114def has_required_labels(pr: "GitHubPR") -> bool:
115    pr_labels = pr.get_labels()
116    # Check if PR is not user facing
117    is_not_user_facing_pr = any(
118        label.strip() == "topic: not user facing" for label in pr_labels
119    )
120    return is_not_user_facing_pr or any(
121        label.strip() in get_release_notes_labels(pr.org, pr.project)
122        for label in pr_labels
123    )
124
125
126def is_label_err_comment(comment: GitHubComment) -> bool:
127    # comment.body_text returns text without markdown
128    no_format_title = LABEL_ERR_MSG_TITLE.replace("`", "")
129    return (
130        comment.body_text.lstrip(" #").startswith(no_format_title)
131        and comment.author_login in BOT_AUTHORS
132    )
133