• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1from __future__ import annotations
2
3import re
4from typing import Any
5
6from tools.testing.target_determination.heuristics.interface import (
7    HeuristicInterface,
8    TestPrioritizations,
9)
10from tools.testing.target_determination.heuristics.utils import (
11    get_git_commit_info,
12    get_issue_or_pr_body,
13    get_pr_number,
14)
15from tools.testing.test_run import TestRun
16
17
18# This heuristic searches the PR body and commit titles, as well as issues/PRs
19# mentioned in the PR body/commit title for test names (search depth of 1) and
20# gives the test a rating of 1.  For example, if I mention "test_foo" in the PR
21# body, test_foo will be rated 1.  If I mention #123 in the PR body, and #123
22# mentions "test_foo", test_foo will be rated 1.
23class MentionedInPR(HeuristicInterface):
24    def __init__(self, **kwargs: Any) -> None:
25        super().__init__(**kwargs)
26
27    def _search_for_linked_issues(self, s: str) -> list[str]:
28        return re.findall(r"#(\d+)", s) + re.findall(r"/pytorch/pytorch/.*/(\d+)", s)
29
30    def get_prediction_confidence(self, tests: list[str]) -> TestPrioritizations:
31        try:
32            commit_messages = get_git_commit_info()
33        except Exception as e:
34            print(f"Can't get commit info due to {e}")
35            commit_messages = ""
36        try:
37            pr_number = get_pr_number()
38            if pr_number is not None:
39                pr_body = get_issue_or_pr_body(pr_number)
40            else:
41                pr_body = ""
42        except Exception as e:
43            print(f"Can't get PR body due to {e}")
44            pr_body = ""
45
46        # Search for linked issues or PRs
47        linked_issue_bodies: list[str] = []
48        for issue in self._search_for_linked_issues(
49            commit_messages
50        ) + self._search_for_linked_issues(pr_body):
51            try:
52                linked_issue_bodies.append(get_issue_or_pr_body(int(issue)))
53            except Exception as e:
54                pass
55
56        mentioned = []
57        for test in tests:
58            if (
59                test in commit_messages
60                or test in pr_body
61                or any(test in body for body in linked_issue_bodies)
62            ):
63                mentioned.append(test)
64
65        return TestPrioritizations(tests, {TestRun(test): 1 for test in mentioned})
66