1 // Copyright (c) 2009 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include <algorithm>
6 #include <limits>
7
8 #include "net/websockets/websocket.h"
9
10 #include "base/message_loop.h"
11 #include "net/http/http_response_headers.h"
12 #include "net/http/http_util.h"
13
14 namespace net {
15
16 static const int kWebSocketPort = 80;
17 static const int kSecureWebSocketPort = 443;
18
19 static const char kServerHandshakeHeader[] =
20 "HTTP/1.1 101 Web Socket Protocol Handshake\r\n";
21 static const size_t kServerHandshakeHeaderLength =
22 sizeof(kServerHandshakeHeader) - 1;
23
24 static const char kUpgradeHeader[] = "Upgrade: WebSocket\r\n";
25 static const size_t kUpgradeHeaderLength = sizeof(kUpgradeHeader) - 1;
26
27 static const char kConnectionHeader[] = "Connection: Upgrade\r\n";
28 static const size_t kConnectionHeaderLength = sizeof(kConnectionHeader) - 1;
29
is_secure() const30 bool WebSocket::Request::is_secure() const {
31 return url_.SchemeIs("wss");
32 }
33
WebSocket(Request * request,WebSocketDelegate * delegate)34 WebSocket::WebSocket(Request* request, WebSocketDelegate* delegate)
35 : ready_state_(INITIALIZED),
36 mode_(MODE_INCOMPLETE),
37 request_(request),
38 delegate_(delegate),
39 origin_loop_(MessageLoop::current()),
40 socket_stream_(NULL),
41 max_pending_send_allowed_(0),
42 current_read_buf_(NULL),
43 read_consumed_len_(0),
44 current_write_buf_(NULL) {
45 DCHECK(request_.get());
46 DCHECK(delegate_);
47 DCHECK(origin_loop_);
48 }
49
~WebSocket()50 WebSocket::~WebSocket() {
51 DCHECK(ready_state_ == INITIALIZED || !delegate_);
52 DCHECK(!socket_stream_);
53 DCHECK(!delegate_);
54 }
55
Connect()56 void WebSocket::Connect() {
57 DCHECK(ready_state_ == INITIALIZED);
58 DCHECK(request_.get());
59 DCHECK(delegate_);
60 DCHECK(!socket_stream_);
61 DCHECK(MessageLoop::current() == origin_loop_);
62
63 socket_stream_ = new SocketStream(request_->url(), this);
64 socket_stream_->set_context(request_->context());
65
66 if (request_->host_resolver())
67 socket_stream_->SetHostResolver(request_->host_resolver());
68 if (request_->client_socket_factory())
69 socket_stream_->SetClientSocketFactory(request_->client_socket_factory());
70
71 AddRef(); // Release in DoClose().
72 ready_state_ = CONNECTING;
73 socket_stream_->Connect();
74 }
75
Send(const std::string & msg)76 void WebSocket::Send(const std::string& msg) {
77 DCHECK(ready_state_ == OPEN);
78 DCHECK(MessageLoop::current() == origin_loop_);
79
80 IOBufferWithSize* buf = new IOBufferWithSize(msg.size() + 2);
81 char* p = buf->data();
82 *p = '\0';
83 memcpy(p + 1, msg.data(), msg.size());
84 *(p + 1 + msg.size()) = '\xff';
85 pending_write_bufs_.push_back(buf);
86 SendPending();
87 }
88
Close()89 void WebSocket::Close() {
90 DCHECK(MessageLoop::current() == origin_loop_);
91
92 if (ready_state_ == INITIALIZED) {
93 DCHECK(!socket_stream_);
94 ready_state_ = CLOSED;
95 return;
96 }
97 if (ready_state_ != CLOSED) {
98 DCHECK(socket_stream_);
99 socket_stream_->Close();
100 return;
101 }
102 }
103
DetachDelegate()104 void WebSocket::DetachDelegate() {
105 if (!delegate_)
106 return;
107 delegate_ = NULL;
108 Close();
109 }
110
OnConnected(SocketStream * socket_stream,int max_pending_send_allowed)111 void WebSocket::OnConnected(SocketStream* socket_stream,
112 int max_pending_send_allowed) {
113 DCHECK(socket_stream == socket_stream_);
114 max_pending_send_allowed_ = max_pending_send_allowed;
115
116 // Use |max_pending_send_allowed| as hint for initial size of read buffer.
117 current_read_buf_ = new GrowableIOBuffer();
118 current_read_buf_->SetCapacity(max_pending_send_allowed_);
119 read_consumed_len_ = 0;
120
121 DCHECK(!current_write_buf_);
122 const std::string msg = request_->CreateClientHandshakeMessage();
123 IOBufferWithSize* buf = new IOBufferWithSize(msg.size());
124 memcpy(buf->data(), msg.data(), msg.size());
125 pending_write_bufs_.push_back(buf);
126 origin_loop_->PostTask(FROM_HERE,
127 NewRunnableMethod(this, &WebSocket::SendPending));
128 }
129
OnSentData(SocketStream * socket_stream,int amount_sent)130 void WebSocket::OnSentData(SocketStream* socket_stream, int amount_sent) {
131 DCHECK(socket_stream == socket_stream_);
132 DCHECK(current_write_buf_);
133 current_write_buf_->DidConsume(amount_sent);
134 DCHECK_GE(current_write_buf_->BytesRemaining(), 0);
135 if (current_write_buf_->BytesRemaining() == 0) {
136 current_write_buf_ = NULL;
137 pending_write_bufs_.pop_front();
138 }
139 origin_loop_->PostTask(FROM_HERE,
140 NewRunnableMethod(this, &WebSocket::SendPending));
141 }
142
OnReceivedData(SocketStream * socket_stream,const char * data,int len)143 void WebSocket::OnReceivedData(SocketStream* socket_stream,
144 const char* data, int len) {
145 DCHECK(socket_stream == socket_stream_);
146 AddToReadBuffer(data, len);
147 origin_loop_->PostTask(FROM_HERE,
148 NewRunnableMethod(this, &WebSocket::DoReceivedData));
149 }
150
OnClose(SocketStream * socket_stream)151 void WebSocket::OnClose(SocketStream* socket_stream) {
152 origin_loop_->PostTask(FROM_HERE,
153 NewRunnableMethod(this, &WebSocket::DoClose));
154 }
155
OnError(const SocketStream * socket_stream,int error)156 void WebSocket::OnError(const SocketStream* socket_stream, int error) {
157 origin_loop_->PostTask(FROM_HERE,
158 NewRunnableMethod(this, &WebSocket::DoError, error));
159 }
160
CreateClientHandshakeMessage() const161 std::string WebSocket::Request::CreateClientHandshakeMessage() const {
162 std::string msg;
163 msg = "GET ";
164 msg += url_.path();
165 if (url_.has_query()) {
166 msg += "?";
167 msg += url_.query();
168 }
169 msg += " HTTP/1.1\r\n";
170 msg += kUpgradeHeader;
171 msg += kConnectionHeader;
172 msg += "Host: ";
173 msg += StringToLowerASCII(url_.host());
174 if (url_.has_port()) {
175 bool secure = is_secure();
176 int port = url_.EffectiveIntPort();
177 if ((!secure &&
178 port != kWebSocketPort && port != url_parse::PORT_UNSPECIFIED) ||
179 (secure &&
180 port != kSecureWebSocketPort && port != url_parse::PORT_UNSPECIFIED)) {
181 msg += ":";
182 msg += IntToString(port);
183 }
184 }
185 msg += "\r\n";
186 msg += "Origin: ";
187 // It's OK to lowercase the origin as the Origin header does not contain
188 // the path or query portions, as per
189 // http://tools.ietf.org/html/draft-abarth-origin-00.
190 //
191 // TODO(satorux): Should we trim the port portion here if it's 80 for
192 // http:// or 443 for https:// ? Or can we assume it's done by the
193 // client of the library?
194 msg += StringToLowerASCII(origin_);
195 msg += "\r\n";
196 if (!protocol_.empty()) {
197 msg += "WebSocket-Protocol: ";
198 msg += protocol_;
199 msg += "\r\n";
200 }
201 // TODO(ukai): Add cookie if necessary.
202 msg += "\r\n";
203 return msg;
204 }
205
CheckHandshake()206 int WebSocket::CheckHandshake() {
207 DCHECK(current_read_buf_);
208 DCHECK(ready_state_ == CONNECTING);
209 mode_ = MODE_INCOMPLETE;
210 const char *start = current_read_buf_->StartOfBuffer() + read_consumed_len_;
211 const char *p = start;
212 size_t len = current_read_buf_->offset() - read_consumed_len_;
213 if (len < kServerHandshakeHeaderLength) {
214 return -1;
215 }
216 if (!memcmp(p, kServerHandshakeHeader, kServerHandshakeHeaderLength)) {
217 mode_ = MODE_NORMAL;
218 } else {
219 int eoh = HttpUtil::LocateEndOfHeaders(p, len);
220 if (eoh < 0)
221 return -1;
222 scoped_refptr<HttpResponseHeaders> headers(
223 new HttpResponseHeaders(HttpUtil::AssembleRawHeaders(p, eoh)));
224 if (headers->response_code() == 407) {
225 mode_ = MODE_AUTHENTICATE;
226 // TODO(ukai): Implement authentication handlers.
227 }
228 DLOG(INFO) << "non-normal websocket connection. "
229 << "response_code=" << headers->response_code()
230 << " mode=" << mode_;
231 // Invalid response code.
232 ready_state_ = CLOSED;
233 return eoh;
234 }
235 const char* end = p + len + 1;
236 p += kServerHandshakeHeaderLength;
237
238 if (mode_ == MODE_NORMAL) {
239 size_t header_size = end - p;
240 if (header_size < kUpgradeHeaderLength)
241 return -1;
242 if (memcmp(p, kUpgradeHeader, kUpgradeHeaderLength)) {
243 DLOG(INFO) << "Bad Upgrade Header "
244 << std::string(p, kUpgradeHeaderLength);
245 ready_state_ = CLOSED;
246 return p - start;
247 }
248 p += kUpgradeHeaderLength;
249
250 header_size = end - p;
251 if (header_size < kConnectionHeaderLength)
252 return -1;
253 if (memcmp(p, kConnectionHeader, kConnectionHeaderLength)) {
254 DLOG(INFO) << "Bad Connection Header "
255 << std::string(p, kConnectionHeaderLength);
256 ready_state_ = CLOSED;
257 return p - start;
258 }
259 p += kConnectionHeaderLength;
260 }
261 int eoh = HttpUtil::LocateEndOfHeaders(start, len);
262 if (eoh == -1)
263 return eoh;
264 scoped_refptr<HttpResponseHeaders> headers(
265 new HttpResponseHeaders(HttpUtil::AssembleRawHeaders(start, eoh)));
266 if (!ProcessHeaders(*headers)) {
267 DLOG(INFO) << "Process Headers failed: "
268 << std::string(start, eoh);
269 ready_state_ = CLOSED;
270 return eoh;
271 }
272 switch (mode_) {
273 case MODE_NORMAL:
274 if (CheckResponseHeaders()) {
275 ready_state_ = OPEN;
276 } else {
277 ready_state_ = CLOSED;
278 }
279 break;
280 default:
281 ready_state_ = CLOSED;
282 break;
283 }
284 if (ready_state_ == CLOSED)
285 DLOG(INFO) << "CheckHandshake mode=" << mode_
286 << " " << std::string(start, eoh);
287 return eoh;
288 }
289
290 // Gets the value of the specified header.
291 // It assures only one header of |name| in |headers|.
292 // Returns true iff single header of |name| is found in |headers|
293 // and |value| is filled with the value.
294 // Returns false otherwise.
GetSingleHeader(const HttpResponseHeaders & headers,const std::string & name,std::string * value)295 static bool GetSingleHeader(const HttpResponseHeaders& headers,
296 const std::string& name,
297 std::string* value) {
298 std::string first_value;
299 void* iter = NULL;
300 if (!headers.EnumerateHeader(&iter, name, &first_value))
301 return false;
302
303 // Checks no more |name| found in |headers|.
304 // Second call of EnumerateHeader() must return false.
305 std::string second_value;
306 if (headers.EnumerateHeader(&iter, name, &second_value))
307 return false;
308 *value = first_value;
309 return true;
310 }
311
ProcessHeaders(const HttpResponseHeaders & headers)312 bool WebSocket::ProcessHeaders(const HttpResponseHeaders& headers) {
313 if (!GetSingleHeader(headers, "websocket-origin", &ws_origin_))
314 return false;
315
316 if (!GetSingleHeader(headers, "websocket-location", &ws_location_))
317 return false;
318
319 if (!request_->protocol().empty()
320 && !GetSingleHeader(headers, "websocket-protocol", &ws_protocol_))
321 return false;
322 return true;
323 }
324
CheckResponseHeaders() const325 bool WebSocket::CheckResponseHeaders() const {
326 DCHECK(mode_ == MODE_NORMAL);
327 if (!LowerCaseEqualsASCII(request_->origin(), ws_origin_.c_str()))
328 return false;
329 if (request_->location() != ws_location_)
330 return false;
331 if (request_->protocol() != ws_protocol_)
332 return false;
333 return true;
334 }
335
SendPending()336 void WebSocket::SendPending() {
337 DCHECK(MessageLoop::current() == origin_loop_);
338 DCHECK(socket_stream_);
339 if (!current_write_buf_) {
340 if (pending_write_bufs_.empty())
341 return;
342 current_write_buf_ = new DrainableIOBuffer(
343 pending_write_bufs_.front(), pending_write_bufs_.front()->size());
344 }
345 DCHECK_GT(current_write_buf_->BytesRemaining(), 0);
346 bool sent = socket_stream_->SendData(
347 current_write_buf_->data(),
348 std::min(current_write_buf_->BytesRemaining(),
349 max_pending_send_allowed_));
350 DCHECK(sent);
351 }
352
DoReceivedData()353 void WebSocket::DoReceivedData() {
354 DCHECK(MessageLoop::current() == origin_loop_);
355 switch (ready_state_) {
356 case CONNECTING:
357 {
358 int eoh = CheckHandshake();
359 if (eoh < 0) {
360 // Not enough data, Retry when more data is available.
361 return;
362 }
363 SkipReadBuffer(eoh);
364 }
365 if (ready_state_ != OPEN) {
366 // Handshake failed.
367 socket_stream_->Close();
368 return;
369 }
370 if (delegate_)
371 delegate_->OnOpen(this);
372 if (current_read_buf_->offset() == read_consumed_len_) {
373 // No remaining data after handshake message.
374 break;
375 }
376 // FALL THROUGH
377 case OPEN:
378 ProcessFrameData();
379 break;
380
381 case CLOSED:
382 // Closed just after DoReceivedData is queued on |origin_loop_|.
383 break;
384 default:
385 NOTREACHED();
386 break;
387 }
388 }
389
ProcessFrameData()390 void WebSocket::ProcessFrameData() {
391 DCHECK(current_read_buf_);
392 const char* start_frame =
393 current_read_buf_->StartOfBuffer() + read_consumed_len_;
394 const char* next_frame = start_frame;
395 const char* p = next_frame;
396 const char* end =
397 current_read_buf_->StartOfBuffer() + current_read_buf_->offset();
398 while (p < end) {
399 unsigned char frame_byte = static_cast<unsigned char>(*p++);
400 if ((frame_byte & 0x80) == 0x80) {
401 int length = 0;
402 while (p < end) {
403 if (length > std::numeric_limits<int>::max() / 128) {
404 // frame length overflow.
405 socket_stream_->Close();
406 return;
407 }
408 unsigned char c = static_cast<unsigned char>(*p);
409 length = length * 128 + (c & 0x7f);
410 ++p;
411 if ((c & 0x80) != 0x80)
412 break;
413 }
414 // Checks if the frame body hasn't been completely received yet.
415 // It also checks the case the frame length bytes haven't been completely
416 // received yet, because p == end and length > 0 in such case.
417 if (p + length < end) {
418 p += length;
419 next_frame = p;
420 } else {
421 break;
422 }
423 } else {
424 const char* msg_start = p;
425 while (p < end && *p != '\xff')
426 ++p;
427 if (p < end && *p == '\xff') {
428 if (frame_byte == 0x00 && delegate_)
429 delegate_->OnMessage(this, std::string(msg_start, p - msg_start));
430 ++p;
431 next_frame = p;
432 }
433 }
434 }
435 SkipReadBuffer(next_frame - start_frame);
436 }
437
AddToReadBuffer(const char * data,int len)438 void WebSocket::AddToReadBuffer(const char* data, int len) {
439 DCHECK(current_read_buf_);
440 // Check if |current_read_buf_| has enough space to store |len| of |data|.
441 if (len >= current_read_buf_->RemainingCapacity()) {
442 current_read_buf_->SetCapacity(
443 current_read_buf_->offset() + len);
444 }
445
446 DCHECK(current_read_buf_->RemainingCapacity() >= len);
447 memcpy(current_read_buf_->data(), data, len);
448 current_read_buf_->set_offset(current_read_buf_->offset() + len);
449 }
450
SkipReadBuffer(int len)451 void WebSocket::SkipReadBuffer(int len) {
452 if (len == 0)
453 return;
454 DCHECK_GT(len, 0);
455 read_consumed_len_ += len;
456 int remaining = current_read_buf_->offset() - read_consumed_len_;
457 DCHECK_GE(remaining, 0);
458 if (remaining < read_consumed_len_ &&
459 current_read_buf_->RemainingCapacity() < read_consumed_len_) {
460 // Pre compaction:
461 // 0 v-read_consumed_len_ v-offset v- capacity
462 // |..processed..| .. remaining .. | .. RemainingCapacity |
463 //
464 memmove(current_read_buf_->StartOfBuffer(),
465 current_read_buf_->StartOfBuffer() + read_consumed_len_,
466 remaining);
467 read_consumed_len_ = 0;
468 current_read_buf_->set_offset(remaining);
469 // Post compaction:
470 // 0read_consumed_len_ v- offset v- capacity
471 // |.. remaining .. | .. RemainingCapacity ... |
472 //
473 }
474 }
475
DoClose()476 void WebSocket::DoClose() {
477 DCHECK(MessageLoop::current() == origin_loop_);
478 WebSocketDelegate* delegate = delegate_;
479 delegate_ = NULL;
480 ready_state_ = CLOSED;
481 if (!socket_stream_)
482 return;
483 socket_stream_ = NULL;
484 if (delegate)
485 delegate->OnClose(this);
486 Release();
487 }
488
DoError(int error)489 void WebSocket::DoError(int error) {
490 DCHECK(MessageLoop::current() == origin_loop_);
491 if (delegate_)
492 delegate_->OnError(this, error);
493 }
494
495 } // namespace net
496