• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2024 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/test/embedded_test_server/create_websocket_handler.h"
6 
7 #include "base/base64.h"
8 #include "base/functional/bind.h"
9 #include "base/memory/scoped_refptr.h"
10 #include "base/strings/string_util.h"
11 #include "base/test/bind.h"
12 #include "base/time/time.h"
13 #include "base/types/expected.h"
14 #include "net/base/host_port_pair.h"
15 #include "net/base/url_util.h"
16 #include "net/http/http_status_code.h"
17 #include "net/test/embedded_test_server/embedded_test_server.h"
18 #include "net/test/embedded_test_server/http_request.h"
19 #include "net/test/embedded_test_server/http_response.h"
20 #include "net/test/embedded_test_server/websocket_connection.h"
21 
22 namespace net::test_server {
23 
24 namespace {
25 
26 // Helper function to strip the query part of the URL
StripQuery(std::string_view url)27 std::string_view StripQuery(std::string_view url) {
28   const size_t query_position = url.find('?');
29   return (query_position == std::string_view::npos)
30              ? url
31              : url.substr(0, query_position);
32 }
33 
MakeErrorResponse(HttpStatusCode code,std::string_view content)34 std::unique_ptr<HttpResponse> MakeErrorResponse(HttpStatusCode code,
35                                                 std::string_view content) {
36   auto error_response = std::make_unique<BasicHttpResponse>();
37   error_response->set_code(code);
38   error_response->set_content(content);
39   DVLOG(3) << "Error response created. Code: " << static_cast<int>(code)
40            << ", Content: " << content;
41   return error_response;
42 }
43 
HandleWebSocketUpgrade(std::string_view handle_path,WebSocketHandlerCreator websocket_handler_creator,EmbeddedTestServer * server,const HttpRequest & request,HttpConnection * connection)44 EmbeddedTestServer::UpgradeResultOrHttpResponse HandleWebSocketUpgrade(
45     std::string_view handle_path,
46     WebSocketHandlerCreator websocket_handler_creator,
47     EmbeddedTestServer* server,
48     const HttpRequest& request,
49     HttpConnection* connection) {
50   DVLOG(3) << "Handling WebSocket upgrade for path: " << handle_path;
51 
52   std::string_view request_path = StripQuery(request.relative_url);
53 
54   if (request_path != handle_path) {
55     return UpgradeResult::kNotHandled;
56   }
57 
58   if (request.method != METHOD_GET) {
59     return base::unexpected(
60         MakeErrorResponse(HttpStatusCode::HTTP_BAD_REQUEST,
61                           "Invalid request method. Expected GET."));
62   }
63 
64   // TODO(crbug.com/40812029): Check that the HTTP version is 1.1
65   // See https://datatracker.ietf.org/doc/html/rfc6455#section-4.2.1
66 
67   auto host_header = request.headers.find("Host");
68   if (host_header == request.headers.end()) {
69     DVLOG(1) << "Host header is missing.";
70     return base::unexpected(MakeErrorResponse(HttpStatusCode::HTTP_BAD_REQUEST,
71                                               "Host header is missing."));
72   }
73 
74   HostPortPair host_port = HostPortPair::FromString(host_header->second);
75   if (!IsCanonicalizedHostCompliant(host_port.host())) {
76     DVLOG(1) << "Host header is invalid: " << host_port.host();
77     return base::unexpected(MakeErrorResponse(HttpStatusCode::HTTP_BAD_REQUEST,
78                                               "Host header is invalid."));
79   }
80 
81   auto upgrade_header = request.headers.find("Upgrade");
82   if (upgrade_header == request.headers.end() ||
83       !base::EqualsCaseInsensitiveASCII(upgrade_header->second, "websocket")) {
84     DVLOG(1) << "Upgrade header is missing or invalid: "
85              << upgrade_header->second;
86     return base::unexpected(
87         MakeErrorResponse(HttpStatusCode::HTTP_BAD_REQUEST,
88                           "Upgrade header is missing or invalid."));
89   }
90 
91   auto connection_header = request.headers.find("Connection");
92   if (connection_header == request.headers.end()) {
93     DVLOG(1) << "Connection header is missing.";
94     return base::unexpected(MakeErrorResponse(HttpStatusCode::HTTP_BAD_REQUEST,
95                                               "Connection header is missing."));
96   }
97 
98   auto tokens =
99       base::SplitStringPiece(connection_header->second, ",",
100                              base::TRIM_WHITESPACE, base::SPLIT_WANT_NONEMPTY);
101   if (!base::ranges::any_of(tokens, [](std::string_view token) {
102         return base::EqualsCaseInsensitiveASCII(token, "Upgrade");
103       })) {
104     DVLOG(1) << "Connection header does not contain 'Upgrade'. Tokens: "
105              << connection_header->second;
106     return base::unexpected(
107         MakeErrorResponse(HttpStatusCode::HTTP_BAD_REQUEST,
108                           "Connection header does not contain 'Upgrade'."));
109   }
110 
111   auto websocket_version_header = request.headers.find("Sec-WebSocket-Version");
112   if (websocket_version_header == request.headers.end() ||
113       websocket_version_header->second != "13") {
114     DVLOG(1) << "Invalid or missing Sec-WebSocket-Version: "
115              << websocket_version_header->second;
116     return base::unexpected(MakeErrorResponse(
117         HttpStatusCode::HTTP_BAD_REQUEST, "Sec-WebSocket-Version must be 13."));
118   }
119 
120   auto sec_websocket_key_iter = request.headers.find("Sec-WebSocket-Key");
121   if (sec_websocket_key_iter == request.headers.end()) {
122     DVLOG(1) << "Sec-WebSocket-Key header is missing.";
123     return base::unexpected(
124         MakeErrorResponse(HttpStatusCode::HTTP_BAD_REQUEST,
125                           "Sec-WebSocket-Key header is missing."));
126   }
127 
128   auto decoded = base::Base64Decode(sec_websocket_key_iter->second);
129   if (!decoded || decoded->size() != 16) {
130     DVLOG(1) << "Sec-WebSocket-Key is invalid or has incorrect length.";
131     return base::unexpected(MakeErrorResponse(
132         HttpStatusCode::HTTP_BAD_REQUEST,
133         "Sec-WebSocket-Key is invalid or has incorrect length."));
134   }
135 
136   std::unique_ptr<StreamSocket> socket = connection->TakeSocket();
137   CHECK(socket);
138 
139   auto websocket_connection = base::MakeRefCounted<WebSocketConnection>(
140       std::move(socket), sec_websocket_key_iter->second, server);
141 
142   auto handler = websocket_handler_creator.Run(websocket_connection);
143   handler->OnHandshake(request);
144   websocket_connection->SetHandler(std::move(handler));
145   websocket_connection->SendHandshakeResponse();
146   return UpgradeResult::kUpgraded;
147 }
148 
149 }  // namespace
150 
CreateWebSocketHandler(std::string_view handle_path,WebSocketHandlerCreator websocket_handler_creator,EmbeddedTestServer * server)151 EmbeddedTestServer::HandleUpgradeRequestCallback CreateWebSocketHandler(
152     std::string_view handle_path,
153     WebSocketHandlerCreator websocket_handler_creator,
154     EmbeddedTestServer* server) {
155   // Note: The callback registered in ControllableHttpResponse will not be
156   // called after the server has been destroyed. This guarantees that the
157   // EmbeddedTestServer pointer remains valid for the lifetime of the
158   // ControllableHttpResponse instance.
159   return base::BindRepeating(&HandleWebSocketUpgrade, handle_path,
160                              websocket_handler_creator, server);
161 }
162 
163 }  // namespace net::test_server
164