• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2023 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/websockets/websocket_http3_handshake_stream.h"
6 
7 #include <utility>
8 
9 #include "base/strings/stringprintf.h"
10 #include "base/time/time.h"
11 #include "net/base/ip_endpoint.h"
12 #include "net/http/http_request_headers.h"
13 #include "net/http/http_request_info.h"
14 #include "net/http/http_response_headers.h"
15 #include "net/http/http_status_code.h"
16 #include "net/spdy/spdy_http_utils.h"
17 #include "net/traffic_annotation/network_traffic_annotation.h"
18 #include "net/websockets/websocket_deflate_predictor_impl.h"
19 #include "net/websockets/websocket_deflate_stream.h"
20 #include "net/websockets/websocket_handshake_constants.h"
21 
22 namespace net {
23 
24 namespace {
25 
ValidateStatus(const HttpResponseHeaders * headers)26 bool ValidateStatus(const HttpResponseHeaders* headers) {
27   return headers->GetStatusLine() == "HTTP/1.1 200";
28 }
29 
30 }  // namespace
31 
WebSocketHttp3HandshakeStream(std::unique_ptr<QuicChromiumClientSession::Handle> session,WebSocketStream::ConnectDelegate * connect_delegate,std::vector<std::string> requested_sub_protocols,std::vector<std::string> requested_extensions,WebSocketStreamRequestAPI * request,std::set<std::string> dns_aliases)32 WebSocketHttp3HandshakeStream::WebSocketHttp3HandshakeStream(
33     std::unique_ptr<QuicChromiumClientSession::Handle> session,
34     WebSocketStream::ConnectDelegate* connect_delegate,
35     std::vector<std::string> requested_sub_protocols,
36     std::vector<std::string> requested_extensions,
37     WebSocketStreamRequestAPI* request,
38     std::set<std::string> dns_aliases)
39     : session_(std::move(session)),
40       connect_delegate_(connect_delegate),
41       requested_sub_protocols_(std::move(requested_sub_protocols)),
42       requested_extensions_(std::move(requested_extensions)),
43       stream_request_(request),
44       dns_aliases_(std::move(dns_aliases)) {
45   DCHECK(connect_delegate);
46   DCHECK(request);
47 }
48 
~WebSocketHttp3HandshakeStream()49 WebSocketHttp3HandshakeStream::~WebSocketHttp3HandshakeStream() {
50   RecordHandshakeResult(result_);
51 }
52 
RegisterRequest(const HttpRequestInfo * request_info)53 void WebSocketHttp3HandshakeStream::RegisterRequest(
54     const HttpRequestInfo* request_info) {
55   DCHECK(request_info);
56   DCHECK(request_info->traffic_annotation.is_valid());
57   request_info_ = request_info;
58 }
59 
InitializeStream(bool can_send_early,RequestPriority priority,const NetLogWithSource & net_log,CompletionOnceCallback callback)60 int WebSocketHttp3HandshakeStream::InitializeStream(
61     bool can_send_early,
62     RequestPriority priority,
63     const NetLogWithSource& net_log,
64     CompletionOnceCallback callback) {
65   priority_ = priority;
66   net_log_ = net_log;
67   request_time_ = base::Time::Now();
68   return OK;
69 }
70 
SendRequest(const HttpRequestHeaders & request_headers,HttpResponseInfo * response,CompletionOnceCallback callback)71 int WebSocketHttp3HandshakeStream::SendRequest(
72     const HttpRequestHeaders& request_headers,
73     HttpResponseInfo* response,
74     CompletionOnceCallback callback) {
75   DCHECK(!request_headers.HasHeader(websockets::kSecWebSocketKey));
76   DCHECK(!request_headers.HasHeader(websockets::kSecWebSocketProtocol));
77   DCHECK(!request_headers.HasHeader(websockets::kSecWebSocketExtensions));
78   DCHECK(request_headers.HasHeader(HttpRequestHeaders::kOrigin));
79   DCHECK(request_headers.HasHeader(websockets::kUpgrade));
80   DCHECK(request_headers.HasHeader(HttpRequestHeaders::kConnection));
81   DCHECK(request_headers.HasHeader(websockets::kSecWebSocketVersion));
82 
83   if (!session_) {
84     constexpr int rv = ERR_CONNECTION_CLOSED;
85     OnFailure("Connection closed before sending request.", rv, absl::nullopt);
86     return rv;
87   }
88 
89   http_response_info_ = response;
90 
91   IPEndPoint address;
92   int result = session_->GetPeerAddress(&address);
93   if (result != OK) {
94     OnFailure("Error getting IP address.", result, absl::nullopt);
95     return result;
96   }
97   http_response_info_->remote_endpoint = address;
98 
99   auto request = std::make_unique<WebSocketHandshakeRequestInfo>(
100       request_info_->url, base::Time::Now());
101   request->headers.CopyFrom(request_headers);
102 
103   AddVectorHeaderIfNonEmpty(websockets::kSecWebSocketExtensions,
104                             requested_extensions_, &request->headers);
105   AddVectorHeaderIfNonEmpty(websockets::kSecWebSocketProtocol,
106                             requested_sub_protocols_, &request->headers);
107 
108   CreateSpdyHeadersFromHttpRequestForWebSocket(
109       request_info_->url, request->headers, &http3_request_headers_);
110 
111   connect_delegate_->OnStartOpeningHandshake(std::move(request));
112 
113   callback_ = std::move(callback);
114 
115   std::unique_ptr<WebSocketQuicStreamAdapter> stream_adapter =
116       session_->CreateWebSocketQuicStreamAdapter(
117           this,
118           base::BindOnce(
119               &WebSocketHttp3HandshakeStream::ReceiveAdapterAndStartRequest,
120               base::Unretained(this)),
121           NetworkTrafficAnnotationTag(request_info_->traffic_annotation));
122   if (!stream_adapter) {
123     return ERR_IO_PENDING;
124   }
125   ReceiveAdapterAndStartRequest(std::move(stream_adapter));
126   return OK;
127 }
128 
ReadResponseHeaders(CompletionOnceCallback callback)129 int WebSocketHttp3HandshakeStream::ReadResponseHeaders(
130     CompletionOnceCallback callback) {
131   if (stream_closed_) {
132     return stream_error_;
133   }
134 
135   if (response_headers_complete_) {
136     return ValidateResponse();
137   }
138 
139   callback_ = std::move(callback);
140   return ERR_IO_PENDING;
141 }
142 
143 // TODO(momoka): Implement this.
ReadResponseBody(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)144 int WebSocketHttp3HandshakeStream::ReadResponseBody(
145     IOBuffer* buf,
146     int buf_len,
147     CompletionOnceCallback callback) {
148   return OK;
149 }
150 
Close(bool not_reusable)151 void WebSocketHttp3HandshakeStream::Close(bool not_reusable) {
152   if (stream_adapter_) {
153     stream_adapter_->Disconnect();
154     stream_closed_ = true;
155     stream_error_ = ERR_CONNECTION_CLOSED;
156   }
157 }
158 
159 // TODO(momoka): Implement this.
IsResponseBodyComplete() const160 bool WebSocketHttp3HandshakeStream::IsResponseBodyComplete() const {
161   return false;
162 }
163 
164 // TODO(momoka): Implement this.
IsConnectionReused() const165 bool WebSocketHttp3HandshakeStream::IsConnectionReused() const {
166   return true;
167 }
168 
169 // TODO(momoka): Implement this.
SetConnectionReused()170 void WebSocketHttp3HandshakeStream::SetConnectionReused() {}
171 
172 // TODO(momoka): Implement this.
CanReuseConnection() const173 bool WebSocketHttp3HandshakeStream::CanReuseConnection() const {
174   return false;
175 }
176 
177 // TODO(momoka): Implement this.
GetTotalReceivedBytes() const178 int64_t WebSocketHttp3HandshakeStream::GetTotalReceivedBytes() const {
179   return 0;
180 }
181 
182 // TODO(momoka): Implement this.
GetTotalSentBytes() const183 int64_t WebSocketHttp3HandshakeStream::GetTotalSentBytes() const {
184   return 0;
185 }
186 
187 // TODO(momoka): Implement this.
GetAlternativeService(AlternativeService * alternative_service) const188 bool WebSocketHttp3HandshakeStream::GetAlternativeService(
189     AlternativeService* alternative_service) const {
190   return false;
191 }
192 
193 // TODO(momoka): Implement this.
GetLoadTimingInfo(LoadTimingInfo * load_timing_info) const194 bool WebSocketHttp3HandshakeStream::GetLoadTimingInfo(
195     LoadTimingInfo* load_timing_info) const {
196   return false;
197 }
198 
199 // TODO(momoka): Implement this.
GetSSLInfo(SSLInfo * ssl_info)200 void WebSocketHttp3HandshakeStream::GetSSLInfo(SSLInfo* ssl_info) {}
201 
202 // TODO(momoka): Implement this.
GetSSLCertRequestInfo(SSLCertRequestInfo * cert_request_info)203 void WebSocketHttp3HandshakeStream::GetSSLCertRequestInfo(
204     SSLCertRequestInfo* cert_request_info) {}
205 
206 // TODO(momoka): Implement this.
GetRemoteEndpoint(IPEndPoint * endpoint)207 int WebSocketHttp3HandshakeStream::GetRemoteEndpoint(IPEndPoint* endpoint) {
208   return 0;
209 }
210 
211 // TODO(momoka): Implement this.
Drain(HttpNetworkSession * session)212 void WebSocketHttp3HandshakeStream::Drain(HttpNetworkSession* session) {}
213 
214 // TODO(momoka): Implement this.
SetPriority(RequestPriority priority)215 void WebSocketHttp3HandshakeStream::SetPriority(RequestPriority priority) {}
216 
217 // TODO(momoka): Implement this.
PopulateNetErrorDetails(NetErrorDetails * details)218 void WebSocketHttp3HandshakeStream::PopulateNetErrorDetails(
219     NetErrorDetails* details) {}
220 
221 // TODO(momoka): Implement this.
222 std::unique_ptr<HttpStream>
RenewStreamForAuth()223 WebSocketHttp3HandshakeStream::RenewStreamForAuth() {
224   return nullptr;
225 }
226 
227 // TODO(momoka): Implement this.
GetDnsAliases() const228 const std::set<std::string>& WebSocketHttp3HandshakeStream::GetDnsAliases()
229     const {
230   return dns_aliases_;
231 }
232 
233 // TODO(momoka): Implement this.
GetAcceptChViaAlps() const234 base::StringPiece WebSocketHttp3HandshakeStream::GetAcceptChViaAlps() const {
235   return {};
236 }
237 
238 // WebSocketHandshakeStreamBase methods.
239 
240 // TODO(momoka): Implement this.
Upgrade()241 std::unique_ptr<WebSocketStream> WebSocketHttp3HandshakeStream::Upgrade() {
242   DCHECK(extension_params_.get());
243 
244   stream_adapter_->clear_delegate();
245   std::unique_ptr<WebSocketStream> basic_stream =
246       std::make_unique<WebSocketBasicStream>(std::move(stream_adapter_),
247                                              nullptr, sub_protocol_,
248                                              extensions_, net_log_);
249 
250   if (!extension_params_->deflate_enabled) {
251     return basic_stream;
252   }
253 
254   return std::make_unique<WebSocketDeflateStream>(
255       std::move(basic_stream), extension_params_->deflate_parameters,
256       std::make_unique<WebSocketDeflatePredictorImpl>());
257 }
258 
259 base::WeakPtr<WebSocketHandshakeStreamBase>
GetWeakPtr()260 WebSocketHttp3HandshakeStream::GetWeakPtr() {
261   return weak_ptr_factory_.GetWeakPtr();
262 }
263 
OnHeadersSent()264 void WebSocketHttp3HandshakeStream::OnHeadersSent() {
265   std::move(callback_).Run(OK);
266 }
267 
OnHeadersReceived(const spdy::Http2HeaderBlock & response_headers)268 void WebSocketHttp3HandshakeStream::OnHeadersReceived(
269     const spdy::Http2HeaderBlock& response_headers) {
270   DCHECK(!response_headers_complete_);
271   DCHECK(http_response_info_);
272 
273   response_headers_complete_ = true;
274 
275   const int rv =
276       SpdyHeadersToHttpResponse(response_headers, http_response_info_);
277   DCHECK_NE(rv, ERR_INCOMPLETE_HTTP2_HEADERS);
278 
279   // Do not store SSLInfo in the response here, HttpNetworkTransaction will take
280   // care of that part.
281   http_response_info_->was_alpn_negotiated = true;
282   http_response_info_->response_time = base::Time::Now();
283   http_response_info_->request_time = request_time_;
284   http_response_info_->connection_info =
285       HttpResponseInfo::CONNECTION_INFO_HTTP2;
286   http_response_info_->alpn_negotiated_protocol =
287       HttpResponseInfo::ConnectionInfoToString(
288           http_response_info_->connection_info);
289 
290   if (callback_) {
291     std::move(callback_).Run(ValidateResponse());
292   }
293 }
294 
OnClose(int status)295 void WebSocketHttp3HandshakeStream::OnClose(int status) {
296   DCHECK(stream_adapter_);
297   DCHECK_GT(ERR_IO_PENDING, status);
298 
299   stream_closed_ = true;
300   stream_error_ = status;
301 
302   stream_adapter_.reset();
303 
304   // If response headers have already been received,
305   // then ValidateResponse() sets `result_`.
306   if (!response_headers_complete_) {
307     result_ = HandshakeResult::HTTP3_FAILED;
308   }
309 
310   OnFailure(std::string("Stream closed with error: ") + ErrorToString(status),
311             status, absl::nullopt);
312 
313   if (callback_) {
314     std::move(callback_).Run(status);
315   }
316 }
317 
ReceiveAdapterAndStartRequest(std::unique_ptr<WebSocketQuicStreamAdapter> adapter)318 void WebSocketHttp3HandshakeStream::ReceiveAdapterAndStartRequest(
319     std::unique_ptr<WebSocketQuicStreamAdapter> adapter) {
320   stream_adapter_ = std::move(adapter);
321   // WriteHeaders returns synchronously.
322   stream_adapter_->WriteHeaders(std::move(http3_request_headers_), false);
323 }
324 
ValidateResponse()325 int WebSocketHttp3HandshakeStream::ValidateResponse() {
326   DCHECK(http_response_info_);
327   const HttpResponseHeaders* headers = http_response_info_->headers.get();
328   const int response_code = headers->response_code();
329   switch (response_code) {
330     case HTTP_OK:
331       return ValidateUpgradeResponse(headers);
332 
333     // We need to pass these through for authentication to work.
334     case HTTP_UNAUTHORIZED:
335     case HTTP_PROXY_AUTHENTICATION_REQUIRED:
336       return OK;
337 
338     // Other status codes are potentially risky (see the warnings in the
339     // WHATWG WebSocket API spec) and so are dropped by default.
340     default:
341       OnFailure(
342           base::StringPrintf(
343               "Error during WebSocket handshake: Unexpected response code: %d",
344               headers->response_code()),
345           ERR_FAILED, headers->response_code());
346       result_ = HandshakeResult::HTTP3_INVALID_STATUS;
347       return ERR_INVALID_RESPONSE;
348   }
349 }
350 
ValidateUpgradeResponse(const HttpResponseHeaders * headers)351 int WebSocketHttp3HandshakeStream::ValidateUpgradeResponse(
352     const HttpResponseHeaders* headers) {
353   extension_params_ = std::make_unique<WebSocketExtensionParams>();
354   std::string failure_message;
355   if (!ValidateStatus(headers)) {
356     result_ = HandshakeResult::HTTP3_INVALID_STATUS;
357   } else if (!ValidateSubProtocol(headers, requested_sub_protocols_,
358                                   &sub_protocol_, &failure_message)) {
359     result_ = HandshakeResult::HTTP3_FAILED_SUBPROTO;
360   } else if (!ValidateExtensions(headers, &extensions_, &failure_message,
361                                  extension_params_.get())) {
362     result_ = HandshakeResult::HTTP3_FAILED_EXTENSIONS;
363   } else {
364     result_ = HandshakeResult::HTTP3_CONNECTED;
365     return OK;
366   }
367 
368   const int rv = ERR_INVALID_RESPONSE;
369   OnFailure("Error during WebSocket handshake: " + failure_message, rv,
370             absl::nullopt);
371   return rv;
372 }
373 
374 // TODO(momoka): Implement this.
OnFailure(const std::string & message,int net_error,absl::optional<int> response_code)375 void WebSocketHttp3HandshakeStream::OnFailure(
376     const std::string& message,
377     int net_error,
378     absl::optional<int> response_code) {
379   stream_request_->OnFailure(message, net_error, response_code);
380 }
381 
382 }  // namespace net
383