• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2# Copyright 2022 The ChromiumOS Authors
3# Use of this source code is governed by a BSD-style license that can be
4# found in the LICENSE file.
5
6"""
7Provides helpers for writing shell-like scripts in Python.
8
9It provides tools to execute commands with similar flexibility to shell scripts and simplifies
10command line arguments using `argh` and provides common flags (e.g. -v and -vv) for all of
11our command line tools.
12
13Refer to the scripts in ./tools for example usage.
14"""
15
16# Import preamble before anything else
17from . import preamble  # type: ignore
18
19import argparse
20import contextlib
21import csv
22import datetime
23import functools
24import getpass
25import json
26import os
27import re
28import shlex
29import shutil
30import subprocess
31import sys
32import traceback
33import urllib
34import urllib.request
35import urllib.error
36from copy import deepcopy
37from io import StringIO
38from math import ceil
39from multiprocessing.pool import ThreadPool
40from pathlib import Path
41from subprocess import DEVNULL, PIPE, STDOUT  # type: ignore
42from tempfile import gettempdir
43from typing import (
44    Any,
45    Callable,
46    Dict,
47    Iterable,
48    List,
49    NamedTuple,
50    Optional,
51    Tuple,
52    TypeVar,
53    Union,
54    cast,
55)
56
57import argh  # type: ignore
58import rich
59import rich.console
60import rich.live
61import rich.spinner
62import rich.text
63
64
65# File where to store http headers for gcloud authentication
66AUTH_HEADERS_FILE = Path(gettempdir()) / f"crosvm_gcloud_auth_headers_{getpass.getuser()}"
67
68PathLike = Union[Path, str]
69
70
71def find_crosvm_root():
72    "Walk up from CWD until we find the crosvm root dir."
73    path = Path("").resolve()
74    while True:
75        if (path / "tools/impl/common.py").is_file():
76            return path
77        if path.parent:
78            path = path.parent
79        else:
80            raise Exception("Cannot find crosvm root dir.")
81
82
83"Root directory of crosvm derived from CWD."
84CROSVM_ROOT = find_crosvm_root()
85
86"Cargo.toml file of crosvm"
87CROSVM_TOML = CROSVM_ROOT / "Cargo.toml"
88
89"""
90Root directory of crosvm devtools.
91
92May be different from `CROSVM_ROOT/tools`, which is allows you to run the crosvm dev
93tools from this directory on another crosvm repo.
94
95Use this if you want to call crosvm dev tools, which will use the scripts relative
96to this file.
97"""
98TOOLS_ROOT = Path(__file__).parent.parent.resolve()
99
100"Cache directory that is preserved between builds in CI."
101CACHE_DIR = Path(os.environ.get("CROSVM_CACHE_DIR", os.environ.get("TMPDIR", "/tmp")))
102
103"Url of crosvm's gerrit review host"
104GERRIT_URL = "https://chromium-review.googlesource.com"
105
106# Ensure that we really found the crosvm root directory
107assert 'name = "crosvm"' in CROSVM_TOML.read_text()
108
109# List of times recorded by `record_time` which will be printed if --timing-info is provided.
110global_time_records: List[Tuple[str, datetime.timedelta]] = []
111
112
113def crosvm_target_dir():
114    crosvm_target = os.environ.get("CROSVM_TARGET_DIR")
115    cargo_target = os.environ.get("CARGO_TARGET_DIR")
116    if crosvm_target:
117        return Path(crosvm_target)
118    elif cargo_target:
119        return Path(cargo_target) / "crosvm"
120    else:
121        return CROSVM_ROOT / "target/crosvm"
122
123
124class CommandResult(NamedTuple):
125    """Results of a command execution as returned by Command.run()"""
126
127    stdout: str
128    stderr: str
129    returncode: int
130
131
132class Command(object):
133    """
134    Simplified subprocess handling for shell-like scripts.
135
136    ## Example Usage
137
138    To run a program on behalf of the user:
139
140    >> cmd("cargo build").fg()
141
142    This will run the program with stdio passed to the user. Developer tools usually run a set of
143    actions on behalf of the user. These should be executed with fg().
144
145    To make calls in the background to gather information use success/stdout/lines:
146
147    >> cmd("git branch").lines()
148    >> cmd("git rev-parse foo").success()
149
150    These will capture all program output. Try to avoid using these to run mutating commands,
151    as they will remain hidden to the user even when using --verbose.
152
153    ## Arguments
154
155    Arguments are provided as a list similar to subprocess.run():
156
157    >>> Command('cargo', 'build', '--workspace')
158    Command('cargo', 'build', '--workspace')
159
160    In contrast to subprocess.run, all strings are split by whitespaces similar to bash:
161
162    >>> Command('cargo build --workspace', '--features foo')
163    Command('cargo', 'build', '--workspace', '--features', 'foo')
164
165    In contrast to bash, globs are *not* evaluated, but can easily be provided using Path:
166
167    >>> Command('ls -l', *Path(CROSVM_ROOT).glob('*.toml'))
168    Command('ls', '-l', ...)
169
170    None or False are ignored to make it easy to include conditional arguments:
171
172    >>> all = False
173    >>> Command('cargo build', '--workspace' if all else None)
174    Command('cargo', 'build')
175
176    ## Nesting
177
178    Commands can be nested, similar to $() subshells in bash. The sub-commands will be executed
179    right away and their output will undergo the usual splitting:
180
181    >>> Command('printf "(%s)"', Command('echo foo bar')).stdout()
182    '(foo)(bar)'
183
184    Arguments can be explicitly quoted to prevent splitting, it applies to both sub-commands
185    as well as strings:
186
187    >>> Command('printf "(%s)"', quoted(Command('echo foo bar'))).stdout()
188    '(foo bar)'
189
190    Commands can also be piped into one another:
191
192    >>> wc = Command('wc')
193    >>> Command('echo "abcd"').pipe(wc('-c')).stdout()
194    '5'
195
196    ## Verbosity
197
198    The --verbose flag is intended for users and will show all command lines executed in the
199    foreground with fg(), it'll also include output of programs run with fg(quiet=True). Commands
200    executed in the background are not shown.
201
202    For script developers, the --very-verbose flag will print full details and output of all
203    executed command lines, including those run hidden from the user.
204    """
205
206    def __init__(
207        self,
208        *args: Any,
209        stdin_cmd: Optional["Command"] = None,
210        env_vars: Dict[str, str] = {},
211        cwd: Optional[Path] = None,
212    ):
213        self.args = Command.__parse_cmd(args)
214        self.stdin_cmd = stdin_cmd
215        self.env_vars = env_vars
216        self.cwd = cwd
217
218    ### Builder API to construct commands
219
220    def with_args(self, *args: Any):
221        """Returns a new Command with added arguments.
222
223        >>> cargo = Command('cargo')
224        >>> cargo.with_args('clippy')
225        Command('cargo', 'clippy')
226        """
227        cmd = deepcopy(self)
228        cmd.args = [*self.args, *Command.__parse_cmd(args)]
229        return cmd
230
231    def with_cwd(self, cwd: Optional[Path]):
232        """Changes the working directory the command is executed in.
233
234        >>> cargo = Command('pwd')
235        >>> cargo.with_cwd('/tmp').stdout()
236        '/tmp'
237        """
238        cmd = deepcopy(self)
239        cmd.cwd = cwd
240        return cmd
241
242    def __call__(self, *args: Any):
243        """Shorthand for Command.with_args"""
244        return self.with_args(*args)
245
246    def with_env(self, key: str, value: Optional[str]):
247        """
248        Returns a command with an added env variable.
249
250        The variable is removed if value is None.
251        """
252        return self.with_envs({key: value})
253
254    def with_envs(self, envs: Union[Dict[str, Optional[str]], Dict[str, str]]):
255        """
256        Returns a command with an added env variable.
257
258        The variable is removed if value is None.
259        """
260        cmd = deepcopy(self)
261        for key, value in envs.items():
262            if value is not None:
263                cmd.env_vars[key] = value
264            else:
265                if key in cmd.env_vars:
266                    del cmd.env_vars[key]
267        return cmd
268
269    def with_path_env(self, new_path: str):
270        """Returns a command with a path added to the PATH variable."""
271        path_var = self.env_vars.get("PATH", os.environ.get("PATH", ""))
272        return self.with_env("PATH", f"{path_var}:{new_path}")
273
274    def with_color_arg(
275        self,
276        always: Optional[str] = None,
277        never: Optional[str] = None,
278    ):
279        """Returns a command with an argument added to pass through enabled/disabled colors."""
280        new_cmd = self
281        if color_enabled():
282            if always:
283                new_cmd = new_cmd(always)
284        else:
285            if never:
286                new_cmd = new_cmd(never)
287        return new_cmd
288
289    def with_color_env(self, var_name: str):
290        """Returns a command with an env var added to pass through enabled/disabled colors."""
291        return self.with_env(var_name, "1" if color_enabled() else "0")
292
293    def with_color_flag(self, flag: str = "--color"):
294        """Returns a command with an added --color=always/never/auto flag."""
295        return self.with_color_arg(always=f"{flag}=always", never=f"{flag}=never")
296
297    def foreach(self, arguments: Iterable[Any], batch_size: int = 1):
298        """
299        Yields a new command for each entry in `arguments`.
300
301        The argument is appended to each command and is intended to be used in
302        conjunction with `parallel()` to execute a command on a list of arguments in
303        parallel.
304
305        >>> parallel(*cmd('echo').foreach((1, 2, 3))).stdout()
306        ['1', '2', '3']
307
308        Arguments can also be batched by setting batch_size > 1, which will append multiple
309        arguments to each command.
310
311        >>> parallel(*cmd('echo').foreach((1, 2, 3), batch_size=2)).stdout()
312        ['1 2', '3']
313
314        """
315        for batch in batched(arguments, batch_size):
316            yield self(*batch)
317
318    def pipe(self, *args: Any):
319        """
320        Pipes the output of this command into another process.
321
322        The target can either be another Command or the argument list to build a new command.
323        """
324        if len(args) == 1 and isinstance(args[0], Command):
325            cmd = Command(stdin_cmd=self)
326            cmd.args = args[0].args
327            cmd.env_vars = self.env_vars.copy()
328            return cmd
329        else:
330            return Command(*args, stdin_cmd=self, env_vars=self.env_vars)
331
332    ### Executing programs in the foreground
333
334    def run_foreground(
335        self,
336        quiet: bool = False,
337        check: bool = True,
338        dry_run: bool = False,
339        style: Optional[Callable[["subprocess.Popen[str]"], None]] = None,
340    ):
341        """
342        Runs a program in the foreground with output streamed to the user.
343
344        >>> Command('true').fg()
345        0
346
347        Non-zero exit codes will trigger an Exception
348
349        >>> Command('false').fg()
350        Traceback (most recent call last):
351        ...
352        subprocess.CalledProcessError...
353
354        But can be disabled:
355
356        >>> Command('false').fg(check=False)
357        1
358
359        Output can be hidden by setting quiet=True:
360
361        >>> Command("echo foo").fg(quiet=True)
362        0
363
364        This will hide the programs stdout and stderr unless the program fails.
365
366        More sophisticated means of outputting stdout/err are available via `Styles`:
367
368        >>> Command("echo foo").fg(style=Styles.live_truncated())
369370        foo
371        0
372
373        Will output the results of the command but truncate output after a few lines. See `Styles`
374        for more options.
375
376        Arguments:
377            quiet: Do not show stdout/stderr unless the program failed.
378            check: Raise an exception if the program returned an error code.
379            style: A function to present the output of the program. See `Styles`
380
381        Returns: The return code of the program.
382        """
383        if dry_run:
384            print(f"Not running: {self}")
385            return 0
386
387        if quiet:
388            style = Styles.quiet
389
390        if verbose():
391            print(f"$ {self}")
392
393        if style is None or verbose():
394            return self.__run(stdout=None, stderr=None, check=check).returncode
395        else:
396            process = self.popen(stderr=STDOUT)
397            style(process)
398            returncode = process.wait()
399            if returncode != 0 and check:
400                assert process.stdout
401                raise subprocess.CalledProcessError(returncode, process.args)
402            return returncode
403
404    def fg(
405        self,
406        quiet: bool = False,
407        check: bool = True,
408        dry_run: bool = False,
409        style: Optional[Callable[["subprocess.Popen[str]"], None]] = None,
410    ):
411        """
412        Shorthand for Command.run_foreground()
413        """
414        return self.run_foreground(quiet, check, dry_run, style)
415
416    def write_to(self, filename: Path):
417        """
418        Writes stdout to the provided file.
419        """
420        if verbose():
421            print(f"$ {self} > {filename}")
422        with open(filename, "w") as file:
423            file.write(self.__run(stdout=PIPE, stderr=PIPE).stdout)
424
425    def append_to(self, filename: Path):
426        """
427        Appends stdout to the provided file.
428        """
429        if verbose():
430            print(f"$ {self} >> {filename}")
431        with open(filename, "a") as file:
432            file.write(self.__run(stdout=PIPE, stderr=PIPE).stdout)
433
434    ### API for executing commands hidden from the user
435
436    def success(self):
437        """
438        Returns True if the program succeeded (i.e. returned 0).
439
440        The program will not be visible to the user unless --very-verbose is specified.
441        """
442        if very_verbose():
443            print(f"$ {self}")
444        return self.__run(stdout=PIPE, stderr=PIPE, check=False).returncode == 0
445
446    def stdout(self, check: bool = True, stderr: int = PIPE):
447        """
448        Runs a program and returns stdout.
449
450        The program will not be visible to the user unless --very-verbose is specified.
451        """
452        if very_verbose():
453            print(f"$ {self}")
454        return self.__run(stdout=PIPE, stderr=stderr, check=check).stdout.strip()
455
456    def json(self, check: bool = True) -> Any:
457        """
458        Runs a program and returns stdout parsed as json.
459
460        The program will not be visible to the user unless --very-verbose is specified.
461        """
462        stdout = self.stdout(check=check)
463        if stdout:
464            return json.loads(stdout)
465        else:
466            return None
467
468    def lines(self, check: bool = True, stderr: int = PIPE):
469        """
470        Runs a program and returns stdout line by line.
471
472        The program will not be visible to the user unless --very-verbose is specified.
473        """
474        return self.stdout(check=check, stderr=stderr).splitlines()
475
476    ### Utilities
477
478    def __str__(self):
479        stdin = ""
480        if self.stdin_cmd:
481            stdin = str(self.stdin_cmd) + " | "
482        return stdin + shlex.join(self.args)
483
484    def __repr__(self):
485        stdin = ""
486        if self.stdin_cmd:
487            stdin = ", stdin_cmd=" + repr(self.stdin_cmd)
488        return f"Command({', '.join(repr(a) for a in self.args)}{stdin})"
489
490    ### Private implementation details
491
492    def __run(
493        self,
494        stdout: Optional[int],
495        stderr: Optional[int],
496        check: bool = True,
497    ) -> CommandResult:
498        "Run this command in subprocess.run()"
499        if very_verbose():
500            print(f"cwd: {Path().resolve()}")
501            for k, v in self.env_vars.items():
502                print(f"env: {k}={v}")
503        result = subprocess.run(
504            self.args,
505            cwd=self.cwd,
506            stdout=stdout,
507            stderr=stderr,
508            stdin=self.__stdin_stream(),
509            env={**os.environ, **self.env_vars},
510            check=check,
511            text=True,
512        )
513        if very_verbose():
514            if result.stdout:
515                for line in result.stdout.splitlines():
516                    print("stdout:", line)
517            if result.stderr:
518                for line in result.stderr.splitlines():
519                    print("stderr:", line)
520            print("returncode:", result.returncode)
521        if check and result.returncode != 0:
522            raise subprocess.CalledProcessError(result.returncode, str(self), result.stdout)
523        return CommandResult(result.stdout, result.stderr, result.returncode)
524
525    def __stdin_stream(self):
526        if self.stdin_cmd:
527            return self.stdin_cmd.popen().stdout
528        return None
529
530    def popen(self, stderr: Optional[int] = PIPE) -> "subprocess.Popen[str]":
531        """
532        Runs a program and returns the Popen object of the running process.
533        """
534        return subprocess.Popen(
535            self.args,
536            cwd=self.cwd,
537            stdout=subprocess.PIPE,
538            stderr=stderr,
539            stdin=self.__stdin_stream(),
540            env={**os.environ, **self.env_vars},
541            text=True,
542        )
543
544    @staticmethod
545    def __parse_cmd(args: Iterable[Any]) -> List[str]:
546        """Parses command line arguments for Command."""
547        res = [parsed for arg in args for parsed in Command.__parse_cmd_args(arg)]
548        return res
549
550    @staticmethod
551    def __parse_cmd_args(arg: Any) -> List[str]:
552        """Parses a mixed type command line argument into a list of strings."""
553        if isinstance(arg, Path):
554            return [str(arg)]
555        elif isinstance(arg, QuotedString):
556            return [arg.value]
557        elif isinstance(arg, Command):
558            return [*shlex.split(arg.stdout())]
559        elif arg is None or arg is False:
560            return []
561        else:
562            return [*shlex.split(str(arg))]
563
564
565class Styles(object):
566    "A collection of methods that can be passed to `Command.fg(style=)`"
567
568    @staticmethod
569    def quiet(process: "subprocess.Popen[str]"):
570        "Won't print anything unless the command failed."
571        assert process.stdout
572        stdout = process.stdout.read()
573        if process.wait() != 0:
574            print(stdout, end="")
575
576    @staticmethod
577    def live_truncated(num_lines: int = 8):
578        "Prints only the last `num_lines` of output while the program is running and succeessful."
579
580        def output(process: "subprocess.Popen[str]"):
581            assert process.stdout
582            spinner = rich.spinner.Spinner("dots")
583            lines: List[rich.text.Text] = []
584            stdout: List[str] = []
585            with rich.live.Live(refresh_per_second=30, transient=True) as live:
586                for line in iter(process.stdout.readline, ""):
587                    stdout.append(line.strip())
588                    lines.append(rich.text.Text.from_ansi(line.strip(), no_wrap=True))
589                    while len(lines) > num_lines:
590                        lines.pop(0)
591                    live.update(rich.console.Group(rich.text.Text("…"), *lines, spinner))
592            if process.wait() == 0:
593                console.print(rich.console.Group(rich.text.Text("…"), *lines))
594            else:
595                for line in stdout:
596                    print(line)
597
598        return output
599
600    @staticmethod
601    def quiet_with_progress(title: str):
602        "Prints only the last `num_lines` of output while the program is running and succeessful."
603
604        def output(process: "subprocess.Popen[str]"):
605            assert process.stdout
606            with rich.live.Live(
607                rich.spinner.Spinner("dots", title), refresh_per_second=30, transient=True
608            ):
609                stdout = process.stdout.read()
610
611            if process.wait() == 0:
612                console.print(f"[green]OK[/green] {title}")
613            else:
614                print(stdout)
615                console.print(f"[red]ERR[/red] {title}")
616
617        return output
618
619
620class ParallelCommands(object):
621    """
622    Allows commands to be run in parallel.
623
624    >>> parallel(cmd('true'), cmd('false')).fg(check=False)
625    [0, 1]
626
627    >>> parallel(cmd('echo a'), cmd('echo b')).stdout()
628    ['a', 'b']
629    """
630
631    def __init__(self, *commands: Command):
632        self.commands = commands
633
634    def fg(self, quiet: bool = False, check: bool = True):
635        with ThreadPool(1 if very_verbose() else os.cpu_count()) as pool:
636            return pool.map(lambda command: command.fg(quiet=quiet, check=check), self.commands)
637
638    def stdout(self):
639        with ThreadPool(1 if very_verbose() else os.cpu_count()) as pool:
640            return pool.map(lambda command: command.stdout(), self.commands)
641
642    def success(self):
643        results = self.fg(check=False, quiet=True)
644        return all(result == 0 for result in results)
645
646
647class Remote(object):
648    """
649    Wrapper around the cmd() API and allow execution of commands via SSH."
650    """
651
652    def __init__(self, host: str, opts: Dict[str, str]):
653        self.host = host
654        ssh_opts = [f"-o{k}={v}" for k, v in opts.items()]
655        self.ssh_cmd = cmd("ssh", host, "-T", *ssh_opts)
656        self.scp_cmd = cmd("scp", *ssh_opts)
657
658    def ssh(self, cmd: Command, remote_cwd: Optional[Path] = None):
659        # Use huponexit to ensure the process is killed if the connection is lost.
660        # Use shlex to properly quote the command.
661        wrapped_cmd = f"bash -O huponexit -c {shlex.quote(str(cmd))}"
662        if remote_cwd is not None:
663            wrapped_cmd = f"cd {remote_cwd} && {wrapped_cmd}"
664        # The whole command to pass it to SSH for remote execution.
665        return self.ssh_cmd.with_args(quoted(wrapped_cmd))
666
667    def scp(self, sources: List[Path], target: str, quiet: bool = False):
668        return self.scp_cmd.with_args(*sources, f"{self.host}:{target}").fg(quiet=quiet)
669
670
671@contextlib.contextmanager
672def record_time(title: str):
673    """
674    Records wall-time of how long this context lasts.
675
676    The results will be printed at the end of script executation if --timing-info is specified.
677    """
678    start_time = datetime.datetime.now()
679    try:
680        yield
681    finally:
682        global_time_records.append((title, datetime.datetime.now() - start_time))
683
684
685@contextlib.contextmanager
686def cwd_context(path: PathLike):
687    """Context for temporarily changing the cwd.
688
689    >>> with cwd('/tmp'):
690    ...     os.getcwd()
691    '/tmp'
692
693    """
694    cwd = os.getcwd()
695    try:
696        chdir(path)
697        yield
698    finally:
699        chdir(cwd)
700
701
702def chdir(path: PathLike):
703    if very_verbose():
704        print("cd", path)
705    os.chdir(path)
706
707
708class QuotedString(object):
709    """
710    Prevents the provided string from being split.
711
712    Commands will be executed and their stdout is quoted.
713    """
714
715    def __init__(self, value: Any):
716        if isinstance(value, Command):
717            self.value = value.stdout()
718        else:
719            self.value = str(value)
720
721    def __str__(self):
722        return f'"{self.value}"'
723
724
725T = TypeVar("T")
726
727
728def batched(source: Iterable[T], max_batch_size: int) -> Iterable[List[T]]:
729    """
730    Returns an iterator over batches of elements from source_list.
731
732    >>> list(batched([1, 2, 3, 4, 5], 2))
733    [[1, 2], [3, 4], [5]]
734    """
735    source_list = list(source)
736    # Calculate batch size that spreads elements evenly across all batches
737    batch_count = ceil(len(source_list) / max_batch_size)
738    batch_size = ceil(len(source_list) / batch_count)
739    for index in range(0, len(source_list), batch_size):
740        yield source_list[index : min(index + batch_size, len(source_list))]
741
742
743# Shorthands
744quoted = QuotedString
745cmd = Command
746cwd = cwd_context
747parallel = ParallelCommands
748
749
750def run_main(main_fn: Callable[..., Any], usage: Optional[str] = None):
751    run_commands(default_fn=main_fn, usage=usage)
752
753
754def run_commands(
755    *functions: Callable[..., Any],
756    default_fn: Optional[Callable[..., Any]] = None,
757    usage: Optional[str] = None,
758):
759    """
760    Allow the user to call the provided functions with command line arguments translated to
761    function arguments via argh: https://pythonhosted.org/argh
762    """
763    exit_code = 0
764    try:
765        parser = argparse.ArgumentParser(
766            description=usage,
767            # Docstrings are used as the description in argparse, preserve their formatting.
768            formatter_class=argparse.RawDescriptionHelpFormatter,
769            # Do not allow implied abbreviations. Abbreviations should be manually specified.
770            allow_abbrev=False,
771        )
772        add_common_args(parser)
773
774        # Add provided commands to parser. Do not use sub-commands if we just got one function.
775        if functions:
776            argh.add_commands(parser, functions)  # type: ignore
777        if default_fn:
778            argh.set_default_command(parser, default_fn)  # type: ignore
779
780        with record_time("Total Time"):
781            # Call main method
782            argh.dispatch(parser)  # type: ignore
783
784    except Exception as e:
785        if verbose():
786            traceback.print_exc()
787        else:
788            print(e)
789        exit_code = 1
790
791    if parse_common_args().timing_info:
792        print_timing_info()
793
794    sys.exit(exit_code)
795
796
797def print_timing_info():
798    console.print()
799    console.print("Timing info:")
800    console.print()
801    for title, delta in global_time_records:
802        console.print(f"  {title:20} {delta.total_seconds():.2f}s")
803
804
805@functools.lru_cache(None)
806def parse_common_args():
807    """
808    Parse args common to all scripts
809
810    These args are parsed separately of the run_main/run_commands method so we can access
811    verbose/etc before the commands arguments are parsed.
812    """
813    parser = argparse.ArgumentParser(add_help=False)
814    add_common_args(parser)
815    return parser.parse_known_args()[0]
816
817
818def add_common_args(parser: argparse.ArgumentParser):
819    "These args are added to all commands."
820    parser.add_argument(
821        "--color",
822        default="auto",
823        choices=("always", "never", "auto"),
824        help="Force enable or disable colors. Defaults to automatic detection.",
825    )
826    parser.add_argument(
827        "--verbose",
828        "-v",
829        action="store_true",
830        default=False,
831        help="Print more details about the commands this script is running.",
832    )
833    parser.add_argument(
834        "--very-verbose",
835        "-vv",
836        action="store_true",
837        default=False,
838        help="Print more debug output",
839    )
840    parser.add_argument(
841        "--timing-info",
842        action="store_true",
843        default=False,
844        help="Print info on how long which parts of the command take",
845    )
846
847
848def verbose():
849    return very_verbose() or parse_common_args().verbose
850
851
852def very_verbose():
853    return parse_common_args().very_verbose
854
855
856def color_enabled():
857    color_arg = parse_common_args().color
858    if color_arg == "never":
859        return False
860    if color_arg == "always":
861        return True
862    return sys.stdout.isatty()
863
864
865def all_tracked_files():
866    for line in cmd("git ls-files").lines():
867        file = Path(line)
868        if file.is_file():
869            yield file
870
871
872def find_source_files(extension: str, ignore: List[str] = []):
873    for file in all_tracked_files():
874        if file.suffix != f".{extension}":
875            continue
876        if file.is_relative_to("third_party"):
877            continue
878        if str(file) in ignore:
879            continue
880        yield file
881
882
883def find_scripts(path: Path, shebang: str):
884    for file in path.glob("*"):
885        if file.is_file() and file.open(errors="ignore").read(512).startswith(f"#!{shebang}"):
886            yield file
887
888
889def confirm(message: str, default: bool = False):
890    print(message, "[y/N]" if default == False else "[Y/n]", end=" ", flush=True)
891    response = sys.stdin.readline().strip()
892    if response in ("y", "Y"):
893        return True
894    if response in ("n", "N"):
895        return False
896    return default
897
898
899def get_cookie_file():
900    path = cmd("git config http.cookiefile").stdout(check=False)
901    return Path(path) if path else None
902
903
904def get_gcloud_access_token():
905    if not shutil.which("gcloud"):
906        return None
907    return cmd("gcloud auth print-access-token").stdout(check=False)
908
909
910@functools.lru_cache(maxsize=None)
911def curl_with_git_auth():
912    """
913    Returns a curl `Command` instance set up to use the same HTTP credentials as git.
914
915    This currently supports two methods:
916    - git cookies (the default)
917    - gcloud
918
919    Most developers will use git cookies, which are passed to curl.
920
921    glloud for authorization can be enabled in git via `git config credential.helper gcloud.sh`.
922    If enabled in git, this command will also return a curl command using a gloud access token.
923    """
924    helper = cmd("git config credential.helper").stdout(check=False)
925
926    if not helper:
927        cookie_file = get_cookie_file()
928        if not cookie_file or not cookie_file.is_file():
929            raise Exception("git http cookiefile is not available.")
930        return cmd("curl --cookie", cookie_file)
931
932    if helper.endswith("gcloud.sh"):
933        token = get_gcloud_access_token()
934        if not token:
935            raise Exception("Cannot get gcloud access token.")
936        # File where to store http headers for gcloud authentication
937        AUTH_HEADERS_FILE = Path(gettempdir()) / f"crosvm_gcloud_auth_headers_{getpass.getuser()}"
938
939        # Write token to a header file so it will not appear in logs or error messages.
940        AUTH_HEADERS_FILE.write_text(f"Authorization: Bearer {token}")
941        return cmd(f"curl -H @{AUTH_HEADERS_FILE}")
942
943    raise Exception(f"Unsupported git credentials.helper: {helper}")
944
945
946def strip_xssi(response: str):
947    # See https://gerrit-review.googlesource.com/Documentation/rest-api.html#output
948    assert response.startswith(")]}'\n")
949    return response[5:]
950
951
952def gerrit_api_get(path: str):
953    response = cmd(f"curl --silent --fail {GERRIT_URL}/{path}").stdout()
954    return json.loads(strip_xssi(response))
955
956
957def gerrit_api_post(path: str, body: Any):
958    response = curl_with_git_auth()(
959        "--silent --fail",
960        "-X POST",
961        "-H",
962        quoted("Content-Type: application/json"),
963        "-d",
964        quoted(json.dumps(body)),
965        f"{GERRIT_URL}/a/{path}",
966    ).stdout()
967    if very_verbose():
968        print("Response:", response)
969    return json.loads(strip_xssi(response))
970
971
972class GerritChange(object):
973    """
974    Class to interact with the gerrit /changes/ API.
975
976    For information on the data format returned by the API, see:
977    https://gerrit-review.googlesource.com/Documentation/rest-api-changes.html#change-info
978    """
979
980    id: str
981    _data: Any
982
983    def __init__(self, data: Any):
984        self._data = data
985        self.id = data["id"]
986
987    @functools.cached_property
988    def _details(self) -> Any:
989        return gerrit_api_get(f"changes/{self.id}/detail")
990
991    @functools.cached_property
992    def _messages(self) -> List[Any]:
993        return gerrit_api_get(f"changes/{self.id}/messages")
994
995    @property
996    def status(self):
997        return cast(str, self._data["status"])
998
999    def get_votes(self, label_name: str) -> List[int]:
1000        "Returns the list of votes on `label_name`"
1001        label_info = self._details.get("labels", {}).get(label_name)
1002        votes = label_info.get("all", [])
1003        return [cast(int, v.get("value")) for v in votes]
1004
1005    def get_messages_by(self, email: str) -> List[str]:
1006        "Returns all messages posted by the user with the specified `email`."
1007        return [m["message"] for m in self._messages if m["author"].get("email") == email]
1008
1009    def review(self, message: str, labels: Dict[str, int]):
1010        "Post review `message` and set the specified review `labels`"
1011        print("Posting on", self, ":", message, labels)
1012        gerrit_api_post(
1013            f"changes/{self.id}/revisions/current/review",
1014            {"message": message, "labels": labels},
1015        )
1016
1017    def abandon(self, message: str):
1018        print("Abandoning", self, ":", message)
1019        gerrit_api_post(f"changes/{self.id}/abandon", {"message": message})
1020
1021    @classmethod
1022    def query(cls, *queries: str):
1023        "Returns a list of gerrit changes matching the provided list of queries."
1024        return [cls(c) for c in gerrit_api_get(f"changes/?q={'+'.join(queries)}")]
1025
1026    def short_url(self):
1027        return f"http://crrev.com/c/{self._data['_number']}"
1028
1029    def __str__(self):
1030        return self.short_url()
1031
1032    def pretty_info(self):
1033        return f"{self} - {self._data['subject']}"
1034
1035
1036def is_cros_repo():
1037    "Returns true if the crosvm repo is a symlink or worktree to a CrOS repo checkout."
1038    dot_git = CROSVM_ROOT / ".git"
1039    if not dot_git.is_symlink() and dot_git.is_dir():
1040        return False
1041    return (cros_repo_root() / ".repo").exists()
1042
1043
1044def cros_repo_root():
1045    "Root directory of the CrOS repo checkout."
1046    return (CROSVM_ROOT / "../../..").resolve()
1047
1048
1049def is_kiwi_repo():
1050    "Returns true if the crosvm repo contains .kiwi_repo file."
1051    dot_kiwi_repo = CROSVM_ROOT / ".kiwi_repo"
1052    return dot_kiwi_repo.exists()
1053
1054
1055def kiwi_repo_root():
1056    "Root directory of the kiwi repo checkout."
1057    return (CROSVM_ROOT / "../..").resolve()
1058
1059
1060def sudo_is_passwordless():
1061    # Run with --askpass but no askpass set, succeeds only if passwordless sudo
1062    # is available.
1063    (ret, _) = subprocess.getstatusoutput("SUDO_ASKPASS=false sudo --askpass true")
1064    return ret == 0
1065
1066
1067SHORTHANDS = {
1068    "mingw64": "x86_64-pc-windows-gnu",
1069    "msvc64": "x86_64-pc-windows-msvc",
1070    "armhf": "armv7-unknown-linux-gnueabihf",
1071    "aarch64": "aarch64-unknown-linux-gnu",
1072    "x86_64": "x86_64-unknown-linux-gnu",
1073}
1074
1075
1076class Triple(NamedTuple):
1077    """
1078    Build triple in cargo format.
1079
1080    The format is: <arch><sub>-<vendor>-<sys>-<abi>, However, we will treat <arch><sub> as a single
1081    arch to simplify things.
1082    """
1083
1084    arch: str
1085    vendor: str
1086    sys: Optional[str]
1087    abi: Optional[str]
1088
1089    @classmethod
1090    def from_shorthand(cls, shorthand: str):
1091        "These shorthands make it easier to specify triples on the command line."
1092        if "-" in shorthand:
1093            triple = shorthand
1094        elif shorthand in SHORTHANDS:
1095            triple = SHORTHANDS[shorthand]
1096        else:
1097            raise Exception(f"Not a valid build triple shorthand: {shorthand}")
1098        return cls.from_str(triple)
1099
1100    @classmethod
1101    def from_str(cls, triple: str):
1102        parts = triple.split("-")
1103        if len(parts) < 2:
1104            raise Exception(f"Unsupported triple {triple}")
1105        return cls(
1106            parts[0],
1107            parts[1],
1108            parts[2] if len(parts) > 2 else None,
1109            parts[3] if len(parts) > 3 else None,
1110        )
1111
1112    @classmethod
1113    def from_linux_arch(cls, arch: str):
1114        "Rough logic to convert the output of `arch` into a corresponding linux build triple."
1115        if arch == "armhf":
1116            return cls.from_str("armv7-unknown-linux-gnueabihf")
1117        else:
1118            return cls.from_str(f"{arch}-unknown-linux-gnu")
1119
1120    @classmethod
1121    def host_default(cls):
1122        "Returns the default build triple of the host."
1123        rustc_info = subprocess.check_output(["rustc", "-vV"], text=True)
1124        match = re.search(r"host: (\S+)", rustc_info)
1125        if not match:
1126            raise Exception(f"Cannot parse rustc info: {rustc_info}")
1127        return cls.from_str(match.group(1))
1128
1129    @property
1130    def feature_flag(self):
1131        triple_to_shorthand = {v: k for k, v in SHORTHANDS.items()}
1132        shorthand = triple_to_shorthand.get(str(self))
1133        if not shorthand:
1134            raise Exception(f"No feature set for triple {self}")
1135        return f"all-{shorthand}"
1136
1137    @property
1138    def target_dir(self):
1139        return crosvm_target_dir() / str(self)
1140
1141    def get_cargo_env(self):
1142        """Environment variables to make cargo use the test target."""
1143        env: Dict[str, str] = {}
1144        cargo_target = str(self)
1145        env["CARGO_BUILD_TARGET"] = cargo_target
1146        env["CARGO_TARGET_DIR"] = str(self.target_dir)
1147        env["CROSVM_TARGET_DIR"] = str(crosvm_target_dir())
1148        return env
1149
1150    def __str__(self):
1151        return f"{self.arch}-{self.vendor}-{self.sys}-{self.abi}"
1152
1153
1154def download_file(url: str, filename: Path, attempts: int = 3):
1155    assert attempts > 0
1156    while True:
1157        attempts -= 1
1158        try:
1159            urllib.request.urlretrieve(url, filename)
1160            return
1161        except Exception as e:
1162            if attempts == 0:
1163                raise e
1164            else:
1165                console.print("Download failed:", e)
1166
1167
1168console = rich.console.Console()
1169
1170if __name__ == "__main__":
1171    import doctest
1172
1173    (failures, num_tests) = doctest.testmod(optionflags=doctest.ELLIPSIS)
1174    sys.exit(1 if failures > 0 else 0)
1175