• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2# For the dependencies, see the requirements.txt
3
4import logging
5import re
6import traceback
7from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser, Namespace
8from collections import OrderedDict
9from copy import deepcopy
10from dataclasses import dataclass, field
11from itertools import accumulate
12from pathlib import Path
13from subprocess import check_output
14from textwrap import dedent
15from typing import Any, Iterable, Optional, Pattern, TypedDict, Union
16
17import yaml
18from filecache import DAY, filecache
19from gitlab_common import get_token_from_default_dir
20from gql import Client, gql
21from gql.transport.requests import RequestsHTTPTransport
22from graphql import DocumentNode
23
24
25class DagNode(TypedDict):
26    needs: set[str]
27    stage: str
28    # `name` is redundant but is here for retro-compatibility
29    name: str
30
31
32# see create_job_needs_dag function for more details
33Dag = dict[str, DagNode]
34
35
36StageSeq = OrderedDict[str, set[str]]
37
38
39def get_project_root_dir():
40    root_path = Path(__file__).parent.parent.parent.resolve()
41    gitlab_file = root_path / ".gitlab-ci.yml"
42    assert gitlab_file.exists()
43
44    return root_path
45
46
47@dataclass
48class GitlabGQL:
49    _transport: Any = field(init=False)
50    client: Client = field(init=False)
51    url: str = "https://gitlab.freedesktop.org/api/graphql"
52    token: Optional[str] = None
53
54    def __post_init__(self) -> None:
55        self._setup_gitlab_gql_client()
56
57    def _setup_gitlab_gql_client(self) -> None:
58        # Select your transport with a defined url endpoint
59        headers = {}
60        if self.token:
61            headers["Authorization"] = f"Bearer {self.token}"
62        self._transport = RequestsHTTPTransport(url=self.url, headers=headers)
63
64        # Create a GraphQL client using the defined transport
65        self.client = Client(transport=self._transport, fetch_schema_from_transport=True)
66
67    def query(
68        self,
69        gql_file: Union[Path, str],
70        params: dict[str, Any] = {},
71        operation_name: Optional[str] = None,
72        paginated_key_loc: Iterable[str] = [],
73        disable_cache: bool = False,
74    ) -> dict[str, Any]:
75        def run_uncached() -> dict[str, Any]:
76            if paginated_key_loc:
77                return self._sweep_pages(gql_file, params, operation_name, paginated_key_loc)
78            return self._query(gql_file, params, operation_name)
79
80        if disable_cache:
81            return run_uncached()
82
83        try:
84            # Create an auxiliary variable to deliver a cached result and enable catching exceptions
85            # Decorate the query to be cached
86            if paginated_key_loc:
87                result = self._sweep_pages_cached(
88                    gql_file, params, operation_name, paginated_key_loc
89                )
90            else:
91                result = self._query_cached(gql_file, params, operation_name)
92            return result  # type: ignore
93        except Exception as ex:
94            logging.error(f"Cached query failed with {ex}")
95            # print exception traceback
96            traceback_str = "".join(traceback.format_exception(ex))
97            logging.error(traceback_str)
98            self.invalidate_query_cache()
99            logging.error("Cache invalidated, retrying without cache")
100        finally:
101            return run_uncached()
102
103    def _query(
104        self,
105        gql_file: Union[Path, str],
106        params: dict[str, Any] = {},
107        operation_name: Optional[str] = None,
108    ) -> dict[str, Any]:
109        # Provide a GraphQL query
110        source_path: Path = Path(__file__).parent
111        pipeline_query_file: Path = source_path / gql_file
112
113        query: DocumentNode
114        with open(pipeline_query_file, "r") as f:
115            pipeline_query = f.read()
116            query = gql(pipeline_query)
117
118        # Execute the query on the transport
119        return self.client.execute_sync(
120            query, variable_values=params, operation_name=operation_name
121        )
122
123    @filecache(DAY)
124    def _sweep_pages_cached(self, *args, **kwargs):
125        return self._sweep_pages(*args, **kwargs)
126
127    @filecache(DAY)
128    def _query_cached(self, *args, **kwargs):
129        return self._query(*args, **kwargs)
130
131    def _sweep_pages(
132        self, query, params, operation_name=None, paginated_key_loc: Iterable[str] = []
133    ) -> dict[str, Any]:
134        """
135        Retrieve paginated data from a GraphQL API and concatenate the results into a single
136        response.
137
138        Args:
139            query: represents a filepath with the GraphQL query to be executed.
140            params: a dictionary that contains the parameters to be passed to the query. These
141                parameters can be used to filter or modify the results of the query.
142            operation_name: The `operation_name` parameter is an optional parameter that specifies
143                the name of the GraphQL operation to be executed. It is used when making a GraphQL
144                query to specify which operation to execute if there are multiple operations defined
145                in the GraphQL schema. If not provided, the default operation will be executed.
146            paginated_key_loc (Iterable[str]): The `paginated_key_loc` parameter is an iterable of
147                strings that represents the location of the paginated field within the response. It
148                is used to extract the paginated field from the response and append it to the final
149                result. The node has to be a list of objects with a `pageInfo` field that contains
150                at least the `hasNextPage` and `endCursor` fields.
151
152        Returns:
153            a dictionary containing the response from the query with the paginated field
154            concatenated.
155        """
156
157        def fetch_page(cursor: str | None = None) -> dict[str, Any]:
158            if cursor:
159                params["cursor"] = cursor
160                logging.info(
161                    f"Found more than 100 elements, paginating. "
162                    f"Current cursor at {cursor}"
163                )
164
165            return self._query(query, params, operation_name)
166
167        # Execute the initial query
168        response: dict[str, Any] = fetch_page()
169
170        # Initialize an empty list to store the final result
171        final_partial_field: list[dict[str, Any]] = []
172
173        # Loop until all pages have been retrieved
174        while True:
175            # Get the partial field to be appended to the final result
176            partial_field = response
177            for key in paginated_key_loc:
178                partial_field = partial_field[key]
179
180            # Append the partial field to the final result
181            final_partial_field += partial_field["nodes"]
182
183            # Check if there are more pages to retrieve
184            page_info = partial_field["pageInfo"]
185            if not page_info["hasNextPage"]:
186                break
187
188            # Execute the query with the updated cursor parameter
189            response = fetch_page(page_info["endCursor"])
190
191        # Replace the "nodes" field in the original response with the final result
192        partial_field["nodes"] = final_partial_field
193        return response
194
195    def invalidate_query_cache(self) -> None:
196        logging.warning("Invalidating query cache")
197        try:
198            self._sweep_pages._db.clear()
199            self._query._db.clear()
200        except AttributeError as ex:
201            logging.warning(f"Could not invalidate cache, maybe it was not used in {ex.args}?")
202
203
204def insert_early_stage_jobs(stage_sequence: StageSeq, jobs_metadata: Dag) -> Dag:
205    pre_processed_dag: dict[str, set[str]] = {}
206    jobs_from_early_stages = list(accumulate(stage_sequence.values(), set.union))
207    for job_name, metadata in jobs_metadata.items():
208        final_needs: set[str] = deepcopy(metadata["needs"])
209        # Pre-process jobs that are not based on needs field
210        # e.g. sanity job in mesa MR pipelines
211        if not final_needs:
212            job_stage: str = jobs_metadata[job_name]["stage"]
213            stage_index: int = list(stage_sequence.keys()).index(job_stage)
214            if stage_index > 0:
215                final_needs |= jobs_from_early_stages[stage_index - 1]
216        pre_processed_dag[job_name] = final_needs
217
218    for job_name, needs in pre_processed_dag.items():
219        jobs_metadata[job_name]["needs"] = needs
220
221    return jobs_metadata
222
223
224def traverse_dag_needs(jobs_metadata: Dag) -> None:
225    created_jobs = set(jobs_metadata.keys())
226    for job, metadata in jobs_metadata.items():
227        final_needs: set = deepcopy(metadata["needs"]) & created_jobs
228        # Post process jobs that are based on needs field
229        partial = True
230
231        while partial:
232            next_depth: set[str] = {n for dn in final_needs for n in jobs_metadata[dn]["needs"]}
233            partial: bool = not final_needs.issuperset(next_depth)
234            final_needs = final_needs.union(next_depth)
235
236        jobs_metadata[job]["needs"] = final_needs
237
238
239def extract_stages_and_job_needs(
240    pipeline_jobs: dict[str, Any], pipeline_stages: dict[str, Any]
241) -> tuple[StageSeq, Dag]:
242    jobs_metadata = Dag()
243    # Record the stage sequence to post process deps that are not based on needs
244    # field, for example: sanity job
245    stage_sequence: OrderedDict[str, set[str]] = OrderedDict()
246    for stage in pipeline_stages["nodes"]:
247        stage_sequence[stage["name"]] = set()
248
249    for job in pipeline_jobs["nodes"]:
250        stage_sequence[job["stage"]["name"]].add(job["name"])
251        dag_job: DagNode = {
252            "name": job["name"],
253            "stage": job["stage"]["name"],
254            "needs": set([j["node"]["name"] for j in job["needs"]["edges"]]),
255        }
256        jobs_metadata[job["name"]] = dag_job
257
258    return stage_sequence, jobs_metadata
259
260
261def create_job_needs_dag(gl_gql: GitlabGQL, params, disable_cache: bool = True) -> Dag:
262    """
263    This function creates a Directed Acyclic Graph (DAG) to represent a sequence of jobs, where each
264    job has a set of jobs that it depends on (its "needs") and belongs to a certain "stage".
265    The "name" of the job is used as the key in the dictionary.
266
267    For example, consider the following DAG:
268
269        1. build stage: job1 -> job2 -> job3
270        2. test stage: job2 -> job4
271
272    - The job needs for job3 are: job1, job2
273    - The job needs for job4 are: job2
274    - The job2 needs to wait all jobs from build stage to finish.
275
276    The resulting DAG would look like this:
277
278        dag = {
279            "job1": {"needs": set(), "stage": "build", "name": "job1"},
280            "job2": {"needs": {"job1", "job2", job3"}, "stage": "test", "name": "job2"},
281            "job3": {"needs": {"job1", "job2"}, "stage": "build", "name": "job3"},
282            "job4": {"needs": {"job2"}, "stage": "test", "name": "job4"},
283        }
284
285    To access the job needs, one can do:
286
287        dag["job3"]["needs"]
288
289    This will return the set of jobs that job3 needs: {"job1", "job2"}
290
291    Args:
292        gl_gql (GitlabGQL): The `gl_gql` parameter is an instance of the `GitlabGQL` class, which is
293            used to make GraphQL queries to the GitLab API.
294        params (dict): The `params` parameter is a dictionary that contains the necessary parameters
295            for the GraphQL query. It is used to specify the details of the pipeline for which the
296            job needs DAG is being created.
297            The specific keys and values in the `params` dictionary will depend on
298            the requirements of the GraphQL query being executed
299        disable_cache (bool): The `disable_cache` parameter is a boolean that specifies whether the
300
301    Returns:
302        The final DAG (Directed Acyclic Graph) representing the job dependencies sourced from needs
303        or stages rule.
304    """
305    stages_jobs_gql = gl_gql.query(
306        "pipeline_details.gql",
307        params=params,
308        paginated_key_loc=["project", "pipeline", "jobs"],
309        disable_cache=disable_cache,
310    )
311    pipeline_data = stages_jobs_gql["project"]["pipeline"]
312    if not pipeline_data:
313        raise RuntimeError(f"Could not find any pipelines for {params}")
314
315    stage_sequence, jobs_metadata = extract_stages_and_job_needs(
316        pipeline_data["jobs"], pipeline_data["stages"]
317    )
318    # Fill the DAG with the job needs from stages that don't have any needs but still need to wait
319    # for previous stages
320    final_dag = insert_early_stage_jobs(stage_sequence, jobs_metadata)
321    # Now that each job has its direct needs filled correctly, update the "needs" field for each job
322    # in the DAG by performing a topological traversal
323    traverse_dag_needs(final_dag)
324
325    return final_dag
326
327
328def filter_dag(dag: Dag, regex: Pattern) -> Dag:
329    jobs_with_regex: set[str] = {job for job in dag if regex.fullmatch(job)}
330    return Dag({job: data for job, data in dag.items() if job in sorted(jobs_with_regex)})
331
332
333def print_dag(dag: Dag) -> None:
334    for job, data in dag.items():
335        print(f"{job}:")
336        print(f"\t{' '.join(data['needs'])}")
337        print()
338
339
340def fetch_merged_yaml(gl_gql: GitlabGQL, params) -> dict[str, Any]:
341    params["content"] = dedent("""\
342    include:
343      - local: .gitlab-ci.yml
344    """)
345    raw_response = gl_gql.query("job_details.gql", params)
346    ci_config = raw_response["ciConfig"]
347    if merged_yaml := ci_config["mergedYaml"]:
348        return yaml.safe_load(merged_yaml)
349    if "errors" in ci_config:
350        for error in ci_config["errors"]:
351            print(error)
352
353    gl_gql.invalidate_query_cache()
354    raise ValueError(
355        """
356    Could not fetch any content for merged YAML,
357    please verify if the git SHA exists in remote.
358    Maybe you forgot to `git push`?  """
359    )
360
361
362def recursive_fill(job, relationship_field, target_data, acc_data: dict, merged_yaml):
363    if relatives := job.get(relationship_field):
364        if isinstance(relatives, str):
365            relatives = [relatives]
366
367        for relative in relatives:
368            parent_job = merged_yaml[relative]
369            acc_data = recursive_fill(parent_job, acc_data, merged_yaml)  # type: ignore
370
371    acc_data |= job.get(target_data, {})
372
373    return acc_data
374
375
376def get_variables(job, merged_yaml, project_path, sha) -> dict[str, str]:
377    p = get_project_root_dir() / ".gitlab-ci" / "image-tags.yml"
378    image_tags = yaml.safe_load(p.read_text())
379
380    variables = image_tags["variables"]
381    variables |= merged_yaml["variables"]
382    variables |= job["variables"]
383    variables["CI_PROJECT_PATH"] = project_path
384    variables["CI_PROJECT_NAME"] = project_path.split("/")[1]
385    variables["CI_REGISTRY_IMAGE"] = "registry.freedesktop.org/${CI_PROJECT_PATH}"
386    variables["CI_COMMIT_SHA"] = sha
387
388    while recurse_among_variables_space(variables):
389        pass
390
391    return variables
392
393
394# Based on: https://stackoverflow.com/a/2158532/1079223
395def flatten(xs):
396    for x in xs:
397        if isinstance(x, Iterable) and not isinstance(x, (str, bytes)):
398            yield from flatten(x)
399        else:
400            yield x
401
402
403def get_full_script(job) -> list[str]:
404    script = []
405    for script_part in ("before_script", "script", "after_script"):
406        script.append(f"# {script_part}")
407        lines = flatten(job.get(script_part, []))
408        script.extend(lines)
409        script.append("")
410
411    return script
412
413
414def recurse_among_variables_space(var_graph) -> bool:
415    updated = False
416    for var, value in var_graph.items():
417        value = str(value)
418        dep_vars = []
419        if match := re.findall(r"(\$[{]?[\w\d_]*[}]?)", value):
420            all_dep_vars = [v.lstrip("${").rstrip("}") for v in match]
421            # print(value, match, all_dep_vars)
422            dep_vars = [v for v in all_dep_vars if v in var_graph]
423
424        for dep_var in dep_vars:
425            dep_value = str(var_graph[dep_var])
426            new_value = var_graph[var]
427            new_value = new_value.replace(f"${{{dep_var}}}", dep_value)
428            new_value = new_value.replace(f"${dep_var}", dep_value)
429            var_graph[var] = new_value
430            updated |= dep_value != new_value
431
432    return updated
433
434
435def print_job_final_definition(job_name, merged_yaml, project_path, sha):
436    job = merged_yaml[job_name]
437    variables = get_variables(job, merged_yaml, project_path, sha)
438
439    print("# --------- variables ---------------")
440    for var, value in sorted(variables.items()):
441        print(f"export {var}={value!r}")
442
443    # TODO: Recurse into needs to get full script
444    # TODO: maybe create a extra yaml file to avoid too much rework
445    script = get_full_script(job)
446    print()
447    print()
448    print("# --------- full script ---------------")
449    print("\n".join(script))
450
451    if image := variables.get("MESA_IMAGE"):
452        print()
453        print()
454        print("# --------- container image ---------------")
455        print(image)
456
457
458def from_sha_to_pipeline_iid(gl_gql: GitlabGQL, params) -> str:
459    result = gl_gql.query("pipeline_utils.gql", params)
460
461    return result["project"]["pipelines"]["nodes"][0]["iid"]
462
463
464def parse_args() -> Namespace:
465    parser = ArgumentParser(
466        formatter_class=ArgumentDefaultsHelpFormatter,
467        description="CLI and library with utility functions to debug jobs via Gitlab GraphQL",
468        epilog=f"""Example:
469        {Path(__file__).name} --print-dag""",
470    )
471    parser.add_argument("-pp", "--project-path", type=str, default="mesa/mesa")
472    parser.add_argument("--sha", "--rev", type=str, default='HEAD')
473    parser.add_argument(
474        "--regex",
475        type=str,
476        required=False,
477        help="Regex pattern for the job name to be considered",
478    )
479    mutex_group_print = parser.add_mutually_exclusive_group()
480    mutex_group_print.add_argument(
481        "--print-dag",
482        action="store_true",
483        help="Print job needs DAG",
484    )
485    mutex_group_print.add_argument(
486        "--print-merged-yaml",
487        action="store_true",
488        help="Print the resulting YAML for the specific SHA",
489    )
490    mutex_group_print.add_argument(
491        "--print-job-manifest",
492        metavar='JOB_NAME',
493        type=str,
494        help="Print the resulting job data"
495    )
496    parser.add_argument(
497        "--gitlab-token-file",
498        type=str,
499        default=get_token_from_default_dir(),
500        help="force GitLab token, otherwise it's read from $XDG_CONFIG_HOME/gitlab-token",
501    )
502
503    args = parser.parse_args()
504    args.gitlab_token = Path(args.gitlab_token_file).read_text().strip()
505    return args
506
507
508def main():
509    args = parse_args()
510    gl_gql = GitlabGQL(token=args.gitlab_token)
511
512    sha = check_output(['git', 'rev-parse', args.sha]).decode('ascii').strip()
513
514    if args.print_dag:
515        iid = from_sha_to_pipeline_iid(gl_gql, {"projectPath": args.project_path, "sha": sha})
516        dag = create_job_needs_dag(
517            gl_gql, {"projectPath": args.project_path, "iid": iid}, disable_cache=True
518        )
519
520        if args.regex:
521            dag = filter_dag(dag, re.compile(args.regex))
522
523        print_dag(dag)
524
525    if args.print_merged_yaml or args.print_job_manifest:
526        merged_yaml = fetch_merged_yaml(
527            gl_gql, {"projectPath": args.project_path, "sha": sha}
528        )
529
530        if args.print_merged_yaml:
531            print(yaml.dump(merged_yaml, indent=2))
532
533        if args.print_job_manifest:
534            print_job_final_definition(
535                args.print_job_manifest, merged_yaml, args.project_path, sha
536            )
537
538
539if __name__ == "__main__":
540    main()
541