• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2
3# NB: the following functions are used in Meta-internal workflows
4# (github_first_try_merge/my_handler.py) and thus have functionality limitations
5# (no `git` command access, no network access besides the strict allow list):
6#
7# find_matching_merge_rule
8# read_merge_rules
9#
10# Also any signature changes of these functions, as well as changes to the `GitHubPR`
11# class, will likely require corresponding changes for the internal workflows.
12
13import base64
14import json
15import os
16import re
17import time
18import urllib.parse
19from collections import defaultdict
20from dataclasses import dataclass
21from functools import lru_cache
22from pathlib import Path
23from typing import (
24    Any,
25    Callable,
26    cast,
27    Dict,
28    Iterable,
29    List,
30    NamedTuple,
31    Optional,
32    Pattern,
33    Tuple,
34)
35from warnings import warn
36
37import yaml
38from github_utils import (
39    gh_fetch_json_list,
40    gh_fetch_merge_base,
41    gh_fetch_url,
42    gh_graphql,
43    gh_post_commit_comment,
44    gh_post_pr_comment,
45    gh_update_pr_state,
46    GitHubComment,
47)
48from gitutils import (
49    are_ghstack_branches_in_sync,
50    get_git_remote_name,
51    get_git_repo_dir,
52    GitRepo,
53    patterns_to_regex,
54    retries_decorator,
55)
56from label_utils import (
57    gh_add_labels,
58    gh_remove_label,
59    has_required_labels,
60    LABEL_ERR_MSG,
61)
62from trymerge_explainer import get_revert_message, TryMergeExplainer
63
64
65# labels
66MERGE_IN_PROGRESS_LABEL = "merging"
67MERGE_COMPLETE_LABEL = "merged"
68
69
70class JobCheckState(NamedTuple):
71    name: str
72    url: str
73    status: Optional[str]
74    classification: Optional[str]
75    job_id: Optional[int]
76    title: Optional[str]
77    summary: Optional[str]
78
79
80JobNameToStateDict = Dict[str, JobCheckState]
81
82
83class WorkflowCheckState:
84    def __init__(self, name: str, url: str, run_id: int, status: Optional[str]):
85        self.name: str = name
86        self.url: str = url
87        self.run_id: int = run_id
88        self.status: Optional[str] = status
89        self.jobs: JobNameToStateDict = {}
90
91
92GH_PR_REVIEWS_FRAGMENT = """
93fragment PRReviews on PullRequestReviewConnection {
94  nodes {
95    author {
96      login
97    }
98    bodyText
99    createdAt
100    authorAssociation
101    editor {
102      login
103    }
104    databaseId
105    url
106    state
107  }
108  pageInfo {
109    startCursor
110    hasPreviousPage
111  }
112}
113"""
114
115GH_CHECKSUITES_FRAGMENT = """
116fragment PRCheckSuites on CheckSuiteConnection {
117  edges {
118    node {
119      app {
120        name
121        databaseId
122      }
123      workflowRun {
124        workflow {
125          name
126          databaseId
127        }
128        databaseId
129        url
130      }
131      checkRuns(first: 50) {
132        nodes {
133          name
134          conclusion
135          detailsUrl
136          databaseId
137          title
138          summary
139        }
140        pageInfo {
141          endCursor
142          hasNextPage
143        }
144      }
145      conclusion
146    }
147    cursor
148  }
149  pageInfo {
150    hasNextPage
151  }
152}
153"""
154
155GH_COMMIT_AUTHORS_FRAGMENT = """
156fragment CommitAuthors on PullRequestCommitConnection {
157  nodes {
158    commit {
159      authors(first: 2) {
160        nodes {
161          user {
162            login
163          }
164          email
165          name
166        }
167      }
168      oid
169    }
170  }
171  pageInfo {
172    endCursor
173    hasNextPage
174  }
175}
176"""
177
178GH_GET_PR_INFO_QUERY = (
179    GH_PR_REVIEWS_FRAGMENT
180    + GH_CHECKSUITES_FRAGMENT
181    + GH_COMMIT_AUTHORS_FRAGMENT
182    + """
183query ($owner: String!, $name: String!, $number: Int!) {
184  repository(owner: $owner, name: $name) {
185    pullRequest(number: $number) {
186      closed
187      isCrossRepository
188      author {
189        login
190      }
191      title
192      body
193      headRefName
194      headRepository {
195        nameWithOwner
196      }
197      baseRefName
198      baseRefOid
199      baseRepository {
200        nameWithOwner
201        isPrivate
202        defaultBranchRef {
203          name
204        }
205      }
206      mergeCommit {
207        oid
208      }
209      commits_with_authors: commits(first: 100) {
210        ...CommitAuthors
211        totalCount
212      }
213      commits(last: 1) {
214        nodes {
215          commit {
216            checkSuites(first: 10) {
217              ...PRCheckSuites
218            }
219            status {
220              contexts {
221                context
222                state
223                targetUrl
224              }
225            }
226            oid
227          }
228        }
229      }
230      changedFiles
231      files(first: 100) {
232        nodes {
233          path
234        }
235        pageInfo {
236          endCursor
237          hasNextPage
238        }
239      }
240      reviews(last: 100) {
241        ...PRReviews
242      }
243      comments(last: 5) {
244        nodes {
245          bodyText
246          createdAt
247          author {
248            login
249          }
250          authorAssociation
251          editor {
252            login
253          }
254          databaseId
255          url
256        }
257        pageInfo {
258          startCursor
259          hasPreviousPage
260        }
261      }
262      labels(first: 100) {
263        edges {
264          node {
265            name
266          }
267        }
268      }
269    }
270  }
271}
272"""
273)
274
275GH_GET_PR_NEXT_FILES_QUERY = """
276query ($owner: String!, $name: String!, $number: Int!, $cursor: String!) {
277  repository(name: $name, owner: $owner) {
278    pullRequest(number: $number) {
279      files(first: 100, after: $cursor) {
280        nodes {
281          path
282        }
283        pageInfo {
284          endCursor
285          hasNextPage
286        }
287      }
288    }
289  }
290}
291"""
292
293GH_GET_PR_NEXT_CHECKSUITES = (
294    GH_CHECKSUITES_FRAGMENT
295    + """
296query ($owner: String!, $name: String!, $number: Int!, $cursor: String!) {
297  repository(name: $name, owner: $owner) {
298    pullRequest(number: $number) {
299      commits(last: 1) {
300        nodes {
301          commit {
302            oid
303            checkSuites(first: 10, after: $cursor) {
304              ...PRCheckSuites
305            }
306          }
307        }
308      }
309    }
310  }
311}
312"""
313)
314
315GH_GET_PR_NEXT_CHECK_RUNS = """
316query ($owner: String!, $name: String!, $number: Int!, $cs_cursor: String, $cr_cursor: String!) {
317  repository(name: $name, owner: $owner) {
318    pullRequest(number: $number) {
319      commits(last: 1) {
320        nodes {
321          commit {
322            oid
323            checkSuites(first: 1, after: $cs_cursor) {
324              nodes {
325                checkRuns(first: 100, after: $cr_cursor) {
326                  nodes {
327                    name
328                    conclusion
329                    detailsUrl
330                    databaseId
331                    title
332                    summary
333                  }
334                  pageInfo {
335                    endCursor
336                    hasNextPage
337                  }
338                }
339              }
340            }
341          }
342        }
343      }
344    }
345  }
346}
347"""
348
349GH_GET_PR_PREV_COMMENTS = """
350query ($owner: String!, $name: String!, $number: Int!, $cursor: String!) {
351  repository(name: $name, owner: $owner) {
352    pullRequest(number: $number) {
353      comments(last: 100, before: $cursor) {
354        nodes {
355          bodyText
356          createdAt
357          author {
358            login
359          }
360          authorAssociation
361          editor {
362            login
363          }
364          databaseId
365          url
366        }
367        pageInfo {
368          startCursor
369          hasPreviousPage
370        }
371      }
372    }
373  }
374}
375"""
376
377# This query needs read-org permission
378GH_GET_TEAM_MEMBERS_QUERY = """
379query($org: String!, $name: String!, $cursor: String) {
380  organization(login: $org) {
381    team(slug: $name) {
382      members(first: 100, after: $cursor) {
383        nodes {
384          login
385        }
386        pageInfo {
387          hasNextPage
388          endCursor
389        }
390      }
391    }
392  }
393}
394"""
395
396GH_GET_PR_NEXT_AUTHORS_QUERY = (
397    GH_COMMIT_AUTHORS_FRAGMENT
398    + """
399query ($owner: String!, $name: String!, $number: Int!, $cursor: String) {
400  repository(name: $name, owner: $owner) {
401    pullRequest(number: $number) {
402      commits_with_authors: commits(first: 100, after: $cursor) {
403        ...CommitAuthors
404      }
405    }
406  }
407}
408"""
409)
410
411GH_GET_PR_PREV_REVIEWS_QUERY = (
412    GH_PR_REVIEWS_FRAGMENT
413    + """
414query ($owner: String!, $name: String!, $number: Int!, $cursor: String!) {
415  repository(name: $name, owner: $owner) {
416    pullRequest(number: $number) {
417      reviews(last: 100, before: $cursor) {
418        ...PRReviews
419      }
420    }
421  }
422}
423"""
424)
425
426GH_GET_REPO_SUBMODULES = """
427query ($owner: String!, $name: String!) {
428  repository(owner: $owner, name: $name) {
429    submodules(first: 100) {
430      nodes {
431        path
432      }
433      pageInfo {
434        endCursor
435        hasNextPage
436      }
437    }
438  }
439}
440"""
441
442RE_GHSTACK_HEAD_REF = re.compile(r"^(gh/[^/]+/[0-9]+/)head$")
443RE_GHSTACK_DESC = re.compile(r"Stack.*:\r?\n(\* [^\r\n]+\r?\n)+", re.MULTILINE)
444RE_PULL_REQUEST_RESOLVED = re.compile(
445    r"Pull Request resolved: "
446    r"https://github.com/(?P<owner>[^/]+)/(?P<repo>[^/]+)/pull/(?P<number>[0-9]+)",
447    re.MULTILINE,
448)
449RE_PR_CC_LINE = re.compile(r"^cc:? @\w+.*\r?\n?$", re.MULTILINE)
450RE_DIFF_REV = re.compile(r"^Differential Revision:.+?(D[0-9]+)", re.MULTILINE)
451CIFLOW_LABEL = re.compile(r"^ciflow/.+")
452CIFLOW_TRUNK_LABEL = re.compile(r"^ciflow/trunk")
453MERGE_RULE_PATH = Path(".github") / "merge_rules.yaml"
454ROCKSET_MERGES_COLLECTION = "merges"
455ROCKSET_MERGES_WORKSPACE = "commons"
456REMOTE_MAIN_BRANCH = "origin/main"
457DRCI_CHECKRUN_NAME = "Dr.CI"
458INTERNAL_CHANGES_CHECKRUN_NAME = "Meta Internal-Only Changes Check"
459HAS_NO_CONNECTED_DIFF_TITLE = (
460    "There is no internal Diff connected, this can be merged now"
461)
462# This could be set to -1 to ignore all flaky and broken trunk failures. On the
463# other hand, using a large value like 10 here might be useful in sev situation
464IGNORABLE_FAILED_CHECKS_THESHOLD = 10
465
466
467def gh_get_pr_info(org: str, proj: str, pr_no: int) -> Any:
468    rc = gh_graphql(GH_GET_PR_INFO_QUERY, name=proj, owner=org, number=pr_no)
469    return rc["data"]["repository"]["pullRequest"]
470
471
472@lru_cache(maxsize=None)
473def gh_get_team_members(org: str, name: str) -> List[str]:
474    rc: List[str] = []
475    team_members: Dict[str, Any] = {
476        "pageInfo": {"hasNextPage": "true", "endCursor": None}
477    }
478    while bool(team_members["pageInfo"]["hasNextPage"]):
479        query = gh_graphql(
480            GH_GET_TEAM_MEMBERS_QUERY,
481            org=org,
482            name=name,
483            cursor=team_members["pageInfo"]["endCursor"],
484        )
485        team = query["data"]["organization"]["team"]
486        if team is None:
487            warn(f"Requested non-existing team {org}/{name}")
488            return []
489        team_members = team["members"]
490        rc += [member["login"] for member in team_members["nodes"]]
491    return rc
492
493
494def get_check_run_name_prefix(workflow_run: Any) -> str:
495    if workflow_run is None:
496        return ""
497    else:
498        return f'{workflow_run["workflow"]["name"]} / '
499
500
501def is_passing_status(status: Optional[str]) -> bool:
502    return status is not None and status.upper() in ["SUCCESS", "SKIPPED", "NEUTRAL"]
503
504
505def add_workflow_conclusions(
506    checksuites: Any,
507    get_next_checkruns_page: Callable[[List[Dict[str, Dict[str, Any]]], int, Any], Any],
508    get_next_checksuites: Callable[[Any], Any],
509) -> JobNameToStateDict:
510    # graphql seems to favor the most recent workflow run, so in theory we
511    # shouldn't need to account for reruns, but do it just in case
512
513    # workflow -> job -> job info
514    workflows: Dict[str, WorkflowCheckState] = {}
515
516    # for the jobs that don't have a workflow
517    no_workflow_obj: WorkflowCheckState = WorkflowCheckState("", "", 0, None)
518
519    def add_conclusions(edges: Any) -> None:
520        for edge_idx, edge in enumerate(edges):
521            node = edge["node"]
522            workflow_run = node["workflowRun"]
523            checkruns = node["checkRuns"]
524
525            workflow_obj: WorkflowCheckState = no_workflow_obj
526
527            if workflow_run is not None:
528                # This is the usual workflow run ID we see on GitHub
529                workflow_run_id = workflow_run["databaseId"]
530                # While this is the metadata name and ID of the workflow itself
531                workflow_name = workflow_run["workflow"]["name"]
532                workflow_id = workflow_run["workflow"]["databaseId"]
533
534                workflow_conclusion = node["conclusion"]
535                # Do not override existing status with cancelled
536                if workflow_conclusion == "CANCELLED" and workflow_name in workflows:
537                    continue
538
539                # Only keep the latest workflow run for each workflow, heuristically,
540                # it's the run with largest run ID
541                if (
542                    workflow_id not in workflows
543                    or workflows[workflow_id].run_id < workflow_run_id
544                ):
545                    workflows[workflow_id] = WorkflowCheckState(
546                        name=workflow_name,
547                        status=workflow_conclusion,
548                        url=workflow_run["url"],
549                        run_id=workflow_run_id,
550                    )
551                workflow_obj = workflows[workflow_id]
552
553            while checkruns is not None:
554                for checkrun_node in checkruns["nodes"]:
555                    if not isinstance(checkrun_node, dict):
556                        warn(f"Expected dictionary, but got {type(checkrun_node)}")
557                        continue
558                    checkrun_name = f'{get_check_run_name_prefix(workflow_run)}{checkrun_node["name"]}'
559                    existing_checkrun = workflow_obj.jobs.get(checkrun_name)
560                    if existing_checkrun is None or not is_passing_status(
561                        existing_checkrun.status
562                    ):
563                        workflow_obj.jobs[checkrun_name] = JobCheckState(
564                            checkrun_name,
565                            checkrun_node["detailsUrl"],
566                            checkrun_node["conclusion"],
567                            classification=None,
568                            job_id=checkrun_node["databaseId"],
569                            title=checkrun_node["title"],
570                            summary=checkrun_node["summary"],
571                        )
572
573                if bool(checkruns["pageInfo"]["hasNextPage"]):
574                    checkruns = get_next_checkruns_page(edges, edge_idx, checkruns)
575                else:
576                    checkruns = None
577
578    all_edges = checksuites["edges"].copy()
579    while bool(checksuites["pageInfo"]["hasNextPage"]):
580        checksuites = get_next_checksuites(checksuites)
581        all_edges.extend(checksuites["edges"])
582
583    add_conclusions(all_edges)
584
585    # Flatten the dictionaries.  If there exists jobs in the workflow run, put
586    # the jobs in but don't put the workflow in.  We care more about the jobs in
587    # the workflow that ran than the container workflow.
588    res: JobNameToStateDict = {}
589    for workflow in workflows.values():
590        if len(workflow.jobs) > 0:
591            for job_name, job in workflow.jobs.items():
592                res[job_name] = job
593        else:
594            res[workflow.name] = JobCheckState(
595                workflow.name,
596                workflow.url,
597                workflow.status,
598                classification=None,
599                job_id=None,
600                title=None,
601                summary=None,
602            )
603    for job_name, job in no_workflow_obj.jobs.items():
604        res[job_name] = job
605    return res
606
607
608def parse_args() -> Any:
609    from argparse import ArgumentParser
610
611    parser = ArgumentParser("Merge PR into default branch")
612    parser.add_argument("--dry-run", action="store_true")
613    parser.add_argument("--revert", action="store_true")
614    parser.add_argument("--force", action="store_true")
615    parser.add_argument("--ignore-current", action="store_true")
616    parser.add_argument("--check-mergeability", action="store_true")
617    parser.add_argument("--comment-id", type=int)
618    parser.add_argument("--reason", type=str)
619    parser.add_argument("pr_num", type=int)
620    return parser.parse_args()
621
622
623def can_skip_internal_checks(pr: "GitHubPR", comment_id: Optional[int] = None) -> bool:
624    if comment_id is None:
625        return False
626    comment = pr.get_comment_by_id(comment_id)
627    if comment.editor_login is not None:
628        return False
629    return comment.author_login == "facebook-github-bot"
630
631
632def _revlist_to_prs(
633    repo: GitRepo,
634    pr: "GitHubPR",
635    rev_list: Iterable[str],
636    should_skip: Optional[Callable[[int, "GitHubPR"], bool]] = None,
637) -> List[Tuple["GitHubPR", str]]:
638    rc: List[Tuple[GitHubPR, str]] = []
639    for idx, rev in enumerate(rev_list):
640        msg = repo.commit_message(rev)
641        m = RE_PULL_REQUEST_RESOLVED.search(msg)
642        if m is None:
643            raise RuntimeError(
644                f"Could not find PR-resolved string in {msg} of ghstacked PR {pr.pr_num}"
645            )
646        if pr.org != m.group("owner") or pr.project != m.group("repo"):
647            raise RuntimeError(
648                f"PR {m.group('number')} resolved to wrong owner/repo pair"
649            )
650        pr_num = int(m.group("number"))
651        candidate = GitHubPR(pr.org, pr.project, pr_num) if pr_num != pr.pr_num else pr
652        if should_skip is not None and should_skip(idx, candidate):
653            continue
654        rc.append((candidate, rev))
655    return rc
656
657
658def get_ghstack_prs(
659    repo: GitRepo, pr: "GitHubPR", open_only: bool = True
660) -> List[Tuple["GitHubPR", str]]:
661    """
662    Get the PRs in the stack that are below this PR (inclusive).  Throws error if any of the open PRs are out of sync.
663    @:param open_only: Only return open PRs
664    """
665    # For ghstack, cherry-pick commits based from origin
666    orig_ref = f"{repo.remote}/{pr.get_ghstack_orig_ref()}"
667    rev_list = repo.revlist(f"{pr.default_branch()}..{orig_ref}")
668
669    def skip_func(idx: int, candidate: "GitHubPR") -> bool:
670        if not open_only or not candidate.is_closed():
671            return False
672        print(
673            f"Skipping {idx+1} of {len(rev_list)} PR (#{candidate.pr_num}) as its already been merged"
674        )
675        return True
676
677    assert pr.is_ghstack_pr()
678    entire_stack = _revlist_to_prs(repo, pr, reversed(rev_list), skip_func)
679
680    for stacked_pr, rev in entire_stack:
681        if stacked_pr.is_closed():
682            continue
683        base_ref = stacked_pr.base_ref()
684        if base_ref == pr.default_branch():
685            base_ref = repo.get_merge_base(
686                f"{repo.remote}/{base_ref}", f"{repo.remote}/{stacked_pr.head_ref()}"
687            )
688        if not are_ghstack_branches_in_sync(repo, stacked_pr.head_ref(), base_ref):
689            raise RuntimeError(
690                f"PR {stacked_pr.pr_num} is out of sync with the corresponding revision {rev} on "
691                + f"branch {stacked_pr.get_ghstack_orig_ref()} that would be merged into {stacked_pr.default_branch()}.  "
692                + "This usually happens because there is a non ghstack change in the PR.  "
693                + f"Please sync them and try again (ex. make the changes on {orig_ref} and run ghstack)."
694            )
695    return entire_stack
696
697
698class GitHubPR:
699    def __init__(self, org: str, project: str, pr_num: int) -> None:
700        assert isinstance(pr_num, int)
701        self.org = org
702        self.project = project
703        self.pr_num = pr_num
704        self.info = gh_get_pr_info(org, project, pr_num)
705        self.changed_files: Optional[List[str]] = None
706        self.labels: Optional[List[str]] = None
707        self.conclusions: Optional[JobNameToStateDict] = None
708        self.comments: Optional[List[GitHubComment]] = None
709        self._authors: Optional[List[Tuple[str, str]]] = None
710        self._reviews: Optional[List[Tuple[str, str]]] = None
711        self.merge_base: Optional[str] = None
712        self.submodules: Optional[List[str]] = None
713
714    def is_closed(self) -> bool:
715        return bool(self.info["closed"])
716
717    def is_cross_repo(self) -> bool:
718        return bool(self.info["isCrossRepository"])
719
720    def base_ref(self) -> str:
721        return cast(str, self.info["baseRefName"])
722
723    def default_branch(self) -> str:
724        return cast(str, self.info["baseRepository"]["defaultBranchRef"]["name"])
725
726    def head_ref(self) -> str:
727        return cast(str, self.info["headRefName"])
728
729    def is_ghstack_pr(self) -> bool:
730        return RE_GHSTACK_HEAD_REF.match(self.head_ref()) is not None
731
732    def get_ghstack_orig_ref(self) -> str:
733        assert self.is_ghstack_pr()
734        return re.sub(r"/head$", "/orig", self.head_ref())
735
736    def is_base_repo_private(self) -> bool:
737        return bool(self.info["baseRepository"]["isPrivate"])
738
739    def get_changed_files_count(self) -> int:
740        return int(self.info["changedFiles"])
741
742    def last_commit(self) -> Any:
743        return self.info["commits"]["nodes"][-1]["commit"]
744
745    def get_merge_base(self) -> str:
746        if self.merge_base:
747            return self.merge_base
748
749        last_commit_oid = self.last_commit()["oid"]
750        # NB: We could use self.base_ref() here for regular PR, however, that doesn't
751        # work for ghstack where the base is the custom branch, i.e. gh/USER/ID/base,
752        # so let's just use main instead
753        self.merge_base = gh_fetch_merge_base(
754            self.org, self.project, last_commit_oid, self.default_branch()
755        )
756
757        # Fallback to baseRefOid if the API call fails, i.e. rate limit. Note that baseRefOid
758        # points to the base ref associated with the PR or, in other words, the head of main
759        # when the PR is created or rebased. This is not necessarily the merge base commit,
760        # but it could serve as a fallback in most cases and it's readily available as part
761        # of the PR info
762        if not self.merge_base:
763            self.merge_base = cast(str, self.info["baseRefOid"])
764
765        return self.merge_base
766
767    def get_changed_files(self) -> List[str]:
768        if self.changed_files is None:
769            info = self.info
770            unique_changed_files = set()
771            # Do not try to fetch more than 10K files
772            for _ in range(100):
773                unique_changed_files.update([x["path"] for x in info["files"]["nodes"]])
774                if not info["files"]["pageInfo"]["hasNextPage"]:
775                    break
776                rc = gh_graphql(
777                    GH_GET_PR_NEXT_FILES_QUERY,
778                    name=self.project,
779                    owner=self.org,
780                    number=self.pr_num,
781                    cursor=info["files"]["pageInfo"]["endCursor"],
782                )
783                info = rc["data"]["repository"]["pullRequest"]
784            self.changed_files = list(unique_changed_files)
785
786        if len(self.changed_files) != self.get_changed_files_count():
787            raise RuntimeError("Changed file count mismatch")
788        return self.changed_files
789
790    def get_submodules(self) -> List[str]:
791        if self.submodules is None:
792            rc = gh_graphql(GH_GET_REPO_SUBMODULES, name=self.project, owner=self.org)
793            info = rc["data"]["repository"]["submodules"]
794            self.submodules = [s["path"] for s in info["nodes"]]
795        return self.submodules
796
797    def get_changed_submodules(self) -> List[str]:
798        submodules = self.get_submodules()
799        return [f for f in self.get_changed_files() if f in submodules]
800
801    def has_invalid_submodule_updates(self) -> bool:
802        """Submodule updates in PR are invalid if submodule keyword
803        is not mentioned in neither the title nor body/description
804        nor in any of the labels.
805        """
806        return (
807            len(self.get_changed_submodules()) > 0
808            and "submodule" not in self.get_title().lower()
809            and "submodule" not in self.get_body().lower()
810            and all("submodule" not in label for label in self.get_labels())
811        )
812
813    def _get_reviews(self) -> List[Tuple[str, str]]:
814        if self._reviews is None:
815            self._reviews = []
816            info = self.info
817            for _ in range(100):
818                nodes = info["reviews"]["nodes"]
819                self._reviews = [
820                    (node["author"]["login"], node["state"]) for node in nodes
821                ] + self._reviews
822                if not info["reviews"]["pageInfo"]["hasPreviousPage"]:
823                    break
824                rc = gh_graphql(
825                    GH_GET_PR_PREV_REVIEWS_QUERY,
826                    name=self.project,
827                    owner=self.org,
828                    number=self.pr_num,
829                    cursor=info["reviews"]["pageInfo"]["startCursor"],
830                )
831                info = rc["data"]["repository"]["pullRequest"]
832        reviews = {}
833        for author, state in self._reviews:
834            if state != "COMMENTED":
835                reviews[author] = state
836        return list(reviews.items())
837
838    def get_approved_by(self) -> List[str]:
839        return [login for (login, state) in self._get_reviews() if state == "APPROVED"]
840
841    def get_commit_count(self) -> int:
842        return int(self.info["commits_with_authors"]["totalCount"])
843
844    def get_pr_creator_login(self) -> str:
845        return cast(str, self.info["author"]["login"])
846
847    def _fetch_authors(self) -> List[Tuple[str, str]]:
848        if self._authors is not None:
849            return self._authors
850        authors: List[Tuple[str, str]] = []
851
852        def add_authors(info: Dict[str, Any]) -> None:
853            for node in info["commits_with_authors"]["nodes"]:
854                for author_node in node["commit"]["authors"]["nodes"]:
855                    user_node = author_node["user"]
856                    author = f"{author_node['name']} <{author_node['email']}>"
857                    if user_node is None:
858                        # If author is not github user, user node will be null
859                        authors.append(("", author))
860                    else:
861                        authors.append((cast(str, user_node["login"]), author))
862
863        info = self.info
864        for _ in range(100):
865            add_authors(info)
866            if not info["commits_with_authors"]["pageInfo"]["hasNextPage"]:
867                break
868            rc = gh_graphql(
869                GH_GET_PR_NEXT_AUTHORS_QUERY,
870                name=self.project,
871                owner=self.org,
872                number=self.pr_num,
873                cursor=info["commits_with_authors"]["pageInfo"]["endCursor"],
874            )
875            info = rc["data"]["repository"]["pullRequest"]
876        self._authors = authors
877        return authors
878
879    def get_committer_login(self, num: int = 0) -> str:
880        return self._fetch_authors()[num][0]
881
882    def get_committer_author(self, num: int = 0) -> str:
883        return self._fetch_authors()[num][1]
884
885    def get_labels(self) -> List[str]:
886        if self.labels is not None:
887            return self.labels
888        labels = (
889            [node["node"]["name"] for node in self.info["labels"]["edges"]]
890            if "labels" in self.info
891            else []
892        )
893        self.labels = labels
894        return self.labels
895
896    def get_checkrun_conclusions(self) -> JobNameToStateDict:
897        """Returns dict of checkrun -> [conclusion, url]"""
898        if self.conclusions is not None:
899            return self.conclusions
900        orig_last_commit = self.last_commit()
901
902        def get_pr_next_check_runs(
903            edges: List[Dict[str, Dict[str, Any]]], edge_idx: int, checkruns: Any
904        ) -> Any:
905            rc = gh_graphql(
906                GH_GET_PR_NEXT_CHECK_RUNS,
907                name=self.project,
908                owner=self.org,
909                number=self.pr_num,
910                cs_cursor=edges[edge_idx - 1]["cursor"] if edge_idx > 0 else None,
911                cr_cursor=checkruns["pageInfo"]["endCursor"],
912            )
913            last_commit = rc["data"]["repository"]["pullRequest"]["commits"]["nodes"][
914                -1
915            ]["commit"]
916            checkruns = last_commit["checkSuites"]["nodes"][-1]["checkRuns"]
917            return checkruns
918
919        def get_pr_next_checksuites(checksuites: Any) -> Any:
920            rc = gh_graphql(
921                GH_GET_PR_NEXT_CHECKSUITES,
922                name=self.project,
923                owner=self.org,
924                number=self.pr_num,
925                cursor=checksuites["edges"][-1]["cursor"],
926            )
927            info = rc["data"]["repository"]["pullRequest"]
928            last_commit = info["commits"]["nodes"][-1]["commit"]
929            if last_commit["oid"] != orig_last_commit["oid"]:
930                raise RuntimeError("Last commit changed on PR")
931            return last_commit["checkSuites"]
932
933        checksuites = orig_last_commit["checkSuites"]
934
935        self.conclusions = add_workflow_conclusions(
936            checksuites, get_pr_next_check_runs, get_pr_next_checksuites
937        )
938
939        # Append old style statuses(like ones populated by CircleCI or EasyCLA) to conclusions
940        if orig_last_commit["status"] and orig_last_commit["status"]["contexts"]:
941            for status in orig_last_commit["status"]["contexts"]:
942                name = status["context"]
943                self.conclusions[name] = JobCheckState(
944                    name,
945                    status["targetUrl"],
946                    status["state"],
947                    classification=None,
948                    job_id=None,
949                    title=None,
950                    summary=None,
951                )
952
953        return self.conclusions
954
955    def get_authors(self) -> Dict[str, str]:
956        rc = {}
957        for idx in range(len(self._fetch_authors())):
958            rc[self.get_committer_login(idx)] = self.get_committer_author(idx)
959
960        return rc
961
962    def get_author(self) -> str:
963        authors = self.get_authors()
964        if len(authors) == 1:
965            return next(iter(authors.values()))
966        creator = self.get_pr_creator_login()
967        # If PR creator is not among authors
968        # Assume it was authored by first commit author
969        if creator not in authors:
970            return self.get_committer_author(0)
971        return authors[creator]
972
973    def get_title(self) -> str:
974        return cast(str, self.info["title"])
975
976    def get_body(self) -> str:
977        return cast(str, self.info["body"])
978
979    def get_merge_commit(self) -> Optional[str]:
980        mc = self.info["mergeCommit"]
981        return mc["oid"] if mc is not None else None
982
983    def get_pr_url(self) -> str:
984        return f"https://github.com/{self.org}/{self.project}/pull/{self.pr_num}"
985
986    @staticmethod
987    def _comment_from_node(node: Any) -> GitHubComment:
988        editor = node["editor"]
989        return GitHubComment(
990            body_text=node["bodyText"],
991            created_at=node["createdAt"] if "createdAt" in node else "",
992            author_login=node["author"]["login"],
993            author_association=node["authorAssociation"],
994            editor_login=editor["login"] if editor else None,
995            database_id=node["databaseId"],
996            url=node["url"],
997        )
998
999    def get_comments(self) -> List[GitHubComment]:
1000        if self.comments is not None:
1001            return self.comments
1002        self.comments = []
1003        info = self.info["comments"]
1004        # Do not try to fetch more than 10K comments
1005        for _ in range(100):
1006            self.comments = [
1007                self._comment_from_node(node) for node in info["nodes"]
1008            ] + self.comments
1009            if not info["pageInfo"]["hasPreviousPage"]:
1010                break
1011            rc = gh_graphql(
1012                GH_GET_PR_PREV_COMMENTS,
1013                name=self.project,
1014                owner=self.org,
1015                number=self.pr_num,
1016                cursor=info["pageInfo"]["startCursor"],
1017            )
1018            info = rc["data"]["repository"]["pullRequest"]["comments"]
1019        return self.comments
1020
1021    def get_last_comment(self) -> GitHubComment:
1022        return self._comment_from_node(self.info["comments"]["nodes"][-1])
1023
1024    def get_comment_by_id(self, database_id: int) -> GitHubComment:
1025        if self.comments is None:
1026            # Fastpath - try searching in partial prefetched comments
1027            for node in self.info["comments"]["nodes"]:
1028                comment = self._comment_from_node(node)
1029                if comment.database_id == database_id:
1030                    return comment
1031
1032        for comment in self.get_comments():
1033            if comment.database_id == database_id:
1034                return comment
1035
1036        # The comment could have actually been a review left on the PR (the message written alongside the review).
1037        # (This is generally done to trigger the merge right when a comment is left)
1038        # Check those review comments to see if one of those was the comment in question.
1039        for node in self.info["reviews"]["nodes"]:
1040            # These review comments contain all the fields regular comments need
1041            comment = self._comment_from_node(node)
1042            if comment.database_id == database_id:
1043                return comment
1044
1045        raise RuntimeError(f"Comment with id {database_id} not found")
1046
1047    def get_diff_revision(self) -> Optional[str]:
1048        rc = RE_DIFF_REV.search(self.get_body())
1049        return rc.group(1) if rc is not None else None
1050
1051    def has_internal_changes(self) -> bool:
1052        checkrun_name = INTERNAL_CHANGES_CHECKRUN_NAME
1053        if self.get_diff_revision() is None:
1054            return False
1055        checks = self.get_checkrun_conclusions()
1056        if checks is None or checkrun_name not in checks:
1057            return False
1058        return checks[checkrun_name].status != "SUCCESS"
1059
1060    def has_no_connected_diff(self) -> bool:
1061        checkrun_name = INTERNAL_CHANGES_CHECKRUN_NAME
1062        checks = self.get_checkrun_conclusions()
1063        if checks is None or checkrun_name not in checks:
1064            return False
1065        return checks[checkrun_name].title == HAS_NO_CONNECTED_DIFF_TITLE
1066
1067    def merge_ghstack_into(
1068        self,
1069        repo: GitRepo,
1070        skip_mandatory_checks: bool,
1071        comment_id: Optional[int] = None,
1072        skip_all_rule_checks: bool = False,
1073    ) -> List["GitHubPR"]:
1074        assert self.is_ghstack_pr()
1075        ghstack_prs = get_ghstack_prs(
1076            repo, self, open_only=False
1077        )  # raises error if out of sync
1078        pr_dependencies = []
1079        for pr, rev in ghstack_prs:
1080            if pr.is_closed():
1081                pr_dependencies.append(pr)
1082                continue
1083
1084            commit_msg = pr.gen_commit_message(
1085                filter_ghstack=True, ghstack_deps=pr_dependencies
1086            )
1087            if pr.pr_num != self.pr_num and not skip_all_rule_checks:
1088                # Raises exception if matching rule is not found
1089                find_matching_merge_rule(
1090                    pr,
1091                    repo,
1092                    skip_mandatory_checks=skip_mandatory_checks,
1093                    skip_internal_checks=can_skip_internal_checks(self, comment_id),
1094                )
1095            repo.cherry_pick(rev)
1096            repo.amend_commit_message(commit_msg)
1097            pr_dependencies.append(pr)
1098        return [x for x, _ in ghstack_prs if not x.is_closed()]
1099
1100    def gen_commit_message(
1101        self,
1102        filter_ghstack: bool = False,
1103        ghstack_deps: Optional[List["GitHubPR"]] = None,
1104    ) -> str:
1105        """Fetches title and body from PR description
1106        adds reviewed by, pull request resolved and optionally
1107        filters out ghstack info"""
1108        # Adding the url here makes it clickable within the Github UI
1109        approved_by_urls = ", ".join(
1110            prefix_with_github_url(login) for login in self.get_approved_by()
1111        )
1112        # Remove "cc: " line from the message body
1113        msg_body = re.sub(RE_PR_CC_LINE, "", self.get_body())
1114        if filter_ghstack:
1115            msg_body = re.sub(RE_GHSTACK_DESC, "", msg_body)
1116        msg = self.get_title() + f" (#{self.pr_num})\n\n"
1117        msg += msg_body
1118
1119        msg += f"\nPull Request resolved: {self.get_pr_url()}\n"
1120        msg += f"Approved by: {approved_by_urls}\n"
1121        if ghstack_deps:
1122            msg += f"ghstack dependencies: {', '.join([f'#{pr.pr_num}' for pr in ghstack_deps])}\n"
1123
1124        # Mention PR co-authors, which should be at the end of the message
1125        # And separated from the body by two newlines
1126        first_coauthor = True
1127        for author_login, author_name in self.get_authors().items():
1128            if author_login != self.get_pr_creator_login():
1129                if first_coauthor:
1130                    msg, first_coauthor = (msg + "\n", False)
1131                msg += f"\nCo-authored-by: {author_name}"
1132
1133        return msg
1134
1135    def add_numbered_label(self, label_base: str, dry_run: bool) -> None:
1136        labels = self.get_labels() if self.labels is not None else []
1137        full_label = label_base
1138        count = 0
1139        for label in labels:
1140            if label_base in label:
1141                count += 1
1142                full_label = f"{label_base}X{count}"
1143        gh_add_labels(self.org, self.project, self.pr_num, [full_label], dry_run)
1144
1145    def merge_into(
1146        self,
1147        repo: GitRepo,
1148        *,
1149        skip_mandatory_checks: bool = False,
1150        dry_run: bool = False,
1151        comment_id: Optional[int] = None,
1152        ignore_current_checks: Optional[List[str]] = None,
1153    ) -> None:
1154        # Raises exception if matching rule is not found
1155        (
1156            merge_rule,
1157            pending_checks,
1158            failed_checks,
1159            ignorable_checks,
1160        ) = find_matching_merge_rule(
1161            self,
1162            repo,
1163            skip_mandatory_checks=skip_mandatory_checks,
1164            skip_internal_checks=can_skip_internal_checks(self, comment_id),
1165            ignore_current_checks=ignore_current_checks,
1166        )
1167        additional_merged_prs = self.merge_changes(
1168            repo, skip_mandatory_checks, comment_id
1169        )
1170
1171        repo.push(self.default_branch(), dry_run)
1172        if not dry_run:
1173            self.add_numbered_label(MERGE_COMPLETE_LABEL, dry_run)
1174            for pr in additional_merged_prs:
1175                pr.add_numbered_label(MERGE_COMPLETE_LABEL, dry_run)
1176
1177        if comment_id and self.pr_num:
1178            # When the merge process reaches this part, we can assume that the commit
1179            # has been successfully pushed to trunk
1180            merge_commit_sha = repo.rev_parse(name=REMOTE_MAIN_BRANCH)
1181
1182            # Finally, upload the record to Rockset. The list of pending and failed
1183            # checks are at the time of the merge
1184            save_merge_record(
1185                comment_id=comment_id,
1186                pr_num=self.pr_num,
1187                owner=self.org,
1188                project=self.project,
1189                author=self.get_author(),
1190                pending_checks=pending_checks,
1191                failed_checks=failed_checks,
1192                ignore_current_checks=ignorable_checks.get("IGNORE_CURRENT_CHECK", []),
1193                broken_trunk_checks=ignorable_checks.get("BROKEN_TRUNK", []),
1194                flaky_checks=ignorable_checks.get("FLAKY", []),
1195                unstable_checks=ignorable_checks.get("UNSTABLE", []),
1196                last_commit_sha=self.last_commit().get("oid", ""),
1197                merge_base_sha=self.get_merge_base(),
1198                merge_commit_sha=merge_commit_sha,
1199                is_failed=False,
1200                skip_mandatory_checks=skip_mandatory_checks,
1201                ignore_current=bool(ignore_current_checks),
1202            )
1203        else:
1204            print("Missing comment ID or PR number, couldn't upload to Rockset")
1205
1206    def merge_changes(
1207        self,
1208        repo: GitRepo,
1209        skip_mandatory_checks: bool = False,
1210        comment_id: Optional[int] = None,
1211        branch: Optional[str] = None,
1212        skip_all_rule_checks: bool = False,
1213    ) -> List["GitHubPR"]:
1214        """
1215        :param skip_all_rule_checks: If true, skips all rule checks, useful for dry-running merge locally
1216        """
1217        branch_to_merge_into = self.default_branch() if branch is None else branch
1218        if repo.current_branch() != branch_to_merge_into:
1219            repo.checkout(branch_to_merge_into)
1220        if not self.is_ghstack_pr():
1221            msg = self.gen_commit_message()
1222            pr_branch_name = f"__pull-request-{self.pr_num}__init__"
1223            repo.fetch(f"pull/{self.pr_num}/head", pr_branch_name)
1224            repo._run_git("merge", "--squash", pr_branch_name)
1225            repo._run_git("commit", f'--author="{self.get_author()}"', "-m", msg)
1226            return []
1227        else:
1228            return self.merge_ghstack_into(
1229                repo,
1230                skip_mandatory_checks,
1231                comment_id=comment_id,
1232                skip_all_rule_checks=skip_all_rule_checks,
1233            )
1234
1235
1236class MergeRuleFailedError(RuntimeError):
1237    def __init__(self, message: str, rule: Optional["MergeRule"] = None) -> None:
1238        super().__init__(message)
1239        self.rule = rule
1240
1241
1242class MandatoryChecksMissingError(MergeRuleFailedError):
1243    pass
1244
1245
1246class PostCommentError(Exception):
1247    pass
1248
1249
1250@dataclass
1251class MergeRule:
1252    name: str
1253    patterns: List[str]
1254    approved_by: List[str]
1255    mandatory_checks_name: Optional[List[str]]
1256    ignore_flaky_failures: bool = True
1257
1258
1259def gen_new_issue_link(
1260    org: str, project: str, labels: List[str], template: str = "bug-report.yml"
1261) -> str:
1262    labels_str = ",".join(labels)
1263    return (
1264        f"https://github.com/{org}/{project}/issues/new?"
1265        f"labels={urllib.parse.quote(labels_str)}&"
1266        f"template={urllib.parse.quote(template)}"
1267    )
1268
1269
1270def read_merge_rules(
1271    repo: Optional[GitRepo], org: str, project: str
1272) -> List[MergeRule]:
1273    """Returns the list of all merge rules for the repo or project.
1274
1275    NB: this function is used in Meta-internal workflows, see the comment
1276    at the top of this file for details.
1277    """
1278    repo_relative_rules_path = MERGE_RULE_PATH
1279    if repo is None:
1280        json_data = gh_fetch_url(
1281            f"https://api.github.com/repos/{org}/{project}/contents/{repo_relative_rules_path}",
1282            headers={"Accept": "application/vnd.github.v3+json"},
1283            reader=json.load,
1284        )
1285        content = base64.b64decode(json_data["content"])
1286        return [MergeRule(**x) for x in yaml.safe_load(content)]
1287    else:
1288        rules_path = Path(repo.repo_dir) / repo_relative_rules_path
1289        if not rules_path.exists():
1290            print(f"{rules_path} does not exist, returning empty rules")
1291            return []
1292        with open(rules_path) as fp:
1293            rc = yaml.safe_load(fp)
1294        return [MergeRule(**x) for x in rc]
1295
1296
1297def find_matching_merge_rule(
1298    pr: GitHubPR,
1299    repo: Optional[GitRepo] = None,
1300    skip_mandatory_checks: bool = False,
1301    skip_internal_checks: bool = False,
1302    ignore_current_checks: Optional[List[str]] = None,
1303) -> Tuple[
1304    MergeRule,
1305    List[Tuple[str, Optional[str], Optional[int]]],
1306    List[Tuple[str, Optional[str], Optional[int]]],
1307    Dict[str, List[Any]],
1308]:
1309    """
1310    Returns merge rule matching to this pr together with the list of associated pending
1311    and failing jobs OR raises an exception.
1312
1313    NB: this function is used in Meta-internal workflows, see the comment at the top of
1314    this file for details.
1315    """
1316    changed_files = pr.get_changed_files()
1317    approved_by = set(pr.get_approved_by())
1318
1319    issue_link = gen_new_issue_link(
1320        org=pr.org,
1321        project=pr.project,
1322        labels=["module: ci"],
1323    )
1324    reject_reason = f"No rule found to match PR. Please [report]{issue_link} this issue to DevX team."
1325
1326    rules = read_merge_rules(repo, pr.org, pr.project)
1327    if not rules:
1328        reject_reason = f"Rejecting the merge as no rules are defined for the repository in {MERGE_RULE_PATH}"
1329        raise RuntimeError(reject_reason)
1330
1331    checks = pr.get_checkrun_conclusions()
1332    checks = get_classifications(
1333        pr.pr_num,
1334        pr.project,
1335        checks,
1336        ignore_current_checks=ignore_current_checks,
1337    )
1338
1339    # This keeps the list of all approvers that could stamp the change
1340    all_rule_approvers = {}
1341
1342    # PRs can fail multiple merge rules, but it only needs to pass one rule to be approved.
1343    # If it fails all rules, we need to find the rule that it came closest to passing and report
1344    # that to the dev.
1345    #
1346    # reject_reason_score ranks rules by relevancy. The higher the score, the more relevant the
1347    # rule & rejection reason, and we only care about the most relevant rule/reason
1348    #
1349    # reject_reason_score intrepretation:
1350    # Score 0 to 10K - how many files rule matched
1351    # Score 10K - matched all files, but no overlapping approvers
1352    # Score 20K - matched all files and approvers, but mandatory checks are pending
1353    # Score 30k - Matched all files and approvers, but mandatory checks failed
1354    reject_reason_score = 0
1355    for rule in rules:
1356        rule_name = rule.name
1357        patterns_re = patterns_to_regex(rule.patterns)
1358        non_matching_files = []
1359
1360        # Does this rule apply to all the files?
1361        for fname in changed_files:
1362            if not patterns_re.match(fname):
1363                non_matching_files.append(fname)
1364        if len(non_matching_files) > 0:
1365            num_matching_files = len(changed_files) - len(non_matching_files)
1366            if num_matching_files > reject_reason_score:
1367                reject_reason_score = num_matching_files
1368                reject_reason = "\n".join(
1369                    (
1370                        f"Not all files match rule `{rule_name}`.",
1371                        f"{num_matching_files} files matched, but there are still non-matching files:",
1372                        f"{','.join(non_matching_files[:5])}{', ...' if len(non_matching_files) > 5 else ''}",
1373                    )
1374                )
1375            continue
1376
1377        # If rule needs approvers but PR has not been reviewed, skip it
1378        if len(rule.approved_by) > 0 and len(approved_by) == 0:
1379            if reject_reason_score < 10000:
1380                reject_reason_score = 10000
1381                reject_reason = f"PR #{pr.pr_num} has not been reviewed yet"
1382            continue
1383
1384        # Does the PR have the required approvals for this rule?
1385        rule_approvers = set()
1386        for approver in rule.approved_by:
1387            if "/" in approver:
1388                org, name = approver.split("/")
1389                rule_approvers.update(gh_get_team_members(org, name))
1390            else:
1391                rule_approvers.add(approver)
1392        approvers_intersection = approved_by.intersection(rule_approvers)
1393        # If rule requires approvers but they aren't the ones that reviewed PR
1394        if len(approvers_intersection) == 0 and len(rule_approvers) > 0:
1395            # Less than or equal is intentionally used here to gather all potential
1396            # approvers
1397            if reject_reason_score <= 10000:
1398                reject_reason_score = 10000
1399
1400                all_rule_approvers[rule.name] = rule.approved_by
1401                # Prepare the reject reason
1402                all_rule_approvers_msg = [
1403                    f"- {name} ({', '.join(approved_by[:5])}{', ...' if len(approved_by) > 5 else ''})"
1404                    for name, approved_by in all_rule_approvers.items()
1405                ]
1406
1407                reject_reason = "Approvers from one of the following sets are needed:\n"
1408                reject_reason += "\n".join(all_rule_approvers_msg)
1409
1410            continue
1411
1412        # Does the PR pass the checks required by this rule?
1413        mandatory_checks = (
1414            rule.mandatory_checks_name if rule.mandatory_checks_name is not None else []
1415        )
1416        required_checks = list(
1417            filter(
1418                lambda x: ("EasyCLA" in x)
1419                or ("Facebook CLA Check" in x)
1420                or not skip_mandatory_checks,
1421                mandatory_checks,
1422            )
1423        )
1424        pending_checks, failed_checks, _ = categorize_checks(
1425            checks,
1426            required_checks,
1427            ok_failed_checks_threshold=IGNORABLE_FAILED_CHECKS_THESHOLD
1428            if rule.ignore_flaky_failures
1429            else 0,
1430        )
1431
1432        # categorize_checks assumes all tests are required if required_checks is empty.
1433        # this is a workaround as we want to keep that behavior for categorize_checks
1434        # generally.
1435        if not required_checks:
1436            pending_checks = []
1437            failed_checks = []
1438
1439        hud_link = f"https://hud.pytorch.org/{pr.org}/{pr.project}/commit/{pr.last_commit()['oid']}"
1440        if len(failed_checks) > 0:
1441            if reject_reason_score < 30000:
1442                reject_reason_score = 30000
1443                reject_reason = "\n".join(
1444                    (
1445                        f"{len(failed_checks)} mandatory check(s) failed.  The first few are:",
1446                        *checks_to_markdown_bullets(failed_checks),
1447                        "",
1448                        f"Dig deeper by [viewing the failures on hud]({hud_link})",
1449                    )
1450                )
1451            continue
1452        elif len(pending_checks) > 0:
1453            if reject_reason_score < 20000:
1454                reject_reason_score = 20000
1455                reject_reason = "\n".join(
1456                    (
1457                        f"{len(pending_checks)} mandatory check(s) are pending/not yet run.  The first few are:",
1458                        *checks_to_markdown_bullets(pending_checks),
1459                        "",
1460                        f"Dig deeper by [viewing the pending checks on hud]({hud_link})",
1461                    )
1462                )
1463            continue
1464
1465        if not skip_internal_checks and pr.has_internal_changes():
1466            raise RuntimeError(
1467                "This PR has internal changes and must be landed via Phabricator! Please try reimporting/rexporting the PR!"
1468            )
1469
1470        # Categorize all checks when skip_mandatory_checks (force merge) is set. Do it here
1471        # where the list of checks is readily available. These records will be saved into
1472        # Rockset merge records
1473        (
1474            pending_mandatory_checks,
1475            failed_mandatory_checks,
1476            ignorable_checks,
1477        ) = categorize_checks(
1478            checks,
1479            [],
1480            ok_failed_checks_threshold=IGNORABLE_FAILED_CHECKS_THESHOLD,
1481        )
1482        return (
1483            rule,
1484            pending_mandatory_checks,
1485            failed_mandatory_checks,
1486            ignorable_checks,
1487        )
1488
1489    if reject_reason_score == 20000:
1490        raise MandatoryChecksMissingError(reject_reason, rule)
1491    raise MergeRuleFailedError(reject_reason, rule)
1492
1493
1494def checks_to_str(checks: List[Tuple[str, Optional[str]]]) -> str:
1495    return ", ".join(f"[{c[0]}]({c[1]})" if c[1] is not None else c[0] for c in checks)
1496
1497
1498def checks_to_markdown_bullets(
1499    checks: List[Tuple[str, Optional[str], Optional[int]]]
1500) -> List[str]:
1501    return [
1502        f"- [{c[0]}]({c[1]})" if c[1] is not None else f"- {c[0]}" for c in checks[:5]
1503    ]
1504
1505
1506@retries_decorator()
1507def save_merge_record(
1508    comment_id: int,
1509    pr_num: int,
1510    owner: str,
1511    project: str,
1512    author: str,
1513    pending_checks: List[Tuple[str, Optional[str], Optional[int]]],
1514    failed_checks: List[Tuple[str, Optional[str], Optional[int]]],
1515    ignore_current_checks: List[Tuple[str, Optional[str], Optional[int]]],
1516    broken_trunk_checks: List[Tuple[str, Optional[str], Optional[int]]],
1517    flaky_checks: List[Tuple[str, Optional[str], Optional[int]]],
1518    unstable_checks: List[Tuple[str, Optional[str], Optional[int]]],
1519    last_commit_sha: str,
1520    merge_base_sha: str,
1521    merge_commit_sha: str = "",
1522    is_failed: bool = False,
1523    skip_mandatory_checks: bool = False,
1524    ignore_current: bool = False,
1525    error: str = "",
1526) -> None:
1527    """
1528    This saves the merge records as a json, which can later be uploaded to s3
1529    """
1530
1531    # Prepare the record to be written into Rockset
1532    data = [
1533        {
1534            "comment_id": comment_id,
1535            "pr_num": pr_num,
1536            "owner": owner,
1537            "project": project,
1538            "author": author,
1539            "pending_checks": pending_checks,
1540            "failed_checks": failed_checks,
1541            "ignore_current_checks": ignore_current_checks,
1542            "broken_trunk_checks": broken_trunk_checks,
1543            "flaky_checks": flaky_checks,
1544            "unstable_checks": unstable_checks,
1545            "last_commit_sha": last_commit_sha,
1546            "merge_base_sha": merge_base_sha,
1547            "merge_commit_sha": merge_commit_sha,
1548            "is_failed": is_failed,
1549            "skip_mandatory_checks": skip_mandatory_checks,
1550            "ignore_current": ignore_current,
1551            "error": error,
1552            # This is a unique identifier for the record for deduping purposes
1553            # in rockset.  Any unique string would work
1554            "_id": f"{project}-{pr_num}-{comment_id}-{os.environ.get('GITHUB_RUN_ID')}",
1555        }
1556    ]
1557    repo_root = Path(__file__).resolve().parent.parent.parent
1558
1559    with open(repo_root / "merge_record.json", "w") as f:
1560        json.dump(data, f)
1561
1562
1563@retries_decorator(rc=[])
1564def get_rockset_results(head_sha: str, merge_base: str) -> List[Dict[str, Any]]:
1565    query = f"""
1566SELECT
1567    w.name as workflow_name,
1568    j.id,
1569    j.name,
1570    j.conclusion,
1571    j.completed_at,
1572    j.html_url,
1573    j.head_sha,
1574    j.torchci_classification.captures as failure_captures,
1575    LENGTH(j.steps) as steps,
1576FROM
1577    commons.workflow_job j join commons.workflow_run w on w.id = j.run_id
1578where
1579    j.head_sha in ('{head_sha}','{merge_base}')
1580"""
1581    try:
1582        import rockset  # type: ignore[import]
1583
1584        res = rockset.RocksetClient(
1585            host="api.usw2a1.rockset.com", api_key=os.environ["ROCKSET_API_KEY"]
1586        ).sql(query)
1587        return cast(List[Dict[str, Any]], res.results)
1588    except ModuleNotFoundError:
1589        print("Could not use RockSet as rocket dependency is missing")
1590        return []
1591
1592
1593@retries_decorator()
1594def get_drci_classifications(pr_num: int, project: str = "pytorch") -> Any:
1595    """
1596    Query HUD API to find similar failures to decide if they are flaky
1597    """
1598    # NB: This doesn't work internally atm because this requires making an
1599    # external API call to HUD
1600    failures = gh_fetch_url(
1601        f"https://hud.pytorch.org/api/drci/drci?prNumber={pr_num}",
1602        data=f"repo={project}",
1603        headers={
1604            "Authorization": os.getenv("DRCI_BOT_KEY", ""),
1605            "Accept": "application/vnd.github.v3+json",
1606        },
1607        method="POST",
1608        reader=json.load,
1609    )
1610
1611    return failures.get(str(pr_num), {}) if failures else {}
1612
1613
1614REMOVE_JOB_NAME_SUFFIX_REGEX = re.compile(r", [0-9]+, [0-9]+, .+\)$")
1615
1616
1617def remove_job_name_suffix(name: str, replacement: str = ")") -> str:
1618    return re.sub(REMOVE_JOB_NAME_SUFFIX_REGEX, replacement, name)
1619
1620
1621def is_broken_trunk(
1622    check: JobCheckState,
1623    drci_classifications: Any,
1624) -> bool:
1625    if not check or not drci_classifications:
1626        return False
1627
1628    name = check.name
1629    job_id = check.job_id
1630
1631    # Consult the list of broken trunk failures from Dr.CI
1632    return any(
1633        (name == broken_trunk["name"]) or (job_id and job_id == broken_trunk["id"])
1634        for broken_trunk in drci_classifications.get("BROKEN_TRUNK", [])
1635    )
1636
1637
1638def is_unstable(
1639    check: JobCheckState,
1640    drci_classifications: Any,
1641) -> bool:
1642    if not check or not drci_classifications:
1643        return False
1644
1645    name = check.name
1646    job_id = check.job_id
1647
1648    # The job name has the unstable keyword. This is the original way to mark a job
1649    # as unstable on HUD, Dr.CI, and trymerge
1650    if "unstable" in name:
1651        return True
1652
1653    # Consult the list of unstable failures from Dr.CI
1654    return any(
1655        (name == unstable["name"] or (job_id and job_id == unstable["id"]))
1656        for unstable in drci_classifications.get("UNSTABLE", [])
1657    )
1658
1659
1660def is_flaky(
1661    check: JobCheckState,
1662    drci_classifications: Any,
1663) -> bool:
1664    if not check or not drci_classifications:
1665        return False
1666
1667    name = check.name
1668    job_id = check.job_id
1669
1670    # Consult the list of flaky failures from Dr.CI
1671    return any(
1672        (name == flaky["name"] or (job_id and job_id == flaky["id"]))
1673        for flaky in drci_classifications.get("FLAKY", [])
1674    )
1675
1676
1677def is_invalid_cancel(
1678    name: str,
1679    conclusion: Optional[str],
1680    drci_classifications: Any,
1681) -> bool:
1682    """
1683    After https://github.com/pytorch/test-infra/pull/4579, invalid cancelled
1684    signals have been removed from HUD and Dr.CI. The same needs to be done
1685    here for consistency
1686    """
1687    if (
1688        not name
1689        or not drci_classifications
1690        or not conclusion
1691        or conclusion.upper() != "CANCELLED"
1692    ):
1693        return False
1694
1695    # If a job is cancelled and not listed as a failure by Dr.CI, it's an
1696    # invalid signal and can be ignored
1697    return all(
1698        name != failure["name"] for failure in drci_classifications.get("FAILED", [])
1699    )
1700
1701
1702def get_classifications(
1703    pr_num: int,
1704    project: str,
1705    checks: Dict[str, JobCheckState],
1706    ignore_current_checks: Optional[List[str]],
1707) -> Dict[str, JobCheckState]:
1708    # Get the failure classification from Dr.CI, which is the source of truth
1709    # going forward. It's preferable to try calling Dr.CI API directly first
1710    # to get the latest results as well as update Dr.CI PR comment
1711    drci_classifications = get_drci_classifications(pr_num=pr_num, project=project)
1712
1713    def get_readable_drci_results(drci_classifications: Any) -> str:
1714        try:
1715            s = f"From Dr.CI API ({pr_num}):\n"
1716            for classification, jobs in drci_classifications.items():
1717                s += f"  {classification}: \n"
1718                for job in jobs:
1719                    s += f"    {job['id']} {job['name']}\n"
1720            return s
1721        except Exception:
1722            return f"From Dr.CI API: {json.dumps(drci_classifications)}"
1723
1724    print(get_readable_drci_results(drci_classifications))
1725
1726    # NB: if the latest results from Dr.CI is not available, i.e. when calling from
1727    # SandCastle, we fallback to any results we can find on Dr.CI check run summary
1728    if (
1729        not drci_classifications
1730        and DRCI_CHECKRUN_NAME in checks
1731        and checks[DRCI_CHECKRUN_NAME]
1732        and checks[DRCI_CHECKRUN_NAME].summary
1733    ):
1734        drci_summary = checks[DRCI_CHECKRUN_NAME].summary
1735        try:
1736            print(f"From Dr.CI checkrun summary: {drci_summary}")
1737            drci_classifications = json.loads(str(drci_summary))
1738        except json.JSONDecodeError as error:
1739            warn("Invalid Dr.CI checkrun summary")
1740            drci_classifications = {}
1741
1742    checks_with_classifications = checks.copy()
1743    for name, check in checks.items():
1744        if check.status == "SUCCESS" or check.status == "NEUTRAL":
1745            continue
1746
1747        if is_unstable(check, drci_classifications):
1748            checks_with_classifications[name] = JobCheckState(
1749                check.name,
1750                check.url,
1751                check.status,
1752                "UNSTABLE",
1753                check.job_id,
1754                check.title,
1755                check.summary,
1756            )
1757            continue
1758
1759        # NB: It's important to note that when it comes to ghstack and broken trunk classification,
1760        # Dr.CI uses the base of the whole stack
1761        if is_broken_trunk(check, drci_classifications):
1762            checks_with_classifications[name] = JobCheckState(
1763                check.name,
1764                check.url,
1765                check.status,
1766                "BROKEN_TRUNK",
1767                check.job_id,
1768                check.title,
1769                check.summary,
1770            )
1771            continue
1772
1773        elif is_flaky(check, drci_classifications):
1774            checks_with_classifications[name] = JobCheckState(
1775                check.name,
1776                check.url,
1777                check.status,
1778                "FLAKY",
1779                check.job_id,
1780                check.title,
1781                check.summary,
1782            )
1783            continue
1784
1785        elif is_invalid_cancel(name, check.status, drci_classifications):
1786            # NB: Create a new category here for invalid cancelled signals because
1787            # there are usually many of them when they happen. So, they shouldn't
1788            # be counted toward ignorable failures threshold
1789            checks_with_classifications[name] = JobCheckState(
1790                check.name,
1791                check.url,
1792                check.status,
1793                "INVALID_CANCEL",
1794                check.job_id,
1795                check.title,
1796                check.summary,
1797            )
1798            continue
1799
1800        if ignore_current_checks is not None and name in ignore_current_checks:
1801            checks_with_classifications[name] = JobCheckState(
1802                check.name,
1803                check.url,
1804                check.status,
1805                "IGNORE_CURRENT_CHECK",
1806                check.job_id,
1807                check.title,
1808                check.summary,
1809            )
1810
1811    return checks_with_classifications
1812
1813
1814def filter_checks_with_lambda(
1815    checks: JobNameToStateDict, status_filter: Callable[[Optional[str]], bool]
1816) -> List[JobCheckState]:
1817    return [check for check in checks.values() if status_filter(check.status)]
1818
1819
1820def get_pr_commit_sha(repo: GitRepo, pr: GitHubPR) -> str:
1821    commit_sha = pr.get_merge_commit()
1822    if commit_sha is not None:
1823        return commit_sha
1824    commits = repo.commits_resolving_gh_pr(pr.pr_num)
1825    if len(commits) == 0:
1826        raise PostCommentError("Can't find any commits resolving PR")
1827    return commits[0]
1828
1829
1830def validate_revert(
1831    repo: GitRepo, pr: GitHubPR, *, comment_id: Optional[int] = None
1832) -> Tuple[str, str]:
1833    comment = (
1834        pr.get_last_comment()
1835        if comment_id is None
1836        else pr.get_comment_by_id(comment_id)
1837    )
1838    if comment.editor_login is not None:
1839        raise PostCommentError("Don't want to revert based on edited command")
1840    author_association = comment.author_association
1841    author_login = comment.author_login
1842    allowed_reverters = ["COLLABORATOR", "MEMBER", "OWNER"]
1843    # For some reason, one can not be a member of private repo, only CONTRIBUTOR
1844    if pr.is_base_repo_private():
1845        allowed_reverters.append("CONTRIBUTOR")
1846    if author_association not in allowed_reverters:
1847        raise PostCommentError(
1848            f"Will not revert as @{author_login} is not one of "
1849            f"[{', '.join(allowed_reverters)}], but instead is {author_association}."
1850        )
1851
1852    # Raises exception if matching rule is not found, but ignores all status checks
1853    find_matching_merge_rule(
1854        pr, repo, skip_mandatory_checks=True, skip_internal_checks=True
1855    )
1856    commit_sha = get_pr_commit_sha(repo, pr)
1857    return (author_login, commit_sha)
1858
1859
1860def get_ghstack_dependent_prs(
1861    repo: GitRepo, pr: GitHubPR, only_closed: bool = True
1862) -> List[Tuple[str, GitHubPR]]:
1863    """
1864    Get the PRs in the stack that are above this PR (inclusive).
1865    Throws error if stack have branched or original branches are gone
1866    """
1867    assert pr.is_ghstack_pr()
1868    orig_ref = f"{repo.remote}/{pr.get_ghstack_orig_ref()}"
1869    rev_list = repo.revlist(f"{pr.default_branch()}..{orig_ref}")
1870    if len(rev_list) == 0:
1871        raise RuntimeError(
1872            f"PR {pr.pr_num} does not have any revisions associated with it"
1873        )
1874    skip_len = len(rev_list) - 1
1875    for branch in repo.branches_containing_ref(orig_ref):
1876        candidate = repo.revlist(f"{pr.default_branch()}..{branch}")
1877        # Pick longest candidate
1878        if len(candidate) > len(rev_list):
1879            candidate, rev_list = rev_list, candidate
1880        # Validate that candidate always ends rev-list
1881        if rev_list[-len(candidate) :] != candidate:
1882            raise RuntimeError(
1883                f"Branch {branch} revlist {', '.join(candidate)} is not a subset of {', '.join(rev_list)}"
1884            )
1885    # Remove commits original PR depends on
1886    if skip_len > 0:
1887        rev_list = rev_list[:-skip_len]
1888    rc: List[Tuple[str, GitHubPR]] = []
1889    for pr_, sha in _revlist_to_prs(repo, pr, rev_list):
1890        if not pr_.is_closed():
1891            if not only_closed:
1892                rc.append(("", pr_))
1893            continue
1894        commit_sha = get_pr_commit_sha(repo, pr_)
1895        rc.append((commit_sha, pr_))
1896    return rc
1897
1898
1899def do_revert_prs(
1900    repo: GitRepo,
1901    shas_and_prs: List[Tuple[str, GitHubPR]],
1902    *,
1903    author_login: str,
1904    extra_msg: str = "",
1905    skip_internal_checks: bool = False,
1906    dry_run: bool = False,
1907) -> None:
1908    # Prepare and push revert commits
1909    commit_shas: List[str] = []
1910    for commit_sha, pr in shas_and_prs:
1911        revert_msg = f"\nReverted {pr.get_pr_url()} on behalf of {prefix_with_github_url(author_login)}"
1912        revert_msg += extra_msg
1913        repo.checkout(pr.default_branch())
1914        repo.revert(commit_sha)
1915        msg = repo.commit_message("HEAD")
1916        msg = re.sub(RE_PULL_REQUEST_RESOLVED, "", msg)
1917        msg += revert_msg
1918        repo.amend_commit_message(msg)
1919    repo.push(shas_and_prs[0][1].default_branch(), dry_run)
1920
1921    # Comment/reopen PRs
1922    for commit_sha, pr in shas_and_prs:
1923        revert_message = (
1924            f"@{pr.get_pr_creator_login()} your PR has been successfully reverted."
1925        )
1926        if (
1927            pr.has_internal_changes()
1928            and not pr.has_no_connected_diff()
1929            and not skip_internal_checks
1930        ):
1931            revert_message += "\n:warning: This PR might contain internal changes"
1932            revert_message += "\ncc: @pytorch/pytorch-dev-infra"
1933        gh_post_pr_comment(
1934            pr.org, pr.project, pr.pr_num, revert_message, dry_run=dry_run
1935        )
1936
1937        pr.add_numbered_label("reverted", dry_run)
1938        if not dry_run:
1939            gh_post_commit_comment(pr.org, pr.project, commit_sha, revert_msg)
1940            gh_update_pr_state(pr.org, pr.project, pr.pr_num)
1941
1942
1943def try_revert(
1944    repo: GitRepo,
1945    pr: GitHubPR,
1946    *,
1947    dry_run: bool = False,
1948    comment_id: Optional[int] = None,
1949    reason: Optional[str] = None,
1950) -> None:
1951    try:
1952        author_login, commit_sha = validate_revert(repo, pr, comment_id=comment_id)
1953    except PostCommentError as e:
1954        gh_post_pr_comment(pr.org, pr.project, pr.pr_num, str(e), dry_run=dry_run)
1955        return
1956
1957    extra_msg = f" due to {reason}" if reason is not None else ""
1958    extra_msg += (
1959        f" ([comment]({pr.get_comment_by_id(comment_id).url}))\n"
1960        if comment_id is not None
1961        else "\n"
1962    )
1963    shas_and_prs = [(commit_sha, pr)]
1964    if pr.is_ghstack_pr():
1965        try:
1966            shas_and_prs = get_ghstack_dependent_prs(repo, pr)
1967            prs_to_revert = " ".join([t[1].get_pr_url() for t in shas_and_prs])
1968            print(f"About to stack of PRs: {prs_to_revert}")
1969        except Exception as e:
1970            print(
1971                f"Failed to fetch dependent PRs: {str(e)}, fall over to single revert"
1972            )
1973
1974    do_revert_prs(
1975        repo,
1976        shas_and_prs,
1977        author_login=author_login,
1978        extra_msg=extra_msg,
1979        dry_run=dry_run,
1980        skip_internal_checks=can_skip_internal_checks(pr, comment_id),
1981    )
1982
1983
1984def prefix_with_github_url(suffix_str: str) -> str:
1985    return f"https://github.com/{suffix_str}"
1986
1987
1988def check_for_sev(org: str, project: str, skip_mandatory_checks: bool) -> None:
1989    if skip_mandatory_checks:
1990        return
1991    response = cast(
1992        Dict[str, Any],
1993        gh_fetch_json_list(
1994            "https://api.github.com/search/issues",
1995            params={"q": f'repo:{org}/{project} is:open is:issue label:"ci: sev"'},
1996        ),
1997    )
1998    if response["total_count"] != 0:
1999        for item in response["items"]:
2000            if "MERGE BLOCKING" in item["body"]:
2001                raise RuntimeError(
2002                    "Not merging any PRs at the moment because there is a "
2003                    + "merge blocking https://github.com/pytorch/pytorch/labels/ci:%20sev issue open at: \n"
2004                    + f"{item['html_url']}"
2005                )
2006    return
2007
2008
2009def has_label(labels: List[str], pattern: Pattern[str] = CIFLOW_LABEL) -> bool:
2010    return len(list(filter(pattern.match, labels))) > 0
2011
2012
2013def categorize_checks(
2014    check_runs: JobNameToStateDict,
2015    required_checks: List[str],
2016    ok_failed_checks_threshold: Optional[int] = None,
2017) -> Tuple[
2018    List[Tuple[str, Optional[str], Optional[int]]],
2019    List[Tuple[str, Optional[str], Optional[int]]],
2020    Dict[str, List[Any]],
2021]:
2022    """
2023    Categories all jobs into the list of pending and failing jobs. All known flaky
2024    failures and broken trunk are ignored by defaults when ok_failed_checks_threshold
2025    is not set (unlimited)
2026    """
2027    pending_checks: List[Tuple[str, Optional[str], Optional[int]]] = []
2028    failed_checks: List[Tuple[str, Optional[str], Optional[int]]] = []
2029
2030    # failed_checks_categorization is used to keep track of all ignorable failures when saving the merge record on Rockset
2031    failed_checks_categorization: Dict[str, List[Any]] = defaultdict(list)
2032
2033    # If required_checks is not set or empty, consider all names are relevant
2034    relevant_checknames = [
2035        name
2036        for name in check_runs.keys()
2037        if not required_checks or any(x in name for x in required_checks)
2038    ]
2039
2040    for checkname in required_checks:
2041        if all(checkname not in x for x in check_runs.keys()):
2042            pending_checks.append((checkname, None, None))
2043
2044    for checkname in relevant_checknames:
2045        status = check_runs[checkname].status
2046        url = check_runs[checkname].url
2047        classification = check_runs[checkname].classification
2048        job_id = check_runs[checkname].job_id
2049
2050        if status is None and classification != "UNSTABLE":
2051            # NB: No need to wait if the job classification is unstable as it would be
2052            # ignored anyway. This is useful to not need to wait for scarce resources
2053            # like ROCm, which is also frequently in unstable mode
2054            pending_checks.append((checkname, url, job_id))
2055        elif classification == "INVALID_CANCEL":
2056            continue
2057        elif not is_passing_status(check_runs[checkname].status):
2058            target = (
2059                failed_checks_categorization[classification]
2060                if classification
2061                in ("IGNORE_CURRENT_CHECK", "BROKEN_TRUNK", "FLAKY", "UNSTABLE")
2062                else failed_checks
2063            )
2064            target.append((checkname, url, job_id))
2065
2066    flaky_or_broken_trunk = (
2067        failed_checks_categorization["BROKEN_TRUNK"]
2068        + failed_checks_categorization["FLAKY"]
2069    )
2070
2071    if flaky_or_broken_trunk:
2072        warn(
2073            f"The following {len(flaky_or_broken_trunk)} checks failed but were likely due flakiness or broken trunk: "
2074            + ", ".join([x[0] for x in flaky_or_broken_trunk])
2075            + (
2076                f" but this is greater than the threshold of {ok_failed_checks_threshold} so merge will fail"
2077                if ok_failed_checks_threshold is not None
2078                and len(flaky_or_broken_trunk) > ok_failed_checks_threshold
2079                else ""
2080            )
2081        )
2082
2083    if (
2084        ok_failed_checks_threshold is not None
2085        and len(flaky_or_broken_trunk) > ok_failed_checks_threshold
2086    ):
2087        failed_checks = failed_checks + flaky_or_broken_trunk
2088
2089    # The list of failed_checks_categorization is returned so that it can be saved into the Rockset merge record
2090    return (pending_checks, failed_checks, failed_checks_categorization)
2091
2092
2093def merge(
2094    pr: GitHubPR,
2095    repo: GitRepo,
2096    dry_run: bool = False,
2097    skip_mandatory_checks: bool = False,
2098    comment_id: Optional[int] = None,
2099    timeout_minutes: int = 400,
2100    stale_pr_days: int = 3,
2101    ignore_current: bool = False,
2102) -> None:
2103    initial_commit_sha = pr.last_commit()["oid"]
2104    pr_link = f"https://github.com/{pr.org}/{pr.project}/pull/{pr.pr_num}"
2105    print(f"Attempting merge of {initial_commit_sha} ({pr_link})")
2106
2107    if MERGE_IN_PROGRESS_LABEL not in pr.get_labels():
2108        gh_add_labels(pr.org, pr.project, pr.pr_num, [MERGE_IN_PROGRESS_LABEL], dry_run)
2109
2110    explainer = TryMergeExplainer(
2111        skip_mandatory_checks,
2112        pr.get_labels(),
2113        pr.pr_num,
2114        pr.org,
2115        pr.project,
2116        ignore_current,
2117    )
2118
2119    # probably a bad name, but this is a list of current checks that should be
2120    # ignored and is toggled by the --ignore-current flag
2121    ignore_current_checks_info = []
2122
2123    if pr.is_ghstack_pr():
2124        get_ghstack_prs(repo, pr)  # raises error if out of sync
2125
2126    check_for_sev(pr.org, pr.project, skip_mandatory_checks)
2127
2128    if skip_mandatory_checks:
2129        gh_post_pr_comment(
2130            pr.org,
2131            pr.project,
2132            pr.pr_num,
2133            explainer.get_merge_message(),
2134            dry_run=dry_run,
2135        )
2136        return pr.merge_into(
2137            repo,
2138            dry_run=dry_run,
2139            skip_mandatory_checks=skip_mandatory_checks,
2140            comment_id=comment_id,
2141        )
2142
2143    # Check for approvals
2144    find_matching_merge_rule(pr, repo, skip_mandatory_checks=True)
2145
2146    if not has_required_labels(pr):
2147        raise RuntimeError(LABEL_ERR_MSG.lstrip(" #"))
2148
2149    if ignore_current:
2150        checks = pr.get_checkrun_conclusions()
2151        _, failing, _ = categorize_checks(
2152            checks,
2153            list(checks.keys()),
2154            ok_failed_checks_threshold=IGNORABLE_FAILED_CHECKS_THESHOLD,
2155        )
2156        ignore_current_checks_info = failing
2157
2158    gh_post_pr_comment(
2159        pr.org,
2160        pr.project,
2161        pr.pr_num,
2162        explainer.get_merge_message(ignore_current_checks_info),
2163        dry_run=dry_run,
2164    )
2165
2166    start_time = time.time()
2167    last_exception = ""
2168    elapsed_time = 0.0
2169    ignore_current_checks = [
2170        x[0] for x in ignore_current_checks_info
2171    ]  # convert to List[str] for convenience
2172    while elapsed_time < timeout_minutes * 60:
2173        check_for_sev(pr.org, pr.project, skip_mandatory_checks)
2174        current_time = time.time()
2175        elapsed_time = current_time - start_time
2176        print(
2177            f"Attempting merge of https://github.com/{pr.org}/{pr.project}/pull/{pr.pr_num} ({elapsed_time / 60} minutes elapsed)"
2178        )
2179        pr = GitHubPR(pr.org, pr.project, pr.pr_num)
2180        if initial_commit_sha != pr.last_commit()["oid"]:
2181            raise RuntimeError(
2182                "New commits were pushed while merging. Please rerun the merge command."
2183            )
2184        try:
2185            required_checks = []
2186            failed_rule_message = None
2187            ignore_flaky_failures = True
2188            try:
2189                find_matching_merge_rule(
2190                    pr, repo, ignore_current_checks=ignore_current_checks
2191                )
2192            except MandatoryChecksMissingError as ex:
2193                if ex.rule is not None:
2194                    ignore_flaky_failures = ex.rule.ignore_flaky_failures
2195                    if ex.rule.mandatory_checks_name is not None:
2196                        required_checks = ex.rule.mandatory_checks_name
2197                failed_rule_message = ex
2198
2199            checks = pr.get_checkrun_conclusions()
2200            checks = get_classifications(
2201                pr.pr_num,
2202                pr.project,
2203                checks,
2204                ignore_current_checks=ignore_current_checks,
2205            )
2206            pending, failing, _ = categorize_checks(
2207                checks,
2208                required_checks
2209                + [x for x in checks.keys() if x not in required_checks],
2210                ok_failed_checks_threshold=IGNORABLE_FAILED_CHECKS_THESHOLD
2211                if ignore_flaky_failures
2212                else 0,
2213            )
2214            # HACK until GitHub will be better about surfacing those
2215            startup_failures = filter_checks_with_lambda(
2216                checks, lambda status: status == "STARTUP_FAILURE"
2217            )
2218            if len(startup_failures) > 0:
2219                raise RuntimeError(
2220                    f"{len(startup_failures)} STARTUP failures reported, please check workflows syntax! "
2221                    + ", ".join(f"[{x.name}]({x.url})" for x in startup_failures[:5])
2222                )
2223            # END of HACK
2224
2225            if len(failing) > 0:
2226                raise RuntimeError(
2227                    f"{len(failing)} jobs have failed, first few of them are: "
2228                    + ", ".join(f"[{x[0]}]({x[1]})" for x in failing[:5])
2229                )
2230            if len(pending) > 0:
2231                if failed_rule_message is not None:
2232                    raise failed_rule_message
2233                else:
2234                    raise MandatoryChecksMissingError(
2235                        f"Still waiting for {len(pending)} jobs to finish, "
2236                        + f"first few of them are: {', '.join(x[0] for x in pending[:5])}"
2237                    )
2238
2239            return pr.merge_into(
2240                repo,
2241                dry_run=dry_run,
2242                skip_mandatory_checks=skip_mandatory_checks,
2243                comment_id=comment_id,
2244                ignore_current_checks=ignore_current_checks,
2245            )
2246        except MandatoryChecksMissingError as ex:
2247            last_exception = str(ex)
2248            print(
2249                f"Merge of https://github.com/{pr.org}/{pr.project}/pull/{pr.pr_num} failed due to: {ex}. Retrying in 5 min"
2250            )
2251            time.sleep(5 * 60)
2252    # Finally report timeout back
2253    msg = f"Merged timed out after {timeout_minutes} minutes. Please contact the pytorch_dev_infra team."
2254    msg += f"The last exception was: {last_exception}"
2255    gh_add_labels(pr.org, pr.project, pr.pr_num, ["land-failed"], dry_run)
2256    raise RuntimeError(msg)
2257
2258
2259def main() -> None:
2260    args = parse_args()
2261    repo = GitRepo(get_git_repo_dir(), get_git_remote_name())
2262    org, project = repo.gh_owner_and_name()
2263    pr = GitHubPR(org, project, args.pr_num)
2264
2265    def handle_exception(e: Exception, title: str = "Merge failed") -> None:
2266        exception = f"**Reason**: {e}"
2267
2268        failing_rule = None
2269        if isinstance(e, MergeRuleFailedError):
2270            failing_rule = e.rule.name if e.rule else None
2271
2272        internal_debugging = ""
2273        run_url = os.getenv("GH_RUN_URL")
2274        if run_url is not None:
2275            # Hide this behind a collapsed bullet since it's not helpful to most devs
2276            internal_debugging = "\n".join(
2277                line
2278                for line in (
2279                    "<details><summary>Details for Dev Infra team</summary>",
2280                    f'Raised by <a href="{run_url}">workflow job</a>\n',
2281                    f"Failing merge rule: {failing_rule}" if failing_rule else "",
2282                    "</details>",
2283                )
2284                if line
2285            )  # ignore empty lines during the join
2286
2287        msg = "\n".join((f"## {title}", f"{exception}", "", f"{internal_debugging}"))
2288
2289        gh_post_pr_comment(org, project, args.pr_num, msg, dry_run=args.dry_run)
2290        import traceback
2291
2292        traceback.print_exc()
2293
2294    if args.revert:
2295        try:
2296            gh_post_pr_comment(
2297                org,
2298                project,
2299                args.pr_num,
2300                get_revert_message(org, project, pr.pr_num),
2301                args.dry_run,
2302            )
2303            try_revert(
2304                repo,
2305                pr,
2306                dry_run=args.dry_run,
2307                comment_id=args.comment_id,
2308                reason=args.reason,
2309            )
2310        except Exception as e:
2311            handle_exception(e, f"Reverting PR {args.pr_num} failed")
2312        return
2313
2314    if pr.is_closed():
2315        gh_post_pr_comment(
2316            org,
2317            project,
2318            args.pr_num,
2319            f"Can't merge closed PR #{args.pr_num}",
2320            dry_run=args.dry_run,
2321        )
2322        return
2323
2324    if pr.is_cross_repo() and pr.is_ghstack_pr():
2325        gh_post_pr_comment(
2326            org,
2327            project,
2328            args.pr_num,
2329            "Cross-repo ghstack merges are not supported",
2330            dry_run=args.dry_run,
2331        )
2332        return
2333    if not pr.is_ghstack_pr() and pr.base_ref() != pr.default_branch():
2334        gh_post_pr_comment(
2335            org,
2336            project,
2337            args.pr_num,
2338            f"PR targets {pr.base_ref()} rather than {pr.default_branch()}, refusing merge request",
2339            dry_run=args.dry_run,
2340        )
2341        return
2342
2343    if args.check_mergeability:
2344        if pr.is_ghstack_pr():
2345            get_ghstack_prs(repo, pr)  # raises error if out of sync
2346        pr.merge_changes(
2347            repo,
2348            skip_mandatory_checks=True,
2349            skip_all_rule_checks=True,
2350        )
2351        return
2352
2353    if not args.force and pr.has_invalid_submodule_updates():
2354        message = (
2355            f"This PR updates submodules {', '.join(pr.get_changed_submodules())}\n"
2356        )
2357        message += '\nIf those updates are intentional, please add "submodule" keyword to PR title/description.'
2358        gh_post_pr_comment(org, project, args.pr_num, message, dry_run=args.dry_run)
2359        return
2360    try:
2361        merge(
2362            pr,
2363            repo,
2364            dry_run=args.dry_run,
2365            skip_mandatory_checks=args.force,
2366            comment_id=args.comment_id,
2367            ignore_current=args.ignore_current,
2368        )
2369    except Exception as e:
2370        handle_exception(e)
2371
2372        if args.comment_id and args.pr_num:
2373            # Finally, upload the record to Rockset, we don't have access to the
2374            # list of pending and failed checks here, but they are not really
2375            # needed at the moment
2376            save_merge_record(
2377                comment_id=args.comment_id,
2378                pr_num=args.pr_num,
2379                owner=org,
2380                project=project,
2381                author=pr.get_author(),
2382                pending_checks=[],
2383                failed_checks=[],
2384                ignore_current_checks=[],
2385                broken_trunk_checks=[],
2386                flaky_checks=[],
2387                unstable_checks=[],
2388                last_commit_sha=pr.last_commit().get("oid", ""),
2389                merge_base_sha=pr.get_merge_base(),
2390                is_failed=True,
2391                skip_mandatory_checks=args.force,
2392                ignore_current=args.ignore_current,
2393                error=str(e),
2394            )
2395        else:
2396            print("Missing comment ID or PR number, couldn't upload to Rockset")
2397    finally:
2398        if not args.check_mergeability:
2399            gh_remove_label(
2400                org, project, args.pr_num, MERGE_IN_PROGRESS_LABEL, args.dry_run
2401            )
2402
2403
2404if __name__ == "__main__":
2405    main()
2406