1# Copyright 2023 The Chromium Authors 2# Use of this source code is governed by a BSD-style license that can be 3# found in the LICENSE file. 4 5from __future__ import annotations 6 7import argparse 8import asyncio 9import enum 10import json 11import logging 12import os 13import secrets 14import shlex 15import sys 16import tempfile 17from typing import TYPE_CHECKING, Any, Coroutine, Dict, Optional, Tuple 18 19import websockets 20from websockets.server import WebSocketServerProtocol 21 22from crossbench import compat, helper 23from crossbench import path as pth 24from crossbench.helper.state import BaseState, StateMachine 25 26if TYPE_CHECKING: 27 from crossbench.plt.base import ListCmdArgs 28 from crossbench.types import JsonDict 29 30 31CROSSBENCH_ROOT: pth.LocalPath = pth.LocalPath(__file__).parents[4] 32 33 34@enum.unique 35class State(BaseState): 36 CONNECTED = enum.auto() 37 RUNNING = enum.auto() 38 39 40@enum.unique 41class Response(compat.StrEnum): 42 STATUS = "status" 43 OUTPUT = "output" 44 45 46class AuthenticationError(ValueError): 47 pass 48 49 50class CrossbenchDevToolsRecorderProxy: 51 DEFAULT_PORT = 44645 52 53 @classmethod 54 def add_subcommand(cls, subparsers) -> argparse.ArgumentParser: 55 parser = subparsers.add_parser( 56 "devtools-recorder-proxy", 57 aliases=["devtools"], 58 help=("Starts a local server to communicate with the " 59 "DevTools Recorder extension.")) 60 parser.set_defaults(subcommand_fn=cls._subcommand) 61 parser.add_argument( 62 "--disable-token-authentication", 63 dest="use_auth_token", 64 default=True, 65 action="store_false", 66 help=("Disable token-based authentication. " 67 "Unsafe, only use for local development.")) 68 return parser 69 70 @classmethod 71 def _subcommand(cls, args: argparse.Namespace) -> None: 72 instance: CrossbenchDevToolsRecorderProxy = cls( 73 use_auth_token=args.use_auth_token) 74 instance.run() 75 76 _websocket: WebSocketServerProtocol 77 78 def __init__(self, use_auth_token: bool = True) -> None: 79 self._token: str = secrets.token_hex(16) 80 self._use_auth_token: bool = use_auth_token 81 self._print_cmd_output: bool = False 82 self._port: int = self.DEFAULT_PORT 83 self._state = StateMachine(State.CONNECTED) 84 self._crossbench_task: Optional[asyncio.Task] = None 85 self._crossbench_process = None 86 self._tmp_json = pth.LocalPath( 87 tempfile.mkdtemp("crossbench_proxy")) / "devtools_recorder.json" 88 89 def run(self) -> None: 90 asyncio.run(self.run_server()) 91 92 async def run_server(self) -> None: 93 try: 94 serve = websockets.serve(self.handler, "localhost", self.DEFAULT_PORT) 95 except Exception as e: # pylint: disable=broad-except 96 logging.exception(e) 97 serve = websockets.serve(self.handler, "localhost") 98 async with serve as server: 99 self._port = server.sockets[0].getsockname()[1] 100 logging.info("#" * 80) 101 logging.info("#" * 80) 102 logging.info("# Crossbench DevTools Recorder Replay Server Started") 103 logging.info("#") 104 if self._port != self.DEFAULT_PORT: 105 logging.warning("# Non-default port!") 106 logging.info("# PORT: %s", self._port) 107 if not self._use_auth_token: 108 logging.warning("# Token authentication has been disabled!") 109 logging.info("# TOKEN: %s", self._token) 110 logging.info("#") 111 logging.info("#" * 80) 112 logging.info("#" * 80) 113 await asyncio.Future() # run forever 114 115 async def handler(self, websocket: WebSocketServerProtocol) -> None: 116 self._websocket = websocket 117 async for message in websocket: 118 await self._send_message(self._handle_message(message)) 119 120 async def _send_message( 121 self, coroutine: Coroutine[Any, Any, Optional[Tuple[Response, 122 Any]]]) -> None: 123 response: JsonDict = {"success": False, "payload": None, "error": None} 124 try: 125 result: Optional[Tuple[Response, Any]] = await coroutine 126 response["success"] = True 127 if result: 128 response_type, payload = result 129 response["payload"] = payload 130 response["type"] = response_type.value 131 except Exception as e: # pylint: disable=broad-except 132 logging.exception(e) 133 response["error"] = str(type(e).__name__) 134 try: 135 response_json = json.dumps(response) 136 except Exception as e: # pylint: disable=broad-except 137 logging.exception(e) 138 response["success"] = False 139 response["error"] = "Failed to encode message" 140 response["payload"] = None 141 response_json = json.dumps(response) 142 logging.debug("SEND Response: %s", response_json) 143 await self._websocket.send(response_json) 144 145 async def _handle_message(self, message) -> Optional[Tuple[Response, Any]]: 146 logging.debug("RECEIVE Message: %s", message) 147 try: 148 payload: Dict[str, Any] = json.loads(message) 149 except json.JSONDecodeError as e: 150 logging.error("Could not parse JSON response: %s", e) 151 raise e 152 if self._use_auth_token: 153 payload_token = payload["token"] 154 if payload_token != self._token: 155 logging.error("Invalid request token: %s", payload_token) 156 raise AuthenticationError("Invalid Token") 157 command = payload["command"] 158 args = payload.get("args", None) 159 if command == "run": 160 return await self._run_command(args) 161 if command == "stop": 162 return await self._stop_command() 163 if command == "status": 164 return await self._status_command() 165 logging.error("Unknown command: %s", command) 166 return None 167 168 async def _stop_command(self) -> Tuple[Response, str]: 169 if self._crossbench_process: 170 logging.info("# CROSSBENCH COMMAND: KILL") 171 helper.wait_and_kill(self._crossbench_process) 172 self._state.transition(State.CONNECTED, State.CONNECTED, to=State.CONNECTED) 173 return await self._status_command() 174 175 async def _run_command(self, args) -> Tuple[Response, str]: 176 self._state.transition(State.CONNECTED, to=State.RUNNING) 177 assert self._crossbench_process is None 178 cb_path = CROSSBENCH_ROOT / "cb.py" 179 os.environ["PYTHONUNBUFFERED"] = "1" 180 cmd: ListCmdArgs = [] 181 if args.get("cmd") == "--help": 182 cmd = ["load", "--help"] 183 self._print_cmd_output = False 184 elif args.get("cmd") == "describe probes": 185 cmd = ["describe", "probes"] 186 self._print_cmd_output = False 187 else: 188 self._print_cmd_output = True 189 with self._tmp_json.open("w", encoding="utf-8") as f: 190 json.dump(args["json"], f) 191 assert self._tmp_json.exists(), f"{self._tmp_json} does not exist." 192 assert cb_path.exists(), f"{cb_path} does not exist." 193 cmd = [ 194 "load", "--env-validation=warn", "--verbose", 195 f"--devtools-recorder={self._tmp_json.absolute()}", 196 *shlex.split(args.get("cmd")) 197 ] 198 logging.info("CROSSBENCH COMMAND: %s", cmd) 199 self._crossbench_process = await asyncio.create_subprocess_exec( 200 cb_path.absolute(), 201 *cmd, 202 stdout=asyncio.subprocess.PIPE, 203 stderr=asyncio.subprocess.PIPE) 204 self._crossbench_task = asyncio.create_task(self._wait_for_crossbench()) 205 return await self._status_command() 206 207 async def _send_output( 208 self, stdout_str: Optional[str], 209 stderr_str: Optional[str]) -> Optional[Tuple[Response, Dict[str, str]]]: 210 if self._state != State.RUNNING: 211 return None 212 if self._print_cmd_output: 213 sys.stdout.write(stdout_str or "") 214 sys.stderr.write(stderr_str or "") 215 return Response.OUTPUT, { 216 "stdout": stdout_str or "", 217 "stderr": stderr_str or "", 218 } 219 220 async def _wait_for_crossbench(self) -> None: 221 assert self._crossbench_process 222 stdout_sender = asyncio.create_task( 223 self._send_stdout_incrementally(self._crossbench_process.stdout)) 224 stderr_sender = asyncio.create_task( 225 self._send_stderr_incrementally(self._crossbench_process.stderr)) 226 # TODO: Figure out why waiting on sending the output hangs when the 227 # crossbench subprocess ends with exit!=0 228 await asyncio.wait([stdout_sender, stderr_sender]) 229 stdout, stderr = await self._crossbench_process.communicate() 230 await self._send_message( 231 self._send_output(stdout.decode("utf-8"), stderr.decode("utf-8"))) 232 returncode = self._crossbench_process.returncode 233 self._state.transition(State.RUNNING, to=State.CONNECTED) 234 self._crossbench_task = None 235 self._crossbench_process = None 236 await self._send_message(self._status_command()) 237 logging.info("# CROSSBENCH COMMAND DONE: returncode=%s", returncode) 238 239 _OUTPUT_BUFFER_SIZE = 128 240 241 async def _send_stdout_incrementally(self, stdout) -> None: 242 while self._crossbench_process: 243 stdout_data = await stdout.read(self._OUTPUT_BUFFER_SIZE) 244 if not stdout_data: 245 return 246 stdout_str = stdout_data.decode("utf-8") 247 await self._send_message(self._send_output(stdout_str, None)) 248 249 async def _send_stderr_incrementally(self, stderr) -> None: 250 while self._crossbench_process: 251 stderr_data = await stderr.read(self._OUTPUT_BUFFER_SIZE) 252 if not stderr_data: 253 return 254 stderr_str = stderr_data.decode("utf-8") 255 await self._send_message(self._send_output(None, stderr_str)) 256 257 async def _status_command(self) -> Tuple[Response, str]: 258 return Response.STATUS, str(self._state.name) 259