1 // Copyright 2024 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/test/embedded_test_server/websocket_connection.h"
6
7 #include <stdint.h>
8
9 #include "base/compiler_specific.h"
10 #include "base/containers/extend.h"
11 #include "base/containers/span.h"
12 #include "base/containers/span_writer.h"
13 #include "base/functional/bind.h"
14 #include "base/logging.h"
15 #include "base/memory/scoped_refptr.h"
16 #include "base/numerics/byte_conversions.h"
17 #include "base/numerics/safe_conversions.h"
18 #include "base/strings/strcat.h"
19 #include "net/base/net_errors.h"
20 #include "net/socket/socket.h"
21 #include "net/socket/stream_socket.h"
22 #include "net/test/embedded_test_server/websocket_handler.h"
23 #include "net/test/embedded_test_server/websocket_message_assembler.h"
24 #include "net/websockets/websocket_frame.h"
25 #include "net/websockets/websocket_frame_parser.h"
26 #include "net/websockets/websocket_handshake_challenge.h"
27
28 namespace net::test_server {
29
WebSocketConnection(std::unique_ptr<StreamSocket> socket,std::string_view sec_websocket_key,EmbeddedTestServer * server)30 WebSocketConnection::WebSocketConnection(std::unique_ptr<StreamSocket> socket,
31 std::string_view sec_websocket_key,
32 EmbeddedTestServer* server)
33 : stream_socket_(std::move(socket)),
34 // Register a shutdown closure to safely disconnect this connection when
35 // the
36 // server shuts down. base::Unretained is safe here because:
37 // 1. The shutdown closure is registered during the construction of the
38 // WebSocketConnection object, ensuring `this` is fully initialized.
39 // 2. The lifetime of the closure is tied to the `WebSocketConnection`
40 // object via `shutdown_subscription_`, which ensures that the closure
41 // is automatically unregistered when the object is destroyed.
42 // 3. DisconnectImmediately() ensures safe cleanup by resetting the socket
43 // and marking the connection state as closed.
44 shutdown_subscription_(server->RegisterShutdownClosure(
45 base::BindOnce(&WebSocketConnection::DisconnectImmediately,
46 base::Unretained(this)))) {
47 CHECK(stream_socket_);
48
49 response_headers_.emplace_back("Upgrade", "websocket");
50 response_headers_.emplace_back("Connection", "Upgrade");
51 response_headers_.emplace_back(
52 "Sec-WebSocket-Accept",
53 ComputeSecWebSocketAccept(std::string(sec_websocket_key)));
54 }
55
~WebSocketConnection()56 WebSocketConnection::~WebSocketConnection() {
57 DisconnectImmediately();
58 }
59
SetResponseHeader(std::string_view name,std::string_view value)60 void WebSocketConnection::SetResponseHeader(std::string_view name,
61 std::string_view value) {
62 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
63 CHECK(stream_socket_);
64 for (auto& header : response_headers_) {
65 if (header.first == name) {
66 header.second = value;
67 return;
68 }
69 }
70 response_headers_.emplace_back(name, value);
71 }
72
SendTextMessage(std::string_view message)73 void WebSocketConnection::SendTextMessage(std::string_view message) {
74 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
75 CHECK(stream_socket_);
76 CHECK(base::IsStringUTF8AllowingNoncharacters(message));
77 scoped_refptr<IOBufferWithSize> frame = CreateTextFrame(message);
78
79 SendInternal(std::move(frame), /*wait_for_handshake=*/true);
80 }
81
SendBinaryMessage(base::span<const uint8_t> message)82 void WebSocketConnection::SendBinaryMessage(base::span<const uint8_t> message) {
83 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
84 CHECK(stream_socket_);
85 scoped_refptr<IOBufferWithSize> frame = CreateBinaryFrame(message);
86 SendInternal(std::move(frame), /*wait_for_handshake=*/true);
87 }
88
StartClosingHandshake(std::optional<uint16_t> code,std::string_view message)89 void WebSocketConnection::StartClosingHandshake(std::optional<uint16_t> code,
90 std::string_view message) {
91 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
92 if (!stream_socket_) {
93 DVLOG(2) << "Attempted to start closing handshake, but socket is null.";
94 return;
95 }
96
97 DVLOG(3) << "Starting closing handshake. Code: "
98 << (code ? base::NumberToString(*code) : "none")
99 << ", Message: " << message;
100
101 if (!code) {
102 CHECK(base::IsStringUTF8AllowingNoncharacters(message));
103 SendInternal(BuildWebSocketFrame(base::span<const uint8_t>(),
104 WebSocketFrameHeader::kOpCodeClose),
105 /*wait_for_handshake=*/true);
106 state_ = WebSocketState::kWaitingForClientClose;
107 return;
108 }
109
110 scoped_refptr<IOBufferWithSize> close_frame = CreateCloseFrame(code, message);
111 SendInternal(std::move(close_frame), /*wait_for_handshake=*/true);
112
113 state_ = WebSocketState::kWaitingForClientClose;
114 }
115
RespondToCloseFrame(std::optional<uint16_t> code,std::string_view message)116 void WebSocketConnection::RespondToCloseFrame(std::optional<uint16_t> code,
117 std::string_view message) {
118 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
119 if (state_ == WebSocketState::kClosed) {
120 DVLOG(2) << "Attempted to respond to close frame, but connection is "
121 "already closed.";
122 return;
123 }
124
125 CHECK(base::IsStringUTF8AllowingNoncharacters(message));
126 scoped_refptr<IOBufferWithSize> close_frame = CreateCloseFrame(code, message);
127 SendInternal(std::move(close_frame), /*wait_for_handshake=*/false);
128 DisconnectAfterAnyWritesDone();
129 }
130
SendPing(base::span<const uint8_t> payload)131 void WebSocketConnection::SendPing(base::span<const uint8_t> payload) {
132 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
133 scoped_refptr<IOBufferWithSize> ping_frame = CreatePingFrame(payload);
134 SendInternal(std::move(ping_frame), /*wait_for_handshake=*/true);
135 }
136
SendPong(base::span<const uint8_t> payload)137 void WebSocketConnection::SendPong(base::span<const uint8_t> payload) {
138 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
139 scoped_refptr<IOBufferWithSize> pong_frame = CreatePongFrame(payload);
140 SendInternal(std::move(pong_frame), /*wait_for_handshake=*/true);
141 }
142
DisconnectAfterAnyWritesDone()143 void WebSocketConnection::DisconnectAfterAnyWritesDone() {
144 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
145 if (!stream_socket_) {
146 DVLOG(3) << "Socket is already disconnected.";
147 return;
148 }
149
150 if (!pending_buffer_) {
151 DisconnectImmediately();
152 return;
153 }
154
155 should_disconnect_after_write_ = true;
156 state_ = WebSocketState::kDisconnectingSoon;
157 handler_.reset();
158 }
159
DisconnectImmediately()160 void WebSocketConnection::DisconnectImmediately() {
161 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
162 if (!stream_socket_) {
163 DVLOG(3) << "Socket is already disconnected.";
164 handler_.reset();
165 return;
166 }
167
168 // Intentionally not calling Disconnect(), as it doesn't work with
169 // SSLServerSocket. Resetting the socket here is sufficient to disconnect.
170 ResetStreamSocket();
171 handler_.reset();
172 }
173
ResetStreamSocket()174 void WebSocketConnection::ResetStreamSocket() {
175 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
176 if (stream_socket_) {
177 stream_socket_.reset();
178 state_ = WebSocketState::kClosed;
179 }
180 // `this` may be deleted here.
181 }
182
SendRaw(base::span<const uint8_t> bytes)183 void WebSocketConnection::SendRaw(base::span<const uint8_t> bytes) {
184 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
185 scoped_refptr<IOBufferWithSize> buffer =
186 base::MakeRefCounted<IOBufferWithSize>(bytes.size());
187 buffer->span().copy_from(bytes);
188 SendInternal(std::move(buffer), /*wait_for_handshake=*/false);
189 }
190
SendInternal(scoped_refptr<IOBufferWithSize> buffer,bool wait_for_handshake)191 void WebSocketConnection::SendInternal(scoped_refptr<IOBufferWithSize> buffer,
192 bool wait_for_handshake) {
193 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
194 if ((wait_for_handshake && state_ != WebSocketState::kOpen) ||
195 pending_buffer_) {
196 pending_messages_.emplace(std::move(buffer));
197 return;
198 }
199
200 const size_t buffer_size = buffer->size();
201 pending_buffer_ =
202 base::MakeRefCounted<DrainableIOBuffer>(std::move(buffer), buffer_size);
203
204 PerformWrite();
205 }
206
SetHandler(std::unique_ptr<WebSocketHandler> handler)207 void WebSocketConnection::SetHandler(
208 std::unique_ptr<WebSocketHandler> handler) {
209 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
210 handler_ = std::move(handler);
211 }
212
PerformWrite()213 void WebSocketConnection::PerformWrite()
214 VALID_CONTEXT_REQUIRED(sequence_checker_) {
215 const int result = stream_socket_->Write(
216 pending_buffer_.get(), pending_buffer_->BytesRemaining(),
217 base::BindOnce(&WebSocketConnection::OnWriteComplete,
218 base::WrapRefCounted(this)),
219 DefineNetworkTrafficAnnotation(
220 "test", "Traffic annotation for unit, browser and other tests"));
221
222 if (result != ERR_IO_PENDING) {
223 OnWriteComplete(result);
224 }
225 }
226
OnWriteComplete(int result)227 void WebSocketConnection::OnWriteComplete(int result)
228 VALID_CONTEXT_REQUIRED(sequence_checker_) {
229 if (result < 0) {
230 DVLOG(1) << "Failed to write to WebSocket connection, error: " << result;
231 DisconnectImmediately();
232 return;
233 }
234
235 pending_buffer_->DidConsume(result);
236
237 if (pending_buffer_->BytesRemaining() > 0) {
238 PerformWrite();
239 return;
240 }
241
242 pending_buffer_ = nullptr;
243
244 if (!pending_messages_.empty()) {
245 scoped_refptr<IOBufferWithSize> next_message =
246 std::move(pending_messages_.front());
247 pending_messages_.pop();
248 SendInternal(std::move(next_message), /*wait_for_handshake=*/false);
249 return;
250 }
251
252 if (should_disconnect_after_write_) {
253 DisconnectImmediately();
254 }
255 }
256
Read()257 void WebSocketConnection::Read() VALID_CONTEXT_REQUIRED(sequence_checker_) {
258 read_buffer_ = base::MakeRefCounted<IOBufferWithSize>(4096);
259
260 const int result =
261 stream_socket_->Read(read_buffer_.get(), read_buffer_->size(),
262 base::BindOnce(&WebSocketConnection::OnReadComplete,
263 base::WrapRefCounted(this)));
264 if (result != ERR_IO_PENDING) {
265 OnReadComplete(result);
266 }
267 }
268
OnReadComplete(int result)269 void WebSocketConnection::OnReadComplete(int result)
270 VALID_CONTEXT_REQUIRED(sequence_checker_) {
271 if (result <= 0) {
272 DVLOG(1) << "Failed to read from WebSocket connection, error: " << result;
273 DisconnectImmediately();
274 return;
275 }
276
277 if (!handler_) {
278 DVLOG(1) << "No handler set, ignoring read.";
279 return;
280 }
281
282 base::span<uint8_t> data_span =
283 read_buffer_->span().first(static_cast<size_t>(result));
284
285 WebSocketFrameParser parser;
286 std::vector<std::unique_ptr<WebSocketFrameChunk>> frame_chunks;
287 parser.Decode(data_span, &frame_chunks);
288
289 for (auto& chunk : frame_chunks) {
290 auto assemble_result = chunk_assembler_.HandleChunk(std::move(chunk));
291
292 if (assemble_result.has_value()) {
293 std::unique_ptr<WebSocketFrame> assembled_frame =
294 std::move(assemble_result).value();
295 HandleFrame(assembled_frame->header.opcode,
296 base::as_chars(assembled_frame->payload),
297 assembled_frame->header.final);
298 continue;
299 }
300
301 if (assemble_result.error() == ERR_WS_PROTOCOL_ERROR) {
302 DVLOG(1) << "Protocol error while handling frame.";
303 StartClosingHandshake(1002, "Protocol error");
304 DisconnectAfterAnyWritesDone();
305 return;
306 }
307 }
308
309 if (stream_socket_) {
310 Read();
311 }
312 }
313
HandleFrame(WebSocketFrameHeader::OpCode opcode,base::span<const char> payload,bool is_final)314 void WebSocketConnection::HandleFrame(WebSocketFrameHeader::OpCode opcode,
315 base::span<const char> payload,
316 bool is_final)
317 VALID_CONTEXT_REQUIRED(sequence_checker_) {
318 CHECK(handler_) << "No handler set for WebSocket connection.";
319
320 switch (opcode) {
321 case WebSocketFrameHeader::kOpCodeText:
322 case WebSocketFrameHeader::kOpCodeBinary:
323 case WebSocketFrameHeader::kOpCodeContinuation: {
324 auto message_result =
325 message_assembler_.HandleFrame(is_final, opcode, payload);
326
327 if (message_result.has_value()) {
328 if (message_result->is_text_message) {
329 handler_->OnTextMessage(base::as_string_view(message_result->body));
330 } else {
331 handler_->OnBinaryMessage(message_result->body);
332 }
333 } else if (message_result.error() == ERR_WS_PROTOCOL_ERROR) {
334 StartClosingHandshake(1002, "Protocol error");
335 DisconnectAfterAnyWritesDone();
336 }
337
338 break;
339 }
340 case WebSocketFrameHeader::kOpCodeClose: {
341 auto parse_close_frame_result = ParseCloseFrame(payload);
342 if (parse_close_frame_result.error.has_value()) {
343 DVLOG(1) << "Failed to parse close frame: "
344 << parse_close_frame_result.error.value();
345 StartClosingHandshake(1002, "Protocol error");
346 DisconnectAfterAnyWritesDone();
347 } else {
348 handler_->OnClosingHandshake(parse_close_frame_result.code,
349 parse_close_frame_result.reason);
350 }
351 break;
352 }
353 case WebSocketFrameHeader::kOpCodePing:
354 handler_->OnPing(base::as_bytes(payload));
355 break;
356 case WebSocketFrameHeader::kOpCodePong:
357 handler_->OnPong(base::as_bytes(payload));
358 break;
359 default:
360 DVLOG(2) << "Unknown frame opcode: " << opcode;
361 StartClosingHandshake(1002, "Protocol error");
362 DisconnectAfterAnyWritesDone();
363 break;
364 }
365 }
366
SendHandshakeResponse()367 void WebSocketConnection::SendHandshakeResponse() {
368 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
369
370 if (!stream_socket_) {
371 DVLOG(2) << "Stream socket is already null. Returning early.";
372 return;
373 }
374
375 std::string response_text = "HTTP/1.1 101 Switching Protocols\r\n";
376 for (const auto& header : response_headers_) {
377 base::StrAppend(&response_text,
378 {header.first, ": ", header.second, "\r\n"});
379 }
380 base::StrAppend(&response_text, {"\r\n"});
381
382 SendRaw(base::as_byte_span(response_text));
383
384 state_ = WebSocketState::kOpen;
385
386 Read();
387
388 // A nullptr check is performed because the connection may have been closed
389 // within Read().
390 if (handler_) {
391 handler_->OnHandshakeComplete();
392 } else {
393 DVLOG(2)
394 << "Handler is null after starting Read. Connection likely closed.";
395 }
396 }
397
CreateTextFrame(std::string_view message)398 scoped_refptr<IOBufferWithSize> CreateTextFrame(std::string_view message) {
399 return BuildWebSocketFrame(base::as_byte_span(message),
400 WebSocketFrameHeader::kOpCodeText);
401 }
402
CreateBinaryFrame(base::span<const uint8_t> message)403 scoped_refptr<IOBufferWithSize> CreateBinaryFrame(
404 base::span<const uint8_t> message) {
405 return BuildWebSocketFrame(message, WebSocketFrameHeader::kOpCodeBinary);
406 }
407
CreateCloseFrame(std::optional<uint16_t> code,std::string_view message)408 scoped_refptr<IOBufferWithSize> CreateCloseFrame(std::optional<uint16_t> code,
409 std::string_view message) {
410 DVLOG(3) << "Creating close frame with code: "
411 << (code ? base::NumberToString(*code) : "none")
412 << ", Message: " << message;
413 CHECK(message.empty() || code);
414 CHECK(base::IsStringUTF8AllowingNoncharacters(message));
415
416 if (!code) {
417 return BuildWebSocketFrame(base::span<const uint8_t>(),
418 WebSocketFrameHeader::kOpCodeClose);
419 }
420
421 auto payload =
422 base::HeapArray<uint8_t>::Uninit(sizeof(uint16_t) + message.size());
423 base::SpanWriter<uint8_t> writer{payload};
424 writer.WriteU16BigEndian(code.value());
425 writer.Write(base::as_byte_span(message));
426
427 return BuildWebSocketFrame(payload, WebSocketFrameHeader::kOpCodeClose);
428 }
429
CreatePingFrame(base::span<const uint8_t> payload)430 scoped_refptr<IOBufferWithSize> CreatePingFrame(
431 base::span<const uint8_t> payload) {
432 return BuildWebSocketFrame(payload, WebSocketFrameHeader::kOpCodePing);
433 }
434
CreatePongFrame(base::span<const uint8_t> payload)435 scoped_refptr<IOBufferWithSize> CreatePongFrame(
436 base::span<const uint8_t> payload) {
437 return BuildWebSocketFrame(payload, WebSocketFrameHeader::kOpCodePong);
438 }
439
BuildWebSocketFrame(base::span<const uint8_t> payload,WebSocketFrameHeader::OpCode op_code)440 scoped_refptr<IOBufferWithSize> BuildWebSocketFrame(
441 base::span<const uint8_t> payload,
442 WebSocketFrameHeader::OpCode op_code) {
443 WebSocketFrameHeader header(op_code);
444 header.final = true;
445 header.payload_length = payload.size();
446
447 const size_t header_size = GetWebSocketFrameHeaderSize(header);
448
449 scoped_refptr<IOBufferWithSize> buffer =
450 base::MakeRefCounted<IOBufferWithSize>(header_size + payload.size());
451
452 const int written_header_size =
453 WriteWebSocketFrameHeader(header, nullptr, buffer->span());
454 base::span<uint8_t> buffer_span = buffer->span().subspan(
455 base::checked_cast<size_t>(written_header_size), payload.size());
456 buffer_span.copy_from(payload);
457
458 return buffer;
459 }
460
461 } // namespace net::test_server
462