• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2013 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/websockets/websocket_basic_stream.h"
6 
7 #include <stddef.h>
8 #include <stdint.h>
9 
10 #include <algorithm>
11 #include <limits>
12 #include <ostream>
13 #include <utility>
14 
15 #include "base/check.h"
16 #include "base/check_op.h"
17 #include "base/containers/span.h"
18 #include "base/functional/bind.h"
19 #include "base/functional/callback.h"
20 #include "base/logging.h"
21 #include "base/numerics/safe_conversions.h"
22 #include "base/values.h"
23 #include "build/build_config.h"
24 #include "net/base/io_buffer.h"
25 #include "net/base/net_errors.h"
26 #include "net/log/net_log_event_type.h"
27 #include "net/socket/client_socket_handle.h"
28 #include "net/traffic_annotation/network_traffic_annotation.h"
29 #include "net/websockets/websocket_basic_stream_adapters.h"
30 #include "net/websockets/websocket_errors.h"
31 #include "net/websockets/websocket_frame.h"
32 
33 namespace net {
34 
35 namespace {
36 
37 // Please refer to the comment in class header if the usage changes.
38 constexpr net::NetworkTrafficAnnotationTag kTrafficAnnotation =
39     net::DefineNetworkTrafficAnnotation("websocket_basic_stream", R"(
40       semantics {
41         sender: "WebSocket Basic Stream"
42         description:
43           "Implementation of WebSocket API from web content (a page the user "
44           "visits)."
45         trigger: "Website calls the WebSocket API."
46         data:
47           "Any data provided by web content, masked and framed in accordance "
48           "with RFC6455."
49         destination: OTHER
50         destination_other:
51           "The address that the website has chosen to communicate to."
52       }
53       policy {
54         cookies_allowed: YES
55         cookies_store: "user"
56         setting: "These requests cannot be disabled."
57         policy_exception_justification:
58           "Not implemented. WebSocket is a core web platform API."
59       }
60       comments:
61         "The browser will never add cookies to a WebSocket message. But the "
62         "handshake that was performed when the WebSocket connection was "
63         "established may have contained cookies."
64       )");
65 
66 // The number of bytes to attempt to read at a time. It's used only for high
67 // throughput connections.
68 // TODO(ricea): See if there is a better number or algorithm to fulfill our
69 // requirements:
70 //  1. We would like to use minimal memory on low-bandwidth or idle connections
71 //  2. We would like to read as close to line speed as possible on
72 //     high-bandwidth connections
73 //  3. We can't afford to cause jank on the IO thread by copying large buffers
74 //     around
75 //  4. We would like to hit any sweet-spots that might exist in terms of network
76 //     packet sizes / encryption block sizes / IPC alignment issues, etc.
77 #if BUILDFLAG(IS_ANDROID)
78 constexpr size_t kLargeReadBufferSize = 32 * 1024;
79 #else
80 // |2^n - delta| is better than 2^n on Linux. See crrev.com/c/1792208.
81 constexpr size_t kLargeReadBufferSize = 131000;
82 #endif
83 
84 // The number of bytes to attempt to read at a time. It's set as an initial read
85 // buffer size and used for low throughput connections.
86 constexpr size_t kSmallReadBufferSize = 1000;
87 
88 // The threshold to decide whether to switch the read buffer size.
89 constexpr double kThresholdInBytesPerSecond = 1200 * 1000;
90 
91 // Returns the total serialized size of |frames|. This function assumes that
92 // |frames| will be serialized with mask field. This function forces the
93 // masked bit of the frames on.
CalculateSerializedSizeAndTurnOnMaskBit(std::vector<std::unique_ptr<WebSocketFrame>> * frames)94 int CalculateSerializedSizeAndTurnOnMaskBit(
95     std::vector<std::unique_ptr<WebSocketFrame>>* frames) {
96   constexpr uint64_t kMaximumTotalSize = std::numeric_limits<int>::max();
97 
98   uint64_t total_size = 0;
99   for (const auto& frame : *frames) {
100     // Force the masked bit on.
101     frame->header.masked = true;
102     // We enforce flow control so the renderer should never be able to force us
103     // to cache anywhere near 2GB of frames.
104     uint64_t frame_size = frame->header.payload_length +
105                           GetWebSocketFrameHeaderSize(frame->header);
106     CHECK_LE(frame_size, kMaximumTotalSize - total_size)
107         << "Aborting to prevent overflow";
108     total_size += frame_size;
109   }
110   return static_cast<int>(total_size);
111 }
112 
NetLogBufferSizeParam(int buffer_size)113 base::Value::Dict NetLogBufferSizeParam(int buffer_size) {
114   base::Value::Dict dict;
115   dict.Set("read_buffer_size_in_bytes", buffer_size);
116   return dict;
117 }
118 
NetLogFrameHeaderParam(const WebSocketFrameHeader * header)119 base::Value::Dict NetLogFrameHeaderParam(const WebSocketFrameHeader* header) {
120   base::Value::Dict dict;
121   dict.Set("final", header->final);
122   dict.Set("reserved1", header->reserved1);
123   dict.Set("reserved2", header->reserved2);
124   dict.Set("reserved3", header->reserved3);
125   dict.Set("opcode", header->opcode);
126   dict.Set("masked", header->masked);
127   dict.Set("payload_length", static_cast<double>(header->payload_length));
128   return dict;
129 }
130 
131 }  // namespace
132 
133 WebSocketBasicStream::BufferSizeManager::BufferSizeManager() = default;
134 
135 WebSocketBasicStream::BufferSizeManager::~BufferSizeManager() = default;
136 
OnRead(base::TimeTicks now)137 void WebSocketBasicStream::BufferSizeManager::OnRead(base::TimeTicks now) {
138   read_start_timestamps_.push(now);
139 }
140 
OnReadComplete(base::TimeTicks now,int size)141 void WebSocketBasicStream::BufferSizeManager::OnReadComplete(
142     base::TimeTicks now,
143     int size) {
144   DCHECK_GT(size, 0);
145   // This cannot overflow because the result is at most
146   // kLargeReadBufferSize*rolling_average_window_.
147   rolling_byte_total_ += size;
148   recent_read_sizes_.push(size);
149   DCHECK_LE(read_start_timestamps_.size(), rolling_average_window_);
150   if (read_start_timestamps_.size() == rolling_average_window_) {
151     DCHECK_EQ(read_start_timestamps_.size(), recent_read_sizes_.size());
152     base::TimeDelta duration = now - read_start_timestamps_.front();
153     base::TimeDelta threshold_duration =
154         base::Seconds(rolling_byte_total_ / kThresholdInBytesPerSecond);
155     read_start_timestamps_.pop();
156     rolling_byte_total_ -= recent_read_sizes_.front();
157     recent_read_sizes_.pop();
158     if (threshold_duration < duration) {
159       buffer_size_ = BufferSize::kSmall;
160     } else {
161       buffer_size_ = BufferSize::kLarge;
162     }
163   }
164 }
165 
WebSocketBasicStream(std::unique_ptr<Adapter> connection,const scoped_refptr<GrowableIOBuffer> & http_read_buffer,const std::string & sub_protocol,const std::string & extensions,const NetLogWithSource & net_log)166 WebSocketBasicStream::WebSocketBasicStream(
167     std::unique_ptr<Adapter> connection,
168     const scoped_refptr<GrowableIOBuffer>& http_read_buffer,
169     const std::string& sub_protocol,
170     const std::string& extensions,
171     const NetLogWithSource& net_log)
172     : read_buffer_(
173           base::MakeRefCounted<IOBufferWithSize>(kSmallReadBufferSize)),
174       target_read_buffer_size_(read_buffer_->size()),
175       connection_(std::move(connection)),
176       http_read_buffer_(http_read_buffer),
177       sub_protocol_(sub_protocol),
178       extensions_(extensions),
179       net_log_(net_log),
180       generate_websocket_masking_key_(&GenerateWebSocketMaskingKey) {
181   // http_read_buffer_ should not be set if it contains no data.
182   if (http_read_buffer_.get() && http_read_buffer_->offset() == 0)
183     http_read_buffer_ = nullptr;
184   DCHECK(connection_->is_initialized());
185 }
186 
~WebSocketBasicStream()187 WebSocketBasicStream::~WebSocketBasicStream() { Close(); }
188 
ReadFrames(std::vector<std::unique_ptr<WebSocketFrame>> * frames,CompletionOnceCallback callback)189 int WebSocketBasicStream::ReadFrames(
190     std::vector<std::unique_ptr<WebSocketFrame>>* frames,
191     CompletionOnceCallback callback) {
192   read_callback_ = std::move(callback);
193   control_frame_payloads_.clear();
194   if (http_read_buffer_ && is_http_read_buffer_decoded_) {
195     http_read_buffer_.reset();
196   }
197   return ReadEverything(frames);
198 }
199 
WriteFrames(std::vector<std::unique_ptr<WebSocketFrame>> * frames,CompletionOnceCallback callback)200 int WebSocketBasicStream::WriteFrames(
201     std::vector<std::unique_ptr<WebSocketFrame>>* frames,
202     CompletionOnceCallback callback) {
203   // This function always concatenates all frames into a single buffer.
204   // TODO(ricea): Investigate whether it would be better in some cases to
205   // perform multiple writes with smaller buffers.
206 
207   write_callback_ = std::move(callback);
208 
209   // First calculate the size of the buffer we need to allocate.
210   int total_size = CalculateSerializedSizeAndTurnOnMaskBit(frames);
211   auto combined_buffer = base::MakeRefCounted<IOBufferWithSize>(total_size);
212 
213   base::span<uint8_t> dest = combined_buffer->span();
214   for (const auto& frame : *frames) {
215     net_log_.AddEvent(net::NetLogEventType::WEBSOCKET_SENT_FRAME_HEADER,
216                       [&] { return NetLogFrameHeaderParam(&frame->header); });
217     WebSocketMaskingKey mask = generate_websocket_masking_key_();
218     int result = WriteWebSocketFrameHeader(frame->header, &mask, dest);
219     DCHECK_NE(ERR_INVALID_ARGUMENT, result)
220         << "WriteWebSocketFrameHeader() says that " << dest.size()
221         << " is not enough to write the header in. This should not happen.";
222     dest = dest.subspan(base::checked_cast<size_t>(result));
223 
224     CHECK_LE(frame->header.payload_length,
225              base::checked_cast<uint64_t>(dest.size()));
226     const size_t frame_size = frame->header.payload_length;
227     if (frame_size > 0) {
228       dest.copy_prefix_from(frame->payload);
229       MaskWebSocketFramePayload(mask, 0, dest.first(frame_size));
230       dest = dest.subspan(frame_size);
231     }
232   }
233   DCHECK(dest.empty()) << "Buffer size calculation was wrong; " << dest.size()
234                        << " bytes left over.";
235   auto drainable_buffer = base::MakeRefCounted<DrainableIOBuffer>(
236       std::move(combined_buffer), total_size);
237   return WriteEverything(drainable_buffer);
238 }
239 
Close()240 void WebSocketBasicStream::Close() {
241   connection_->Disconnect();
242 }
243 
GetSubProtocol() const244 std::string WebSocketBasicStream::GetSubProtocol() const {
245   return sub_protocol_;
246 }
247 
GetExtensions() const248 std::string WebSocketBasicStream::GetExtensions() const { return extensions_; }
249 
GetNetLogWithSource() const250 const NetLogWithSource& WebSocketBasicStream::GetNetLogWithSource() const {
251   return net_log_;
252 }
253 
254 /*static*/
255 std::unique_ptr<WebSocketBasicStream>
CreateWebSocketBasicStreamForTesting(std::unique_ptr<ClientSocketHandle> connection,const scoped_refptr<GrowableIOBuffer> & http_read_buffer,const std::string & sub_protocol,const std::string & extensions,const NetLogWithSource & net_log,WebSocketMaskingKeyGeneratorFunction key_generator_function)256 WebSocketBasicStream::CreateWebSocketBasicStreamForTesting(
257     std::unique_ptr<ClientSocketHandle> connection,
258     const scoped_refptr<GrowableIOBuffer>& http_read_buffer,
259     const std::string& sub_protocol,
260     const std::string& extensions,
261     const NetLogWithSource& net_log,
262     WebSocketMaskingKeyGeneratorFunction key_generator_function) {
263   auto stream = std::make_unique<WebSocketBasicStream>(
264       std::make_unique<WebSocketClientSocketHandleAdapter>(
265           std::move(connection)),
266       http_read_buffer, sub_protocol, extensions, net_log);
267   stream->generate_websocket_masking_key_ = key_generator_function;
268   return stream;
269 }
270 
ReadEverything(std::vector<std::unique_ptr<WebSocketFrame>> * frames)271 int WebSocketBasicStream::ReadEverything(
272     std::vector<std::unique_ptr<WebSocketFrame>>* frames) {
273   DCHECK(frames->empty());
274 
275   // If there is data left over after parsing the HTTP headers, attempt to parse
276   // it as WebSocket frames.
277   if (http_read_buffer_.get() && !is_http_read_buffer_decoded_) {
278     DCHECK_GE(http_read_buffer_->offset(), 0);
279     is_http_read_buffer_decoded_ = true;
280     std::vector<std::unique_ptr<WebSocketFrameChunk>> frame_chunks;
281     if (!parser_.Decode(http_read_buffer_->span_before_offset(),
282                         &frame_chunks)) {
283       return WebSocketErrorToNetError(parser_.websocket_error());
284     }
285     if (!frame_chunks.empty()) {
286       int result = ConvertChunksToFrames(&frame_chunks, frames);
287       if (result != ERR_IO_PENDING)
288         return result;
289     }
290   }
291 
292   // Run until socket stops giving us data or we get some frames.
293   while (true) {
294     if (buffer_size_manager_.buffer_size() != buffer_size_) {
295       read_buffer_ = base::MakeRefCounted<IOBufferWithSize>(
296           buffer_size_manager_.buffer_size() == BufferSize::kSmall
297               ? kSmallReadBufferSize
298               : kLargeReadBufferSize);
299       buffer_size_ = buffer_size_manager_.buffer_size();
300       net_log_.AddEvent(
301           net::NetLogEventType::WEBSOCKET_READ_BUFFER_SIZE_CHANGED,
302           [&] { return NetLogBufferSizeParam(read_buffer_->size()); });
303     }
304     buffer_size_manager_.OnRead(base::TimeTicks::Now());
305 
306     // base::Unretained(this) here is safe because net::Socket guarantees not to
307     // call any callbacks after Disconnect(), which we call from the destructor.
308     // The caller of ReadEverything() is required to keep |frames| valid.
309     int result = connection_->Read(
310         read_buffer_.get(), read_buffer_->size(),
311         base::BindOnce(&WebSocketBasicStream::OnReadComplete,
312                        base::Unretained(this), base::Unretained(frames)));
313     if (result == ERR_IO_PENDING)
314       return result;
315     result = HandleReadResult(result, frames);
316     if (result != ERR_IO_PENDING)
317       return result;
318     DCHECK(frames->empty());
319   }
320 }
321 
OnReadComplete(std::vector<std::unique_ptr<WebSocketFrame>> * frames,int result)322 void WebSocketBasicStream::OnReadComplete(
323     std::vector<std::unique_ptr<WebSocketFrame>>* frames,
324     int result) {
325   result = HandleReadResult(result, frames);
326   if (result == ERR_IO_PENDING)
327     result = ReadEverything(frames);
328   if (result != ERR_IO_PENDING)
329     std::move(read_callback_).Run(result);
330 }
331 
WriteEverything(const scoped_refptr<DrainableIOBuffer> & buffer)332 int WebSocketBasicStream::WriteEverything(
333     const scoped_refptr<DrainableIOBuffer>& buffer) {
334   while (buffer->BytesRemaining() > 0) {
335     // The use of base::Unretained() here is safe because on destruction we
336     // disconnect the socket, preventing any further callbacks.
337     int result = connection_->Write(
338         buffer.get(), buffer->BytesRemaining(),
339         base::BindOnce(&WebSocketBasicStream::OnWriteComplete,
340                        base::Unretained(this), buffer),
341         kTrafficAnnotation);
342     if (result > 0) {
343       buffer->DidConsume(result);
344     } else {
345       return result;
346     }
347   }
348   return OK;
349 }
350 
OnWriteComplete(const scoped_refptr<DrainableIOBuffer> & buffer,int result)351 void WebSocketBasicStream::OnWriteComplete(
352     const scoped_refptr<DrainableIOBuffer>& buffer,
353     int result) {
354   if (result < 0) {
355     DCHECK_NE(ERR_IO_PENDING, result);
356     std::move(write_callback_).Run(result);
357     return;
358   }
359 
360   DCHECK_NE(0, result);
361 
362   buffer->DidConsume(result);
363   result = WriteEverything(buffer);
364   if (result != ERR_IO_PENDING)
365     std::move(write_callback_).Run(result);
366 }
367 
HandleReadResult(int result,std::vector<std::unique_ptr<WebSocketFrame>> * frames)368 int WebSocketBasicStream::HandleReadResult(
369     int result,
370     std::vector<std::unique_ptr<WebSocketFrame>>* frames) {
371   DCHECK_NE(ERR_IO_PENDING, result);
372   DCHECK(frames->empty());
373   if (result < 0)
374     return result;
375   if (result == 0)
376     return ERR_CONNECTION_CLOSED;
377 
378   buffer_size_manager_.OnReadComplete(base::TimeTicks::Now(), result);
379 
380   std::vector<std::unique_ptr<WebSocketFrameChunk>> frame_chunks;
381   if (!parser_.Decode(
382           read_buffer_->span().first(base::checked_cast<size_t>(result)),
383           &frame_chunks)) {
384     return WebSocketErrorToNetError(parser_.websocket_error());
385   }
386   if (frame_chunks.empty())
387     return ERR_IO_PENDING;
388   return ConvertChunksToFrames(&frame_chunks, frames);
389 }
390 
ConvertChunksToFrames(std::vector<std::unique_ptr<WebSocketFrameChunk>> * frame_chunks,std::vector<std::unique_ptr<WebSocketFrame>> * frames)391 int WebSocketBasicStream::ConvertChunksToFrames(
392     std::vector<std::unique_ptr<WebSocketFrameChunk>>* frame_chunks,
393     std::vector<std::unique_ptr<WebSocketFrame>>* frames) {
394   for (auto& chunk : *frame_chunks) {
395     DCHECK(chunk == frame_chunks->back() || chunk->final_chunk)
396         << "Only last chunk can have |final_chunk| set to be false.";
397 
398     if (chunk->header) {
399       net_log_.AddEvent(net::NetLogEventType::WEBSOCKET_RECV_FRAME_HEADER, [&] {
400         return NetLogFrameHeaderParam(chunk->header.get());
401       });
402     }
403 
404     auto frame_result = chunk_assembler_.HandleChunk(std::move(chunk));
405 
406     if (!frame_result.has_value()) {
407       return frame_result.error();
408     }
409 
410     auto frame = std::move(frame_result.value());
411     bool is_control_opcode =
412         WebSocketFrameHeader::IsKnownControlOpCode(frame->header.opcode) ||
413         WebSocketFrameHeader::IsReservedControlOpCode(frame->header.opcode);
414     if (is_control_opcode) {
415       const size_t length =
416           base::checked_cast<size_t>(frame->header.payload_length);
417       if (length > 0) {
418         auto copied_payload =
419             base::HeapArray<uint8_t>::CopiedFrom(frame->payload);
420         frame->payload = copied_payload.as_span();
421         control_frame_payloads_.emplace_back(std::move(copied_payload));
422       }
423     }
424 
425     frames->emplace_back(std::move(frame));
426   }
427 
428   frame_chunks->clear();
429 
430   return frames->empty() ? ERR_IO_PENDING : OK;
431 }
432 
433 }  // namespace net
434