1 // Copyright 2013 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/websockets/websocket_basic_stream.h"
6
7 #include <stddef.h>
8 #include <stdint.h>
9
10 #include <algorithm>
11 #include <limits>
12 #include <ostream>
13 #include <utility>
14
15 #include "base/check.h"
16 #include "base/check_op.h"
17 #include "base/containers/span.h"
18 #include "base/functional/bind.h"
19 #include "base/functional/callback.h"
20 #include "base/logging.h"
21 #include "base/numerics/safe_conversions.h"
22 #include "base/values.h"
23 #include "build/build_config.h"
24 #include "net/base/io_buffer.h"
25 #include "net/base/net_errors.h"
26 #include "net/log/net_log_event_type.h"
27 #include "net/socket/client_socket_handle.h"
28 #include "net/traffic_annotation/network_traffic_annotation.h"
29 #include "net/websockets/websocket_basic_stream_adapters.h"
30 #include "net/websockets/websocket_errors.h"
31 #include "net/websockets/websocket_frame.h"
32
33 namespace net {
34
35 namespace {
36
37 // Please refer to the comment in class header if the usage changes.
38 constexpr net::NetworkTrafficAnnotationTag kTrafficAnnotation =
39 net::DefineNetworkTrafficAnnotation("websocket_basic_stream", R"(
40 semantics {
41 sender: "WebSocket Basic Stream"
42 description:
43 "Implementation of WebSocket API from web content (a page the user "
44 "visits)."
45 trigger: "Website calls the WebSocket API."
46 data:
47 "Any data provided by web content, masked and framed in accordance "
48 "with RFC6455."
49 destination: OTHER
50 destination_other:
51 "The address that the website has chosen to communicate to."
52 }
53 policy {
54 cookies_allowed: YES
55 cookies_store: "user"
56 setting: "These requests cannot be disabled."
57 policy_exception_justification:
58 "Not implemented. WebSocket is a core web platform API."
59 }
60 comments:
61 "The browser will never add cookies to a WebSocket message. But the "
62 "handshake that was performed when the WebSocket connection was "
63 "established may have contained cookies."
64 )");
65
66 // The number of bytes to attempt to read at a time. It's used only for high
67 // throughput connections.
68 // TODO(ricea): See if there is a better number or algorithm to fulfill our
69 // requirements:
70 // 1. We would like to use minimal memory on low-bandwidth or idle connections
71 // 2. We would like to read as close to line speed as possible on
72 // high-bandwidth connections
73 // 3. We can't afford to cause jank on the IO thread by copying large buffers
74 // around
75 // 4. We would like to hit any sweet-spots that might exist in terms of network
76 // packet sizes / encryption block sizes / IPC alignment issues, etc.
77 #if BUILDFLAG(IS_ANDROID)
78 constexpr size_t kLargeReadBufferSize = 32 * 1024;
79 #else
80 // |2^n - delta| is better than 2^n on Linux. See crrev.com/c/1792208.
81 constexpr size_t kLargeReadBufferSize = 131000;
82 #endif
83
84 // The number of bytes to attempt to read at a time. It's set as an initial read
85 // buffer size and used for low throughput connections.
86 constexpr size_t kSmallReadBufferSize = 1000;
87
88 // The threshold to decide whether to switch the read buffer size.
89 constexpr double kThresholdInBytesPerSecond = 1200 * 1000;
90
91 // Returns the total serialized size of |frames|. This function assumes that
92 // |frames| will be serialized with mask field. This function forces the
93 // masked bit of the frames on.
CalculateSerializedSizeAndTurnOnMaskBit(std::vector<std::unique_ptr<WebSocketFrame>> * frames)94 int CalculateSerializedSizeAndTurnOnMaskBit(
95 std::vector<std::unique_ptr<WebSocketFrame>>* frames) {
96 constexpr uint64_t kMaximumTotalSize = std::numeric_limits<int>::max();
97
98 uint64_t total_size = 0;
99 for (const auto& frame : *frames) {
100 // Force the masked bit on.
101 frame->header.masked = true;
102 // We enforce flow control so the renderer should never be able to force us
103 // to cache anywhere near 2GB of frames.
104 uint64_t frame_size = frame->header.payload_length +
105 GetWebSocketFrameHeaderSize(frame->header);
106 CHECK_LE(frame_size, kMaximumTotalSize - total_size)
107 << "Aborting to prevent overflow";
108 total_size += frame_size;
109 }
110 return static_cast<int>(total_size);
111 }
112
NetLogBufferSizeParam(int buffer_size)113 base::Value::Dict NetLogBufferSizeParam(int buffer_size) {
114 base::Value::Dict dict;
115 dict.Set("read_buffer_size_in_bytes", buffer_size);
116 return dict;
117 }
118
NetLogFrameHeaderParam(const WebSocketFrameHeader * header)119 base::Value::Dict NetLogFrameHeaderParam(const WebSocketFrameHeader* header) {
120 base::Value::Dict dict;
121 dict.Set("final", header->final);
122 dict.Set("reserved1", header->reserved1);
123 dict.Set("reserved2", header->reserved2);
124 dict.Set("reserved3", header->reserved3);
125 dict.Set("opcode", header->opcode);
126 dict.Set("masked", header->masked);
127 dict.Set("payload_length", static_cast<double>(header->payload_length));
128 return dict;
129 }
130
131 } // namespace
132
133 WebSocketBasicStream::BufferSizeManager::BufferSizeManager() = default;
134
135 WebSocketBasicStream::BufferSizeManager::~BufferSizeManager() = default;
136
OnRead(base::TimeTicks now)137 void WebSocketBasicStream::BufferSizeManager::OnRead(base::TimeTicks now) {
138 read_start_timestamps_.push(now);
139 }
140
OnReadComplete(base::TimeTicks now,int size)141 void WebSocketBasicStream::BufferSizeManager::OnReadComplete(
142 base::TimeTicks now,
143 int size) {
144 DCHECK_GT(size, 0);
145 // This cannot overflow because the result is at most
146 // kLargeReadBufferSize*rolling_average_window_.
147 rolling_byte_total_ += size;
148 recent_read_sizes_.push(size);
149 DCHECK_LE(read_start_timestamps_.size(), rolling_average_window_);
150 if (read_start_timestamps_.size() == rolling_average_window_) {
151 DCHECK_EQ(read_start_timestamps_.size(), recent_read_sizes_.size());
152 base::TimeDelta duration = now - read_start_timestamps_.front();
153 base::TimeDelta threshold_duration =
154 base::Seconds(rolling_byte_total_ / kThresholdInBytesPerSecond);
155 read_start_timestamps_.pop();
156 rolling_byte_total_ -= recent_read_sizes_.front();
157 recent_read_sizes_.pop();
158 if (threshold_duration < duration) {
159 buffer_size_ = BufferSize::kSmall;
160 } else {
161 buffer_size_ = BufferSize::kLarge;
162 }
163 }
164 }
165
WebSocketBasicStream(std::unique_ptr<Adapter> connection,const scoped_refptr<GrowableIOBuffer> & http_read_buffer,const std::string & sub_protocol,const std::string & extensions,const NetLogWithSource & net_log)166 WebSocketBasicStream::WebSocketBasicStream(
167 std::unique_ptr<Adapter> connection,
168 const scoped_refptr<GrowableIOBuffer>& http_read_buffer,
169 const std::string& sub_protocol,
170 const std::string& extensions,
171 const NetLogWithSource& net_log)
172 : read_buffer_(
173 base::MakeRefCounted<IOBufferWithSize>(kSmallReadBufferSize)),
174 target_read_buffer_size_(read_buffer_->size()),
175 connection_(std::move(connection)),
176 http_read_buffer_(http_read_buffer),
177 sub_protocol_(sub_protocol),
178 extensions_(extensions),
179 net_log_(net_log),
180 generate_websocket_masking_key_(&GenerateWebSocketMaskingKey) {
181 // http_read_buffer_ should not be set if it contains no data.
182 if (http_read_buffer_.get() && http_read_buffer_->offset() == 0)
183 http_read_buffer_ = nullptr;
184 DCHECK(connection_->is_initialized());
185 }
186
~WebSocketBasicStream()187 WebSocketBasicStream::~WebSocketBasicStream() { Close(); }
188
ReadFrames(std::vector<std::unique_ptr<WebSocketFrame>> * frames,CompletionOnceCallback callback)189 int WebSocketBasicStream::ReadFrames(
190 std::vector<std::unique_ptr<WebSocketFrame>>* frames,
191 CompletionOnceCallback callback) {
192 read_callback_ = std::move(callback);
193 control_frame_payloads_.clear();
194 if (http_read_buffer_ && is_http_read_buffer_decoded_) {
195 http_read_buffer_.reset();
196 }
197 return ReadEverything(frames);
198 }
199
WriteFrames(std::vector<std::unique_ptr<WebSocketFrame>> * frames,CompletionOnceCallback callback)200 int WebSocketBasicStream::WriteFrames(
201 std::vector<std::unique_ptr<WebSocketFrame>>* frames,
202 CompletionOnceCallback callback) {
203 // This function always concatenates all frames into a single buffer.
204 // TODO(ricea): Investigate whether it would be better in some cases to
205 // perform multiple writes with smaller buffers.
206
207 write_callback_ = std::move(callback);
208
209 // First calculate the size of the buffer we need to allocate.
210 int total_size = CalculateSerializedSizeAndTurnOnMaskBit(frames);
211 auto combined_buffer = base::MakeRefCounted<IOBufferWithSize>(total_size);
212
213 base::span<uint8_t> dest = combined_buffer->span();
214 for (const auto& frame : *frames) {
215 net_log_.AddEvent(net::NetLogEventType::WEBSOCKET_SENT_FRAME_HEADER,
216 [&] { return NetLogFrameHeaderParam(&frame->header); });
217 WebSocketMaskingKey mask = generate_websocket_masking_key_();
218 int result = WriteWebSocketFrameHeader(frame->header, &mask, dest);
219 DCHECK_NE(ERR_INVALID_ARGUMENT, result)
220 << "WriteWebSocketFrameHeader() says that " << dest.size()
221 << " is not enough to write the header in. This should not happen.";
222 dest = dest.subspan(base::checked_cast<size_t>(result));
223
224 CHECK_LE(frame->header.payload_length,
225 base::checked_cast<uint64_t>(dest.size()));
226 const size_t frame_size = frame->header.payload_length;
227 if (frame_size > 0) {
228 dest.copy_prefix_from(frame->payload);
229 MaskWebSocketFramePayload(mask, 0, dest.first(frame_size));
230 dest = dest.subspan(frame_size);
231 }
232 }
233 DCHECK(dest.empty()) << "Buffer size calculation was wrong; " << dest.size()
234 << " bytes left over.";
235 auto drainable_buffer = base::MakeRefCounted<DrainableIOBuffer>(
236 std::move(combined_buffer), total_size);
237 return WriteEverything(drainable_buffer);
238 }
239
Close()240 void WebSocketBasicStream::Close() {
241 connection_->Disconnect();
242 }
243
GetSubProtocol() const244 std::string WebSocketBasicStream::GetSubProtocol() const {
245 return sub_protocol_;
246 }
247
GetExtensions() const248 std::string WebSocketBasicStream::GetExtensions() const { return extensions_; }
249
GetNetLogWithSource() const250 const NetLogWithSource& WebSocketBasicStream::GetNetLogWithSource() const {
251 return net_log_;
252 }
253
254 /*static*/
255 std::unique_ptr<WebSocketBasicStream>
CreateWebSocketBasicStreamForTesting(std::unique_ptr<ClientSocketHandle> connection,const scoped_refptr<GrowableIOBuffer> & http_read_buffer,const std::string & sub_protocol,const std::string & extensions,const NetLogWithSource & net_log,WebSocketMaskingKeyGeneratorFunction key_generator_function)256 WebSocketBasicStream::CreateWebSocketBasicStreamForTesting(
257 std::unique_ptr<ClientSocketHandle> connection,
258 const scoped_refptr<GrowableIOBuffer>& http_read_buffer,
259 const std::string& sub_protocol,
260 const std::string& extensions,
261 const NetLogWithSource& net_log,
262 WebSocketMaskingKeyGeneratorFunction key_generator_function) {
263 auto stream = std::make_unique<WebSocketBasicStream>(
264 std::make_unique<WebSocketClientSocketHandleAdapter>(
265 std::move(connection)),
266 http_read_buffer, sub_protocol, extensions, net_log);
267 stream->generate_websocket_masking_key_ = key_generator_function;
268 return stream;
269 }
270
ReadEverything(std::vector<std::unique_ptr<WebSocketFrame>> * frames)271 int WebSocketBasicStream::ReadEverything(
272 std::vector<std::unique_ptr<WebSocketFrame>>* frames) {
273 DCHECK(frames->empty());
274
275 // If there is data left over after parsing the HTTP headers, attempt to parse
276 // it as WebSocket frames.
277 if (http_read_buffer_.get() && !is_http_read_buffer_decoded_) {
278 DCHECK_GE(http_read_buffer_->offset(), 0);
279 is_http_read_buffer_decoded_ = true;
280 std::vector<std::unique_ptr<WebSocketFrameChunk>> frame_chunks;
281 if (!parser_.Decode(http_read_buffer_->span_before_offset(),
282 &frame_chunks)) {
283 return WebSocketErrorToNetError(parser_.websocket_error());
284 }
285 if (!frame_chunks.empty()) {
286 int result = ConvertChunksToFrames(&frame_chunks, frames);
287 if (result != ERR_IO_PENDING)
288 return result;
289 }
290 }
291
292 // Run until socket stops giving us data or we get some frames.
293 while (true) {
294 if (buffer_size_manager_.buffer_size() != buffer_size_) {
295 read_buffer_ = base::MakeRefCounted<IOBufferWithSize>(
296 buffer_size_manager_.buffer_size() == BufferSize::kSmall
297 ? kSmallReadBufferSize
298 : kLargeReadBufferSize);
299 buffer_size_ = buffer_size_manager_.buffer_size();
300 net_log_.AddEvent(
301 net::NetLogEventType::WEBSOCKET_READ_BUFFER_SIZE_CHANGED,
302 [&] { return NetLogBufferSizeParam(read_buffer_->size()); });
303 }
304 buffer_size_manager_.OnRead(base::TimeTicks::Now());
305
306 // base::Unretained(this) here is safe because net::Socket guarantees not to
307 // call any callbacks after Disconnect(), which we call from the destructor.
308 // The caller of ReadEverything() is required to keep |frames| valid.
309 int result = connection_->Read(
310 read_buffer_.get(), read_buffer_->size(),
311 base::BindOnce(&WebSocketBasicStream::OnReadComplete,
312 base::Unretained(this), base::Unretained(frames)));
313 if (result == ERR_IO_PENDING)
314 return result;
315 result = HandleReadResult(result, frames);
316 if (result != ERR_IO_PENDING)
317 return result;
318 DCHECK(frames->empty());
319 }
320 }
321
OnReadComplete(std::vector<std::unique_ptr<WebSocketFrame>> * frames,int result)322 void WebSocketBasicStream::OnReadComplete(
323 std::vector<std::unique_ptr<WebSocketFrame>>* frames,
324 int result) {
325 result = HandleReadResult(result, frames);
326 if (result == ERR_IO_PENDING)
327 result = ReadEverything(frames);
328 if (result != ERR_IO_PENDING)
329 std::move(read_callback_).Run(result);
330 }
331
WriteEverything(const scoped_refptr<DrainableIOBuffer> & buffer)332 int WebSocketBasicStream::WriteEverything(
333 const scoped_refptr<DrainableIOBuffer>& buffer) {
334 while (buffer->BytesRemaining() > 0) {
335 // The use of base::Unretained() here is safe because on destruction we
336 // disconnect the socket, preventing any further callbacks.
337 int result = connection_->Write(
338 buffer.get(), buffer->BytesRemaining(),
339 base::BindOnce(&WebSocketBasicStream::OnWriteComplete,
340 base::Unretained(this), buffer),
341 kTrafficAnnotation);
342 if (result > 0) {
343 buffer->DidConsume(result);
344 } else {
345 return result;
346 }
347 }
348 return OK;
349 }
350
OnWriteComplete(const scoped_refptr<DrainableIOBuffer> & buffer,int result)351 void WebSocketBasicStream::OnWriteComplete(
352 const scoped_refptr<DrainableIOBuffer>& buffer,
353 int result) {
354 if (result < 0) {
355 DCHECK_NE(ERR_IO_PENDING, result);
356 std::move(write_callback_).Run(result);
357 return;
358 }
359
360 DCHECK_NE(0, result);
361
362 buffer->DidConsume(result);
363 result = WriteEverything(buffer);
364 if (result != ERR_IO_PENDING)
365 std::move(write_callback_).Run(result);
366 }
367
HandleReadResult(int result,std::vector<std::unique_ptr<WebSocketFrame>> * frames)368 int WebSocketBasicStream::HandleReadResult(
369 int result,
370 std::vector<std::unique_ptr<WebSocketFrame>>* frames) {
371 DCHECK_NE(ERR_IO_PENDING, result);
372 DCHECK(frames->empty());
373 if (result < 0)
374 return result;
375 if (result == 0)
376 return ERR_CONNECTION_CLOSED;
377
378 buffer_size_manager_.OnReadComplete(base::TimeTicks::Now(), result);
379
380 std::vector<std::unique_ptr<WebSocketFrameChunk>> frame_chunks;
381 if (!parser_.Decode(
382 read_buffer_->span().first(base::checked_cast<size_t>(result)),
383 &frame_chunks)) {
384 return WebSocketErrorToNetError(parser_.websocket_error());
385 }
386 if (frame_chunks.empty())
387 return ERR_IO_PENDING;
388 return ConvertChunksToFrames(&frame_chunks, frames);
389 }
390
ConvertChunksToFrames(std::vector<std::unique_ptr<WebSocketFrameChunk>> * frame_chunks,std::vector<std::unique_ptr<WebSocketFrame>> * frames)391 int WebSocketBasicStream::ConvertChunksToFrames(
392 std::vector<std::unique_ptr<WebSocketFrameChunk>>* frame_chunks,
393 std::vector<std::unique_ptr<WebSocketFrame>>* frames) {
394 for (auto& chunk : *frame_chunks) {
395 DCHECK(chunk == frame_chunks->back() || chunk->final_chunk)
396 << "Only last chunk can have |final_chunk| set to be false.";
397
398 if (chunk->header) {
399 net_log_.AddEvent(net::NetLogEventType::WEBSOCKET_RECV_FRAME_HEADER, [&] {
400 return NetLogFrameHeaderParam(chunk->header.get());
401 });
402 }
403
404 auto frame_result = chunk_assembler_.HandleChunk(std::move(chunk));
405
406 if (!frame_result.has_value()) {
407 return frame_result.error();
408 }
409
410 auto frame = std::move(frame_result.value());
411 bool is_control_opcode =
412 WebSocketFrameHeader::IsKnownControlOpCode(frame->header.opcode) ||
413 WebSocketFrameHeader::IsReservedControlOpCode(frame->header.opcode);
414 if (is_control_opcode) {
415 const size_t length =
416 base::checked_cast<size_t>(frame->header.payload_length);
417 if (length > 0) {
418 auto copied_payload =
419 base::HeapArray<uint8_t>::CopiedFrom(frame->payload);
420 frame->payload = copied_payload.as_span();
421 control_frame_payloads_.emplace_back(std::move(copied_payload));
422 }
423 }
424
425 frames->emplace_back(std::move(frame));
426 }
427
428 frame_chunks->clear();
429
430 return frames->empty() ? ERR_IO_PENDING : OK;
431 }
432
433 } // namespace net
434