1 // Copyright 2013 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #ifdef UNSAFE_BUFFERS_BUILD
6 // TODO(crbug.com/40284755): Remove this and spanify to fix the errors.
7 #pragma allow_unsafe_buffers
8 #endif
9
10 #include "net/websockets/websocket_test_util.h"
11
12 #include <stddef.h>
13
14 #include <algorithm>
15 #include <sstream>
16 #include <utility>
17
18 #include "base/check.h"
19 #include "base/containers/span.h"
20 #include "base/strings/strcat.h"
21 #include "base/strings/string_util.h"
22 #include "base/strings/stringprintf.h"
23 #include "net/base/net_errors.h"
24 #include "net/http/http_network_session.h"
25 #include "net/proxy_resolution/configured_proxy_resolution_service.h"
26 #include "net/proxy_resolution/proxy_resolution_service.h"
27 #include "net/socket/socket_test_util.h"
28 #include "net/third_party/quiche/src/quiche/http2/core/spdy_protocol.h"
29 #include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
30 #include "net/url_request/url_request_context.h"
31 #include "net/url_request/url_request_context_builder.h"
32 #include "net/websockets/websocket_basic_handshake_stream.h"
33 #include "url/origin.h"
34
35 namespace net {
36 class AuthChallengeInfo;
37 class AuthCredentials;
38 class HttpResponseHeaders;
39 class WebSocketHttp2HandshakeStream;
40 class WebSocketHttp3HandshakeStream;
41
42 namespace {
43
44 const uint64_t kA = (static_cast<uint64_t>(0x5851f42d) << 32) +
45 static_cast<uint64_t>(0x4c957f2d);
46 const uint64_t kC = 12345;
47 const uint64_t kM = static_cast<uint64_t>(1) << 48;
48
49 } // namespace
50
LinearCongruentialGenerator(uint32_t seed)51 LinearCongruentialGenerator::LinearCongruentialGenerator(uint32_t seed)
52 : current_(seed) {}
53
Generate()54 uint32_t LinearCongruentialGenerator::Generate() {
55 uint64_t result = current_;
56 current_ = (current_ * kA + kC) % kM;
57 return static_cast<uint32_t>(result >> 16);
58 }
59
WebSocketExtraHeadersToString(const WebSocketExtraHeaders & headers)60 std::string WebSocketExtraHeadersToString(
61 const WebSocketExtraHeaders& headers) {
62 std::string answer;
63 for (const auto& header : headers) {
64 base::StrAppend(&answer, {header.first, ": ", header.second, "\r\n"});
65 }
66 return answer;
67 }
68
WebSocketExtraHeadersToHttpRequestHeaders(const WebSocketExtraHeaders & headers)69 HttpRequestHeaders WebSocketExtraHeadersToHttpRequestHeaders(
70 const WebSocketExtraHeaders& headers) {
71 HttpRequestHeaders headers_to_return;
72 for (const auto& header : headers)
73 headers_to_return.SetHeader(header.first, header.second);
74 return headers_to_return;
75 }
76
WebSocketStandardRequest(const std::string & path,const std::string & host,const url::Origin & origin,const WebSocketExtraHeaders & send_additional_request_headers,const WebSocketExtraHeaders & extra_headers)77 std::string WebSocketStandardRequest(
78 const std::string& path,
79 const std::string& host,
80 const url::Origin& origin,
81 const WebSocketExtraHeaders& send_additional_request_headers,
82 const WebSocketExtraHeaders& extra_headers) {
83 return WebSocketStandardRequestWithCookies(path, host, origin, /*cookies=*/{},
84 send_additional_request_headers,
85 extra_headers);
86 }
87
WebSocketStandardRequestWithCookies(const std::string & path,const std::string & host,const url::Origin & origin,const WebSocketExtraHeaders & cookies,const WebSocketExtraHeaders & send_additional_request_headers,const WebSocketExtraHeaders & extra_headers)88 std::string WebSocketStandardRequestWithCookies(
89 const std::string& path,
90 const std::string& host,
91 const url::Origin& origin,
92 const WebSocketExtraHeaders& cookies,
93 const WebSocketExtraHeaders& send_additional_request_headers,
94 const WebSocketExtraHeaders& extra_headers) {
95 // Unrelated changes in net/http may change the order and default-values of
96 // HTTP headers, causing WebSocket tests to fail. It is safe to update this
97 // in that case.
98 HttpRequestHeaders headers;
99 std::stringstream request_headers;
100
101 request_headers << base::StringPrintf("GET %s HTTP/1.1\r\n", path.c_str());
102 headers.SetHeader("Host", host);
103 headers.SetHeader("Connection", "Upgrade");
104 headers.SetHeader("Pragma", "no-cache");
105 headers.SetHeader("Cache-Control", "no-cache");
106 for (const auto& [key, value] : send_additional_request_headers)
107 headers.SetHeader(key, value);
108 headers.SetHeader("Upgrade", "websocket");
109 headers.SetHeader("Origin", origin.Serialize());
110 headers.SetHeader("Sec-WebSocket-Version", "13");
111 if (!headers.HasHeader("User-Agent"))
112 headers.SetHeader("User-Agent", "");
113 headers.SetHeader("Accept-Encoding", "gzip, deflate");
114 headers.SetHeader("Accept-Language", "en-us,fr");
115 for (const auto& [key, value] : cookies)
116 headers.SetHeader(key, value);
117 headers.SetHeader("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==");
118 headers.SetHeader("Sec-WebSocket-Extensions",
119 "permessage-deflate; client_max_window_bits");
120 for (const auto& [key, value] : extra_headers)
121 headers.SetHeader(key, value);
122
123 request_headers << headers.ToString();
124 return request_headers.str();
125 }
126
WebSocketStandardResponse(const std::string & extra_headers)127 std::string WebSocketStandardResponse(const std::string& extra_headers) {
128 return base::StrCat(
129 {"HTTP/1.1 101 Switching Protocols\r\n"
130 "Upgrade: websocket\r\n"
131 "Connection: Upgrade\r\n"
132 "Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n",
133 extra_headers, "\r\n"});
134 }
135
WebSocketCommonTestHeaders()136 HttpRequestHeaders WebSocketCommonTestHeaders() {
137 HttpRequestHeaders request_headers;
138 request_headers.SetHeader("Host", "www.example.org");
139 request_headers.SetHeader("Connection", "Upgrade");
140 request_headers.SetHeader("Pragma", "no-cache");
141 request_headers.SetHeader("Cache-Control", "no-cache");
142 request_headers.SetHeader("Upgrade", "websocket");
143 request_headers.SetHeader("Origin", "http://origin.example.org");
144 request_headers.SetHeader("Sec-WebSocket-Version", "13");
145 request_headers.SetHeader("User-Agent", "");
146 request_headers.SetHeader("Accept-Encoding", "gzip, deflate");
147 request_headers.SetHeader("Accept-Language", "en-us,fr");
148 return request_headers;
149 }
150
WebSocketHttp2Request(const std::string & path,const std::string & authority,const std::string & origin,const WebSocketExtraHeaders & extra_headers)151 quiche::HttpHeaderBlock WebSocketHttp2Request(
152 const std::string& path,
153 const std::string& authority,
154 const std::string& origin,
155 const WebSocketExtraHeaders& extra_headers) {
156 quiche::HttpHeaderBlock request_headers;
157 request_headers[spdy::kHttp2MethodHeader] = "CONNECT";
158 request_headers[spdy::kHttp2AuthorityHeader] = authority;
159 request_headers[spdy::kHttp2SchemeHeader] = "https";
160 request_headers[spdy::kHttp2PathHeader] = path;
161 request_headers[spdy::kHttp2ProtocolHeader] = "websocket";
162 request_headers["pragma"] = "no-cache";
163 request_headers["cache-control"] = "no-cache";
164 request_headers["origin"] = origin;
165 request_headers["sec-websocket-version"] = "13";
166 request_headers["user-agent"] = "";
167 request_headers["accept-encoding"] = "gzip, deflate";
168 request_headers["accept-language"] = "en-us,fr";
169 request_headers["sec-websocket-extensions"] =
170 "permessage-deflate; client_max_window_bits";
171 for (const auto& header : extra_headers) {
172 request_headers[base::ToLowerASCII(header.first)] = header.second;
173 }
174 return request_headers;
175 }
176
WebSocketHttp2Response(const WebSocketExtraHeaders & extra_headers)177 quiche::HttpHeaderBlock WebSocketHttp2Response(
178 const WebSocketExtraHeaders& extra_headers) {
179 quiche::HttpHeaderBlock response_headers;
180 response_headers[spdy::kHttp2StatusHeader] = "200";
181 for (const auto& header : extra_headers) {
182 response_headers[base::ToLowerASCII(header.first)] = header.second;
183 }
184 return response_headers;
185 }
186
187 struct WebSocketMockClientSocketFactoryMaker::Detail {
188 std::string expect_written;
189 std::string return_to_read;
190 std::vector<MockRead> reads;
191 MockWrite write;
192 std::vector<std::unique_ptr<SequencedSocketData>> socket_data_vector;
193 std::vector<std::unique_ptr<SSLSocketDataProvider>> ssl_socket_data_vector;
194 MockClientSocketFactory factory;
195 };
196
WebSocketMockClientSocketFactoryMaker()197 WebSocketMockClientSocketFactoryMaker::WebSocketMockClientSocketFactoryMaker()
198 : detail_(std::make_unique<Detail>()) {}
199
200 WebSocketMockClientSocketFactoryMaker::
201 ~WebSocketMockClientSocketFactoryMaker() = default;
202
factory()203 MockClientSocketFactory* WebSocketMockClientSocketFactoryMaker::factory() {
204 return &detail_->factory;
205 }
206
SetExpectations(const std::string & expect_written,const std::string & return_to_read)207 void WebSocketMockClientSocketFactoryMaker::SetExpectations(
208 const std::string& expect_written,
209 const std::string& return_to_read) {
210 constexpr size_t kHttpStreamParserBufferSize = 4096;
211 // We need to extend the lifetime of these strings.
212 detail_->expect_written = expect_written;
213 detail_->return_to_read = return_to_read;
214 int sequence = 0;
215 detail_->write = MockWrite(SYNCHRONOUS,
216 detail_->expect_written.data(),
217 detail_->expect_written.size(),
218 sequence++);
219 // HttpStreamParser reads 4KB at a time. We need to take this implementation
220 // detail into account if |return_to_read| is big enough.
221 for (size_t place = 0; place < detail_->return_to_read.size();
222 place += kHttpStreamParserBufferSize) {
223 detail_->reads.emplace_back(SYNCHRONOUS,
224 detail_->return_to_read.data() + place,
225 std::min(detail_->return_to_read.size() - place,
226 kHttpStreamParserBufferSize),
227 sequence++);
228 }
229 auto socket_data = std::make_unique<SequencedSocketData>(
230 detail_->reads, base::span(&detail_->write, 1u));
231 socket_data->set_connect_data(MockConnect(SYNCHRONOUS, OK));
232 AddRawExpectations(std::move(socket_data));
233 }
234
AddRawExpectations(std::unique_ptr<SequencedSocketData> socket_data)235 void WebSocketMockClientSocketFactoryMaker::AddRawExpectations(
236 std::unique_ptr<SequencedSocketData> socket_data) {
237 detail_->factory.AddSocketDataProvider(socket_data.get());
238 detail_->socket_data_vector.push_back(std::move(socket_data));
239 }
240
AddSSLSocketDataProvider(std::unique_ptr<SSLSocketDataProvider> ssl_socket_data)241 void WebSocketMockClientSocketFactoryMaker::AddSSLSocketDataProvider(
242 std::unique_ptr<SSLSocketDataProvider> ssl_socket_data) {
243 detail_->factory.AddSSLSocketDataProvider(ssl_socket_data.get());
244 detail_->ssl_socket_data_vector.push_back(std::move(ssl_socket_data));
245 }
246
WebSocketTestURLRequestContextHost()247 WebSocketTestURLRequestContextHost::WebSocketTestURLRequestContextHost()
248 : url_request_context_builder_(CreateTestURLRequestContextBuilder()) {
249 url_request_context_builder_->set_client_socket_factory_for_testing(
250 maker_.factory());
251 HttpNetworkSessionParams params;
252 params.enable_spdy_ping_based_connection_checking = false;
253 params.enable_quic = false;
254 params.disable_idle_sockets_close_on_memory_pressure = false;
255 url_request_context_builder_->set_http_network_session_params(params);
256 }
257
258 WebSocketTestURLRequestContextHost::~WebSocketTestURLRequestContextHost() =
259 default;
260
AddRawExpectations(std::unique_ptr<SequencedSocketData> socket_data)261 void WebSocketTestURLRequestContextHost::AddRawExpectations(
262 std::unique_ptr<SequencedSocketData> socket_data) {
263 maker_.AddRawExpectations(std::move(socket_data));
264 }
265
AddSSLSocketDataProvider(std::unique_ptr<SSLSocketDataProvider> ssl_socket_data)266 void WebSocketTestURLRequestContextHost::AddSSLSocketDataProvider(
267 std::unique_ptr<SSLSocketDataProvider> ssl_socket_data) {
268 maker_.AddSSLSocketDataProvider(std::move(ssl_socket_data));
269 }
270
SetProxyConfig(const std::string & proxy_rules)271 void WebSocketTestURLRequestContextHost::SetProxyConfig(
272 const std::string& proxy_rules) {
273 DCHECK(!url_request_context_);
274 auto proxy_resolution_service =
275 ConfiguredProxyResolutionService::CreateFixedForTest(
276 proxy_rules, TRAFFIC_ANNOTATION_FOR_TESTS);
277 url_request_context_builder_->set_proxy_resolution_service(
278 std::move(proxy_resolution_service));
279 }
280
OnURLRequestConnected(URLRequest * request,const TransportInfo & info)281 void DummyConnectDelegate::OnURLRequestConnected(URLRequest* request,
282 const TransportInfo& info) {}
283
OnAuthRequired(const AuthChallengeInfo & auth_info,scoped_refptr<HttpResponseHeaders> response_headers,const IPEndPoint & host_port_pair,base::OnceCallback<void (const AuthCredentials *)> callback,std::optional<AuthCredentials> * credentials)284 int DummyConnectDelegate::OnAuthRequired(
285 const AuthChallengeInfo& auth_info,
286 scoped_refptr<HttpResponseHeaders> response_headers,
287 const IPEndPoint& host_port_pair,
288 base::OnceCallback<void(const AuthCredentials*)> callback,
289 std::optional<AuthCredentials>* credentials) {
290 return OK;
291 }
292
GetURLRequestContext()293 URLRequestContext* WebSocketTestURLRequestContextHost::GetURLRequestContext() {
294 if (!url_request_context_) {
295 url_request_context_builder_->set_network_delegate(
296 std::make_unique<TestNetworkDelegate>());
297 url_request_context_ = url_request_context_builder_->Build();
298 url_request_context_builder_ = nullptr;
299 }
300 return url_request_context_.get();
301 }
302
OnBasicHandshakeStreamCreated(WebSocketBasicHandshakeStream * handshake_stream)303 void TestWebSocketStreamRequestAPI::OnBasicHandshakeStreamCreated(
304 WebSocketBasicHandshakeStream* handshake_stream) {
305 handshake_stream->SetWebSocketKeyForTesting("dGhlIHNhbXBsZSBub25jZQ==");
306 }
307
OnHttp2HandshakeStreamCreated(WebSocketHttp2HandshakeStream * handshake_stream)308 void TestWebSocketStreamRequestAPI::OnHttp2HandshakeStreamCreated(
309 WebSocketHttp2HandshakeStream* handshake_stream) {}
310
OnHttp3HandshakeStreamCreated(WebSocketHttp3HandshakeStream * handshake_stream)311 void TestWebSocketStreamRequestAPI::OnHttp3HandshakeStreamCreated(
312 WebSocketHttp3HandshakeStream* handshake_stream) {}
313 } // namespace net
314