1 // Copyright 2023 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/websockets/websocket_http3_handshake_stream.h"
6
7 #include <string_view>
8 #include <utility>
9
10 #include "base/check.h"
11 #include "base/check_op.h"
12 #include "base/functional/bind.h"
13 #include "base/functional/callback.h"
14 #include "base/memory/scoped_refptr.h"
15 #include "base/strings/strcat.h"
16 #include "base/strings/stringprintf.h"
17 #include "base/time/time.h"
18 #include "net/base/ip_endpoint.h"
19 #include "net/http/http_request_headers.h"
20 #include "net/http/http_request_info.h"
21 #include "net/http/http_response_headers.h"
22 #include "net/http/http_response_info.h"
23 #include "net/http/http_status_code.h"
24 #include "net/spdy/spdy_http_utils.h"
25 #include "net/traffic_annotation/network_traffic_annotation.h"
26 #include "net/websockets/websocket_basic_stream.h"
27 #include "net/websockets/websocket_deflate_predictor_impl.h"
28 #include "net/websockets/websocket_deflate_stream.h"
29 #include "net/websockets/websocket_handshake_constants.h"
30 #include "net/websockets/websocket_handshake_request_info.h"
31
32 namespace net {
33 struct AlternativeService;
34
35 namespace {
36
ValidateStatus(const HttpResponseHeaders * headers)37 bool ValidateStatus(const HttpResponseHeaders* headers) {
38 return headers->GetStatusLine() == "HTTP/1.1 200";
39 }
40
41 } // namespace
42
WebSocketHttp3HandshakeStream(std::unique_ptr<QuicChromiumClientSession::Handle> session,WebSocketStream::ConnectDelegate * connect_delegate,std::vector<std::string> requested_sub_protocols,std::vector<std::string> requested_extensions,WebSocketStreamRequestAPI * request,std::set<std::string> dns_aliases)43 WebSocketHttp3HandshakeStream::WebSocketHttp3HandshakeStream(
44 std::unique_ptr<QuicChromiumClientSession::Handle> session,
45 WebSocketStream::ConnectDelegate* connect_delegate,
46 std::vector<std::string> requested_sub_protocols,
47 std::vector<std::string> requested_extensions,
48 WebSocketStreamRequestAPI* request,
49 std::set<std::string> dns_aliases)
50 : session_(std::move(session)),
51 connect_delegate_(connect_delegate),
52 requested_sub_protocols_(std::move(requested_sub_protocols)),
53 requested_extensions_(std::move(requested_extensions)),
54 stream_request_(request),
55 dns_aliases_(std::move(dns_aliases)) {
56 DCHECK(connect_delegate);
57 DCHECK(request);
58 }
59
~WebSocketHttp3HandshakeStream()60 WebSocketHttp3HandshakeStream::~WebSocketHttp3HandshakeStream() {
61 RecordHandshakeResult(result_);
62 }
63
RegisterRequest(const HttpRequestInfo * request_info)64 void WebSocketHttp3HandshakeStream::RegisterRequest(
65 const HttpRequestInfo* request_info) {
66 DCHECK(request_info);
67 DCHECK(request_info->traffic_annotation.is_valid());
68 request_info_ = request_info;
69 }
70
InitializeStream(bool can_send_early,RequestPriority priority,const NetLogWithSource & net_log,CompletionOnceCallback callback)71 int WebSocketHttp3HandshakeStream::InitializeStream(
72 bool can_send_early,
73 RequestPriority priority,
74 const NetLogWithSource& net_log,
75 CompletionOnceCallback callback) {
76 priority_ = priority;
77 net_log_ = net_log;
78 request_time_ = base::Time::Now();
79 return OK;
80 }
81
SendRequest(const HttpRequestHeaders & request_headers,HttpResponseInfo * response,CompletionOnceCallback callback)82 int WebSocketHttp3HandshakeStream::SendRequest(
83 const HttpRequestHeaders& request_headers,
84 HttpResponseInfo* response,
85 CompletionOnceCallback callback) {
86 DCHECK(!request_headers.HasHeader(websockets::kSecWebSocketKey));
87 DCHECK(!request_headers.HasHeader(websockets::kSecWebSocketProtocol));
88 DCHECK(!request_headers.HasHeader(websockets::kSecWebSocketExtensions));
89 DCHECK(request_headers.HasHeader(HttpRequestHeaders::kOrigin));
90 DCHECK(request_headers.HasHeader(websockets::kUpgrade));
91 DCHECK(request_headers.HasHeader(HttpRequestHeaders::kConnection));
92 DCHECK(request_headers.HasHeader(websockets::kSecWebSocketVersion));
93
94 if (!session_) {
95 constexpr int rv = ERR_CONNECTION_CLOSED;
96 OnFailure("Connection closed before sending request.", rv, std::nullopt);
97 return rv;
98 }
99
100 http_response_info_ = response;
101
102 IPEndPoint address;
103 int result = session_->GetPeerAddress(&address);
104 if (result != OK) {
105 OnFailure("Error getting IP address.", result, std::nullopt);
106 return result;
107 }
108 http_response_info_->remote_endpoint = address;
109
110 auto request = std::make_unique<WebSocketHandshakeRequestInfo>(
111 request_info_->url, base::Time::Now());
112 request->headers = request_headers;
113
114 AddVectorHeaders(requested_extensions_, requested_sub_protocols_,
115 &request->headers);
116
117 CreateSpdyHeadersFromHttpRequestForWebSocket(
118 request_info_->url, request->headers, &http3_request_headers_);
119
120 connect_delegate_->OnStartOpeningHandshake(std::move(request));
121
122 callback_ = std::move(callback);
123
124 std::unique_ptr<WebSocketQuicStreamAdapter> stream_adapter =
125 session_->CreateWebSocketQuicStreamAdapter(
126 this,
127 base::BindOnce(
128 &WebSocketHttp3HandshakeStream::ReceiveAdapterAndStartRequest,
129 base::Unretained(this)),
130 NetworkTrafficAnnotationTag(request_info_->traffic_annotation));
131 if (!stream_adapter) {
132 return ERR_IO_PENDING;
133 }
134 ReceiveAdapterAndStartRequest(std::move(stream_adapter));
135 return OK;
136 }
137
ReadResponseHeaders(CompletionOnceCallback callback)138 int WebSocketHttp3HandshakeStream::ReadResponseHeaders(
139 CompletionOnceCallback callback) {
140 if (stream_closed_) {
141 return stream_error_;
142 }
143
144 if (response_headers_complete_) {
145 return ValidateResponse();
146 }
147
148 callback_ = std::move(callback);
149 return ERR_IO_PENDING;
150 }
151
152 // TODO(momoka): Implement this.
ReadResponseBody(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)153 int WebSocketHttp3HandshakeStream::ReadResponseBody(
154 IOBuffer* buf,
155 int buf_len,
156 CompletionOnceCallback callback) {
157 return OK;
158 }
159
Close(bool not_reusable)160 void WebSocketHttp3HandshakeStream::Close(bool not_reusable) {
161 if (stream_adapter_) {
162 stream_adapter_->Disconnect();
163 stream_closed_ = true;
164 stream_error_ = ERR_CONNECTION_CLOSED;
165 }
166 }
167
168 // TODO(momoka): Implement this.
IsResponseBodyComplete() const169 bool WebSocketHttp3HandshakeStream::IsResponseBodyComplete() const {
170 return false;
171 }
172
173 // TODO(momoka): Implement this.
IsConnectionReused() const174 bool WebSocketHttp3HandshakeStream::IsConnectionReused() const {
175 return true;
176 }
177
178 // TODO(momoka): Implement this.
SetConnectionReused()179 void WebSocketHttp3HandshakeStream::SetConnectionReused() {}
180
181 // TODO(momoka): Implement this.
CanReuseConnection() const182 bool WebSocketHttp3HandshakeStream::CanReuseConnection() const {
183 return false;
184 }
185
186 // TODO(momoka): Implement this.
GetTotalReceivedBytes() const187 int64_t WebSocketHttp3HandshakeStream::GetTotalReceivedBytes() const {
188 return 0;
189 }
190
191 // TODO(momoka): Implement this.
GetTotalSentBytes() const192 int64_t WebSocketHttp3HandshakeStream::GetTotalSentBytes() const {
193 return 0;
194 }
195
196 // TODO(momoka): Implement this.
GetAlternativeService(AlternativeService * alternative_service) const197 bool WebSocketHttp3HandshakeStream::GetAlternativeService(
198 AlternativeService* alternative_service) const {
199 return false;
200 }
201
202 // TODO(momoka): Implement this.
GetLoadTimingInfo(LoadTimingInfo * load_timing_info) const203 bool WebSocketHttp3HandshakeStream::GetLoadTimingInfo(
204 LoadTimingInfo* load_timing_info) const {
205 return false;
206 }
207
208 // TODO(momoka): Implement this.
GetSSLInfo(SSLInfo * ssl_info)209 void WebSocketHttp3HandshakeStream::GetSSLInfo(SSLInfo* ssl_info) {}
210
211 // TODO(momoka): Implement this.
GetRemoteEndpoint(IPEndPoint * endpoint)212 int WebSocketHttp3HandshakeStream::GetRemoteEndpoint(IPEndPoint* endpoint) {
213 return 0;
214 }
215
216 // TODO(momoka): Implement this.
Drain(HttpNetworkSession * session)217 void WebSocketHttp3HandshakeStream::Drain(HttpNetworkSession* session) {}
218
219 // TODO(momoka): Implement this.
SetPriority(RequestPriority priority)220 void WebSocketHttp3HandshakeStream::SetPriority(RequestPriority priority) {}
221
222 // TODO(momoka): Implement this.
PopulateNetErrorDetails(NetErrorDetails * details)223 void WebSocketHttp3HandshakeStream::PopulateNetErrorDetails(
224 NetErrorDetails* details) {}
225
226 // TODO(momoka): Implement this.
227 std::unique_ptr<HttpStream>
RenewStreamForAuth()228 WebSocketHttp3HandshakeStream::RenewStreamForAuth() {
229 return nullptr;
230 }
231
232 // TODO(momoka): Implement this.
GetDnsAliases() const233 const std::set<std::string>& WebSocketHttp3HandshakeStream::GetDnsAliases()
234 const {
235 return dns_aliases_;
236 }
237
238 // TODO(momoka): Implement this.
GetAcceptChViaAlps() const239 std::string_view WebSocketHttp3HandshakeStream::GetAcceptChViaAlps() const {
240 return {};
241 }
242
243 // WebSocketHandshakeStreamBase methods.
244
245 // TODO(momoka): Implement this.
Upgrade()246 std::unique_ptr<WebSocketStream> WebSocketHttp3HandshakeStream::Upgrade() {
247 DCHECK(extension_params_.get());
248
249 stream_adapter_->clear_delegate();
250 std::unique_ptr<WebSocketStream> basic_stream =
251 std::make_unique<WebSocketBasicStream>(std::move(stream_adapter_),
252 nullptr, sub_protocol_,
253 extensions_, net_log_);
254
255 if (!extension_params_->deflate_enabled) {
256 return basic_stream;
257 }
258
259 return std::make_unique<WebSocketDeflateStream>(
260 std::move(basic_stream), extension_params_->deflate_parameters,
261 std::make_unique<WebSocketDeflatePredictorImpl>());
262 }
263
CanReadFromStream() const264 bool WebSocketHttp3HandshakeStream::CanReadFromStream() const {
265 return stream_adapter_ && stream_adapter_->is_initialized();
266 }
267
268 base::WeakPtr<WebSocketHandshakeStreamBase>
GetWeakPtr()269 WebSocketHttp3HandshakeStream::GetWeakPtr() {
270 return weak_ptr_factory_.GetWeakPtr();
271 }
272
OnHeadersSent()273 void WebSocketHttp3HandshakeStream::OnHeadersSent() {
274 std::move(callback_).Run(OK);
275 }
276
OnHeadersReceived(const quiche::HttpHeaderBlock & response_headers)277 void WebSocketHttp3HandshakeStream::OnHeadersReceived(
278 const quiche::HttpHeaderBlock& response_headers) {
279 DCHECK(!response_headers_complete_);
280 DCHECK(http_response_info_);
281
282 response_headers_complete_ = true;
283
284 const int rv =
285 SpdyHeadersToHttpResponse(response_headers, http_response_info_);
286 DCHECK_NE(rv, ERR_INCOMPLETE_HTTP2_HEADERS);
287
288 // Do not store SSLInfo in the response here, HttpNetworkTransaction will take
289 // care of that part.
290 http_response_info_->was_alpn_negotiated = true;
291 http_response_info_->response_time =
292 http_response_info_->original_response_time = base::Time::Now();
293 http_response_info_->request_time = request_time_;
294 http_response_info_->connection_info = HttpConnectionInfo::kHTTP2;
295 http_response_info_->alpn_negotiated_protocol =
296 HttpConnectionInfoToString(http_response_info_->connection_info);
297
298 if (callback_) {
299 std::move(callback_).Run(ValidateResponse());
300 }
301 }
302
OnClose(int status)303 void WebSocketHttp3HandshakeStream::OnClose(int status) {
304 DCHECK(stream_adapter_);
305 DCHECK_GT(ERR_IO_PENDING, status);
306
307 stream_closed_ = true;
308 stream_error_ = status;
309
310 stream_adapter_.reset();
311
312 // If response headers have already been received,
313 // then ValidateResponse() sets `result_`.
314 if (!response_headers_complete_) {
315 result_ = HandshakeResult::HTTP3_FAILED;
316 }
317
318 OnFailure(base::StrCat({"Stream closed with error: ", ErrorToString(status)}),
319 status, std::nullopt);
320
321 if (callback_) {
322 std::move(callback_).Run(status);
323 }
324 }
325
ReceiveAdapterAndStartRequest(std::unique_ptr<WebSocketQuicStreamAdapter> adapter)326 void WebSocketHttp3HandshakeStream::ReceiveAdapterAndStartRequest(
327 std::unique_ptr<WebSocketQuicStreamAdapter> adapter) {
328 stream_adapter_ = std::move(adapter);
329 // WriteHeaders returns synchronously.
330 stream_adapter_->WriteHeaders(std::move(http3_request_headers_), false);
331 }
332
ValidateResponse()333 int WebSocketHttp3HandshakeStream::ValidateResponse() {
334 DCHECK(http_response_info_);
335 const HttpResponseHeaders* headers = http_response_info_->headers.get();
336 const int response_code = headers->response_code();
337 switch (response_code) {
338 case HTTP_OK:
339 return ValidateUpgradeResponse(headers);
340
341 // We need to pass these through for authentication to work.
342 case HTTP_UNAUTHORIZED:
343 case HTTP_PROXY_AUTHENTICATION_REQUIRED:
344 return OK;
345
346 // Other status codes are potentially risky (see the warnings in the
347 // WHATWG WebSocket API spec) and so are dropped by default.
348 default:
349 OnFailure(
350 base::StringPrintf(
351 "Error during WebSocket handshake: Unexpected response code: %d",
352 headers->response_code()),
353 ERR_FAILED, headers->response_code());
354 result_ = HandshakeResult::HTTP3_INVALID_STATUS;
355 return ERR_INVALID_RESPONSE;
356 }
357 }
358
ValidateUpgradeResponse(const HttpResponseHeaders * headers)359 int WebSocketHttp3HandshakeStream::ValidateUpgradeResponse(
360 const HttpResponseHeaders* headers) {
361 extension_params_ = std::make_unique<WebSocketExtensionParams>();
362 std::string failure_message;
363 if (!ValidateStatus(headers)) {
364 result_ = HandshakeResult::HTTP3_INVALID_STATUS;
365 } else if (!ValidateSubProtocol(headers, requested_sub_protocols_,
366 &sub_protocol_, &failure_message)) {
367 result_ = HandshakeResult::HTTP3_FAILED_SUBPROTO;
368 } else if (!ValidateExtensions(headers, &extensions_, &failure_message,
369 extension_params_.get())) {
370 result_ = HandshakeResult::HTTP3_FAILED_EXTENSIONS;
371 } else {
372 result_ = HandshakeResult::HTTP3_CONNECTED;
373 return OK;
374 }
375
376 const int rv = ERR_INVALID_RESPONSE;
377 OnFailure("Error during WebSocket handshake: " + failure_message, rv,
378 std::nullopt);
379 return rv;
380 }
381
382 // TODO(momoka): Implement this.
OnFailure(const std::string & message,int net_error,std::optional<int> response_code)383 void WebSocketHttp3HandshakeStream::OnFailure(
384 const std::string& message,
385 int net_error,
386 std::optional<int> response_code) {
387 stream_request_->OnFailure(message, net_error, response_code);
388 }
389
390 } // namespace net
391