• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2024 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/socket/tcp_stream_attempt.h"
6 
7 #include <optional>
8 #include <string_view>
9 
10 #include "base/functional/callback_forward.h"
11 #include "base/test/task_environment.h"
12 #include "base/time/time.h"
13 #include "net/base/ip_endpoint.h"
14 #include "net/base/net_errors.h"
15 #include "net/log/net_log_capture_mode.h"
16 #include "net/log/net_log_entry.h"
17 #include "net/socket/socket_performance_watcher.h"
18 #include "net/socket/socket_performance_watcher_factory.h"
19 #include "net/socket/stream_attempt.h"
20 #include "net/socket/transport_client_socket_pool_test_util.h"
21 #include "net/test/gtest_util.h"
22 #include "net/test/test_with_task_environment.h"
23 #include "testing/gtest/include/gtest/gtest.h"
24 
25 using net::test::IsError;
26 using net::test::IsOk;
27 
28 namespace net {
29 
30 namespace {
31 
MakeIPEndPoint(std::string_view ip_literal,uint16_t port=80)32 IPEndPoint MakeIPEndPoint(std::string_view ip_literal, uint16_t port = 80) {
33   std::optional<IPAddress> ip = IPAddress::FromIPLiteral(std::move(ip_literal));
34   return IPEndPoint(*ip, port);
35 }
36 
37 class NetLogObserver : public NetLog::ThreadSafeObserver {
38  public:
NetLogObserver(NetLog * net_log)39   explicit NetLogObserver(NetLog* net_log) {
40     net_log->AddObserver(this, NetLogCaptureMode::kEverything);
41   }
42 
~NetLogObserver()43   ~NetLogObserver() override {
44     if (net_log()) {
45       net_log()->RemoveObserver(this);
46     }
47   }
48 
OnAddEntry(const NetLogEntry & entry)49   void OnAddEntry(const NetLogEntry& entry) override {
50     entries_.emplace_back(entry.Clone());
51   }
52 
entries() const53   const std::vector<NetLogEntry>& entries() const { return entries_; }
54 
55  private:
56   std::vector<NetLogEntry> entries_;
57 };
58 
59 class TestSocketPerformanceWatcher : public SocketPerformanceWatcher {
60  public:
61   ~TestSocketPerformanceWatcher() override = default;
62 
ShouldNotifyUpdatedRTT() const63   bool ShouldNotifyUpdatedRTT() const override { return false; }
64 
OnUpdatedRTTAvailable(const base::TimeDelta & rtt)65   void OnUpdatedRTTAvailable(const base::TimeDelta& rtt) override {}
66 
OnConnectionChanged()67   void OnConnectionChanged() override {}
68 };
69 
70 class TestSocketPerformanceWatcherFactory
71     : public SocketPerformanceWatcherFactory {
72  public:
73   ~TestSocketPerformanceWatcherFactory() override = default;
74 
CreateSocketPerformanceWatcher(const Protocol protocol,const IPAddress & ip_address)75   std::unique_ptr<SocketPerformanceWatcher> CreateSocketPerformanceWatcher(
76       const Protocol protocol,
77       const IPAddress& ip_address) override {
78     return std::make_unique<TestSocketPerformanceWatcher>();
79   }
80 };
81 
82 class StreamAttemptHelper {
83  public:
StreamAttemptHelper(StreamAttemptParams * params,IPEndPoint ip_endpoint)84   StreamAttemptHelper(StreamAttemptParams* params, IPEndPoint ip_endpoint)
85       : attempt_(std::make_unique<TcpStreamAttempt>(params, ip_endpoint)) {}
86 
Start()87   int Start() {
88     return attempt_->Start(base::BindOnce(&StreamAttemptHelper::OnComplete,
89                                           base::Unretained(this)));
90   }
91 
WaitForCompletion()92   int WaitForCompletion() {
93     if (result_.has_value()) {
94       return *result_;
95     }
96 
97     base::RunLoop loop;
98     completion_closure_ = loop.QuitClosure();
99     loop.Run();
100 
101     return *result_;
102   }
103 
attempt()104   TcpStreamAttempt* attempt() { return attempt_.get(); }
105 
106  private:
OnComplete(int rv)107   void OnComplete(int rv) {
108     result_ = rv;
109     if (completion_closure_) {
110       std::move(completion_closure_).Run();
111     }
112   }
113 
114   std::unique_ptr<TcpStreamAttempt> attempt_;
115   base::OnceClosure completion_closure_;
116   std::optional<int> result_;
117 };
118 
119 }  // namespace
120 
121 class TcpStreamAttemptTest : public TestWithTaskEnvironment {
122  public:
TcpStreamAttemptTest()123   TcpStreamAttemptTest()
124       : TestWithTaskEnvironment(
125             base::test::TaskEnvironment::TimeSource::MOCK_TIME),
126         socket_factory_(NetLog::Get()),
127         params_(&socket_factory_,
128                 /*ssl_client_context=*/nullptr,
129                 /*socket_performance_watcher_factory=*/nullptr,
130                 /*network_quality_estimator=*/nullptr,
131                 /*net_log=*/NetLog::Get()) {}
132 
133  protected:
EnableSocketPerformanceWatcher()134   void EnableSocketPerformanceWatcher() {
135     params_.socket_performance_watcher_factory =
136         &socket_performance_watcher_factory_;
137   }
138 
socket_factory()139   MockTransportClientSocketFactory& socket_factory() { return socket_factory_; }
140 
params()141   StreamAttemptParams* params() { return &params_; }
142 
143  private:
144   MockTransportClientSocketFactory socket_factory_;
145   TestSocketPerformanceWatcherFactory socket_performance_watcher_factory_;
146   StreamAttemptParams params_;
147 };
148 
TEST_F(TcpStreamAttemptTest,SuccessSync)149 TEST_F(TcpStreamAttemptTest, SuccessSync) {
150   socket_factory().set_default_client_socket_type(
151       MockTransportClientSocketFactory::Type::kSynchronous);
152   StreamAttemptHelper helper(params(), MakeIPEndPoint("192.0.2.1"));
153   int rv = helper.Start();
154   EXPECT_THAT(rv, IsOk());
155 
156   std::unique_ptr<StreamSocket> stream_socket =
157       helper.attempt()->ReleaseStreamSocket();
158   ASSERT_TRUE(stream_socket);
159   ASSERT_FALSE(helper.attempt()->connect_timing().connect_start.is_null());
160   ASSERT_FALSE(helper.attempt()->connect_timing().connect_end.is_null());
161   ASSERT_EQ(helper.attempt()->GetLoadState(), LOAD_STATE_IDLE);
162 }
163 
TEST_F(TcpStreamAttemptTest,SuccessAsync)164 TEST_F(TcpStreamAttemptTest, SuccessAsync) {
165   socket_factory().set_default_client_socket_type(
166       MockTransportClientSocketFactory::Type::kPending);
167   StreamAttemptHelper helper(params(), MakeIPEndPoint("192.0.2.1"));
168   int rv = helper.Start();
169   EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
170   ASSERT_EQ(helper.attempt()->GetLoadState(), LOAD_STATE_CONNECTING);
171 
172   rv = helper.WaitForCompletion();
173   EXPECT_THAT(rv, IsOk());
174 
175   std::unique_ptr<StreamSocket> stream_socket =
176       helper.attempt()->ReleaseStreamSocket();
177   ASSERT_TRUE(stream_socket);
178   ASSERT_FALSE(helper.attempt()->connect_timing().connect_start.is_null());
179   ASSERT_FALSE(helper.attempt()->connect_timing().connect_end.is_null());
180   ASSERT_EQ(helper.attempt()->GetLoadState(), LOAD_STATE_IDLE);
181 }
182 
TEST_F(TcpStreamAttemptTest,FailureSync)183 TEST_F(TcpStreamAttemptTest, FailureSync) {
184   socket_factory().set_default_client_socket_type(
185       MockTransportClientSocketFactory::Type::kFailing);
186   StreamAttemptHelper helper(params(), MakeIPEndPoint("192.0.2.1"));
187   int rv = helper.Start();
188   EXPECT_THAT(rv, IsError(ERR_CONNECTION_FAILED));
189   ASSERT_EQ(helper.attempt()->GetLoadState(), LOAD_STATE_IDLE);
190 }
191 
TEST_F(TcpStreamAttemptTest,FailureAsync)192 TEST_F(TcpStreamAttemptTest, FailureAsync) {
193   socket_factory().set_default_client_socket_type(
194       MockTransportClientSocketFactory::Type::kPendingFailing);
195   StreamAttemptHelper helper(params(), MakeIPEndPoint("192.0.2.1"));
196   int rv = helper.Start();
197   EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
198 
199   rv = helper.WaitForCompletion();
200   EXPECT_THAT(rv, IsError(ERR_CONNECTION_FAILED));
201   ASSERT_EQ(helper.attempt()->GetLoadState(), LOAD_STATE_IDLE);
202 }
203 
TEST_F(TcpStreamAttemptTest,Timeout)204 TEST_F(TcpStreamAttemptTest, Timeout) {
205   socket_factory().set_default_client_socket_type(
206       MockTransportClientSocketFactory::Type::kStalled);
207   StreamAttemptHelper helper(params(), MakeIPEndPoint("192.0.2.1"));
208   int rv = helper.Start();
209   EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
210 
211   FastForwardBy(TcpStreamAttempt::kTcpHandshakeTimeout);
212   rv = helper.WaitForCompletion();
213   EXPECT_THAT(rv, IsError(ERR_TIMED_OUT));
214   ASSERT_FALSE(helper.attempt()->ReleaseStreamSocket());
215   ASSERT_EQ(helper.attempt()->GetLoadState(), LOAD_STATE_IDLE);
216 }
217 
TEST_F(TcpStreamAttemptTest,Abort)218 TEST_F(TcpStreamAttemptTest, Abort) {
219   socket_factory().set_default_client_socket_type(
220       MockTransportClientSocketFactory::Type::kPending);
221   auto helper = std::make_unique<StreamAttemptHelper>(
222       params(), MakeIPEndPoint("192.0.2.1"));
223   int rv = helper->Start();
224   EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
225 
226   NetLogObserver observer(helper->attempt()->net_log().net_log());
227   // Drop the helpr to abort the attempt.
228   helper.reset();
229 
230   ASSERT_EQ(observer.entries().size(), 1u);
231   std::optional<int> error =
232       observer.entries().front().params.FindInt("net_error");
233   ASSERT_TRUE(error.has_value());
234   EXPECT_THAT(*error, IsError(ERR_ABORTED));
235 }
236 
TEST_F(TcpStreamAttemptTest,SocketPerformanceWatcher)237 TEST_F(TcpStreamAttemptTest, SocketPerformanceWatcher) {
238   EnableSocketPerformanceWatcher();
239 
240   socket_factory().set_default_client_socket_type(
241       MockTransportClientSocketFactory::Type::kSynchronous);
242   StreamAttemptHelper helper(params(), MakeIPEndPoint("192.0.2.1"));
243   int rv = helper.Start();
244   EXPECT_THAT(rv, IsOk());
245 
246   std::unique_ptr<StreamSocket> stream_socket =
247       helper.attempt()->ReleaseStreamSocket();
248   ASSERT_TRUE(stream_socket);
249   ASSERT_FALSE(helper.attempt()->connect_timing().connect_start.is_null());
250   ASSERT_FALSE(helper.attempt()->connect_timing().connect_end.is_null());
251   ASSERT_EQ(helper.attempt()->GetLoadState(), LOAD_STATE_IDLE);
252 }
253 
254 }  // namespace net
255