1 // Copyright 2014 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #ifdef UNSAFE_BUFFERS_BUILD
6 // TODO(crbug.com/40284755): Remove this and spanify to fix the errors.
7 #pragma allow_unsafe_buffers
8 #endif
9
10 #include "net/server/web_socket_encoder.h"
11
12 #include <limits>
13 #include <string_view>
14 #include <utility>
15
16 #include "base/check.h"
17 #include "base/memory/ptr_util.h"
18 #include "base/strings/strcat.h"
19 #include "base/strings/string_number_conversions.h"
20 #include "net/base/io_buffer.h"
21 #include "net/base/net_export.h"
22 #include "net/websockets/websocket_deflate_parameters.h"
23 #include "net/websockets/websocket_extension.h"
24 #include "net/websockets/websocket_extension_parser.h"
25 #include "net/websockets/websocket_frame.h"
26
27 namespace net {
28
29 NET_EXPORT
30 const char WebSocketEncoder::kClientExtensions[] =
31 "permessage-deflate; client_max_window_bits";
32
33 namespace {
34
35 const int kInflaterChunkSize = 16 * 1024;
36
37 // Constants for hybi-10 frame format.
38
39 const unsigned char kFinalBit = 0x80;
40 const unsigned char kReserved1Bit = 0x40;
41 const unsigned char kReserved2Bit = 0x20;
42 const unsigned char kReserved3Bit = 0x10;
43 const unsigned char kOpCodeMask = 0xF;
44 const unsigned char kMaskBit = 0x80;
45 const unsigned char kPayloadLengthMask = 0x7F;
46
47 const size_t kMaxSingleBytePayloadLength = 125;
48 const size_t kTwoBytePayloadLengthField = 126;
49 const size_t kEightBytePayloadLengthField = 127;
50 const size_t kMaskingKeyWidthInBytes = 4;
51
DecodeFrameHybi17(std::string_view frame,bool client_frame,int * bytes_consumed,std::string * output,bool * compressed)52 WebSocketParseResult DecodeFrameHybi17(std::string_view frame,
53 bool client_frame,
54 int* bytes_consumed,
55 std::string* output,
56 bool* compressed) {
57 size_t data_length = frame.length();
58 if (data_length < 2)
59 return WebSocketParseResult::FRAME_INCOMPLETE;
60
61 const char* buffer_begin = const_cast<char*>(frame.data());
62 const char* p = buffer_begin;
63 const char* buffer_end = p + data_length;
64
65 unsigned char first_byte = *p++;
66 unsigned char second_byte = *p++;
67
68 bool final = (first_byte & kFinalBit) != 0;
69 bool reserved1 = (first_byte & kReserved1Bit) != 0;
70 bool reserved2 = (first_byte & kReserved2Bit) != 0;
71 bool reserved3 = (first_byte & kReserved3Bit) != 0;
72 int op_code = first_byte & kOpCodeMask;
73 bool masked = (second_byte & kMaskBit) != 0;
74 *compressed = reserved1;
75 if (reserved2 || reserved3)
76 return WebSocketParseResult::FRAME_ERROR; // Only compression extension is
77 // supported.
78
79 bool closed = false;
80 switch (op_code) {
81 case WebSocketFrameHeader::OpCodeEnum::kOpCodeClose:
82 closed = true;
83 break;
84
85 case WebSocketFrameHeader::OpCodeEnum::kOpCodeText:
86 case WebSocketFrameHeader::OpCodeEnum::
87 kOpCodeContinuation: // Treated in the same as kOpCodeText.
88 case WebSocketFrameHeader::OpCodeEnum::kOpCodePing:
89 case WebSocketFrameHeader::OpCodeEnum::kOpCodePong:
90 break;
91
92 case WebSocketFrameHeader::OpCodeEnum::kOpCodeBinary: // We don't support
93 // binary frames yet.
94 default:
95 return WebSocketParseResult::FRAME_ERROR;
96 }
97
98 if (client_frame && !masked) // In Hybi-17 spec client MUST mask its frame.
99 return WebSocketParseResult::FRAME_ERROR;
100
101 uint64_t payload_length64 = second_byte & kPayloadLengthMask;
102 if (payload_length64 > kMaxSingleBytePayloadLength) {
103 int extended_payload_length_size;
104 if (payload_length64 == kTwoBytePayloadLengthField) {
105 extended_payload_length_size = 2;
106 } else {
107 DCHECK(payload_length64 == kEightBytePayloadLengthField);
108 extended_payload_length_size = 8;
109 }
110 if (buffer_end - p < extended_payload_length_size)
111 return WebSocketParseResult::FRAME_INCOMPLETE;
112 payload_length64 = 0;
113 for (int i = 0; i < extended_payload_length_size; ++i) {
114 payload_length64 <<= 8;
115 payload_length64 |= static_cast<unsigned char>(*p++);
116 }
117 }
118
119 size_t actual_masking_key_length = masked ? kMaskingKeyWidthInBytes : 0;
120 static const uint64_t max_payload_length = 0x7FFFFFFFFFFFFFFFull;
121 static size_t max_length = std::numeric_limits<size_t>::max();
122 if (payload_length64 > max_payload_length ||
123 payload_length64 + actual_masking_key_length > max_length) {
124 // WebSocket frame length too large.
125 return WebSocketParseResult::FRAME_ERROR;
126 }
127 size_t payload_length = static_cast<size_t>(payload_length64);
128
129 size_t total_length = actual_masking_key_length + payload_length;
130 if (static_cast<size_t>(buffer_end - p) < total_length)
131 return WebSocketParseResult::FRAME_INCOMPLETE;
132
133 if (masked) {
134 output->resize(payload_length);
135 const char* masking_key = p;
136 char* payload = const_cast<char*>(p + kMaskingKeyWidthInBytes);
137 for (size_t i = 0; i < payload_length; ++i) // Unmask the payload.
138 (*output)[i] = payload[i] ^ masking_key[i % kMaskingKeyWidthInBytes];
139 } else {
140 output->assign(p, p + payload_length);
141 }
142
143 size_t pos = p + actual_masking_key_length + payload_length - buffer_begin;
144 *bytes_consumed = pos;
145
146 if (op_code == WebSocketFrameHeader::OpCodeEnum::kOpCodePing)
147 return WebSocketParseResult::FRAME_PING;
148
149 if (op_code == WebSocketFrameHeader::OpCodeEnum::kOpCodePong)
150 return WebSocketParseResult::FRAME_PONG;
151
152 if (closed)
153 return WebSocketParseResult::FRAME_CLOSE;
154
155 return final ? WebSocketParseResult::FRAME_OK_FINAL
156 : WebSocketParseResult::FRAME_OK_MIDDLE;
157 }
158
EncodeFrameHybi17(std::string_view message,int masking_key,bool compressed,WebSocketFrameHeader::OpCodeEnum op_code,std::string * output)159 void EncodeFrameHybi17(std::string_view message,
160 int masking_key,
161 bool compressed,
162 WebSocketFrameHeader::OpCodeEnum op_code,
163 std::string* output) {
164 std::vector<char> frame;
165 size_t data_length = message.length();
166
167 int reserved1 = compressed ? kReserved1Bit : 0;
168 frame.push_back(kFinalBit | op_code | reserved1);
169 char mask_key_bit = masking_key != 0 ? kMaskBit : 0;
170 if (data_length <= kMaxSingleBytePayloadLength) {
171 frame.push_back(static_cast<char>(data_length) | mask_key_bit);
172 } else if (data_length <= 0xFFFF) {
173 frame.push_back(kTwoBytePayloadLengthField | mask_key_bit);
174 frame.push_back((data_length & 0xFF00) >> 8);
175 frame.push_back(data_length & 0xFF);
176 } else {
177 frame.push_back(kEightBytePayloadLengthField | mask_key_bit);
178 char extended_payload_length[8];
179 size_t remaining = data_length;
180 // Fill the length into extended_payload_length in the network byte order.
181 for (int i = 0; i < 8; ++i) {
182 extended_payload_length[7 - i] = remaining & 0xFF;
183 remaining >>= 8;
184 }
185 frame.insert(frame.end(), extended_payload_length,
186 extended_payload_length + 8);
187 DCHECK(!remaining);
188 }
189
190 const char* data = const_cast<char*>(message.data());
191 if (masking_key != 0) {
192 const char* mask_bytes = reinterpret_cast<char*>(&masking_key);
193 frame.insert(frame.end(), mask_bytes, mask_bytes + 4);
194 for (size_t i = 0; i < data_length; ++i) // Mask the payload.
195 frame.push_back(data[i] ^ mask_bytes[i % kMaskingKeyWidthInBytes]);
196 } else {
197 frame.insert(frame.end(), data, data + data_length);
198 }
199 *output = std::string(frame.data(), frame.size());
200 }
201
202 } // anonymous namespace
203
204 // static
CreateServer()205 std::unique_ptr<WebSocketEncoder> WebSocketEncoder::CreateServer() {
206 return base::WrapUnique(new WebSocketEncoder(FOR_SERVER, nullptr, nullptr));
207 }
208
209 // static
CreateServer(const std::string & extensions,WebSocketDeflateParameters * deflate_parameters)210 std::unique_ptr<WebSocketEncoder> WebSocketEncoder::CreateServer(
211 const std::string& extensions,
212 WebSocketDeflateParameters* deflate_parameters) {
213 WebSocketExtensionParser parser;
214 if (!parser.Parse(extensions)) {
215 // Failed to parse Sec-WebSocket-Extensions header. We MUST fail the
216 // connection.
217 return nullptr;
218 }
219
220 for (const auto& extension : parser.extensions()) {
221 std::string failure_message;
222 WebSocketDeflateParameters offer;
223 if (!offer.Initialize(extension, &failure_message) ||
224 !offer.IsValidAsRequest(&failure_message)) {
225 // We decline unknown / malformed extensions.
226 continue;
227 }
228
229 WebSocketDeflateParameters response = offer;
230 if (offer.is_client_max_window_bits_specified() &&
231 !offer.has_client_max_window_bits_value()) {
232 // We need to choose one value for the response.
233 response.SetClientMaxWindowBits(15);
234 }
235 DCHECK(response.IsValidAsResponse());
236 DCHECK(offer.IsCompatibleWith(response));
237 auto deflater = std::make_unique<WebSocketDeflater>(
238 response.server_context_take_over_mode());
239 auto inflater = std::make_unique<WebSocketInflater>(kInflaterChunkSize,
240 kInflaterChunkSize);
241 if (!deflater->Initialize(response.PermissiveServerMaxWindowBits()) ||
242 !inflater->Initialize(response.PermissiveClientMaxWindowBits())) {
243 // For some reason we cannot accept the parameters.
244 continue;
245 }
246 *deflate_parameters = response;
247 return base::WrapUnique(new WebSocketEncoder(
248 FOR_SERVER, std::move(deflater), std::move(inflater)));
249 }
250
251 // We cannot find an acceptable offer.
252 return base::WrapUnique(new WebSocketEncoder(FOR_SERVER, nullptr, nullptr));
253 }
254
255 // static
CreateClient(const std::string & response_extensions)256 std::unique_ptr<WebSocketEncoder> WebSocketEncoder::CreateClient(
257 const std::string& response_extensions) {
258 // TODO(yhirano): Add a way to return an error.
259
260 WebSocketExtensionParser parser;
261 if (!parser.Parse(response_extensions)) {
262 // Parse error. Note that there are two cases here.
263 // 1) There is no Sec-WebSocket-Extensions header.
264 // 2) There is a malformed Sec-WebSocketExtensions header.
265 // We should return a deflate-disabled encoder for the former case and
266 // fail the connection for the latter case.
267 return base::WrapUnique(new WebSocketEncoder(FOR_CLIENT, nullptr, nullptr));
268 }
269 if (parser.extensions().size() != 1) {
270 // Only permessage-deflate extension is supported.
271 // TODO (yhirano): Fail the connection.
272 return base::WrapUnique(new WebSocketEncoder(FOR_CLIENT, nullptr, nullptr));
273 }
274 const auto& extension = parser.extensions()[0];
275 WebSocketDeflateParameters params;
276 std::string failure_message;
277 if (!params.Initialize(extension, &failure_message) ||
278 !params.IsValidAsResponse(&failure_message)) {
279 // TODO (yhirano): Fail the connection.
280 return base::WrapUnique(new WebSocketEncoder(FOR_CLIENT, nullptr, nullptr));
281 }
282
283 auto deflater = std::make_unique<WebSocketDeflater>(
284 params.client_context_take_over_mode());
285 auto inflater = std::make_unique<WebSocketInflater>(kInflaterChunkSize,
286 kInflaterChunkSize);
287 if (!deflater->Initialize(params.PermissiveClientMaxWindowBits()) ||
288 !inflater->Initialize(params.PermissiveServerMaxWindowBits())) {
289 // TODO (yhirano): Fail the connection.
290 return base::WrapUnique(new WebSocketEncoder(FOR_CLIENT, nullptr, nullptr));
291 }
292
293 return base::WrapUnique(new WebSocketEncoder(FOR_CLIENT, std::move(deflater),
294 std::move(inflater)));
295 }
296
WebSocketEncoder(Type type,std::unique_ptr<WebSocketDeflater> deflater,std::unique_ptr<WebSocketInflater> inflater)297 WebSocketEncoder::WebSocketEncoder(Type type,
298 std::unique_ptr<WebSocketDeflater> deflater,
299 std::unique_ptr<WebSocketInflater> inflater)
300 : type_(type),
301 deflater_(std::move(deflater)),
302 inflater_(std::move(inflater)) {}
303
304 WebSocketEncoder::~WebSocketEncoder() = default;
305
DecodeFrame(std::string_view frame,int * bytes_consumed,std::string * output)306 WebSocketParseResult WebSocketEncoder::DecodeFrame(std::string_view frame,
307 int* bytes_consumed,
308 std::string* output) {
309 bool compressed;
310 std::string current_output;
311 WebSocketParseResult result = DecodeFrameHybi17(
312 frame, type_ == FOR_SERVER, bytes_consumed, ¤t_output, &compressed);
313 switch (result) {
314 case WebSocketParseResult::FRAME_OK_FINAL:
315 case WebSocketParseResult::FRAME_OK_MIDDLE: {
316 if (continuation_message_frames_.empty())
317 is_current_message_compressed_ = compressed;
318 continuation_message_frames_.push_back(current_output);
319
320 if (result == WebSocketParseResult::FRAME_OK_FINAL) {
321 *output = base::StrCat(continuation_message_frames_);
322 continuation_message_frames_.clear();
323 if (is_current_message_compressed_ && !Inflate(output)) {
324 return WebSocketParseResult::FRAME_ERROR;
325 }
326 }
327 break;
328 }
329
330 case WebSocketParseResult::FRAME_PING:
331 *output = current_output;
332 break;
333
334 default:
335 // This function doesn't need special handling for other parse results.
336 break;
337 }
338
339 return result;
340 }
341
EncodeTextFrame(std::string_view frame,int masking_key,std::string * output)342 void WebSocketEncoder::EncodeTextFrame(std::string_view frame,
343 int masking_key,
344 std::string* output) {
345 std::string compressed;
346 constexpr auto op_code = WebSocketFrameHeader::OpCodeEnum::kOpCodeText;
347 if (Deflate(frame, &compressed))
348 EncodeFrameHybi17(compressed, masking_key, true, op_code, output);
349 else
350 EncodeFrameHybi17(frame, masking_key, false, op_code, output);
351 }
352
EncodeCloseFrame(std::string_view frame,int masking_key,std::string * output)353 void WebSocketEncoder::EncodeCloseFrame(std::string_view frame,
354 int masking_key,
355 std::string* output) {
356 constexpr auto op_code = WebSocketFrameHeader::OpCodeEnum::kOpCodeClose;
357 EncodeFrameHybi17(frame, masking_key, false, op_code, output);
358 }
359
EncodePongFrame(std::string_view frame,int masking_key,std::string * output)360 void WebSocketEncoder::EncodePongFrame(std::string_view frame,
361 int masking_key,
362 std::string* output) {
363 constexpr auto op_code = WebSocketFrameHeader::OpCodeEnum::kOpCodePong;
364 EncodeFrameHybi17(frame, masking_key, false, op_code, output);
365 }
366
Inflate(std::string * message)367 bool WebSocketEncoder::Inflate(std::string* message) {
368 if (!inflater_)
369 return false;
370 if (!inflater_->AddBytes(message->data(), message->length()))
371 return false;
372 if (!inflater_->Finish())
373 return false;
374
375 std::vector<char> output;
376 while (inflater_->CurrentOutputSize() > 0) {
377 scoped_refptr<IOBufferWithSize> chunk =
378 inflater_->GetOutput(inflater_->CurrentOutputSize());
379 if (!chunk.get())
380 return false;
381 output.insert(output.end(), chunk->data(), chunk->data() + chunk->size());
382 }
383
384 *message =
385 output.size() ? std::string(output.data(), output.size()) : std::string();
386 return true;
387 }
388
Deflate(std::string_view message,std::string * output)389 bool WebSocketEncoder::Deflate(std::string_view message, std::string* output) {
390 if (!deflater_)
391 return false;
392 if (!deflater_->AddBytes(message.data(), message.length())) {
393 deflater_->Finish();
394 return false;
395 }
396 if (!deflater_->Finish())
397 return false;
398 scoped_refptr<IOBufferWithSize> buffer =
399 deflater_->GetOutput(deflater_->CurrentOutputSize());
400 if (!buffer.get())
401 return false;
402 *output = std::string(buffer->data(), buffer->size());
403 return true;
404 }
405
406 } // namespace net
407