• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2024 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/test/embedded_test_server/websocket_connection.h"
6 
7 #include <stdint.h>
8 
9 #include "base/compiler_specific.h"
10 #include "base/containers/extend.h"
11 #include "base/containers/span.h"
12 #include "base/containers/span_writer.h"
13 #include "base/functional/bind.h"
14 #include "base/logging.h"
15 #include "base/memory/scoped_refptr.h"
16 #include "base/numerics/byte_conversions.h"
17 #include "base/numerics/safe_conversions.h"
18 #include "base/strings/strcat.h"
19 #include "net/base/net_errors.h"
20 #include "net/socket/socket.h"
21 #include "net/socket/stream_socket.h"
22 #include "net/test/embedded_test_server/websocket_handler.h"
23 #include "net/test/embedded_test_server/websocket_message_assembler.h"
24 #include "net/websockets/websocket_frame.h"
25 #include "net/websockets/websocket_frame_parser.h"
26 #include "net/websockets/websocket_handshake_challenge.h"
27 
28 namespace net::test_server {
29 
WebSocketConnection(std::unique_ptr<StreamSocket> socket,std::string_view sec_websocket_key,EmbeddedTestServer * server)30 WebSocketConnection::WebSocketConnection(std::unique_ptr<StreamSocket> socket,
31                                          std::string_view sec_websocket_key,
32                                          EmbeddedTestServer* server)
33     : stream_socket_(std::move(socket)),
34       // Register a shutdown closure to safely disconnect this connection when
35       // the
36       // server shuts down. base::Unretained is safe here because:
37       // 1. The shutdown closure is registered during the construction of the
38       //    WebSocketConnection object, ensuring `this` is fully initialized.
39       // 2. The lifetime of the closure is tied to the `WebSocketConnection`
40       //    object via `shutdown_subscription_`, which ensures that the closure
41       //    is automatically unregistered when the object is destroyed.
42       // 3. DisconnectImmediately() ensures safe cleanup by resetting the socket
43       //    and marking the connection state as closed.
44       shutdown_subscription_(server->RegisterShutdownClosure(
45           base::BindOnce(&WebSocketConnection::DisconnectImmediately,
46                          base::Unretained(this)))) {
47   CHECK(stream_socket_);
48 
49   response_headers_.emplace_back("Upgrade", "websocket");
50   response_headers_.emplace_back("Connection", "Upgrade");
51   response_headers_.emplace_back(
52       "Sec-WebSocket-Accept",
53       ComputeSecWebSocketAccept(std::string(sec_websocket_key)));
54 }
55 
~WebSocketConnection()56 WebSocketConnection::~WebSocketConnection() {
57   DisconnectImmediately();
58 }
59 
SetResponseHeader(std::string_view name,std::string_view value)60 void WebSocketConnection::SetResponseHeader(std::string_view name,
61                                             std::string_view value) {
62   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
63   CHECK(stream_socket_);
64   for (auto& header : response_headers_) {
65     if (header.first == name) {
66       header.second = value;
67       return;
68     }
69   }
70   response_headers_.emplace_back(name, value);
71 }
72 
SendTextMessage(std::string_view message)73 void WebSocketConnection::SendTextMessage(std::string_view message) {
74   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
75   CHECK(stream_socket_);
76   CHECK(base::IsStringUTF8AllowingNoncharacters(message));
77   scoped_refptr<IOBufferWithSize> frame = CreateTextFrame(message);
78 
79   SendInternal(std::move(frame), /*wait_for_handshake=*/true);
80 }
81 
SendBinaryMessage(base::span<const uint8_t> message)82 void WebSocketConnection::SendBinaryMessage(base::span<const uint8_t> message) {
83   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
84   CHECK(stream_socket_);
85   scoped_refptr<IOBufferWithSize> frame = CreateBinaryFrame(message);
86   SendInternal(std::move(frame), /*wait_for_handshake=*/true);
87 }
88 
StartClosingHandshake(std::optional<uint16_t> code,std::string_view message)89 void WebSocketConnection::StartClosingHandshake(std::optional<uint16_t> code,
90                                                 std::string_view message) {
91   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
92   if (!stream_socket_) {
93     DVLOG(2) << "Attempted to start closing handshake, but socket is null.";
94     return;
95   }
96 
97   DVLOG(3) << "Starting closing handshake. Code: "
98            << (code ? base::NumberToString(*code) : "none")
99            << ", Message: " << message;
100 
101   if (!code) {
102     CHECK(base::IsStringUTF8AllowingNoncharacters(message));
103     SendInternal(BuildWebSocketFrame(base::span<const uint8_t>(),
104                                      WebSocketFrameHeader::kOpCodeClose),
105                  /*wait_for_handshake=*/true);
106     state_ = WebSocketState::kWaitingForClientClose;
107     return;
108   }
109 
110   scoped_refptr<IOBufferWithSize> close_frame = CreateCloseFrame(code, message);
111   SendInternal(std::move(close_frame), /*wait_for_handshake=*/true);
112 
113   state_ = WebSocketState::kWaitingForClientClose;
114 }
115 
RespondToCloseFrame(std::optional<uint16_t> code,std::string_view message)116 void WebSocketConnection::RespondToCloseFrame(std::optional<uint16_t> code,
117                                               std::string_view message) {
118   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
119   if (state_ == WebSocketState::kClosed) {
120     DVLOG(2) << "Attempted to respond to close frame, but connection is "
121                 "already closed.";
122     return;
123   }
124 
125   CHECK(base::IsStringUTF8AllowingNoncharacters(message));
126   scoped_refptr<IOBufferWithSize> close_frame = CreateCloseFrame(code, message);
127   SendInternal(std::move(close_frame), /*wait_for_handshake=*/false);
128   DisconnectAfterAnyWritesDone();
129 }
130 
SendPing(base::span<const uint8_t> payload)131 void WebSocketConnection::SendPing(base::span<const uint8_t> payload) {
132   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
133   scoped_refptr<IOBufferWithSize> ping_frame = CreatePingFrame(payload);
134   SendInternal(std::move(ping_frame), /*wait_for_handshake=*/true);
135 }
136 
SendPong(base::span<const uint8_t> payload)137 void WebSocketConnection::SendPong(base::span<const uint8_t> payload) {
138   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
139   scoped_refptr<IOBufferWithSize> pong_frame = CreatePongFrame(payload);
140   SendInternal(std::move(pong_frame), /*wait_for_handshake=*/true);
141 }
142 
DisconnectAfterAnyWritesDone()143 void WebSocketConnection::DisconnectAfterAnyWritesDone() {
144   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
145   if (!stream_socket_) {
146     DVLOG(3) << "Socket is already disconnected.";
147     return;
148   }
149 
150   if (!pending_buffer_) {
151     DisconnectImmediately();
152     return;
153   }
154 
155   should_disconnect_after_write_ = true;
156   state_ = WebSocketState::kDisconnectingSoon;
157   handler_.reset();
158 }
159 
DisconnectImmediately()160 void WebSocketConnection::DisconnectImmediately() {
161   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
162   if (!stream_socket_) {
163     DVLOG(3) << "Socket is already disconnected.";
164     handler_.reset();
165     return;
166   }
167 
168   // Intentionally not calling Disconnect(), as it doesn't work with
169   // SSLServerSocket. Resetting the socket here is sufficient to disconnect.
170   ResetStreamSocket();
171   handler_.reset();
172 }
173 
ResetStreamSocket()174 void WebSocketConnection::ResetStreamSocket() {
175   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
176   if (stream_socket_) {
177     stream_socket_.reset();
178     state_ = WebSocketState::kClosed;
179   }
180   // `this` may be deleted here.
181 }
182 
SendRaw(base::span<const uint8_t> bytes)183 void WebSocketConnection::SendRaw(base::span<const uint8_t> bytes) {
184   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
185   scoped_refptr<IOBufferWithSize> buffer =
186       base::MakeRefCounted<IOBufferWithSize>(bytes.size());
187   buffer->span().copy_from(bytes);
188   SendInternal(std::move(buffer), /*wait_for_handshake=*/false);
189 }
190 
SendInternal(scoped_refptr<IOBufferWithSize> buffer,bool wait_for_handshake)191 void WebSocketConnection::SendInternal(scoped_refptr<IOBufferWithSize> buffer,
192                                        bool wait_for_handshake) {
193   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
194   if ((wait_for_handshake && state_ != WebSocketState::kOpen) ||
195       pending_buffer_) {
196     pending_messages_.emplace(std::move(buffer));
197     return;
198   }
199 
200   const size_t buffer_size = buffer->size();
201   pending_buffer_ =
202       base::MakeRefCounted<DrainableIOBuffer>(std::move(buffer), buffer_size);
203 
204   PerformWrite();
205 }
206 
SetHandler(std::unique_ptr<WebSocketHandler> handler)207 void WebSocketConnection::SetHandler(
208     std::unique_ptr<WebSocketHandler> handler) {
209   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
210   handler_ = std::move(handler);
211 }
212 
PerformWrite()213 void WebSocketConnection::PerformWrite()
214     VALID_CONTEXT_REQUIRED(sequence_checker_) {
215   const int result = stream_socket_->Write(
216       pending_buffer_.get(), pending_buffer_->BytesRemaining(),
217       base::BindOnce(&WebSocketConnection::OnWriteComplete,
218                      base::WrapRefCounted(this)),
219       DefineNetworkTrafficAnnotation(
220           "test", "Traffic annotation for unit, browser and other tests"));
221 
222   if (result != ERR_IO_PENDING) {
223     OnWriteComplete(result);
224   }
225 }
226 
OnWriteComplete(int result)227 void WebSocketConnection::OnWriteComplete(int result)
228     VALID_CONTEXT_REQUIRED(sequence_checker_) {
229   if (result < 0) {
230     DVLOG(1) << "Failed to write to WebSocket connection, error: " << result;
231     DisconnectImmediately();
232     return;
233   }
234 
235   pending_buffer_->DidConsume(result);
236 
237   if (pending_buffer_->BytesRemaining() > 0) {
238     PerformWrite();
239     return;
240   }
241 
242   pending_buffer_ = nullptr;
243 
244   if (!pending_messages_.empty()) {
245     scoped_refptr<IOBufferWithSize> next_message =
246         std::move(pending_messages_.front());
247     pending_messages_.pop();
248     SendInternal(std::move(next_message), /*wait_for_handshake=*/false);
249     return;
250   }
251 
252   if (should_disconnect_after_write_) {
253     DisconnectImmediately();
254   }
255 }
256 
Read()257 void WebSocketConnection::Read() VALID_CONTEXT_REQUIRED(sequence_checker_) {
258   read_buffer_ = base::MakeRefCounted<IOBufferWithSize>(4096);
259 
260   const int result =
261       stream_socket_->Read(read_buffer_.get(), read_buffer_->size(),
262                            base::BindOnce(&WebSocketConnection::OnReadComplete,
263                                           base::WrapRefCounted(this)));
264   if (result != ERR_IO_PENDING) {
265     OnReadComplete(result);
266   }
267 }
268 
OnReadComplete(int result)269 void WebSocketConnection::OnReadComplete(int result)
270     VALID_CONTEXT_REQUIRED(sequence_checker_) {
271   if (result <= 0) {
272     DVLOG(1) << "Failed to read from WebSocket connection, error: " << result;
273     DisconnectImmediately();
274     return;
275   }
276 
277   if (!handler_) {
278     DVLOG(1) << "No handler set, ignoring read.";
279     return;
280   }
281 
282   base::span<uint8_t> data_span =
283       read_buffer_->span().first(static_cast<size_t>(result));
284 
285   WebSocketFrameParser parser;
286   std::vector<std::unique_ptr<WebSocketFrameChunk>> frame_chunks;
287   parser.Decode(data_span, &frame_chunks);
288 
289   for (auto& chunk : frame_chunks) {
290     auto assemble_result = chunk_assembler_.HandleChunk(std::move(chunk));
291 
292     if (assemble_result.has_value()) {
293       std::unique_ptr<WebSocketFrame> assembled_frame =
294           std::move(assemble_result).value();
295       HandleFrame(assembled_frame->header.opcode,
296                   base::as_chars(assembled_frame->payload),
297                   assembled_frame->header.final);
298       continue;
299     }
300 
301     if (assemble_result.error() == ERR_WS_PROTOCOL_ERROR) {
302       DVLOG(1) << "Protocol error while handling frame.";
303       StartClosingHandshake(1002, "Protocol error");
304       DisconnectAfterAnyWritesDone();
305       return;
306     }
307   }
308 
309   if (stream_socket_) {
310     Read();
311   }
312 }
313 
HandleFrame(WebSocketFrameHeader::OpCode opcode,base::span<const char> payload,bool is_final)314 void WebSocketConnection::HandleFrame(WebSocketFrameHeader::OpCode opcode,
315                                       base::span<const char> payload,
316                                       bool is_final)
317     VALID_CONTEXT_REQUIRED(sequence_checker_) {
318   CHECK(handler_) << "No handler set for WebSocket connection.";
319 
320   switch (opcode) {
321     case WebSocketFrameHeader::kOpCodeText:
322     case WebSocketFrameHeader::kOpCodeBinary:
323     case WebSocketFrameHeader::kOpCodeContinuation: {
324       auto message_result =
325           message_assembler_.HandleFrame(is_final, opcode, payload);
326 
327       if (message_result.has_value()) {
328         if (message_result->is_text_message) {
329           handler_->OnTextMessage(base::as_string_view(message_result->body));
330         } else {
331           handler_->OnBinaryMessage(message_result->body);
332         }
333       } else if (message_result.error() == ERR_WS_PROTOCOL_ERROR) {
334         StartClosingHandshake(1002, "Protocol error");
335         DisconnectAfterAnyWritesDone();
336       }
337 
338       break;
339     }
340     case WebSocketFrameHeader::kOpCodeClose: {
341       auto parse_close_frame_result = ParseCloseFrame(payload);
342       if (parse_close_frame_result.error.has_value()) {
343         DVLOG(1) << "Failed to parse close frame: "
344                  << parse_close_frame_result.error.value();
345         StartClosingHandshake(1002, "Protocol error");
346         DisconnectAfterAnyWritesDone();
347       } else {
348         handler_->OnClosingHandshake(parse_close_frame_result.code,
349                                      parse_close_frame_result.reason);
350       }
351       break;
352     }
353     case WebSocketFrameHeader::kOpCodePing:
354       handler_->OnPing(base::as_bytes(payload));
355       break;
356     case WebSocketFrameHeader::kOpCodePong:
357       handler_->OnPong(base::as_bytes(payload));
358       break;
359     default:
360       DVLOG(2) << "Unknown frame opcode: " << opcode;
361       StartClosingHandshake(1002, "Protocol error");
362       DisconnectAfterAnyWritesDone();
363       break;
364   }
365 }
366 
SendHandshakeResponse()367 void WebSocketConnection::SendHandshakeResponse() {
368   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
369 
370   if (!stream_socket_) {
371     DVLOG(2) << "Stream socket is already null. Returning early.";
372     return;
373   }
374 
375   std::string response_text = "HTTP/1.1 101 Switching Protocols\r\n";
376   for (const auto& header : response_headers_) {
377     base::StrAppend(&response_text,
378                     {header.first, ": ", header.second, "\r\n"});
379   }
380   base::StrAppend(&response_text, {"\r\n"});
381 
382   SendRaw(base::as_byte_span(response_text));
383 
384   state_ = WebSocketState::kOpen;
385 
386   Read();
387 
388   // A nullptr check is performed because the connection may have been closed
389   // within Read().
390   if (handler_) {
391     handler_->OnHandshakeComplete();
392   } else {
393     DVLOG(2)
394         << "Handler is null after starting Read. Connection likely closed.";
395   }
396 }
397 
CreateTextFrame(std::string_view message)398 scoped_refptr<IOBufferWithSize> CreateTextFrame(std::string_view message) {
399   return BuildWebSocketFrame(base::as_byte_span(message),
400                              WebSocketFrameHeader::kOpCodeText);
401 }
402 
CreateBinaryFrame(base::span<const uint8_t> message)403 scoped_refptr<IOBufferWithSize> CreateBinaryFrame(
404     base::span<const uint8_t> message) {
405   return BuildWebSocketFrame(message, WebSocketFrameHeader::kOpCodeBinary);
406 }
407 
CreateCloseFrame(std::optional<uint16_t> code,std::string_view message)408 scoped_refptr<IOBufferWithSize> CreateCloseFrame(std::optional<uint16_t> code,
409                                                  std::string_view message) {
410   DVLOG(3) << "Creating close frame with code: "
411            << (code ? base::NumberToString(*code) : "none")
412            << ", Message: " << message;
413   CHECK(message.empty() || code);
414   CHECK(base::IsStringUTF8AllowingNoncharacters(message));
415 
416   if (!code) {
417     return BuildWebSocketFrame(base::span<const uint8_t>(),
418                                WebSocketFrameHeader::kOpCodeClose);
419   }
420 
421   auto payload =
422       base::HeapArray<uint8_t>::Uninit(sizeof(uint16_t) + message.size());
423   base::SpanWriter<uint8_t> writer{payload};
424   writer.WriteU16BigEndian(code.value());
425   writer.Write(base::as_byte_span(message));
426 
427   return BuildWebSocketFrame(payload, WebSocketFrameHeader::kOpCodeClose);
428 }
429 
CreatePingFrame(base::span<const uint8_t> payload)430 scoped_refptr<IOBufferWithSize> CreatePingFrame(
431     base::span<const uint8_t> payload) {
432   return BuildWebSocketFrame(payload, WebSocketFrameHeader::kOpCodePing);
433 }
434 
CreatePongFrame(base::span<const uint8_t> payload)435 scoped_refptr<IOBufferWithSize> CreatePongFrame(
436     base::span<const uint8_t> payload) {
437   return BuildWebSocketFrame(payload, WebSocketFrameHeader::kOpCodePong);
438 }
439 
BuildWebSocketFrame(base::span<const uint8_t> payload,WebSocketFrameHeader::OpCode op_code)440 scoped_refptr<IOBufferWithSize> BuildWebSocketFrame(
441     base::span<const uint8_t> payload,
442     WebSocketFrameHeader::OpCode op_code) {
443   WebSocketFrameHeader header(op_code);
444   header.final = true;
445   header.payload_length = payload.size();
446 
447   const size_t header_size = GetWebSocketFrameHeaderSize(header);
448 
449   scoped_refptr<IOBufferWithSize> buffer =
450       base::MakeRefCounted<IOBufferWithSize>(header_size + payload.size());
451 
452   const int written_header_size =
453       WriteWebSocketFrameHeader(header, nullptr, buffer->span());
454   base::span<uint8_t> buffer_span = buffer->span().subspan(
455       base::checked_cast<size_t>(written_header_size), payload.size());
456   buffer_span.copy_from(payload);
457 
458   return buffer;
459 }
460 
461 }  // namespace net::test_server
462