1#!/usr/bin/env python3 2 3import json 4import logging 5import os 6import re 7import subprocess 8import sys 9import warnings 10from enum import Enum 11from functools import lru_cache 12from logging import info 13from typing import Any, Callable, Dict, List, Optional, Set 14from urllib.request import Request, urlopen 15 16import yaml 17 18REENABLE_TEST_REGEX = "(?i)(Close(d|s)?|Resolve(d|s)?|Fix(ed|es)?) (#|https://github.com/pytorch/pytorch/issues/)([0-9]+)" 19 20PREFIX = "test-config/" 21 22logging.basicConfig(level=logging.INFO) 23 24 25def is_cuda_or_rocm_job(job_name: Optional[str]) -> bool: 26 if not job_name: 27 return False 28 29 return "cuda" in job_name or "rocm" in job_name 30 31 32# Supported modes when running periodically. Only applying the mode when 33# its lambda condition returns true 34SUPPORTED_PERIODICAL_MODES: Dict[str, Callable[[Optional[str]], bool]] = { 35 # Memory leak check is only needed for CUDA and ROCm jobs which utilize GPU memory 36 "mem_leak_check": is_cuda_or_rocm_job, 37 "rerun_disabled_tests": lambda job_name: True, 38} 39 40# The link to the published list of disabled jobs 41DISABLED_JOBS_URL = "https://ossci-metrics.s3.amazonaws.com/disabled-jobs.json?versionId=tIl0Qo224T_NDVw0dtG4hU1cZJM97inV" 42# and unstable jobs 43UNSTABLE_JOBS_URL = "https://ossci-metrics.s3.amazonaws.com/unstable-jobs.json?versionId=GPyRZRsOo26Gfk_WjAoNNxEMGXkIxIes" 44 45# Some constants used to handle disabled and unstable jobs 46JOB_NAME_SEP = "/" 47BUILD_JOB_NAME = "build" 48TEST_JOB_NAME = "test" 49BUILD_AND_TEST_JOB_NAME = "build-and-test" 50JOB_NAME_CFG_REGEX = re.compile(r"(?P<job>[\w-]+)\s+\((?P<cfg>[\w-]+)\)") 51EXCLUDED_BRANCHES = ["nightly"] 52MEM_LEAK_LABEL = "enable-mem-leak-check" 53 54 55class IssueType(Enum): 56 DISABLED = "disabled" 57 UNSTABLE = "unstable" 58 59 60def parse_args() -> Any: 61 from argparse import ArgumentParser 62 63 parser = ArgumentParser( 64 "Filter all test configurations and keep only requested ones" 65 ) 66 parser.add_argument( 67 "--test-matrix", type=str, required=True, help="the original test matrix" 68 ) 69 parser.add_argument( 70 "--selected-test-configs", 71 type=str, 72 default="", 73 help="a comma-separated list of test configurations from the test matrix to keep", 74 ) 75 parser.add_argument( 76 "--workflow", type=str, help="the name of the current workflow, i.e. pull" 77 ) 78 parser.add_argument( 79 "--job-name", 80 type=str, 81 help="the name of the current job, i.e. linux-focal-py3.8-gcc7 / build", 82 ) 83 parser.add_argument("--pr-number", type=str, help="the pull request number") 84 parser.add_argument("--tag", type=str, help="the associated tag if it exists") 85 parser.add_argument( 86 "--event-name", 87 type=str, 88 help="name of the event that triggered the job (pull, schedule, etc)", 89 ) 90 parser.add_argument( 91 "--schedule", 92 type=str, 93 help="cron schedule that triggered the job", 94 ) 95 parser.add_argument( 96 "--branch", 97 type=str, 98 default="main", 99 help="the branch name", 100 ) 101 return parser.parse_args() 102 103 104@lru_cache(maxsize=None) 105def get_pr_info(pr_number: int) -> Dict[str, Any]: 106 """ 107 Dynamically get PR information 108 """ 109 # From https://docs.github.com/en/actions/learn-github-actions/environment-variables 110 pytorch_repo = os.environ.get("GITHUB_REPOSITORY", "pytorch/pytorch") 111 pytorch_github_api = f"https://api.github.com/repos/{pytorch_repo}" 112 github_token = os.environ["GITHUB_TOKEN"] 113 114 headers = { 115 "Accept": "application/vnd.github.v3+json", 116 "Authorization": f"token {github_token}", 117 } 118 json_response: Dict[str, Any] = download_json( 119 url=f"{pytorch_github_api}/issues/{pr_number}", 120 headers=headers, 121 ) 122 123 if not json_response: 124 warnings.warn(f"Failed to get the labels for #{pr_number}") 125 return {} 126 127 return json_response 128 129 130def get_labels(pr_number: int) -> Set[str]: 131 """ 132 Dynamically get the latest list of labels from the pull request 133 """ 134 pr_info = get_pr_info(pr_number) 135 return { 136 label.get("name") for label in pr_info.get("labels", []) if label.get("name") 137 } 138 139 140def filter_labels(labels: Set[str], label_regex: Any) -> Set[str]: 141 """ 142 Return the list of matching labels 143 """ 144 return {l for l in labels if re.match(label_regex, l)} 145 146 147def filter(test_matrix: Dict[str, List[Any]], labels: Set[str]) -> Dict[str, List[Any]]: 148 """ 149 Select the list of test config to run from the test matrix. The logic works 150 as follows: 151 152 If the PR has one or more test-config labels as specified, only these test configs 153 will be selected. This also works with ciflow labels, for example, if a PR has both 154 ciflow/trunk and test-config/functorch, only trunk functorch builds and tests will 155 be run. 156 157 If the PR has none of the test-config label, all tests are run as usual. 158 """ 159 filtered_test_matrix: Dict[str, List[Any]] = {"include": []} 160 161 for entry in test_matrix.get("include", []): 162 config_name = entry.get("config", "") 163 if not config_name: 164 continue 165 166 label = f"{PREFIX}{config_name.strip()}" 167 if label in labels: 168 msg = f"Select {config_name} because label {label} is present in the pull request by the time the test starts" 169 info(msg) 170 filtered_test_matrix["include"].append(entry) 171 172 test_config_labels = filter_labels(labels, re.compile(f"{PREFIX}.+")) 173 if not filtered_test_matrix["include"] and not test_config_labels: 174 info("Found no test-config label on the PR, so all test configs are included") 175 # Found no test-config label and the filtered test matrix is empty, return the same 176 # test matrix as before so that all tests can be run normally 177 return test_matrix 178 else: 179 msg = f"Found {test_config_labels} on the PR so only these test configs are run" 180 info(msg) 181 # When the filter test matrix contain matches or if a valid test config label 182 # is found in the PR, return the filtered test matrix 183 return filtered_test_matrix 184 185 186def filter_selected_test_configs( 187 test_matrix: Dict[str, List[Any]], selected_test_configs: Set[str] 188) -> Dict[str, List[Any]]: 189 """ 190 Keep only the selected configs if the list if not empty. Otherwise, keep all test configs. 191 This filter is used when the workflow is dispatched manually. 192 """ 193 if not selected_test_configs: 194 return test_matrix 195 196 filtered_test_matrix: Dict[str, List[Any]] = {"include": []} 197 for entry in test_matrix.get("include", []): 198 config_name = entry.get("config", "") 199 if not config_name: 200 continue 201 202 if config_name in selected_test_configs: 203 filtered_test_matrix["include"].append(entry) 204 205 return filtered_test_matrix 206 207 208def set_periodic_modes( 209 test_matrix: Dict[str, List[Any]], job_name: Optional[str] 210) -> Dict[str, List[Any]]: 211 """ 212 Apply all periodic modes when running under a schedule 213 """ 214 scheduled_test_matrix: Dict[str, List[Any]] = { 215 "include": [], 216 } 217 218 for config in test_matrix.get("include", []): 219 for mode, cond in SUPPORTED_PERIODICAL_MODES.items(): 220 if not cond(job_name): 221 continue 222 223 cfg = config.copy() 224 cfg[mode] = mode 225 scheduled_test_matrix["include"].append(cfg) 226 227 return scheduled_test_matrix 228 229 230def mark_unstable_jobs( 231 workflow: str, job_name: str, test_matrix: Dict[str, List[Any]] 232) -> Dict[str, List[Any]]: 233 """ 234 Check the list of unstable jobs and mark them accordingly. Note that if a job 235 is unstable, all its dependents will also be marked accordingly 236 """ 237 return process_jobs( 238 workflow=workflow, 239 job_name=job_name, 240 test_matrix=test_matrix, 241 issue_type=IssueType.UNSTABLE, 242 url=UNSTABLE_JOBS_URL, 243 ) 244 245 246def remove_disabled_jobs( 247 workflow: str, job_name: str, test_matrix: Dict[str, List[Any]] 248) -> Dict[str, List[Any]]: 249 """ 250 Check the list of disabled jobs, remove the current job and all its dependents 251 if it exists in the list 252 """ 253 return process_jobs( 254 workflow=workflow, 255 job_name=job_name, 256 test_matrix=test_matrix, 257 issue_type=IssueType.DISABLED, 258 url=DISABLED_JOBS_URL, 259 ) 260 261 262def _filter_jobs( 263 test_matrix: Dict[str, List[Any]], 264 issue_type: IssueType, 265 target_cfg: Optional[str] = None, 266) -> Dict[str, List[Any]]: 267 """ 268 An utility function used to actually apply the job filter 269 """ 270 # The result will be stored here 271 filtered_test_matrix: Dict[str, List[Any]] = {"include": []} 272 273 # This is an issue to disable a CI job 274 if issue_type == IssueType.DISABLED: 275 # If there is a target config, disable (remove) only that 276 if target_cfg: 277 # Remove the target config from the test matrix 278 filtered_test_matrix["include"] = [ 279 r for r in test_matrix["include"] if r.get("config", "") != target_cfg 280 ] 281 282 return filtered_test_matrix 283 284 if issue_type == IssueType.UNSTABLE: 285 for r in test_matrix["include"]: 286 cpy = r.copy() 287 288 if (target_cfg and r.get("config", "") == target_cfg) or not target_cfg: 289 # If there is a target config, only mark that as unstable, otherwise, 290 # mark everything as unstable 291 cpy[IssueType.UNSTABLE.value] = IssueType.UNSTABLE.value 292 293 filtered_test_matrix["include"].append(cpy) 294 295 return filtered_test_matrix 296 297 # No matching issue, return everything 298 return test_matrix 299 300 301def process_jobs( 302 workflow: str, 303 job_name: str, 304 test_matrix: Dict[str, List[Any]], 305 issue_type: IssueType, 306 url: str, 307) -> Dict[str, List[Any]]: 308 """ 309 Both disabled and unstable jobs are in the following format: 310 311 { 312 "WORKFLOW / PLATFORM / JOB (CONFIG)": [ 313 AUTHOR, 314 ISSUE_NUMBER, 315 ISSUE_URL, 316 WORKFLOW, 317 PLATFORM, 318 JOB (CONFIG), 319 ], 320 "pull / linux-bionic-py3.8-clang9 / test (dynamo)": [ 321 "pytorchbot", 322 "94861", 323 "https://github.com/pytorch/pytorch/issues/94861", 324 "pull", 325 "linux-bionic-py3.8-clang9", 326 "test (dynamo)", 327 ], 328 } 329 """ 330 try: 331 # The job name from github is in the PLATFORM / JOB (CONFIG) format, so breaking 332 # it into its two components first 333 current_platform, _ = (n.strip() for n in job_name.split(JOB_NAME_SEP, 1) if n) 334 except ValueError as error: 335 warnings.warn(f"Invalid job name {job_name}, returning") 336 return test_matrix 337 338 for record in download_json(url=url, headers={}).values(): 339 ( 340 author, 341 _, 342 target_url, 343 target_workflow, 344 target_platform, 345 target_job_cfg, 346 ) = record 347 348 if target_workflow != workflow: 349 # The current workflow doesn't match this record 350 continue 351 352 cleanup_regex = rf"(-{BUILD_JOB_NAME}|-{TEST_JOB_NAME})$" 353 # There is an exception here for binary build workflows in which the platform 354 # names have the build and test suffix. For example, we have a build job called 355 # manywheel-py3-cuda11_8-build / build and its subsequent test job called 356 # manywheel-py3-cuda11_8-test / test. So they are linked, but their suffixes 357 # are different 358 target_platform_no_suffix = re.sub(cleanup_regex, "", target_platform) 359 current_platform_no_suffix = re.sub(cleanup_regex, "", current_platform) 360 361 if ( 362 target_platform != current_platform 363 and target_platform_no_suffix != current_platform_no_suffix 364 ): 365 # The current platform doesn't match this record 366 continue 367 368 # The logic after this is fairly complicated: 369 # 370 # - If the target record doesn't have the optional job (config) name, 371 # i.e. pull / linux-bionic-py3.8-clang9, all build and test jobs will 372 # be skipped if it's a disabled record or marked as unstable if it's 373 # an unstable record 374 # 375 # - If the target record has the job name and it's a build job, i.e. 376 # pull / linux-bionic-py3.8-clang9 / build, all build and test jobs 377 # will be skipped if it's a disabled record or marked as unstable if 378 # it's an unstable record, because the latter requires the former 379 # 380 # - If the target record has the job name and it's a test job without 381 # the config part, i.e. pull / linux-bionic-py3.8-clang9 / test, all 382 # test jobs will be skipped if it's a disabled record or marked as 383 # unstable if it's an unstable record 384 # 385 # - If the target record has the job (config) name, only that test config 386 # will be skipped or marked as unstable 387 if not target_job_cfg: 388 msg = ( 389 f"Issue {target_url} created by {author} has {issue_type.value} " 390 + f"all CI jobs for {workflow} / {job_name}" 391 ) 392 info(msg) 393 return _filter_jobs( 394 test_matrix=test_matrix, 395 issue_type=issue_type, 396 ) 397 398 if target_job_cfg == BUILD_JOB_NAME: 399 msg = ( 400 f"Issue {target_url} created by {author} has {issue_type.value} " 401 + f"the build job for {workflow} / {job_name}" 402 ) 403 info(msg) 404 return _filter_jobs( 405 test_matrix=test_matrix, 406 issue_type=issue_type, 407 ) 408 409 if target_job_cfg in (TEST_JOB_NAME, BUILD_AND_TEST_JOB_NAME): 410 msg = ( 411 f"Issue {target_url} created by {author} has {issue_type.value} " 412 + f"all the test jobs for {workflow} / {job_name}" 413 ) 414 info(msg) 415 return _filter_jobs( 416 test_matrix=test_matrix, 417 issue_type=issue_type, 418 ) 419 420 m = JOB_NAME_CFG_REGEX.match(target_job_cfg) 421 if m: 422 target_job = m.group("job") 423 # Make sure that the job name is a valid test job name first before checking the config 424 if target_job in (TEST_JOB_NAME, BUILD_AND_TEST_JOB_NAME): 425 target_cfg = m.group("cfg") 426 427 # NB: There can be multiple unstable configurations, i.e. inductor, inductor_huggingface 428 test_matrix = _filter_jobs( 429 test_matrix=test_matrix, 430 issue_type=issue_type, 431 target_cfg=target_cfg, 432 ) 433 else: 434 warnings.warn( 435 f"Found a matching {issue_type.value} issue {target_url} for {workflow} / {job_name}, " 436 + f"but the name {target_job_cfg} is invalid" 437 ) 438 439 # Found no matching target, return the same input test matrix 440 return test_matrix 441 442 443def download_json(url: str, headers: Dict[str, str], num_retries: int = 3) -> Any: 444 for _ in range(num_retries): 445 try: 446 req = Request(url=url, headers=headers) 447 content = urlopen(req, timeout=5).read().decode("utf-8") 448 return json.loads(content) 449 except Exception as e: 450 warnings.warn(f"Could not download {url}: {e}") 451 452 warnings.warn(f"All {num_retries} retries exhausted, downloading {url} failed") 453 return {} 454 455 456def set_output(name: str, val: Any) -> None: 457 if os.getenv("GITHUB_OUTPUT"): 458 with open(str(os.getenv("GITHUB_OUTPUT")), "a") as env: 459 print(f"{name}={val}", file=env) 460 else: 461 print(f"::set-output name={name}::{val}") 462 463 464def parse_reenabled_issues(s: Optional[str]) -> List[str]: 465 # NB: When the PR body is empty, GitHub API returns a None value, which is 466 # passed into this function 467 if not s: 468 return [] 469 470 # The regex is meant to match all *case-insensitive* keywords that 471 # GitHub has delineated would link PRs to issues, more details here: 472 # https://docs.github.com/en/issues/tracking-your-work-with-issues/linking-a-pull-request-to-an-issue. 473 # E.g., "Close #62851", "fixES #62851" and "RESOLVED #62851" would all match, but not 474 # "closes #62851" --> extra space, "fixing #62851" --> not a keyword, nor "fix 62851" --> no # 475 issue_numbers = [x[5] for x in re.findall(REENABLE_TEST_REGEX, s)] 476 return issue_numbers 477 478 479def get_reenabled_issues(pr_body: str = "") -> List[str]: 480 default_branch = f"origin/{os.environ.get('GIT_DEFAULT_BRANCH', 'main')}" 481 try: 482 commit_messages = subprocess.check_output( 483 f"git cherry -v {default_branch}".split(" ") 484 ).decode("utf-8") 485 except Exception as e: 486 warnings.warn(f"failed to get commit messages: {e}") 487 commit_messages = "" 488 return parse_reenabled_issues(pr_body) + parse_reenabled_issues(commit_messages) 489 490 491def check_for_setting(labels: Set[str], body: str, setting: str) -> bool: 492 return setting in labels or f"[{setting}]" in body 493 494 495def perform_misc_tasks( 496 labels: Set[str], test_matrix: Dict[str, List[Any]], job_name: str, pr_body: str 497) -> None: 498 """ 499 In addition to apply the filter logic, the script also does the following 500 misc tasks to set keep-going and is-unstable variables 501 """ 502 set_output("keep-going", check_for_setting(labels, pr_body, "keep-going")) 503 set_output( 504 "ci-verbose-test-logs", 505 check_for_setting(labels, pr_body, "ci-verbose-test-logs"), 506 ) 507 set_output( 508 "ci-no-test-timeout", check_for_setting(labels, pr_body, "ci-no-test-timeout") 509 ) 510 set_output("ci-no-td", check_for_setting(labels, pr_body, "ci-no-td")) 511 # Only relevant for the one linux distributed cuda job, delete this when TD 512 # is rolled out completely 513 set_output( 514 "ci-td-distributed", check_for_setting(labels, pr_body, "ci-td-distributed") 515 ) 516 517 # Obviously, if the job name includes unstable, then this is an unstable job 518 is_unstable = job_name and IssueType.UNSTABLE.value in job_name 519 if not is_unstable and test_matrix and test_matrix.get("include"): 520 # Even when the job name doesn't mention unstable, we will also mark it as 521 # unstable when the test matrix only includes unstable jobs. Basically, this 522 # logic allows build or build-and-test jobs to be marked as unstable too. 523 # 524 # Basically, when a build job is unstable, all the subsequent test jobs are 525 # also unstable. And when all test jobs are unstable, we will also treat the 526 # build job as unstable. It's simpler this way 527 is_unstable = all(IssueType.UNSTABLE.value in r for r in test_matrix["include"]) 528 529 set_output( 530 "is-unstable", 531 is_unstable, 532 ) 533 534 set_output("reenabled-issues", ",".join(get_reenabled_issues(pr_body=pr_body))) 535 536 if MEM_LEAK_LABEL in labels: 537 # Enable mem leak check if label is added 538 for config in test_matrix.get("include", []): 539 if is_cuda_or_rocm_job(job_name): 540 config["mem_leak_check"] = "mem_leak_check" 541 542 543def main() -> None: 544 args = parse_args() 545 # Load the original test matrix set by the workflow. Its format, however, 546 # doesn't follow the strict JSON format, so we load it using yaml here for 547 # its more relaxed syntax 548 test_matrix = yaml.safe_load(args.test_matrix) 549 550 if test_matrix is None: 551 warnings.warn(f"Invalid test matrix input '{args.test_matrix}', exiting") 552 # We handle invalid test matrix gracefully by marking it as empty 553 set_output("is-test-matrix-empty", True) 554 sys.exit(0) 555 556 pr_number = args.pr_number 557 tag = args.tag 558 559 # If the tag matches, we can get the PR number from it, this is from ciflow 560 # workflow dispatcher 561 tag_regex = re.compile(r"^ciflow/\w+/(?P<pr_number>\d+)$") 562 563 labels = set() 564 if pr_number: 565 # If a PR number is set, query all the labels from that PR 566 labels = get_labels(int(pr_number)) 567 # Then filter the test matrix and keep only the selected ones 568 filtered_test_matrix = filter(test_matrix, labels) 569 570 elif tag: 571 m = tag_regex.match(tag) 572 573 if m: 574 pr_number = m.group("pr_number") 575 576 # The PR number can also come from the tag in ciflow tag event 577 labels = get_labels(int(pr_number)) 578 # Filter the test matrix and keep only the selected ones 579 filtered_test_matrix = filter(test_matrix, labels) 580 581 else: 582 # There is a tag but it isn't ciflow, so there is nothing left to do 583 filtered_test_matrix = test_matrix 584 585 else: 586 # No PR number, no tag, we can just return the test matrix as it is 587 filtered_test_matrix = test_matrix 588 589 if args.selected_test_configs: 590 selected_test_configs = { 591 v.strip().lower() 592 for v in args.selected_test_configs.split(",") 593 if v.strip() 594 } 595 filtered_test_matrix = filter_selected_test_configs( 596 filtered_test_matrix, selected_test_configs 597 ) 598 599 if args.event_name == "schedule" and args.schedule == "29 8 * * *": 600 # we don't want to run the mem leak check or disabled tests on normal 601 # periodically scheduled jobs, only the ones at this time 602 filtered_test_matrix = set_periodic_modes(filtered_test_matrix, args.job_name) 603 604 if args.workflow and args.job_name and args.branch not in EXCLUDED_BRANCHES: 605 # If both workflow and job name are available, we will check if the current job 606 # is disabled and remove it and all its dependants from the test matrix 607 filtered_test_matrix = remove_disabled_jobs( 608 args.workflow, args.job_name, filtered_test_matrix 609 ) 610 611 filtered_test_matrix = mark_unstable_jobs( 612 args.workflow, args.job_name, filtered_test_matrix 613 ) 614 615 pr_body = get_pr_info(int(pr_number)).get("body", "") if pr_number else "" 616 617 perform_misc_tasks( 618 labels=labels, 619 test_matrix=filtered_test_matrix, 620 job_name=args.job_name, 621 pr_body=pr_body if pr_body else "", 622 ) 623 624 # Set the filtered test matrix as the output 625 set_output("test-matrix", json.dumps(filtered_test_matrix)) 626 627 filtered_test_matrix_len = len(filtered_test_matrix.get("include", [])) 628 # and also put a flag if the test matrix is empty, so subsequent jobs can 629 # quickly check it without the need to parse the JSON string 630 set_output("is-test-matrix-empty", filtered_test_matrix_len == 0) 631 632 633if __name__ == "__main__": 634 main() 635