• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/websockets/websocket_basic_handshake_stream.h"
6 
7 #include <algorithm>
8 #include <iterator>
9 
10 #include "base/base64.h"
11 #include "base/basictypes.h"
12 #include "base/bind.h"
13 #include "base/containers/hash_tables.h"
14 #include "base/stl_util.h"
15 #include "base/strings/string_util.h"
16 #include "crypto/random.h"
17 #include "net/http/http_request_headers.h"
18 #include "net/http/http_request_info.h"
19 #include "net/http/http_response_body_drainer.h"
20 #include "net/http/http_response_headers.h"
21 #include "net/http/http_status_code.h"
22 #include "net/http/http_stream_parser.h"
23 #include "net/socket/client_socket_handle.h"
24 #include "net/websockets/websocket_basic_stream.h"
25 #include "net/websockets/websocket_handshake_constants.h"
26 #include "net/websockets/websocket_handshake_handler.h"
27 #include "net/websockets/websocket_stream.h"
28 
29 namespace net {
30 namespace {
31 
GenerateHandshakeChallenge()32 std::string GenerateHandshakeChallenge() {
33   std::string raw_challenge(websockets::kRawChallengeLength, '\0');
34   crypto::RandBytes(string_as_array(&raw_challenge), raw_challenge.length());
35   std::string encoded_challenge;
36   base::Base64Encode(raw_challenge, &encoded_challenge);
37   return encoded_challenge;
38 }
39 
AddVectorHeaderIfNonEmpty(const char * name,const std::vector<std::string> & value,HttpRequestHeaders * headers)40 void AddVectorHeaderIfNonEmpty(const char* name,
41                                const std::vector<std::string>& value,
42                                HttpRequestHeaders* headers) {
43   if (value.empty())
44     return;
45   headers->SetHeader(name, JoinString(value, ", "));
46 }
47 
48 // If |case_sensitive| is false, then |value| must be in lower-case.
ValidateSingleTokenHeader(const scoped_refptr<HttpResponseHeaders> & headers,const base::StringPiece & name,const std::string & value,bool case_sensitive)49 bool ValidateSingleTokenHeader(
50     const scoped_refptr<HttpResponseHeaders>& headers,
51     const base::StringPiece& name,
52     const std::string& value,
53     bool case_sensitive) {
54   void* state = NULL;
55   std::string token;
56   int tokens = 0;
57   bool has_value = false;
58   while (headers->EnumerateHeader(&state, name, &token)) {
59     if (++tokens > 1)
60       return false;
61     has_value = case_sensitive ? value == token
62                                : LowerCaseEqualsASCII(token, value.c_str());
63   }
64   return has_value;
65 }
66 
ValidateSubProtocol(const scoped_refptr<HttpResponseHeaders> & headers,const std::vector<std::string> & requested_sub_protocols,std::string * sub_protocol)67 bool ValidateSubProtocol(
68     const scoped_refptr<HttpResponseHeaders>& headers,
69     const std::vector<std::string>& requested_sub_protocols,
70     std::string* sub_protocol) {
71   void* state = NULL;
72   std::string token;
73   base::hash_set<std::string> requested_set(requested_sub_protocols.begin(),
74                                             requested_sub_protocols.end());
75   int accepted = 0;
76   while (headers->EnumerateHeader(
77       &state, websockets::kSecWebSocketProtocol, &token)) {
78     if (requested_set.count(token) == 0)
79       return false;
80 
81     *sub_protocol = token;
82     // The server is only allowed to accept one protocol.
83     if (++accepted > 1)
84       return false;
85   }
86   // If the browser requested > 0 protocols, the server is required to accept
87   // one.
88   return requested_set.empty() || accepted == 1;
89 }
90 
ValidateExtensions(const scoped_refptr<HttpResponseHeaders> & headers,const std::vector<std::string> & requested_extensions,std::string * extensions)91 bool ValidateExtensions(const scoped_refptr<HttpResponseHeaders>& headers,
92                         const std::vector<std::string>& requested_extensions,
93                         std::string* extensions) {
94   void* state = NULL;
95   std::string token;
96   while (headers->EnumerateHeader(
97       &state, websockets::kSecWebSocketExtensions, &token)) {
98     // TODO(ricea): Accept permessage-deflate with valid parameters.
99     return false;
100   }
101   return true;
102 }
103 
104 }  // namespace
105 
WebSocketBasicHandshakeStream(scoped_ptr<ClientSocketHandle> connection,bool using_proxy,std::vector<std::string> requested_sub_protocols,std::vector<std::string> requested_extensions)106 WebSocketBasicHandshakeStream::WebSocketBasicHandshakeStream(
107     scoped_ptr<ClientSocketHandle> connection,
108     bool using_proxy,
109     std::vector<std::string> requested_sub_protocols,
110     std::vector<std::string> requested_extensions)
111     : state_(connection.release(), using_proxy),
112       http_response_info_(NULL),
113       requested_sub_protocols_(requested_sub_protocols),
114       requested_extensions_(requested_extensions) {}
115 
~WebSocketBasicHandshakeStream()116 WebSocketBasicHandshakeStream::~WebSocketBasicHandshakeStream() {}
117 
InitializeStream(const HttpRequestInfo * request_info,RequestPriority priority,const BoundNetLog & net_log,const CompletionCallback & callback)118 int WebSocketBasicHandshakeStream::InitializeStream(
119     const HttpRequestInfo* request_info,
120     RequestPriority priority,
121     const BoundNetLog& net_log,
122     const CompletionCallback& callback) {
123   state_.Initialize(request_info, priority, net_log, callback);
124   return OK;
125 }
126 
SendRequest(const HttpRequestHeaders & headers,HttpResponseInfo * response,const CompletionCallback & callback)127 int WebSocketBasicHandshakeStream::SendRequest(
128     const HttpRequestHeaders& headers,
129     HttpResponseInfo* response,
130     const CompletionCallback& callback) {
131   DCHECK(!headers.HasHeader(websockets::kSecWebSocketKey));
132   DCHECK(!headers.HasHeader(websockets::kSecWebSocketProtocol));
133   DCHECK(!headers.HasHeader(websockets::kSecWebSocketExtensions));
134   DCHECK(headers.HasHeader(HttpRequestHeaders::kOrigin));
135   DCHECK(headers.HasHeader(websockets::kUpgrade));
136   DCHECK(headers.HasHeader(HttpRequestHeaders::kConnection));
137   DCHECK(headers.HasHeader(websockets::kSecWebSocketVersion));
138   DCHECK(parser());
139 
140   http_response_info_ = response;
141 
142   // Create a copy of the headers object, so that we can add the
143   // Sec-WebSockey-Key header.
144   HttpRequestHeaders enriched_headers;
145   enriched_headers.CopyFrom(headers);
146   std::string handshake_challenge;
147   if (handshake_challenge_for_testing_) {
148     handshake_challenge = *handshake_challenge_for_testing_;
149     handshake_challenge_for_testing_.reset();
150   } else {
151     handshake_challenge = GenerateHandshakeChallenge();
152   }
153   enriched_headers.SetHeader(websockets::kSecWebSocketKey, handshake_challenge);
154 
155   AddVectorHeaderIfNonEmpty(websockets::kSecWebSocketProtocol,
156                             requested_sub_protocols_,
157                             &enriched_headers);
158   AddVectorHeaderIfNonEmpty(websockets::kSecWebSocketExtensions,
159                             requested_extensions_,
160                             &enriched_headers);
161 
162   ComputeSecWebSocketAccept(handshake_challenge,
163                             &handshake_challenge_response_);
164 
165   return parser()->SendRequest(
166       state_.GenerateRequestLine(), enriched_headers, response, callback);
167 }
168 
ReadResponseHeaders(const CompletionCallback & callback)169 int WebSocketBasicHandshakeStream::ReadResponseHeaders(
170     const CompletionCallback& callback) {
171   // HttpStreamParser uses a weak pointer when reading from the
172   // socket, so it won't be called back after being destroyed. The
173   // HttpStreamParser is owned by HttpBasicState which is owned by this object,
174   // so this use of base::Unretained() is safe.
175   int rv = parser()->ReadResponseHeaders(
176       base::Bind(&WebSocketBasicHandshakeStream::ReadResponseHeadersCallback,
177                  base::Unretained(this),
178                  callback));
179   return rv == OK ? ValidateResponse() : rv;
180 }
181 
GetResponseInfo() const182 const HttpResponseInfo* WebSocketBasicHandshakeStream::GetResponseInfo() const {
183   return parser()->GetResponseInfo();
184 }
185 
ReadResponseBody(IOBuffer * buf,int buf_len,const CompletionCallback & callback)186 int WebSocketBasicHandshakeStream::ReadResponseBody(
187     IOBuffer* buf,
188     int buf_len,
189     const CompletionCallback& callback) {
190   return parser()->ReadResponseBody(buf, buf_len, callback);
191 }
192 
Close(bool not_reusable)193 void WebSocketBasicHandshakeStream::Close(bool not_reusable) {
194   // This class ignores the value of |not_reusable| and never lets the socket be
195   // re-used.
196   if (parser())
197     parser()->Close(true);
198 }
199 
IsResponseBodyComplete() const200 bool WebSocketBasicHandshakeStream::IsResponseBodyComplete() const {
201   return parser()->IsResponseBodyComplete();
202 }
203 
CanFindEndOfResponse() const204 bool WebSocketBasicHandshakeStream::CanFindEndOfResponse() const {
205   return parser() && parser()->CanFindEndOfResponse();
206 }
207 
IsConnectionReused() const208 bool WebSocketBasicHandshakeStream::IsConnectionReused() const {
209   return parser()->IsConnectionReused();
210 }
211 
SetConnectionReused()212 void WebSocketBasicHandshakeStream::SetConnectionReused() {
213   parser()->SetConnectionReused();
214 }
215 
IsConnectionReusable() const216 bool WebSocketBasicHandshakeStream::IsConnectionReusable() const {
217   return false;
218 }
219 
GetTotalReceivedBytes() const220 int64 WebSocketBasicHandshakeStream::GetTotalReceivedBytes() const {
221   return 0;
222 }
223 
GetLoadTimingInfo(LoadTimingInfo * load_timing_info) const224 bool WebSocketBasicHandshakeStream::GetLoadTimingInfo(
225     LoadTimingInfo* load_timing_info) const {
226   return state_.connection()->GetLoadTimingInfo(IsConnectionReused(),
227                                                 load_timing_info);
228 }
229 
GetSSLInfo(SSLInfo * ssl_info)230 void WebSocketBasicHandshakeStream::GetSSLInfo(SSLInfo* ssl_info) {
231   parser()->GetSSLInfo(ssl_info);
232 }
233 
GetSSLCertRequestInfo(SSLCertRequestInfo * cert_request_info)234 void WebSocketBasicHandshakeStream::GetSSLCertRequestInfo(
235     SSLCertRequestInfo* cert_request_info) {
236   parser()->GetSSLCertRequestInfo(cert_request_info);
237 }
238 
IsSpdyHttpStream() const239 bool WebSocketBasicHandshakeStream::IsSpdyHttpStream() const { return false; }
240 
Drain(HttpNetworkSession * session)241 void WebSocketBasicHandshakeStream::Drain(HttpNetworkSession* session) {
242   HttpResponseBodyDrainer* drainer = new HttpResponseBodyDrainer(this);
243   drainer->Start(session);
244   // |drainer| will delete itself.
245 }
246 
SetPriority(RequestPriority priority)247 void WebSocketBasicHandshakeStream::SetPriority(RequestPriority priority) {
248   // TODO(ricea): See TODO comment in HttpBasicStream::SetPriority(). If it is
249   // gone, then copy whatever has happened there over here.
250 }
251 
Upgrade()252 scoped_ptr<WebSocketStream> WebSocketBasicHandshakeStream::Upgrade() {
253   // TODO(ricea): Add deflate support.
254 
255   // The HttpStreamParser object has a pointer to our ClientSocketHandle. Make
256   // sure it does not touch it again before it is destroyed.
257   state_.DeleteParser();
258   return scoped_ptr<WebSocketStream>(
259       new WebSocketBasicStream(state_.ReleaseConnection(),
260                                state_.read_buf(),
261                                sub_protocol_,
262                                extensions_));
263 }
264 
SetWebSocketKeyForTesting(const std::string & key)265 void WebSocketBasicHandshakeStream::SetWebSocketKeyForTesting(
266     const std::string& key) {
267   handshake_challenge_for_testing_.reset(new std::string(key));
268 }
269 
ReadResponseHeadersCallback(const CompletionCallback & callback,int result)270 void WebSocketBasicHandshakeStream::ReadResponseHeadersCallback(
271     const CompletionCallback& callback,
272     int result) {
273   if (result == OK)
274     result = ValidateResponse();
275   callback.Run(result);
276 }
277 
ValidateResponse()278 int WebSocketBasicHandshakeStream::ValidateResponse() {
279   DCHECK(http_response_info_);
280   const scoped_refptr<HttpResponseHeaders>& headers =
281       http_response_info_->headers;
282 
283   switch (headers->response_code()) {
284     case HTTP_SWITCHING_PROTOCOLS:
285       return ValidateUpgradeResponse(headers);
286 
287     // We need to pass these through for authentication to work.
288     case HTTP_UNAUTHORIZED:
289     case HTTP_PROXY_AUTHENTICATION_REQUIRED:
290       return OK;
291 
292     // Other status codes are potentially risky (see the warnings in the
293     // WHATWG WebSocket API spec) and so are dropped by default.
294     default:
295       return ERR_INVALID_RESPONSE;
296   }
297 }
298 
ValidateUpgradeResponse(const scoped_refptr<HttpResponseHeaders> & headers)299 int WebSocketBasicHandshakeStream::ValidateUpgradeResponse(
300     const scoped_refptr<HttpResponseHeaders>& headers) {
301   if (ValidateSingleTokenHeader(headers,
302                                 websockets::kUpgrade,
303                                 websockets::kWebSocketLowercase,
304                                 false) &&
305       ValidateSingleTokenHeader(headers,
306                                 websockets::kSecWebSocketAccept,
307                                 handshake_challenge_response_,
308                                 true) &&
309       headers->HasHeaderValue(HttpRequestHeaders::kConnection,
310                               websockets::kUpgrade) &&
311       ValidateSubProtocol(headers, requested_sub_protocols_, &sub_protocol_) &&
312       ValidateExtensions(headers, requested_extensions_, &extensions_)) {
313     return OK;
314   }
315   return ERR_INVALID_RESPONSE;
316 }
317 
318 }  // namespace net
319