• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2023 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/websockets/websocket_http3_handshake_stream.h"
6 
7 #include <utility>
8 
9 #include "base/check.h"
10 #include "base/check_op.h"
11 #include "base/functional/bind.h"
12 #include "base/functional/callback.h"
13 #include "base/memory/scoped_refptr.h"
14 #include "base/strings/strcat.h"
15 #include "base/strings/stringprintf.h"
16 #include "base/time/time.h"
17 #include "net/base/ip_endpoint.h"
18 #include "net/http/http_request_headers.h"
19 #include "net/http/http_request_info.h"
20 #include "net/http/http_response_headers.h"
21 #include "net/http/http_response_info.h"
22 #include "net/http/http_status_code.h"
23 #include "net/spdy/spdy_http_utils.h"
24 #include "net/third_party/quiche/src/quiche/spdy/core/http2_header_block.h"
25 #include "net/traffic_annotation/network_traffic_annotation.h"
26 #include "net/websockets/websocket_basic_stream.h"
27 #include "net/websockets/websocket_deflate_predictor_impl.h"
28 #include "net/websockets/websocket_deflate_stream.h"
29 #include "net/websockets/websocket_handshake_constants.h"
30 #include "net/websockets/websocket_handshake_request_info.h"
31 
32 namespace net {
33 struct AlternativeService;
34 
35 namespace {
36 
ValidateStatus(const HttpResponseHeaders * headers)37 bool ValidateStatus(const HttpResponseHeaders* headers) {
38   return headers->GetStatusLine() == "HTTP/1.1 200";
39 }
40 
41 }  // namespace
42 
WebSocketHttp3HandshakeStream(std::unique_ptr<QuicChromiumClientSession::Handle> session,WebSocketStream::ConnectDelegate * connect_delegate,std::vector<std::string> requested_sub_protocols,std::vector<std::string> requested_extensions,WebSocketStreamRequestAPI * request,std::set<std::string> dns_aliases)43 WebSocketHttp3HandshakeStream::WebSocketHttp3HandshakeStream(
44     std::unique_ptr<QuicChromiumClientSession::Handle> session,
45     WebSocketStream::ConnectDelegate* connect_delegate,
46     std::vector<std::string> requested_sub_protocols,
47     std::vector<std::string> requested_extensions,
48     WebSocketStreamRequestAPI* request,
49     std::set<std::string> dns_aliases)
50     : session_(std::move(session)),
51       connect_delegate_(connect_delegate),
52       requested_sub_protocols_(std::move(requested_sub_protocols)),
53       requested_extensions_(std::move(requested_extensions)),
54       stream_request_(request),
55       dns_aliases_(std::move(dns_aliases)) {
56   DCHECK(connect_delegate);
57   DCHECK(request);
58 }
59 
~WebSocketHttp3HandshakeStream()60 WebSocketHttp3HandshakeStream::~WebSocketHttp3HandshakeStream() {
61   RecordHandshakeResult(result_);
62 }
63 
RegisterRequest(const HttpRequestInfo * request_info)64 void WebSocketHttp3HandshakeStream::RegisterRequest(
65     const HttpRequestInfo* request_info) {
66   DCHECK(request_info);
67   DCHECK(request_info->traffic_annotation.is_valid());
68   request_info_ = request_info;
69 }
70 
InitializeStream(bool can_send_early,RequestPriority priority,const NetLogWithSource & net_log,CompletionOnceCallback callback)71 int WebSocketHttp3HandshakeStream::InitializeStream(
72     bool can_send_early,
73     RequestPriority priority,
74     const NetLogWithSource& net_log,
75     CompletionOnceCallback callback) {
76   priority_ = priority;
77   net_log_ = net_log;
78   request_time_ = base::Time::Now();
79   return OK;
80 }
81 
SendRequest(const HttpRequestHeaders & request_headers,HttpResponseInfo * response,CompletionOnceCallback callback)82 int WebSocketHttp3HandshakeStream::SendRequest(
83     const HttpRequestHeaders& request_headers,
84     HttpResponseInfo* response,
85     CompletionOnceCallback callback) {
86   DCHECK(!request_headers.HasHeader(websockets::kSecWebSocketKey));
87   DCHECK(!request_headers.HasHeader(websockets::kSecWebSocketProtocol));
88   DCHECK(!request_headers.HasHeader(websockets::kSecWebSocketExtensions));
89   DCHECK(request_headers.HasHeader(HttpRequestHeaders::kOrigin));
90   DCHECK(request_headers.HasHeader(websockets::kUpgrade));
91   DCHECK(request_headers.HasHeader(HttpRequestHeaders::kConnection));
92   DCHECK(request_headers.HasHeader(websockets::kSecWebSocketVersion));
93 
94   if (!session_) {
95     constexpr int rv = ERR_CONNECTION_CLOSED;
96     OnFailure("Connection closed before sending request.", rv, absl::nullopt);
97     return rv;
98   }
99 
100   http_response_info_ = response;
101 
102   IPEndPoint address;
103   int result = session_->GetPeerAddress(&address);
104   if (result != OK) {
105     OnFailure("Error getting IP address.", result, absl::nullopt);
106     return result;
107   }
108   http_response_info_->remote_endpoint = address;
109 
110   auto request = std::make_unique<WebSocketHandshakeRequestInfo>(
111       request_info_->url, base::Time::Now());
112   request->headers = request_headers;
113 
114   AddVectorHeaderIfNonEmpty(websockets::kSecWebSocketExtensions,
115                             requested_extensions_, &request->headers);
116   AddVectorHeaderIfNonEmpty(websockets::kSecWebSocketProtocol,
117                             requested_sub_protocols_, &request->headers);
118 
119   CreateSpdyHeadersFromHttpRequestForWebSocket(
120       request_info_->url, request->headers, &http3_request_headers_);
121 
122   connect_delegate_->OnStartOpeningHandshake(std::move(request));
123 
124   callback_ = std::move(callback);
125 
126   std::unique_ptr<WebSocketQuicStreamAdapter> stream_adapter =
127       session_->CreateWebSocketQuicStreamAdapter(
128           this,
129           base::BindOnce(
130               &WebSocketHttp3HandshakeStream::ReceiveAdapterAndStartRequest,
131               base::Unretained(this)),
132           NetworkTrafficAnnotationTag(request_info_->traffic_annotation));
133   if (!stream_adapter) {
134     return ERR_IO_PENDING;
135   }
136   ReceiveAdapterAndStartRequest(std::move(stream_adapter));
137   return OK;
138 }
139 
ReadResponseHeaders(CompletionOnceCallback callback)140 int WebSocketHttp3HandshakeStream::ReadResponseHeaders(
141     CompletionOnceCallback callback) {
142   if (stream_closed_) {
143     return stream_error_;
144   }
145 
146   if (response_headers_complete_) {
147     return ValidateResponse();
148   }
149 
150   callback_ = std::move(callback);
151   return ERR_IO_PENDING;
152 }
153 
154 // TODO(momoka): Implement this.
ReadResponseBody(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)155 int WebSocketHttp3HandshakeStream::ReadResponseBody(
156     IOBuffer* buf,
157     int buf_len,
158     CompletionOnceCallback callback) {
159   return OK;
160 }
161 
Close(bool not_reusable)162 void WebSocketHttp3HandshakeStream::Close(bool not_reusable) {
163   if (stream_adapter_) {
164     stream_adapter_->Disconnect();
165     stream_closed_ = true;
166     stream_error_ = ERR_CONNECTION_CLOSED;
167   }
168 }
169 
170 // TODO(momoka): Implement this.
IsResponseBodyComplete() const171 bool WebSocketHttp3HandshakeStream::IsResponseBodyComplete() const {
172   return false;
173 }
174 
175 // TODO(momoka): Implement this.
IsConnectionReused() const176 bool WebSocketHttp3HandshakeStream::IsConnectionReused() const {
177   return true;
178 }
179 
180 // TODO(momoka): Implement this.
SetConnectionReused()181 void WebSocketHttp3HandshakeStream::SetConnectionReused() {}
182 
183 // TODO(momoka): Implement this.
CanReuseConnection() const184 bool WebSocketHttp3HandshakeStream::CanReuseConnection() const {
185   return false;
186 }
187 
188 // TODO(momoka): Implement this.
GetTotalReceivedBytes() const189 int64_t WebSocketHttp3HandshakeStream::GetTotalReceivedBytes() const {
190   return 0;
191 }
192 
193 // TODO(momoka): Implement this.
GetTotalSentBytes() const194 int64_t WebSocketHttp3HandshakeStream::GetTotalSentBytes() const {
195   return 0;
196 }
197 
198 // TODO(momoka): Implement this.
GetAlternativeService(AlternativeService * alternative_service) const199 bool WebSocketHttp3HandshakeStream::GetAlternativeService(
200     AlternativeService* alternative_service) const {
201   return false;
202 }
203 
204 // TODO(momoka): Implement this.
GetLoadTimingInfo(LoadTimingInfo * load_timing_info) const205 bool WebSocketHttp3HandshakeStream::GetLoadTimingInfo(
206     LoadTimingInfo* load_timing_info) const {
207   return false;
208 }
209 
210 // TODO(momoka): Implement this.
GetSSLInfo(SSLInfo * ssl_info)211 void WebSocketHttp3HandshakeStream::GetSSLInfo(SSLInfo* ssl_info) {}
212 
213 // TODO(momoka): Implement this.
GetSSLCertRequestInfo(SSLCertRequestInfo * cert_request_info)214 void WebSocketHttp3HandshakeStream::GetSSLCertRequestInfo(
215     SSLCertRequestInfo* cert_request_info) {}
216 
217 // TODO(momoka): Implement this.
GetRemoteEndpoint(IPEndPoint * endpoint)218 int WebSocketHttp3HandshakeStream::GetRemoteEndpoint(IPEndPoint* endpoint) {
219   return 0;
220 }
221 
222 // TODO(momoka): Implement this.
Drain(HttpNetworkSession * session)223 void WebSocketHttp3HandshakeStream::Drain(HttpNetworkSession* session) {}
224 
225 // TODO(momoka): Implement this.
SetPriority(RequestPriority priority)226 void WebSocketHttp3HandshakeStream::SetPriority(RequestPriority priority) {}
227 
228 // TODO(momoka): Implement this.
PopulateNetErrorDetails(NetErrorDetails * details)229 void WebSocketHttp3HandshakeStream::PopulateNetErrorDetails(
230     NetErrorDetails* details) {}
231 
232 // TODO(momoka): Implement this.
233 std::unique_ptr<HttpStream>
RenewStreamForAuth()234 WebSocketHttp3HandshakeStream::RenewStreamForAuth() {
235   return nullptr;
236 }
237 
238 // TODO(momoka): Implement this.
GetDnsAliases() const239 const std::set<std::string>& WebSocketHttp3HandshakeStream::GetDnsAliases()
240     const {
241   return dns_aliases_;
242 }
243 
244 // TODO(momoka): Implement this.
GetAcceptChViaAlps() const245 base::StringPiece WebSocketHttp3HandshakeStream::GetAcceptChViaAlps() const {
246   return {};
247 }
248 
249 // WebSocketHandshakeStreamBase methods.
250 
251 // TODO(momoka): Implement this.
Upgrade()252 std::unique_ptr<WebSocketStream> WebSocketHttp3HandshakeStream::Upgrade() {
253   DCHECK(extension_params_.get());
254 
255   stream_adapter_->clear_delegate();
256   std::unique_ptr<WebSocketStream> basic_stream =
257       std::make_unique<WebSocketBasicStream>(std::move(stream_adapter_),
258                                              nullptr, sub_protocol_,
259                                              extensions_, net_log_);
260 
261   if (!extension_params_->deflate_enabled) {
262     return basic_stream;
263   }
264 
265   return std::make_unique<WebSocketDeflateStream>(
266       std::move(basic_stream), extension_params_->deflate_parameters,
267       std::make_unique<WebSocketDeflatePredictorImpl>());
268 }
269 
CanReadFromStream() const270 bool WebSocketHttp3HandshakeStream::CanReadFromStream() const {
271   return stream_adapter_ && stream_adapter_->is_initialized();
272 }
273 
274 base::WeakPtr<WebSocketHandshakeStreamBase>
GetWeakPtr()275 WebSocketHttp3HandshakeStream::GetWeakPtr() {
276   return weak_ptr_factory_.GetWeakPtr();
277 }
278 
OnHeadersSent()279 void WebSocketHttp3HandshakeStream::OnHeadersSent() {
280   std::move(callback_).Run(OK);
281 }
282 
OnHeadersReceived(const spdy::Http2HeaderBlock & response_headers)283 void WebSocketHttp3HandshakeStream::OnHeadersReceived(
284     const spdy::Http2HeaderBlock& response_headers) {
285   DCHECK(!response_headers_complete_);
286   DCHECK(http_response_info_);
287 
288   response_headers_complete_ = true;
289 
290   const int rv =
291       SpdyHeadersToHttpResponse(response_headers, http_response_info_);
292   DCHECK_NE(rv, ERR_INCOMPLETE_HTTP2_HEADERS);
293 
294   // Do not store SSLInfo in the response here, HttpNetworkTransaction will take
295   // care of that part.
296   http_response_info_->was_alpn_negotiated = true;
297   http_response_info_->response_time = base::Time::Now();
298   http_response_info_->request_time = request_time_;
299   http_response_info_->connection_info = HttpConnectionInfo::kHTTP2;
300   http_response_info_->alpn_negotiated_protocol =
301       HttpConnectionInfoToString(http_response_info_->connection_info);
302 
303   if (callback_) {
304     std::move(callback_).Run(ValidateResponse());
305   }
306 }
307 
OnClose(int status)308 void WebSocketHttp3HandshakeStream::OnClose(int status) {
309   DCHECK(stream_adapter_);
310   DCHECK_GT(ERR_IO_PENDING, status);
311 
312   stream_closed_ = true;
313   stream_error_ = status;
314 
315   stream_adapter_.reset();
316 
317   // If response headers have already been received,
318   // then ValidateResponse() sets `result_`.
319   if (!response_headers_complete_) {
320     result_ = HandshakeResult::HTTP3_FAILED;
321   }
322 
323   OnFailure(base::StrCat({"Stream closed with error: ", ErrorToString(status)}),
324             status, absl::nullopt);
325 
326   if (callback_) {
327     std::move(callback_).Run(status);
328   }
329 }
330 
ReceiveAdapterAndStartRequest(std::unique_ptr<WebSocketQuicStreamAdapter> adapter)331 void WebSocketHttp3HandshakeStream::ReceiveAdapterAndStartRequest(
332     std::unique_ptr<WebSocketQuicStreamAdapter> adapter) {
333   stream_adapter_ = std::move(adapter);
334   // WriteHeaders returns synchronously.
335   stream_adapter_->WriteHeaders(std::move(http3_request_headers_), false);
336 }
337 
ValidateResponse()338 int WebSocketHttp3HandshakeStream::ValidateResponse() {
339   DCHECK(http_response_info_);
340   const HttpResponseHeaders* headers = http_response_info_->headers.get();
341   const int response_code = headers->response_code();
342   switch (response_code) {
343     case HTTP_OK:
344       return ValidateUpgradeResponse(headers);
345 
346     // We need to pass these through for authentication to work.
347     case HTTP_UNAUTHORIZED:
348     case HTTP_PROXY_AUTHENTICATION_REQUIRED:
349       return OK;
350 
351     // Other status codes are potentially risky (see the warnings in the
352     // WHATWG WebSocket API spec) and so are dropped by default.
353     default:
354       OnFailure(
355           base::StringPrintf(
356               "Error during WebSocket handshake: Unexpected response code: %d",
357               headers->response_code()),
358           ERR_FAILED, headers->response_code());
359       result_ = HandshakeResult::HTTP3_INVALID_STATUS;
360       return ERR_INVALID_RESPONSE;
361   }
362 }
363 
ValidateUpgradeResponse(const HttpResponseHeaders * headers)364 int WebSocketHttp3HandshakeStream::ValidateUpgradeResponse(
365     const HttpResponseHeaders* headers) {
366   extension_params_ = std::make_unique<WebSocketExtensionParams>();
367   std::string failure_message;
368   if (!ValidateStatus(headers)) {
369     result_ = HandshakeResult::HTTP3_INVALID_STATUS;
370   } else if (!ValidateSubProtocol(headers, requested_sub_protocols_,
371                                   &sub_protocol_, &failure_message)) {
372     result_ = HandshakeResult::HTTP3_FAILED_SUBPROTO;
373   } else if (!ValidateExtensions(headers, &extensions_, &failure_message,
374                                  extension_params_.get())) {
375     result_ = HandshakeResult::HTTP3_FAILED_EXTENSIONS;
376   } else {
377     result_ = HandshakeResult::HTTP3_CONNECTED;
378     return OK;
379   }
380 
381   const int rv = ERR_INVALID_RESPONSE;
382   OnFailure("Error during WebSocket handshake: " + failure_message, rv,
383             absl::nullopt);
384   return rv;
385 }
386 
387 // TODO(momoka): Implement this.
OnFailure(const std::string & message,int net_error,absl::optional<int> response_code)388 void WebSocketHttp3HandshakeStream::OnFailure(
389     const std::string& message,
390     int net_error,
391     absl::optional<int> response_code) {
392   stream_request_->OnFailure(message, net_error, response_code);
393 }
394 
395 }  // namespace net
396