• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2013 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/websockets/websocket_stream.h"
6 
7 #include <ostream>
8 #include <utility>
9 
10 #include "base/check.h"
11 #include "base/check_op.h"
12 #include "base/functional/bind.h"
13 #include "base/location.h"
14 #include "base/logging.h"
15 #include "base/memory/raw_ptr.h"
16 #include "base/memory/weak_ptr.h"
17 #include "base/metrics/histogram_functions.h"
18 #include "base/notreached.h"
19 #include "base/strings/strcat.h"
20 #include "base/time/time.h"
21 #include "base/timer/timer.h"
22 #include "net/base/auth.h"
23 #include "net/base/isolation_info.h"
24 #include "net/base/load_flags.h"
25 #include "net/base/net_errors.h"
26 #include "net/base/request_priority.h"
27 #include "net/base/url_util.h"
28 #include "net/http/http_request_headers.h"
29 #include "net/http/http_response_headers.h"
30 #include "net/http/http_response_info.h"
31 #include "net/http/http_status_code.h"
32 #include "net/traffic_annotation/network_traffic_annotation.h"
33 #include "net/url_request/redirect_info.h"
34 #include "net/url_request/url_request.h"
35 #include "net/url_request/url_request_context.h"
36 #include "net/url_request/websocket_handshake_userdata_key.h"
37 #include "net/websockets/websocket_basic_handshake_stream.h"
38 #include "net/websockets/websocket_event_interface.h"
39 #include "net/websockets/websocket_handshake_constants.h"
40 #include "net/websockets/websocket_handshake_response_info.h"
41 #include "net/websockets/websocket_handshake_stream_base.h"
42 #include "net/websockets/websocket_handshake_stream_create_helper.h"
43 #include "net/websockets/websocket_http2_handshake_stream.h"
44 #include "net/websockets/websocket_http3_handshake_stream.h"
45 #include "third_party/abseil-cpp/absl/types/optional.h"
46 #include "url/gurl.h"
47 #include "url/origin.h"
48 
49 namespace net {
50 class SSLCertRequestInfo;
51 class SSLInfo;
52 class SiteForCookies;
53 
54 namespace {
55 
56 // The timeout duration of WebSocket handshake.
57 // It is defined as the same value as the TCP connection timeout value in
58 // net/socket/websocket_transport_client_socket_pool.cc to make it hard for
59 // JavaScript programs to recognize the timeout cause.
60 const int kHandshakeTimeoutIntervalInSeconds = 240;
61 
62 class WebSocketStreamRequestImpl;
63 
64 class Delegate : public URLRequest::Delegate {
65  public:
Delegate(WebSocketStreamRequestImpl * owner)66   explicit Delegate(WebSocketStreamRequestImpl* owner) : owner_(owner) {}
67   ~Delegate() override = default;
68 
69   // Implementation of URLRequest::Delegate methods.
70   void OnReceivedRedirect(URLRequest* request,
71                           const RedirectInfo& redirect_info,
72                           bool* defer_redirect) override;
73 
74   void OnResponseStarted(URLRequest* request, int net_error) override;
75 
76   void OnAuthRequired(URLRequest* request,
77                       const AuthChallengeInfo& auth_info) override;
78 
79   void OnCertificateRequested(URLRequest* request,
80                               SSLCertRequestInfo* cert_request_info) override;
81 
82   void OnSSLCertificateError(URLRequest* request,
83                              int net_error,
84                              const SSLInfo& ssl_info,
85                              bool fatal) override;
86 
87   void OnReadCompleted(URLRequest* request, int bytes_read) override;
88 
89  private:
90   void OnAuthRequiredComplete(URLRequest* request,
91                               const AuthCredentials* auth_credentials);
92 
93   raw_ptr<WebSocketStreamRequestImpl> owner_;
94 };
95 
96 class WebSocketStreamRequestImpl : public WebSocketStreamRequestAPI {
97  public:
WebSocketStreamRequestImpl(const GURL & url,const std::vector<std::string> & requested_subprotocols,const URLRequestContext * context,const url::Origin & origin,const SiteForCookies & site_for_cookies,bool has_storage_access,const IsolationInfo & isolation_info,const HttpRequestHeaders & additional_headers,NetworkTrafficAnnotationTag traffic_annotation,std::unique_ptr<WebSocketStream::ConnectDelegate> connect_delegate,std::unique_ptr<WebSocketStreamRequestAPI> api_delegate)98   WebSocketStreamRequestImpl(
99       const GURL& url,
100       const std::vector<std::string>& requested_subprotocols,
101       const URLRequestContext* context,
102       const url::Origin& origin,
103       const SiteForCookies& site_for_cookies,
104       bool has_storage_access,
105       const IsolationInfo& isolation_info,
106       const HttpRequestHeaders& additional_headers,
107       NetworkTrafficAnnotationTag traffic_annotation,
108       std::unique_ptr<WebSocketStream::ConnectDelegate> connect_delegate,
109       std::unique_ptr<WebSocketStreamRequestAPI> api_delegate)
110       : delegate_(this),
111         url_request_(context->CreateRequest(url,
112                                             DEFAULT_PRIORITY,
113                                             &delegate_,
114                                             traffic_annotation,
115                                             /*is_for_websockets=*/true)),
116         connect_delegate_(std::move(connect_delegate)),
117         api_delegate_(std::move(api_delegate)) {
118     DCHECK_EQ(IsolationInfo::RequestType::kOther,
119               isolation_info.request_type());
120 
121     HttpRequestHeaders headers = additional_headers;
122     headers.SetHeader(websockets::kUpgrade, websockets::kWebSocketLowercase);
123     headers.SetHeader(HttpRequestHeaders::kConnection, websockets::kUpgrade);
124     headers.SetHeader(HttpRequestHeaders::kOrigin, origin.Serialize());
125     headers.SetHeader(websockets::kSecWebSocketVersion,
126                       websockets::kSupportedVersion);
127 
128     // Remove HTTP headers that are important to websocket connections: they
129     // will be added later.
130     headers.RemoveHeader(websockets::kSecWebSocketExtensions);
131     headers.RemoveHeader(websockets::kSecWebSocketKey);
132     headers.RemoveHeader(websockets::kSecWebSocketProtocol);
133 
134     url_request_->SetExtraRequestHeaders(headers);
135     url_request_->set_initiator(origin);
136     url_request_->set_site_for_cookies(site_for_cookies);
137     url_request_->set_has_storage_access(has_storage_access);
138     url_request_->set_isolation_info(isolation_info);
139 
140     auto create_helper = std::make_unique<WebSocketHandshakeStreamCreateHelper>(
141         connect_delegate_.get(), requested_subprotocols, this);
142     url_request_->SetUserData(kWebSocketHandshakeUserDataKey,
143                               std::move(create_helper));
144     url_request_->SetLoadFlags(LOAD_DISABLE_CACHE | LOAD_BYPASS_CACHE);
145     connect_delegate_->OnCreateRequest(url_request_.get());
146   }
147 
148   // Destroying this object destroys the URLRequest, which cancels the request
149   // and so terminates the handshake if it is incomplete.
150   ~WebSocketStreamRequestImpl() override = default;
151 
OnBasicHandshakeStreamCreated(WebSocketBasicHandshakeStream * handshake_stream)152   void OnBasicHandshakeStreamCreated(
153       WebSocketBasicHandshakeStream* handshake_stream) override {
154     if (api_delegate_) {
155       api_delegate_->OnBasicHandshakeStreamCreated(handshake_stream);
156     }
157     OnHandshakeStreamCreated(handshake_stream);
158   }
159 
OnHttp2HandshakeStreamCreated(WebSocketHttp2HandshakeStream * handshake_stream)160   void OnHttp2HandshakeStreamCreated(
161       WebSocketHttp2HandshakeStream* handshake_stream) override {
162     if (api_delegate_) {
163       api_delegate_->OnHttp2HandshakeStreamCreated(handshake_stream);
164     }
165     OnHandshakeStreamCreated(handshake_stream);
166   }
167 
OnHttp3HandshakeStreamCreated(WebSocketHttp3HandshakeStream * handshake_stream)168   void OnHttp3HandshakeStreamCreated(
169       WebSocketHttp3HandshakeStream* handshake_stream) override {
170     if (api_delegate_) {
171       api_delegate_->OnHttp3HandshakeStreamCreated(handshake_stream);
172     }
173     OnHandshakeStreamCreated(handshake_stream);
174   }
175 
OnFailure(const std::string & message,int net_error,absl::optional<int> response_code)176   void OnFailure(const std::string& message,
177                  int net_error,
178                  absl::optional<int> response_code) override {
179     if (api_delegate_)
180       api_delegate_->OnFailure(message, net_error, response_code);
181     failure_message_ = message;
182     failure_net_error_ = net_error;
183     failure_response_code_ = response_code;
184   }
185 
Start(std::unique_ptr<base::OneShotTimer> timer)186   void Start(std::unique_ptr<base::OneShotTimer> timer) {
187     DCHECK(timer);
188     base::TimeDelta timeout(base::Seconds(kHandshakeTimeoutIntervalInSeconds));
189     timer_ = std::move(timer);
190     timer_->Start(FROM_HERE, timeout,
191                   base::BindOnce(&WebSocketStreamRequestImpl::OnTimeout,
192                                  base::Unretained(this)));
193     url_request_->Start();
194   }
195 
PerformUpgrade()196   void PerformUpgrade() {
197     DCHECK(timer_);
198     DCHECK(connect_delegate_);
199 
200     timer_->Stop();
201 
202     if (!handshake_stream_) {
203       ReportFailureWithMessage(
204           "No handshake stream has been created or handshake stream is already "
205           "destroyed.",
206           ERR_FAILED, absl::nullopt);
207       return;
208     }
209 
210     if (!handshake_stream_->CanReadFromStream()) {
211       ReportFailureWithMessage("Handshake stream is not readable.",
212                                ERR_CONNECTION_CLOSED, absl::nullopt);
213       return;
214     }
215 
216     std::unique_ptr<URLRequest> url_request = std::move(url_request_);
217     WebSocketHandshakeStreamBase* handshake_stream = handshake_stream_.get();
218     handshake_stream_.reset();
219     auto handshake_response_info =
220         std::make_unique<WebSocketHandshakeResponseInfo>(
221             url_request->url(), url_request->response_headers(),
222             url_request->GetResponseRemoteEndpoint(),
223             url_request->response_time());
224     connect_delegate_->OnSuccess(handshake_stream->Upgrade(),
225                                  std::move(handshake_response_info));
226 
227     // This is safe even if |this| has already been deleted.
228     url_request->CancelWithError(ERR_WS_UPGRADE);
229   }
230 
FailureMessageFromNetError(int net_error)231   std::string FailureMessageFromNetError(int net_error) {
232     if (net_error == ERR_TUNNEL_CONNECTION_FAILED) {
233       // This error is common and confusing, so special-case it.
234       // TODO(ricea): Include the HostPortPair of the selected proxy server in
235       // the error message. This is not currently possible because it isn't set
236       // in HttpResponseInfo when a ERR_TUNNEL_CONNECTION_FAILED error happens.
237       return "Establishing a tunnel via proxy server failed.";
238     } else {
239       return base::StrCat(
240           {"Error in connection establishment: ", ErrorToString(net_error)});
241     }
242   }
243 
ReportFailure(int net_error,absl::optional<int> response_code)244   void ReportFailure(int net_error, absl::optional<int> response_code) {
245     DCHECK(timer_);
246     timer_->Stop();
247     if (failure_message_.empty()) {
248       switch (net_error) {
249         case OK:
250         case ERR_IO_PENDING:
251           break;
252         case ERR_ABORTED:
253           failure_message_ = "WebSocket opening handshake was canceled";
254           break;
255         case ERR_TIMED_OUT:
256           failure_message_ = "WebSocket opening handshake timed out";
257           break;
258         default:
259           failure_message_ = FailureMessageFromNetError(net_error);
260           break;
261       }
262     }
263 
264     ReportFailureWithMessage(
265         failure_message_, failure_net_error_.value_or(net_error),
266         failure_response_code_ ? failure_response_code_ : response_code);
267   }
268 
ReportFailureWithMessage(const std::string & failure_message,int net_error,absl::optional<int> response_code)269   void ReportFailureWithMessage(const std::string& failure_message,
270                                 int net_error,
271                                 absl::optional<int> response_code) {
272     connect_delegate_->OnFailure(failure_message, net_error, response_code);
273   }
274 
connect_delegate() const275   WebSocketStream::ConnectDelegate* connect_delegate() const {
276     return connect_delegate_.get();
277   }
278 
OnTimeout()279   void OnTimeout() {
280     url_request_->CancelWithError(ERR_TIMED_OUT);
281   }
282 
283  private:
OnHandshakeStreamCreated(WebSocketHandshakeStreamBase * handshake_stream)284   void OnHandshakeStreamCreated(
285       WebSocketHandshakeStreamBase* handshake_stream) {
286     DCHECK(handshake_stream);
287 
288     handshake_stream_ = handshake_stream->GetWeakPtr();
289   }
290 
291   // |delegate_| needs to be declared before |url_request_| so that it gets
292   // initialised first.
293   Delegate delegate_;
294 
295   // Deleting the WebSocketStreamRequestImpl object deletes this URLRequest
296   // object, cancelling the whole connection.
297   std::unique_ptr<URLRequest> url_request_;
298 
299   std::unique_ptr<WebSocketStream::ConnectDelegate> connect_delegate_;
300 
301   // This is owned by the caller of
302   // WebsocketHandshakeStreamCreateHelper::CreateBasicStream() or
303   // CreateHttp2Stream() or CreateHttp3Stream().  Both the stream and this
304   // object will be destroyed during the destruction of the URLRequest object
305   // associated with the handshake. This is only guaranteed to be a valid
306   // pointer if the handshake succeeded.
307   base::WeakPtr<WebSocketHandshakeStreamBase> handshake_stream_;
308 
309   // The failure information supplied by WebSocketBasicHandshakeStream, if any.
310   std::string failure_message_;
311   absl::optional<int> failure_net_error_;
312   absl::optional<int> failure_response_code_;
313 
314   // A timer for handshake timeout.
315   std::unique_ptr<base::OneShotTimer> timer_;
316 
317   // A delegate for On*HandshakeCreated and OnFailure calls.
318   std::unique_ptr<WebSocketStreamRequestAPI> api_delegate_;
319 };
320 
321 class SSLErrorCallbacks : public WebSocketEventInterface::SSLErrorCallbacks {
322  public:
SSLErrorCallbacks(URLRequest * url_request)323   explicit SSLErrorCallbacks(URLRequest* url_request)
324       : url_request_(url_request->GetWeakPtr()) {}
325 
CancelSSLRequest(int error,const SSLInfo * ssl_info)326   void CancelSSLRequest(int error, const SSLInfo* ssl_info) override {
327     if (!url_request_)
328       return;
329 
330     if (ssl_info) {
331       url_request_->CancelWithSSLError(error, *ssl_info);
332     } else {
333       url_request_->CancelWithError(error);
334     }
335   }
336 
ContinueSSLRequest()337   void ContinueSSLRequest() override {
338     if (url_request_)
339       url_request_->ContinueDespiteLastError();
340   }
341 
342  private:
343   base::WeakPtr<URLRequest> url_request_;
344 };
345 
OnReceivedRedirect(URLRequest * request,const RedirectInfo & redirect_info,bool * defer_redirect)346 void Delegate::OnReceivedRedirect(URLRequest* request,
347                                   const RedirectInfo& redirect_info,
348                                   bool* defer_redirect) {
349   // This code should never be reached for externally generated redirects,
350   // as WebSocketBasicHandshakeStream is responsible for filtering out
351   // all response codes besides 101, 401, and 407. As such, the URLRequest
352   // should never see a redirect sent over the network. However, internal
353   // redirects also result in this method being called, such as those
354   // caused by HSTS.
355   // Because it's security critical to prevent externally-generated
356   // redirects in WebSockets, perform additional checks to ensure this
357   // is only internal.
358   GURL::Replacements replacements;
359   replacements.SetSchemeStr("wss");
360   GURL expected_url = request->original_url().ReplaceComponents(replacements);
361   if (redirect_info.new_method != "GET" ||
362       redirect_info.new_url != expected_url) {
363     // This should not happen.
364     DLOG(FATAL) << "Unauthorized WebSocket redirect to "
365                 << redirect_info.new_method << " "
366                 << redirect_info.new_url.spec();
367     request->Cancel();
368   }
369 }
370 
OnResponseStarted(URLRequest * request,int net_error)371 void Delegate::OnResponseStarted(URLRequest* request, int net_error) {
372   DCHECK_NE(ERR_IO_PENDING, net_error);
373 
374   const bool is_http2 =
375       request->response_info().connection_info == HttpConnectionInfo::kHTTP2;
376 
377   // All error codes, including OK and ABORTED, as with
378   // Net.ErrorCodesForMainFrame4
379   base::UmaHistogramSparse("Net.WebSocket.ErrorCodes", -net_error);
380   if (is_http2) {
381     base::UmaHistogramSparse("Net.WebSocket.ErrorCodes.Http2", -net_error);
382   }
383   if (net::IsLocalhost(request->url())) {
384     base::UmaHistogramSparse("Net.WebSocket.ErrorCodes_Localhost", -net_error);
385   } else {
386     base::UmaHistogramSparse("Net.WebSocket.ErrorCodes_NotLocalhost",
387                              -net_error);
388   }
389 
390   if (net_error != OK) {
391     DVLOG(3) << "OnResponseStarted (request failed)";
392     owner_->ReportFailure(net_error, absl::nullopt);
393     return;
394   }
395   const int response_code = request->GetResponseCode();
396   DVLOG(3) << "OnResponseStarted (response code " << response_code << ")";
397 
398   if (is_http2) {
399     if (response_code == HTTP_OK) {
400       owner_->PerformUpgrade();
401       return;
402     }
403 
404     owner_->ReportFailure(net_error, absl::nullopt);
405     return;
406   }
407 
408   switch (response_code) {
409     case HTTP_SWITCHING_PROTOCOLS:
410       owner_->PerformUpgrade();
411       return;
412 
413     case HTTP_UNAUTHORIZED:
414       owner_->ReportFailureWithMessage(
415           "HTTP Authentication failed; no valid credentials available",
416           net_error, response_code);
417       return;
418 
419     case HTTP_PROXY_AUTHENTICATION_REQUIRED:
420       owner_->ReportFailureWithMessage("Proxy authentication failed", net_error,
421                                        response_code);
422       return;
423 
424     default:
425       owner_->ReportFailure(net_error, response_code);
426   }
427 }
428 
OnAuthRequired(URLRequest * request,const AuthChallengeInfo & auth_info)429 void Delegate::OnAuthRequired(URLRequest* request,
430                               const AuthChallengeInfo& auth_info) {
431   absl::optional<AuthCredentials> credentials;
432   // This base::Unretained(this) relies on an assumption that |callback| can
433   // be called called during the opening handshake.
434   int rv = owner_->connect_delegate()->OnAuthRequired(
435       auth_info, request->response_headers(),
436       request->GetResponseRemoteEndpoint(),
437       base::BindOnce(&Delegate::OnAuthRequiredComplete, base::Unretained(this),
438                      request),
439       &credentials);
440   request->LogBlockedBy("WebSocketStream::Delegate::OnAuthRequired");
441   if (rv == ERR_IO_PENDING)
442     return;
443   if (rv != OK) {
444     request->LogUnblocked();
445     owner_->ReportFailure(rv, absl::nullopt);
446     return;
447   }
448   OnAuthRequiredComplete(request, nullptr);
449 }
450 
OnAuthRequiredComplete(URLRequest * request,const AuthCredentials * credentials)451 void Delegate::OnAuthRequiredComplete(URLRequest* request,
452                                       const AuthCredentials* credentials) {
453   request->LogUnblocked();
454   if (!credentials) {
455     request->CancelAuth();
456     return;
457   }
458   request->SetAuth(*credentials);
459 }
460 
OnCertificateRequested(URLRequest * request,SSLCertRequestInfo * cert_request_info)461 void Delegate::OnCertificateRequested(URLRequest* request,
462                                       SSLCertRequestInfo* cert_request_info) {
463   // This method is called when a client certificate is requested, and the
464   // request context does not already contain a client certificate selection for
465   // the endpoint. In this case, a main frame resource request would pop-up UI
466   // to permit selection of a client certificate, but since WebSockets are
467   // sub-resources they should not pop-up UI and so there is nothing more we can
468   // do.
469   request->Cancel();
470 }
471 
OnSSLCertificateError(URLRequest * request,int net_error,const SSLInfo & ssl_info,bool fatal)472 void Delegate::OnSSLCertificateError(URLRequest* request,
473                                      int net_error,
474                                      const SSLInfo& ssl_info,
475                                      bool fatal) {
476   owner_->connect_delegate()->OnSSLCertificateError(
477       std::make_unique<SSLErrorCallbacks>(request), net_error, ssl_info, fatal);
478 }
479 
OnReadCompleted(URLRequest * request,int bytes_read)480 void Delegate::OnReadCompleted(URLRequest* request, int bytes_read) {
481   NOTREACHED();
482 }
483 
484 }  // namespace
485 
486 WebSocketStreamRequest::~WebSocketStreamRequest() = default;
487 
488 WebSocketStream::WebSocketStream() = default;
489 WebSocketStream::~WebSocketStream() = default;
490 
491 WebSocketStream::ConnectDelegate::~ConnectDelegate() = default;
492 
CreateAndConnectStream(const GURL & socket_url,const std::vector<std::string> & requested_subprotocols,const url::Origin & origin,const SiteForCookies & site_for_cookies,bool has_storage_access,const IsolationInfo & isolation_info,const HttpRequestHeaders & additional_headers,URLRequestContext * url_request_context,const NetLogWithSource & net_log,NetworkTrafficAnnotationTag traffic_annotation,std::unique_ptr<ConnectDelegate> connect_delegate)493 std::unique_ptr<WebSocketStreamRequest> WebSocketStream::CreateAndConnectStream(
494     const GURL& socket_url,
495     const std::vector<std::string>& requested_subprotocols,
496     const url::Origin& origin,
497     const SiteForCookies& site_for_cookies,
498     bool has_storage_access,
499     const IsolationInfo& isolation_info,
500     const HttpRequestHeaders& additional_headers,
501     URLRequestContext* url_request_context,
502     const NetLogWithSource& net_log,
503     NetworkTrafficAnnotationTag traffic_annotation,
504     std::unique_ptr<ConnectDelegate> connect_delegate) {
505   auto request = std::make_unique<WebSocketStreamRequestImpl>(
506       socket_url, requested_subprotocols, url_request_context, origin,
507       site_for_cookies, has_storage_access, isolation_info, additional_headers,
508       traffic_annotation, std::move(connect_delegate), nullptr);
509   request->Start(std::make_unique<base::OneShotTimer>());
510   return std::move(request);
511 }
512 
513 std::unique_ptr<WebSocketStreamRequest>
CreateAndConnectStreamForTesting(const GURL & socket_url,const std::vector<std::string> & requested_subprotocols,const url::Origin & origin,const SiteForCookies & site_for_cookies,bool has_storage_access,const IsolationInfo & isolation_info,const HttpRequestHeaders & additional_headers,URLRequestContext * url_request_context,const NetLogWithSource & net_log,NetworkTrafficAnnotationTag traffic_annotation,std::unique_ptr<WebSocketStream::ConnectDelegate> connect_delegate,std::unique_ptr<base::OneShotTimer> timer,std::unique_ptr<WebSocketStreamRequestAPI> api_delegate)514 WebSocketStream::CreateAndConnectStreamForTesting(
515     const GURL& socket_url,
516     const std::vector<std::string>& requested_subprotocols,
517     const url::Origin& origin,
518     const SiteForCookies& site_for_cookies,
519     bool has_storage_access,
520     const IsolationInfo& isolation_info,
521     const HttpRequestHeaders& additional_headers,
522     URLRequestContext* url_request_context,
523     const NetLogWithSource& net_log,
524     NetworkTrafficAnnotationTag traffic_annotation,
525     std::unique_ptr<WebSocketStream::ConnectDelegate> connect_delegate,
526     std::unique_ptr<base::OneShotTimer> timer,
527     std::unique_ptr<WebSocketStreamRequestAPI> api_delegate) {
528   auto request = std::make_unique<WebSocketStreamRequestImpl>(
529       socket_url, requested_subprotocols, url_request_context, origin,
530       site_for_cookies, has_storage_access, isolation_info, additional_headers,
531       traffic_annotation, std::move(connect_delegate), std::move(api_delegate));
532   request->Start(std::move(timer));
533   return std::move(request);
534 }
535 
536 }  // namespace net
537