1#!/usr/bin/env python3 2# For the dependencies, see the requirements.txt 3 4import logging 5import re 6import traceback 7from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser, Namespace 8from collections import OrderedDict 9from copy import deepcopy 10from dataclasses import dataclass, field 11from itertools import accumulate 12from pathlib import Path 13from subprocess import check_output 14from textwrap import dedent 15from typing import Any, Iterable, Optional, Pattern, TypedDict, Union 16 17import yaml 18from filecache import DAY, filecache 19from gitlab_common import get_token_from_default_dir 20from gql import Client, gql 21from gql.transport.requests import RequestsHTTPTransport 22from graphql import DocumentNode 23 24 25class DagNode(TypedDict): 26 needs: set[str] 27 stage: str 28 # `name` is redundant but is here for retro-compatibility 29 name: str 30 31 32# see create_job_needs_dag function for more details 33Dag = dict[str, DagNode] 34 35 36StageSeq = OrderedDict[str, set[str]] 37 38 39def get_project_root_dir(): 40 root_path = Path(__file__).parent.parent.parent.resolve() 41 gitlab_file = root_path / ".gitlab-ci.yml" 42 assert gitlab_file.exists() 43 44 return root_path 45 46 47@dataclass 48class GitlabGQL: 49 _transport: Any = field(init=False) 50 client: Client = field(init=False) 51 url: str = "https://gitlab.freedesktop.org/api/graphql" 52 token: Optional[str] = None 53 54 def __post_init__(self) -> None: 55 self._setup_gitlab_gql_client() 56 57 def _setup_gitlab_gql_client(self) -> None: 58 # Select your transport with a defined url endpoint 59 headers = {} 60 if self.token: 61 headers["Authorization"] = f"Bearer {self.token}" 62 self._transport = RequestsHTTPTransport(url=self.url, headers=headers) 63 64 # Create a GraphQL client using the defined transport 65 self.client = Client(transport=self._transport, fetch_schema_from_transport=True) 66 67 def query( 68 self, 69 gql_file: Union[Path, str], 70 params: dict[str, Any] = {}, 71 operation_name: Optional[str] = None, 72 paginated_key_loc: Iterable[str] = [], 73 disable_cache: bool = False, 74 ) -> dict[str, Any]: 75 def run_uncached() -> dict[str, Any]: 76 if paginated_key_loc: 77 return self._sweep_pages(gql_file, params, operation_name, paginated_key_loc) 78 return self._query(gql_file, params, operation_name) 79 80 if disable_cache: 81 return run_uncached() 82 83 try: 84 # Create an auxiliary variable to deliver a cached result and enable catching exceptions 85 # Decorate the query to be cached 86 if paginated_key_loc: 87 result = self._sweep_pages_cached( 88 gql_file, params, operation_name, paginated_key_loc 89 ) 90 else: 91 result = self._query_cached(gql_file, params, operation_name) 92 return result # type: ignore 93 except Exception as ex: 94 logging.error(f"Cached query failed with {ex}") 95 # print exception traceback 96 traceback_str = "".join(traceback.format_exception(ex)) 97 logging.error(traceback_str) 98 self.invalidate_query_cache() 99 logging.error("Cache invalidated, retrying without cache") 100 finally: 101 return run_uncached() 102 103 def _query( 104 self, 105 gql_file: Union[Path, str], 106 params: dict[str, Any] = {}, 107 operation_name: Optional[str] = None, 108 ) -> dict[str, Any]: 109 # Provide a GraphQL query 110 source_path: Path = Path(__file__).parent 111 pipeline_query_file: Path = source_path / gql_file 112 113 query: DocumentNode 114 with open(pipeline_query_file, "r") as f: 115 pipeline_query = f.read() 116 query = gql(pipeline_query) 117 118 # Execute the query on the transport 119 return self.client.execute_sync( 120 query, variable_values=params, operation_name=operation_name 121 ) 122 123 @filecache(DAY) 124 def _sweep_pages_cached(self, *args, **kwargs): 125 return self._sweep_pages(*args, **kwargs) 126 127 @filecache(DAY) 128 def _query_cached(self, *args, **kwargs): 129 return self._query(*args, **kwargs) 130 131 def _sweep_pages( 132 self, query, params, operation_name=None, paginated_key_loc: Iterable[str] = [] 133 ) -> dict[str, Any]: 134 """ 135 Retrieve paginated data from a GraphQL API and concatenate the results into a single 136 response. 137 138 Args: 139 query: represents a filepath with the GraphQL query to be executed. 140 params: a dictionary that contains the parameters to be passed to the query. These 141 parameters can be used to filter or modify the results of the query. 142 operation_name: The `operation_name` parameter is an optional parameter that specifies 143 the name of the GraphQL operation to be executed. It is used when making a GraphQL 144 query to specify which operation to execute if there are multiple operations defined 145 in the GraphQL schema. If not provided, the default operation will be executed. 146 paginated_key_loc (Iterable[str]): The `paginated_key_loc` parameter is an iterable of 147 strings that represents the location of the paginated field within the response. It 148 is used to extract the paginated field from the response and append it to the final 149 result. The node has to be a list of objects with a `pageInfo` field that contains 150 at least the `hasNextPage` and `endCursor` fields. 151 152 Returns: 153 a dictionary containing the response from the query with the paginated field 154 concatenated. 155 """ 156 157 def fetch_page(cursor: str | None = None) -> dict[str, Any]: 158 if cursor: 159 params["cursor"] = cursor 160 logging.info( 161 f"Found more than 100 elements, paginating. " 162 f"Current cursor at {cursor}" 163 ) 164 165 return self._query(query, params, operation_name) 166 167 # Execute the initial query 168 response: dict[str, Any] = fetch_page() 169 170 # Initialize an empty list to store the final result 171 final_partial_field: list[dict[str, Any]] = [] 172 173 # Loop until all pages have been retrieved 174 while True: 175 # Get the partial field to be appended to the final result 176 partial_field = response 177 for key in paginated_key_loc: 178 partial_field = partial_field[key] 179 180 # Append the partial field to the final result 181 final_partial_field += partial_field["nodes"] 182 183 # Check if there are more pages to retrieve 184 page_info = partial_field["pageInfo"] 185 if not page_info["hasNextPage"]: 186 break 187 188 # Execute the query with the updated cursor parameter 189 response = fetch_page(page_info["endCursor"]) 190 191 # Replace the "nodes" field in the original response with the final result 192 partial_field["nodes"] = final_partial_field 193 return response 194 195 def invalidate_query_cache(self) -> None: 196 logging.warning("Invalidating query cache") 197 try: 198 self._sweep_pages._db.clear() 199 self._query._db.clear() 200 except AttributeError as ex: 201 logging.warning(f"Could not invalidate cache, maybe it was not used in {ex.args}?") 202 203 204def insert_early_stage_jobs(stage_sequence: StageSeq, jobs_metadata: Dag) -> Dag: 205 pre_processed_dag: dict[str, set[str]] = {} 206 jobs_from_early_stages = list(accumulate(stage_sequence.values(), set.union)) 207 for job_name, metadata in jobs_metadata.items(): 208 final_needs: set[str] = deepcopy(metadata["needs"]) 209 # Pre-process jobs that are not based on needs field 210 # e.g. sanity job in mesa MR pipelines 211 if not final_needs: 212 job_stage: str = jobs_metadata[job_name]["stage"] 213 stage_index: int = list(stage_sequence.keys()).index(job_stage) 214 if stage_index > 0: 215 final_needs |= jobs_from_early_stages[stage_index - 1] 216 pre_processed_dag[job_name] = final_needs 217 218 for job_name, needs in pre_processed_dag.items(): 219 jobs_metadata[job_name]["needs"] = needs 220 221 return jobs_metadata 222 223 224def traverse_dag_needs(jobs_metadata: Dag) -> None: 225 created_jobs = set(jobs_metadata.keys()) 226 for job, metadata in jobs_metadata.items(): 227 final_needs: set = deepcopy(metadata["needs"]) & created_jobs 228 # Post process jobs that are based on needs field 229 partial = True 230 231 while partial: 232 next_depth: set[str] = {n for dn in final_needs for n in jobs_metadata[dn]["needs"]} 233 partial: bool = not final_needs.issuperset(next_depth) 234 final_needs = final_needs.union(next_depth) 235 236 jobs_metadata[job]["needs"] = final_needs 237 238 239def extract_stages_and_job_needs( 240 pipeline_jobs: dict[str, Any], pipeline_stages: dict[str, Any] 241) -> tuple[StageSeq, Dag]: 242 jobs_metadata = Dag() 243 # Record the stage sequence to post process deps that are not based on needs 244 # field, for example: sanity job 245 stage_sequence: OrderedDict[str, set[str]] = OrderedDict() 246 for stage in pipeline_stages["nodes"]: 247 stage_sequence[stage["name"]] = set() 248 249 for job in pipeline_jobs["nodes"]: 250 stage_sequence[job["stage"]["name"]].add(job["name"]) 251 dag_job: DagNode = { 252 "name": job["name"], 253 "stage": job["stage"]["name"], 254 "needs": set([j["node"]["name"] for j in job["needs"]["edges"]]), 255 } 256 jobs_metadata[job["name"]] = dag_job 257 258 return stage_sequence, jobs_metadata 259 260 261def create_job_needs_dag(gl_gql: GitlabGQL, params, disable_cache: bool = True) -> Dag: 262 """ 263 This function creates a Directed Acyclic Graph (DAG) to represent a sequence of jobs, where each 264 job has a set of jobs that it depends on (its "needs") and belongs to a certain "stage". 265 The "name" of the job is used as the key in the dictionary. 266 267 For example, consider the following DAG: 268 269 1. build stage: job1 -> job2 -> job3 270 2. test stage: job2 -> job4 271 272 - The job needs for job3 are: job1, job2 273 - The job needs for job4 are: job2 274 - The job2 needs to wait all jobs from build stage to finish. 275 276 The resulting DAG would look like this: 277 278 dag = { 279 "job1": {"needs": set(), "stage": "build", "name": "job1"}, 280 "job2": {"needs": {"job1", "job2", job3"}, "stage": "test", "name": "job2"}, 281 "job3": {"needs": {"job1", "job2"}, "stage": "build", "name": "job3"}, 282 "job4": {"needs": {"job2"}, "stage": "test", "name": "job4"}, 283 } 284 285 To access the job needs, one can do: 286 287 dag["job3"]["needs"] 288 289 This will return the set of jobs that job3 needs: {"job1", "job2"} 290 291 Args: 292 gl_gql (GitlabGQL): The `gl_gql` parameter is an instance of the `GitlabGQL` class, which is 293 used to make GraphQL queries to the GitLab API. 294 params (dict): The `params` parameter is a dictionary that contains the necessary parameters 295 for the GraphQL query. It is used to specify the details of the pipeline for which the 296 job needs DAG is being created. 297 The specific keys and values in the `params` dictionary will depend on 298 the requirements of the GraphQL query being executed 299 disable_cache (bool): The `disable_cache` parameter is a boolean that specifies whether the 300 301 Returns: 302 The final DAG (Directed Acyclic Graph) representing the job dependencies sourced from needs 303 or stages rule. 304 """ 305 stages_jobs_gql = gl_gql.query( 306 "pipeline_details.gql", 307 params=params, 308 paginated_key_loc=["project", "pipeline", "jobs"], 309 disable_cache=disable_cache, 310 ) 311 pipeline_data = stages_jobs_gql["project"]["pipeline"] 312 if not pipeline_data: 313 raise RuntimeError(f"Could not find any pipelines for {params}") 314 315 stage_sequence, jobs_metadata = extract_stages_and_job_needs( 316 pipeline_data["jobs"], pipeline_data["stages"] 317 ) 318 # Fill the DAG with the job needs from stages that don't have any needs but still need to wait 319 # for previous stages 320 final_dag = insert_early_stage_jobs(stage_sequence, jobs_metadata) 321 # Now that each job has its direct needs filled correctly, update the "needs" field for each job 322 # in the DAG by performing a topological traversal 323 traverse_dag_needs(final_dag) 324 325 return final_dag 326 327 328def filter_dag(dag: Dag, regex: Pattern) -> Dag: 329 jobs_with_regex: set[str] = {job for job in dag if regex.fullmatch(job)} 330 return Dag({job: data for job, data in dag.items() if job in sorted(jobs_with_regex)}) 331 332 333def print_dag(dag: Dag) -> None: 334 for job, data in dag.items(): 335 print(f"{job}:") 336 print(f"\t{' '.join(data['needs'])}") 337 print() 338 339 340def fetch_merged_yaml(gl_gql: GitlabGQL, params) -> dict[str, Any]: 341 params["content"] = dedent("""\ 342 include: 343 - local: .gitlab-ci.yml 344 """) 345 raw_response = gl_gql.query("job_details.gql", params) 346 ci_config = raw_response["ciConfig"] 347 if merged_yaml := ci_config["mergedYaml"]: 348 return yaml.safe_load(merged_yaml) 349 if "errors" in ci_config: 350 for error in ci_config["errors"]: 351 print(error) 352 353 gl_gql.invalidate_query_cache() 354 raise ValueError( 355 """ 356 Could not fetch any content for merged YAML, 357 please verify if the git SHA exists in remote. 358 Maybe you forgot to `git push`? """ 359 ) 360 361 362def recursive_fill(job, relationship_field, target_data, acc_data: dict, merged_yaml): 363 if relatives := job.get(relationship_field): 364 if isinstance(relatives, str): 365 relatives = [relatives] 366 367 for relative in relatives: 368 parent_job = merged_yaml[relative] 369 acc_data = recursive_fill(parent_job, acc_data, merged_yaml) # type: ignore 370 371 acc_data |= job.get(target_data, {}) 372 373 return acc_data 374 375 376def get_variables(job, merged_yaml, project_path, sha) -> dict[str, str]: 377 p = get_project_root_dir() / ".gitlab-ci" / "image-tags.yml" 378 image_tags = yaml.safe_load(p.read_text()) 379 380 variables = image_tags["variables"] 381 variables |= merged_yaml["variables"] 382 variables |= job["variables"] 383 variables["CI_PROJECT_PATH"] = project_path 384 variables["CI_PROJECT_NAME"] = project_path.split("/")[1] 385 variables["CI_REGISTRY_IMAGE"] = "registry.freedesktop.org/${CI_PROJECT_PATH}" 386 variables["CI_COMMIT_SHA"] = sha 387 388 while recurse_among_variables_space(variables): 389 pass 390 391 return variables 392 393 394# Based on: https://stackoverflow.com/a/2158532/1079223 395def flatten(xs): 396 for x in xs: 397 if isinstance(x, Iterable) and not isinstance(x, (str, bytes)): 398 yield from flatten(x) 399 else: 400 yield x 401 402 403def get_full_script(job) -> list[str]: 404 script = [] 405 for script_part in ("before_script", "script", "after_script"): 406 script.append(f"# {script_part}") 407 lines = flatten(job.get(script_part, [])) 408 script.extend(lines) 409 script.append("") 410 411 return script 412 413 414def recurse_among_variables_space(var_graph) -> bool: 415 updated = False 416 for var, value in var_graph.items(): 417 value = str(value) 418 dep_vars = [] 419 if match := re.findall(r"(\$[{]?[\w\d_]*[}]?)", value): 420 all_dep_vars = [v.lstrip("${").rstrip("}") for v in match] 421 # print(value, match, all_dep_vars) 422 dep_vars = [v for v in all_dep_vars if v in var_graph] 423 424 for dep_var in dep_vars: 425 dep_value = str(var_graph[dep_var]) 426 new_value = var_graph[var] 427 new_value = new_value.replace(f"${{{dep_var}}}", dep_value) 428 new_value = new_value.replace(f"${dep_var}", dep_value) 429 var_graph[var] = new_value 430 updated |= dep_value != new_value 431 432 return updated 433 434 435def print_job_final_definition(job_name, merged_yaml, project_path, sha): 436 job = merged_yaml[job_name] 437 variables = get_variables(job, merged_yaml, project_path, sha) 438 439 print("# --------- variables ---------------") 440 for var, value in sorted(variables.items()): 441 print(f"export {var}={value!r}") 442 443 # TODO: Recurse into needs to get full script 444 # TODO: maybe create a extra yaml file to avoid too much rework 445 script = get_full_script(job) 446 print() 447 print() 448 print("# --------- full script ---------------") 449 print("\n".join(script)) 450 451 if image := variables.get("MESA_IMAGE"): 452 print() 453 print() 454 print("# --------- container image ---------------") 455 print(image) 456 457 458def from_sha_to_pipeline_iid(gl_gql: GitlabGQL, params) -> str: 459 result = gl_gql.query("pipeline_utils.gql", params) 460 461 return result["project"]["pipelines"]["nodes"][0]["iid"] 462 463 464def parse_args() -> Namespace: 465 parser = ArgumentParser( 466 formatter_class=ArgumentDefaultsHelpFormatter, 467 description="CLI and library with utility functions to debug jobs via Gitlab GraphQL", 468 epilog=f"""Example: 469 {Path(__file__).name} --print-dag""", 470 ) 471 parser.add_argument("-pp", "--project-path", type=str, default="mesa/mesa") 472 parser.add_argument("--sha", "--rev", type=str, default='HEAD') 473 parser.add_argument( 474 "--regex", 475 type=str, 476 required=False, 477 help="Regex pattern for the job name to be considered", 478 ) 479 mutex_group_print = parser.add_mutually_exclusive_group() 480 mutex_group_print.add_argument( 481 "--print-dag", 482 action="store_true", 483 help="Print job needs DAG", 484 ) 485 mutex_group_print.add_argument( 486 "--print-merged-yaml", 487 action="store_true", 488 help="Print the resulting YAML for the specific SHA", 489 ) 490 mutex_group_print.add_argument( 491 "--print-job-manifest", 492 metavar='JOB_NAME', 493 type=str, 494 help="Print the resulting job data" 495 ) 496 parser.add_argument( 497 "--gitlab-token-file", 498 type=str, 499 default=get_token_from_default_dir(), 500 help="force GitLab token, otherwise it's read from $XDG_CONFIG_HOME/gitlab-token", 501 ) 502 503 args = parser.parse_args() 504 args.gitlab_token = Path(args.gitlab_token_file).read_text().strip() 505 return args 506 507 508def main(): 509 args = parse_args() 510 gl_gql = GitlabGQL(token=args.gitlab_token) 511 512 sha = check_output(['git', 'rev-parse', args.sha]).decode('ascii').strip() 513 514 if args.print_dag: 515 iid = from_sha_to_pipeline_iid(gl_gql, {"projectPath": args.project_path, "sha": sha}) 516 dag = create_job_needs_dag( 517 gl_gql, {"projectPath": args.project_path, "iid": iid}, disable_cache=True 518 ) 519 520 if args.regex: 521 dag = filter_dag(dag, re.compile(args.regex)) 522 523 print_dag(dag) 524 525 if args.print_merged_yaml or args.print_job_manifest: 526 merged_yaml = fetch_merged_yaml( 527 gl_gql, {"projectPath": args.project_path, "sha": sha} 528 ) 529 530 if args.print_merged_yaml: 531 print(yaml.dump(merged_yaml, indent=2)) 532 533 if args.print_job_manifest: 534 print_job_final_definition( 535 args.print_job_manifest, merged_yaml, args.project_path, sha 536 ) 537 538 539if __name__ == "__main__": 540 main() 541