• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""Utilities for invoking LLVM tools."""
2
3import asyncio
4import functools
5import os
6import re
7import shlex
8import subprocess
9import typing
10
11_LLVM_VERSION = 18
12_LLVM_VERSION_PATTERN = re.compile(rf"version\s+{_LLVM_VERSION}\.\d+\.\d+\S*\s+")
13
14_P = typing.ParamSpec("_P")
15_R = typing.TypeVar("_R")
16_C = typing.Callable[_P, typing.Awaitable[_R]]
17
18
19def _async_cache(f: _C[_P, _R]) -> _C[_P, _R]:
20    cache = {}
21    lock = asyncio.Lock()
22
23    @functools.wraps(f)
24    async def wrapper(
25        *args: _P.args, **kwargs: _P.kwargs  # pylint: disable = no-member
26    ) -> _R:
27        async with lock:
28            if args not in cache:
29                cache[args] = await f(*args, **kwargs)
30            return cache[args]
31
32    return wrapper
33
34
35_CORES = asyncio.BoundedSemaphore(os.cpu_count() or 1)
36
37
38async def _run(tool: str, args: typing.Iterable[str], echo: bool = False) -> str | None:
39    command = [tool, *args]
40    async with _CORES:
41        if echo:
42            print(shlex.join(command))
43        try:
44            process = await asyncio.create_subprocess_exec(
45                *command, stdout=subprocess.PIPE
46            )
47        except FileNotFoundError:
48            return None
49        out, _ = await process.communicate()
50    if process.returncode:
51        raise RuntimeError(f"{tool} exited with return code {process.returncode}")
52    return out.decode()
53
54
55@_async_cache
56async def _check_tool_version(name: str, *, echo: bool = False) -> bool:
57    output = await _run(name, ["--version"], echo=echo)
58    return bool(output and _LLVM_VERSION_PATTERN.search(output))
59
60
61@_async_cache
62async def _get_brew_llvm_prefix(*, echo: bool = False) -> str | None:
63    output = await _run("brew", ["--prefix", f"llvm@{_LLVM_VERSION}"], echo=echo)
64    return output and output.removesuffix("\n")
65
66
67@_async_cache
68async def _find_tool(tool: str, *, echo: bool = False) -> str | None:
69    # Unversioned executables:
70    path = tool
71    if await _check_tool_version(path, echo=echo):
72        return path
73    # Versioned executables:
74    path = f"{tool}-{_LLVM_VERSION}"
75    if await _check_tool_version(path, echo=echo):
76        return path
77    # Homebrew-installed executables:
78    prefix = await _get_brew_llvm_prefix(echo=echo)
79    if prefix is not None:
80        path = os.path.join(prefix, "bin", tool)
81        if await _check_tool_version(path, echo=echo):
82            return path
83    # Nothing found:
84    return None
85
86
87async def maybe_run(
88    tool: str, args: typing.Iterable[str], echo: bool = False
89) -> str | None:
90    """Run an LLVM tool if it can be found. Otherwise, return None."""
91    path = await _find_tool(tool, echo=echo)
92    return path and await _run(path, args, echo=echo)
93
94
95async def run(tool: str, args: typing.Iterable[str], echo: bool = False) -> str:
96    """Run an LLVM tool if it can be found. Otherwise, raise RuntimeError."""
97    output = await maybe_run(tool, args, echo=echo)
98    if output is None:
99        raise RuntimeError(f"Can't find {tool}-{_LLVM_VERSION}!")
100    return output
101