1 // Copyright 2013 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/websockets/websocket_stream.h"
6
7 #include <utility>
8
9 #include "base/functional/bind.h"
10 #include "base/logging.h"
11 #include "base/memory/raw_ptr.h"
12 #include "base/memory/weak_ptr.h"
13 #include "base/metrics/histogram_functions.h"
14 #include "base/time/time.h"
15 #include "base/timer/timer.h"
16 #include "net/base/ip_endpoint.h"
17 #include "net/base/isolation_info.h"
18 #include "net/base/load_flags.h"
19 #include "net/base/url_util.h"
20 #include "net/http/http_request_headers.h"
21 #include "net/http/http_response_headers.h"
22 #include "net/http/http_response_info.h"
23 #include "net/http/http_status_code.h"
24 #include "net/traffic_annotation/network_traffic_annotation.h"
25 #include "net/url_request/redirect_info.h"
26 #include "net/url_request/url_request.h"
27 #include "net/url_request/url_request_context.h"
28 #include "net/url_request/websocket_handshake_userdata_key.h"
29 #include "net/websockets/websocket_basic_handshake_stream.h"
30 #include "net/websockets/websocket_errors.h"
31 #include "net/websockets/websocket_event_interface.h"
32 #include "net/websockets/websocket_handshake_constants.h"
33 #include "net/websockets/websocket_handshake_stream_base.h"
34 #include "net/websockets/websocket_handshake_stream_create_helper.h"
35 #include "net/websockets/websocket_http2_handshake_stream.h"
36 #include "net/websockets/websocket_http3_handshake_stream.h"
37 #include "third_party/abseil-cpp/absl/types/optional.h"
38 #include "url/gurl.h"
39 #include "url/origin.h"
40
41 namespace net {
42 namespace {
43
44 // The timeout duration of WebSocket handshake.
45 // It is defined as the same value as the TCP connection timeout value in
46 // net/socket/websocket_transport_client_socket_pool.cc to make it hard for
47 // JavaScript programs to recognize the timeout cause.
48 const int kHandshakeTimeoutIntervalInSeconds = 240;
49
50 class WebSocketStreamRequestImpl;
51
52 class Delegate : public URLRequest::Delegate {
53 public:
Delegate(WebSocketStreamRequestImpl * owner)54 explicit Delegate(WebSocketStreamRequestImpl* owner) : owner_(owner) {}
55 ~Delegate() override = default;
56
57 // Implementation of URLRequest::Delegate methods.
58 void OnReceivedRedirect(URLRequest* request,
59 const RedirectInfo& redirect_info,
60 bool* defer_redirect) override;
61
62 void OnResponseStarted(URLRequest* request, int net_error) override;
63
64 void OnAuthRequired(URLRequest* request,
65 const AuthChallengeInfo& auth_info) override;
66
67 void OnCertificateRequested(URLRequest* request,
68 SSLCertRequestInfo* cert_request_info) override;
69
70 void OnSSLCertificateError(URLRequest* request,
71 int net_error,
72 const SSLInfo& ssl_info,
73 bool fatal) override;
74
75 void OnReadCompleted(URLRequest* request, int bytes_read) override;
76
77 private:
78 void OnAuthRequiredComplete(URLRequest* request,
79 const AuthCredentials* auth_credentials);
80
81 raw_ptr<WebSocketStreamRequestImpl> owner_;
82 };
83
84 class WebSocketStreamRequestImpl : public WebSocketStreamRequestAPI {
85 public:
WebSocketStreamRequestImpl(const GURL & url,const std::vector<std::string> & requested_subprotocols,const URLRequestContext * context,const url::Origin & origin,const SiteForCookies & site_for_cookies,const IsolationInfo & isolation_info,const HttpRequestHeaders & additional_headers,NetworkTrafficAnnotationTag traffic_annotation,std::unique_ptr<WebSocketStream::ConnectDelegate> connect_delegate,std::unique_ptr<WebSocketStreamRequestAPI> api_delegate)86 WebSocketStreamRequestImpl(
87 const GURL& url,
88 const std::vector<std::string>& requested_subprotocols,
89 const URLRequestContext* context,
90 const url::Origin& origin,
91 const SiteForCookies& site_for_cookies,
92 const IsolationInfo& isolation_info,
93 const HttpRequestHeaders& additional_headers,
94 NetworkTrafficAnnotationTag traffic_annotation,
95 std::unique_ptr<WebSocketStream::ConnectDelegate> connect_delegate,
96 std::unique_ptr<WebSocketStreamRequestAPI> api_delegate)
97 : delegate_(this),
98 url_request_(context->CreateRequest(url,
99 DEFAULT_PRIORITY,
100 &delegate_,
101 traffic_annotation,
102 /*is_for_websockets=*/true)),
103 connect_delegate_(std::move(connect_delegate)),
104 api_delegate_(std::move(api_delegate)) {
105 DCHECK_EQ(IsolationInfo::RequestType::kOther,
106 isolation_info.request_type());
107
108 HttpRequestHeaders headers = additional_headers;
109 headers.SetHeader(websockets::kUpgrade, websockets::kWebSocketLowercase);
110 headers.SetHeader(HttpRequestHeaders::kConnection, websockets::kUpgrade);
111 headers.SetHeader(HttpRequestHeaders::kOrigin, origin.Serialize());
112 headers.SetHeader(websockets::kSecWebSocketVersion,
113 websockets::kSupportedVersion);
114
115 // Remove HTTP headers that are important to websocket connections: they
116 // will be added later.
117 headers.RemoveHeader(websockets::kSecWebSocketExtensions);
118 headers.RemoveHeader(websockets::kSecWebSocketKey);
119 headers.RemoveHeader(websockets::kSecWebSocketProtocol);
120
121 url_request_->SetExtraRequestHeaders(headers);
122 url_request_->set_initiator(origin);
123 url_request_->set_site_for_cookies(site_for_cookies);
124 url_request_->set_isolation_info(isolation_info);
125
126 auto create_helper = std::make_unique<WebSocketHandshakeStreamCreateHelper>(
127 connect_delegate_.get(), requested_subprotocols, this);
128 url_request_->SetUserData(kWebSocketHandshakeUserDataKey,
129 std::move(create_helper));
130 url_request_->SetLoadFlags(LOAD_DISABLE_CACHE | LOAD_BYPASS_CACHE);
131 connect_delegate_->OnCreateRequest(url_request_.get());
132 }
133
134 // Destroying this object destroys the URLRequest, which cancels the request
135 // and so terminates the handshake if it is incomplete.
136 ~WebSocketStreamRequestImpl() override = default;
137
OnBasicHandshakeStreamCreated(WebSocketBasicHandshakeStream * handshake_stream)138 void OnBasicHandshakeStreamCreated(
139 WebSocketBasicHandshakeStream* handshake_stream) override {
140 if (api_delegate_) {
141 api_delegate_->OnBasicHandshakeStreamCreated(handshake_stream);
142 }
143 OnHandshakeStreamCreated(handshake_stream);
144 }
145
OnHttp2HandshakeStreamCreated(WebSocketHttp2HandshakeStream * handshake_stream)146 void OnHttp2HandshakeStreamCreated(
147 WebSocketHttp2HandshakeStream* handshake_stream) override {
148 if (api_delegate_) {
149 api_delegate_->OnHttp2HandshakeStreamCreated(handshake_stream);
150 }
151 OnHandshakeStreamCreated(handshake_stream);
152 }
153
OnHttp3HandshakeStreamCreated(WebSocketHttp3HandshakeStream * handshake_stream)154 void OnHttp3HandshakeStreamCreated(
155 WebSocketHttp3HandshakeStream* handshake_stream) override {
156 if (api_delegate_) {
157 api_delegate_->OnHttp3HandshakeStreamCreated(handshake_stream);
158 }
159 OnHandshakeStreamCreated(handshake_stream);
160 }
161
OnFailure(const std::string & message,int net_error,absl::optional<int> response_code)162 void OnFailure(const std::string& message,
163 int net_error,
164 absl::optional<int> response_code) override {
165 if (api_delegate_)
166 api_delegate_->OnFailure(message, net_error, response_code);
167 failure_message_ = message;
168 failure_net_error_ = net_error;
169 failure_response_code_ = response_code;
170 }
171
Start(std::unique_ptr<base::OneShotTimer> timer)172 void Start(std::unique_ptr<base::OneShotTimer> timer) {
173 DCHECK(timer);
174 base::TimeDelta timeout(base::Seconds(kHandshakeTimeoutIntervalInSeconds));
175 timer_ = std::move(timer);
176 timer_->Start(FROM_HERE, timeout,
177 base::BindOnce(&WebSocketStreamRequestImpl::OnTimeout,
178 base::Unretained(this)));
179 url_request_->Start();
180 }
181
PerformUpgrade()182 void PerformUpgrade() {
183 DCHECK(timer_);
184 DCHECK(connect_delegate_);
185
186 timer_->Stop();
187
188 if (!handshake_stream_) {
189 ReportFailureWithMessage(
190 "No handshake stream has been created or handshake stream is already "
191 "destroyed.",
192 ERR_FAILED, absl::nullopt);
193 return;
194 }
195
196 std::unique_ptr<URLRequest> url_request = std::move(url_request_);
197 WebSocketHandshakeStreamBase* handshake_stream = handshake_stream_.get();
198 handshake_stream_.reset();
199 auto handshake_response_info =
200 std::make_unique<WebSocketHandshakeResponseInfo>(
201 url_request->url(), url_request->response_headers(),
202 url_request->GetResponseRemoteEndpoint(),
203 url_request->response_time());
204 connect_delegate_->OnSuccess(handshake_stream->Upgrade(),
205 std::move(handshake_response_info));
206
207 // This is safe even if |this| has already been deleted.
208 url_request->CancelWithError(ERR_WS_UPGRADE);
209 }
210
FailureMessageFromNetError(int net_error)211 std::string FailureMessageFromNetError(int net_error) {
212 if (net_error == ERR_TUNNEL_CONNECTION_FAILED) {
213 // This error is common and confusing, so special-case it.
214 // TODO(ricea): Include the HostPortPair of the selected proxy server in
215 // the error message. This is not currently possible because it isn't set
216 // in HttpResponseInfo when a ERR_TUNNEL_CONNECTION_FAILED error happens.
217 return "Establishing a tunnel via proxy server failed.";
218 } else {
219 return std::string("Error in connection establishment: ") +
220 ErrorToString(net_error);
221 }
222 }
223
ReportFailure(int net_error,absl::optional<int> response_code)224 void ReportFailure(int net_error, absl::optional<int> response_code) {
225 DCHECK(timer_);
226 timer_->Stop();
227 if (failure_message_.empty()) {
228 switch (net_error) {
229 case OK:
230 case ERR_IO_PENDING:
231 break;
232 case ERR_ABORTED:
233 failure_message_ = "WebSocket opening handshake was canceled";
234 break;
235 case ERR_TIMED_OUT:
236 failure_message_ = "WebSocket opening handshake timed out";
237 break;
238 default:
239 failure_message_ = FailureMessageFromNetError(net_error);
240 break;
241 }
242 }
243
244 ReportFailureWithMessage(
245 failure_message_, failure_net_error_.value_or(net_error),
246 failure_response_code_ ? failure_response_code_ : response_code);
247 }
248
ReportFailureWithMessage(const std::string & failure_message,int net_error,absl::optional<int> response_code)249 void ReportFailureWithMessage(const std::string& failure_message,
250 int net_error,
251 absl::optional<int> response_code) {
252 connect_delegate_->OnFailure(failure_message, net_error, response_code);
253 }
254
connect_delegate() const255 WebSocketStream::ConnectDelegate* connect_delegate() const {
256 return connect_delegate_.get();
257 }
258
OnTimeout()259 void OnTimeout() {
260 url_request_->CancelWithError(ERR_TIMED_OUT);
261 }
262
263 private:
OnHandshakeStreamCreated(WebSocketHandshakeStreamBase * handshake_stream)264 void OnHandshakeStreamCreated(
265 WebSocketHandshakeStreamBase* handshake_stream) {
266 DCHECK(handshake_stream);
267
268 handshake_stream_ = handshake_stream->GetWeakPtr();
269 }
270
271 // |delegate_| needs to be declared before |url_request_| so that it gets
272 // initialised first.
273 Delegate delegate_;
274
275 // Deleting the WebSocketStreamRequestImpl object deletes this URLRequest
276 // object, cancelling the whole connection.
277 std::unique_ptr<URLRequest> url_request_;
278
279 std::unique_ptr<WebSocketStream::ConnectDelegate> connect_delegate_;
280
281 // This is owned by the caller of
282 // WebsocketHandshakeStreamCreateHelper::CreateBasicStream() or
283 // CreateHttp2Stream() or CreateHttp3Stream(). Both the stream and this
284 // object will be destroyed during the destruction of the URLRequest object
285 // associated with the handshake. This is only guaranteed to be a valid
286 // pointer if the handshake succeeded.
287 base::WeakPtr<WebSocketHandshakeStreamBase> handshake_stream_;
288
289 // The failure information supplied by WebSocketBasicHandshakeStream, if any.
290 std::string failure_message_;
291 absl::optional<int> failure_net_error_;
292 absl::optional<int> failure_response_code_;
293
294 // A timer for handshake timeout.
295 std::unique_ptr<base::OneShotTimer> timer_;
296
297 // A delegate for On*HandshakeCreated and OnFailure calls.
298 std::unique_ptr<WebSocketStreamRequestAPI> api_delegate_;
299 };
300
301 class SSLErrorCallbacks : public WebSocketEventInterface::SSLErrorCallbacks {
302 public:
SSLErrorCallbacks(URLRequest * url_request)303 explicit SSLErrorCallbacks(URLRequest* url_request)
304 : url_request_(url_request->GetWeakPtr()) {}
305
CancelSSLRequest(int error,const SSLInfo * ssl_info)306 void CancelSSLRequest(int error, const SSLInfo* ssl_info) override {
307 if (!url_request_)
308 return;
309
310 if (ssl_info) {
311 url_request_->CancelWithSSLError(error, *ssl_info);
312 } else {
313 url_request_->CancelWithError(error);
314 }
315 }
316
ContinueSSLRequest()317 void ContinueSSLRequest() override {
318 if (url_request_)
319 url_request_->ContinueDespiteLastError();
320 }
321
322 private:
323 base::WeakPtr<URLRequest> url_request_;
324 };
325
OnReceivedRedirect(URLRequest * request,const RedirectInfo & redirect_info,bool * defer_redirect)326 void Delegate::OnReceivedRedirect(URLRequest* request,
327 const RedirectInfo& redirect_info,
328 bool* defer_redirect) {
329 // This code should never be reached for externally generated redirects,
330 // as WebSocketBasicHandshakeStream is responsible for filtering out
331 // all response codes besides 101, 401, and 407. As such, the URLRequest
332 // should never see a redirect sent over the network. However, internal
333 // redirects also result in this method being called, such as those
334 // caused by HSTS.
335 // Because it's security critical to prevent externally-generated
336 // redirects in WebSockets, perform additional checks to ensure this
337 // is only internal.
338 GURL::Replacements replacements;
339 replacements.SetSchemeStr("wss");
340 GURL expected_url = request->original_url().ReplaceComponents(replacements);
341 if (redirect_info.new_method != "GET" ||
342 redirect_info.new_url != expected_url) {
343 // This should not happen.
344 DLOG(FATAL) << "Unauthorized WebSocket redirect to "
345 << redirect_info.new_method << " "
346 << redirect_info.new_url.spec();
347 request->Cancel();
348 }
349 }
350
OnResponseStarted(URLRequest * request,int net_error)351 void Delegate::OnResponseStarted(URLRequest* request, int net_error) {
352 DCHECK_NE(ERR_IO_PENDING, net_error);
353 // All error codes, including OK and ABORTED, as with
354 // Net.ErrorCodesForMainFrame4
355 base::UmaHistogramSparse("Net.WebSocket.ErrorCodes", -net_error);
356 if (net::IsLocalhost(request->url())) {
357 base::UmaHistogramSparse("Net.WebSocket.ErrorCodes_Localhost", -net_error);
358 } else {
359 base::UmaHistogramSparse("Net.WebSocket.ErrorCodes_NotLocalhost",
360 -net_error);
361 }
362
363 if (net_error != OK) {
364 DVLOG(3) << "OnResponseStarted (request failed)";
365 owner_->ReportFailure(net_error, absl::nullopt);
366 return;
367 }
368 const int response_code = request->GetResponseCode();
369 DVLOG(3) << "OnResponseStarted (response code " << response_code << ")";
370
371 if (request->response_info().connection_info ==
372 HttpResponseInfo::CONNECTION_INFO_HTTP2) {
373 if (response_code == HTTP_OK) {
374 owner_->PerformUpgrade();
375 return;
376 }
377
378 owner_->ReportFailure(net_error, absl::nullopt);
379 return;
380 }
381
382 switch (response_code) {
383 case HTTP_SWITCHING_PROTOCOLS:
384 owner_->PerformUpgrade();
385 return;
386
387 case HTTP_UNAUTHORIZED:
388 owner_->ReportFailureWithMessage(
389 "HTTP Authentication failed; no valid credentials available",
390 net_error, response_code);
391 return;
392
393 case HTTP_PROXY_AUTHENTICATION_REQUIRED:
394 owner_->ReportFailureWithMessage("Proxy authentication failed", net_error,
395 response_code);
396 return;
397
398 default:
399 owner_->ReportFailure(net_error, response_code);
400 }
401 }
402
OnAuthRequired(URLRequest * request,const AuthChallengeInfo & auth_info)403 void Delegate::OnAuthRequired(URLRequest* request,
404 const AuthChallengeInfo& auth_info) {
405 absl::optional<AuthCredentials> credentials;
406 // This base::Unretained(this) relies on an assumption that |callback| can
407 // be called called during the opening handshake.
408 int rv = owner_->connect_delegate()->OnAuthRequired(
409 auth_info, request->response_headers(),
410 request->GetResponseRemoteEndpoint(),
411 base::BindOnce(&Delegate::OnAuthRequiredComplete, base::Unretained(this),
412 request),
413 &credentials);
414 request->LogBlockedBy("WebSocketStream::Delegate::OnAuthRequired");
415 if (rv == ERR_IO_PENDING)
416 return;
417 if (rv != OK) {
418 request->LogUnblocked();
419 owner_->ReportFailure(rv, absl::nullopt);
420 return;
421 }
422 OnAuthRequiredComplete(request, nullptr);
423 }
424
OnAuthRequiredComplete(URLRequest * request,const AuthCredentials * credentials)425 void Delegate::OnAuthRequiredComplete(URLRequest* request,
426 const AuthCredentials* credentials) {
427 request->LogUnblocked();
428 if (!credentials) {
429 request->CancelAuth();
430 return;
431 }
432 request->SetAuth(*credentials);
433 }
434
OnCertificateRequested(URLRequest * request,SSLCertRequestInfo * cert_request_info)435 void Delegate::OnCertificateRequested(URLRequest* request,
436 SSLCertRequestInfo* cert_request_info) {
437 // This method is called when a client certificate is requested, and the
438 // request context does not already contain a client certificate selection for
439 // the endpoint. In this case, a main frame resource request would pop-up UI
440 // to permit selection of a client certificate, but since WebSockets are
441 // sub-resources they should not pop-up UI and so there is nothing more we can
442 // do.
443 request->Cancel();
444 }
445
OnSSLCertificateError(URLRequest * request,int net_error,const SSLInfo & ssl_info,bool fatal)446 void Delegate::OnSSLCertificateError(URLRequest* request,
447 int net_error,
448 const SSLInfo& ssl_info,
449 bool fatal) {
450 owner_->connect_delegate()->OnSSLCertificateError(
451 std::make_unique<SSLErrorCallbacks>(request), net_error, ssl_info, fatal);
452 }
453
OnReadCompleted(URLRequest * request,int bytes_read)454 void Delegate::OnReadCompleted(URLRequest* request, int bytes_read) {
455 NOTREACHED();
456 }
457
458 } // namespace
459
460 WebSocketStreamRequest::~WebSocketStreamRequest() = default;
461
462 WebSocketStream::WebSocketStream() = default;
463 WebSocketStream::~WebSocketStream() = default;
464
465 WebSocketStream::ConnectDelegate::~ConnectDelegate() = default;
466
CreateAndConnectStream(const GURL & socket_url,const std::vector<std::string> & requested_subprotocols,const url::Origin & origin,const SiteForCookies & site_for_cookies,const IsolationInfo & isolation_info,const HttpRequestHeaders & additional_headers,URLRequestContext * url_request_context,const NetLogWithSource & net_log,NetworkTrafficAnnotationTag traffic_annotation,std::unique_ptr<ConnectDelegate> connect_delegate)467 std::unique_ptr<WebSocketStreamRequest> WebSocketStream::CreateAndConnectStream(
468 const GURL& socket_url,
469 const std::vector<std::string>& requested_subprotocols,
470 const url::Origin& origin,
471 const SiteForCookies& site_for_cookies,
472 const IsolationInfo& isolation_info,
473 const HttpRequestHeaders& additional_headers,
474 URLRequestContext* url_request_context,
475 const NetLogWithSource& net_log,
476 NetworkTrafficAnnotationTag traffic_annotation,
477 std::unique_ptr<ConnectDelegate> connect_delegate) {
478 auto request = std::make_unique<WebSocketStreamRequestImpl>(
479 socket_url, requested_subprotocols, url_request_context, origin,
480 site_for_cookies, isolation_info, additional_headers, traffic_annotation,
481 std::move(connect_delegate), nullptr);
482 request->Start(std::make_unique<base::OneShotTimer>());
483 return std::move(request);
484 }
485
486 std::unique_ptr<WebSocketStreamRequest>
CreateAndConnectStreamForTesting(const GURL & socket_url,const std::vector<std::string> & requested_subprotocols,const url::Origin & origin,const SiteForCookies & site_for_cookies,const IsolationInfo & isolation_info,const HttpRequestHeaders & additional_headers,URLRequestContext * url_request_context,const NetLogWithSource & net_log,NetworkTrafficAnnotationTag traffic_annotation,std::unique_ptr<WebSocketStream::ConnectDelegate> connect_delegate,std::unique_ptr<base::OneShotTimer> timer,std::unique_ptr<WebSocketStreamRequestAPI> api_delegate)487 WebSocketStream::CreateAndConnectStreamForTesting(
488 const GURL& socket_url,
489 const std::vector<std::string>& requested_subprotocols,
490 const url::Origin& origin,
491 const SiteForCookies& site_for_cookies,
492 const IsolationInfo& isolation_info,
493 const HttpRequestHeaders& additional_headers,
494 URLRequestContext* url_request_context,
495 const NetLogWithSource& net_log,
496 NetworkTrafficAnnotationTag traffic_annotation,
497 std::unique_ptr<WebSocketStream::ConnectDelegate> connect_delegate,
498 std::unique_ptr<base::OneShotTimer> timer,
499 std::unique_ptr<WebSocketStreamRequestAPI> api_delegate) {
500 auto request = std::make_unique<WebSocketStreamRequestImpl>(
501 socket_url, requested_subprotocols, url_request_context, origin,
502 site_for_cookies, isolation_info, additional_headers, traffic_annotation,
503 std::move(connect_delegate), std::move(api_delegate));
504 request->Start(std::move(timer));
505 return std::move(request);
506 }
507
508 } // namespace net
509