• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2
3import json
4import logging
5import os
6import re
7import subprocess
8import sys
9import warnings
10from enum import Enum
11from functools import lru_cache
12from logging import info
13from typing import Any, Callable, Dict, List, Optional, Set
14from urllib.request import Request, urlopen
15
16import yaml
17
18REENABLE_TEST_REGEX = "(?i)(Close(d|s)?|Resolve(d|s)?|Fix(ed|es)?) (#|https://github.com/pytorch/pytorch/issues/)([0-9]+)"
19
20PREFIX = "test-config/"
21
22logging.basicConfig(level=logging.INFO)
23
24
25def is_cuda_or_rocm_job(job_name: Optional[str]) -> bool:
26    if not job_name:
27        return False
28
29    return "cuda" in job_name or "rocm" in job_name
30
31
32# Supported modes when running periodically. Only applying the mode when
33# its lambda condition returns true
34SUPPORTED_PERIODICAL_MODES: Dict[str, Callable[[Optional[str]], bool]] = {
35    # Memory leak check is only needed for CUDA and ROCm jobs which utilize GPU memory
36    "mem_leak_check": is_cuda_or_rocm_job,
37    "rerun_disabled_tests": lambda job_name: True,
38}
39
40# The link to the published list of disabled jobs
41DISABLED_JOBS_URL = "https://ossci-metrics.s3.amazonaws.com/disabled-jobs.json?versionId=tIl0Qo224T_NDVw0dtG4hU1cZJM97inV"
42# and unstable jobs
43UNSTABLE_JOBS_URL = "https://ossci-metrics.s3.amazonaws.com/unstable-jobs.json?versionId=GPyRZRsOo26Gfk_WjAoNNxEMGXkIxIes"
44
45# Some constants used to handle disabled and unstable jobs
46JOB_NAME_SEP = "/"
47BUILD_JOB_NAME = "build"
48TEST_JOB_NAME = "test"
49BUILD_AND_TEST_JOB_NAME = "build-and-test"
50JOB_NAME_CFG_REGEX = re.compile(r"(?P<job>[\w-]+)\s+\((?P<cfg>[\w-]+)\)")
51EXCLUDED_BRANCHES = ["nightly"]
52MEM_LEAK_LABEL = "enable-mem-leak-check"
53
54
55class IssueType(Enum):
56    DISABLED = "disabled"
57    UNSTABLE = "unstable"
58
59
60def parse_args() -> Any:
61    from argparse import ArgumentParser
62
63    parser = ArgumentParser(
64        "Filter all test configurations and keep only requested ones"
65    )
66    parser.add_argument(
67        "--test-matrix", type=str, required=True, help="the original test matrix"
68    )
69    parser.add_argument(
70        "--selected-test-configs",
71        type=str,
72        default="",
73        help="a comma-separated list of test configurations from the test matrix to keep",
74    )
75    parser.add_argument(
76        "--workflow", type=str, help="the name of the current workflow, i.e. pull"
77    )
78    parser.add_argument(
79        "--job-name",
80        type=str,
81        help="the name of the current job, i.e. linux-focal-py3.8-gcc7 / build",
82    )
83    parser.add_argument("--pr-number", type=str, help="the pull request number")
84    parser.add_argument("--tag", type=str, help="the associated tag if it exists")
85    parser.add_argument(
86        "--event-name",
87        type=str,
88        help="name of the event that triggered the job (pull, schedule, etc)",
89    )
90    parser.add_argument(
91        "--schedule",
92        type=str,
93        help="cron schedule that triggered the job",
94    )
95    parser.add_argument(
96        "--branch",
97        type=str,
98        default="main",
99        help="the branch name",
100    )
101    return parser.parse_args()
102
103
104@lru_cache(maxsize=None)
105def get_pr_info(pr_number: int) -> Dict[str, Any]:
106    """
107    Dynamically get PR information
108    """
109    # From https://docs.github.com/en/actions/learn-github-actions/environment-variables
110    pytorch_repo = os.environ.get("GITHUB_REPOSITORY", "pytorch/pytorch")
111    pytorch_github_api = f"https://api.github.com/repos/{pytorch_repo}"
112    github_token = os.environ["GITHUB_TOKEN"]
113
114    headers = {
115        "Accept": "application/vnd.github.v3+json",
116        "Authorization": f"token {github_token}",
117    }
118    json_response: Dict[str, Any] = download_json(
119        url=f"{pytorch_github_api}/issues/{pr_number}",
120        headers=headers,
121    )
122
123    if not json_response:
124        warnings.warn(f"Failed to get the labels for #{pr_number}")
125        return {}
126
127    return json_response
128
129
130def get_labels(pr_number: int) -> Set[str]:
131    """
132    Dynamically get the latest list of labels from the pull request
133    """
134    pr_info = get_pr_info(pr_number)
135    return {
136        label.get("name") for label in pr_info.get("labels", []) if label.get("name")
137    }
138
139
140def filter_labels(labels: Set[str], label_regex: Any) -> Set[str]:
141    """
142    Return the list of matching labels
143    """
144    return {l for l in labels if re.match(label_regex, l)}
145
146
147def filter(test_matrix: Dict[str, List[Any]], labels: Set[str]) -> Dict[str, List[Any]]:
148    """
149    Select the list of test config to run from the test matrix. The logic works
150    as follows:
151
152    If the PR has one or more test-config labels as specified, only these test configs
153    will be selected.  This also works with ciflow labels, for example, if a PR has both
154    ciflow/trunk and test-config/functorch, only trunk functorch builds and tests will
155    be run.
156
157    If the PR has none of the test-config label, all tests are run as usual.
158    """
159    filtered_test_matrix: Dict[str, List[Any]] = {"include": []}
160
161    for entry in test_matrix.get("include", []):
162        config_name = entry.get("config", "")
163        if not config_name:
164            continue
165
166        label = f"{PREFIX}{config_name.strip()}"
167        if label in labels:
168            msg = f"Select {config_name} because label {label} is present in the pull request by the time the test starts"
169            info(msg)
170            filtered_test_matrix["include"].append(entry)
171
172    test_config_labels = filter_labels(labels, re.compile(f"{PREFIX}.+"))
173    if not filtered_test_matrix["include"] and not test_config_labels:
174        info("Found no test-config label on the PR, so all test configs are included")
175        # Found no test-config label and the filtered test matrix is empty, return the same
176        # test matrix as before so that all tests can be run normally
177        return test_matrix
178    else:
179        msg = f"Found {test_config_labels} on the PR so only these test configs are run"
180        info(msg)
181        # When the filter test matrix contain matches or if a valid test config label
182        # is found in the PR, return the filtered test matrix
183        return filtered_test_matrix
184
185
186def filter_selected_test_configs(
187    test_matrix: Dict[str, List[Any]], selected_test_configs: Set[str]
188) -> Dict[str, List[Any]]:
189    """
190    Keep only the selected configs if the list if not empty. Otherwise, keep all test configs.
191    This filter is used when the workflow is dispatched manually.
192    """
193    if not selected_test_configs:
194        return test_matrix
195
196    filtered_test_matrix: Dict[str, List[Any]] = {"include": []}
197    for entry in test_matrix.get("include", []):
198        config_name = entry.get("config", "")
199        if not config_name:
200            continue
201
202        if config_name in selected_test_configs:
203            filtered_test_matrix["include"].append(entry)
204
205    return filtered_test_matrix
206
207
208def set_periodic_modes(
209    test_matrix: Dict[str, List[Any]], job_name: Optional[str]
210) -> Dict[str, List[Any]]:
211    """
212    Apply all periodic modes when running under a schedule
213    """
214    scheduled_test_matrix: Dict[str, List[Any]] = {
215        "include": [],
216    }
217
218    for config in test_matrix.get("include", []):
219        for mode, cond in SUPPORTED_PERIODICAL_MODES.items():
220            if not cond(job_name):
221                continue
222
223            cfg = config.copy()
224            cfg[mode] = mode
225            scheduled_test_matrix["include"].append(cfg)
226
227    return scheduled_test_matrix
228
229
230def mark_unstable_jobs(
231    workflow: str, job_name: str, test_matrix: Dict[str, List[Any]]
232) -> Dict[str, List[Any]]:
233    """
234    Check the list of unstable jobs and mark them accordingly. Note that if a job
235    is unstable, all its dependents will also be marked accordingly
236    """
237    return process_jobs(
238        workflow=workflow,
239        job_name=job_name,
240        test_matrix=test_matrix,
241        issue_type=IssueType.UNSTABLE,
242        url=UNSTABLE_JOBS_URL,
243    )
244
245
246def remove_disabled_jobs(
247    workflow: str, job_name: str, test_matrix: Dict[str, List[Any]]
248) -> Dict[str, List[Any]]:
249    """
250    Check the list of disabled jobs, remove the current job and all its dependents
251    if it exists in the list
252    """
253    return process_jobs(
254        workflow=workflow,
255        job_name=job_name,
256        test_matrix=test_matrix,
257        issue_type=IssueType.DISABLED,
258        url=DISABLED_JOBS_URL,
259    )
260
261
262def _filter_jobs(
263    test_matrix: Dict[str, List[Any]],
264    issue_type: IssueType,
265    target_cfg: Optional[str] = None,
266) -> Dict[str, List[Any]]:
267    """
268    An utility function used to actually apply the job filter
269    """
270    # The result will be stored here
271    filtered_test_matrix: Dict[str, List[Any]] = {"include": []}
272
273    # This is an issue to disable a CI job
274    if issue_type == IssueType.DISABLED:
275        # If there is a target config, disable (remove) only that
276        if target_cfg:
277            # Remove the target config from the test matrix
278            filtered_test_matrix["include"] = [
279                r for r in test_matrix["include"] if r.get("config", "") != target_cfg
280            ]
281
282        return filtered_test_matrix
283
284    if issue_type == IssueType.UNSTABLE:
285        for r in test_matrix["include"]:
286            cpy = r.copy()
287
288            if (target_cfg and r.get("config", "") == target_cfg) or not target_cfg:
289                # If there is a target config, only mark that as unstable, otherwise,
290                # mark everything as unstable
291                cpy[IssueType.UNSTABLE.value] = IssueType.UNSTABLE.value
292
293            filtered_test_matrix["include"].append(cpy)
294
295        return filtered_test_matrix
296
297    # No matching issue, return everything
298    return test_matrix
299
300
301def process_jobs(
302    workflow: str,
303    job_name: str,
304    test_matrix: Dict[str, List[Any]],
305    issue_type: IssueType,
306    url: str,
307) -> Dict[str, List[Any]]:
308    """
309    Both disabled and unstable jobs are in the following format:
310
311    {
312        "WORKFLOW / PLATFORM / JOB (CONFIG)": [
313            AUTHOR,
314            ISSUE_NUMBER,
315            ISSUE_URL,
316            WORKFLOW,
317            PLATFORM,
318            JOB (CONFIG),
319        ],
320        "pull / linux-bionic-py3.8-clang9 / test (dynamo)": [
321            "pytorchbot",
322            "94861",
323            "https://github.com/pytorch/pytorch/issues/94861",
324            "pull",
325            "linux-bionic-py3.8-clang9",
326            "test (dynamo)",
327        ],
328    }
329    """
330    try:
331        # The job name from github is in the PLATFORM / JOB (CONFIG) format, so breaking
332        # it into its two components first
333        current_platform, _ = (n.strip() for n in job_name.split(JOB_NAME_SEP, 1) if n)
334    except ValueError as error:
335        warnings.warn(f"Invalid job name {job_name}, returning")
336        return test_matrix
337
338    for record in download_json(url=url, headers={}).values():
339        (
340            author,
341            _,
342            target_url,
343            target_workflow,
344            target_platform,
345            target_job_cfg,
346        ) = record
347
348        if target_workflow != workflow:
349            # The current workflow doesn't match this record
350            continue
351
352        cleanup_regex = rf"(-{BUILD_JOB_NAME}|-{TEST_JOB_NAME})$"
353        # There is an exception here for binary build workflows in which the platform
354        # names have the build and test suffix. For example, we have a build job called
355        # manywheel-py3-cuda11_8-build / build and its subsequent test job called
356        # manywheel-py3-cuda11_8-test / test. So they are linked, but their suffixes
357        # are different
358        target_platform_no_suffix = re.sub(cleanup_regex, "", target_platform)
359        current_platform_no_suffix = re.sub(cleanup_regex, "", current_platform)
360
361        if (
362            target_platform != current_platform
363            and target_platform_no_suffix != current_platform_no_suffix
364        ):
365            # The current platform doesn't match this record
366            continue
367
368        # The logic after this is fairly complicated:
369        #
370        # - If the target record doesn't have the optional job (config) name,
371        #   i.e. pull / linux-bionic-py3.8-clang9, all build and test jobs will
372        #   be skipped if it's a disabled record or marked as unstable if it's
373        #   an unstable record
374        #
375        # - If the target record has the job name and it's a build job, i.e.
376        #   pull / linux-bionic-py3.8-clang9 / build, all build and test jobs
377        #   will be skipped if it's a disabled record or marked as unstable if
378        #   it's an unstable record, because the latter requires the former
379        #
380        # - If the target record has the job name and it's a test job without
381        #   the config part, i.e. pull / linux-bionic-py3.8-clang9 / test, all
382        #   test jobs will be skipped if it's a disabled record or marked as
383        #   unstable if it's an unstable record
384        #
385        # - If the target record has the job (config) name, only that test config
386        #   will be skipped or marked as unstable
387        if not target_job_cfg:
388            msg = (
389                f"Issue {target_url} created by {author} has {issue_type.value} "
390                + f"all CI jobs for {workflow} / {job_name}"
391            )
392            info(msg)
393            return _filter_jobs(
394                test_matrix=test_matrix,
395                issue_type=issue_type,
396            )
397
398        if target_job_cfg == BUILD_JOB_NAME:
399            msg = (
400                f"Issue {target_url} created by {author} has {issue_type.value} "
401                + f"the build job for {workflow} / {job_name}"
402            )
403            info(msg)
404            return _filter_jobs(
405                test_matrix=test_matrix,
406                issue_type=issue_type,
407            )
408
409        if target_job_cfg in (TEST_JOB_NAME, BUILD_AND_TEST_JOB_NAME):
410            msg = (
411                f"Issue {target_url} created by {author} has {issue_type.value} "
412                + f"all the test jobs for {workflow} / {job_name}"
413            )
414            info(msg)
415            return _filter_jobs(
416                test_matrix=test_matrix,
417                issue_type=issue_type,
418            )
419
420        m = JOB_NAME_CFG_REGEX.match(target_job_cfg)
421        if m:
422            target_job = m.group("job")
423            # Make sure that the job name is a valid test job name first before checking the config
424            if target_job in (TEST_JOB_NAME, BUILD_AND_TEST_JOB_NAME):
425                target_cfg = m.group("cfg")
426
427                # NB: There can be multiple unstable configurations, i.e. inductor, inductor_huggingface
428                test_matrix = _filter_jobs(
429                    test_matrix=test_matrix,
430                    issue_type=issue_type,
431                    target_cfg=target_cfg,
432                )
433        else:
434            warnings.warn(
435                f"Found a matching {issue_type.value} issue {target_url} for {workflow} / {job_name}, "
436                + f"but the name {target_job_cfg} is invalid"
437            )
438
439    # Found no matching target, return the same input test matrix
440    return test_matrix
441
442
443def download_json(url: str, headers: Dict[str, str], num_retries: int = 3) -> Any:
444    for _ in range(num_retries):
445        try:
446            req = Request(url=url, headers=headers)
447            content = urlopen(req, timeout=5).read().decode("utf-8")
448            return json.loads(content)
449        except Exception as e:
450            warnings.warn(f"Could not download {url}: {e}")
451
452    warnings.warn(f"All {num_retries} retries exhausted, downloading {url} failed")
453    return {}
454
455
456def set_output(name: str, val: Any) -> None:
457    if os.getenv("GITHUB_OUTPUT"):
458        with open(str(os.getenv("GITHUB_OUTPUT")), "a") as env:
459            print(f"{name}={val}", file=env)
460    else:
461        print(f"::set-output name={name}::{val}")
462
463
464def parse_reenabled_issues(s: Optional[str]) -> List[str]:
465    # NB: When the PR body is empty, GitHub API returns a None value, which is
466    # passed into this function
467    if not s:
468        return []
469
470    # The regex is meant to match all *case-insensitive* keywords that
471    # GitHub has delineated would link PRs to issues, more details here:
472    # https://docs.github.com/en/issues/tracking-your-work-with-issues/linking-a-pull-request-to-an-issue.
473    # E.g., "Close #62851", "fixES #62851" and "RESOLVED #62851" would all match, but not
474    # "closes  #62851" --> extra space, "fixing #62851" --> not a keyword, nor "fix 62851" --> no #
475    issue_numbers = [x[5] for x in re.findall(REENABLE_TEST_REGEX, s)]
476    return issue_numbers
477
478
479def get_reenabled_issues(pr_body: str = "") -> List[str]:
480    default_branch = f"origin/{os.environ.get('GIT_DEFAULT_BRANCH', 'main')}"
481    try:
482        commit_messages = subprocess.check_output(
483            f"git cherry -v {default_branch}".split(" ")
484        ).decode("utf-8")
485    except Exception as e:
486        warnings.warn(f"failed to get commit messages: {e}")
487        commit_messages = ""
488    return parse_reenabled_issues(pr_body) + parse_reenabled_issues(commit_messages)
489
490
491def check_for_setting(labels: Set[str], body: str, setting: str) -> bool:
492    return setting in labels or f"[{setting}]" in body
493
494
495def perform_misc_tasks(
496    labels: Set[str], test_matrix: Dict[str, List[Any]], job_name: str, pr_body: str
497) -> None:
498    """
499    In addition to apply the filter logic, the script also does the following
500    misc tasks to set keep-going and is-unstable variables
501    """
502    set_output("keep-going", check_for_setting(labels, pr_body, "keep-going"))
503    set_output(
504        "ci-verbose-test-logs",
505        check_for_setting(labels, pr_body, "ci-verbose-test-logs"),
506    )
507    set_output(
508        "ci-no-test-timeout", check_for_setting(labels, pr_body, "ci-no-test-timeout")
509    )
510    set_output("ci-no-td", check_for_setting(labels, pr_body, "ci-no-td"))
511    # Only relevant for the one linux distributed cuda job, delete this when TD
512    # is rolled out completely
513    set_output(
514        "ci-td-distributed", check_for_setting(labels, pr_body, "ci-td-distributed")
515    )
516
517    # Obviously, if the job name includes unstable, then this is an unstable job
518    is_unstable = job_name and IssueType.UNSTABLE.value in job_name
519    if not is_unstable and test_matrix and test_matrix.get("include"):
520        # Even when the job name doesn't mention unstable, we will also mark it as
521        # unstable when the test matrix only includes unstable jobs. Basically, this
522        # logic allows build or build-and-test jobs to be marked as unstable too.
523        #
524        # Basically, when a build job is unstable, all the subsequent test jobs are
525        # also unstable. And when all test jobs are unstable, we will also treat the
526        # build job as unstable. It's simpler this way
527        is_unstable = all(IssueType.UNSTABLE.value in r for r in test_matrix["include"])
528
529    set_output(
530        "is-unstable",
531        is_unstable,
532    )
533
534    set_output("reenabled-issues", ",".join(get_reenabled_issues(pr_body=pr_body)))
535
536    if MEM_LEAK_LABEL in labels:
537        # Enable mem leak check if label is added
538        for config in test_matrix.get("include", []):
539            if is_cuda_or_rocm_job(job_name):
540                config["mem_leak_check"] = "mem_leak_check"
541
542
543def main() -> None:
544    args = parse_args()
545    # Load the original test matrix set by the workflow. Its format, however,
546    # doesn't follow the strict JSON format, so we load it using yaml here for
547    # its more relaxed syntax
548    test_matrix = yaml.safe_load(args.test_matrix)
549
550    if test_matrix is None:
551        warnings.warn(f"Invalid test matrix input '{args.test_matrix}', exiting")
552        # We handle invalid test matrix gracefully by marking it as empty
553        set_output("is-test-matrix-empty", True)
554        sys.exit(0)
555
556    pr_number = args.pr_number
557    tag = args.tag
558
559    # If the tag matches, we can get the PR number from it, this is from ciflow
560    # workflow dispatcher
561    tag_regex = re.compile(r"^ciflow/\w+/(?P<pr_number>\d+)$")
562
563    labels = set()
564    if pr_number:
565        # If a PR number is set, query all the labels from that PR
566        labels = get_labels(int(pr_number))
567        # Then filter the test matrix and keep only the selected ones
568        filtered_test_matrix = filter(test_matrix, labels)
569
570    elif tag:
571        m = tag_regex.match(tag)
572
573        if m:
574            pr_number = m.group("pr_number")
575
576            # The PR number can also come from the tag in ciflow tag event
577            labels = get_labels(int(pr_number))
578            # Filter the test matrix and keep only the selected ones
579            filtered_test_matrix = filter(test_matrix, labels)
580
581        else:
582            # There is a tag but it isn't ciflow, so there is nothing left to do
583            filtered_test_matrix = test_matrix
584
585    else:
586        # No PR number, no tag, we can just return the test matrix as it is
587        filtered_test_matrix = test_matrix
588
589    if args.selected_test_configs:
590        selected_test_configs = {
591            v.strip().lower()
592            for v in args.selected_test_configs.split(",")
593            if v.strip()
594        }
595        filtered_test_matrix = filter_selected_test_configs(
596            filtered_test_matrix, selected_test_configs
597        )
598
599    if args.event_name == "schedule" and args.schedule == "29 8 * * *":
600        # we don't want to run the mem leak check or disabled tests on normal
601        # periodically scheduled jobs, only the ones at this time
602        filtered_test_matrix = set_periodic_modes(filtered_test_matrix, args.job_name)
603
604    if args.workflow and args.job_name and args.branch not in EXCLUDED_BRANCHES:
605        # If both workflow and job name are available, we will check if the current job
606        # is disabled and remove it and all its dependants from the test matrix
607        filtered_test_matrix = remove_disabled_jobs(
608            args.workflow, args.job_name, filtered_test_matrix
609        )
610
611        filtered_test_matrix = mark_unstable_jobs(
612            args.workflow, args.job_name, filtered_test_matrix
613        )
614
615    pr_body = get_pr_info(int(pr_number)).get("body", "") if pr_number else ""
616
617    perform_misc_tasks(
618        labels=labels,
619        test_matrix=filtered_test_matrix,
620        job_name=args.job_name,
621        pr_body=pr_body if pr_body else "",
622    )
623
624    # Set the filtered test matrix as the output
625    set_output("test-matrix", json.dumps(filtered_test_matrix))
626
627    filtered_test_matrix_len = len(filtered_test_matrix.get("include", []))
628    # and also put a flag if the test matrix is empty, so subsequent jobs can
629    # quickly check it without the need to parse the JSON string
630    set_output("is-test-matrix-empty", filtered_test_matrix_len == 0)
631
632
633if __name__ == "__main__":
634    main()
635