• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import argparse
8import os
9import re
10
11from typing import List
12
13# Provided by the PyGithub pip package.
14from github import Auth, Github
15from github.Repository import Repository
16
17
18def parse_args():
19    parser = argparse.ArgumentParser(
20        description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
21    )
22    parser.add_argument(
23        "--repo",
24        type=str,
25        help='The github repo to modify: e.g. "pytorch/executorch".',
26        required=True,
27    )
28    parser.add_argument(
29        "--ref",
30        type=str,
31        help="Ref fo PR in the stack to check and create corresponding PR",
32        required=True,
33    )
34    return parser.parse_args()
35
36
37def extract_stack_from_body(pr_body: str) -> List[int]:
38    """Extracts a list of PR numbers from a ghexport-generated PR body.
39
40    The base of the stack is in index 0.
41    """
42
43    # Expected format. The `__->__` could appear on any line. Stop parsing
44    # after the blank line. This would return [1, 2, 3].
45    """
46    Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom):
47    * #3
48    * __->__ #2
49    * #1
50
51    <PR description details>
52    """
53
54    prs = []
55    ghstack_begin = (
56        "Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom):"
57    )
58    ghstack_begin_seen = False
59    for line in pr_body.splitlines():
60        if ghstack_begin in line:
61            ghstack_begin_seen = True
62        if not ghstack_begin_seen:
63            continue
64        match = re.match(r"\*(?:.*?)? #(\d+)", line)
65        if match:
66            # It's a bullet followed by an integer.
67            prs.append(int(match.group(1)))
68    return list(reversed(prs))
69
70
71def get_pr_stack_from_number(ref: str, repo: Repository) -> List[int]:
72    if ref.isnumeric():
73        pr_number = int(ref)
74    else:
75        branch_name = ref.replace("refs/heads/", "")
76        pr_number = repo.get_branch(branch_name).commit.get_pulls()[0].number
77
78    pr_stack = extract_stack_from_body(repo.get_pull(pr_number).body)
79
80    if not pr_stack:
81        raise Exception(
82            f"Could not find PR stack in body of ref. "
83            + "Please make sure that the PR was created with ghstack."
84        )
85
86    return pr_stack
87
88
89def create_prs_for_orig_branch(pr_stack: List[int], repo: Repository):
90    # For the first PR, we want to merge to `main` branch, and we will update
91    # as we go through the stack
92    orig_branch_merge_base = "main"
93    for i in range(len(pr_stack)):
94        pr = repo.get_pull(pr_stack[i])
95        if not pr.is_merged():
96            print("The PR (and stack above) is not merged yet, skipping")
97            return
98        # Check for invariant: For the current PR, it must be gh/user/x/base <- gh/user/x/head
99        assert pr.base.ref.replace("base", "head") == pr.head.ref
100        # The PR we want to create is then "branch_to_merge" <- gh/user/x/orig
101        # gh/user/x/orig is the clean diff between gh/user/x/base <- gh/user/x/head
102        orig_branch_merge_head = pr.base.ref.replace("base", "orig")
103        bot_metadata = f"""This PR was created by the merge bot to help merge the original PR into the main branch.
104ghstack PR number: https://github.com/pytorch/executorch/pull/{pr.number}
105^ Please use this as the source of truth for the PR details, comments, and reviews
106ghstack PR base: https://github.com/pytorch/executorch/tree/{pr.base.ref}
107ghstack PR head: https://github.com/pytorch/executorch/tree/{pr.head.ref}
108Merge bot PR base: https://github.com/pytorch/executorch/tree/{orig_branch_merge_base}
109Merge bot PR head: https://github.com/pytorch/executorch/tree/{orig_branch_merge_head}
110@diff-train-skip-merge"""
111
112        existing_orig_pr = repo.get_pulls(
113            head="pytorch:" + orig_branch_merge_head,
114            base=orig_branch_merge_base,
115            state="all",
116        )
117        if existing_orig_pr.totalCount > 0 and existing_orig_pr[0].title == pr.title:
118            print(
119                f"PR for {orig_branch_merge_head} already exists {existing_orig_pr[0]}"
120            )
121            # We don't need to create/edit because the head PR is merged and orig is finalized.
122        else:
123            repo.create_pull(
124                base=orig_branch_merge_base,
125                head=orig_branch_merge_head,
126                title=pr.title,
127                body=bot_metadata,
128            )
129        # Advance the base for the next PR
130        orig_branch_merge_base = orig_branch_merge_head
131
132
133def main():
134    args = parse_args()
135
136    with Github(auth=Auth.Token(os.environ["GITHUB_TOKEN"])) as gh:
137        repo = gh.get_repo(args.repo)
138        create_prs_for_orig_branch(get_pr_stack_from_number(args.ref, repo), repo)
139
140
141if __name__ == "__main__":
142    main()
143