1 // Copyright 2013 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/websockets/websocket_test_util.h"
6
7 #include <stddef.h>
8 #include <algorithm>
9 #include <utility>
10
11 #include "base/strings/strcat.h"
12 #include "base/strings/string_util.h"
13 #include "base/strings/stringprintf.h"
14 #include "net/base/ip_endpoint.h"
15 #include "net/http/http_network_session.h"
16 #include "net/proxy_resolution/configured_proxy_resolution_service.h"
17 #include "net/socket/socket_test_util.h"
18 #include "net/third_party/quiche/src/quiche/spdy/core/spdy_protocol.h"
19 #include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
20 #include "net/url_request/url_request_context.h"
21 #include "net/url_request/url_request_context_builder.h"
22 #include "net/websockets/websocket_basic_handshake_stream.h"
23 #include "url/origin.h"
24
25 namespace net {
26
27 namespace {
28
29 const uint64_t kA = (static_cast<uint64_t>(0x5851f42d) << 32) +
30 static_cast<uint64_t>(0x4c957f2d);
31 const uint64_t kC = 12345;
32 const uint64_t kM = static_cast<uint64_t>(1) << 48;
33
34 } // namespace
35
LinearCongruentialGenerator(uint32_t seed)36 LinearCongruentialGenerator::LinearCongruentialGenerator(uint32_t seed)
37 : current_(seed) {}
38
Generate()39 uint32_t LinearCongruentialGenerator::Generate() {
40 uint64_t result = current_;
41 current_ = (current_ * kA + kC) % kM;
42 return static_cast<uint32_t>(result >> 16);
43 }
44
WebSocketExtraHeadersToString(const WebSocketExtraHeaders & headers)45 std::string WebSocketExtraHeadersToString(
46 const WebSocketExtraHeaders& headers) {
47 std::string answer;
48 for (const auto& header : headers) {
49 base::StrAppend(&answer, {header.first, ": ", header.second, "\r\n"});
50 }
51 return answer;
52 }
53
WebSocketExtraHeadersToHttpRequestHeaders(const WebSocketExtraHeaders & headers)54 HttpRequestHeaders WebSocketExtraHeadersToHttpRequestHeaders(
55 const WebSocketExtraHeaders& headers) {
56 HttpRequestHeaders headers_to_return;
57 for (const auto& header : headers)
58 headers_to_return.SetHeader(header.first, header.second);
59 return headers_to_return;
60 }
61
WebSocketStandardRequest(const std::string & path,const std::string & host,const url::Origin & origin,const WebSocketExtraHeaders & send_additional_request_headers,const WebSocketExtraHeaders & extra_headers)62 std::string WebSocketStandardRequest(
63 const std::string& path,
64 const std::string& host,
65 const url::Origin& origin,
66 const WebSocketExtraHeaders& send_additional_request_headers,
67 const WebSocketExtraHeaders& extra_headers) {
68 return WebSocketStandardRequestWithCookies(path, host, origin, /*cookies=*/{},
69 send_additional_request_headers,
70 extra_headers);
71 }
72
WebSocketStandardRequestWithCookies(const std::string & path,const std::string & host,const url::Origin & origin,const WebSocketExtraHeaders & cookies,const WebSocketExtraHeaders & send_additional_request_headers,const WebSocketExtraHeaders & extra_headers)73 std::string WebSocketStandardRequestWithCookies(
74 const std::string& path,
75 const std::string& host,
76 const url::Origin& origin,
77 const WebSocketExtraHeaders& cookies,
78 const WebSocketExtraHeaders& send_additional_request_headers,
79 const WebSocketExtraHeaders& extra_headers) {
80 // Unrelated changes in net/http may change the order and default-values of
81 // HTTP headers, causing WebSocket tests to fail. It is safe to update this
82 // in that case.
83 HttpRequestHeaders headers;
84 std::stringstream request_headers;
85
86 request_headers << base::StringPrintf("GET %s HTTP/1.1\r\n", path.c_str());
87 headers.SetHeader("Host", host);
88 headers.SetHeader("Connection", "Upgrade");
89 headers.SetHeader("Pragma", "no-cache");
90 headers.SetHeader("Cache-Control", "no-cache");
91 for (const auto& [key, value] : send_additional_request_headers)
92 headers.SetHeader(key, value);
93 headers.SetHeader("Upgrade", "websocket");
94 headers.SetHeader("Origin", origin.Serialize());
95 headers.SetHeader("Sec-WebSocket-Version", "13");
96 if (!headers.HasHeader("User-Agent"))
97 headers.SetHeader("User-Agent", "");
98 headers.SetHeader("Accept-Encoding", "gzip, deflate");
99 headers.SetHeader("Accept-Language", "en-us,fr");
100 for (const auto& [key, value] : cookies)
101 headers.SetHeader(key, value);
102 headers.SetHeader("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==");
103 headers.SetHeader("Sec-WebSocket-Extensions",
104 "permessage-deflate; client_max_window_bits");
105 for (const auto& [key, value] : extra_headers)
106 headers.SetHeader(key, value);
107
108 request_headers << headers.ToString();
109 return request_headers.str();
110 }
111
WebSocketStandardResponse(const std::string & extra_headers)112 std::string WebSocketStandardResponse(const std::string& extra_headers) {
113 return base::StringPrintf(
114 "HTTP/1.1 101 Switching Protocols\r\n"
115 "Upgrade: websocket\r\n"
116 "Connection: Upgrade\r\n"
117 "Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n"
118 "%s\r\n",
119 extra_headers.c_str());
120 }
121
WebSocketCommonTestHeaders()122 HttpRequestHeaders WebSocketCommonTestHeaders() {
123 HttpRequestHeaders request_headers;
124 request_headers.SetHeader("Host", "www.example.org");
125 request_headers.SetHeader("Connection", "Upgrade");
126 request_headers.SetHeader("Pragma", "no-cache");
127 request_headers.SetHeader("Cache-Control", "no-cache");
128 request_headers.SetHeader("Upgrade", "websocket");
129 request_headers.SetHeader("Origin", "http://origin.example.org");
130 request_headers.SetHeader("Sec-WebSocket-Version", "13");
131 request_headers.SetHeader("User-Agent", "");
132 request_headers.SetHeader("Accept-Encoding", "gzip, deflate");
133 request_headers.SetHeader("Accept-Language", "en-us,fr");
134 return request_headers;
135 }
136
WebSocketHttp2Request(const std::string & path,const std::string & authority,const std::string & origin,const WebSocketExtraHeaders & extra_headers)137 spdy::Http2HeaderBlock WebSocketHttp2Request(
138 const std::string& path,
139 const std::string& authority,
140 const std::string& origin,
141 const WebSocketExtraHeaders& extra_headers) {
142 spdy::Http2HeaderBlock request_headers;
143 request_headers[spdy::kHttp2MethodHeader] = "CONNECT";
144 request_headers[spdy::kHttp2AuthorityHeader] = authority;
145 request_headers[spdy::kHttp2SchemeHeader] = "https";
146 request_headers[spdy::kHttp2PathHeader] = path;
147 request_headers[spdy::kHttp2ProtocolHeader] = "websocket";
148 request_headers["pragma"] = "no-cache";
149 request_headers["cache-control"] = "no-cache";
150 request_headers["origin"] = origin;
151 request_headers["sec-websocket-version"] = "13";
152 request_headers["user-agent"] = "";
153 request_headers["accept-encoding"] = "gzip, deflate";
154 request_headers["accept-language"] = "en-us,fr";
155 request_headers["sec-websocket-extensions"] =
156 "permessage-deflate; client_max_window_bits";
157 for (const auto& header : extra_headers) {
158 request_headers[base::ToLowerASCII(header.first)] = header.second;
159 }
160 return request_headers;
161 }
162
WebSocketHttp2Response(const WebSocketExtraHeaders & extra_headers)163 spdy::Http2HeaderBlock WebSocketHttp2Response(
164 const WebSocketExtraHeaders& extra_headers) {
165 spdy::Http2HeaderBlock response_headers;
166 response_headers[spdy::kHttp2StatusHeader] = "200";
167 for (const auto& header : extra_headers) {
168 response_headers[base::ToLowerASCII(header.first)] = header.second;
169 }
170 return response_headers;
171 }
172
173 struct WebSocketMockClientSocketFactoryMaker::Detail {
174 std::string expect_written;
175 std::string return_to_read;
176 std::vector<MockRead> reads;
177 MockWrite write;
178 std::vector<std::unique_ptr<SequencedSocketData>> socket_data_vector;
179 std::vector<std::unique_ptr<SSLSocketDataProvider>> ssl_socket_data_vector;
180 MockClientSocketFactory factory;
181 };
182
WebSocketMockClientSocketFactoryMaker()183 WebSocketMockClientSocketFactoryMaker::WebSocketMockClientSocketFactoryMaker()
184 : detail_(std::make_unique<Detail>()) {}
185
186 WebSocketMockClientSocketFactoryMaker::
187 ~WebSocketMockClientSocketFactoryMaker() = default;
188
factory()189 MockClientSocketFactory* WebSocketMockClientSocketFactoryMaker::factory() {
190 return &detail_->factory;
191 }
192
SetExpectations(const std::string & expect_written,const std::string & return_to_read)193 void WebSocketMockClientSocketFactoryMaker::SetExpectations(
194 const std::string& expect_written,
195 const std::string& return_to_read) {
196 const size_t kHttpStreamParserBufferSize = 4096;
197 // We need to extend the lifetime of these strings.
198 detail_->expect_written = expect_written;
199 detail_->return_to_read = return_to_read;
200 int sequence = 0;
201 detail_->write = MockWrite(SYNCHRONOUS,
202 detail_->expect_written.data(),
203 detail_->expect_written.size(),
204 sequence++);
205 // HttpStreamParser reads 4KB at a time. We need to take this implementation
206 // detail into account if |return_to_read| is big enough.
207 for (size_t place = 0; place < detail_->return_to_read.size();
208 place += kHttpStreamParserBufferSize) {
209 detail_->reads.emplace_back(SYNCHRONOUS,
210 detail_->return_to_read.data() + place,
211 std::min(detail_->return_to_read.size() - place,
212 kHttpStreamParserBufferSize),
213 sequence++);
214 }
215 auto socket_data = std::make_unique<SequencedSocketData>(
216 detail_->reads, base::make_span(&detail_->write, 1u));
217 socket_data->set_connect_data(MockConnect(SYNCHRONOUS, OK));
218 AddRawExpectations(std::move(socket_data));
219 }
220
AddRawExpectations(std::unique_ptr<SequencedSocketData> socket_data)221 void WebSocketMockClientSocketFactoryMaker::AddRawExpectations(
222 std::unique_ptr<SequencedSocketData> socket_data) {
223 detail_->factory.AddSocketDataProvider(socket_data.get());
224 detail_->socket_data_vector.push_back(std::move(socket_data));
225 }
226
AddSSLSocketDataProvider(std::unique_ptr<SSLSocketDataProvider> ssl_socket_data)227 void WebSocketMockClientSocketFactoryMaker::AddSSLSocketDataProvider(
228 std::unique_ptr<SSLSocketDataProvider> ssl_socket_data) {
229 detail_->factory.AddSSLSocketDataProvider(ssl_socket_data.get());
230 detail_->ssl_socket_data_vector.push_back(std::move(ssl_socket_data));
231 }
232
WebSocketTestURLRequestContextHost()233 WebSocketTestURLRequestContextHost::WebSocketTestURLRequestContextHost()
234 : url_request_context_builder_(CreateTestURLRequestContextBuilder()) {
235 url_request_context_builder_->set_client_socket_factory_for_testing(
236 maker_.factory());
237 HttpNetworkSessionParams params;
238 params.enable_spdy_ping_based_connection_checking = false;
239 params.enable_quic = false;
240 params.disable_idle_sockets_close_on_memory_pressure = false;
241 url_request_context_builder_->set_http_network_session_params(params);
242 }
243
244 WebSocketTestURLRequestContextHost::~WebSocketTestURLRequestContextHost() =
245 default;
246
AddRawExpectations(std::unique_ptr<SequencedSocketData> socket_data)247 void WebSocketTestURLRequestContextHost::AddRawExpectations(
248 std::unique_ptr<SequencedSocketData> socket_data) {
249 maker_.AddRawExpectations(std::move(socket_data));
250 }
251
AddSSLSocketDataProvider(std::unique_ptr<SSLSocketDataProvider> ssl_socket_data)252 void WebSocketTestURLRequestContextHost::AddSSLSocketDataProvider(
253 std::unique_ptr<SSLSocketDataProvider> ssl_socket_data) {
254 maker_.AddSSLSocketDataProvider(std::move(ssl_socket_data));
255 }
256
SetProxyConfig(const std::string & proxy_rules)257 void WebSocketTestURLRequestContextHost::SetProxyConfig(
258 const std::string& proxy_rules) {
259 DCHECK(!url_request_context_);
260 auto proxy_resolution_service =
261 ConfiguredProxyResolutionService::CreateFixedForTest(
262 proxy_rules, TRAFFIC_ANNOTATION_FOR_TESTS);
263 url_request_context_builder_->set_proxy_resolution_service(
264 std::move(proxy_resolution_service));
265 }
266
OnAuthRequired(const AuthChallengeInfo & auth_info,scoped_refptr<HttpResponseHeaders> response_headers,const IPEndPoint & host_port_pair,base::OnceCallback<void (const AuthCredentials *)> callback,absl::optional<AuthCredentials> * credentials)267 int DummyConnectDelegate::OnAuthRequired(
268 const AuthChallengeInfo& auth_info,
269 scoped_refptr<HttpResponseHeaders> response_headers,
270 const IPEndPoint& host_port_pair,
271 base::OnceCallback<void(const AuthCredentials*)> callback,
272 absl::optional<AuthCredentials>* credentials) {
273 return OK;
274 }
275
GetURLRequestContext()276 URLRequestContext* WebSocketTestURLRequestContextHost::GetURLRequestContext() {
277 if (!url_request_context_) {
278 url_request_context_builder_->set_network_delegate(
279 std::make_unique<TestNetworkDelegate>());
280 url_request_context_ = url_request_context_builder_->Build();
281 url_request_context_builder_ = nullptr;
282 }
283 return url_request_context_.get();
284 }
285
OnBasicHandshakeStreamCreated(WebSocketBasicHandshakeStream * handshake_stream)286 void TestWebSocketStreamRequestAPI::OnBasicHandshakeStreamCreated(
287 WebSocketBasicHandshakeStream* handshake_stream) {
288 handshake_stream->SetWebSocketKeyForTesting("dGhlIHNhbXBsZSBub25jZQ==");
289 }
290
OnHttp2HandshakeStreamCreated(WebSocketHttp2HandshakeStream * handshake_stream)291 void TestWebSocketStreamRequestAPI::OnHttp2HandshakeStreamCreated(
292 WebSocketHttp2HandshakeStream* handshake_stream) {}
293
OnHttp3HandshakeStreamCreated(WebSocketHttp3HandshakeStream * handshake_stream)294 void TestWebSocketStreamRequestAPI::OnHttp3HandshakeStreamCreated(
295 WebSocketHttp3HandshakeStream* handshake_stream) {}
296 } // namespace net
297