1#!/usr/bin/env python3 2# Copyright (c) Meta Platforms, Inc. and affiliates. 3# All rights reserved. 4# 5# This source code is licensed under the BSD-style license found in the 6# LICENSE file in the root directory of this source tree. 7 8import json 9import os 10import re 11from typing import Any, cast, Dict, List, Optional 12 13from urllib.error import HTTPError 14 15from github_utils import gh_fetch_url, gh_post_pr_comment, gh_query_issues_by_labels 16 17from gitutils import get_git_remote_name, get_git_repo_dir, GitRepo 18from trymerge import get_pr_commit_sha, GitHubPR 19 20 21# This is only a suggestion for now, not a strict requirement 22REQUIRES_ISSUE = { 23 "regression", 24 "critical", 25 "fixnewfeature", 26} 27RELEASE_BRANCH_REGEX = re.compile(r"release/(?P<version>.+)") 28 29 30def parse_args() -> Any: 31 from argparse import ArgumentParser 32 33 parser = ArgumentParser("cherry pick a landed PR onto a release branch") 34 parser.add_argument( 35 "--onto-branch", type=str, required=True, help="the target release branch" 36 ) 37 parser.add_argument( 38 "--github-actor", type=str, required=True, help="all the world's a stage" 39 ) 40 parser.add_argument( 41 "--classification", 42 choices=["regression", "critical", "fixnewfeature", "docs", "release"], 43 required=True, 44 help="the cherry pick category", 45 ) 46 parser.add_argument("pr_num", type=int) 47 parser.add_argument( 48 "--fixes", 49 type=str, 50 default="", 51 help="the GitHub issue that the cherry pick fixes", 52 ) 53 parser.add_argument("--dry-run", action="store_true") 54 55 return parser.parse_args() 56 57 58def get_merge_commit_sha(repo: GitRepo, pr: GitHubPR) -> Optional[str]: 59 """ 60 Return the merge commit SHA iff the PR has been merged. For simplicity, we 61 will only cherry pick PRs that have been merged into main 62 """ 63 commit_sha = get_pr_commit_sha(repo, pr) 64 return commit_sha if pr.is_closed() else None 65 66 67def get_release_version(onto_branch: str) -> Optional[str]: 68 """ 69 Return the release version if the target branch is a release branch 70 """ 71 m = re.match(RELEASE_BRANCH_REGEX, onto_branch) 72 return m.group("version") if m else "" 73 74 75def get_tracker_issues( 76 org: str, project: str, onto_branch: str 77) -> List[Dict[str, Any]]: 78 """ 79 Find the tracker issue from the repo. The tracker issue needs to have the title 80 like [VERSION] Release Tracker following the convention on PyTorch 81 """ 82 version = get_release_version(onto_branch) 83 if not version: 84 return [] 85 86 tracker_issues = gh_query_issues_by_labels(org, project, labels=["release tracker"]) 87 if not tracker_issues: 88 return [] 89 90 # Figure out the tracker issue from the list by looking at the title 91 return [issue for issue in tracker_issues if version in issue.get("title", "")] 92 93 94def cherry_pick( 95 github_actor: str, 96 repo: GitRepo, 97 pr: GitHubPR, 98 commit_sha: str, 99 onto_branch: str, 100 classification: str, 101 fixes: str, 102 dry_run: bool = False, 103) -> None: 104 """ 105 Create a local branch to cherry pick the commit and submit it as a pull request 106 """ 107 current_branch = repo.current_branch() 108 cherry_pick_branch = create_cherry_pick_branch( 109 github_actor, repo, pr, commit_sha, onto_branch 110 ) 111 112 try: 113 org, project = repo.gh_owner_and_name() 114 115 cherry_pick_pr = "" 116 if not dry_run: 117 cherry_pick_pr = submit_pr(repo, pr, cherry_pick_branch, onto_branch) 118 119 tracker_issues_comments = [] 120 tracker_issues = get_tracker_issues(org, project, onto_branch) 121 for issue in tracker_issues: 122 issue_number = int(str(issue.get("number", "0"))) 123 if not issue_number: 124 continue 125 126 res = cast( 127 Dict[str, Any], 128 post_tracker_issue_comment( 129 org, 130 project, 131 issue_number, 132 pr.pr_num, 133 cherry_pick_pr, 134 classification, 135 fixes, 136 dry_run, 137 ), 138 ) 139 140 comment_url = res.get("html_url", "") 141 if comment_url: 142 tracker_issues_comments.append(comment_url) 143 144 msg = f"The cherry pick PR is at {cherry_pick_pr}" 145 if fixes: 146 msg += f" and it is linked with issue {fixes}." 147 elif classification in REQUIRES_ISSUE: 148 msg += f" and it is recommended to link a {classification} cherry pick PR with an issue." 149 150 if tracker_issues_comments: 151 msg += " The following tracker issues are updated:\n" 152 for tracker_issues_comment in tracker_issues_comments: 153 msg += f"* {tracker_issues_comment}\n" 154 155 post_pr_comment(org, project, pr.pr_num, msg, dry_run) 156 157 finally: 158 if current_branch: 159 repo.checkout(branch=current_branch) 160 161 162def create_cherry_pick_branch( 163 github_actor: str, repo: GitRepo, pr: GitHubPR, commit_sha: str, onto_branch: str 164) -> str: 165 """ 166 Create a local branch and cherry pick the commit. Return the name of the local 167 cherry picking branch. 168 """ 169 repo.checkout(branch=onto_branch) 170 repo._run_git("submodule", "update", "--init", "--recursive") 171 172 # Remove all special characters if we want to include the actor in the branch name 173 github_actor = re.sub("[^0-9a-zA-Z]+", "_", github_actor) 174 175 cherry_pick_branch = f"cherry-pick-{pr.pr_num}-by-{github_actor}" 176 repo.create_branch_and_checkout(branch=cherry_pick_branch) 177 178 # We might want to support ghstack later 179 repo._run_git("cherry-pick", "-x", "-X", "theirs", commit_sha) 180 repo.push(branch=cherry_pick_branch, dry_run=False) 181 182 return cherry_pick_branch 183 184 185def submit_pr( 186 repo: GitRepo, 187 pr: GitHubPR, 188 cherry_pick_branch: str, 189 onto_branch: str, 190) -> str: 191 """ 192 Submit the cherry pick PR and return the link to the PR 193 """ 194 org, project = repo.gh_owner_and_name() 195 196 default_msg = f"Cherry pick #{pr.pr_num} onto {onto_branch} branch" 197 title = pr.info.get("title", default_msg) 198 body = pr.info.get("body", default_msg) 199 200 try: 201 response = gh_fetch_url( 202 f"https://api.github.com/repos/{org}/{project}/pulls", 203 method="POST", 204 data={ 205 "title": title, 206 "body": body, 207 "head": cherry_pick_branch, 208 "base": onto_branch, 209 }, 210 headers={"Accept": "application/vnd.github.v3+json"}, 211 reader=json.load, 212 ) 213 214 cherry_pick_pr = response.get("html_url", "") 215 if not cherry_pick_pr: 216 raise RuntimeError( 217 f"Fail to find the cherry pick PR: {json.dumps(response)}" 218 ) 219 220 return str(cherry_pick_pr) 221 222 except HTTPError as error: 223 msg = f"Fail to submit the cherry pick PR: {error}" 224 raise RuntimeError(msg) from error 225 226 227def post_pr_comment( 228 org: str, project: str, pr_num: int, msg: str, dry_run: bool = False 229) -> List[Dict[str, Any]]: 230 """ 231 Post a comment on the PR itself to point to the cherry picking PR when success 232 or print the error when failure 233 """ 234 internal_debugging = "" 235 236 run_url = os.getenv("GH_RUN_URL") 237 # Post a comment to tell folks that the PR is being cherry picked 238 if run_url is not None: 239 internal_debugging = "\n".join( 240 line 241 for line in ( 242 "<details><summary>Details for Dev Infra team</summary>", 243 f'Raised by <a href="{run_url}">workflow job</a>\n', 244 "</details>", 245 ) 246 if line 247 ) 248 249 comment = "\n".join( 250 (f"### Cherry picking #{pr_num}", f"{msg}", "", f"{internal_debugging}") 251 ) 252 return gh_post_pr_comment(org, project, pr_num, comment, dry_run) 253 254 255def post_tracker_issue_comment( 256 org: str, 257 project: str, 258 issue_num: int, 259 pr_num: int, 260 cherry_pick_pr: str, 261 classification: str, 262 fixes: str, 263 dry_run: bool = False, 264) -> List[Dict[str, Any]]: 265 """ 266 Post a comment on the tracker issue (if any) to record the cherry pick 267 """ 268 comment = "\n".join( 269 ( 270 "Link to landed trunk PR (if applicable):", 271 f"* https://github.com/{org}/{project}/pull/{pr_num}", 272 "", 273 "Link to release branch PR:", 274 f"* {cherry_pick_pr}", 275 "", 276 "Criteria Category:", 277 " - ".join((classification.capitalize(), fixes.capitalize())), 278 ) 279 ) 280 return gh_post_pr_comment(org, project, issue_num, comment, dry_run) 281 282 283def main() -> None: 284 args = parse_args() 285 pr_num = args.pr_num 286 287 repo = GitRepo(get_git_repo_dir(), get_git_remote_name()) 288 org, project = repo.gh_owner_and_name() 289 290 pr = GitHubPR(org, project, pr_num) 291 292 try: 293 commit_sha = get_merge_commit_sha(repo, pr) 294 if not commit_sha: 295 raise RuntimeError( 296 f"Refuse to cherry pick #{pr_num} because it hasn't been merged yet" 297 ) 298 299 cherry_pick( 300 args.github_actor, 301 repo, 302 pr, 303 commit_sha, 304 args.onto_branch, 305 args.classification, 306 args.fixes, 307 args.dry_run, 308 ) 309 310 except RuntimeError as error: 311 if not args.dry_run: 312 post_pr_comment(org, project, pr_num, str(error)) 313 else: 314 raise error 315 316 317if __name__ == "__main__": 318 main() 319