1 // Copyright 2024 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/socket/tcp_stream_attempt.h"
6
7 #include <optional>
8 #include <string_view>
9
10 #include "base/functional/callback_forward.h"
11 #include "base/test/task_environment.h"
12 #include "base/time/time.h"
13 #include "net/base/ip_endpoint.h"
14 #include "net/base/net_errors.h"
15 #include "net/log/net_log_capture_mode.h"
16 #include "net/log/net_log_entry.h"
17 #include "net/socket/socket_performance_watcher.h"
18 #include "net/socket/socket_performance_watcher_factory.h"
19 #include "net/socket/stream_attempt.h"
20 #include "net/socket/transport_client_socket_pool_test_util.h"
21 #include "net/test/gtest_util.h"
22 #include "net/test/test_with_task_environment.h"
23 #include "testing/gtest/include/gtest/gtest.h"
24
25 using net::test::IsError;
26 using net::test::IsOk;
27
28 namespace net {
29
30 namespace {
31
MakeIPEndPoint(std::string_view ip_literal,uint16_t port=80)32 IPEndPoint MakeIPEndPoint(std::string_view ip_literal, uint16_t port = 80) {
33 std::optional<IPAddress> ip = IPAddress::FromIPLiteral(std::move(ip_literal));
34 return IPEndPoint(*ip, port);
35 }
36
37 class NetLogObserver : public NetLog::ThreadSafeObserver {
38 public:
NetLogObserver(NetLog * net_log)39 explicit NetLogObserver(NetLog* net_log) {
40 net_log->AddObserver(this, NetLogCaptureMode::kEverything);
41 }
42
~NetLogObserver()43 ~NetLogObserver() override {
44 if (net_log()) {
45 net_log()->RemoveObserver(this);
46 }
47 }
48
OnAddEntry(const NetLogEntry & entry)49 void OnAddEntry(const NetLogEntry& entry) override {
50 entries_.emplace_back(entry.Clone());
51 }
52
entries() const53 const std::vector<NetLogEntry>& entries() const { return entries_; }
54
55 private:
56 std::vector<NetLogEntry> entries_;
57 };
58
59 class TestSocketPerformanceWatcher : public SocketPerformanceWatcher {
60 public:
61 ~TestSocketPerformanceWatcher() override = default;
62
ShouldNotifyUpdatedRTT() const63 bool ShouldNotifyUpdatedRTT() const override { return false; }
64
OnUpdatedRTTAvailable(const base::TimeDelta & rtt)65 void OnUpdatedRTTAvailable(const base::TimeDelta& rtt) override {}
66
OnConnectionChanged()67 void OnConnectionChanged() override {}
68 };
69
70 class TestSocketPerformanceWatcherFactory
71 : public SocketPerformanceWatcherFactory {
72 public:
73 ~TestSocketPerformanceWatcherFactory() override = default;
74
CreateSocketPerformanceWatcher(const Protocol protocol,const IPAddress & ip_address)75 std::unique_ptr<SocketPerformanceWatcher> CreateSocketPerformanceWatcher(
76 const Protocol protocol,
77 const IPAddress& ip_address) override {
78 return std::make_unique<TestSocketPerformanceWatcher>();
79 }
80 };
81
82 class StreamAttemptHelper {
83 public:
StreamAttemptHelper(StreamAttemptParams * params,IPEndPoint ip_endpoint)84 StreamAttemptHelper(StreamAttemptParams* params, IPEndPoint ip_endpoint)
85 : attempt_(std::make_unique<TcpStreamAttempt>(params, ip_endpoint)) {}
86
Start()87 int Start() {
88 return attempt_->Start(base::BindOnce(&StreamAttemptHelper::OnComplete,
89 base::Unretained(this)));
90 }
91
WaitForCompletion()92 int WaitForCompletion() {
93 if (result_.has_value()) {
94 return *result_;
95 }
96
97 base::RunLoop loop;
98 completion_closure_ = loop.QuitClosure();
99 loop.Run();
100
101 return *result_;
102 }
103
attempt()104 TcpStreamAttempt* attempt() { return attempt_.get(); }
105
106 private:
OnComplete(int rv)107 void OnComplete(int rv) {
108 result_ = rv;
109 if (completion_closure_) {
110 std::move(completion_closure_).Run();
111 }
112 }
113
114 std::unique_ptr<TcpStreamAttempt> attempt_;
115 base::OnceClosure completion_closure_;
116 std::optional<int> result_;
117 };
118
119 } // namespace
120
121 class TcpStreamAttemptTest : public TestWithTaskEnvironment {
122 public:
TcpStreamAttemptTest()123 TcpStreamAttemptTest()
124 : TestWithTaskEnvironment(
125 base::test::TaskEnvironment::TimeSource::MOCK_TIME),
126 socket_factory_(NetLog::Get()),
127 params_(&socket_factory_,
128 /*ssl_client_context=*/nullptr,
129 /*socket_performance_watcher_factory=*/nullptr,
130 /*network_quality_estimator=*/nullptr,
131 /*net_log=*/NetLog::Get()) {}
132
133 protected:
EnableSocketPerformanceWatcher()134 void EnableSocketPerformanceWatcher() {
135 params_.socket_performance_watcher_factory =
136 &socket_performance_watcher_factory_;
137 }
138
socket_factory()139 MockTransportClientSocketFactory& socket_factory() { return socket_factory_; }
140
params()141 StreamAttemptParams* params() { return ¶ms_; }
142
143 private:
144 MockTransportClientSocketFactory socket_factory_;
145 TestSocketPerformanceWatcherFactory socket_performance_watcher_factory_;
146 StreamAttemptParams params_;
147 };
148
TEST_F(TcpStreamAttemptTest,SuccessSync)149 TEST_F(TcpStreamAttemptTest, SuccessSync) {
150 socket_factory().set_default_client_socket_type(
151 MockTransportClientSocketFactory::Type::kSynchronous);
152 StreamAttemptHelper helper(params(), MakeIPEndPoint("192.0.2.1"));
153 int rv = helper.Start();
154 EXPECT_THAT(rv, IsOk());
155
156 std::unique_ptr<StreamSocket> stream_socket =
157 helper.attempt()->ReleaseStreamSocket();
158 ASSERT_TRUE(stream_socket);
159 ASSERT_FALSE(helper.attempt()->connect_timing().connect_start.is_null());
160 ASSERT_FALSE(helper.attempt()->connect_timing().connect_end.is_null());
161 ASSERT_EQ(helper.attempt()->GetLoadState(), LOAD_STATE_IDLE);
162 }
163
TEST_F(TcpStreamAttemptTest,SuccessAsync)164 TEST_F(TcpStreamAttemptTest, SuccessAsync) {
165 socket_factory().set_default_client_socket_type(
166 MockTransportClientSocketFactory::Type::kPending);
167 StreamAttemptHelper helper(params(), MakeIPEndPoint("192.0.2.1"));
168 int rv = helper.Start();
169 EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
170 ASSERT_EQ(helper.attempt()->GetLoadState(), LOAD_STATE_CONNECTING);
171
172 rv = helper.WaitForCompletion();
173 EXPECT_THAT(rv, IsOk());
174
175 std::unique_ptr<StreamSocket> stream_socket =
176 helper.attempt()->ReleaseStreamSocket();
177 ASSERT_TRUE(stream_socket);
178 ASSERT_FALSE(helper.attempt()->connect_timing().connect_start.is_null());
179 ASSERT_FALSE(helper.attempt()->connect_timing().connect_end.is_null());
180 ASSERT_EQ(helper.attempt()->GetLoadState(), LOAD_STATE_IDLE);
181 }
182
TEST_F(TcpStreamAttemptTest,FailureSync)183 TEST_F(TcpStreamAttemptTest, FailureSync) {
184 socket_factory().set_default_client_socket_type(
185 MockTransportClientSocketFactory::Type::kFailing);
186 StreamAttemptHelper helper(params(), MakeIPEndPoint("192.0.2.1"));
187 int rv = helper.Start();
188 EXPECT_THAT(rv, IsError(ERR_CONNECTION_FAILED));
189 ASSERT_EQ(helper.attempt()->GetLoadState(), LOAD_STATE_IDLE);
190 }
191
TEST_F(TcpStreamAttemptTest,FailureAsync)192 TEST_F(TcpStreamAttemptTest, FailureAsync) {
193 socket_factory().set_default_client_socket_type(
194 MockTransportClientSocketFactory::Type::kPendingFailing);
195 StreamAttemptHelper helper(params(), MakeIPEndPoint("192.0.2.1"));
196 int rv = helper.Start();
197 EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
198
199 rv = helper.WaitForCompletion();
200 EXPECT_THAT(rv, IsError(ERR_CONNECTION_FAILED));
201 ASSERT_EQ(helper.attempt()->GetLoadState(), LOAD_STATE_IDLE);
202 }
203
TEST_F(TcpStreamAttemptTest,Timeout)204 TEST_F(TcpStreamAttemptTest, Timeout) {
205 socket_factory().set_default_client_socket_type(
206 MockTransportClientSocketFactory::Type::kStalled);
207 StreamAttemptHelper helper(params(), MakeIPEndPoint("192.0.2.1"));
208 int rv = helper.Start();
209 EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
210
211 FastForwardBy(TcpStreamAttempt::kTcpHandshakeTimeout);
212 rv = helper.WaitForCompletion();
213 EXPECT_THAT(rv, IsError(ERR_TIMED_OUT));
214 ASSERT_FALSE(helper.attempt()->ReleaseStreamSocket());
215 ASSERT_EQ(helper.attempt()->GetLoadState(), LOAD_STATE_IDLE);
216 }
217
TEST_F(TcpStreamAttemptTest,Abort)218 TEST_F(TcpStreamAttemptTest, Abort) {
219 socket_factory().set_default_client_socket_type(
220 MockTransportClientSocketFactory::Type::kPending);
221 auto helper = std::make_unique<StreamAttemptHelper>(
222 params(), MakeIPEndPoint("192.0.2.1"));
223 int rv = helper->Start();
224 EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
225
226 NetLogObserver observer(helper->attempt()->net_log().net_log());
227 // Drop the helpr to abort the attempt.
228 helper.reset();
229
230 ASSERT_EQ(observer.entries().size(), 1u);
231 std::optional<int> error =
232 observer.entries().front().params.FindInt("net_error");
233 ASSERT_TRUE(error.has_value());
234 EXPECT_THAT(*error, IsError(ERR_ABORTED));
235 }
236
TEST_F(TcpStreamAttemptTest,SocketPerformanceWatcher)237 TEST_F(TcpStreamAttemptTest, SocketPerformanceWatcher) {
238 EnableSocketPerformanceWatcher();
239
240 socket_factory().set_default_client_socket_type(
241 MockTransportClientSocketFactory::Type::kSynchronous);
242 StreamAttemptHelper helper(params(), MakeIPEndPoint("192.0.2.1"));
243 int rv = helper.Start();
244 EXPECT_THAT(rv, IsOk());
245
246 std::unique_ptr<StreamSocket> stream_socket =
247 helper.attempt()->ReleaseStreamSocket();
248 ASSERT_TRUE(stream_socket);
249 ASSERT_FALSE(helper.attempt()->connect_timing().connect_start.is_null());
250 ASSERT_FALSE(helper.attempt()->connect_timing().connect_end.is_null());
251 ASSERT_EQ(helper.attempt()->GetLoadState(), LOAD_STATE_IDLE);
252 }
253
254 } // namespace net
255