• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2# Copyright 2023 The ChromiumOS Authors
3# Use of this source code is governed by a BSD-style license that can be
4# found in the LICENSE file.
5
6"""
7Provides general utility functions.
8"""
9
10import argparse
11import contextlib
12import datetime
13import functools
14import os
15import re
16import subprocess
17import sys
18import urllib
19import urllib.request
20import urllib.error
21from pathlib import Path
22from subprocess import DEVNULL, PIPE, STDOUT  # type: ignore
23from typing import (
24    Dict,
25    List,
26    NamedTuple,
27    Optional,
28    Tuple,
29    Union,
30)
31
32PathLike = Union[Path, str]
33
34# Regex that matches ANSI escape sequences
35ANSI_ESCAPE = re.compile(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])")
36
37
38def find_crosvm_root():
39    "Walk up from CWD until we find the crosvm root dir."
40    path = Path("").resolve()
41    while True:
42        if (path / "tools/impl/common.py").is_file():
43            return path
44        if path.parent:
45            path = path.parent
46        else:
47            raise Exception("Cannot find crosvm root dir.")
48
49
50"Root directory of crosvm derived from CWD."
51CROSVM_ROOT = find_crosvm_root()
52
53"Cargo.toml file of crosvm"
54CROSVM_TOML = CROSVM_ROOT / "Cargo.toml"
55
56"""
57Root directory of crosvm devtools.
58
59May be different from `CROSVM_ROOT/tools`, which is allows you to run the crosvm dev
60tools from this directory on another crosvm repo.
61
62Use this if you want to call crosvm dev tools, which will use the scripts relative
63to this file.
64"""
65TOOLS_ROOT = Path(__file__).parent.parent.resolve()
66
67"Cache directory that is preserved between builds in CI."
68CACHE_DIR = Path(os.environ.get("CROSVM_CACHE_DIR", os.environ.get("TMPDIR", "/tmp")))
69
70# Ensure that we really found the crosvm root directory
71assert 'name = "crosvm"' in CROSVM_TOML.read_text()
72
73# List of times recorded by `record_time` which will be printed if --timing-info is provided.
74global_time_records: List[Tuple[str, datetime.timedelta]] = []
75
76
77def crosvm_target_dir():
78    crosvm_target = os.environ.get("CROSVM_TARGET_DIR")
79    cargo_target = os.environ.get("CARGO_TARGET_DIR")
80    if crosvm_target:
81        return Path(crosvm_target)
82    elif cargo_target:
83        return Path(cargo_target) / "crosvm"
84    else:
85        return CROSVM_ROOT / "target/crosvm"
86
87
88@functools.lru_cache(None)
89def parse_common_args():
90    """
91    Parse args common to all scripts
92
93    These args are parsed separately of the run_main/run_commands method so we can access
94    verbose/etc before the commands arguments are parsed.
95    """
96    parser = argparse.ArgumentParser(add_help=False)
97    add_common_args(parser)
98    return parser.parse_known_args()[0]
99
100
101def add_common_args(parser: argparse.ArgumentParser):
102    "These args are added to all commands."
103    parser.add_argument(
104        "--color",
105        default="auto",
106        choices=("always", "never", "auto"),
107        help="Force enable or disable colors. Defaults to automatic detection.",
108    )
109    parser.add_argument(
110        "--verbose",
111        "-v",
112        action="store_true",
113        default=False,
114        help="Print more details about the commands this script is running.",
115    )
116    parser.add_argument(
117        "--very-verbose",
118        "-vv",
119        action="store_true",
120        default=False,
121        help="Print more debug output",
122    )
123    parser.add_argument(
124        "--timing-info",
125        action="store_true",
126        default=False,
127        help="Print info on how long which parts of the command take",
128    )
129
130
131def verbose():
132    return very_verbose() or parse_common_args().verbose
133
134
135def very_verbose():
136    return parse_common_args().very_verbose
137
138
139def color_enabled():
140    color_arg = parse_common_args().color
141    if color_arg == "never":
142        return False
143    if color_arg == "always":
144        return True
145    return sys.stdout.isatty()
146
147
148def find_scripts(path: Path, shebang: str):
149    for file in path.glob("*"):
150        if file.is_file() and file.open(errors="ignore").read(512).startswith(f"#!{shebang}"):
151            yield file
152
153
154def confirm(message: str, default: bool = False):
155    print(message, "[y/N]" if default == False else "[Y/n]", end=" ", flush=True)
156    response = sys.stdin.readline().strip()
157    if response in ("y", "Y"):
158        return True
159    if response in ("n", "N"):
160        return False
161    return default
162
163
164def is_cros_repo():
165    "Returns true if the crosvm repo is a symlink or worktree to a CrOS repo checkout."
166    dot_git = CROSVM_ROOT / ".git"
167    if not dot_git.is_symlink() and dot_git.is_dir():
168        return False
169    return (cros_repo_root() / ".repo").exists()
170
171
172def cros_repo_root():
173    "Root directory of the CrOS repo checkout."
174    return (CROSVM_ROOT / "../../..").resolve()
175
176
177def is_kiwi_repo():
178    "Returns true if the crosvm repo contains .kiwi_repo file."
179    dot_kiwi_repo = CROSVM_ROOT / ".kiwi_repo"
180    return dot_kiwi_repo.exists()
181
182
183def kiwi_repo_root():
184    "Root directory of the kiwi repo checkout."
185    return (CROSVM_ROOT / "../..").resolve()
186
187def is_aosp_repo():
188    "Returns true if the crosvm repo is an AOSP repo checkout."
189    android_bp = CROSVM_ROOT / "Android.bp"
190    return android_bp.exists()
191
192def aosp_repo_root():
193    "Root directory of AOSP repo checkout."
194    return (CROSVM_ROOT / "../..").resolve()
195
196def is_aosp_repo():
197    "Returns true if the crosvm repo is an AOSP repo checkout."
198    android_bp = CROSVM_ROOT / "Android.bp"
199    return android_bp.exists()
200
201
202def aosp_repo_root():
203    "Root directory of AOSP repo checkout."
204    return (CROSVM_ROOT / "../..").resolve()
205
206
207def sudo_is_passwordless():
208    # Run with --askpass but no askpass set, succeeds only if passwordless sudo
209    # is available.
210    (ret, _) = subprocess.getstatusoutput("SUDO_ASKPASS=false sudo --askpass true")
211    return ret == 0
212
213
214SHORTHANDS = {
215    "mingw64": "x86_64-pc-windows-gnu",
216    "msvc64": "x86_64-pc-windows-msvc",
217    "armhf": "armv7-unknown-linux-gnueabihf",
218    "aarch64": "aarch64-unknown-linux-gnu",
219    "riscv64": "riscv64gc-unknown-linux-gnu",
220    "x86_64": "x86_64-unknown-linux-gnu",
221}
222
223
224class Triple(NamedTuple):
225    """
226    Build triple in cargo format.
227
228    The format is: <arch><sub>-<vendor>-<sys>-<abi>, However, we will treat <arch><sub> as a single
229    arch to simplify things.
230    """
231
232    arch: str
233    vendor: str
234    sys: Optional[str]
235    abi: Optional[str]
236
237    @classmethod
238    def from_shorthand(cls, shorthand: str):
239        "These shorthands make it easier to specify triples on the command line."
240        if "-" in shorthand:
241            triple = shorthand
242        elif shorthand in SHORTHANDS:
243            triple = SHORTHANDS[shorthand]
244        else:
245            raise Exception(f"Not a valid build triple shorthand: {shorthand}")
246        return cls.from_str(triple)
247
248    @classmethod
249    def from_str(cls, triple: str):
250        parts = triple.split("-")
251        if len(parts) < 2:
252            raise Exception(f"Unsupported triple {triple}")
253        return cls(
254            parts[0],
255            parts[1],
256            parts[2] if len(parts) > 2 else None,
257            parts[3] if len(parts) > 3 else None,
258        )
259
260    @classmethod
261    def from_linux_arch(cls, arch: str):
262        "Rough logic to convert the output of `arch` into a corresponding linux build triple."
263        if arch == "armhf":
264            return cls.from_str("armv7-unknown-linux-gnueabihf")
265        else:
266            return cls.from_str(f"{arch}-unknown-linux-gnu")
267
268    @classmethod
269    def host_default(cls):
270        "Returns the default build triple of the host."
271        rustc_info = subprocess.check_output(["rustc", "-vV"], text=True)
272        match = re.search(r"host: (\S+)", rustc_info)
273        if not match:
274            raise Exception(f"Cannot parse rustc info: {rustc_info}")
275        return cls.from_str(match.group(1))
276
277    @property
278    def feature_flag(self):
279        triple_to_shorthand = {v: k for k, v in SHORTHANDS.items()}
280        shorthand = triple_to_shorthand.get(str(self))
281        if not shorthand:
282            raise Exception(f"No feature set for triple {self}")
283        return f"all-{shorthand}"
284
285    @property
286    def target_dir(self):
287        return crosvm_target_dir() / str(self)
288
289    def get_cargo_env(self):
290        """Environment variables to make cargo use the test target."""
291        env: Dict[str, str] = {}
292        cargo_target = str(self)
293        env["CARGO_BUILD_TARGET"] = cargo_target
294        env["CARGO_TARGET_DIR"] = str(self.target_dir)
295        env["CROSVM_TARGET_DIR"] = str(crosvm_target_dir())
296        return env
297
298    def __str__(self):
299        return f"{self.arch}-{self.vendor}-{self.sys}-{self.abi}"
300
301
302def download_file(url: str, filename: Path, attempts: int = 3):
303    assert attempts > 0
304    while True:
305        attempts -= 1
306        try:
307            urllib.request.urlretrieve(url, filename)
308            return
309        except Exception as e:
310            if attempts == 0:
311                raise e
312            else:
313                print("Download failed:", e)
314
315
316def strip_ansi_escape_sequences(line: str) -> str:
317    return ANSI_ESCAPE.sub("", line)
318
319
320def ensure_packages_exist(*packages: str):
321    """
322    Exits if one of the listed packages does not exist.
323    """
324    missing_packages: List[str] = []
325
326    for package in packages:
327        try:
328            __import__(package)
329        except ImportError:
330            missing_packages.append(package)
331
332    if missing_packages:
333        debian_packages = [f"python3-{p}" for p in missing_packages]
334        package_list = " ".join(debian_packages)
335        print("Missing python dependencies. Please re-run ./tools/install-deps")
336        print(f"Or `sudo apt install {package_list}`")
337        sys.exit(1)
338
339
340@contextlib.contextmanager
341def record_time(title: str):
342    """
343    Records wall-time of how long this context lasts.
344
345    The results will be printed at the end of script executation if --timing-info is specified.
346    """
347    start_time = datetime.datetime.now()
348    try:
349        yield
350    finally:
351        global_time_records.append((title, datetime.datetime.now() - start_time))
352
353
354def print_timing_info():
355    print()
356    print("Timing info:")
357    print()
358    for title, delta in global_time_records:
359        print(f"  {title:20} {delta.total_seconds():.2f}s")
360