• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2009 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include <algorithm>
6 #include <limits>
7 
8 #include "net/websockets/websocket.h"
9 
10 #include "base/message_loop.h"
11 #include "net/http/http_response_headers.h"
12 #include "net/http/http_util.h"
13 
14 namespace net {
15 
16 static const int kWebSocketPort = 80;
17 static const int kSecureWebSocketPort = 443;
18 
19 static const char kServerHandshakeHeader[] =
20     "HTTP/1.1 101 Web Socket Protocol Handshake\r\n";
21 static const size_t kServerHandshakeHeaderLength =
22     sizeof(kServerHandshakeHeader) - 1;
23 
24 static const char kUpgradeHeader[] = "Upgrade: WebSocket\r\n";
25 static const size_t kUpgradeHeaderLength = sizeof(kUpgradeHeader) - 1;
26 
27 static const char kConnectionHeader[] = "Connection: Upgrade\r\n";
28 static const size_t kConnectionHeaderLength = sizeof(kConnectionHeader) - 1;
29 
is_secure() const30 bool WebSocket::Request::is_secure() const {
31   return url_.SchemeIs("wss");
32 }
33 
WebSocket(Request * request,WebSocketDelegate * delegate)34 WebSocket::WebSocket(Request* request, WebSocketDelegate* delegate)
35     : ready_state_(INITIALIZED),
36       mode_(MODE_INCOMPLETE),
37       request_(request),
38       delegate_(delegate),
39       origin_loop_(MessageLoop::current()),
40       socket_stream_(NULL),
41       max_pending_send_allowed_(0),
42       current_read_buf_(NULL),
43       read_consumed_len_(0),
44       current_write_buf_(NULL) {
45   DCHECK(request_.get());
46   DCHECK(delegate_);
47   DCHECK(origin_loop_);
48 }
49 
~WebSocket()50 WebSocket::~WebSocket() {
51   DCHECK(ready_state_ == INITIALIZED || !delegate_);
52   DCHECK(!socket_stream_);
53   DCHECK(!delegate_);
54 }
55 
Connect()56 void WebSocket::Connect() {
57   DCHECK(ready_state_ == INITIALIZED);
58   DCHECK(request_.get());
59   DCHECK(delegate_);
60   DCHECK(!socket_stream_);
61   DCHECK(MessageLoop::current() == origin_loop_);
62 
63   socket_stream_ = new SocketStream(request_->url(), this);
64   socket_stream_->set_context(request_->context());
65 
66   if (request_->host_resolver())
67     socket_stream_->SetHostResolver(request_->host_resolver());
68   if (request_->client_socket_factory())
69     socket_stream_->SetClientSocketFactory(request_->client_socket_factory());
70 
71   AddRef();  // Release in DoClose().
72   ready_state_ = CONNECTING;
73   socket_stream_->Connect();
74 }
75 
Send(const std::string & msg)76 void WebSocket::Send(const std::string& msg) {
77   DCHECK(ready_state_ == OPEN);
78   DCHECK(MessageLoop::current() == origin_loop_);
79 
80   IOBufferWithSize* buf = new IOBufferWithSize(msg.size() + 2);
81   char* p = buf->data();
82   *p = '\0';
83   memcpy(p + 1, msg.data(), msg.size());
84   *(p + 1 + msg.size()) = '\xff';
85   pending_write_bufs_.push_back(buf);
86   SendPending();
87 }
88 
Close()89 void WebSocket::Close() {
90   DCHECK(MessageLoop::current() == origin_loop_);
91 
92   if (ready_state_ == INITIALIZED) {
93     DCHECK(!socket_stream_);
94     ready_state_ = CLOSED;
95     return;
96   }
97   if (ready_state_ != CLOSED) {
98     DCHECK(socket_stream_);
99     socket_stream_->Close();
100     return;
101   }
102 }
103 
DetachDelegate()104 void WebSocket::DetachDelegate() {
105   if (!delegate_)
106     return;
107   delegate_ = NULL;
108   Close();
109 }
110 
OnConnected(SocketStream * socket_stream,int max_pending_send_allowed)111 void WebSocket::OnConnected(SocketStream* socket_stream,
112                             int max_pending_send_allowed) {
113   DCHECK(socket_stream == socket_stream_);
114   max_pending_send_allowed_ = max_pending_send_allowed;
115 
116   // Use |max_pending_send_allowed| as hint for initial size of read buffer.
117   current_read_buf_ = new GrowableIOBuffer();
118   current_read_buf_->SetCapacity(max_pending_send_allowed_);
119   read_consumed_len_ = 0;
120 
121   DCHECK(!current_write_buf_);
122   const std::string msg = request_->CreateClientHandshakeMessage();
123   IOBufferWithSize* buf = new IOBufferWithSize(msg.size());
124   memcpy(buf->data(), msg.data(), msg.size());
125   pending_write_bufs_.push_back(buf);
126   origin_loop_->PostTask(FROM_HERE,
127                          NewRunnableMethod(this, &WebSocket::SendPending));
128 }
129 
OnSentData(SocketStream * socket_stream,int amount_sent)130 void WebSocket::OnSentData(SocketStream* socket_stream, int amount_sent) {
131   DCHECK(socket_stream == socket_stream_);
132   DCHECK(current_write_buf_);
133   current_write_buf_->DidConsume(amount_sent);
134   DCHECK_GE(current_write_buf_->BytesRemaining(), 0);
135   if (current_write_buf_->BytesRemaining() == 0) {
136     current_write_buf_ = NULL;
137     pending_write_bufs_.pop_front();
138   }
139   origin_loop_->PostTask(FROM_HERE,
140                          NewRunnableMethod(this, &WebSocket::SendPending));
141 }
142 
OnReceivedData(SocketStream * socket_stream,const char * data,int len)143 void WebSocket::OnReceivedData(SocketStream* socket_stream,
144                                const char* data, int len) {
145   DCHECK(socket_stream == socket_stream_);
146   AddToReadBuffer(data, len);
147   origin_loop_->PostTask(FROM_HERE,
148                          NewRunnableMethod(this, &WebSocket::DoReceivedData));
149 }
150 
OnClose(SocketStream * socket_stream)151 void WebSocket::OnClose(SocketStream* socket_stream) {
152   origin_loop_->PostTask(FROM_HERE,
153                          NewRunnableMethod(this, &WebSocket::DoClose));
154 }
155 
OnError(const SocketStream * socket_stream,int error)156 void WebSocket::OnError(const SocketStream* socket_stream, int error) {
157   origin_loop_->PostTask(FROM_HERE,
158                          NewRunnableMethod(this, &WebSocket::DoError, error));
159 }
160 
CreateClientHandshakeMessage() const161 std::string WebSocket::Request::CreateClientHandshakeMessage() const {
162   std::string msg;
163   msg = "GET ";
164   msg += url_.path();
165   if (url_.has_query()) {
166     msg += "?";
167     msg += url_.query();
168   }
169   msg += " HTTP/1.1\r\n";
170   msg += kUpgradeHeader;
171   msg += kConnectionHeader;
172   msg += "Host: ";
173   msg += StringToLowerASCII(url_.host());
174   if (url_.has_port()) {
175     bool secure = is_secure();
176     int port = url_.EffectiveIntPort();
177     if ((!secure &&
178          port != kWebSocketPort && port != url_parse::PORT_UNSPECIFIED) ||
179         (secure &&
180          port != kSecureWebSocketPort && port != url_parse::PORT_UNSPECIFIED)) {
181       msg += ":";
182       msg += IntToString(port);
183     }
184   }
185   msg += "\r\n";
186   msg += "Origin: ";
187   // It's OK to lowercase the origin as the Origin header does not contain
188   // the path or query portions, as per
189   // http://tools.ietf.org/html/draft-abarth-origin-00.
190   //
191   // TODO(satorux): Should we trim the port portion here if it's 80 for
192   // http:// or 443 for https:// ? Or can we assume it's done by the
193   // client of the library?
194   msg += StringToLowerASCII(origin_);
195   msg += "\r\n";
196   if (!protocol_.empty()) {
197     msg += "WebSocket-Protocol: ";
198     msg += protocol_;
199     msg += "\r\n";
200   }
201   // TODO(ukai): Add cookie if necessary.
202   msg += "\r\n";
203   return msg;
204 }
205 
CheckHandshake()206 int WebSocket::CheckHandshake() {
207   DCHECK(current_read_buf_);
208   DCHECK(ready_state_ == CONNECTING);
209   mode_ = MODE_INCOMPLETE;
210   const char *start = current_read_buf_->StartOfBuffer() + read_consumed_len_;
211   const char *p = start;
212   size_t len = current_read_buf_->offset() - read_consumed_len_;
213   if (len < kServerHandshakeHeaderLength) {
214     return -1;
215   }
216   if (!memcmp(p, kServerHandshakeHeader, kServerHandshakeHeaderLength)) {
217     mode_ = MODE_NORMAL;
218   } else {
219     int eoh = HttpUtil::LocateEndOfHeaders(p, len);
220     if (eoh < 0)
221       return -1;
222     scoped_refptr<HttpResponseHeaders> headers(
223         new HttpResponseHeaders(HttpUtil::AssembleRawHeaders(p, eoh)));
224     if (headers->response_code() == 407) {
225       mode_ = MODE_AUTHENTICATE;
226       // TODO(ukai): Implement authentication handlers.
227     }
228     DLOG(INFO) << "non-normal websocket connection. "
229                << "response_code=" << headers->response_code()
230                << " mode=" << mode_;
231     // Invalid response code.
232     ready_state_ = CLOSED;
233     return eoh;
234   }
235   const char* end = p + len + 1;
236   p += kServerHandshakeHeaderLength;
237 
238   if (mode_ == MODE_NORMAL) {
239     size_t header_size = end - p;
240     if (header_size < kUpgradeHeaderLength)
241       return -1;
242     if (memcmp(p, kUpgradeHeader, kUpgradeHeaderLength)) {
243       DLOG(INFO) << "Bad Upgrade Header "
244                  << std::string(p, kUpgradeHeaderLength);
245       ready_state_ = CLOSED;
246       return p - start;
247     }
248     p += kUpgradeHeaderLength;
249 
250     header_size = end - p;
251     if (header_size < kConnectionHeaderLength)
252       return -1;
253     if (memcmp(p, kConnectionHeader, kConnectionHeaderLength)) {
254       DLOG(INFO) << "Bad Connection Header "
255                  << std::string(p, kConnectionHeaderLength);
256       ready_state_ = CLOSED;
257       return p - start;
258     }
259     p += kConnectionHeaderLength;
260   }
261   int eoh = HttpUtil::LocateEndOfHeaders(start, len);
262   if (eoh == -1)
263     return eoh;
264   scoped_refptr<HttpResponseHeaders> headers(
265       new HttpResponseHeaders(HttpUtil::AssembleRawHeaders(start, eoh)));
266   if (!ProcessHeaders(*headers)) {
267     DLOG(INFO) << "Process Headers failed: "
268                << std::string(start, eoh);
269     ready_state_ = CLOSED;
270     return eoh;
271   }
272   switch (mode_) {
273     case MODE_NORMAL:
274       if (CheckResponseHeaders()) {
275         ready_state_ = OPEN;
276       } else {
277         ready_state_ = CLOSED;
278       }
279       break;
280     default:
281       ready_state_ = CLOSED;
282       break;
283   }
284   if (ready_state_ == CLOSED)
285     DLOG(INFO) << "CheckHandshake mode=" << mode_
286                << " " << std::string(start, eoh);
287   return eoh;
288 }
289 
290 // Gets the value of the specified header.
291 // It assures only one header of |name| in |headers|.
292 // Returns true iff single header of |name| is found in |headers|
293 // and |value| is filled with the value.
294 // Returns false otherwise.
GetSingleHeader(const HttpResponseHeaders & headers,const std::string & name,std::string * value)295 static bool GetSingleHeader(const HttpResponseHeaders& headers,
296                             const std::string& name,
297                             std::string* value) {
298   std::string first_value;
299   void* iter = NULL;
300   if (!headers.EnumerateHeader(&iter, name, &first_value))
301     return false;
302 
303   // Checks no more |name| found in |headers|.
304   // Second call of EnumerateHeader() must return false.
305   std::string second_value;
306   if (headers.EnumerateHeader(&iter, name, &second_value))
307     return false;
308   *value = first_value;
309   return true;
310 }
311 
ProcessHeaders(const HttpResponseHeaders & headers)312 bool WebSocket::ProcessHeaders(const HttpResponseHeaders& headers) {
313   if (!GetSingleHeader(headers, "websocket-origin", &ws_origin_))
314     return false;
315 
316   if (!GetSingleHeader(headers, "websocket-location", &ws_location_))
317     return false;
318 
319   if (!request_->protocol().empty()
320       && !GetSingleHeader(headers, "websocket-protocol", &ws_protocol_))
321     return false;
322   return true;
323 }
324 
CheckResponseHeaders() const325 bool WebSocket::CheckResponseHeaders() const {
326   DCHECK(mode_ == MODE_NORMAL);
327   if (!LowerCaseEqualsASCII(request_->origin(), ws_origin_.c_str()))
328     return false;
329   if (request_->location() != ws_location_)
330     return false;
331   if (request_->protocol() != ws_protocol_)
332     return false;
333   return true;
334 }
335 
SendPending()336 void WebSocket::SendPending() {
337   DCHECK(MessageLoop::current() == origin_loop_);
338   DCHECK(socket_stream_);
339   if (!current_write_buf_) {
340     if (pending_write_bufs_.empty())
341       return;
342     current_write_buf_ = new DrainableIOBuffer(
343         pending_write_bufs_.front(), pending_write_bufs_.front()->size());
344   }
345   DCHECK_GT(current_write_buf_->BytesRemaining(), 0);
346   bool sent = socket_stream_->SendData(
347       current_write_buf_->data(),
348       std::min(current_write_buf_->BytesRemaining(),
349                max_pending_send_allowed_));
350   DCHECK(sent);
351 }
352 
DoReceivedData()353 void WebSocket::DoReceivedData() {
354   DCHECK(MessageLoop::current() == origin_loop_);
355   switch (ready_state_) {
356     case CONNECTING:
357       {
358         int eoh = CheckHandshake();
359         if (eoh < 0) {
360           // Not enough data,  Retry when more data is available.
361           return;
362         }
363         SkipReadBuffer(eoh);
364       }
365       if (ready_state_ != OPEN) {
366         // Handshake failed.
367         socket_stream_->Close();
368         return;
369       }
370       if (delegate_)
371         delegate_->OnOpen(this);
372       if (current_read_buf_->offset() == read_consumed_len_) {
373         // No remaining data after handshake message.
374         break;
375       }
376       // FALL THROUGH
377     case OPEN:
378       ProcessFrameData();
379       break;
380 
381     case CLOSED:
382       // Closed just after DoReceivedData is queued on |origin_loop_|.
383       break;
384     default:
385       NOTREACHED();
386       break;
387   }
388 }
389 
ProcessFrameData()390 void WebSocket::ProcessFrameData() {
391   DCHECK(current_read_buf_);
392   const char* start_frame =
393       current_read_buf_->StartOfBuffer() + read_consumed_len_;
394   const char* next_frame = start_frame;
395   const char* p = next_frame;
396   const char* end =
397       current_read_buf_->StartOfBuffer() + current_read_buf_->offset();
398   while (p < end) {
399     unsigned char frame_byte = static_cast<unsigned char>(*p++);
400     if ((frame_byte & 0x80) == 0x80) {
401       int length = 0;
402       while (p < end) {
403         if (length > std::numeric_limits<int>::max() / 128) {
404           // frame length overflow.
405           socket_stream_->Close();
406           return;
407         }
408         unsigned char c = static_cast<unsigned char>(*p);
409         length = length * 128 + (c & 0x7f);
410         ++p;
411         if ((c & 0x80) != 0x80)
412           break;
413       }
414       // Checks if the frame body hasn't been completely received yet.
415       // It also checks the case the frame length bytes haven't been completely
416       // received yet, because p == end and length > 0 in such case.
417       if (p + length < end) {
418         p += length;
419         next_frame = p;
420       } else {
421         break;
422       }
423     } else {
424       const char* msg_start = p;
425       while (p < end && *p != '\xff')
426         ++p;
427       if (p < end && *p == '\xff') {
428         if (frame_byte == 0x00 && delegate_)
429           delegate_->OnMessage(this, std::string(msg_start, p - msg_start));
430         ++p;
431         next_frame = p;
432       }
433     }
434   }
435   SkipReadBuffer(next_frame - start_frame);
436 }
437 
AddToReadBuffer(const char * data,int len)438 void WebSocket::AddToReadBuffer(const char* data, int len) {
439   DCHECK(current_read_buf_);
440   // Check if |current_read_buf_| has enough space to store |len| of |data|.
441   if (len >= current_read_buf_->RemainingCapacity()) {
442     current_read_buf_->SetCapacity(
443         current_read_buf_->offset() + len);
444   }
445 
446   DCHECK(current_read_buf_->RemainingCapacity() >= len);
447   memcpy(current_read_buf_->data(), data, len);
448   current_read_buf_->set_offset(current_read_buf_->offset() + len);
449 }
450 
SkipReadBuffer(int len)451 void WebSocket::SkipReadBuffer(int len) {
452   if (len == 0)
453     return;
454   DCHECK_GT(len, 0);
455   read_consumed_len_ += len;
456   int remaining = current_read_buf_->offset() - read_consumed_len_;
457   DCHECK_GE(remaining, 0);
458   if (remaining < read_consumed_len_ &&
459       current_read_buf_->RemainingCapacity() < read_consumed_len_) {
460     // Pre compaction:
461     // 0             v-read_consumed_len_  v-offset               v- capacity
462     // |..processed..| .. remaining ..     | .. RemainingCapacity |
463     //
464     memmove(current_read_buf_->StartOfBuffer(),
465             current_read_buf_->StartOfBuffer() + read_consumed_len_,
466             remaining);
467     read_consumed_len_ = 0;
468     current_read_buf_->set_offset(remaining);
469     // Post compaction:
470     // 0read_consumed_len_  v- offset                             v- capacity
471     // |.. remaining ..     | ..  RemainingCapacity  ...          |
472     //
473   }
474 }
475 
DoClose()476 void WebSocket::DoClose() {
477   DCHECK(MessageLoop::current() == origin_loop_);
478   WebSocketDelegate* delegate = delegate_;
479   delegate_ = NULL;
480   ready_state_ = CLOSED;
481   if (!socket_stream_)
482     return;
483   socket_stream_ = NULL;
484   if (delegate)
485     delegate->OnClose(this);
486   Release();
487 }
488 
DoError(int error)489 void WebSocket::DoError(int error) {
490   DCHECK(MessageLoop::current() == origin_loop_);
491   if (delegate_)
492     delegate_->OnError(this, error);
493 }
494 
495 }  // namespace net
496