1# Copyright 2023 The Chromium Authors 2# Use of this source code is governed by a BSD-style license that can be 3# found in the LICENSE file. 4 5from __future__ import annotations 6 7import contextlib 8import email.parser 9import http.server 10import json 11import logging 12import os 13import threading 14from typing import (TYPE_CHECKING, Final, Iterator, Mapping, Optional, Tuple, 15 Type) 16 17from immutabledict import immutabledict 18 19from crossbench import plt 20from crossbench.network.base import Network 21from crossbench.parse import ObjectParser 22 23if TYPE_CHECKING: 24 from crossbench.network.traffic_shaping.base import TrafficShaper 25 from crossbench.path import LocalPath 26 from crossbench.runner.groups.session import BrowserSessionRunGroup 27 28DEFAULT_HOST = "localhost" 29DEFAULT_PORT = 8000 30# List of known headers that are served by the default HTTPServer and might 31# be accidentally overridden by provided extra headers. 32CONFLICTING_EXTRA_HEADERS: Final[frozenset[str]] = frozenset( 33 map(lambda header: header.lower(), 34 ("Content-Type", "Content-Length", "Last-Modified", "Server", "Date", 35 "Connection", "Location"))) 36 37 38class CustomHeadersRequestHandler(http.server.SimpleHTTPRequestHandler): 39 40 @classmethod 41 def bind( 42 cls: Type[CustomHeadersRequestHandler], 43 server_dir: LocalPath, 44 extra_headers: Mapping[str, str], 45 ) -> Type[http.server.SimpleHTTPRequestHandler]: 46 # Use a temporary class to bind arguments. 47 class BoundDirectoryRequestHandler(cls): 48 49 def __init__(self, *args, **kwargs): 50 super().__init__( 51 *args, 52 directory=os.fspath(server_dir), 53 extra_headers=extra_headers, 54 **kwargs) 55 56 return BoundDirectoryRequestHandler 57 58 def __init__(self, 59 *args, 60 directory: Optional[str] = None, 61 extra_headers: Optional[Mapping[str, str]] = None, 62 **kwargs): 63 self._extra_headers: immutabledict[str, str] = ( 64 immutabledict(extra_headers) if extra_headers else immutabledict()) 65 super().__init__(*args, directory=directory, **kwargs) 66 67 def end_headers(self): 68 if self._extra_headers: 69 self._send_custom_headers() 70 super().end_headers() 71 72 def _send_custom_headers(self): 73 for key, value in self._extra_headers.items(): 74 self.send_header(key, value) 75 76 77class LocalFileNetwork(Network): 78 79 def __init__(self, 80 path: LocalPath, 81 url: Optional[str], 82 traffic_shaper: Optional[TrafficShaper] = None, 83 browser_platform: plt.Platform = plt.PLATFORM): 84 super().__init__(traffic_shaper, browser_platform) 85 self._path = path 86 self._host, self._port = self._parse_url(url) 87 # TODO: support custom headers via command line 88 self._extra_headers: immutabledict[str, str] = self._try_parse_headers() 89 if self._extra_headers: 90 self._validate_extra_headers() 91 92 @property 93 def is_local_file_server(self) -> bool: 94 return True 95 96 @property 97 def path(self) -> LocalPath: 98 return self._path 99 100 def _parse_url(self, url: Optional[str]) -> Tuple[str, int]: 101 host: str = DEFAULT_HOST 102 port: int = DEFAULT_PORT 103 if not url: 104 return host, port 105 parsed_url = ObjectParser.url(url) 106 if parsed_url.hostname: 107 host = parsed_url.hostname 108 if parsed_url.port is not None: 109 port = parsed_url.port 110 return host, port 111 112 def _try_parse_headers(self) -> immutabledict[str, str]: 113 for name in ("HEADERS", "HEADERS.txt"): 114 header_file = self._path / name 115 if header_file.exists(): 116 return self._read_headers_file(header_file) 117 return immutabledict() 118 119 def _read_headers_file(self, 120 header_file: LocalPath) -> immutabledict[str, str]: 121 with header_file.open("rb") as f: 122 # Reuse python's email message library to parse headers 123 message = email.parser.BytesParser().parsebytes(f.read()) 124 return immutabledict(message) 125 126 def _validate_extra_headers(self): 127 for key, value in self._extra_headers.items(): 128 if key.lower() in CONFLICTING_EXTRA_HEADERS: 129 logging.error( 130 "BROWSER Network: Extra header overrides server defaults: '%s: %s'", 131 key, value) 132 133 @contextlib.contextmanager 134 def open(self, session: BrowserSessionRunGroup) -> Iterator[Network]: 135 with super().open(session): 136 with self._open_local_file_server(): 137 # TODO: properly hook up traffic shaper for the local http server 138 with self._traffic_shaper.open(self, session): 139 with self._forward_ports(session): 140 yield self 141 142 @contextlib.contextmanager 143 def _open_local_file_server(self): 144 # TODO: write request log file to session results folder. 145 # TODO: support https server using SSLContext.wrap_socket(httpd.socket) 146 request_handler_cls = CustomHeadersRequestHandler.bind( 147 self._path, self._extra_headers) 148 server = http.server.ThreadingHTTPServer((self._host, self._port), 149 request_handler_cls) 150 with self._server_thread(server): 151 logging.info("%s custom host=%s, port=%s", 152 type(self).__name__, self.host, self.http_port) 153 yield 154 155 @contextlib.contextmanager 156 def _server_thread(self, server: http.server.HTTPServer) -> Iterator[None]: 157 with server: 158 server_thread = threading.Thread(target=server.serve_forever) 159 server_thread.daemon = True 160 server_thread.start() 161 self._port = server.server_port 162 try: 163 yield 164 finally: 165 server.shutdown() 166 server_thread.join() 167 168 @contextlib.contextmanager 169 def _forward_ports(self, session: BrowserSessionRunGroup) -> Iterator: 170 browser_platform = session.browser_platform 171 if browser_platform.is_remote: 172 logging.info("REMOTE PORT FORWARDING: %s <= %s", self.host_platform, 173 browser_platform) 174 # TODO: create port-forwarder service that is shut down properly. 175 # TODO: make ports configurable 176 browser_platform.reverse_port_forward(self._port, self._port) 177 yield 178 if browser_platform.is_remote: 179 browser_platform.stop_reverse_port_forward(self._port) 180 181 @property 182 def http_port(self) -> Optional[int]: 183 return self._port 184 185 @property 186 def https_port(self) -> Optional[int]: 187 # TODO: support https locally 188 return None 189 190 @property 191 def host(self) -> Optional[str]: 192 return self._host 193 194 def __str__(self) -> str: 195 extra_headers_str = "" 196 if self._extra_headers: 197 formatted_headers = json.dumps(dict(self._extra_headers)) 198 extra_headers_str = f" extra_headers={formatted_headers}" 199 return ("LOCAL(path={self._path}, " 200 f"speed={self.traffic_shaper}{extra_headers_str})") 201