1 // Copyright 2024 The Chromium Authors 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #ifndef NET_TEST_EMBEDDED_TEST_SERVER_WEBSOCKET_CONNECTION_H_ 6 #define NET_TEST_EMBEDDED_TEST_SERVER_WEBSOCKET_CONNECTION_H_ 7 8 #include <stddef.h> 9 #include <stdint.h> 10 11 #include <memory> 12 #include <optional> 13 #include <queue> 14 #include <string_view> 15 16 #include "base/containers/span.h" 17 #include "base/memory/scoped_refptr.h" 18 #include "base/sequence_checker.h" 19 #include "net/base/io_buffer.h" 20 #include "net/test/embedded_test_server/embedded_test_server.h" 21 #include "net/test/embedded_test_server/websocket_message_assembler.h" 22 #include "net/websockets/websocket_chunk_assembler.h" 23 #include "net/websockets/websocket_frame.h" 24 25 namespace net { 26 27 class StreamSocket; 28 29 namespace test_server { 30 31 class WebSocketHandler; 32 33 class WebSocketConnection final : public base::RefCounted<WebSocketConnection> { 34 public: 35 WebSocketConnection(const WebSocketConnection&) = delete; 36 WebSocketConnection& operator=(const WebSocketConnection&) = delete; 37 38 // Constructor initializes the WebSocket connection with a given socket and 39 // prepares for the WebSocket handshake by setting up necessary headers. 40 explicit WebSocketConnection(std::unique_ptr<StreamSocket> socket, 41 std::string_view sec_websocket_key, 42 EmbeddedTestServer* server); 43 44 // Adds or replaces the response header with name `name`. Should only be 45 // called from WebSocketHandler::OnHandshake(). 46 void SetResponseHeader(std::string_view name, std::string_view value); 47 48 // Send a text message. Can be called in OnHandshake(), in which case the 49 // message will be queued to be sent immediately after the response headers. 50 // Can be called at any time up until WebSocketHandler::OnClosingHandshake(), 51 // WebSocketConnection::StartClosingHandshake(), 52 // WebSocketConnection::DisconnectAfterAnyWritesDone() or 53 // WebSocketConnection::DisconnectImmediately() is called. 54 void SendTextMessage(std::string_view message); 55 56 // Send a binary message. Can be called as with SendTextMessage(). 57 void SendBinaryMessage(base::span<const uint8_t> message); 58 59 // Send a CLOSE frame with `code` and `message`. If `code` is std::nullopt 60 // then an empty CLOSE frame will be sent. Initiates a close handshake from 61 // the server side. 62 void StartClosingHandshake(std::optional<uint16_t> code, 63 std::string_view message); 64 65 // Responds to a CLOSE frame received from the client. If `code` is 66 // std::nullopt then an empty CLOSE frame will be sent. 67 void RespondToCloseFrame(std::optional<uint16_t> code, 68 std::string_view message); 69 70 // Send a PING frame. The payload is optional and can be omitted or included 71 // based on the application logic. 72 void SendPing(base::span<const uint8_t> payload); 73 74 // Send a PONG frame. The payload is optional and can be omitted or included 75 // based on the application logic. 76 void SendPong(base::span<const uint8_t> payload); 77 78 // Delete the handler, scheduling a disconnect after any pending writes are 79 // completed. 80 void DisconnectAfterAnyWritesDone(); 81 82 // Sends `bytes` as-is directly on stream. Can be called from 83 // WebSocketHandler::OnHandshake() to send data before the normal 84 // response header. After OnHandshake() returns, can be used to send invalid 85 // WebSocket frames. 86 void SendRaw(base::span<const uint8_t> bytes); 87 88 // Sends the handshake response after headers are set. 89 void SendHandshakeResponse(); 90 91 // Set the WebSocketHandler instance for this connection. 92 void SetHandler(std::unique_ptr<WebSocketHandler> handler); 93 94 private: 95 friend class base::RefCounted<WebSocketConnection>; 96 97 // Enum to represent the current state of the WebSocket connection. 98 // For managing transitions between different phases of the WebSocket 99 // lifecycle. 100 enum class WebSocketState { 101 kHandshakeInProgress, 102 kOpen, 103 kWaitingForClientClose, 104 kDisconnectingSoon, 105 kClosed 106 }; 107 108 ~WebSocketConnection(); 109 110 // Internal function to immediately disconnect, deleting the handler and 111 // closing the socket. 112 void DisconnectImmediately(); 113 114 // Internal function to reset the stream socket. 115 void ResetStreamSocket(); 116 117 void PerformWrite() VALID_CONTEXT_REQUIRED(sequence_checker_); 118 void OnWriteComplete(int result) VALID_CONTEXT_REQUIRED(sequence_checker_); 119 void Read() VALID_CONTEXT_REQUIRED(sequence_checker_); 120 void OnReadComplete(int result) VALID_CONTEXT_REQUIRED(sequence_checker_); 121 122 // Handles incoming WebSocket frames of different opcodes: text, binary, 123 // and continuation frames. Based on the frame's opcode and whether the 124 // frame is marked as final (`is_final`), the payload is processed and 125 // dispatched accordingly. `is_final` determines if the frame completes the 126 // current message. 127 void HandleFrame(WebSocketFrameHeader::OpCode opcode, 128 base::span<const char> payload, 129 bool is_final) VALID_CONTEXT_REQUIRED(sequence_checker_); 130 131 // Internal function to handle sending buffers. 132 // `wait_for_handshake`: If true, the message will be queued until the 133 // handshake is complete. 134 void SendInternal(scoped_refptr<IOBufferWithSize> buffer, 135 bool wait_for_handshake); 136 137 std::unique_ptr<StreamSocket> stream_socket_; 138 base::StringPairs response_headers_; 139 std::unique_ptr<WebSocketHandler> handler_; 140 141 // Messages that are pending until the handshake is complete or until a 142 // previous write is completed. 143 std::queue<scoped_refptr<IOBufferWithSize>> pending_messages_; 144 145 // Tracks pending bytes to be written, used for handling partial writes. 146 scoped_refptr<DrainableIOBuffer> pending_buffer_; 147 148 scoped_refptr<IOBufferWithSize> read_buffer_; 149 150 // The current state of the WebSocket connection, such as OPEN or CLOSED. 151 WebSocketState state_ = WebSocketState::kHandshakeInProgress; 152 153 // Flag to indicate if a disconnect should be performed after write 154 // completion. 155 bool should_disconnect_after_write_ = false; 156 157 // Assembles fragmented frames into full messages. 158 WebSocketMessageAssembler message_assembler_; 159 160 // Handles assembling fragmented WebSocket frame chunks. 161 WebSocketChunkAssembler chunk_assembler_; 162 163 // Subscription to the shutdown closure in EmbeddedTestServer. 164 base::CallbackListSubscription shutdown_subscription_; 165 166 SEQUENCE_CHECKER(sequence_checker_); 167 }; 168 169 // Methods to create specific WebSocket frames. 170 scoped_refptr<IOBufferWithSize> CreateTextFrame(std::string_view message); 171 scoped_refptr<IOBufferWithSize> CreateBinaryFrame( 172 base::span<const uint8_t> message); 173 scoped_refptr<IOBufferWithSize> CreateCloseFrame(std::optional<uint16_t> code, 174 std::string_view message); 175 scoped_refptr<IOBufferWithSize> CreatePingFrame( 176 base::span<const uint8_t> payload); 177 scoped_refptr<IOBufferWithSize> CreatePongFrame( 178 base::span<const uint8_t> payload); 179 180 // Helper for building WebSocket frames (both data and control frames). 181 scoped_refptr<IOBufferWithSize> BuildWebSocketFrame( 182 base::span<const uint8_t> payload, 183 WebSocketFrameHeader::OpCode op_code); 184 185 } // namespace test_server 186 187 } // namespace net 188 189 #endif // NET_TEST_EMBEDDED_TEST_SERVER_WEBSOCKET_CONNECTION_H_ 190