1 // Copyright 2014 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/server/web_socket_encoder.h"
6
7 #include <limits>
8 #include <utility>
9
10 #include "base/check.h"
11 #include "base/memory/ptr_util.h"
12 #include "base/strings/strcat.h"
13 #include "base/strings/string_number_conversions.h"
14 #include "net/base/io_buffer.h"
15 #include "net/websockets/websocket_deflate_parameters.h"
16 #include "net/websockets/websocket_extension.h"
17 #include "net/websockets/websocket_extension_parser.h"
18 #include "net/websockets/websocket_frame.h"
19
20 namespace net {
21
22 const char WebSocketEncoder::kClientExtensions[] =
23 "permessage-deflate; client_max_window_bits";
24
25 namespace {
26
27 const int kInflaterChunkSize = 16 * 1024;
28
29 // Constants for hybi-10 frame format.
30
31 const unsigned char kFinalBit = 0x80;
32 const unsigned char kReserved1Bit = 0x40;
33 const unsigned char kReserved2Bit = 0x20;
34 const unsigned char kReserved3Bit = 0x10;
35 const unsigned char kOpCodeMask = 0xF;
36 const unsigned char kMaskBit = 0x80;
37 const unsigned char kPayloadLengthMask = 0x7F;
38
39 const size_t kMaxSingleBytePayloadLength = 125;
40 const size_t kTwoBytePayloadLengthField = 126;
41 const size_t kEightBytePayloadLengthField = 127;
42 const size_t kMaskingKeyWidthInBytes = 4;
43
DecodeFrameHybi17(base::StringPiece frame,bool client_frame,int * bytes_consumed,std::string * output,bool * compressed)44 WebSocket::ParseResult DecodeFrameHybi17(base::StringPiece frame,
45 bool client_frame,
46 int* bytes_consumed,
47 std::string* output,
48 bool* compressed) {
49 size_t data_length = frame.length();
50 if (data_length < 2)
51 return WebSocket::FRAME_INCOMPLETE;
52
53 const char* buffer_begin = const_cast<char*>(frame.data());
54 const char* p = buffer_begin;
55 const char* buffer_end = p + data_length;
56
57 unsigned char first_byte = *p++;
58 unsigned char second_byte = *p++;
59
60 bool final = (first_byte & kFinalBit) != 0;
61 bool reserved1 = (first_byte & kReserved1Bit) != 0;
62 bool reserved2 = (first_byte & kReserved2Bit) != 0;
63 bool reserved3 = (first_byte & kReserved3Bit) != 0;
64 int op_code = first_byte & kOpCodeMask;
65 bool masked = (second_byte & kMaskBit) != 0;
66 *compressed = reserved1;
67 if (reserved2 || reserved3)
68 return WebSocket::FRAME_ERROR; // Only compression extension is supported.
69
70 bool closed = false;
71 switch (op_code) {
72 case WebSocketFrameHeader::OpCodeEnum::kOpCodeClose:
73 closed = true;
74 break;
75
76 case WebSocketFrameHeader::OpCodeEnum::kOpCodeText:
77 case WebSocketFrameHeader::OpCodeEnum::
78 kOpCodeContinuation: // Treated in the same as kOpCodeText.
79 case WebSocketFrameHeader::OpCodeEnum::kOpCodePing:
80 case WebSocketFrameHeader::OpCodeEnum::kOpCodePong:
81 break;
82
83 case WebSocketFrameHeader::OpCodeEnum::kOpCodeBinary: // We don't support
84 // binary frames yet.
85 default:
86 return WebSocket::FRAME_ERROR;
87 }
88
89 if (client_frame && !masked) // In Hybi-17 spec client MUST mask its frame.
90 return WebSocket::FRAME_ERROR;
91
92 uint64_t payload_length64 = second_byte & kPayloadLengthMask;
93 if (payload_length64 > kMaxSingleBytePayloadLength) {
94 int extended_payload_length_size;
95 if (payload_length64 == kTwoBytePayloadLengthField) {
96 extended_payload_length_size = 2;
97 } else {
98 DCHECK(payload_length64 == kEightBytePayloadLengthField);
99 extended_payload_length_size = 8;
100 }
101 if (buffer_end - p < extended_payload_length_size)
102 return WebSocket::FRAME_INCOMPLETE;
103 payload_length64 = 0;
104 for (int i = 0; i < extended_payload_length_size; ++i) {
105 payload_length64 <<= 8;
106 payload_length64 |= static_cast<unsigned char>(*p++);
107 }
108 }
109
110 size_t actual_masking_key_length = masked ? kMaskingKeyWidthInBytes : 0;
111 static const uint64_t max_payload_length = 0x7FFFFFFFFFFFFFFFull;
112 static size_t max_length = std::numeric_limits<size_t>::max();
113 if (payload_length64 > max_payload_length ||
114 payload_length64 + actual_masking_key_length > max_length) {
115 // WebSocket frame length too large.
116 return WebSocket::FRAME_ERROR;
117 }
118 size_t payload_length = static_cast<size_t>(payload_length64);
119
120 size_t total_length = actual_masking_key_length + payload_length;
121 if (static_cast<size_t>(buffer_end - p) < total_length)
122 return WebSocket::FRAME_INCOMPLETE;
123
124 if (masked) {
125 output->resize(payload_length);
126 const char* masking_key = p;
127 char* payload = const_cast<char*>(p + kMaskingKeyWidthInBytes);
128 for (size_t i = 0; i < payload_length; ++i) // Unmask the payload.
129 (*output)[i] = payload[i] ^ masking_key[i % kMaskingKeyWidthInBytes];
130 } else {
131 output->assign(p, p + payload_length);
132 }
133
134 size_t pos = p + actual_masking_key_length + payload_length - buffer_begin;
135 *bytes_consumed = pos;
136
137 if (op_code == WebSocketFrameHeader::OpCodeEnum::kOpCodePing)
138 return WebSocket::FRAME_PING;
139
140 if (op_code == WebSocketFrameHeader::OpCodeEnum::kOpCodePong)
141 return WebSocket::FRAME_PONG;
142
143 if (closed)
144 return WebSocket::FRAME_CLOSE;
145
146 return final ? WebSocket::FRAME_OK_FINAL : WebSocket::FRAME_OK_MIDDLE;
147 }
148
EncodeFrameHybi17(base::StringPiece message,int masking_key,bool compressed,WebSocketFrameHeader::OpCodeEnum op_code,std::string * output)149 void EncodeFrameHybi17(base::StringPiece message,
150 int masking_key,
151 bool compressed,
152 WebSocketFrameHeader::OpCodeEnum op_code,
153 std::string* output) {
154 std::vector<char> frame;
155 size_t data_length = message.length();
156
157 int reserved1 = compressed ? kReserved1Bit : 0;
158 frame.push_back(kFinalBit | op_code | reserved1);
159 char mask_key_bit = masking_key != 0 ? kMaskBit : 0;
160 if (data_length <= kMaxSingleBytePayloadLength) {
161 frame.push_back(static_cast<char>(data_length) | mask_key_bit);
162 } else if (data_length <= 0xFFFF) {
163 frame.push_back(kTwoBytePayloadLengthField | mask_key_bit);
164 frame.push_back((data_length & 0xFF00) >> 8);
165 frame.push_back(data_length & 0xFF);
166 } else {
167 frame.push_back(kEightBytePayloadLengthField | mask_key_bit);
168 char extended_payload_length[8];
169 size_t remaining = data_length;
170 // Fill the length into extended_payload_length in the network byte order.
171 for (int i = 0; i < 8; ++i) {
172 extended_payload_length[7 - i] = remaining & 0xFF;
173 remaining >>= 8;
174 }
175 frame.insert(frame.end(), extended_payload_length,
176 extended_payload_length + 8);
177 DCHECK(!remaining);
178 }
179
180 const char* data = const_cast<char*>(message.data());
181 if (masking_key != 0) {
182 const char* mask_bytes = reinterpret_cast<char*>(&masking_key);
183 frame.insert(frame.end(), mask_bytes, mask_bytes + 4);
184 for (size_t i = 0; i < data_length; ++i) // Mask the payload.
185 frame.push_back(data[i] ^ mask_bytes[i % kMaskingKeyWidthInBytes]);
186 } else {
187 frame.insert(frame.end(), data, data + data_length);
188 }
189 *output = std::string(frame.data(), frame.size());
190 }
191
192 } // anonymous namespace
193
194 // static
CreateServer()195 std::unique_ptr<WebSocketEncoder> WebSocketEncoder::CreateServer() {
196 return base::WrapUnique(new WebSocketEncoder(FOR_SERVER, nullptr, nullptr));
197 }
198
199 // static
CreateServer(const std::string & extensions,WebSocketDeflateParameters * deflate_parameters)200 std::unique_ptr<WebSocketEncoder> WebSocketEncoder::CreateServer(
201 const std::string& extensions,
202 WebSocketDeflateParameters* deflate_parameters) {
203 WebSocketExtensionParser parser;
204 if (!parser.Parse(extensions)) {
205 // Failed to parse Sec-WebSocket-Extensions header. We MUST fail the
206 // connection.
207 return nullptr;
208 }
209
210 for (const auto& extension : parser.extensions()) {
211 std::string failure_message;
212 WebSocketDeflateParameters offer;
213 if (!offer.Initialize(extension, &failure_message) ||
214 !offer.IsValidAsRequest(&failure_message)) {
215 // We decline unknown / malformed extensions.
216 continue;
217 }
218
219 WebSocketDeflateParameters response = offer;
220 if (offer.is_client_max_window_bits_specified() &&
221 !offer.has_client_max_window_bits_value()) {
222 // We need to choose one value for the response.
223 response.SetClientMaxWindowBits(15);
224 }
225 DCHECK(response.IsValidAsResponse());
226 DCHECK(offer.IsCompatibleWith(response));
227 auto deflater = std::make_unique<WebSocketDeflater>(
228 response.server_context_take_over_mode());
229 auto inflater = std::make_unique<WebSocketInflater>(kInflaterChunkSize,
230 kInflaterChunkSize);
231 if (!deflater->Initialize(response.PermissiveServerMaxWindowBits()) ||
232 !inflater->Initialize(response.PermissiveClientMaxWindowBits())) {
233 // For some reason we cannot accept the parameters.
234 continue;
235 }
236 *deflate_parameters = response;
237 return base::WrapUnique(new WebSocketEncoder(
238 FOR_SERVER, std::move(deflater), std::move(inflater)));
239 }
240
241 // We cannot find an acceptable offer.
242 return base::WrapUnique(new WebSocketEncoder(FOR_SERVER, nullptr, nullptr));
243 }
244
245 // static
CreateClient(const std::string & response_extensions)246 std::unique_ptr<WebSocketEncoder> WebSocketEncoder::CreateClient(
247 const std::string& response_extensions) {
248 // TODO(yhirano): Add a way to return an error.
249
250 WebSocketExtensionParser parser;
251 if (!parser.Parse(response_extensions)) {
252 // Parse error. Note that there are two cases here.
253 // 1) There is no Sec-WebSocket-Extensions header.
254 // 2) There is a malformed Sec-WebSocketExtensions header.
255 // We should return a deflate-disabled encoder for the former case and
256 // fail the connection for the latter case.
257 return base::WrapUnique(new WebSocketEncoder(FOR_CLIENT, nullptr, nullptr));
258 }
259 if (parser.extensions().size() != 1) {
260 // Only permessage-deflate extension is supported.
261 // TODO (yhirano): Fail the connection.
262 return base::WrapUnique(new WebSocketEncoder(FOR_CLIENT, nullptr, nullptr));
263 }
264 const auto& extension = parser.extensions()[0];
265 WebSocketDeflateParameters params;
266 std::string failure_message;
267 if (!params.Initialize(extension, &failure_message) ||
268 !params.IsValidAsResponse(&failure_message)) {
269 // TODO (yhirano): Fail the connection.
270 return base::WrapUnique(new WebSocketEncoder(FOR_CLIENT, nullptr, nullptr));
271 }
272
273 auto deflater = std::make_unique<WebSocketDeflater>(
274 params.client_context_take_over_mode());
275 auto inflater = std::make_unique<WebSocketInflater>(kInflaterChunkSize,
276 kInflaterChunkSize);
277 if (!deflater->Initialize(params.PermissiveClientMaxWindowBits()) ||
278 !inflater->Initialize(params.PermissiveServerMaxWindowBits())) {
279 // TODO (yhirano): Fail the connection.
280 return base::WrapUnique(new WebSocketEncoder(FOR_CLIENT, nullptr, nullptr));
281 }
282
283 return base::WrapUnique(new WebSocketEncoder(FOR_CLIENT, std::move(deflater),
284 std::move(inflater)));
285 }
286
WebSocketEncoder(Type type,std::unique_ptr<WebSocketDeflater> deflater,std::unique_ptr<WebSocketInflater> inflater)287 WebSocketEncoder::WebSocketEncoder(Type type,
288 std::unique_ptr<WebSocketDeflater> deflater,
289 std::unique_ptr<WebSocketInflater> inflater)
290 : type_(type),
291 deflater_(std::move(deflater)),
292 inflater_(std::move(inflater)) {}
293
294 WebSocketEncoder::~WebSocketEncoder() = default;
295
DecodeFrame(base::StringPiece frame,int * bytes_consumed,std::string * output)296 WebSocket::ParseResult WebSocketEncoder::DecodeFrame(base::StringPiece frame,
297 int* bytes_consumed,
298 std::string* output) {
299 bool compressed;
300 std::string current_output;
301 WebSocket::ParseResult result = DecodeFrameHybi17(
302 frame, type_ == FOR_SERVER, bytes_consumed, ¤t_output, &compressed);
303 switch (result) {
304 case WebSocket::FRAME_OK_FINAL:
305 case WebSocket::FRAME_OK_MIDDLE: {
306 if (continuation_message_frames_.empty())
307 is_current_message_compressed_ = compressed;
308 continuation_message_frames_.push_back(current_output);
309
310 if (result == WebSocket::FRAME_OK_FINAL) {
311 *output = base::StrCat(continuation_message_frames_);
312 continuation_message_frames_.clear();
313 if (is_current_message_compressed_ && !Inflate(output)) {
314 return WebSocket::FRAME_ERROR;
315 }
316 }
317 break;
318 }
319
320 case WebSocket::FRAME_PING:
321 *output = current_output;
322 break;
323
324 default:
325 // This function doesn't need special handling for other parse results.
326 break;
327 }
328
329 return result;
330 }
331
EncodeTextFrame(base::StringPiece frame,int masking_key,std::string * output)332 void WebSocketEncoder::EncodeTextFrame(base::StringPiece frame,
333 int masking_key,
334 std::string* output) {
335 std::string compressed;
336 constexpr auto op_code = WebSocketFrameHeader::OpCodeEnum::kOpCodeText;
337 if (Deflate(frame, &compressed))
338 EncodeFrameHybi17(compressed, masking_key, true, op_code, output);
339 else
340 EncodeFrameHybi17(frame, masking_key, false, op_code, output);
341 }
342
EncodePongFrame(base::StringPiece frame,int masking_key,std::string * output)343 void WebSocketEncoder::EncodePongFrame(base::StringPiece frame,
344 int masking_key,
345 std::string* output) {
346 constexpr auto op_code = WebSocketFrameHeader::OpCodeEnum::kOpCodePong;
347 EncodeFrameHybi17(frame, masking_key, false, op_code, output);
348 }
349
Inflate(std::string * message)350 bool WebSocketEncoder::Inflate(std::string* message) {
351 if (!inflater_)
352 return false;
353 if (!inflater_->AddBytes(message->data(), message->length()))
354 return false;
355 if (!inflater_->Finish())
356 return false;
357
358 std::vector<char> output;
359 while (inflater_->CurrentOutputSize() > 0) {
360 scoped_refptr<IOBufferWithSize> chunk =
361 inflater_->GetOutput(inflater_->CurrentOutputSize());
362 if (!chunk.get())
363 return false;
364 output.insert(output.end(), chunk->data(), chunk->data() + chunk->size());
365 }
366
367 *message =
368 output.size() ? std::string(output.data(), output.size()) : std::string();
369 return true;
370 }
371
Deflate(base::StringPiece message,std::string * output)372 bool WebSocketEncoder::Deflate(base::StringPiece message, std::string* output) {
373 if (!deflater_)
374 return false;
375 if (!deflater_->AddBytes(message.data(), message.length())) {
376 deflater_->Finish();
377 return false;
378 }
379 if (!deflater_->Finish())
380 return false;
381 scoped_refptr<IOBufferWithSize> buffer =
382 deflater_->GetOutput(deflater_->CurrentOutputSize());
383 if (!buffer.get())
384 return false;
385 *output = std::string(buffer->data(), buffer->size());
386 return true;
387 }
388
389 } // namespace net
390