1 // Copyright 2018 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/websockets/websocket_http2_handshake_stream.h"
6
7 #include <set>
8 #include <string_view>
9 #include <utility>
10
11 #include "base/check.h"
12 #include "base/check_op.h"
13 #include "base/functional/bind.h"
14 #include "base/functional/callback.h"
15 #include "base/memory/scoped_refptr.h"
16 #include "base/notreached.h"
17 #include "base/strings/strcat.h"
18 #include "base/strings/stringprintf.h"
19 #include "base/time/time.h"
20 #include "net/base/ip_endpoint.h"
21 #include "net/http/http_request_headers.h"
22 #include "net/http/http_request_info.h"
23 #include "net/http/http_response_headers.h"
24 #include "net/http/http_response_info.h"
25 #include "net/http/http_status_code.h"
26 #include "net/spdy/spdy_http_utils.h"
27 #include "net/spdy/spdy_session.h"
28 #include "net/spdy/spdy_stream.h"
29 #include "net/traffic_annotation/network_traffic_annotation.h"
30 #include "net/websockets/websocket_basic_stream.h"
31 #include "net/websockets/websocket_deflate_predictor_impl.h"
32 #include "net/websockets/websocket_deflate_stream.h"
33 #include "net/websockets/websocket_handshake_constants.h"
34 #include "net/websockets/websocket_handshake_request_info.h"
35
36 namespace net {
37
38 namespace {
39
ValidateStatus(const HttpResponseHeaders * headers)40 bool ValidateStatus(const HttpResponseHeaders* headers) {
41 return headers->GetStatusLine() == "HTTP/1.1 200";
42 }
43
44 } // namespace
45
WebSocketHttp2HandshakeStream(base::WeakPtr<SpdySession> session,WebSocketStream::ConnectDelegate * connect_delegate,std::vector<std::string> requested_sub_protocols,std::vector<std::string> requested_extensions,WebSocketStreamRequestAPI * request,std::set<std::string> dns_aliases)46 WebSocketHttp2HandshakeStream::WebSocketHttp2HandshakeStream(
47 base::WeakPtr<SpdySession> session,
48 WebSocketStream::ConnectDelegate* connect_delegate,
49 std::vector<std::string> requested_sub_protocols,
50 std::vector<std::string> requested_extensions,
51 WebSocketStreamRequestAPI* request,
52 std::set<std::string> dns_aliases)
53 : session_(session),
54 connect_delegate_(connect_delegate),
55 requested_sub_protocols_(requested_sub_protocols),
56 requested_extensions_(requested_extensions),
57 stream_request_(request),
58 dns_aliases_(std::move(dns_aliases)) {
59 DCHECK(connect_delegate);
60 DCHECK(request);
61 }
62
~WebSocketHttp2HandshakeStream()63 WebSocketHttp2HandshakeStream::~WebSocketHttp2HandshakeStream() {
64 spdy_stream_request_.reset();
65 RecordHandshakeResult(result_);
66 }
67
RegisterRequest(const HttpRequestInfo * request_info)68 void WebSocketHttp2HandshakeStream::RegisterRequest(
69 const HttpRequestInfo* request_info) {
70 DCHECK(request_info);
71 DCHECK(request_info->traffic_annotation.is_valid());
72 request_info_ = request_info;
73 }
74
InitializeStream(bool can_send_early,RequestPriority priority,const NetLogWithSource & net_log,CompletionOnceCallback callback)75 int WebSocketHttp2HandshakeStream::InitializeStream(
76 bool can_send_early,
77 RequestPriority priority,
78 const NetLogWithSource& net_log,
79 CompletionOnceCallback callback) {
80 priority_ = priority;
81 net_log_ = net_log;
82 return OK;
83 }
84
SendRequest(const HttpRequestHeaders & headers,HttpResponseInfo * response,CompletionOnceCallback callback)85 int WebSocketHttp2HandshakeStream::SendRequest(
86 const HttpRequestHeaders& headers,
87 HttpResponseInfo* response,
88 CompletionOnceCallback callback) {
89 DCHECK(!headers.HasHeader(websockets::kSecWebSocketKey));
90 DCHECK(!headers.HasHeader(websockets::kSecWebSocketProtocol));
91 DCHECK(!headers.HasHeader(websockets::kSecWebSocketExtensions));
92 DCHECK(headers.HasHeader(HttpRequestHeaders::kOrigin));
93 DCHECK(headers.HasHeader(websockets::kUpgrade));
94 DCHECK(headers.HasHeader(HttpRequestHeaders::kConnection));
95 DCHECK(headers.HasHeader(websockets::kSecWebSocketVersion));
96
97 if (!session_) {
98 const int rv = ERR_CONNECTION_CLOSED;
99 OnFailure("Connection closed before sending request.", rv, std::nullopt);
100 return rv;
101 }
102
103 http_response_info_ = response;
104
105 IPEndPoint address;
106 int result = session_->GetPeerAddress(&address);
107 if (result != OK) {
108 OnFailure("Error getting IP address.", result, std::nullopt);
109 return result;
110 }
111 http_response_info_->remote_endpoint = address;
112
113 auto request = std::make_unique<WebSocketHandshakeRequestInfo>(
114 request_info_->url, base::Time::Now());
115 request->headers = headers;
116
117 AddVectorHeaders(requested_extensions_, requested_sub_protocols_,
118 &request->headers);
119
120 CreateSpdyHeadersFromHttpRequestForWebSocket(
121 request_info_->url, request->headers, &http2_request_headers_);
122
123 connect_delegate_->OnStartOpeningHandshake(std::move(request));
124
125 callback_ = std::move(callback);
126 spdy_stream_request_ = std::make_unique<SpdyStreamRequest>();
127 // The initial request for the WebSocket is a CONNECT, so there is no need to
128 // call ConfirmHandshake().
129 int rv = spdy_stream_request_->StartRequest(
130 SPDY_BIDIRECTIONAL_STREAM, session_, request_info_->url, true, priority_,
131 request_info_->socket_tag, net_log_,
132 base::BindOnce(&WebSocketHttp2HandshakeStream::StartRequestCallback,
133 base::Unretained(this)),
134 NetworkTrafficAnnotationTag(request_info_->traffic_annotation));
135 if (rv == OK) {
136 StartRequestCallback(rv);
137 return ERR_IO_PENDING;
138 }
139 return rv;
140 }
141
ReadResponseHeaders(CompletionOnceCallback callback)142 int WebSocketHttp2HandshakeStream::ReadResponseHeaders(
143 CompletionOnceCallback callback) {
144 if (stream_closed_)
145 return stream_error_;
146
147 if (response_headers_complete_)
148 return ValidateResponse();
149
150 callback_ = std::move(callback);
151 return ERR_IO_PENDING;
152 }
153
ReadResponseBody(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)154 int WebSocketHttp2HandshakeStream::ReadResponseBody(
155 IOBuffer* buf,
156 int buf_len,
157 CompletionOnceCallback callback) {
158 // Callers should instead call Upgrade() to get a WebSocketStream
159 // and call ReadFrames() on that.
160 NOTREACHED();
161 }
162
Close(bool not_reusable)163 void WebSocketHttp2HandshakeStream::Close(bool not_reusable) {
164 spdy_stream_request_.reset();
165 if (stream_) {
166 stream_ = nullptr;
167 stream_closed_ = true;
168 stream_error_ = ERR_CONNECTION_CLOSED;
169 }
170 stream_adapter_.reset();
171 }
172
IsResponseBodyComplete() const173 bool WebSocketHttp2HandshakeStream::IsResponseBodyComplete() const {
174 return false;
175 }
176
IsConnectionReused() const177 bool WebSocketHttp2HandshakeStream::IsConnectionReused() const {
178 return true;
179 }
180
SetConnectionReused()181 void WebSocketHttp2HandshakeStream::SetConnectionReused() {}
182
CanReuseConnection() const183 bool WebSocketHttp2HandshakeStream::CanReuseConnection() const {
184 return false;
185 }
186
GetTotalReceivedBytes() const187 int64_t WebSocketHttp2HandshakeStream::GetTotalReceivedBytes() const {
188 return stream_ ? stream_->raw_received_bytes() : 0;
189 }
190
GetTotalSentBytes() const191 int64_t WebSocketHttp2HandshakeStream::GetTotalSentBytes() const {
192 return stream_ ? stream_->raw_sent_bytes() : 0;
193 }
194
GetAlternativeService(AlternativeService * alternative_service) const195 bool WebSocketHttp2HandshakeStream::GetAlternativeService(
196 AlternativeService* alternative_service) const {
197 return false;
198 }
199
GetLoadTimingInfo(LoadTimingInfo * load_timing_info) const200 bool WebSocketHttp2HandshakeStream::GetLoadTimingInfo(
201 LoadTimingInfo* load_timing_info) const {
202 return stream_ && stream_->GetLoadTimingInfo(load_timing_info);
203 }
204
GetSSLInfo(SSLInfo * ssl_info)205 void WebSocketHttp2HandshakeStream::GetSSLInfo(SSLInfo* ssl_info) {
206 if (stream_)
207 stream_->GetSSLInfo(ssl_info);
208 }
209
GetRemoteEndpoint(IPEndPoint * endpoint)210 int WebSocketHttp2HandshakeStream::GetRemoteEndpoint(IPEndPoint* endpoint) {
211 if (!session_)
212 return ERR_SOCKET_NOT_CONNECTED;
213
214 return session_->GetRemoteEndpoint(endpoint);
215 }
216
PopulateNetErrorDetails(NetErrorDetails *)217 void WebSocketHttp2HandshakeStream::PopulateNetErrorDetails(
218 NetErrorDetails* /*details*/) {
219 return;
220 }
221
Drain(HttpNetworkSession * session)222 void WebSocketHttp2HandshakeStream::Drain(HttpNetworkSession* session) {
223 Close(true /* not_reusable */);
224 }
225
SetPriority(RequestPriority priority)226 void WebSocketHttp2HandshakeStream::SetPriority(RequestPriority priority) {
227 priority_ = priority;
228 if (stream_)
229 stream_->SetPriority(priority_);
230 }
231
232 std::unique_ptr<HttpStream>
RenewStreamForAuth()233 WebSocketHttp2HandshakeStream::RenewStreamForAuth() {
234 // Renewing the stream is not supported.
235 return nullptr;
236 }
237
GetDnsAliases() const238 const std::set<std::string>& WebSocketHttp2HandshakeStream::GetDnsAliases()
239 const {
240 return dns_aliases_;
241 }
242
GetAcceptChViaAlps() const243 std::string_view WebSocketHttp2HandshakeStream::GetAcceptChViaAlps() const {
244 return {};
245 }
246
Upgrade()247 std::unique_ptr<WebSocketStream> WebSocketHttp2HandshakeStream::Upgrade() {
248 DCHECK(extension_params_.get());
249
250 stream_adapter_->DetachDelegate();
251 std::unique_ptr<WebSocketStream> basic_stream =
252 std::make_unique<WebSocketBasicStream>(std::move(stream_adapter_),
253 nullptr, sub_protocol_,
254 extensions_, net_log_);
255
256 if (!extension_params_->deflate_enabled)
257 return basic_stream;
258
259 return std::make_unique<WebSocketDeflateStream>(
260 std::move(basic_stream), extension_params_->deflate_parameters,
261 std::make_unique<WebSocketDeflatePredictorImpl>());
262 }
263
CanReadFromStream() const264 bool WebSocketHttp2HandshakeStream::CanReadFromStream() const {
265 return stream_adapter_ && stream_adapter_->is_initialized();
266 }
267
268 base::WeakPtr<WebSocketHandshakeStreamBase>
GetWeakPtr()269 WebSocketHttp2HandshakeStream::GetWeakPtr() {
270 return weak_ptr_factory_.GetWeakPtr();
271 }
272
OnHeadersSent()273 void WebSocketHttp2HandshakeStream::OnHeadersSent() {
274 std::move(callback_).Run(OK);
275 }
276
OnHeadersReceived(const quiche::HttpHeaderBlock & response_headers)277 void WebSocketHttp2HandshakeStream::OnHeadersReceived(
278 const quiche::HttpHeaderBlock& response_headers) {
279 DCHECK(!response_headers_complete_);
280 DCHECK(http_response_info_);
281
282 response_headers_complete_ = true;
283
284 const int rv =
285 SpdyHeadersToHttpResponse(response_headers, http_response_info_);
286 DCHECK_NE(rv, ERR_INCOMPLETE_HTTP2_HEADERS);
287
288 http_response_info_->response_time =
289 http_response_info_->original_response_time = stream_->response_time();
290 // Do not store SSLInfo in the response here, HttpNetworkTransaction will take
291 // care of that part.
292 http_response_info_->was_alpn_negotiated = true;
293 http_response_info_->request_time = stream_->GetRequestTime();
294 http_response_info_->connection_info = HttpConnectionInfo::kHTTP2;
295 http_response_info_->alpn_negotiated_protocol =
296 HttpConnectionInfoToString(http_response_info_->connection_info);
297
298 if (callback_)
299 std::move(callback_).Run(ValidateResponse());
300 }
301
OnClose(int status)302 void WebSocketHttp2HandshakeStream::OnClose(int status) {
303 DCHECK(stream_adapter_);
304 DCHECK_GT(ERR_IO_PENDING, status);
305
306 stream_closed_ = true;
307 stream_error_ = status;
308 stream_ = nullptr;
309
310 stream_adapter_.reset();
311
312 // If response headers have already been received,
313 // then ValidateResponse() sets |result_|.
314 if (!response_headers_complete_)
315 result_ = HandshakeResult::HTTP2_FAILED;
316
317 OnFailure(base::StrCat({"Stream closed with error: ", ErrorToString(status)}),
318 status, std::nullopt);
319
320 if (callback_)
321 std::move(callback_).Run(status);
322 }
323
StartRequestCallback(int rv)324 void WebSocketHttp2HandshakeStream::StartRequestCallback(int rv) {
325 DCHECK(callback_);
326 if (rv != OK) {
327 spdy_stream_request_.reset();
328 std::move(callback_).Run(rv);
329 return;
330 }
331 stream_ = spdy_stream_request_->ReleaseStream();
332 spdy_stream_request_.reset();
333 stream_adapter_ =
334 std::make_unique<WebSocketSpdyStreamAdapter>(stream_, this, net_log_);
335 rv = stream_->SendRequestHeaders(std::move(http2_request_headers_),
336 MORE_DATA_TO_SEND);
337 // SendRequestHeaders() always returns asynchronously,
338 // and instead of taking a callback, it calls OnHeadersSent().
339 DCHECK_EQ(ERR_IO_PENDING, rv);
340 }
341
ValidateResponse()342 int WebSocketHttp2HandshakeStream::ValidateResponse() {
343 DCHECK(http_response_info_);
344 const HttpResponseHeaders* headers = http_response_info_->headers.get();
345 const int response_code = headers->response_code();
346 switch (response_code) {
347 case HTTP_OK:
348 return ValidateUpgradeResponse(headers);
349
350 // We need to pass these through for authentication to work.
351 case HTTP_UNAUTHORIZED:
352 case HTTP_PROXY_AUTHENTICATION_REQUIRED:
353 return OK;
354
355 // Other status codes are potentially risky (see the warnings in the
356 // WHATWG WebSocket API spec) and so are dropped by default.
357 default:
358 OnFailure(
359 base::StringPrintf(
360 "Error during WebSocket handshake: Unexpected response code: %d",
361 headers->response_code()),
362 ERR_FAILED, headers->response_code());
363 result_ = HandshakeResult::HTTP2_INVALID_STATUS;
364 return ERR_INVALID_RESPONSE;
365 }
366 }
367
ValidateUpgradeResponse(const HttpResponseHeaders * headers)368 int WebSocketHttp2HandshakeStream::ValidateUpgradeResponse(
369 const HttpResponseHeaders* headers) {
370 extension_params_ = std::make_unique<WebSocketExtensionParams>();
371 std::string failure_message;
372 if (!ValidateStatus(headers)) {
373 result_ = HandshakeResult::HTTP2_INVALID_STATUS;
374 } else if (!ValidateSubProtocol(headers, requested_sub_protocols_,
375 &sub_protocol_, &failure_message)) {
376 result_ = HandshakeResult::HTTP2_FAILED_SUBPROTO;
377 } else if (!ValidateExtensions(headers, &extensions_, &failure_message,
378 extension_params_.get())) {
379 result_ = HandshakeResult::HTTP2_FAILED_EXTENSIONS;
380 } else {
381 result_ = HandshakeResult::HTTP2_CONNECTED;
382 return OK;
383 }
384
385 const int rv = ERR_INVALID_RESPONSE;
386 OnFailure("Error during WebSocket handshake: " + failure_message, rv,
387 std::nullopt);
388 return rv;
389 }
390
OnFailure(const std::string & message,int net_error,std::optional<int> response_code)391 void WebSocketHttp2HandshakeStream::OnFailure(
392 const std::string& message,
393 int net_error,
394 std::optional<int> response_code) {
395 stream_request_->OnFailure(message, net_error, response_code);
396 }
397
398 } // namespace net
399