1 // Copyright 2015 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/websockets/websocket_stream_create_test_base.h"
6 #include "base/memory/raw_ptr.h"
7
8 #include <utility>
9
10 #include "base/functional/callback.h"
11 #include "net/base/ip_endpoint.h"
12 #include "net/http/http_request_headers.h"
13 #include "net/http/http_response_headers.h"
14 #include "net/log/net_log_with_source.h"
15 #include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
16 #include "net/websockets/websocket_basic_handshake_stream.h"
17 #include "net/websockets/websocket_handshake_request_info.h"
18 #include "net/websockets/websocket_handshake_response_info.h"
19 #include "net/websockets/websocket_stream.h"
20 #include "url/gurl.h"
21 #include "url/origin.h"
22
23 namespace net {
24
25 using HeaderKeyValuePair = WebSocketStreamCreateTestBase::HeaderKeyValuePair;
26
27 class WebSocketStreamCreateTestBase::TestConnectDelegate
28 : public WebSocketStream::ConnectDelegate {
29 public:
TestConnectDelegate(WebSocketStreamCreateTestBase * owner,base::OnceClosure done_callback)30 TestConnectDelegate(WebSocketStreamCreateTestBase* owner,
31 base::OnceClosure done_callback)
32 : owner_(owner), done_callback_(std::move(done_callback)) {}
33
34 TestConnectDelegate(const TestConnectDelegate&) = delete;
35 TestConnectDelegate& operator=(const TestConnectDelegate&) = delete;
36
OnCreateRequest(URLRequest * request)37 void OnCreateRequest(URLRequest* request) override {
38 owner_->url_request_ = request;
39 }
40
OnSuccess(std::unique_ptr<WebSocketStream> stream,std::unique_ptr<WebSocketHandshakeResponseInfo> response)41 void OnSuccess(
42 std::unique_ptr<WebSocketStream> stream,
43 std::unique_ptr<WebSocketHandshakeResponseInfo> response) override {
44 if (owner_->response_info_)
45 ADD_FAILURE();
46 owner_->response_info_ = std::move(response);
47 stream.swap(owner_->stream_);
48 std::move(done_callback_).Run();
49 }
50
OnFailure(const std::string & message,int net_error,absl::optional<int> response_code)51 void OnFailure(const std::string& message,
52 int net_error,
53 absl::optional<int> response_code) override {
54 owner_->has_failed_ = true;
55 owner_->failure_message_ = message;
56 owner_->failure_response_code_ = response_code.value_or(-1);
57 std::move(done_callback_).Run();
58 }
59
OnStartOpeningHandshake(std::unique_ptr<WebSocketHandshakeRequestInfo> request)60 void OnStartOpeningHandshake(
61 std::unique_ptr<WebSocketHandshakeRequestInfo> request) override {
62 // Can be called multiple times (in the case of HTTP auth). Last call
63 // wins.
64 owner_->request_info_ = std::move(request);
65 }
66
OnSSLCertificateError(std::unique_ptr<WebSocketEventInterface::SSLErrorCallbacks> ssl_error_callbacks,int net_error,const SSLInfo & ssl_info,bool fatal)67 void OnSSLCertificateError(
68 std::unique_ptr<WebSocketEventInterface::SSLErrorCallbacks>
69 ssl_error_callbacks,
70 int net_error,
71 const SSLInfo& ssl_info,
72 bool fatal) override {
73 owner_->ssl_error_callbacks_ = std::move(ssl_error_callbacks);
74 owner_->ssl_info_ = ssl_info;
75 owner_->ssl_fatal_ = fatal;
76 }
77
OnAuthRequired(const AuthChallengeInfo & auth_info,scoped_refptr<HttpResponseHeaders> response_headers,const IPEndPoint & remote_endpoint,base::OnceCallback<void (const AuthCredentials *)> callback,absl::optional<AuthCredentials> * credentials)78 int OnAuthRequired(const AuthChallengeInfo& auth_info,
79 scoped_refptr<HttpResponseHeaders> response_headers,
80 const IPEndPoint& remote_endpoint,
81 base::OnceCallback<void(const AuthCredentials*)> callback,
82 absl::optional<AuthCredentials>* credentials) override {
83 owner_->run_loop_waiting_for_on_auth_required_.Quit();
84 owner_->auth_challenge_info_ = auth_info;
85 *credentials = owner_->auth_credentials_;
86 owner_->on_auth_required_callback_ = std::move(callback);
87 return owner_->on_auth_required_rv_;
88 }
89
90 private:
91 raw_ptr<WebSocketStreamCreateTestBase> owner_;
92 base::OnceClosure done_callback_;
93 };
94
95 WebSocketStreamCreateTestBase::WebSocketStreamCreateTestBase() = default;
96
97 WebSocketStreamCreateTestBase::~WebSocketStreamCreateTestBase() = default;
98
CreateAndConnectStream(const GURL & socket_url,const std::vector<std::string> & sub_protocols,const url::Origin & origin,const SiteForCookies & site_for_cookies,const IsolationInfo & isolation_info,const HttpRequestHeaders & additional_headers,std::unique_ptr<base::OneShotTimer> timer)99 void WebSocketStreamCreateTestBase::CreateAndConnectStream(
100 const GURL& socket_url,
101 const std::vector<std::string>& sub_protocols,
102 const url::Origin& origin,
103 const SiteForCookies& site_for_cookies,
104 const IsolationInfo& isolation_info,
105 const HttpRequestHeaders& additional_headers,
106 std::unique_ptr<base::OneShotTimer> timer) {
107 auto connect_delegate = std::make_unique<TestConnectDelegate>(
108 this, connect_run_loop_.QuitClosure());
109 auto api_delegate = std::make_unique<TestWebSocketStreamRequestAPI>();
110 stream_request_ = WebSocketStream::CreateAndConnectStreamForTesting(
111 socket_url, sub_protocols, origin, site_for_cookies, isolation_info,
112 additional_headers, url_request_context_host_.GetURLRequestContext(),
113 NetLogWithSource(), TRAFFIC_ANNOTATION_FOR_TESTS,
114 std::move(connect_delegate),
115 timer ? std::move(timer) : std::make_unique<base::OneShotTimer>(),
116 std::move(api_delegate));
117 }
118
119 std::vector<HeaderKeyValuePair>
RequestHeadersToVector(const HttpRequestHeaders & headers)120 WebSocketStreamCreateTestBase::RequestHeadersToVector(
121 const HttpRequestHeaders& headers) {
122 HttpRequestHeaders::Iterator it(headers);
123 std::vector<HeaderKeyValuePair> result;
124 while (it.GetNext())
125 result.emplace_back(it.name(), it.value());
126 return result;
127 }
128
129 std::vector<HeaderKeyValuePair>
ResponseHeadersToVector(const HttpResponseHeaders & headers)130 WebSocketStreamCreateTestBase::ResponseHeadersToVector(
131 const HttpResponseHeaders& headers) {
132 size_t iter = 0;
133 std::string name, value;
134 std::vector<HeaderKeyValuePair> result;
135 while (headers.EnumerateHeaderLines(&iter, &name, &value))
136 result.emplace_back(name, value);
137 return result;
138 }
139
WaitUntilConnectDone()140 void WebSocketStreamCreateTestBase::WaitUntilConnectDone() {
141 connect_run_loop_.Run();
142 }
143
WaitUntilOnAuthRequired()144 void WebSocketStreamCreateTestBase::WaitUntilOnAuthRequired() {
145 run_loop_waiting_for_on_auth_required_.Run();
146 }
147
NoSubProtocols()148 std::vector<std::string> WebSocketStreamCreateTestBase::NoSubProtocols() {
149 return std::vector<std::string>();
150 }
151
152 } // namespace net
153