• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2014 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #ifdef UNSAFE_BUFFERS_BUILD
6 // TODO(crbug.com/40284755): Remove this and spanify to fix the errors.
7 #pragma allow_unsafe_buffers
8 #endif
9 
10 #include "net/server/web_socket_encoder.h"
11 
12 #include <limits>
13 #include <string_view>
14 #include <utility>
15 
16 #include "base/check.h"
17 #include "base/memory/ptr_util.h"
18 #include "base/strings/strcat.h"
19 #include "base/strings/string_number_conversions.h"
20 #include "net/base/io_buffer.h"
21 #include "net/base/net_export.h"
22 #include "net/websockets/websocket_deflate_parameters.h"
23 #include "net/websockets/websocket_extension.h"
24 #include "net/websockets/websocket_extension_parser.h"
25 #include "net/websockets/websocket_frame.h"
26 
27 namespace net {
28 
29 NET_EXPORT
30 const char WebSocketEncoder::kClientExtensions[] =
31     "permessage-deflate; client_max_window_bits";
32 
33 namespace {
34 
35 const int kInflaterChunkSize = 16 * 1024;
36 
37 // Constants for hybi-10 frame format.
38 
39 const unsigned char kFinalBit = 0x80;
40 const unsigned char kReserved1Bit = 0x40;
41 const unsigned char kReserved2Bit = 0x20;
42 const unsigned char kReserved3Bit = 0x10;
43 const unsigned char kOpCodeMask = 0xF;
44 const unsigned char kMaskBit = 0x80;
45 const unsigned char kPayloadLengthMask = 0x7F;
46 
47 const size_t kMaxSingleBytePayloadLength = 125;
48 const size_t kTwoBytePayloadLengthField = 126;
49 const size_t kEightBytePayloadLengthField = 127;
50 const size_t kMaskingKeyWidthInBytes = 4;
51 
DecodeFrameHybi17(std::string_view frame,bool client_frame,int * bytes_consumed,std::string * output,bool * compressed)52 WebSocketParseResult DecodeFrameHybi17(std::string_view frame,
53                                        bool client_frame,
54                                        int* bytes_consumed,
55                                        std::string* output,
56                                        bool* compressed) {
57   size_t data_length = frame.length();
58   if (data_length < 2)
59     return WebSocketParseResult::FRAME_INCOMPLETE;
60 
61   const char* buffer_begin = const_cast<char*>(frame.data());
62   const char* p = buffer_begin;
63   const char* buffer_end = p + data_length;
64 
65   unsigned char first_byte = *p++;
66   unsigned char second_byte = *p++;
67 
68   bool final = (first_byte & kFinalBit) != 0;
69   bool reserved1 = (first_byte & kReserved1Bit) != 0;
70   bool reserved2 = (first_byte & kReserved2Bit) != 0;
71   bool reserved3 = (first_byte & kReserved3Bit) != 0;
72   int op_code = first_byte & kOpCodeMask;
73   bool masked = (second_byte & kMaskBit) != 0;
74   *compressed = reserved1;
75   if (reserved2 || reserved3)
76     return WebSocketParseResult::FRAME_ERROR;  // Only compression extension is
77                                                // supported.
78 
79   bool closed = false;
80   switch (op_code) {
81     case WebSocketFrameHeader::OpCodeEnum::kOpCodeClose:
82       closed = true;
83       break;
84 
85     case WebSocketFrameHeader::OpCodeEnum::kOpCodeText:
86     case WebSocketFrameHeader::OpCodeEnum::
87         kOpCodeContinuation:  // Treated in the same as kOpCodeText.
88     case WebSocketFrameHeader::OpCodeEnum::kOpCodePing:
89     case WebSocketFrameHeader::OpCodeEnum::kOpCodePong:
90       break;
91 
92     case WebSocketFrameHeader::OpCodeEnum::kOpCodeBinary:  // We don't support
93                                                            // binary frames yet.
94     default:
95       return WebSocketParseResult::FRAME_ERROR;
96   }
97 
98   if (client_frame && !masked)  // In Hybi-17 spec client MUST mask its frame.
99     return WebSocketParseResult::FRAME_ERROR;
100 
101   uint64_t payload_length64 = second_byte & kPayloadLengthMask;
102   if (payload_length64 > kMaxSingleBytePayloadLength) {
103     int extended_payload_length_size;
104     if (payload_length64 == kTwoBytePayloadLengthField) {
105       extended_payload_length_size = 2;
106     } else {
107       DCHECK(payload_length64 == kEightBytePayloadLengthField);
108       extended_payload_length_size = 8;
109     }
110     if (buffer_end - p < extended_payload_length_size)
111       return WebSocketParseResult::FRAME_INCOMPLETE;
112     payload_length64 = 0;
113     for (int i = 0; i < extended_payload_length_size; ++i) {
114       payload_length64 <<= 8;
115       payload_length64 |= static_cast<unsigned char>(*p++);
116     }
117   }
118 
119   size_t actual_masking_key_length = masked ? kMaskingKeyWidthInBytes : 0;
120   static const uint64_t max_payload_length = 0x7FFFFFFFFFFFFFFFull;
121   static size_t max_length = std::numeric_limits<size_t>::max();
122   if (payload_length64 > max_payload_length ||
123       payload_length64 + actual_masking_key_length > max_length) {
124     // WebSocket frame length too large.
125     return WebSocketParseResult::FRAME_ERROR;
126   }
127   size_t payload_length = static_cast<size_t>(payload_length64);
128 
129   size_t total_length = actual_masking_key_length + payload_length;
130   if (static_cast<size_t>(buffer_end - p) < total_length)
131     return WebSocketParseResult::FRAME_INCOMPLETE;
132 
133   if (masked) {
134     output->resize(payload_length);
135     const char* masking_key = p;
136     char* payload = const_cast<char*>(p + kMaskingKeyWidthInBytes);
137     for (size_t i = 0; i < payload_length; ++i)  // Unmask the payload.
138       (*output)[i] = payload[i] ^ masking_key[i % kMaskingKeyWidthInBytes];
139   } else {
140     output->assign(p, p + payload_length);
141   }
142 
143   size_t pos = p + actual_masking_key_length + payload_length - buffer_begin;
144   *bytes_consumed = pos;
145 
146   if (op_code == WebSocketFrameHeader::OpCodeEnum::kOpCodePing)
147     return WebSocketParseResult::FRAME_PING;
148 
149   if (op_code == WebSocketFrameHeader::OpCodeEnum::kOpCodePong)
150     return WebSocketParseResult::FRAME_PONG;
151 
152   if (closed)
153     return WebSocketParseResult::FRAME_CLOSE;
154 
155   return final ? WebSocketParseResult::FRAME_OK_FINAL
156                : WebSocketParseResult::FRAME_OK_MIDDLE;
157 }
158 
EncodeFrameHybi17(std::string_view message,int masking_key,bool compressed,WebSocketFrameHeader::OpCodeEnum op_code,std::string * output)159 void EncodeFrameHybi17(std::string_view message,
160                        int masking_key,
161                        bool compressed,
162                        WebSocketFrameHeader::OpCodeEnum op_code,
163                        std::string* output) {
164   std::vector<char> frame;
165   size_t data_length = message.length();
166 
167   int reserved1 = compressed ? kReserved1Bit : 0;
168   frame.push_back(kFinalBit | op_code | reserved1);
169   char mask_key_bit = masking_key != 0 ? kMaskBit : 0;
170   if (data_length <= kMaxSingleBytePayloadLength) {
171     frame.push_back(static_cast<char>(data_length) | mask_key_bit);
172   } else if (data_length <= 0xFFFF) {
173     frame.push_back(kTwoBytePayloadLengthField | mask_key_bit);
174     frame.push_back((data_length & 0xFF00) >> 8);
175     frame.push_back(data_length & 0xFF);
176   } else {
177     frame.push_back(kEightBytePayloadLengthField | mask_key_bit);
178     char extended_payload_length[8];
179     size_t remaining = data_length;
180     // Fill the length into extended_payload_length in the network byte order.
181     for (int i = 0; i < 8; ++i) {
182       extended_payload_length[7 - i] = remaining & 0xFF;
183       remaining >>= 8;
184     }
185     frame.insert(frame.end(), extended_payload_length,
186                  extended_payload_length + 8);
187     DCHECK(!remaining);
188   }
189 
190   const char* data = const_cast<char*>(message.data());
191   if (masking_key != 0) {
192     const char* mask_bytes = reinterpret_cast<char*>(&masking_key);
193     frame.insert(frame.end(), mask_bytes, mask_bytes + 4);
194     for (size_t i = 0; i < data_length; ++i)  // Mask the payload.
195       frame.push_back(data[i] ^ mask_bytes[i % kMaskingKeyWidthInBytes]);
196   } else {
197     frame.insert(frame.end(), data, data + data_length);
198   }
199   *output = std::string(frame.data(), frame.size());
200 }
201 
202 }  // anonymous namespace
203 
204 // static
CreateServer()205 std::unique_ptr<WebSocketEncoder> WebSocketEncoder::CreateServer() {
206   return base::WrapUnique(new WebSocketEncoder(FOR_SERVER, nullptr, nullptr));
207 }
208 
209 // static
CreateServer(const std::string & extensions,WebSocketDeflateParameters * deflate_parameters)210 std::unique_ptr<WebSocketEncoder> WebSocketEncoder::CreateServer(
211     const std::string& extensions,
212     WebSocketDeflateParameters* deflate_parameters) {
213   WebSocketExtensionParser parser;
214   if (!parser.Parse(extensions)) {
215     // Failed to parse Sec-WebSocket-Extensions header. We MUST fail the
216     // connection.
217     return nullptr;
218   }
219 
220   for (const auto& extension : parser.extensions()) {
221     std::string failure_message;
222     WebSocketDeflateParameters offer;
223     if (!offer.Initialize(extension, &failure_message) ||
224         !offer.IsValidAsRequest(&failure_message)) {
225       // We decline unknown / malformed extensions.
226       continue;
227     }
228 
229     WebSocketDeflateParameters response = offer;
230     if (offer.is_client_max_window_bits_specified() &&
231         !offer.has_client_max_window_bits_value()) {
232       // We need to choose one value for the response.
233       response.SetClientMaxWindowBits(15);
234     }
235     DCHECK(response.IsValidAsResponse());
236     DCHECK(offer.IsCompatibleWith(response));
237     auto deflater = std::make_unique<WebSocketDeflater>(
238         response.server_context_take_over_mode());
239     auto inflater = std::make_unique<WebSocketInflater>(kInflaterChunkSize,
240                                                         kInflaterChunkSize);
241     if (!deflater->Initialize(response.PermissiveServerMaxWindowBits()) ||
242         !inflater->Initialize(response.PermissiveClientMaxWindowBits())) {
243       // For some reason we cannot accept the parameters.
244       continue;
245     }
246     *deflate_parameters = response;
247     return base::WrapUnique(new WebSocketEncoder(
248         FOR_SERVER, std::move(deflater), std::move(inflater)));
249   }
250 
251   // We cannot find an acceptable offer.
252   return base::WrapUnique(new WebSocketEncoder(FOR_SERVER, nullptr, nullptr));
253 }
254 
255 // static
CreateClient(const std::string & response_extensions)256 std::unique_ptr<WebSocketEncoder> WebSocketEncoder::CreateClient(
257     const std::string& response_extensions) {
258   // TODO(yhirano): Add a way to return an error.
259 
260   WebSocketExtensionParser parser;
261   if (!parser.Parse(response_extensions)) {
262     // Parse error. Note that there are two cases here.
263     // 1) There is no Sec-WebSocket-Extensions header.
264     // 2) There is a malformed Sec-WebSocketExtensions header.
265     // We should return a deflate-disabled encoder for the former case and
266     // fail the connection for the latter case.
267     return base::WrapUnique(new WebSocketEncoder(FOR_CLIENT, nullptr, nullptr));
268   }
269   if (parser.extensions().size() != 1) {
270     // Only permessage-deflate extension is supported.
271     // TODO (yhirano): Fail the connection.
272     return base::WrapUnique(new WebSocketEncoder(FOR_CLIENT, nullptr, nullptr));
273   }
274   const auto& extension = parser.extensions()[0];
275   WebSocketDeflateParameters params;
276   std::string failure_message;
277   if (!params.Initialize(extension, &failure_message) ||
278       !params.IsValidAsResponse(&failure_message)) {
279     // TODO (yhirano): Fail the connection.
280     return base::WrapUnique(new WebSocketEncoder(FOR_CLIENT, nullptr, nullptr));
281   }
282 
283   auto deflater = std::make_unique<WebSocketDeflater>(
284       params.client_context_take_over_mode());
285   auto inflater = std::make_unique<WebSocketInflater>(kInflaterChunkSize,
286                                                       kInflaterChunkSize);
287   if (!deflater->Initialize(params.PermissiveClientMaxWindowBits()) ||
288       !inflater->Initialize(params.PermissiveServerMaxWindowBits())) {
289     // TODO (yhirano): Fail the connection.
290     return base::WrapUnique(new WebSocketEncoder(FOR_CLIENT, nullptr, nullptr));
291   }
292 
293   return base::WrapUnique(new WebSocketEncoder(FOR_CLIENT, std::move(deflater),
294                                                std::move(inflater)));
295 }
296 
WebSocketEncoder(Type type,std::unique_ptr<WebSocketDeflater> deflater,std::unique_ptr<WebSocketInflater> inflater)297 WebSocketEncoder::WebSocketEncoder(Type type,
298                                    std::unique_ptr<WebSocketDeflater> deflater,
299                                    std::unique_ptr<WebSocketInflater> inflater)
300     : type_(type),
301       deflater_(std::move(deflater)),
302       inflater_(std::move(inflater)) {}
303 
304 WebSocketEncoder::~WebSocketEncoder() = default;
305 
DecodeFrame(std::string_view frame,int * bytes_consumed,std::string * output)306 WebSocketParseResult WebSocketEncoder::DecodeFrame(std::string_view frame,
307                                                    int* bytes_consumed,
308                                                    std::string* output) {
309   bool compressed;
310   std::string current_output;
311   WebSocketParseResult result = DecodeFrameHybi17(
312       frame, type_ == FOR_SERVER, bytes_consumed, &current_output, &compressed);
313   switch (result) {
314     case WebSocketParseResult::FRAME_OK_FINAL:
315     case WebSocketParseResult::FRAME_OK_MIDDLE: {
316       if (continuation_message_frames_.empty())
317         is_current_message_compressed_ = compressed;
318       continuation_message_frames_.push_back(current_output);
319 
320       if (result == WebSocketParseResult::FRAME_OK_FINAL) {
321         *output = base::StrCat(continuation_message_frames_);
322         continuation_message_frames_.clear();
323         if (is_current_message_compressed_ && !Inflate(output)) {
324           return WebSocketParseResult::FRAME_ERROR;
325         }
326       }
327       break;
328     }
329 
330     case WebSocketParseResult::FRAME_PING:
331       *output = current_output;
332       break;
333 
334     default:
335       // This function doesn't need special handling for other parse results.
336       break;
337   }
338 
339   return result;
340 }
341 
EncodeTextFrame(std::string_view frame,int masking_key,std::string * output)342 void WebSocketEncoder::EncodeTextFrame(std::string_view frame,
343                                        int masking_key,
344                                        std::string* output) {
345   std::string compressed;
346   constexpr auto op_code = WebSocketFrameHeader::OpCodeEnum::kOpCodeText;
347   if (Deflate(frame, &compressed))
348     EncodeFrameHybi17(compressed, masking_key, true, op_code, output);
349   else
350     EncodeFrameHybi17(frame, masking_key, false, op_code, output);
351 }
352 
EncodeCloseFrame(std::string_view frame,int masking_key,std::string * output)353 void WebSocketEncoder::EncodeCloseFrame(std::string_view frame,
354                                         int masking_key,
355                                         std::string* output) {
356   constexpr auto op_code = WebSocketFrameHeader::OpCodeEnum::kOpCodeClose;
357   EncodeFrameHybi17(frame, masking_key, false, op_code, output);
358 }
359 
EncodePongFrame(std::string_view frame,int masking_key,std::string * output)360 void WebSocketEncoder::EncodePongFrame(std::string_view frame,
361                                        int masking_key,
362                                        std::string* output) {
363   constexpr auto op_code = WebSocketFrameHeader::OpCodeEnum::kOpCodePong;
364   EncodeFrameHybi17(frame, masking_key, false, op_code, output);
365 }
366 
Inflate(std::string * message)367 bool WebSocketEncoder::Inflate(std::string* message) {
368   if (!inflater_)
369     return false;
370   if (!inflater_->AddBytes(message->data(), message->length()))
371     return false;
372   if (!inflater_->Finish())
373     return false;
374 
375   std::vector<char> output;
376   while (inflater_->CurrentOutputSize() > 0) {
377     scoped_refptr<IOBufferWithSize> chunk =
378         inflater_->GetOutput(inflater_->CurrentOutputSize());
379     if (!chunk.get())
380       return false;
381     output.insert(output.end(), chunk->data(), chunk->data() + chunk->size());
382   }
383 
384   *message =
385       output.size() ? std::string(output.data(), output.size()) : std::string();
386   return true;
387 }
388 
Deflate(std::string_view message,std::string * output)389 bool WebSocketEncoder::Deflate(std::string_view message, std::string* output) {
390   if (!deflater_)
391     return false;
392   if (!deflater_->AddBytes(message.data(), message.length())) {
393     deflater_->Finish();
394     return false;
395   }
396   if (!deflater_->Finish())
397     return false;
398   scoped_refptr<IOBufferWithSize> buffer =
399       deflater_->GetOutput(deflater_->CurrentOutputSize());
400   if (!buffer.get())
401     return false;
402   *output = std::string(buffer->data(), buffer->size());
403   return true;
404 }
405 
406 }  // namespace net
407