1 // Copyright 2024 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/test/embedded_test_server/create_websocket_handler.h"
6
7 #include "base/base64.h"
8 #include "base/functional/bind.h"
9 #include "base/memory/scoped_refptr.h"
10 #include "base/strings/string_util.h"
11 #include "base/test/bind.h"
12 #include "base/time/time.h"
13 #include "base/types/expected.h"
14 #include "net/base/host_port_pair.h"
15 #include "net/base/url_util.h"
16 #include "net/http/http_status_code.h"
17 #include "net/test/embedded_test_server/embedded_test_server.h"
18 #include "net/test/embedded_test_server/http_request.h"
19 #include "net/test/embedded_test_server/http_response.h"
20 #include "net/test/embedded_test_server/websocket_connection.h"
21
22 namespace net::test_server {
23
24 namespace {
25
26 // Helper function to strip the query part of the URL
StripQuery(std::string_view url)27 std::string_view StripQuery(std::string_view url) {
28 const size_t query_position = url.find('?');
29 return (query_position == std::string_view::npos)
30 ? url
31 : url.substr(0, query_position);
32 }
33
MakeErrorResponse(HttpStatusCode code,std::string_view content)34 std::unique_ptr<HttpResponse> MakeErrorResponse(HttpStatusCode code,
35 std::string_view content) {
36 auto error_response = std::make_unique<BasicHttpResponse>();
37 error_response->set_code(code);
38 error_response->set_content(content);
39 DVLOG(3) << "Error response created. Code: " << static_cast<int>(code)
40 << ", Content: " << content;
41 return error_response;
42 }
43
HandleWebSocketUpgrade(std::string_view handle_path,WebSocketHandlerCreator websocket_handler_creator,EmbeddedTestServer * server,const HttpRequest & request,HttpConnection * connection)44 EmbeddedTestServer::UpgradeResultOrHttpResponse HandleWebSocketUpgrade(
45 std::string_view handle_path,
46 WebSocketHandlerCreator websocket_handler_creator,
47 EmbeddedTestServer* server,
48 const HttpRequest& request,
49 HttpConnection* connection) {
50 DVLOG(3) << "Handling WebSocket upgrade for path: " << handle_path;
51
52 std::string_view request_path = StripQuery(request.relative_url);
53
54 if (request_path != handle_path) {
55 return UpgradeResult::kNotHandled;
56 }
57
58 if (request.method != METHOD_GET) {
59 return base::unexpected(
60 MakeErrorResponse(HttpStatusCode::HTTP_BAD_REQUEST,
61 "Invalid request method. Expected GET."));
62 }
63
64 // TODO(crbug.com/40812029): Check that the HTTP version is 1.1
65 // See https://datatracker.ietf.org/doc/html/rfc6455#section-4.2.1
66
67 auto host_header = request.headers.find("Host");
68 if (host_header == request.headers.end()) {
69 DVLOG(1) << "Host header is missing.";
70 return base::unexpected(MakeErrorResponse(HttpStatusCode::HTTP_BAD_REQUEST,
71 "Host header is missing."));
72 }
73
74 HostPortPair host_port = HostPortPair::FromString(host_header->second);
75 if (!IsCanonicalizedHostCompliant(host_port.host())) {
76 DVLOG(1) << "Host header is invalid: " << host_port.host();
77 return base::unexpected(MakeErrorResponse(HttpStatusCode::HTTP_BAD_REQUEST,
78 "Host header is invalid."));
79 }
80
81 auto upgrade_header = request.headers.find("Upgrade");
82 if (upgrade_header == request.headers.end() ||
83 !base::EqualsCaseInsensitiveASCII(upgrade_header->second, "websocket")) {
84 DVLOG(1) << "Upgrade header is missing or invalid: "
85 << upgrade_header->second;
86 return base::unexpected(
87 MakeErrorResponse(HttpStatusCode::HTTP_BAD_REQUEST,
88 "Upgrade header is missing or invalid."));
89 }
90
91 auto connection_header = request.headers.find("Connection");
92 if (connection_header == request.headers.end()) {
93 DVLOG(1) << "Connection header is missing.";
94 return base::unexpected(MakeErrorResponse(HttpStatusCode::HTTP_BAD_REQUEST,
95 "Connection header is missing."));
96 }
97
98 auto tokens =
99 base::SplitStringPiece(connection_header->second, ",",
100 base::TRIM_WHITESPACE, base::SPLIT_WANT_NONEMPTY);
101 if (!base::ranges::any_of(tokens, [](std::string_view token) {
102 return base::EqualsCaseInsensitiveASCII(token, "Upgrade");
103 })) {
104 DVLOG(1) << "Connection header does not contain 'Upgrade'. Tokens: "
105 << connection_header->second;
106 return base::unexpected(
107 MakeErrorResponse(HttpStatusCode::HTTP_BAD_REQUEST,
108 "Connection header does not contain 'Upgrade'."));
109 }
110
111 auto websocket_version_header = request.headers.find("Sec-WebSocket-Version");
112 if (websocket_version_header == request.headers.end() ||
113 websocket_version_header->second != "13") {
114 DVLOG(1) << "Invalid or missing Sec-WebSocket-Version: "
115 << websocket_version_header->second;
116 return base::unexpected(MakeErrorResponse(
117 HttpStatusCode::HTTP_BAD_REQUEST, "Sec-WebSocket-Version must be 13."));
118 }
119
120 auto sec_websocket_key_iter = request.headers.find("Sec-WebSocket-Key");
121 if (sec_websocket_key_iter == request.headers.end()) {
122 DVLOG(1) << "Sec-WebSocket-Key header is missing.";
123 return base::unexpected(
124 MakeErrorResponse(HttpStatusCode::HTTP_BAD_REQUEST,
125 "Sec-WebSocket-Key header is missing."));
126 }
127
128 auto decoded = base::Base64Decode(sec_websocket_key_iter->second);
129 if (!decoded || decoded->size() != 16) {
130 DVLOG(1) << "Sec-WebSocket-Key is invalid or has incorrect length.";
131 return base::unexpected(MakeErrorResponse(
132 HttpStatusCode::HTTP_BAD_REQUEST,
133 "Sec-WebSocket-Key is invalid or has incorrect length."));
134 }
135
136 std::unique_ptr<StreamSocket> socket = connection->TakeSocket();
137 CHECK(socket);
138
139 auto websocket_connection = base::MakeRefCounted<WebSocketConnection>(
140 std::move(socket), sec_websocket_key_iter->second, server);
141
142 auto handler = websocket_handler_creator.Run(websocket_connection);
143 handler->OnHandshake(request);
144 websocket_connection->SetHandler(std::move(handler));
145 websocket_connection->SendHandshakeResponse();
146 return UpgradeResult::kUpgraded;
147 }
148
149 } // namespace
150
CreateWebSocketHandler(std::string_view handle_path,WebSocketHandlerCreator websocket_handler_creator,EmbeddedTestServer * server)151 EmbeddedTestServer::HandleUpgradeRequestCallback CreateWebSocketHandler(
152 std::string_view handle_path,
153 WebSocketHandlerCreator websocket_handler_creator,
154 EmbeddedTestServer* server) {
155 // Note: The callback registered in ControllableHttpResponse will not be
156 // called after the server has been destroyed. This guarantees that the
157 // EmbeddedTestServer pointer remains valid for the lifetime of the
158 // ControllableHttpResponse instance.
159 return base::BindRepeating(&HandleWebSocketUpgrade, handle_path,
160 websocket_handler_creator, server);
161 }
162
163 } // namespace net::test_server
164