• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import argparse
2import json
3import multiprocessing as mp
4import os
5import re
6import tempfile
7from typing import Any, Dict, List, Optional, Tuple
8
9import requests
10import rockset  # type: ignore[import]
11from gitutils import retries_decorator
12
13
14LOGS_QUERY = """
15with
16    shas as (
17        SELECT
18            push.head_commit.id as sha,
19        FROM
20            commons.push
21        WHERE
22            push.ref = 'refs/heads/viable/strict'
23            AND push.repository.full_name = 'pytorch/pytorch'
24        ORDER BY
25            push._event_time DESC
26        LIMIT
27            5
28    )
29select
30    id,
31    name
32from
33    workflow_job j
34    join shas on shas.sha = j.head_sha
35where
36    j.name like '% / test%'
37    and j.name not like '%rerun_disabled_tests%'
38    and j.name not like '%mem_leak_check%'
39"""
40
41TEST_EXISTS_QUERY = """
42select
43    count(*) as c
44from
45    test_run_s3
46where
47    cast(name as string) like :name
48    and classname like :classname
49    and _event_time > CURRENT_TIMESTAMP() - DAYS(7)
50"""
51
52CLOSING_COMMENT = (
53    "I cannot find any mention of this test in rockset for the past 7 days "
54    "or in the logs for the past 5 commits on viable/strict.  Closing this "
55    "issue as it is highly likely that this test has either been renamed or "
56    "removed.  If you think this is a false positive, please feel free to "
57    "re-open this issue."
58)
59
60DISABLED_TESTS_JSON = (
61    "https://ossci-metrics.s3.amazonaws.com/disabled-tests-condensed.json"
62)
63
64
65def parse_args() -> Any:
66    parser = argparse.ArgumentParser()
67    parser.add_argument(
68        "--dry-run",
69        action="store_true",
70        help="Only list the tests.",
71    )
72    return parser.parse_args()
73
74
75@retries_decorator()
76def query_rockset(
77    query: str, params: Optional[Dict[str, Any]] = None
78) -> List[Dict[str, Any]]:
79    res = rockset.RocksetClient(
80        host="api.rs2.usw2.rockset.com", api_key=os.environ["ROCKSET_API_KEY"]
81    ).sql(query, params)
82    results: List[Dict[str, Any]] = res.results
83    return results
84
85
86def download_log_worker(temp_dir: str, id: int, name: str) -> None:
87    url = f"https://ossci-raw-job-status.s3.amazonaws.com/log/{id}"
88    data = requests.get(url).text
89    with open(f"{temp_dir}/{name.replace('/', '_')} {id}.txt", "x") as f:
90        f.write(data)
91
92
93def printer(item: Tuple[str, Tuple[int, str, List[Any]]], extra: str) -> None:
94    test, (_, link, _) = item
95    print(f"{link:<55} {test:<120} {extra}")
96
97
98def close_issue(num: int) -> None:
99    headers = {
100        "Accept": "application/vnd.github.v3+json",
101        "Authorization": f"token {os.environ['GITHUB_TOKEN']}",
102    }
103    requests.post(
104        f"https://api.github.com/repos/pytorch/pytorch/issues/{num}/comments",
105        data=json.dumps({"body": CLOSING_COMMENT}),
106        headers=headers,
107    )
108    requests.patch(
109        f"https://api.github.com/repos/pytorch/pytorch/issues/{num}",
110        data=json.dumps({"state": "closed"}),
111        headers=headers,
112    )
113
114
115def check_if_exists(
116    item: Tuple[str, Tuple[int, str, List[str]]], all_logs: List[str]
117) -> Tuple[bool, str]:
118    test, (_, link, _) = item
119    # Test names should look like `test_a (module.path.classname)`
120    reg = re.match(r"(\S+) \((\S*)\)", test)
121    if reg is None:
122        return False, "poorly formed"
123
124    name = reg[1]
125    classname = reg[2].split(".")[-1]
126
127    # Check if there is any mention of the link or the test name in the logs.
128    # The link usually shows up in the skip reason.
129    present = False
130    for log in all_logs:
131        if link in log:
132            present = True
133            break
134        if f"{classname}::{name}" in log:
135            present = True
136            break
137    if present:
138        return True, "found in logs"
139
140    # Query rockset to see if the test is there
141    count = query_rockset(
142        TEST_EXISTS_QUERY, {"name": f"{name}%", "classname": f"{classname}%"}
143    )
144    if count[0]["c"] == 0:
145        return False, "not found"
146    return True, "found in rockset"
147
148
149if __name__ == "__main__":
150    args = parse_args()
151    disabled_tests_json = json.loads(requests.get(DISABLED_TESTS_JSON).text)
152
153    all_logs = []
154    jobs = query_rockset(LOGS_QUERY)
155    with tempfile.TemporaryDirectory() as temp_dir:
156        pool = mp.Pool(20)
157        for job in jobs:
158            id = job["id"]
159            name = job["name"]
160            pool.apply_async(download_log_worker, args=(temp_dir, id, name))
161        pool.close()
162        pool.join()
163
164        for filename in os.listdir(temp_dir):
165            with open(f"{temp_dir}/{filename}") as f:
166                all_logs.append(f.read())
167
168    # If its less than 200 something definitely went wrong.
169    assert len(all_logs) > 200
170    assert len(all_logs) == len(jobs)
171
172    to_be_closed = []
173    for item in disabled_tests_json.items():
174        exists, reason = check_if_exists(item, all_logs)
175        printer(item, reason)
176        if not exists:
177            to_be_closed.append(item)
178
179    print(f"There are {len(to_be_closed)} issues that will be closed:")
180    for item in to_be_closed:
181        printer(item, "")
182
183    if args.dry_run:
184        print("dry run, not actually closing")
185    else:
186        for item in to_be_closed:
187            _, (num, _, _) = item
188            close_issue(num)
189