• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2# Copyright (c) Meta Platforms, Inc. and affiliates.
3# All rights reserved.
4#
5# This source code is licensed under the BSD-style license found in the
6# LICENSE file in the root directory of this source tree.
7
8import json
9import os
10import re
11from typing import Any, cast, Dict, List, Optional
12
13from urllib.error import HTTPError
14
15from github_utils import gh_fetch_url, gh_post_pr_comment, gh_query_issues_by_labels
16
17from gitutils import get_git_remote_name, get_git_repo_dir, GitRepo
18from trymerge import get_pr_commit_sha, GitHubPR
19
20
21# This is only a suggestion for now, not a strict requirement
22REQUIRES_ISSUE = {
23    "regression",
24    "critical",
25    "fixnewfeature",
26}
27RELEASE_BRANCH_REGEX = re.compile(r"release/(?P<version>.+)")
28
29
30def parse_args() -> Any:
31    from argparse import ArgumentParser
32
33    parser = ArgumentParser("cherry pick a landed PR onto a release branch")
34    parser.add_argument(
35        "--onto-branch", type=str, required=True, help="the target release branch"
36    )
37    parser.add_argument(
38        "--github-actor", type=str, required=True, help="all the world's a stage"
39    )
40    parser.add_argument(
41        "--classification",
42        choices=["regression", "critical", "fixnewfeature", "docs", "release"],
43        required=True,
44        help="the cherry pick category",
45    )
46    parser.add_argument("pr_num", type=int)
47    parser.add_argument(
48        "--fixes",
49        type=str,
50        default="",
51        help="the GitHub issue that the cherry pick fixes",
52    )
53    parser.add_argument("--dry-run", action="store_true")
54
55    return parser.parse_args()
56
57
58def get_merge_commit_sha(repo: GitRepo, pr: GitHubPR) -> Optional[str]:
59    """
60    Return the merge commit SHA iff the PR has been merged. For simplicity, we
61    will only cherry pick PRs that have been merged into main
62    """
63    commit_sha = get_pr_commit_sha(repo, pr)
64    return commit_sha if pr.is_closed() else None
65
66
67def get_release_version(onto_branch: str) -> Optional[str]:
68    """
69    Return the release version if the target branch is a release branch
70    """
71    m = re.match(RELEASE_BRANCH_REGEX, onto_branch)
72    return m.group("version") if m else ""
73
74
75def get_tracker_issues(
76    org: str, project: str, onto_branch: str
77) -> List[Dict[str, Any]]:
78    """
79    Find the tracker issue from the repo. The tracker issue needs to have the title
80    like [VERSION] Release Tracker following the convention on PyTorch
81    """
82    version = get_release_version(onto_branch)
83    if not version:
84        return []
85
86    tracker_issues = gh_query_issues_by_labels(org, project, labels=["release tracker"])
87    if not tracker_issues:
88        return []
89
90    # Figure out the tracker issue from the list by looking at the title
91    return [issue for issue in tracker_issues if version in issue.get("title", "")]
92
93
94def cherry_pick(
95    github_actor: str,
96    repo: GitRepo,
97    pr: GitHubPR,
98    commit_sha: str,
99    onto_branch: str,
100    classification: str,
101    fixes: str,
102    dry_run: bool = False,
103) -> None:
104    """
105    Create a local branch to cherry pick the commit and submit it as a pull request
106    """
107    current_branch = repo.current_branch()
108    cherry_pick_branch = create_cherry_pick_branch(
109        github_actor, repo, pr, commit_sha, onto_branch
110    )
111
112    try:
113        org, project = repo.gh_owner_and_name()
114
115        cherry_pick_pr = ""
116        if not dry_run:
117            cherry_pick_pr = submit_pr(repo, pr, cherry_pick_branch, onto_branch)
118
119        tracker_issues_comments = []
120        tracker_issues = get_tracker_issues(org, project, onto_branch)
121        for issue in tracker_issues:
122            issue_number = int(str(issue.get("number", "0")))
123            if not issue_number:
124                continue
125
126            res = cast(
127                Dict[str, Any],
128                post_tracker_issue_comment(
129                    org,
130                    project,
131                    issue_number,
132                    pr.pr_num,
133                    cherry_pick_pr,
134                    classification,
135                    fixes,
136                    dry_run,
137                ),
138            )
139
140            comment_url = res.get("html_url", "")
141            if comment_url:
142                tracker_issues_comments.append(comment_url)
143
144        msg = f"The cherry pick PR is at {cherry_pick_pr}"
145        if fixes:
146            msg += f" and it is linked with issue {fixes}."
147        elif classification in REQUIRES_ISSUE:
148            msg += f" and it is recommended to link a {classification} cherry pick PR with an issue."
149
150        if tracker_issues_comments:
151            msg += " The following tracker issues are updated:\n"
152            for tracker_issues_comment in tracker_issues_comments:
153                msg += f"* {tracker_issues_comment}\n"
154
155        post_pr_comment(org, project, pr.pr_num, msg, dry_run)
156
157    finally:
158        if current_branch:
159            repo.checkout(branch=current_branch)
160
161
162def create_cherry_pick_branch(
163    github_actor: str, repo: GitRepo, pr: GitHubPR, commit_sha: str, onto_branch: str
164) -> str:
165    """
166    Create a local branch and cherry pick the commit. Return the name of the local
167    cherry picking branch.
168    """
169    repo.checkout(branch=onto_branch)
170    repo._run_git("submodule", "update", "--init", "--recursive")
171
172    # Remove all special characters if we want to include the actor in the branch name
173    github_actor = re.sub("[^0-9a-zA-Z]+", "_", github_actor)
174
175    cherry_pick_branch = f"cherry-pick-{pr.pr_num}-by-{github_actor}"
176    repo.create_branch_and_checkout(branch=cherry_pick_branch)
177
178    # We might want to support ghstack later
179    repo._run_git("cherry-pick", "-x", "-X", "theirs", commit_sha)
180    repo.push(branch=cherry_pick_branch, dry_run=False)
181
182    return cherry_pick_branch
183
184
185def submit_pr(
186    repo: GitRepo,
187    pr: GitHubPR,
188    cherry_pick_branch: str,
189    onto_branch: str,
190) -> str:
191    """
192    Submit the cherry pick PR and return the link to the PR
193    """
194    org, project = repo.gh_owner_and_name()
195
196    default_msg = f"Cherry pick #{pr.pr_num} onto {onto_branch} branch"
197    title = pr.info.get("title", default_msg)
198    body = pr.info.get("body", default_msg)
199
200    try:
201        response = gh_fetch_url(
202            f"https://api.github.com/repos/{org}/{project}/pulls",
203            method="POST",
204            data={
205                "title": title,
206                "body": body,
207                "head": cherry_pick_branch,
208                "base": onto_branch,
209            },
210            headers={"Accept": "application/vnd.github.v3+json"},
211            reader=json.load,
212        )
213
214        cherry_pick_pr = response.get("html_url", "")
215        if not cherry_pick_pr:
216            raise RuntimeError(
217                f"Fail to find the cherry pick PR: {json.dumps(response)}"
218            )
219
220        return str(cherry_pick_pr)
221
222    except HTTPError as error:
223        msg = f"Fail to submit the cherry pick PR: {error}"
224        raise RuntimeError(msg) from error
225
226
227def post_pr_comment(
228    org: str, project: str, pr_num: int, msg: str, dry_run: bool = False
229) -> List[Dict[str, Any]]:
230    """
231    Post a comment on the PR itself to point to the cherry picking PR when success
232    or print the error when failure
233    """
234    internal_debugging = ""
235
236    run_url = os.getenv("GH_RUN_URL")
237    # Post a comment to tell folks that the PR is being cherry picked
238    if run_url is not None:
239        internal_debugging = "\n".join(
240            line
241            for line in (
242                "<details><summary>Details for Dev Infra team</summary>",
243                f'Raised by <a href="{run_url}">workflow job</a>\n',
244                "</details>",
245            )
246            if line
247        )
248
249    comment = "\n".join(
250        (f"### Cherry picking #{pr_num}", f"{msg}", "", f"{internal_debugging}")
251    )
252    return gh_post_pr_comment(org, project, pr_num, comment, dry_run)
253
254
255def post_tracker_issue_comment(
256    org: str,
257    project: str,
258    issue_num: int,
259    pr_num: int,
260    cherry_pick_pr: str,
261    classification: str,
262    fixes: str,
263    dry_run: bool = False,
264) -> List[Dict[str, Any]]:
265    """
266    Post a comment on the tracker issue (if any) to record the cherry pick
267    """
268    comment = "\n".join(
269        (
270            "Link to landed trunk PR (if applicable):",
271            f"* https://github.com/{org}/{project}/pull/{pr_num}",
272            "",
273            "Link to release branch PR:",
274            f"* {cherry_pick_pr}",
275            "",
276            "Criteria Category:",
277            " - ".join((classification.capitalize(), fixes.capitalize())),
278        )
279    )
280    return gh_post_pr_comment(org, project, issue_num, comment, dry_run)
281
282
283def main() -> None:
284    args = parse_args()
285    pr_num = args.pr_num
286
287    repo = GitRepo(get_git_repo_dir(), get_git_remote_name())
288    org, project = repo.gh_owner_and_name()
289
290    pr = GitHubPR(org, project, pr_num)
291
292    try:
293        commit_sha = get_merge_commit_sha(repo, pr)
294        if not commit_sha:
295            raise RuntimeError(
296                f"Refuse to cherry pick #{pr_num} because it hasn't been merged yet"
297            )
298
299        cherry_pick(
300            args.github_actor,
301            repo,
302            pr,
303            commit_sha,
304            args.onto_branch,
305            args.classification,
306            args.fixes,
307            args.dry_run,
308        )
309
310    except RuntimeError as error:
311        if not args.dry_run:
312            post_pr_comment(org, project, pr_num, str(error))
313        else:
314            raise error
315
316
317if __name__ == "__main__":
318    main()
319