• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2014 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/server/web_socket_encoder.h"
6 
7 #include <limits>
8 #include <utility>
9 
10 #include "base/check.h"
11 #include "base/memory/ptr_util.h"
12 #include "base/strings/strcat.h"
13 #include "base/strings/string_number_conversions.h"
14 #include "net/base/io_buffer.h"
15 #include "net/websockets/websocket_deflate_parameters.h"
16 #include "net/websockets/websocket_extension.h"
17 #include "net/websockets/websocket_extension_parser.h"
18 #include "net/websockets/websocket_frame.h"
19 
20 namespace net {
21 
22 const char WebSocketEncoder::kClientExtensions[] =
23     "permessage-deflate; client_max_window_bits";
24 
25 namespace {
26 
27 const int kInflaterChunkSize = 16 * 1024;
28 
29 // Constants for hybi-10 frame format.
30 
31 const unsigned char kFinalBit = 0x80;
32 const unsigned char kReserved1Bit = 0x40;
33 const unsigned char kReserved2Bit = 0x20;
34 const unsigned char kReserved3Bit = 0x10;
35 const unsigned char kOpCodeMask = 0xF;
36 const unsigned char kMaskBit = 0x80;
37 const unsigned char kPayloadLengthMask = 0x7F;
38 
39 const size_t kMaxSingleBytePayloadLength = 125;
40 const size_t kTwoBytePayloadLengthField = 126;
41 const size_t kEightBytePayloadLengthField = 127;
42 const size_t kMaskingKeyWidthInBytes = 4;
43 
DecodeFrameHybi17(base::StringPiece frame,bool client_frame,int * bytes_consumed,std::string * output,bool * compressed)44 WebSocket::ParseResult DecodeFrameHybi17(base::StringPiece frame,
45                                          bool client_frame,
46                                          int* bytes_consumed,
47                                          std::string* output,
48                                          bool* compressed) {
49   size_t data_length = frame.length();
50   if (data_length < 2)
51     return WebSocket::FRAME_INCOMPLETE;
52 
53   const char* buffer_begin = const_cast<char*>(frame.data());
54   const char* p = buffer_begin;
55   const char* buffer_end = p + data_length;
56 
57   unsigned char first_byte = *p++;
58   unsigned char second_byte = *p++;
59 
60   bool final = (first_byte & kFinalBit) != 0;
61   bool reserved1 = (first_byte & kReserved1Bit) != 0;
62   bool reserved2 = (first_byte & kReserved2Bit) != 0;
63   bool reserved3 = (first_byte & kReserved3Bit) != 0;
64   int op_code = first_byte & kOpCodeMask;
65   bool masked = (second_byte & kMaskBit) != 0;
66   *compressed = reserved1;
67   if (reserved2 || reserved3)
68     return WebSocket::FRAME_ERROR;  // Only compression extension is supported.
69 
70   bool closed = false;
71   switch (op_code) {
72     case WebSocketFrameHeader::OpCodeEnum::kOpCodeClose:
73       closed = true;
74       break;
75 
76     case WebSocketFrameHeader::OpCodeEnum::kOpCodeText:
77     case WebSocketFrameHeader::OpCodeEnum::
78         kOpCodeContinuation:  // Treated in the same as kOpCodeText.
79     case WebSocketFrameHeader::OpCodeEnum::kOpCodePing:
80     case WebSocketFrameHeader::OpCodeEnum::kOpCodePong:
81       break;
82 
83     case WebSocketFrameHeader::OpCodeEnum::kOpCodeBinary:  // We don't support
84                                                            // binary frames yet.
85     default:
86       return WebSocket::FRAME_ERROR;
87   }
88 
89   if (client_frame && !masked)  // In Hybi-17 spec client MUST mask its frame.
90     return WebSocket::FRAME_ERROR;
91 
92   uint64_t payload_length64 = second_byte & kPayloadLengthMask;
93   if (payload_length64 > kMaxSingleBytePayloadLength) {
94     int extended_payload_length_size;
95     if (payload_length64 == kTwoBytePayloadLengthField) {
96       extended_payload_length_size = 2;
97     } else {
98       DCHECK(payload_length64 == kEightBytePayloadLengthField);
99       extended_payload_length_size = 8;
100     }
101     if (buffer_end - p < extended_payload_length_size)
102       return WebSocket::FRAME_INCOMPLETE;
103     payload_length64 = 0;
104     for (int i = 0; i < extended_payload_length_size; ++i) {
105       payload_length64 <<= 8;
106       payload_length64 |= static_cast<unsigned char>(*p++);
107     }
108   }
109 
110   size_t actual_masking_key_length = masked ? kMaskingKeyWidthInBytes : 0;
111   static const uint64_t max_payload_length = 0x7FFFFFFFFFFFFFFFull;
112   static size_t max_length = std::numeric_limits<size_t>::max();
113   if (payload_length64 > max_payload_length ||
114       payload_length64 + actual_masking_key_length > max_length) {
115     // WebSocket frame length too large.
116     return WebSocket::FRAME_ERROR;
117   }
118   size_t payload_length = static_cast<size_t>(payload_length64);
119 
120   size_t total_length = actual_masking_key_length + payload_length;
121   if (static_cast<size_t>(buffer_end - p) < total_length)
122     return WebSocket::FRAME_INCOMPLETE;
123 
124   if (masked) {
125     output->resize(payload_length);
126     const char* masking_key = p;
127     char* payload = const_cast<char*>(p + kMaskingKeyWidthInBytes);
128     for (size_t i = 0; i < payload_length; ++i)  // Unmask the payload.
129       (*output)[i] = payload[i] ^ masking_key[i % kMaskingKeyWidthInBytes];
130   } else {
131     output->assign(p, p + payload_length);
132   }
133 
134   size_t pos = p + actual_masking_key_length + payload_length - buffer_begin;
135   *bytes_consumed = pos;
136 
137   if (op_code == WebSocketFrameHeader::OpCodeEnum::kOpCodePing)
138     return WebSocket::FRAME_PING;
139 
140   if (op_code == WebSocketFrameHeader::OpCodeEnum::kOpCodePong)
141     return WebSocket::FRAME_PONG;
142 
143   if (closed)
144     return WebSocket::FRAME_CLOSE;
145 
146   return final ? WebSocket::FRAME_OK_FINAL : WebSocket::FRAME_OK_MIDDLE;
147 }
148 
EncodeFrameHybi17(base::StringPiece message,int masking_key,bool compressed,WebSocketFrameHeader::OpCodeEnum op_code,std::string * output)149 void EncodeFrameHybi17(base::StringPiece message,
150                        int masking_key,
151                        bool compressed,
152                        WebSocketFrameHeader::OpCodeEnum op_code,
153                        std::string* output) {
154   std::vector<char> frame;
155   size_t data_length = message.length();
156 
157   int reserved1 = compressed ? kReserved1Bit : 0;
158   frame.push_back(kFinalBit | op_code | reserved1);
159   char mask_key_bit = masking_key != 0 ? kMaskBit : 0;
160   if (data_length <= kMaxSingleBytePayloadLength) {
161     frame.push_back(static_cast<char>(data_length) | mask_key_bit);
162   } else if (data_length <= 0xFFFF) {
163     frame.push_back(kTwoBytePayloadLengthField | mask_key_bit);
164     frame.push_back((data_length & 0xFF00) >> 8);
165     frame.push_back(data_length & 0xFF);
166   } else {
167     frame.push_back(kEightBytePayloadLengthField | mask_key_bit);
168     char extended_payload_length[8];
169     size_t remaining = data_length;
170     // Fill the length into extended_payload_length in the network byte order.
171     for (int i = 0; i < 8; ++i) {
172       extended_payload_length[7 - i] = remaining & 0xFF;
173       remaining >>= 8;
174     }
175     frame.insert(frame.end(), extended_payload_length,
176                  extended_payload_length + 8);
177     DCHECK(!remaining);
178   }
179 
180   const char* data = const_cast<char*>(message.data());
181   if (masking_key != 0) {
182     const char* mask_bytes = reinterpret_cast<char*>(&masking_key);
183     frame.insert(frame.end(), mask_bytes, mask_bytes + 4);
184     for (size_t i = 0; i < data_length; ++i)  // Mask the payload.
185       frame.push_back(data[i] ^ mask_bytes[i % kMaskingKeyWidthInBytes]);
186   } else {
187     frame.insert(frame.end(), data, data + data_length);
188   }
189   *output = std::string(frame.data(), frame.size());
190 }
191 
192 }  // anonymous namespace
193 
194 // static
CreateServer()195 std::unique_ptr<WebSocketEncoder> WebSocketEncoder::CreateServer() {
196   return base::WrapUnique(new WebSocketEncoder(FOR_SERVER, nullptr, nullptr));
197 }
198 
199 // static
CreateServer(const std::string & extensions,WebSocketDeflateParameters * deflate_parameters)200 std::unique_ptr<WebSocketEncoder> WebSocketEncoder::CreateServer(
201     const std::string& extensions,
202     WebSocketDeflateParameters* deflate_parameters) {
203   WebSocketExtensionParser parser;
204   if (!parser.Parse(extensions)) {
205     // Failed to parse Sec-WebSocket-Extensions header. We MUST fail the
206     // connection.
207     return nullptr;
208   }
209 
210   for (const auto& extension : parser.extensions()) {
211     std::string failure_message;
212     WebSocketDeflateParameters offer;
213     if (!offer.Initialize(extension, &failure_message) ||
214         !offer.IsValidAsRequest(&failure_message)) {
215       // We decline unknown / malformed extensions.
216       continue;
217     }
218 
219     WebSocketDeflateParameters response = offer;
220     if (offer.is_client_max_window_bits_specified() &&
221         !offer.has_client_max_window_bits_value()) {
222       // We need to choose one value for the response.
223       response.SetClientMaxWindowBits(15);
224     }
225     DCHECK(response.IsValidAsResponse());
226     DCHECK(offer.IsCompatibleWith(response));
227     auto deflater = std::make_unique<WebSocketDeflater>(
228         response.server_context_take_over_mode());
229     auto inflater = std::make_unique<WebSocketInflater>(kInflaterChunkSize,
230                                                         kInflaterChunkSize);
231     if (!deflater->Initialize(response.PermissiveServerMaxWindowBits()) ||
232         !inflater->Initialize(response.PermissiveClientMaxWindowBits())) {
233       // For some reason we cannot accept the parameters.
234       continue;
235     }
236     *deflate_parameters = response;
237     return base::WrapUnique(new WebSocketEncoder(
238         FOR_SERVER, std::move(deflater), std::move(inflater)));
239   }
240 
241   // We cannot find an acceptable offer.
242   return base::WrapUnique(new WebSocketEncoder(FOR_SERVER, nullptr, nullptr));
243 }
244 
245 // static
CreateClient(const std::string & response_extensions)246 std::unique_ptr<WebSocketEncoder> WebSocketEncoder::CreateClient(
247     const std::string& response_extensions) {
248   // TODO(yhirano): Add a way to return an error.
249 
250   WebSocketExtensionParser parser;
251   if (!parser.Parse(response_extensions)) {
252     // Parse error. Note that there are two cases here.
253     // 1) There is no Sec-WebSocket-Extensions header.
254     // 2) There is a malformed Sec-WebSocketExtensions header.
255     // We should return a deflate-disabled encoder for the former case and
256     // fail the connection for the latter case.
257     return base::WrapUnique(new WebSocketEncoder(FOR_CLIENT, nullptr, nullptr));
258   }
259   if (parser.extensions().size() != 1) {
260     // Only permessage-deflate extension is supported.
261     // TODO (yhirano): Fail the connection.
262     return base::WrapUnique(new WebSocketEncoder(FOR_CLIENT, nullptr, nullptr));
263   }
264   const auto& extension = parser.extensions()[0];
265   WebSocketDeflateParameters params;
266   std::string failure_message;
267   if (!params.Initialize(extension, &failure_message) ||
268       !params.IsValidAsResponse(&failure_message)) {
269     // TODO (yhirano): Fail the connection.
270     return base::WrapUnique(new WebSocketEncoder(FOR_CLIENT, nullptr, nullptr));
271   }
272 
273   auto deflater = std::make_unique<WebSocketDeflater>(
274       params.client_context_take_over_mode());
275   auto inflater = std::make_unique<WebSocketInflater>(kInflaterChunkSize,
276                                                       kInflaterChunkSize);
277   if (!deflater->Initialize(params.PermissiveClientMaxWindowBits()) ||
278       !inflater->Initialize(params.PermissiveServerMaxWindowBits())) {
279     // TODO (yhirano): Fail the connection.
280     return base::WrapUnique(new WebSocketEncoder(FOR_CLIENT, nullptr, nullptr));
281   }
282 
283   return base::WrapUnique(new WebSocketEncoder(FOR_CLIENT, std::move(deflater),
284                                                std::move(inflater)));
285 }
286 
WebSocketEncoder(Type type,std::unique_ptr<WebSocketDeflater> deflater,std::unique_ptr<WebSocketInflater> inflater)287 WebSocketEncoder::WebSocketEncoder(Type type,
288                                    std::unique_ptr<WebSocketDeflater> deflater,
289                                    std::unique_ptr<WebSocketInflater> inflater)
290     : type_(type),
291       deflater_(std::move(deflater)),
292       inflater_(std::move(inflater)) {}
293 
294 WebSocketEncoder::~WebSocketEncoder() = default;
295 
DecodeFrame(base::StringPiece frame,int * bytes_consumed,std::string * output)296 WebSocket::ParseResult WebSocketEncoder::DecodeFrame(base::StringPiece frame,
297                                                      int* bytes_consumed,
298                                                      std::string* output) {
299   bool compressed;
300   std::string current_output;
301   WebSocket::ParseResult result = DecodeFrameHybi17(
302       frame, type_ == FOR_SERVER, bytes_consumed, &current_output, &compressed);
303   switch (result) {
304     case WebSocket::FRAME_OK_FINAL:
305     case WebSocket::FRAME_OK_MIDDLE: {
306       if (continuation_message_frames_.empty())
307         is_current_message_compressed_ = compressed;
308       continuation_message_frames_.push_back(current_output);
309 
310       if (result == WebSocket::FRAME_OK_FINAL) {
311         *output = base::StrCat(continuation_message_frames_);
312         continuation_message_frames_.clear();
313         if (is_current_message_compressed_ && !Inflate(output)) {
314           return WebSocket::FRAME_ERROR;
315         }
316       }
317       break;
318     }
319 
320     case WebSocket::FRAME_PING:
321       *output = current_output;
322       break;
323 
324     default:
325       // This function doesn't need special handling for other parse results.
326       break;
327   }
328 
329   return result;
330 }
331 
EncodeTextFrame(base::StringPiece frame,int masking_key,std::string * output)332 void WebSocketEncoder::EncodeTextFrame(base::StringPiece frame,
333                                        int masking_key,
334                                        std::string* output) {
335   std::string compressed;
336   constexpr auto op_code = WebSocketFrameHeader::OpCodeEnum::kOpCodeText;
337   if (Deflate(frame, &compressed))
338     EncodeFrameHybi17(compressed, masking_key, true, op_code, output);
339   else
340     EncodeFrameHybi17(frame, masking_key, false, op_code, output);
341 }
342 
EncodePongFrame(base::StringPiece frame,int masking_key,std::string * output)343 void WebSocketEncoder::EncodePongFrame(base::StringPiece frame,
344                                        int masking_key,
345                                        std::string* output) {
346   constexpr auto op_code = WebSocketFrameHeader::OpCodeEnum::kOpCodePong;
347   EncodeFrameHybi17(frame, masking_key, false, op_code, output);
348 }
349 
Inflate(std::string * message)350 bool WebSocketEncoder::Inflate(std::string* message) {
351   if (!inflater_)
352     return false;
353   if (!inflater_->AddBytes(message->data(), message->length()))
354     return false;
355   if (!inflater_->Finish())
356     return false;
357 
358   std::vector<char> output;
359   while (inflater_->CurrentOutputSize() > 0) {
360     scoped_refptr<IOBufferWithSize> chunk =
361         inflater_->GetOutput(inflater_->CurrentOutputSize());
362     if (!chunk.get())
363       return false;
364     output.insert(output.end(), chunk->data(), chunk->data() + chunk->size());
365   }
366 
367   *message =
368       output.size() ? std::string(output.data(), output.size()) : std::string();
369   return true;
370 }
371 
Deflate(base::StringPiece message,std::string * output)372 bool WebSocketEncoder::Deflate(base::StringPiece message, std::string* output) {
373   if (!deflater_)
374     return false;
375   if (!deflater_->AddBytes(message.data(), message.length())) {
376     deflater_->Finish();
377     return false;
378   }
379   if (!deflater_->Finish())
380     return false;
381   scoped_refptr<IOBufferWithSize> buffer =
382       deflater_->GetOutput(deflater_->CurrentOutputSize());
383   if (!buffer.get())
384     return false;
385   *output = std::string(buffer->data(), buffer->size());
386   return true;
387 }
388 
389 }  // namespace net
390