1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import argparse 8import os 9import re 10 11from typing import List 12 13# Provided by the PyGithub pip package. 14from github import Auth, Github 15from github.Repository import Repository 16 17 18def parse_args(): 19 parser = argparse.ArgumentParser( 20 description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter 21 ) 22 parser.add_argument( 23 "--repo", 24 type=str, 25 help='The github repo to modify: e.g. "pytorch/executorch".', 26 required=True, 27 ) 28 parser.add_argument( 29 "--ref", 30 type=str, 31 help="Ref fo PR in the stack to check and create corresponding PR", 32 required=True, 33 ) 34 return parser.parse_args() 35 36 37def extract_stack_from_body(pr_body: str) -> List[int]: 38 """Extracts a list of PR numbers from a ghexport-generated PR body. 39 40 The base of the stack is in index 0. 41 """ 42 43 # Expected format. The `__->__` could appear on any line. Stop parsing 44 # after the blank line. This would return [1, 2, 3]. 45 """ 46 Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): 47 * #3 48 * __->__ #2 49 * #1 50 51 <PR description details> 52 """ 53 54 prs = [] 55 ghstack_begin = ( 56 "Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom):" 57 ) 58 ghstack_begin_seen = False 59 for line in pr_body.splitlines(): 60 if ghstack_begin in line: 61 ghstack_begin_seen = True 62 if not ghstack_begin_seen: 63 continue 64 match = re.match(r"\*(?:.*?)? #(\d+)", line) 65 if match: 66 # It's a bullet followed by an integer. 67 prs.append(int(match.group(1))) 68 return list(reversed(prs)) 69 70 71def get_pr_stack_from_number(ref: str, repo: Repository) -> List[int]: 72 if ref.isnumeric(): 73 pr_number = int(ref) 74 else: 75 branch_name = ref.replace("refs/heads/", "") 76 pr_number = repo.get_branch(branch_name).commit.get_pulls()[0].number 77 78 pr_stack = extract_stack_from_body(repo.get_pull(pr_number).body) 79 80 if not pr_stack: 81 raise Exception( 82 f"Could not find PR stack in body of ref. " 83 + "Please make sure that the PR was created with ghstack." 84 ) 85 86 return pr_stack 87 88 89def create_prs_for_orig_branch(pr_stack: List[int], repo: Repository): 90 # For the first PR, we want to merge to `main` branch, and we will update 91 # as we go through the stack 92 orig_branch_merge_base = "main" 93 for i in range(len(pr_stack)): 94 pr = repo.get_pull(pr_stack[i]) 95 if not pr.is_merged(): 96 print("The PR (and stack above) is not merged yet, skipping") 97 return 98 # Check for invariant: For the current PR, it must be gh/user/x/base <- gh/user/x/head 99 assert pr.base.ref.replace("base", "head") == pr.head.ref 100 # The PR we want to create is then "branch_to_merge" <- gh/user/x/orig 101 # gh/user/x/orig is the clean diff between gh/user/x/base <- gh/user/x/head 102 orig_branch_merge_head = pr.base.ref.replace("base", "orig") 103 bot_metadata = f"""This PR was created by the merge bot to help merge the original PR into the main branch. 104ghstack PR number: https://github.com/pytorch/executorch/pull/{pr.number} 105^ Please use this as the source of truth for the PR details, comments, and reviews 106ghstack PR base: https://github.com/pytorch/executorch/tree/{pr.base.ref} 107ghstack PR head: https://github.com/pytorch/executorch/tree/{pr.head.ref} 108Merge bot PR base: https://github.com/pytorch/executorch/tree/{orig_branch_merge_base} 109Merge bot PR head: https://github.com/pytorch/executorch/tree/{orig_branch_merge_head} 110@diff-train-skip-merge""" 111 112 existing_orig_pr = repo.get_pulls( 113 head="pytorch:" + orig_branch_merge_head, 114 base=orig_branch_merge_base, 115 state="all", 116 ) 117 if existing_orig_pr.totalCount > 0 and existing_orig_pr[0].title == pr.title: 118 print( 119 f"PR for {orig_branch_merge_head} already exists {existing_orig_pr[0]}" 120 ) 121 # We don't need to create/edit because the head PR is merged and orig is finalized. 122 else: 123 repo.create_pull( 124 base=orig_branch_merge_base, 125 head=orig_branch_merge_head, 126 title=pr.title, 127 body=bot_metadata, 128 ) 129 # Advance the base for the next PR 130 orig_branch_merge_base = orig_branch_merge_head 131 132 133def main(): 134 args = parse_args() 135 136 with Github(auth=Auth.Token(os.environ["GITHUB_TOKEN"])) as gh: 137 repo = gh.get_repo(args.repo) 138 create_prs_for_orig_branch(get_pr_stack_from_number(args.ref, repo), repo) 139 140 141if __name__ == "__main__": 142 main() 143