• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2023 The Chromium Authors
2# Use of this source code is governed by a BSD-style license that can be
3# found in the LICENSE file.
4
5from __future__ import annotations
6
7import argparse
8import asyncio
9import enum
10import json
11import logging
12import os
13import secrets
14import shlex
15import sys
16import tempfile
17from typing import TYPE_CHECKING, Any, Coroutine, Dict, Optional, Tuple
18
19import websockets
20from websockets.server import WebSocketServerProtocol
21
22from crossbench import compat, helper
23from crossbench import path as pth
24from crossbench.helper.state import BaseState, StateMachine
25
26if TYPE_CHECKING:
27  from crossbench.plt.base import ListCmdArgs
28  from crossbench.types import JsonDict
29
30
31CROSSBENCH_ROOT: pth.LocalPath = pth.LocalPath(__file__).parents[4]
32
33
34@enum.unique
35class State(BaseState):
36  CONNECTED = enum.auto()
37  RUNNING = enum.auto()
38
39
40@enum.unique
41class Response(compat.StrEnum):
42  STATUS = "status"
43  OUTPUT = "output"
44
45
46class AuthenticationError(ValueError):
47  pass
48
49
50class CrossbenchDevToolsRecorderProxy:
51  DEFAULT_PORT = 44645
52
53  @classmethod
54  def add_subcommand(cls, subparsers) -> argparse.ArgumentParser:
55    parser = subparsers.add_parser(
56        "devtools-recorder-proxy",
57        aliases=["devtools"],
58        help=("Starts a local server to communicate with the "
59              "DevTools Recorder extension."))
60    parser.set_defaults(subcommand_fn=cls._subcommand)
61    parser.add_argument(
62        "--disable-token-authentication",
63        dest="use_auth_token",
64        default=True,
65        action="store_false",
66        help=("Disable token-based authentication. "
67              "Unsafe, only use for local development."))
68    return parser
69
70  @classmethod
71  def _subcommand(cls, args: argparse.Namespace) -> None:
72    instance: CrossbenchDevToolsRecorderProxy = cls(
73        use_auth_token=args.use_auth_token)
74    instance.run()
75
76  _websocket: WebSocketServerProtocol
77
78  def __init__(self, use_auth_token: bool = True) -> None:
79    self._token: str = secrets.token_hex(16)
80    self._use_auth_token: bool = use_auth_token
81    self._print_cmd_output: bool = False
82    self._port: int = self.DEFAULT_PORT
83    self._state = StateMachine(State.CONNECTED)
84    self._crossbench_task: Optional[asyncio.Task] = None
85    self._crossbench_process = None
86    self._tmp_json = pth.LocalPath(
87        tempfile.mkdtemp("crossbench_proxy")) / "devtools_recorder.json"
88
89  def run(self) -> None:
90    asyncio.run(self.run_server())
91
92  async def run_server(self) -> None:
93    try:
94      serve = websockets.serve(self.handler, "localhost", self.DEFAULT_PORT)
95    except Exception as e:  # pylint: disable=broad-except
96      logging.exception(e)
97      serve = websockets.serve(self.handler, "localhost")
98    async with serve as server:
99      self._port = server.sockets[0].getsockname()[1]
100      logging.info("#" * 80)
101      logging.info("#" * 80)
102      logging.info("# Crossbench DevTools Recorder Replay Server Started")
103      logging.info("#")
104      if self._port != self.DEFAULT_PORT:
105        logging.warning("# Non-default port!")
106      logging.info("# PORT:  %s", self._port)
107      if not self._use_auth_token:
108        logging.warning("# Token authentication has been disabled!")
109      logging.info("# TOKEN: %s", self._token)
110      logging.info("#")
111      logging.info("#" * 80)
112      logging.info("#" * 80)
113      await asyncio.Future()  # run forever
114
115  async def handler(self, websocket: WebSocketServerProtocol) -> None:
116    self._websocket = websocket
117    async for message in websocket:
118      await self._send_message(self._handle_message(message))
119
120  async def _send_message(
121      self, coroutine: Coroutine[Any, Any, Optional[Tuple[Response,
122                                                          Any]]]) -> None:
123    response: JsonDict = {"success": False, "payload": None, "error": None}
124    try:
125      result: Optional[Tuple[Response, Any]] = await coroutine
126      response["success"] = True
127      if result:
128        response_type, payload = result
129        response["payload"] = payload
130        response["type"] = response_type.value
131    except Exception as e:  # pylint: disable=broad-except
132      logging.exception(e)
133      response["error"] = str(type(e).__name__)
134    try:
135      response_json = json.dumps(response)
136    except Exception as e:  # pylint: disable=broad-except
137      logging.exception(e)
138      response["success"] = False
139      response["error"] = "Failed to encode message"
140      response["payload"] = None
141      response_json = json.dumps(response)
142    logging.debug("SEND Response: %s", response_json)
143    await self._websocket.send(response_json)
144
145  async def _handle_message(self, message) -> Optional[Tuple[Response, Any]]:
146    logging.debug("RECEIVE Message: %s", message)
147    try:
148      payload: Dict[str, Any] = json.loads(message)
149    except json.JSONDecodeError as e:
150      logging.error("Could not parse JSON response: %s", e)
151      raise e
152    if self._use_auth_token:
153      payload_token = payload["token"]
154      if payload_token != self._token:
155        logging.error("Invalid request token: %s", payload_token)
156        raise AuthenticationError("Invalid Token")
157    command = payload["command"]
158    args = payload.get("args", None)
159    if command == "run":
160      return await self._run_command(args)
161    if command == "stop":
162      return await self._stop_command()
163    if command == "status":
164      return await self._status_command()
165    logging.error("Unknown command: %s", command)
166    return None
167
168  async def _stop_command(self) -> Tuple[Response, str]:
169    if self._crossbench_process:
170      logging.info("# CROSSBENCH COMMAND: KILL")
171      helper.wait_and_kill(self._crossbench_process)
172    self._state.transition(State.CONNECTED, State.CONNECTED, to=State.CONNECTED)
173    return await self._status_command()
174
175  async def _run_command(self, args) -> Tuple[Response, str]:
176    self._state.transition(State.CONNECTED, to=State.RUNNING)
177    assert self._crossbench_process is None
178    cb_path = CROSSBENCH_ROOT / "cb.py"
179    os.environ["PYTHONUNBUFFERED"] = "1"
180    cmd: ListCmdArgs = []
181    if args.get("cmd") == "--help":
182      cmd = ["load", "--help"]
183      self._print_cmd_output = False
184    elif args.get("cmd") == "describe probes":
185      cmd = ["describe", "probes"]
186      self._print_cmd_output = False
187    else:
188      self._print_cmd_output = True
189      with self._tmp_json.open("w", encoding="utf-8") as f:
190        json.dump(args["json"], f)
191      assert self._tmp_json.exists(), f"{self._tmp_json} does not exist."
192      assert cb_path.exists(), f"{cb_path} does not exist."
193      cmd = [
194          "load", "--env-validation=warn", "--verbose",
195          f"--devtools-recorder={self._tmp_json.absolute()}",
196          *shlex.split(args.get("cmd"))
197      ]
198    logging.info("CROSSBENCH COMMAND: %s", cmd)
199    self._crossbench_process = await asyncio.create_subprocess_exec(
200        cb_path.absolute(),
201        *cmd,
202        stdout=asyncio.subprocess.PIPE,
203        stderr=asyncio.subprocess.PIPE)
204    self._crossbench_task = asyncio.create_task(self._wait_for_crossbench())
205    return await self._status_command()
206
207  async def _send_output(
208      self, stdout_str: Optional[str],
209      stderr_str: Optional[str]) -> Optional[Tuple[Response, Dict[str, str]]]:
210    if self._state != State.RUNNING:
211      return None
212    if self._print_cmd_output:
213      sys.stdout.write(stdout_str or "")
214      sys.stderr.write(stderr_str or "")
215    return Response.OUTPUT, {
216        "stdout": stdout_str or "",
217        "stderr": stderr_str or "",
218    }
219
220  async def _wait_for_crossbench(self) -> None:
221    assert self._crossbench_process
222    stdout_sender = asyncio.create_task(
223        self._send_stdout_incrementally(self._crossbench_process.stdout))
224    stderr_sender = asyncio.create_task(
225        self._send_stderr_incrementally(self._crossbench_process.stderr))
226    # TODO: Figure out why waiting on sending the output hangs when the
227    # crossbench subprocess ends with exit!=0
228    await asyncio.wait([stdout_sender, stderr_sender])
229    stdout, stderr = await self._crossbench_process.communicate()
230    await self._send_message(
231        self._send_output(stdout.decode("utf-8"), stderr.decode("utf-8")))
232    returncode = self._crossbench_process.returncode
233    self._state.transition(State.RUNNING, to=State.CONNECTED)
234    self._crossbench_task = None
235    self._crossbench_process = None
236    await self._send_message(self._status_command())
237    logging.info("# CROSSBENCH COMMAND DONE: returncode=%s", returncode)
238
239  _OUTPUT_BUFFER_SIZE = 128
240
241  async def _send_stdout_incrementally(self, stdout) -> None:
242    while self._crossbench_process:
243      stdout_data = await stdout.read(self._OUTPUT_BUFFER_SIZE)
244      if not stdout_data:
245        return
246      stdout_str = stdout_data.decode("utf-8")
247      await self._send_message(self._send_output(stdout_str, None))
248
249  async def _send_stderr_incrementally(self, stderr) -> None:
250    while self._crossbench_process:
251      stderr_data = await stderr.read(self._OUTPUT_BUFFER_SIZE)
252      if not stderr_data:
253        return
254      stderr_str = stderr_data.decode("utf-8")
255      await self._send_message(self._send_output(None, stderr_str))
256
257  async def _status_command(self) -> Tuple[Response, str]:
258    return Response.STATUS, str(self._state.name)
259