• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2023 The Chromium Authors
2# Use of this source code is governed by a BSD-style license that can be
3# found in the LICENSE file.
4
5from __future__ import annotations
6
7import contextlib
8import email.parser
9import http.server
10import json
11import logging
12import os
13import threading
14from typing import (TYPE_CHECKING, Final, Iterator, Mapping, Optional, Tuple,
15                    Type)
16
17from immutabledict import immutabledict
18
19from crossbench import plt
20from crossbench.network.base import Network
21from crossbench.parse import ObjectParser
22
23if TYPE_CHECKING:
24  from crossbench.network.traffic_shaping.base import TrafficShaper
25  from crossbench.path import LocalPath
26  from crossbench.runner.groups.session import BrowserSessionRunGroup
27
28DEFAULT_HOST = "localhost"
29DEFAULT_PORT = 8000
30# List of known headers that are served by the default HTTPServer and might
31# be accidentally overridden by provided extra headers.
32CONFLICTING_EXTRA_HEADERS: Final[frozenset[str]] = frozenset(
33    map(lambda header: header.lower(),
34        ("Content-Type", "Content-Length", "Last-Modified", "Server", "Date",
35         "Connection", "Location")))
36
37
38class CustomHeadersRequestHandler(http.server.SimpleHTTPRequestHandler):
39
40  @classmethod
41  def bind(
42      cls: Type[CustomHeadersRequestHandler],
43      server_dir: LocalPath,
44      extra_headers: Mapping[str, str],
45  ) -> Type[http.server.SimpleHTTPRequestHandler]:
46    # Use a temporary class to bind arguments.
47    class BoundDirectoryRequestHandler(cls):
48
49      def __init__(self, *args, **kwargs):
50        super().__init__(
51            *args,
52            directory=os.fspath(server_dir),
53            extra_headers=extra_headers,
54            **kwargs)
55
56    return BoundDirectoryRequestHandler
57
58  def __init__(self,
59               *args,
60               directory: Optional[str] = None,
61               extra_headers: Optional[Mapping[str, str]] = None,
62               **kwargs):
63    self._extra_headers: immutabledict[str, str] = (
64        immutabledict(extra_headers) if extra_headers else immutabledict())
65    super().__init__(*args, directory=directory, **kwargs)
66
67  def end_headers(self):
68    if self._extra_headers:
69      self._send_custom_headers()
70    super().end_headers()
71
72  def _send_custom_headers(self):
73    for key, value in self._extra_headers.items():
74      self.send_header(key, value)
75
76
77class LocalFileNetwork(Network):
78
79  def __init__(self,
80               path: LocalPath,
81               url: Optional[str],
82               traffic_shaper: Optional[TrafficShaper] = None,
83               browser_platform: plt.Platform = plt.PLATFORM):
84    super().__init__(traffic_shaper, browser_platform)
85    self._path = path
86    self._host, self._port = self._parse_url(url)
87    # TODO: support custom headers via command line
88    self._extra_headers: immutabledict[str, str] = self._try_parse_headers()
89    if self._extra_headers:
90      self._validate_extra_headers()
91
92  @property
93  def is_local_file_server(self) -> bool:
94    return True
95
96  @property
97  def path(self) -> LocalPath:
98    return self._path
99
100  def _parse_url(self, url: Optional[str]) -> Tuple[str, int]:
101    host: str = DEFAULT_HOST
102    port: int = DEFAULT_PORT
103    if not url:
104      return host, port
105    parsed_url = ObjectParser.url(url)
106    if parsed_url.hostname:
107      host = parsed_url.hostname
108    if parsed_url.port is not None:
109      port = parsed_url.port
110    return host, port
111
112  def _try_parse_headers(self) -> immutabledict[str, str]:
113    for name in ("HEADERS", "HEADERS.txt"):
114      header_file = self._path / name
115      if header_file.exists():
116        return self._read_headers_file(header_file)
117    return immutabledict()
118
119  def _read_headers_file(self,
120                         header_file: LocalPath) -> immutabledict[str, str]:
121    with header_file.open("rb") as f:
122      # Reuse python's email message library to parse headers
123      message = email.parser.BytesParser().parsebytes(f.read())
124      return immutabledict(message)
125
126  def _validate_extra_headers(self):
127    for key, value in self._extra_headers.items():
128      if key.lower() in CONFLICTING_EXTRA_HEADERS:
129        logging.error(
130            "BROWSER Network: Extra header overrides server defaults: '%s: %s'",
131            key, value)
132
133  @contextlib.contextmanager
134  def open(self, session: BrowserSessionRunGroup) -> Iterator[Network]:
135    with super().open(session):
136      with self._open_local_file_server():
137        # TODO: properly hook up traffic shaper for the local http server
138        with self._traffic_shaper.open(self, session):
139          with self._forward_ports(session):
140            yield self
141
142  @contextlib.contextmanager
143  def _open_local_file_server(self):
144    # TODO: write request log file to session results folder.
145    # TODO: support  https server using SSLContext.wrap_socket(httpd.socket)
146    request_handler_cls = CustomHeadersRequestHandler.bind(
147        self._path, self._extra_headers)
148    server = http.server.ThreadingHTTPServer((self._host, self._port),
149                                             request_handler_cls)
150    with self._server_thread(server):
151      logging.info("%s custom host=%s, port=%s",
152                   type(self).__name__, self.host, self.http_port)
153      yield
154
155  @contextlib.contextmanager
156  def _server_thread(self, server: http.server.HTTPServer) -> Iterator[None]:
157    with server:
158      server_thread = threading.Thread(target=server.serve_forever)
159      server_thread.daemon = True
160      server_thread.start()
161      self._port = server.server_port
162      try:
163        yield
164      finally:
165        server.shutdown()
166        server_thread.join()
167
168  @contextlib.contextmanager
169  def _forward_ports(self, session: BrowserSessionRunGroup) -> Iterator:
170    browser_platform = session.browser_platform
171    if browser_platform.is_remote:
172      logging.info("REMOTE PORT FORWARDING: %s <= %s", self.host_platform,
173                   browser_platform)
174      # TODO: create port-forwarder service that is shut down properly.
175      # TODO: make ports configurable
176      browser_platform.reverse_port_forward(self._port, self._port)
177    yield
178    if browser_platform.is_remote:
179      browser_platform.stop_reverse_port_forward(self._port)
180
181  @property
182  def http_port(self) -> Optional[int]:
183    return self._port
184
185  @property
186  def https_port(self) -> Optional[int]:
187    # TODO: support https locally
188    return None
189
190  @property
191  def host(self) -> Optional[str]:
192    return self._host
193
194  def __str__(self) -> str:
195    extra_headers_str = ""
196    if self._extra_headers:
197      formatted_headers = json.dumps(dict(self._extra_headers))
198      extra_headers_str = f" extra_headers={formatted_headers}"
199    return ("LOCAL(path={self._path}, "
200            f"speed={self.traffic_shaper}{extra_headers_str})")
201