• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/websockets/websocket_frame_parser.h"
6 
7 #include <algorithm>
8 #include <ostream>
9 #include <utility>
10 #include <vector>
11 
12 #include "base/check.h"
13 #include "base/check_op.h"
14 #include "base/containers/extend.h"
15 #include "base/containers/span.h"
16 #include "base/logging.h"
17 #include "base/numerics/byte_conversions.h"
18 #include "base/numerics/safe_conversions.h"
19 #include "net/websockets/websocket_frame.h"
20 
21 namespace {
22 
23 constexpr uint8_t kFinalBit = 0x80;
24 constexpr uint8_t kReserved1Bit = 0x40;
25 constexpr uint8_t kReserved2Bit = 0x20;
26 constexpr uint8_t kReserved3Bit = 0x10;
27 constexpr uint8_t kOpCodeMask = 0xF;
28 constexpr uint8_t kMaskBit = 0x80;
29 constexpr uint8_t kPayloadLengthMask = 0x7F;
30 constexpr uint64_t kMaxPayloadLengthWithoutExtendedLengthField = 125;
31 constexpr uint64_t kPayloadLengthWithTwoByteExtendedLengthField = 126;
32 constexpr uint64_t kPayloadLengthWithEightByteExtendedLengthField = 127;
33 constexpr size_t kMaximumFrameHeaderSize =
34     net::WebSocketFrameHeader::kBaseHeaderSize +
35     net::WebSocketFrameHeader::kMaximumExtendedLengthSize +
36     net::WebSocketFrameHeader::kMaskingKeyLength;
37 
38 }  // namespace.
39 
40 namespace net {
41 
42 WebSocketFrameParser::WebSocketFrameParser() = default;
43 
44 WebSocketFrameParser::~WebSocketFrameParser() = default;
45 
Decode(base::span<uint8_t> data_span,std::vector<std::unique_ptr<WebSocketFrameChunk>> * frame_chunks)46 bool WebSocketFrameParser::Decode(
47     base::span<uint8_t> data_span,
48     std::vector<std::unique_ptr<WebSocketFrameChunk>>* frame_chunks) {
49   if (websocket_error_ != kWebSocketNormalClosure) {
50     return false;
51   }
52   if (data_span.empty()) {
53     return true;
54   }
55 
56   // If we have incomplete frame header, try to decode a header combining with
57   // |data|.
58   bool first_chunk = false;
59   if (incomplete_header_buffer_.size() > 0) {
60     DCHECK(!current_frame_header_.get());
61     const size_t original_size = incomplete_header_buffer_.size();
62     DCHECK_LE(original_size, kMaximumFrameHeaderSize);
63     base::Extend(
64         incomplete_header_buffer_,
65         data_span.first(std::min(data_span.size(),
66                                  kMaximumFrameHeaderSize - original_size)));
67     const size_t consumed = DecodeFrameHeader(incomplete_header_buffer_);
68     if (websocket_error_ != kWebSocketNormalClosure)
69       return false;
70     if (!current_frame_header_.get())
71       return true;
72 
73     DCHECK_GE(consumed, original_size);
74     data_span = data_span.subspan(consumed - original_size);
75     incomplete_header_buffer_.clear();
76     first_chunk = true;
77   }
78 
79   DCHECK(incomplete_header_buffer_.empty());
80   while (data_span.size() > 0 || first_chunk) {
81     if (!current_frame_header_.get()) {
82       const size_t consumed = DecodeFrameHeader(data_span);
83       if (websocket_error_ != kWebSocketNormalClosure)
84         return false;
85       // If frame header is incomplete, then carry over the remaining
86       // data to the next round of Decode().
87       if (!current_frame_header_.get()) {
88         DCHECK(!consumed);
89         base::Extend(incomplete_header_buffer_, data_span);
90         // Sanity check: the size of carried-over data should not exceed
91         // the maximum possible length of a frame header.
92         DCHECK_LT(incomplete_header_buffer_.size(), kMaximumFrameHeaderSize);
93         return true;
94       }
95       DCHECK_GE(data_span.size(), consumed);
96       data_span = data_span.subspan(consumed);
97       first_chunk = true;
98     }
99     DCHECK(incomplete_header_buffer_.empty());
100     std::unique_ptr<WebSocketFrameChunk> frame_chunk =
101         DecodeFramePayload(first_chunk, &data_span);
102     first_chunk = false;
103     DCHECK(frame_chunk.get());
104     frame_chunks->push_back(std::move(frame_chunk));
105   }
106   return true;
107 }
108 
DecodeFrameHeader(base::span<const uint8_t> data)109 size_t WebSocketFrameParser::DecodeFrameHeader(base::span<const uint8_t> data) {
110   DVLOG(3) << "DecodeFrameHeader buffer size:"
111            << ", data size:" << data.size();
112   typedef WebSocketFrameHeader::OpCode OpCode;
113   DCHECK(!current_frame_header_.get());
114 
115   // Header needs 2 bytes at minimum.
116   if (data.size() < 2)
117     return 0;
118   size_t current = 0;
119   const uint8_t first_byte = data[current++];
120   const uint8_t second_byte = data[current++];
121 
122   const bool final = (first_byte & kFinalBit) != 0;
123   const bool reserved1 = (first_byte & kReserved1Bit) != 0;
124   const bool reserved2 = (first_byte & kReserved2Bit) != 0;
125   const bool reserved3 = (first_byte & kReserved3Bit) != 0;
126   const OpCode opcode = first_byte & kOpCodeMask;
127 
128   uint64_t payload_length = second_byte & kPayloadLengthMask;
129   if (payload_length == kPayloadLengthWithTwoByteExtendedLengthField) {
130     if (data.size() < current + 2)
131       return 0;
132     uint16_t payload_length_16 =
133         base::U16FromBigEndian(data.subspan(current).first<2>());
134     current += 2;
135     payload_length = payload_length_16;
136     if (payload_length <= kMaxPayloadLengthWithoutExtendedLengthField) {
137       websocket_error_ = kWebSocketErrorProtocolError;
138       return 0;
139     }
140   } else if (payload_length == kPayloadLengthWithEightByteExtendedLengthField) {
141     if (data.size() < current + 8)
142       return 0;
143     payload_length = base::U64FromBigEndian(data.subspan(current).first<8>());
144     current += 8;
145     if (payload_length <= UINT16_MAX ||
146         payload_length > static_cast<uint64_t>(INT64_MAX)) {
147       websocket_error_ = kWebSocketErrorProtocolError;
148       return 0;
149     }
150     if (payload_length > static_cast<uint64_t>(INT32_MAX)) {
151       websocket_error_ = kWebSocketErrorMessageTooBig;
152       return 0;
153     }
154   }
155   DCHECK_EQ(websocket_error_, kWebSocketNormalClosure);
156 
157   WebSocketMaskingKey masking_key = {};
158   const bool masked = (second_byte & kMaskBit) != 0;
159   static constexpr size_t kMaskingKeyLength =
160       WebSocketFrameHeader::kMaskingKeyLength;
161   if (masked) {
162     if (data.size() < current + kMaskingKeyLength)
163       return 0;
164     base::as_writable_byte_span(masking_key.key)
165         .copy_from(data.subspan(current, kMaskingKeyLength));
166     current += kMaskingKeyLength;
167   }
168 
169   current_frame_header_ = std::make_unique<WebSocketFrameHeader>(opcode);
170   current_frame_header_->final = final;
171   current_frame_header_->reserved1 = reserved1;
172   current_frame_header_->reserved2 = reserved2;
173   current_frame_header_->reserved3 = reserved3;
174   current_frame_header_->masked = masked;
175   current_frame_header_->masking_key = masking_key;
176   current_frame_header_->payload_length = payload_length;
177   DCHECK_EQ(0u, frame_offset_);
178   return current;
179 }
180 
DecodeFramePayload(bool first_chunk,base::span<uint8_t> * data)181 std::unique_ptr<WebSocketFrameChunk> WebSocketFrameParser::DecodeFramePayload(
182     bool first_chunk,
183     base::span<uint8_t>* data) {
184   // The cast here is safe because |payload_length| is already checked to be
185   // less than std::numeric_limits<int>::max() when the header is parsed.
186   const auto chunk_data_size = static_cast<uint64_t>(
187       std::min(uint64_t{data->size()},
188                current_frame_header_->payload_length - frame_offset_));
189 
190   auto frame_chunk = std::make_unique<WebSocketFrameChunk>();
191   if (first_chunk) {
192     frame_chunk->header = current_frame_header_->Clone();
193   }
194   frame_chunk->final_chunk = false;
195   if (chunk_data_size) {
196     const auto split_point = base::checked_cast<size_t>(chunk_data_size);
197     frame_chunk->payload = base::as_writable_chars(data->first(split_point));
198     *data = data->subspan(split_point);
199     frame_offset_ += chunk_data_size;
200   }
201 
202   DCHECK_LE(frame_offset_, current_frame_header_->payload_length);
203   if (frame_offset_ == current_frame_header_->payload_length) {
204     frame_chunk->final_chunk = true;
205     current_frame_header_.reset();
206     frame_offset_ = 0;
207   }
208 
209   return frame_chunk;
210 }
211 
212 }  // namespace net
213