1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #ifdef UNSAFE_BUFFERS_BUILD
6 // TODO(crbug.com/40284755): Remove this and spanify to fix the errors.
7 #pragma allow_unsafe_buffers
8 #endif
9
10 #include "net/socket/socket_test_util.h"
11
12 #include <inttypes.h> // For SCNx64
13 #include <stdint.h>
14 #include <stdio.h>
15
16 #include <memory>
17 #include <ostream>
18 #include <string>
19 #include <string_view>
20 #include <utility>
21 #include <vector>
22
23 #include "base/compiler_specific.h"
24 #include "base/files/file_util.h"
25 #include "base/functional/bind.h"
26 #include "base/functional/callback_helpers.h"
27 #include "base/location.h"
28 #include "base/logging.h"
29 #include "base/memory/raw_ptr.h"
30 #include "base/notreached.h"
31 #include "base/rand_util.h"
32 #include "base/ranges/algorithm.h"
33 #include "base/run_loop.h"
34 #include "base/task/single_thread_task_runner.h"
35 #include "base/time/time.h"
36 #include "build/build_config.h"
37 #include "net/base/address_family.h"
38 #include "net/base/address_list.h"
39 #include "net/base/auth.h"
40 #include "net/base/completion_once_callback.h"
41 #include "net/base/hex_utils.h"
42 #include "net/base/ip_address.h"
43 #include "net/base/load_timing_info.h"
44 #include "net/base/net_errors.h"
45 #include "net/base/proxy_server.h"
46 #include "net/http/http_network_session.h"
47 #include "net/http/http_request_headers.h"
48 #include "net/http/http_response_headers.h"
49 #include "net/log/net_log_source.h"
50 #include "net/log/net_log_source_type.h"
51 #include "net/socket/connect_job.h"
52 #include "net/socket/socket.h"
53 #include "net/socket/stream_socket.h"
54 #include "net/socket/websocket_endpoint_lock_manager.h"
55 #include "net/ssl/ssl_cert_request_info.h"
56 #include "net/ssl/ssl_connection_status_flags.h"
57 #include "net/ssl/ssl_info.h"
58 #include "net/traffic_annotation/network_traffic_annotation.h"
59 #include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
60 #include "testing/gtest/include/gtest/gtest.h"
61 #include "third_party/abseil-cpp/absl/strings/ascii.h"
62
63 #if BUILDFLAG(IS_ANDROID)
64 #include "base/android/build_info.h"
65 #endif
66
67 #define NET_TRACE(level, s) VLOG(level) << s << __FUNCTION__ << "() "
68
69 namespace net {
70 namespace {
71
AsciifyHigh(char x)72 inline char AsciifyHigh(char x) {
73 char nybble = static_cast<char>((x >> 4) & 0x0F);
74 return nybble + ((nybble < 0x0A) ? '0' : 'A' - 10);
75 }
76
AsciifyLow(char x)77 inline char AsciifyLow(char x) {
78 char nybble = static_cast<char>((x >> 0) & 0x0F);
79 return nybble + ((nybble < 0x0A) ? '0' : 'A' - 10);
80 }
81
Asciify(char x)82 inline char Asciify(char x) {
83 return absl::ascii_isprint(static_cast<unsigned char>(x)) ? x : '.';
84 }
85
DumpData(const char * data,int data_len)86 void DumpData(const char* data, int data_len) {
87 if (logging::LOGGING_INFO < logging::GetMinLogLevel()) {
88 return;
89 }
90 DVLOG(1) << "Length: " << data_len;
91 const char* pfx = "Data: ";
92 if (!data || (data_len <= 0)) {
93 DVLOG(1) << pfx << "<None>";
94 } else {
95 int i;
96 for (i = 0; i <= (data_len - 4); i += 4) {
97 DVLOG(1) << pfx << AsciifyHigh(data[i + 0]) << AsciifyLow(data[i + 0])
98 << AsciifyHigh(data[i + 1]) << AsciifyLow(data[i + 1])
99 << AsciifyHigh(data[i + 2]) << AsciifyLow(data[i + 2])
100 << AsciifyHigh(data[i + 3]) << AsciifyLow(data[i + 3]) << " '"
101 << Asciify(data[i + 0]) << Asciify(data[i + 1])
102 << Asciify(data[i + 2]) << Asciify(data[i + 3]) << "'";
103 pfx = " ";
104 }
105 // Take care of any 'trailing' bytes, if data_len was not a multiple of 4.
106 switch (data_len - i) {
107 case 3:
108 DVLOG(1) << pfx << AsciifyHigh(data[i + 0]) << AsciifyLow(data[i + 0])
109 << AsciifyHigh(data[i + 1]) << AsciifyLow(data[i + 1])
110 << AsciifyHigh(data[i + 2]) << AsciifyLow(data[i + 2])
111 << " '" << Asciify(data[i + 0]) << Asciify(data[i + 1])
112 << Asciify(data[i + 2]) << " '";
113 break;
114 case 2:
115 DVLOG(1) << pfx << AsciifyHigh(data[i + 0]) << AsciifyLow(data[i + 0])
116 << AsciifyHigh(data[i + 1]) << AsciifyLow(data[i + 1])
117 << " '" << Asciify(data[i + 0]) << Asciify(data[i + 1])
118 << " '";
119 break;
120 case 1:
121 DVLOG(1) << pfx << AsciifyHigh(data[i + 0]) << AsciifyLow(data[i + 0])
122 << " '" << Asciify(data[i + 0]) << " '";
123 break;
124 }
125 }
126 }
127
128 template <MockReadWriteType type>
DumpMockReadWrite(const MockReadWrite<type> & r)129 void DumpMockReadWrite(const MockReadWrite<type>& r) {
130 if (logging::LOGGING_INFO < logging::GetMinLogLevel()) {
131 return;
132 }
133 DVLOG(1) << "Async: " << (r.mode == ASYNC) << "\nResult: " << r.result;
134 DumpData(r.data, r.data_len);
135 const char* stop = (r.sequence_number & MockRead::STOPLOOP) ? " (STOP)" : "";
136 DVLOG(1) << "Stage: " << (r.sequence_number & ~MockRead::STOPLOOP) << stop;
137 }
138
RunClosureIfNonNull(base::OnceClosure closure)139 void RunClosureIfNonNull(base::OnceClosure closure) {
140 if (!closure.is_null()) {
141 std::move(closure).Run();
142 }
143 }
144
145 } // namespace
146
147 MockConnectCompleter::MockConnectCompleter() = default;
148
149 MockConnectCompleter::~MockConnectCompleter() = default;
150
SetCallback(CompletionOnceCallback callback)151 void MockConnectCompleter::SetCallback(CompletionOnceCallback callback) {
152 CHECK(!callback_);
153 callback_ = std::move(callback);
154 }
155
Complete(int result)156 void MockConnectCompleter::Complete(int result) {
157 CHECK(callback_);
158 std::move(callback_).Run(result);
159 }
160
MockConnect()161 MockConnect::MockConnect() : mode(ASYNC), result(OK) {
162 peer_addr = IPEndPoint(IPAddress(192, 0, 2, 33), 0);
163 }
164
MockConnect(IoMode io_mode,int r)165 MockConnect::MockConnect(IoMode io_mode, int r) : mode(io_mode), result(r) {
166 peer_addr = IPEndPoint(IPAddress(192, 0, 2, 33), 0);
167 }
168
MockConnect(IoMode io_mode,int r,IPEndPoint addr)169 MockConnect::MockConnect(IoMode io_mode, int r, IPEndPoint addr)
170 : mode(io_mode), result(r), peer_addr(addr) {}
171
MockConnect(IoMode io_mode,int r,IPEndPoint addr,bool first_attempt_fails)172 MockConnect::MockConnect(IoMode io_mode,
173 int r,
174 IPEndPoint addr,
175 bool first_attempt_fails)
176 : mode(io_mode),
177 result(r),
178 peer_addr(addr),
179 first_attempt_fails(first_attempt_fails) {}
180
MockConnect(MockConnectCompleter * completer)181 MockConnect::MockConnect(MockConnectCompleter* completer)
182 : mode(ASYNC), result(OK), completer(completer) {}
183
184 MockConnect::~MockConnect() = default;
185
MockConfirm()186 MockConfirm::MockConfirm() : mode(SYNCHRONOUS), result(OK) {}
187
MockConfirm(IoMode io_mode,int r)188 MockConfirm::MockConfirm(IoMode io_mode, int r) : mode(io_mode), result(r) {}
189
190 MockConfirm::~MockConfirm() = default;
191
IsIdle() const192 bool SocketDataProvider::IsIdle() const {
193 return true;
194 }
195
Initialize(AsyncSocket * socket)196 void SocketDataProvider::Initialize(AsyncSocket* socket) {
197 CHECK(!socket_);
198 CHECK(socket);
199 socket_ = socket;
200 Reset();
201 }
202
DetachSocket()203 void SocketDataProvider::DetachSocket() {
204 CHECK(socket_);
205 socket_ = nullptr;
206 }
207
208 SocketDataProvider::SocketDataProvider() = default;
209
~SocketDataProvider()210 SocketDataProvider::~SocketDataProvider() {
211 if (socket_)
212 socket_->OnDataProviderDestroyed();
213 }
214
StaticSocketDataHelper(base::span<const MockRead> reads,base::span<const MockWrite> writes)215 StaticSocketDataHelper::StaticSocketDataHelper(
216 base::span<const MockRead> reads,
217 base::span<const MockWrite> writes)
218 : reads_(reads), writes_(writes) {}
219
220 StaticSocketDataHelper::~StaticSocketDataHelper() = default;
221
PeekRead() const222 const MockRead& StaticSocketDataHelper::PeekRead() const {
223 CHECK(!AllReadDataConsumed());
224 return reads_[read_index_];
225 }
226
PeekWrite() const227 const MockWrite& StaticSocketDataHelper::PeekWrite() const {
228 CHECK(!AllWriteDataConsumed());
229 return writes_[write_index_];
230 }
231
AdvanceRead()232 const MockRead& StaticSocketDataHelper::AdvanceRead() {
233 CHECK(!AllReadDataConsumed());
234 return reads_[read_index_++];
235 }
236
AdvanceWrite()237 const MockWrite& StaticSocketDataHelper::AdvanceWrite() {
238 CHECK(!AllWriteDataConsumed());
239 return writes_[write_index_++];
240 }
241
Reset()242 void StaticSocketDataHelper::Reset() {
243 read_index_ = 0;
244 write_index_ = 0;
245 }
246
VerifyWriteData(const std::string & data,SocketDataPrinter * printer)247 bool StaticSocketDataHelper::VerifyWriteData(const std::string& data,
248 SocketDataPrinter* printer) {
249 CHECK(!AllWriteDataConsumed());
250 // Check that the actual data matches the expectations, skipping over any
251 // pause events.
252 const MockWrite& next_write = PeekRealWrite();
253 if (!next_write.data)
254 return true;
255
256 // Note: Partial writes are supported here. If the expected data
257 // is a match, but shorter than the write actually written, that is legal.
258 // Example:
259 // Application writes "foobarbaz" (9 bytes)
260 // Expected write was "foo" (3 bytes)
261 // This is a success, and the function returns true.
262 std::string expected_data(next_write.data, next_write.data_len);
263 std::string actual_data(data.substr(0, next_write.data_len));
264 if (printer) {
265 EXPECT_TRUE(actual_data == expected_data)
266 << "Actual formatted write data:\n"
267 << printer->PrintWrite(data) << "Expected formatted write data:\n"
268 << printer->PrintWrite(expected_data) << "Actual raw write data:\n"
269 << HexDump(data) << "Expected raw write data:\n"
270 << HexDump(expected_data);
271 } else {
272 EXPECT_TRUE(actual_data == expected_data)
273 << "Actual write data:\n"
274 << HexDump(data) << "Expected write data:\n"
275 << HexDump(expected_data);
276 }
277 return expected_data == actual_data;
278 }
279
ExpectAllReadDataConsumed(SocketDataPrinter * printer) const280 void StaticSocketDataHelper::ExpectAllReadDataConsumed(
281 SocketDataPrinter* printer) const {
282 if (AllReadDataConsumed()) {
283 return;
284 }
285
286 std::ostringstream msg;
287 if (read_index_ < read_count()) {
288 msg << "Unconsumed reads:\n";
289 for (size_t i = read_index_; i < read_count(); i++) {
290 msg << (reads_[i].mode == ASYNC ? "ASYNC" : "SYNC") << " MockRead seq "
291 << reads_[i].sequence_number << ":\n";
292 if (reads_[i].result != OK) {
293 msg << "Result: " << reads_[i].result << "\n";
294 }
295 if (reads_[i].data) {
296 std::string data(reads_[i].data, reads_[i].data_len);
297 if (printer) {
298 msg << printer->PrintWrite(data);
299 }
300 msg << HexDump(data);
301 }
302 }
303 }
304 EXPECT_TRUE(AllReadDataConsumed()) << msg.str();
305 }
306
ExpectAllWriteDataConsumed(SocketDataPrinter * printer) const307 void StaticSocketDataHelper::ExpectAllWriteDataConsumed(
308 SocketDataPrinter* printer) const {
309 if (AllWriteDataConsumed()) {
310 return;
311 }
312
313 std::ostringstream msg;
314 if (write_index_ < write_count()) {
315 msg << "Unconsumed writes:\n";
316 for (size_t i = write_index_; i < write_count(); i++) {
317 msg << (writes_[i].mode == ASYNC ? "ASYNC" : "SYNC") << " MockWrite seq "
318 << writes_[i].sequence_number << ":\n";
319 if (writes_[i].result != OK) {
320 msg << "Result: " << writes_[i].result << "\n";
321 }
322 if (writes_[i].data) {
323 std::string data(writes_[i].data, writes_[i].data_len);
324 if (printer) {
325 msg << printer->PrintWrite(data);
326 }
327 msg << HexDump(data);
328 }
329 }
330 }
331 EXPECT_TRUE(AllWriteDataConsumed()) << msg.str();
332 }
333
PeekRealWrite() const334 const MockWrite& StaticSocketDataHelper::PeekRealWrite() const {
335 for (size_t i = write_index_; i < write_count(); i++) {
336 if (writes_[i].mode != ASYNC || writes_[i].result != ERR_IO_PENDING)
337 return writes_[i];
338 }
339
340 NOTREACHED() << "No write data available.";
341 }
342
StaticSocketDataProvider()343 StaticSocketDataProvider::StaticSocketDataProvider()
344 : StaticSocketDataProvider(base::span<const MockRead>(),
345 base::span<const MockWrite>()) {}
346
StaticSocketDataProvider(base::span<const MockRead> reads,base::span<const MockWrite> writes)347 StaticSocketDataProvider::StaticSocketDataProvider(
348 base::span<const MockRead> reads,
349 base::span<const MockWrite> writes)
350 : helper_(reads, writes) {}
351
352 StaticSocketDataProvider::~StaticSocketDataProvider() = default;
353
Pause()354 void StaticSocketDataProvider::Pause() {
355 paused_ = true;
356 }
357
Resume()358 void StaticSocketDataProvider::Resume() {
359 paused_ = false;
360 }
361
OnRead()362 MockRead StaticSocketDataProvider::OnRead() {
363 if (AllReadDataConsumed()) {
364 const net::MockRead pending_read(net::SYNCHRONOUS, net::ERR_IO_PENDING);
365 return pending_read;
366 }
367
368 return helper_.AdvanceRead();
369 }
370
OnWrite(const std::string & data)371 MockWriteResult StaticSocketDataProvider::OnWrite(const std::string& data) {
372 if (helper_.write_count() == 0) {
373 // Not using mock writes; succeed synchronously.
374 return MockWriteResult(SYNCHRONOUS, data.length());
375 }
376 if (printer_) {
377 EXPECT_FALSE(helper_.AllWriteDataConsumed())
378 << "No more mock data to match write:\nFormatted write data:\n"
379 << printer_->PrintWrite(data) << "Raw write data:\n"
380 << HexDump(data);
381 } else {
382 EXPECT_FALSE(helper_.AllWriteDataConsumed())
383 << "No more mock data to match write:\nRaw write data:\n"
384 << HexDump(data);
385 }
386 if (helper_.AllWriteDataConsumed()) {
387 return MockWriteResult(SYNCHRONOUS, ERR_UNEXPECTED);
388 }
389
390 // Check that what we are writing matches the expectation.
391 // Then give the mocked return value.
392 if (!helper_.VerifyWriteData(data, printer_))
393 return MockWriteResult(SYNCHRONOUS, ERR_UNEXPECTED);
394
395 const MockWrite& next_write = helper_.AdvanceWrite();
396 // In the case that the write was successful, return the number of bytes
397 // written. Otherwise return the error code.
398 int result =
399 next_write.result == OK ? next_write.data_len : next_write.result;
400 return MockWriteResult(next_write.mode, result);
401 }
402
AllReadDataConsumed() const403 bool StaticSocketDataProvider::AllReadDataConsumed() const {
404 return paused_ || helper_.AllReadDataConsumed();
405 }
406
AllWriteDataConsumed() const407 bool StaticSocketDataProvider::AllWriteDataConsumed() const {
408 return helper_.AllWriteDataConsumed();
409 }
410
Reset()411 void StaticSocketDataProvider::Reset() {
412 helper_.Reset();
413 }
414
SSLSocketDataProvider(IoMode mode,int result)415 SSLSocketDataProvider::SSLSocketDataProvider(IoMode mode, int result)
416 : connect(mode, result),
417 expected_ssl_version_min(kDefaultSSLVersionMin),
418 expected_ssl_version_max(kDefaultSSLVersionMax) {
419 SSLConnectionStatusSetVersion(SSL_CONNECTION_VERSION_TLS1_3,
420 &ssl_info.connection_status);
421 // Set to TLS_CHACHA20_POLY1305_SHA256
422 SSLConnectionStatusSetCipherSuite(0x1301, &ssl_info.connection_status);
423 }
424
SSLSocketDataProvider(MockConnectCompleter * completer)425 SSLSocketDataProvider::SSLSocketDataProvider(MockConnectCompleter* completer)
426 : connect(completer),
427 expected_ssl_version_min(kDefaultSSLVersionMin),
428 expected_ssl_version_max(kDefaultSSLVersionMax) {
429 SSLConnectionStatusSetVersion(SSL_CONNECTION_VERSION_TLS1_3,
430 &ssl_info.connection_status);
431 // Set to TLS_CHACHA20_POLY1305_SHA256
432 SSLConnectionStatusSetCipherSuite(0x1301, &ssl_info.connection_status);
433 }
434
435 SSLSocketDataProvider::SSLSocketDataProvider(
436 const SSLSocketDataProvider& other) = default;
437
438 SSLSocketDataProvider::~SSLSocketDataProvider() = default;
439
SequencedSocketData()440 SequencedSocketData::SequencedSocketData()
441 : SequencedSocketData(base::span<const MockRead>(),
442 base::span<const MockWrite>()) {}
443
SequencedSocketData(base::span<const MockRead> reads,base::span<const MockWrite> writes)444 SequencedSocketData::SequencedSocketData(base::span<const MockRead> reads,
445 base::span<const MockWrite> writes)
446 : helper_(reads, writes) {
447 // Check that reads and writes have a contiguous set of sequence numbers
448 // starting from 0 and working their way up, with no repeats and skipping
449 // no values.
450 int next_sequence_number = 0;
451 bool last_event_was_pause = false;
452
453 auto next_read = reads.begin();
454 auto next_write = writes.begin();
455 while (next_read != reads.end() || next_write != writes.end()) {
456 if (next_read != reads.end() &&
457 next_read->sequence_number == next_sequence_number) {
458 // Check if this is a pause.
459 if (next_read->mode == ASYNC && next_read->result == ERR_IO_PENDING) {
460 CHECK(!last_event_was_pause)
461 << "Two pauses in a row are not allowed: " << next_sequence_number;
462 last_event_was_pause = true;
463 } else if (last_event_was_pause) {
464 CHECK_EQ(ASYNC, next_read->mode)
465 << "A sync event after a pause makes no sense: "
466 << next_sequence_number;
467 CHECK_NE(ERR_IO_PENDING, next_read->result)
468 << "A pause event after a pause makes no sense: "
469 << next_sequence_number;
470 last_event_was_pause = false;
471 }
472
473 ++next_read;
474 ++next_sequence_number;
475 continue;
476 }
477 if (next_write != writes.end() &&
478 next_write->sequence_number == next_sequence_number) {
479 // Check if this is a pause.
480 if (next_write->mode == ASYNC && next_write->result == ERR_IO_PENDING) {
481 CHECK(!last_event_was_pause)
482 << "Two pauses in a row are not allowed: " << next_sequence_number;
483 last_event_was_pause = true;
484 } else if (last_event_was_pause) {
485 CHECK_EQ(ASYNC, next_write->mode)
486 << "A sync event after a pause makes no sense: "
487 << next_sequence_number;
488 CHECK_NE(ERR_IO_PENDING, next_write->result)
489 << "A pause event after a pause makes no sense: "
490 << next_sequence_number;
491 last_event_was_pause = false;
492 }
493
494 ++next_write;
495 ++next_sequence_number;
496 continue;
497 }
498 if (next_write != writes.end()) {
499 NOTREACHED() << "Sequence number " << next_write->sequence_number
500 << " not found where expected: " << next_sequence_number;
501 }
502 NOTREACHED() << "Too few writes, next expected sequence number: "
503 << next_sequence_number;
504 }
505
506 // Last event must not be a pause. For the final event to indicate the
507 // operation never completes, it should be SYNCHRONOUS and return
508 // ERR_IO_PENDING.
509 CHECK(!last_event_was_pause);
510
511 CHECK(next_read == reads.end());
512 CHECK(next_write == writes.end());
513 }
514
SequencedSocketData(const MockConnect & connect,base::span<const MockRead> reads,base::span<const MockWrite> writes)515 SequencedSocketData::SequencedSocketData(const MockConnect& connect,
516 base::span<const MockRead> reads,
517 base::span<const MockWrite> writes)
518 : SequencedSocketData(reads, writes) {
519 set_connect_data(connect);
520 }
OnRead()521 MockRead SequencedSocketData::OnRead() {
522 CHECK_EQ(IoState::kIdle, read_state_);
523 CHECK(!helper_.AllReadDataConsumed())
524 << "Application tried to read but there is no read data left";
525
526 NET_TRACE(1, " *** ") << "sequence_number: " << sequence_number_;
527 const MockRead& next_read = helper_.PeekRead();
528 NET_TRACE(1, " *** ") << "next_read: " << next_read.sequence_number;
529 CHECK_GE(next_read.sequence_number, sequence_number_);
530
531 if (next_read.sequence_number <= sequence_number_) {
532 if (next_read.mode == SYNCHRONOUS) {
533 NET_TRACE(1, " *** ") << "Returning synchronously";
534 DumpMockReadWrite(next_read);
535 helper_.AdvanceRead();
536 ++sequence_number_;
537 MaybePostWriteCompleteTask();
538 return next_read;
539 }
540
541 // If the result is ERR_IO_PENDING, then pause.
542 if (next_read.result == ERR_IO_PENDING) {
543 NET_TRACE(1, " *** ") << "Pausing read at: " << sequence_number_;
544 read_state_ = IoState::kPaused;
545 if (run_until_paused_run_loop_)
546 run_until_paused_run_loop_->Quit();
547 return MockRead(SYNCHRONOUS, ERR_IO_PENDING);
548 }
549 base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
550 FROM_HERE, base::BindOnce(&SequencedSocketData::OnReadComplete,
551 weak_factory_.GetWeakPtr()));
552 CHECK_NE(IoState::kCompleting, write_state_);
553 read_state_ = IoState::kCompleting;
554 } else if (next_read.mode == SYNCHRONOUS) {
555 ADD_FAILURE() << "Unable to perform synchronous IO while stopped";
556 return MockRead(SYNCHRONOUS, ERR_UNEXPECTED);
557 } else {
558 NET_TRACE(1, " *** ") << "Waiting for write to trigger read";
559 read_state_ = IoState::kPending;
560 }
561
562 return MockRead(SYNCHRONOUS, ERR_IO_PENDING);
563 }
564
OnWrite(const std::string & data)565 MockWriteResult SequencedSocketData::OnWrite(const std::string& data) {
566 CHECK_EQ(IoState::kIdle, write_state_);
567 if (printer_) {
568 CHECK(!helper_.AllWriteDataConsumed())
569 << "\nNo more mock data to match write:\nFormatted write data:\n"
570 << printer_->PrintWrite(data) << "Raw write data:\n"
571 << HexDump(data);
572 } else {
573 CHECK(!helper_.AllWriteDataConsumed())
574 << "\nNo more mock data to match write:\nRaw write data:\n"
575 << HexDump(data);
576 }
577
578 NET_TRACE(1, " *** ") << "sequence_number: " << sequence_number_;
579 const MockWrite& next_write = helper_.PeekWrite();
580 NET_TRACE(1, " *** ") << "next_write: " << next_write.sequence_number;
581 CHECK_GE(next_write.sequence_number, sequence_number_);
582
583 if (!helper_.VerifyWriteData(data, printer_))
584 return MockWriteResult(SYNCHRONOUS, ERR_UNEXPECTED);
585
586 if (next_write.sequence_number <= sequence_number_) {
587 if (next_write.mode == SYNCHRONOUS) {
588 helper_.AdvanceWrite();
589 ++sequence_number_;
590 MaybePostReadCompleteTask();
591 // In the case that the write was successful, return the number of bytes
592 // written. Otherwise return the error code.
593 int rv =
594 next_write.result != OK ? next_write.result : next_write.data_len;
595 NET_TRACE(1, " *** ") << "Returning synchronously";
596 return MockWriteResult(SYNCHRONOUS, rv);
597 }
598
599 // If the result is ERR_IO_PENDING, then pause.
600 if (next_write.result == ERR_IO_PENDING) {
601 NET_TRACE(1, " *** ") << "Pausing write at: " << sequence_number_;
602 write_state_ = IoState::kPaused;
603 if (run_until_paused_run_loop_)
604 run_until_paused_run_loop_->Quit();
605 return MockWriteResult(SYNCHRONOUS, ERR_IO_PENDING);
606 }
607
608 NET_TRACE(1, " *** ") << "Posting task to complete write";
609 base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
610 FROM_HERE, base::BindOnce(&SequencedSocketData::OnWriteComplete,
611 weak_factory_.GetWeakPtr()));
612 CHECK_NE(IoState::kCompleting, read_state_);
613 write_state_ = IoState::kCompleting;
614 } else if (next_write.mode == SYNCHRONOUS) {
615 ADD_FAILURE() << "Unable to perform synchronous IO while stopped";
616 return MockWriteResult(SYNCHRONOUS, ERR_UNEXPECTED);
617 } else {
618 NET_TRACE(1, " *** ") << "Waiting for read to trigger write";
619 write_state_ = IoState::kPending;
620 }
621
622 return MockWriteResult(SYNCHRONOUS, ERR_IO_PENDING);
623 }
624
AllReadDataConsumed() const625 bool SequencedSocketData::AllReadDataConsumed() const {
626 return helper_.AllReadDataConsumed();
627 }
628
CancelPendingRead()629 void SequencedSocketData::CancelPendingRead() {
630 DCHECK_EQ(IoState::kPending, read_state_);
631
632 read_state_ = IoState::kIdle;
633 }
634
AllWriteDataConsumed() const635 bool SequencedSocketData::AllWriteDataConsumed() const {
636 return helper_.AllWriteDataConsumed();
637 }
638
ExpectAllReadDataConsumed() const639 void SequencedSocketData::ExpectAllReadDataConsumed() const {
640 helper_.ExpectAllReadDataConsumed(printer_.get());
641 }
642
ExpectAllWriteDataConsumed() const643 void SequencedSocketData::ExpectAllWriteDataConsumed() const {
644 helper_.ExpectAllWriteDataConsumed(printer_.get());
645 }
646
IsIdle() const647 bool SequencedSocketData::IsIdle() const {
648 // If |busy_before_sync_reads_| is not set, always considered idle. If
649 // no reads left, or the next operation is a write, also consider it idle.
650 if (!busy_before_sync_reads_ || helper_.AllReadDataConsumed() ||
651 helper_.PeekRead().sequence_number != sequence_number_) {
652 return true;
653 }
654
655 // If the next operation is synchronous read, treat the socket as not idle.
656 if (helper_.PeekRead().mode == SYNCHRONOUS)
657 return false;
658 return true;
659 }
660
IsPaused() const661 bool SequencedSocketData::IsPaused() const {
662 // Both states should not be paused.
663 DCHECK(read_state_ != IoState::kPaused || write_state_ != IoState::kPaused);
664 return write_state_ == IoState::kPaused || read_state_ == IoState::kPaused;
665 }
666
Resume()667 void SequencedSocketData::Resume() {
668 if (!IsPaused()) {
669 ADD_FAILURE() << "Unable to Resume when not paused.";
670 return;
671 }
672
673 sequence_number_++;
674 if (read_state_ == IoState::kPaused) {
675 read_state_ = IoState::kPending;
676 helper_.AdvanceRead();
677 } else { // write_state_ == IoState::kPaused
678 write_state_ = IoState::kPending;
679 helper_.AdvanceWrite();
680 }
681
682 if (!helper_.AllWriteDataConsumed() &&
683 helper_.PeekWrite().sequence_number == sequence_number_) {
684 // The next event hasn't even started yet. Pausing isn't really needed in
685 // that case, but may as well support it.
686 if (write_state_ != IoState::kPending)
687 return;
688 write_state_ = IoState::kCompleting;
689 OnWriteComplete();
690 return;
691 }
692
693 CHECK(!helper_.AllReadDataConsumed());
694
695 // The next event hasn't even started yet. Pausing isn't really needed in
696 // that case, but may as well support it.
697 if (read_state_ != IoState::kPending)
698 return;
699 read_state_ = IoState::kCompleting;
700 OnReadComplete();
701 }
702
RunUntilPaused()703 void SequencedSocketData::RunUntilPaused() {
704 CHECK(!run_until_paused_run_loop_);
705
706 if (IsPaused())
707 return;
708
709 run_until_paused_run_loop_ = std::make_unique<base::RunLoop>();
710 run_until_paused_run_loop_->Run();
711 run_until_paused_run_loop_.reset();
712 DCHECK(IsPaused());
713 }
714
MaybePostReadCompleteTask()715 void SequencedSocketData::MaybePostReadCompleteTask() {
716 NET_TRACE(1, " ****** ") << " current: " << sequence_number_;
717 // Only trigger the next read to complete if there is already a read pending
718 // which should complete at the current sequence number.
719 if (read_state_ != IoState::kPending ||
720 helper_.PeekRead().sequence_number != sequence_number_) {
721 return;
722 }
723
724 // If the result is ERR_IO_PENDING, then pause.
725 if (helper_.PeekRead().result == ERR_IO_PENDING) {
726 NET_TRACE(1, " *** ") << "Pausing read at: " << sequence_number_;
727 read_state_ = IoState::kPaused;
728 if (run_until_paused_run_loop_)
729 run_until_paused_run_loop_->Quit();
730 return;
731 }
732
733 NET_TRACE(1, " ****** ") << "Posting task to complete read: "
734 << sequence_number_;
735 base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
736 FROM_HERE, base::BindOnce(&SequencedSocketData::OnReadComplete,
737 weak_factory_.GetWeakPtr()));
738 CHECK_NE(IoState::kCompleting, write_state_);
739 read_state_ = IoState::kCompleting;
740 }
741
MaybePostWriteCompleteTask()742 void SequencedSocketData::MaybePostWriteCompleteTask() {
743 NET_TRACE(1, " ****** ") << " current: " << sequence_number_;
744 // Only trigger the next write to complete if there is already a write pending
745 // which should complete at the current sequence number.
746 if (write_state_ != IoState::kPending ||
747 helper_.PeekWrite().sequence_number != sequence_number_) {
748 return;
749 }
750
751 // If the result is ERR_IO_PENDING, then pause.
752 if (helper_.PeekWrite().result == ERR_IO_PENDING) {
753 NET_TRACE(1, " *** ") << "Pausing write at: " << sequence_number_;
754 write_state_ = IoState::kPaused;
755 if (run_until_paused_run_loop_)
756 run_until_paused_run_loop_->Quit();
757 return;
758 }
759
760 NET_TRACE(1, " ****** ") << "Posting task to complete write: "
761 << sequence_number_;
762 base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
763 FROM_HERE, base::BindOnce(&SequencedSocketData::OnWriteComplete,
764 weak_factory_.GetWeakPtr()));
765 CHECK_NE(IoState::kCompleting, read_state_);
766 write_state_ = IoState::kCompleting;
767 }
768
Reset()769 void SequencedSocketData::Reset() {
770 helper_.Reset();
771 sequence_number_ = 0;
772 read_state_ = IoState::kIdle;
773 write_state_ = IoState::kIdle;
774 weak_factory_.InvalidateWeakPtrs();
775 }
776
OnReadComplete()777 void SequencedSocketData::OnReadComplete() {
778 CHECK_EQ(IoState::kCompleting, read_state_);
779 NET_TRACE(1, " *** ") << "Completing read for: " << sequence_number_;
780
781 MockRead data = helper_.AdvanceRead();
782 DCHECK_EQ(sequence_number_, data.sequence_number);
783 sequence_number_++;
784 read_state_ = IoState::kIdle;
785
786 // The result of this read completing might trigger the completion
787 // of a pending write. If so, post a task to complete the write later.
788 // Since the socket may call back into the SequencedSocketData
789 // from socket()->OnReadComplete(), trigger the write task to be posted
790 // before calling that.
791 MaybePostWriteCompleteTask();
792
793 if (!socket()) {
794 NET_TRACE(1, " *** ") << "No socket available to complete read";
795 return;
796 }
797
798 NET_TRACE(1, " *** ") << "Completing socket read for: "
799 << data.sequence_number;
800 DumpMockReadWrite(data);
801 socket()->OnReadComplete(data);
802 NET_TRACE(1, " *** ") << "Done";
803 }
804
OnWriteComplete()805 void SequencedSocketData::OnWriteComplete() {
806 CHECK_EQ(IoState::kCompleting, write_state_);
807 NET_TRACE(1, " *** ") << " Completing write for: " << sequence_number_;
808
809 const MockWrite& data = helper_.AdvanceWrite();
810 DCHECK_EQ(sequence_number_, data.sequence_number);
811 sequence_number_++;
812 write_state_ = IoState::kIdle;
813 int rv = data.result == OK ? data.data_len : data.result;
814
815 // The result of this write completing might trigger the completion
816 // of a pending read. If so, post a task to complete the read later.
817 // Since the socket may call back into the SequencedSocketData
818 // from socket()->OnWriteComplete(), trigger the write task to be posted
819 // before calling that.
820 MaybePostReadCompleteTask();
821
822 if (!socket()) {
823 NET_TRACE(1, " *** ") << "No socket available to complete write";
824 return;
825 }
826
827 NET_TRACE(1, " *** ") << " Completing socket write for: "
828 << data.sequence_number;
829 socket()->OnWriteComplete(rv);
830 NET_TRACE(1, " *** ") << "Done";
831 }
832
833 SequencedSocketData::~SequencedSocketData() = default;
834
835 MockClientSocketFactory::MockClientSocketFactory() = default;
836
837 MockClientSocketFactory::~MockClientSocketFactory() = default;
838
AddSocketDataProvider(SocketDataProvider * data)839 void MockClientSocketFactory::AddSocketDataProvider(SocketDataProvider* data) {
840 mock_data_.Add(data);
841 }
842
AddTcpSocketDataProvider(SocketDataProvider * data)843 void MockClientSocketFactory::AddTcpSocketDataProvider(
844 SocketDataProvider* data) {
845 mock_tcp_data_.Add(data);
846 }
847
AddSSLSocketDataProvider(SSLSocketDataProvider * data)848 void MockClientSocketFactory::AddSSLSocketDataProvider(
849 SSLSocketDataProvider* data) {
850 mock_ssl_data_.Add(data);
851 }
852
ResetNextMockIndexes()853 void MockClientSocketFactory::ResetNextMockIndexes() {
854 mock_data_.ResetNextIndex();
855 mock_ssl_data_.ResetNextIndex();
856 }
857
858 std::unique_ptr<DatagramClientSocket>
CreateDatagramClientSocket(DatagramSocket::BindType bind_type,NetLog * net_log,const NetLogSource & source)859 MockClientSocketFactory::CreateDatagramClientSocket(
860 DatagramSocket::BindType bind_type,
861 NetLog* net_log,
862 const NetLogSource& source) {
863 NET_TRACE(1, " *** ") << "mock_data_index: " << mock_data_.next_index();
864 SocketDataProvider* data_provider = mock_data_.GetNext();
865 auto socket = std::make_unique<MockUDPClientSocket>(data_provider, net_log);
866 if (bind_type == DatagramSocket::RANDOM_BIND)
867 socket->set_source_port(static_cast<uint16_t>(base::RandInt(1025, 65535)));
868 udp_client_socket_ports_.push_back(socket->source_port());
869 return std::move(socket);
870 }
871
872 std::unique_ptr<TransportClientSocket>
CreateTransportClientSocket(const AddressList & addresses,std::unique_ptr<SocketPerformanceWatcher> socket_performance_watcher,NetworkQualityEstimator * network_quality_estimator,NetLog * net_log,const NetLogSource & source)873 MockClientSocketFactory::CreateTransportClientSocket(
874 const AddressList& addresses,
875 std::unique_ptr<SocketPerformanceWatcher> socket_performance_watcher,
876 NetworkQualityEstimator* network_quality_estimator,
877 NetLog* net_log,
878 const NetLogSource& source) {
879 SocketDataProvider* data_provider = mock_tcp_data_.GetNextWithoutAsserting();
880 if (data_provider) {
881 NET_TRACE(1, " *** ") << "mock_tcp_data_index: "
882 << (mock_tcp_data_.next_index() - 1);
883 } else {
884 NET_TRACE(1, " *** ") << "mock_data_index: " << mock_data_.next_index();
885 data_provider = mock_data_.GetNext();
886 }
887 auto socket =
888 std::make_unique<MockTCPClientSocket>(addresses, net_log, data_provider);
889 if (enable_read_if_ready_)
890 socket->set_enable_read_if_ready(enable_read_if_ready_);
891 return std::move(socket);
892 }
893
CreateSSLClientSocket(SSLClientContext * context,std::unique_ptr<StreamSocket> stream_socket,const HostPortPair & host_and_port,const SSLConfig & ssl_config)894 std::unique_ptr<SSLClientSocket> MockClientSocketFactory::CreateSSLClientSocket(
895 SSLClientContext* context,
896 std::unique_ptr<StreamSocket> stream_socket,
897 const HostPortPair& host_and_port,
898 const SSLConfig& ssl_config) {
899 NET_TRACE(1, " *** ") << "mock_ssl_data_index: "
900 << mock_ssl_data_.next_index();
901 SSLSocketDataProvider* next_ssl_data = mock_ssl_data_.GetNext();
902 if (next_ssl_data->next_protos_expected_in_ssl_config.has_value()) {
903 EXPECT_TRUE(base::ranges::equal(
904 next_ssl_data->next_protos_expected_in_ssl_config.value(),
905 ssl_config.alpn_protos));
906 }
907 if (next_ssl_data->expected_application_settings) {
908 EXPECT_EQ(*next_ssl_data->expected_application_settings,
909 ssl_config.application_settings);
910 }
911
912 // The protocol version used is a combination of the per-socket SSLConfig and
913 // the SSLConfigService.
914 EXPECT_EQ(
915 next_ssl_data->expected_ssl_version_min,
916 ssl_config.version_min_override.value_or(context->config().version_min));
917 EXPECT_EQ(
918 next_ssl_data->expected_ssl_version_max,
919 ssl_config.version_max_override.value_or(context->config().version_max));
920
921 if (next_ssl_data->expected_early_data_enabled) {
922 EXPECT_EQ(*next_ssl_data->expected_early_data_enabled,
923 ssl_config.early_data_enabled);
924 }
925
926 if (next_ssl_data->expected_send_client_cert) {
927 // Client certificate preferences come from |context|.
928 scoped_refptr<X509Certificate> client_cert;
929 scoped_refptr<SSLPrivateKey> client_private_key;
930 bool send_client_cert = context->GetClientCertificate(
931 host_and_port, &client_cert, &client_private_key);
932
933 EXPECT_EQ(*next_ssl_data->expected_send_client_cert, send_client_cert);
934 // Note |send_client_cert| may be true while |client_cert| is null if the
935 // socket is configured to continue without a certificate, as opposed to
936 // surfacing the certificate challenge.
937 EXPECT_EQ(!!next_ssl_data->expected_client_cert, !!client_cert);
938 if (next_ssl_data->expected_client_cert && client_cert) {
939 EXPECT_TRUE(next_ssl_data->expected_client_cert->EqualsIncludingChain(
940 client_cert.get()));
941 }
942 }
943 if (next_ssl_data->expected_host_and_port) {
944 EXPECT_EQ(*next_ssl_data->expected_host_and_port, host_and_port);
945 }
946 if (next_ssl_data->expected_ignore_certificate_errors) {
947 EXPECT_EQ(*next_ssl_data->expected_ignore_certificate_errors,
948 ssl_config.ignore_certificate_errors);
949 }
950 if (next_ssl_data->expected_network_anonymization_key) {
951 EXPECT_EQ(*next_ssl_data->expected_network_anonymization_key,
952 ssl_config.network_anonymization_key);
953 }
954 if (next_ssl_data->expected_ech_config_list) {
955 EXPECT_EQ(*next_ssl_data->expected_ech_config_list,
956 ssl_config.ech_config_list);
957 }
958 return std::make_unique<MockSSLClientSocket>(
959 std::move(stream_socket), host_and_port, ssl_config, next_ssl_data);
960 }
961
MockClientSocket(const NetLogWithSource & net_log)962 MockClientSocket::MockClientSocket(const NetLogWithSource& net_log)
963 : net_log_(net_log) {
964 local_addr_ = IPEndPoint(IPAddress(192, 0, 2, 33), 123);
965 peer_addr_ = IPEndPoint(IPAddress(192, 0, 2, 33), 0);
966 }
967
SetReceiveBufferSize(int32_t size)968 int MockClientSocket::SetReceiveBufferSize(int32_t size) {
969 return OK;
970 }
971
SetSendBufferSize(int32_t size)972 int MockClientSocket::SetSendBufferSize(int32_t size) {
973 return OK;
974 }
975
Bind(const net::IPEndPoint & local_addr)976 int MockClientSocket::Bind(const net::IPEndPoint& local_addr) {
977 local_addr_ = local_addr;
978 return net::OK;
979 }
980
SetNoDelay(bool no_delay)981 bool MockClientSocket::SetNoDelay(bool no_delay) {
982 return true;
983 }
984
SetKeepAlive(bool enable,int delay)985 bool MockClientSocket::SetKeepAlive(bool enable, int delay) {
986 return true;
987 }
988
Disconnect()989 void MockClientSocket::Disconnect() {
990 connected_ = false;
991 }
992
IsConnected() const993 bool MockClientSocket::IsConnected() const {
994 return connected_;
995 }
996
IsConnectedAndIdle() const997 bool MockClientSocket::IsConnectedAndIdle() const {
998 return connected_;
999 }
1000
GetPeerAddress(IPEndPoint * address) const1001 int MockClientSocket::GetPeerAddress(IPEndPoint* address) const {
1002 if (!IsConnected())
1003 return ERR_SOCKET_NOT_CONNECTED;
1004 *address = peer_addr_;
1005 return OK;
1006 }
1007
GetLocalAddress(IPEndPoint * address) const1008 int MockClientSocket::GetLocalAddress(IPEndPoint* address) const {
1009 *address = local_addr_;
1010 return OK;
1011 }
1012
NetLog() const1013 const NetLogWithSource& MockClientSocket::NetLog() const {
1014 return net_log_;
1015 }
1016
GetNegotiatedProtocol() const1017 NextProto MockClientSocket::GetNegotiatedProtocol() const {
1018 return kProtoUnknown;
1019 }
1020
1021 MockClientSocket::~MockClientSocket() = default;
1022
RunCallbackAsync(CompletionOnceCallback callback,int result)1023 void MockClientSocket::RunCallbackAsync(CompletionOnceCallback callback,
1024 int result) {
1025 base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
1026 FROM_HERE,
1027 base::BindOnce(&MockClientSocket::RunCallback, weak_factory_.GetWeakPtr(),
1028 std::move(callback), result));
1029 }
1030
RunCallback(CompletionOnceCallback callback,int result)1031 void MockClientSocket::RunCallback(CompletionOnceCallback callback,
1032 int result) {
1033 std::move(callback).Run(result);
1034 }
1035
MockTCPClientSocket(const AddressList & addresses,net::NetLog * net_log,SocketDataProvider * data)1036 MockTCPClientSocket::MockTCPClientSocket(const AddressList& addresses,
1037 net::NetLog* net_log,
1038 SocketDataProvider* data)
1039 : MockClientSocket(
1040 NetLogWithSource::Make(net_log, NetLogSourceType::SOCKET)),
1041 addresses_(addresses),
1042 data_(data),
1043 read_data_(SYNCHRONOUS, ERR_UNEXPECTED) {
1044 DCHECK(data_);
1045 peer_addr_ = data->connect_data().peer_addr;
1046 data_->Initialize(this);
1047 if (data_->expected_addresses()) {
1048 EXPECT_EQ(*data_->expected_addresses(), addresses);
1049 }
1050 }
1051
~MockTCPClientSocket()1052 MockTCPClientSocket::~MockTCPClientSocket() {
1053 if (data_)
1054 data_->DetachSocket();
1055 }
1056
Read(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)1057 int MockTCPClientSocket::Read(IOBuffer* buf,
1058 int buf_len,
1059 CompletionOnceCallback callback) {
1060 // If the buffer is already in use, a read is already in progress!
1061 DCHECK(!pending_read_buf_);
1062 // Use base::Unretained() is safe because MockClientSocket::RunCallbackAsync()
1063 // takes a weak ptr of the base class, MockClientSocket.
1064 int rv = ReadIfReadyImpl(
1065 buf, buf_len,
1066 base::BindOnce(&MockTCPClientSocket::RetryRead, base::Unretained(this)));
1067 if (rv == ERR_IO_PENDING) {
1068 DCHECK(callback);
1069
1070 pending_read_buf_ = buf;
1071 pending_read_buf_len_ = buf_len;
1072 pending_read_callback_ = std::move(callback);
1073 }
1074 return rv;
1075 }
1076
ReadIfReady(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)1077 int MockTCPClientSocket::ReadIfReady(IOBuffer* buf,
1078 int buf_len,
1079 CompletionOnceCallback callback) {
1080 DCHECK(!pending_read_if_ready_callback_);
1081
1082 if (!enable_read_if_ready_)
1083 return ERR_READ_IF_READY_NOT_IMPLEMENTED;
1084 return ReadIfReadyImpl(buf, buf_len, std::move(callback));
1085 }
1086
CancelReadIfReady()1087 int MockTCPClientSocket::CancelReadIfReady() {
1088 DCHECK(pending_read_if_ready_callback_);
1089
1090 pending_read_if_ready_callback_.Reset();
1091 data_->CancelPendingRead();
1092 return OK;
1093 }
1094
Write(IOBuffer * buf,int buf_len,CompletionOnceCallback callback,const NetworkTrafficAnnotationTag &)1095 int MockTCPClientSocket::Write(
1096 IOBuffer* buf,
1097 int buf_len,
1098 CompletionOnceCallback callback,
1099 const NetworkTrafficAnnotationTag& /* traffic_annotation */) {
1100 DCHECK(buf);
1101 DCHECK_GT(buf_len, 0);
1102
1103 if (!connected_ || !data_)
1104 return ERR_UNEXPECTED;
1105
1106 std::string data(buf->data(), buf_len);
1107 MockWriteResult write_result = data_->OnWrite(data);
1108
1109 was_used_to_convey_data_ = true;
1110
1111 if (write_result.result == ERR_CONNECTION_CLOSED) {
1112 // This MockWrite is just a marker to instruct us to set
1113 // peer_closed_connection_.
1114 peer_closed_connection_ = true;
1115 }
1116 // ERR_IO_PENDING is a signal that the socket data will call back
1117 // asynchronously later.
1118 if (write_result.result == ERR_IO_PENDING) {
1119 pending_write_callback_ = std::move(callback);
1120 return ERR_IO_PENDING;
1121 }
1122
1123 if (write_result.mode == ASYNC) {
1124 RunCallbackAsync(std::move(callback), write_result.result);
1125 return ERR_IO_PENDING;
1126 }
1127
1128 return write_result.result;
1129 }
1130
SetReceiveBufferSize(int32_t size)1131 int MockTCPClientSocket::SetReceiveBufferSize(int32_t size) {
1132 if (!connected_)
1133 return net::ERR_UNEXPECTED;
1134 data_->set_receive_buffer_size(size);
1135 return data_->set_receive_buffer_size_result();
1136 }
1137
SetSendBufferSize(int32_t size)1138 int MockTCPClientSocket::SetSendBufferSize(int32_t size) {
1139 if (!connected_)
1140 return net::ERR_UNEXPECTED;
1141 data_->set_send_buffer_size(size);
1142 return data_->set_send_buffer_size_result();
1143 }
1144
SetNoDelay(bool no_delay)1145 bool MockTCPClientSocket::SetNoDelay(bool no_delay) {
1146 if (!connected_)
1147 return false;
1148 data_->set_no_delay(no_delay);
1149 return data_->set_no_delay_result();
1150 }
1151
SetKeepAlive(bool enable,int delay)1152 bool MockTCPClientSocket::SetKeepAlive(bool enable, int delay) {
1153 if (!connected_)
1154 return false;
1155 data_->set_keep_alive(enable, delay);
1156 return data_->set_keep_alive_result();
1157 }
1158
SetBeforeConnectCallback(const BeforeConnectCallback & before_connect_callback)1159 void MockTCPClientSocket::SetBeforeConnectCallback(
1160 const BeforeConnectCallback& before_connect_callback) {
1161 DCHECK(!before_connect_callback_);
1162 DCHECK(!connected_);
1163
1164 before_connect_callback_ = before_connect_callback;
1165 }
1166
Connect(CompletionOnceCallback callback)1167 int MockTCPClientSocket::Connect(CompletionOnceCallback callback) {
1168 if (!data_)
1169 return ERR_UNEXPECTED;
1170
1171 if (connected_)
1172 return OK;
1173
1174 // Setting socket options fails if not connected, so need to set this before
1175 // calling |before_connect_callback_|.
1176 connected_ = true;
1177
1178 if (before_connect_callback_) {
1179 for (size_t index = 0; index < addresses_.size(); index++) {
1180 int result = before_connect_callback_.Run();
1181 if (data_->connect_data().first_attempt_fails && index == 0) {
1182 continue;
1183 }
1184 DCHECK_NE(result, ERR_IO_PENDING);
1185 if (result != net::OK) {
1186 connected_ = false;
1187 return result;
1188 }
1189 break;
1190 }
1191 }
1192
1193 peer_closed_connection_ = false;
1194
1195 if (data_->connect_data().completer) {
1196 data_->connect_data().completer->SetCallback(std::move(callback));
1197 return ERR_IO_PENDING;
1198 }
1199
1200 int result = data_->connect_data().result;
1201 IoMode mode = data_->connect_data().mode;
1202 if (mode == SYNCHRONOUS)
1203 return result;
1204
1205 DCHECK(callback);
1206
1207 if (result == ERR_IO_PENDING)
1208 pending_connect_callback_ = std::move(callback);
1209 else
1210 RunCallbackAsync(std::move(callback), result);
1211 return ERR_IO_PENDING;
1212 }
1213
Disconnect()1214 void MockTCPClientSocket::Disconnect() {
1215 MockClientSocket::Disconnect();
1216 pending_connect_callback_.Reset();
1217 pending_read_callback_.Reset();
1218 }
1219
IsConnected() const1220 bool MockTCPClientSocket::IsConnected() const {
1221 if (!data_)
1222 return false;
1223 return connected_ && !peer_closed_connection_;
1224 }
1225
IsConnectedAndIdle() const1226 bool MockTCPClientSocket::IsConnectedAndIdle() const {
1227 if (!data_)
1228 return false;
1229 return IsConnected() && data_->IsIdle();
1230 }
1231
GetPeerAddress(IPEndPoint * address) const1232 int MockTCPClientSocket::GetPeerAddress(IPEndPoint* address) const {
1233 if (addresses_.empty())
1234 return MockClientSocket::GetPeerAddress(address);
1235
1236 if (data_->connect_data().first_attempt_fails) {
1237 DCHECK_GE(addresses_.size(), 2U);
1238 *address = addresses_[1];
1239 } else {
1240 *address = addresses_[0];
1241 }
1242 return OK;
1243 }
1244
WasEverUsed() const1245 bool MockTCPClientSocket::WasEverUsed() const {
1246 return was_used_to_convey_data_;
1247 }
1248
GetSSLInfo(SSLInfo * ssl_info)1249 bool MockTCPClientSocket::GetSSLInfo(SSLInfo* ssl_info) {
1250 return false;
1251 }
1252
OnReadComplete(const MockRead & data)1253 void MockTCPClientSocket::OnReadComplete(const MockRead& data) {
1254 // If |data_| has been destroyed, safest to just do nothing.
1255 if (!data_)
1256 return;
1257
1258 // There must be a read pending.
1259 DCHECK(pending_read_if_ready_callback_);
1260 // You can't complete a read with another ERR_IO_PENDING status code.
1261 DCHECK_NE(ERR_IO_PENDING, data.result);
1262 // Since we've been waiting for data, need_read_data_ should be true.
1263 DCHECK(need_read_data_);
1264
1265 read_data_ = data;
1266 need_read_data_ = false;
1267
1268 // The caller is simulating that this IO completes right now. Don't
1269 // let CompleteRead() schedule a callback.
1270 read_data_.mode = SYNCHRONOUS;
1271 RunCallback(std::move(pending_read_if_ready_callback_),
1272 read_data_.result > 0 ? OK : read_data_.result);
1273 }
1274
OnWriteComplete(int rv)1275 void MockTCPClientSocket::OnWriteComplete(int rv) {
1276 // If |data_| has been destroyed, safest to just do nothing.
1277 if (!data_)
1278 return;
1279
1280 // There must be a read pending.
1281 DCHECK(!pending_write_callback_.is_null());
1282 RunCallback(std::move(pending_write_callback_), rv);
1283 }
1284
OnConnectComplete(const MockConnect & data)1285 void MockTCPClientSocket::OnConnectComplete(const MockConnect& data) {
1286 // If |data_| has been destroyed, safest to just do nothing.
1287 if (!data_)
1288 return;
1289
1290 RunCallback(std::move(pending_connect_callback_), data.result);
1291 }
1292
OnDataProviderDestroyed()1293 void MockTCPClientSocket::OnDataProviderDestroyed() {
1294 data_ = nullptr;
1295 }
1296
RetryRead(int rv)1297 void MockTCPClientSocket::RetryRead(int rv) {
1298 DCHECK(pending_read_callback_);
1299 DCHECK(pending_read_buf_.get());
1300 DCHECK_LT(0, pending_read_buf_len_);
1301
1302 if (rv == OK) {
1303 rv = ReadIfReadyImpl(pending_read_buf_.get(), pending_read_buf_len_,
1304 base::BindOnce(&MockTCPClientSocket::RetryRead,
1305 base::Unretained(this)));
1306 if (rv == ERR_IO_PENDING)
1307 return;
1308 }
1309 pending_read_buf_ = nullptr;
1310 pending_read_buf_len_ = 0;
1311 RunCallback(std::move(pending_read_callback_), rv);
1312 }
1313
ReadIfReadyImpl(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)1314 int MockTCPClientSocket::ReadIfReadyImpl(IOBuffer* buf,
1315 int buf_len,
1316 CompletionOnceCallback callback) {
1317 if (!connected_ || !data_)
1318 return ERR_UNEXPECTED;
1319
1320 DCHECK(!pending_read_if_ready_callback_);
1321
1322 if (need_read_data_) {
1323 read_data_ = data_->OnRead();
1324 if (read_data_.result == ERR_CONNECTION_CLOSED) {
1325 // This MockRead is just a marker to instruct us to set
1326 // peer_closed_connection_.
1327 peer_closed_connection_ = true;
1328 }
1329 if (read_data_.result == ERR_TEST_PEER_CLOSE_AFTER_NEXT_MOCK_READ) {
1330 // This MockRead is just a marker to instruct us to set
1331 // peer_closed_connection_. Skip it and get the next one.
1332 read_data_ = data_->OnRead();
1333 peer_closed_connection_ = true;
1334 }
1335 // ERR_IO_PENDING means that the SocketDataProvider is taking responsibility
1336 // to complete the async IO manually later (via OnReadComplete).
1337 if (read_data_.result == ERR_IO_PENDING) {
1338 // We need to be using async IO in this case.
1339 DCHECK(!callback.is_null());
1340 pending_read_if_ready_callback_ = std::move(callback);
1341 return ERR_IO_PENDING;
1342 }
1343 need_read_data_ = false;
1344 }
1345
1346 int result = read_data_.result;
1347 DCHECK_NE(ERR_IO_PENDING, result);
1348 if (read_data_.mode == ASYNC) {
1349 DCHECK(!callback.is_null());
1350 read_data_.mode = SYNCHRONOUS;
1351 pending_read_if_ready_callback_ = std::move(callback);
1352 // base::Unretained() is safe here because RunCallbackAsync will wrap it
1353 // with a callback associated with a weak ptr.
1354 RunCallbackAsync(
1355 base::BindOnce(&MockTCPClientSocket::RunReadIfReadyCallback,
1356 base::Unretained(this)),
1357 result);
1358 return ERR_IO_PENDING;
1359 }
1360
1361 was_used_to_convey_data_ = true;
1362 if (read_data_.data) {
1363 if (read_data_.data_len - read_offset_ > 0) {
1364 result = std::min(buf_len, read_data_.data_len - read_offset_);
1365 memcpy(buf->data(), read_data_.data + read_offset_, result);
1366 read_offset_ += result;
1367 if (read_offset_ == read_data_.data_len) {
1368 need_read_data_ = true;
1369 read_offset_ = 0;
1370 }
1371 } else {
1372 result = 0; // EOF
1373 }
1374 }
1375 return result;
1376 }
1377
RunReadIfReadyCallback(int result)1378 void MockTCPClientSocket::RunReadIfReadyCallback(int result) {
1379 // If ReadIfReady is already canceled, do nothing.
1380 if (!pending_read_if_ready_callback_)
1381 return;
1382 std::move(pending_read_if_ready_callback_).Run(result);
1383 }
1384
1385 // static
ConnectCallback(MockSSLClientSocket * ssl_client_socket,CompletionOnceCallback callback,int rv)1386 void MockSSLClientSocket::ConnectCallback(
1387 MockSSLClientSocket* ssl_client_socket,
1388 CompletionOnceCallback callback,
1389 int rv) {
1390 if (rv == OK)
1391 ssl_client_socket->connected_ = true;
1392 std::move(callback).Run(rv);
1393 }
1394
MockSSLClientSocket(std::unique_ptr<StreamSocket> stream_socket,const HostPortPair & host_and_port,const SSLConfig & ssl_config,SSLSocketDataProvider * data)1395 MockSSLClientSocket::MockSSLClientSocket(
1396 std::unique_ptr<StreamSocket> stream_socket,
1397 const HostPortPair& host_and_port,
1398 const SSLConfig& ssl_config,
1399 SSLSocketDataProvider* data)
1400 : net_log_(stream_socket->NetLog()),
1401 stream_socket_(std::move(stream_socket)),
1402 data_(data) {
1403 DCHECK(data_);
1404 peer_addr_ = data->connect.peer_addr;
1405 }
1406
~MockSSLClientSocket()1407 MockSSLClientSocket::~MockSSLClientSocket() {
1408 Disconnect();
1409 }
1410
Read(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)1411 int MockSSLClientSocket::Read(IOBuffer* buf,
1412 int buf_len,
1413 CompletionOnceCallback callback) {
1414 return stream_socket_->Read(buf, buf_len, std::move(callback));
1415 }
1416
ReadIfReady(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)1417 int MockSSLClientSocket::ReadIfReady(IOBuffer* buf,
1418 int buf_len,
1419 CompletionOnceCallback callback) {
1420 return stream_socket_->ReadIfReady(buf, buf_len, std::move(callback));
1421 }
1422
Write(IOBuffer * buf,int buf_len,CompletionOnceCallback callback,const NetworkTrafficAnnotationTag & traffic_annotation)1423 int MockSSLClientSocket::Write(
1424 IOBuffer* buf,
1425 int buf_len,
1426 CompletionOnceCallback callback,
1427 const NetworkTrafficAnnotationTag& traffic_annotation) {
1428 if (!data_->is_confirm_data_consumed)
1429 data_->write_called_before_confirm = true;
1430 return stream_socket_->Write(buf, buf_len, std::move(callback),
1431 traffic_annotation);
1432 }
1433
CancelReadIfReady()1434 int MockSSLClientSocket::CancelReadIfReady() {
1435 return stream_socket_->CancelReadIfReady();
1436 }
1437
Connect(CompletionOnceCallback callback)1438 int MockSSLClientSocket::Connect(CompletionOnceCallback callback) {
1439 DCHECK(stream_socket_->IsConnected());
1440 data_->is_connect_data_consumed = true;
1441 if (data_->connect.completer) {
1442 data_->connect.completer->SetCallback(std::move(callback));
1443 return ERR_IO_PENDING;
1444 }
1445 if (data_->connect.result == OK)
1446 connected_ = true;
1447 RunClosureIfNonNull(std::move(data_->connect_callback));
1448 if (data_->connect.mode == ASYNC) {
1449 RunCallbackAsync(std::move(callback), data_->connect.result);
1450 return ERR_IO_PENDING;
1451 }
1452 return data_->connect.result;
1453 }
1454
Disconnect()1455 void MockSSLClientSocket::Disconnect() {
1456 if (stream_socket_ != nullptr)
1457 stream_socket_->Disconnect();
1458 }
1459
RunConfirmHandshakeCallback(CompletionOnceCallback callback,int result)1460 void MockSSLClientSocket::RunConfirmHandshakeCallback(
1461 CompletionOnceCallback callback,
1462 int result) {
1463 DCHECK(in_confirm_handshake_);
1464 in_confirm_handshake_ = false;
1465 data_->is_confirm_data_consumed = true;
1466 std::move(callback).Run(result);
1467 }
1468
ConfirmHandshake(CompletionOnceCallback callback)1469 int MockSSLClientSocket::ConfirmHandshake(CompletionOnceCallback callback) {
1470 DCHECK(stream_socket_->IsConnected());
1471 DCHECK(!in_confirm_handshake_);
1472 if (data_->is_confirm_data_consumed)
1473 return data_->confirm.result;
1474 RunClosureIfNonNull(std::move(data_->confirm_callback));
1475 if (data_->confirm.mode == ASYNC) {
1476 in_confirm_handshake_ = true;
1477 RunCallbackAsync(
1478 base::BindOnce(&MockSSLClientSocket::RunConfirmHandshakeCallback,
1479 base::Unretained(this), std::move(callback)),
1480 data_->confirm.result);
1481 return ERR_IO_PENDING;
1482 }
1483 data_->is_confirm_data_consumed = true;
1484 if (data_->confirm.result == ERR_IO_PENDING) {
1485 // `MockConfirm(SYNCHRONOUS, ERR_IO_PENDING)` means `ConfirmHandshake()`
1486 // never completes.
1487 in_confirm_handshake_ = true;
1488 }
1489 return data_->confirm.result;
1490 }
1491
IsConnected() const1492 bool MockSSLClientSocket::IsConnected() const {
1493 return stream_socket_->IsConnected();
1494 }
1495
IsConnectedAndIdle() const1496 bool MockSSLClientSocket::IsConnectedAndIdle() const {
1497 return stream_socket_->IsConnectedAndIdle();
1498 }
1499
WasEverUsed() const1500 bool MockSSLClientSocket::WasEverUsed() const {
1501 return stream_socket_->WasEverUsed();
1502 }
1503
GetLocalAddress(IPEndPoint * address) const1504 int MockSSLClientSocket::GetLocalAddress(IPEndPoint* address) const {
1505 *address = IPEndPoint(IPAddress(192, 0, 2, 33), 123);
1506 return OK;
1507 }
1508
GetPeerAddress(IPEndPoint * address) const1509 int MockSSLClientSocket::GetPeerAddress(IPEndPoint* address) const {
1510 return stream_socket_->GetPeerAddress(address);
1511 }
1512
GetNegotiatedProtocol() const1513 NextProto MockSSLClientSocket::GetNegotiatedProtocol() const {
1514 return data_->next_proto;
1515 }
1516
1517 std::optional<std::string_view>
GetPeerApplicationSettings() const1518 MockSSLClientSocket::GetPeerApplicationSettings() const {
1519 return data_->peer_application_settings;
1520 }
1521
GetSSLInfo(SSLInfo * requested_ssl_info)1522 bool MockSSLClientSocket::GetSSLInfo(SSLInfo* requested_ssl_info) {
1523 *requested_ssl_info = data_->ssl_info;
1524 return true;
1525 }
1526
ApplySocketTag(const SocketTag & tag)1527 void MockSSLClientSocket::ApplySocketTag(const SocketTag& tag) {
1528 return stream_socket_->ApplySocketTag(tag);
1529 }
1530
NetLog() const1531 const NetLogWithSource& MockSSLClientSocket::NetLog() const {
1532 return net_log_;
1533 }
1534
GetTotalReceivedBytes() const1535 int64_t MockSSLClientSocket::GetTotalReceivedBytes() const {
1536 NOTIMPLEMENTED();
1537 return 0;
1538 }
1539
GetTotalReceivedBytes() const1540 int64_t MockClientSocket::GetTotalReceivedBytes() const {
1541 NOTIMPLEMENTED();
1542 return 0;
1543 }
1544
SetReceiveBufferSize(int32_t size)1545 int MockSSLClientSocket::SetReceiveBufferSize(int32_t size) {
1546 return OK;
1547 }
1548
SetSendBufferSize(int32_t size)1549 int MockSSLClientSocket::SetSendBufferSize(int32_t size) {
1550 return OK;
1551 }
1552
GetSSLCertRequestInfo(SSLCertRequestInfo * cert_request_info) const1553 void MockSSLClientSocket::GetSSLCertRequestInfo(
1554 SSLCertRequestInfo* cert_request_info) const {
1555 DCHECK(cert_request_info);
1556 if (data_->cert_request_info) {
1557 cert_request_info->host_and_port = data_->cert_request_info->host_and_port;
1558 cert_request_info->is_proxy = data_->cert_request_info->is_proxy;
1559 cert_request_info->cert_authorities =
1560 data_->cert_request_info->cert_authorities;
1561 cert_request_info->signature_algorithms =
1562 data_->cert_request_info->signature_algorithms;
1563 } else {
1564 cert_request_info->Reset();
1565 }
1566 }
1567
ExportKeyingMaterial(std::string_view label,std::optional<base::span<const uint8_t>> context,base::span<uint8_t> out)1568 int MockSSLClientSocket::ExportKeyingMaterial(
1569 std::string_view label,
1570 std::optional<base::span<const uint8_t>> context,
1571 base::span<uint8_t> out) {
1572 std::ranges::fill(out, 'A');
1573 return OK;
1574 }
1575
GetECHRetryConfigs()1576 std::vector<uint8_t> MockSSLClientSocket::GetECHRetryConfigs() {
1577 return data_->ech_retry_configs;
1578 }
1579
RunCallbackAsync(CompletionOnceCallback callback,int result)1580 void MockSSLClientSocket::RunCallbackAsync(CompletionOnceCallback callback,
1581 int result) {
1582 base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
1583 FROM_HERE,
1584 base::BindOnce(&MockSSLClientSocket::RunCallback,
1585 weak_factory_.GetWeakPtr(), std::move(callback), result));
1586 }
1587
RunCallback(CompletionOnceCallback callback,int result)1588 void MockSSLClientSocket::RunCallback(CompletionOnceCallback callback,
1589 int result) {
1590 std::move(callback).Run(result);
1591 }
1592
OnReadComplete(const MockRead & data)1593 void MockSSLClientSocket::OnReadComplete(const MockRead& data) {
1594 NOTIMPLEMENTED();
1595 }
1596
OnWriteComplete(int rv)1597 void MockSSLClientSocket::OnWriteComplete(int rv) {
1598 NOTIMPLEMENTED();
1599 }
1600
OnConnectComplete(const MockConnect & data)1601 void MockSSLClientSocket::OnConnectComplete(const MockConnect& data) {
1602 NOTIMPLEMENTED();
1603 }
1604
MockUDPClientSocket(SocketDataProvider * data,net::NetLog * net_log)1605 MockUDPClientSocket::MockUDPClientSocket(SocketDataProvider* data,
1606 net::NetLog* net_log)
1607 : data_(data),
1608 read_data_(SYNCHRONOUS, ERR_UNEXPECTED),
1609 source_host_(IPAddress(192, 0, 2, 33)),
1610 net_log_(NetLogWithSource::Make(net_log,
1611 NetLogSourceType::UDP_CLIENT_SOCKET)) {
1612 if (data_) {
1613 data_->Initialize(this);
1614 peer_addr_ = data->connect_data().peer_addr;
1615 }
1616 }
1617
~MockUDPClientSocket()1618 MockUDPClientSocket::~MockUDPClientSocket() {
1619 if (data_)
1620 data_->DetachSocket();
1621 }
1622
Read(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)1623 int MockUDPClientSocket::Read(IOBuffer* buf,
1624 int buf_len,
1625 CompletionOnceCallback callback) {
1626 DCHECK(callback);
1627
1628 if (!connected_ || !data_)
1629 return ERR_UNEXPECTED;
1630 data_transferred_ = true;
1631
1632 // If the buffer is already in use, a read is already in progress!
1633 DCHECK(!pending_read_buf_);
1634
1635 // Store our async IO data.
1636 pending_read_buf_ = buf;
1637 pending_read_buf_len_ = buf_len;
1638 pending_read_callback_ = std::move(callback);
1639
1640 if (need_read_data_) {
1641 read_data_ = data_->OnRead();
1642 last_tos_ = read_data_.tos;
1643 // ERR_IO_PENDING means that the SocketDataProvider is taking responsibility
1644 // to complete the async IO manually later (via OnReadComplete).
1645 if (read_data_.result == ERR_IO_PENDING) {
1646 // We need to be using async IO in this case.
1647 DCHECK(!pending_read_callback_.is_null());
1648 return ERR_IO_PENDING;
1649 }
1650 need_read_data_ = false;
1651 }
1652
1653 return CompleteRead();
1654 }
1655
Write(IOBuffer * buf,int buf_len,CompletionOnceCallback callback,const NetworkTrafficAnnotationTag &)1656 int MockUDPClientSocket::Write(
1657 IOBuffer* buf,
1658 int buf_len,
1659 CompletionOnceCallback callback,
1660 const NetworkTrafficAnnotationTag& /* traffic_annotation */) {
1661 DCHECK(buf);
1662 DCHECK_GT(buf_len, 0);
1663 DCHECK(callback);
1664
1665 if (!connected_ || !data_)
1666 return ERR_UNEXPECTED;
1667 data_transferred_ = true;
1668
1669 std::string data(buf->data(), buf_len);
1670 MockWriteResult write_result = data_->OnWrite(data);
1671
1672 // ERR_IO_PENDING is a signal that the socket data will call back
1673 // asynchronously.
1674 if (write_result.result == ERR_IO_PENDING) {
1675 pending_write_callback_ = std::move(callback);
1676 return ERR_IO_PENDING;
1677 }
1678 if (write_result.mode == ASYNC) {
1679 RunCallbackAsync(std::move(callback), write_result.result);
1680 return ERR_IO_PENDING;
1681 }
1682 return write_result.result;
1683 }
1684
SetReceiveBufferSize(int32_t size)1685 int MockUDPClientSocket::SetReceiveBufferSize(int32_t size) {
1686 return OK;
1687 }
1688
SetSendBufferSize(int32_t size)1689 int MockUDPClientSocket::SetSendBufferSize(int32_t size) {
1690 return OK;
1691 }
1692
SetDoNotFragment()1693 int MockUDPClientSocket::SetDoNotFragment() {
1694 return OK;
1695 }
1696
SetRecvTos()1697 int MockUDPClientSocket::SetRecvTos() {
1698 return OK;
1699 }
1700
SetTos(DiffServCodePoint dscp,EcnCodePoint ecn)1701 int MockUDPClientSocket::SetTos(DiffServCodePoint dscp, EcnCodePoint ecn) {
1702 return OK;
1703 }
1704
Close()1705 void MockUDPClientSocket::Close() {
1706 connected_ = false;
1707 }
1708
GetPeerAddress(IPEndPoint * address) const1709 int MockUDPClientSocket::GetPeerAddress(IPEndPoint* address) const {
1710 if (!data_)
1711 return ERR_UNEXPECTED;
1712
1713 *address = peer_addr_;
1714 return OK;
1715 }
1716
GetLocalAddress(IPEndPoint * address) const1717 int MockUDPClientSocket::GetLocalAddress(IPEndPoint* address) const {
1718 *address = IPEndPoint(source_host_, source_port_);
1719 return OK;
1720 }
1721
UseNonBlockingIO()1722 void MockUDPClientSocket::UseNonBlockingIO() {}
1723
SetMulticastInterface(uint32_t interface_index)1724 int MockUDPClientSocket::SetMulticastInterface(uint32_t interface_index) {
1725 return OK;
1726 }
1727
NetLog() const1728 const NetLogWithSource& MockUDPClientSocket::NetLog() const {
1729 return net_log_;
1730 }
1731
Connect(const IPEndPoint & address)1732 int MockUDPClientSocket::Connect(const IPEndPoint& address) {
1733 if (!data_)
1734 return ERR_UNEXPECTED;
1735 DCHECK_NE(data_->connect_data().result, ERR_IO_PENDING);
1736 connected_ = true;
1737 peer_addr_ = address;
1738 return data_->connect_data().result;
1739 }
1740
ConnectUsingNetwork(handles::NetworkHandle network,const IPEndPoint & address)1741 int MockUDPClientSocket::ConnectUsingNetwork(handles::NetworkHandle network,
1742 const IPEndPoint& address) {
1743 DCHECK(!connected_);
1744 if (!data_)
1745 return ERR_UNEXPECTED;
1746 DCHECK_NE(data_->connect_data().result, ERR_IO_PENDING);
1747 network_ = network;
1748 connected_ = true;
1749 peer_addr_ = address;
1750 return data_->connect_data().result;
1751 }
1752
ConnectUsingDefaultNetwork(const IPEndPoint & address)1753 int MockUDPClientSocket::ConnectUsingDefaultNetwork(const IPEndPoint& address) {
1754 DCHECK(!connected_);
1755 if (!data_)
1756 return ERR_UNEXPECTED;
1757 DCHECK_NE(data_->connect_data().result, ERR_IO_PENDING);
1758 network_ = kDefaultNetworkForTests;
1759 connected_ = true;
1760 peer_addr_ = address;
1761 return data_->connect_data().result;
1762 }
1763
ConnectAsync(const IPEndPoint & address,CompletionOnceCallback callback)1764 int MockUDPClientSocket::ConnectAsync(const IPEndPoint& address,
1765 CompletionOnceCallback callback) {
1766 DCHECK(callback);
1767 if (!data_) {
1768 return ERR_UNEXPECTED;
1769 }
1770 connected_ = true;
1771 peer_addr_ = address;
1772 int result = data_->connect_data().result;
1773 IoMode mode = data_->connect_data().mode;
1774 if (data_->connect_data().completer) {
1775 data_->connect_data().completer->SetCallback(std::move(callback));
1776 return ERR_IO_PENDING;
1777 }
1778 if (mode == SYNCHRONOUS) {
1779 return result;
1780 }
1781 RunCallbackAsync(std::move(callback), result);
1782 return ERR_IO_PENDING;
1783 }
1784
ConnectUsingNetworkAsync(handles::NetworkHandle network,const IPEndPoint & address,CompletionOnceCallback callback)1785 int MockUDPClientSocket::ConnectUsingNetworkAsync(
1786 handles::NetworkHandle network,
1787 const IPEndPoint& address,
1788 CompletionOnceCallback callback) {
1789 DCHECK(callback);
1790 DCHECK(!connected_);
1791 if (!data_)
1792 return ERR_UNEXPECTED;
1793 network_ = network;
1794 connected_ = true;
1795 peer_addr_ = address;
1796 int result = data_->connect_data().result;
1797 IoMode mode = data_->connect_data().mode;
1798 if (data_->connect_data().completer) {
1799 data_->connect_data().completer->SetCallback(std::move(callback));
1800 return ERR_IO_PENDING;
1801 }
1802 if (mode == SYNCHRONOUS) {
1803 return result;
1804 }
1805 RunCallbackAsync(std::move(callback), result);
1806 return ERR_IO_PENDING;
1807 }
1808
ConnectUsingDefaultNetworkAsync(const IPEndPoint & address,CompletionOnceCallback callback)1809 int MockUDPClientSocket::ConnectUsingDefaultNetworkAsync(
1810 const IPEndPoint& address,
1811 CompletionOnceCallback callback) {
1812 DCHECK(!connected_);
1813 if (!data_)
1814 return ERR_UNEXPECTED;
1815 network_ = kDefaultNetworkForTests;
1816 connected_ = true;
1817 peer_addr_ = address;
1818 int result = data_->connect_data().result;
1819 IoMode mode = data_->connect_data().mode;
1820 if (data_->connect_data().completer) {
1821 data_->connect_data().completer->SetCallback(std::move(callback));
1822 return ERR_IO_PENDING;
1823 }
1824 if (mode == SYNCHRONOUS) {
1825 return result;
1826 }
1827 RunCallbackAsync(std::move(callback), result);
1828 return ERR_IO_PENDING;
1829 }
1830
GetBoundNetwork() const1831 handles::NetworkHandle MockUDPClientSocket::GetBoundNetwork() const {
1832 return network_;
1833 }
1834
ApplySocketTag(const SocketTag & tag)1835 void MockUDPClientSocket::ApplySocketTag(const SocketTag& tag) {
1836 tagged_before_data_transferred_ &= !data_transferred_ || tag == tag_;
1837 tag_ = tag;
1838 }
1839
GetLastTos() const1840 DscpAndEcn MockUDPClientSocket::GetLastTos() const {
1841 return TosToDscpAndEcn(last_tos_);
1842 }
1843
OnReadComplete(const MockRead & data)1844 void MockUDPClientSocket::OnReadComplete(const MockRead& data) {
1845 if (!data_)
1846 return;
1847
1848 // There must be a read pending.
1849 DCHECK(pending_read_buf_.get());
1850 DCHECK(pending_read_callback_);
1851 // You can't complete a read with another ERR_IO_PENDING status code.
1852 DCHECK_NE(ERR_IO_PENDING, data.result);
1853 // Since we've been waiting for data, need_read_data_ should be true.
1854 DCHECK(need_read_data_);
1855
1856 read_data_ = data;
1857 last_tos_ = data.tos;
1858 need_read_data_ = false;
1859
1860 // The caller is simulating that this IO completes right now. Don't
1861 // let CompleteRead() schedule a callback.
1862 read_data_.mode = SYNCHRONOUS;
1863
1864 CompletionOnceCallback callback = std::move(pending_read_callback_);
1865 int rv = CompleteRead();
1866 RunCallback(std::move(callback), rv);
1867 }
1868
OnWriteComplete(int rv)1869 void MockUDPClientSocket::OnWriteComplete(int rv) {
1870 if (!data_)
1871 return;
1872
1873 // There must be a read pending.
1874 DCHECK(!pending_write_callback_.is_null());
1875 RunCallback(std::move(pending_write_callback_), rv);
1876 }
1877
OnConnectComplete(const MockConnect & data)1878 void MockUDPClientSocket::OnConnectComplete(const MockConnect& data) {
1879 NOTIMPLEMENTED();
1880 }
1881
OnDataProviderDestroyed()1882 void MockUDPClientSocket::OnDataProviderDestroyed() {
1883 data_ = nullptr;
1884 }
1885
CompleteRead()1886 int MockUDPClientSocket::CompleteRead() {
1887 DCHECK(pending_read_buf_.get());
1888 DCHECK(pending_read_buf_len_ > 0);
1889
1890 // Save the pending async IO data and reset our |pending_| state.
1891 scoped_refptr<IOBuffer> buf = pending_read_buf_;
1892 int buf_len = pending_read_buf_len_;
1893 CompletionOnceCallback callback = std::move(pending_read_callback_);
1894 pending_read_buf_ = nullptr;
1895 pending_read_buf_len_ = 0;
1896
1897 int result = read_data_.result;
1898 DCHECK(result != ERR_IO_PENDING);
1899
1900 if (read_data_.data) {
1901 if (read_data_.data_len - read_offset_ > 0) {
1902 result = std::min(buf_len, read_data_.data_len - read_offset_);
1903 memcpy(buf->data(), read_data_.data + read_offset_, result);
1904 read_offset_ += result;
1905 if (read_offset_ == read_data_.data_len) {
1906 need_read_data_ = true;
1907 read_offset_ = 0;
1908 }
1909 } else {
1910 result = 0; // EOF
1911 }
1912 }
1913
1914 if (read_data_.mode == ASYNC) {
1915 DCHECK(!callback.is_null());
1916 RunCallbackAsync(std::move(callback), result);
1917 return ERR_IO_PENDING;
1918 }
1919 return result;
1920 }
1921
RunCallbackAsync(CompletionOnceCallback callback,int result)1922 void MockUDPClientSocket::RunCallbackAsync(CompletionOnceCallback callback,
1923 int result) {
1924 base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
1925 FROM_HERE,
1926 base::BindOnce(&MockUDPClientSocket::RunCallback,
1927 weak_factory_.GetWeakPtr(), std::move(callback), result));
1928 }
1929
RunCallback(CompletionOnceCallback callback,int result)1930 void MockUDPClientSocket::RunCallback(CompletionOnceCallback callback,
1931 int result) {
1932 std::move(callback).Run(result);
1933 }
1934
TestSocketRequest(std::vector<raw_ptr<TestSocketRequest,VectorExperimental>> * request_order,size_t * completion_count)1935 TestSocketRequest::TestSocketRequest(
1936 std::vector<raw_ptr<TestSocketRequest, VectorExperimental>>* request_order,
1937 size_t* completion_count)
1938 : request_order_(request_order), completion_count_(completion_count) {
1939 DCHECK(request_order);
1940 DCHECK(completion_count);
1941 }
1942
1943 TestSocketRequest::~TestSocketRequest() = default;
1944
OnComplete(int result)1945 void TestSocketRequest::OnComplete(int result) {
1946 SetResult(result);
1947 (*completion_count_)++;
1948 request_order_->push_back(this);
1949 }
1950
1951 // static
1952 const int ClientSocketPoolTest::kIndexOutOfBounds = -1;
1953
1954 // static
1955 const int ClientSocketPoolTest::kRequestNotFound = -2;
1956
1957 ClientSocketPoolTest::ClientSocketPoolTest() = default;
1958 ClientSocketPoolTest::~ClientSocketPoolTest() = default;
1959
GetOrderOfRequest(size_t index) const1960 int ClientSocketPoolTest::GetOrderOfRequest(size_t index) const {
1961 index--;
1962 if (index >= requests_.size())
1963 return kIndexOutOfBounds;
1964
1965 for (size_t i = 0; i < request_order_.size(); i++)
1966 if (requests_[index].get() == request_order_[i])
1967 return i + 1;
1968
1969 return kRequestNotFound;
1970 }
1971
ReleaseOneConnection(KeepAlive keep_alive)1972 bool ClientSocketPoolTest::ReleaseOneConnection(KeepAlive keep_alive) {
1973 for (std::unique_ptr<TestSocketRequest>& it : requests_) {
1974 if (it->handle()->is_initialized()) {
1975 if (keep_alive == NO_KEEP_ALIVE)
1976 it->handle()->socket()->Disconnect();
1977 it->handle()->Reset();
1978 base::RunLoop().RunUntilIdle();
1979 return true;
1980 }
1981 }
1982 return false;
1983 }
1984
ReleaseAllConnections(KeepAlive keep_alive)1985 void ClientSocketPoolTest::ReleaseAllConnections(KeepAlive keep_alive) {
1986 bool released_one;
1987 do {
1988 released_one = ReleaseOneConnection(keep_alive);
1989 } while (released_one);
1990 }
1991
MockConnectJob(std::unique_ptr<StreamSocket> socket,ClientSocketHandle * handle,const SocketTag & socket_tag,CompletionOnceCallback callback,RequestPriority priority)1992 MockTransportClientSocketPool::MockConnectJob::MockConnectJob(
1993 std::unique_ptr<StreamSocket> socket,
1994 ClientSocketHandle* handle,
1995 const SocketTag& socket_tag,
1996 CompletionOnceCallback callback,
1997 RequestPriority priority)
1998 : socket_(std::move(socket)),
1999 handle_(handle),
2000 socket_tag_(socket_tag),
2001 user_callback_(std::move(callback)),
2002 priority_(priority) {}
2003
2004 MockTransportClientSocketPool::MockConnectJob::~MockConnectJob() = default;
2005
Connect()2006 int MockTransportClientSocketPool::MockConnectJob::Connect() {
2007 socket_->ApplySocketTag(socket_tag_);
2008 int rv = socket_->Connect(
2009 base::BindOnce(&MockConnectJob::OnConnect, base::Unretained(this)));
2010 if (rv != ERR_IO_PENDING) {
2011 user_callback_.Reset();
2012 OnConnect(rv);
2013 }
2014 return rv;
2015 }
2016
CancelHandle(const ClientSocketHandle * handle)2017 bool MockTransportClientSocketPool::MockConnectJob::CancelHandle(
2018 const ClientSocketHandle* handle) {
2019 if (handle != handle_)
2020 return false;
2021 socket_.reset();
2022 handle_ = nullptr;
2023 user_callback_.Reset();
2024 return true;
2025 }
2026
OnConnect(int rv)2027 void MockTransportClientSocketPool::MockConnectJob::OnConnect(int rv) {
2028 if (!socket_.get())
2029 return;
2030 if (rv == OK) {
2031 handle_->SetSocket(std::move(socket_));
2032
2033 // Needed for socket pool tests that layer other sockets on top of mock
2034 // sockets.
2035 LoadTimingInfo::ConnectTiming connect_timing;
2036 base::TimeTicks now = base::TimeTicks::Now();
2037 connect_timing.domain_lookup_start = now;
2038 connect_timing.domain_lookup_end = now;
2039 connect_timing.connect_start = now;
2040 connect_timing.connect_end = now;
2041 handle_->set_connect_timing(connect_timing);
2042 } else {
2043 socket_.reset();
2044
2045 // Needed to test copying of ConnectionAttempts in SSL ConnectJob.
2046 ConnectionAttempts attempts;
2047 attempts.push_back(ConnectionAttempt(IPEndPoint(), rv));
2048 handle_->set_connection_attempts(attempts);
2049 }
2050
2051 handle_ = nullptr;
2052
2053 if (!user_callback_.is_null()) {
2054 std::move(user_callback_).Run(rv);
2055 }
2056 }
2057
MockTransportClientSocketPool(int max_sockets,int max_sockets_per_group,const CommonConnectJobParams * common_connect_job_params)2058 MockTransportClientSocketPool::MockTransportClientSocketPool(
2059 int max_sockets,
2060 int max_sockets_per_group,
2061 const CommonConnectJobParams* common_connect_job_params)
2062 : TransportClientSocketPool(
2063 max_sockets,
2064 max_sockets_per_group,
2065 base::Seconds(10) /* unused_idle_socket_timeout */,
2066 ProxyChain::Direct(),
2067 false /* is_for_websockets */,
2068 common_connect_job_params),
2069 client_socket_factory_(common_connect_job_params->client_socket_factory) {
2070 }
2071
2072 MockTransportClientSocketPool::~MockTransportClientSocketPool() = default;
2073
RequestSocket(const ClientSocketPool::GroupId & group_id,scoped_refptr<ClientSocketPool::SocketParams> socket_params,const std::optional<NetworkTrafficAnnotationTag> & proxy_annotation_tag,RequestPriority priority,const SocketTag & socket_tag,RespectLimits respect_limits,ClientSocketHandle * handle,CompletionOnceCallback callback,const ProxyAuthCallback & on_auth_callback,const NetLogWithSource & net_log)2074 int MockTransportClientSocketPool::RequestSocket(
2075 const ClientSocketPool::GroupId& group_id,
2076 scoped_refptr<ClientSocketPool::SocketParams> socket_params,
2077 const std::optional<NetworkTrafficAnnotationTag>& proxy_annotation_tag,
2078 RequestPriority priority,
2079 const SocketTag& socket_tag,
2080 RespectLimits respect_limits,
2081 ClientSocketHandle* handle,
2082 CompletionOnceCallback callback,
2083 const ProxyAuthCallback& on_auth_callback,
2084 const NetLogWithSource& net_log) {
2085 last_request_priority_ = priority;
2086 std::unique_ptr<StreamSocket> socket =
2087 client_socket_factory_->CreateTransportClientSocket(
2088 AddressList(), nullptr, nullptr, net_log.net_log(), NetLogSource());
2089 auto job = std::make_unique<MockConnectJob>(
2090 std::move(socket), handle, socket_tag, std::move(callback), priority);
2091 auto* job_ptr = job.get();
2092 job_list_.push_back(std::move(job));
2093 handle->set_group_generation(1);
2094 return job_ptr->Connect();
2095 }
2096
SetPriority(const ClientSocketPool::GroupId & group_id,ClientSocketHandle * handle,RequestPriority priority)2097 void MockTransportClientSocketPool::SetPriority(
2098 const ClientSocketPool::GroupId& group_id,
2099 ClientSocketHandle* handle,
2100 RequestPriority priority) {
2101 for (auto& job : job_list_) {
2102 if (job->handle() == handle) {
2103 job->set_priority(priority);
2104 return;
2105 }
2106 }
2107 NOTREACHED();
2108 }
2109
CancelRequest(const ClientSocketPool::GroupId & group_id,ClientSocketHandle * handle,bool cancel_connect_job)2110 void MockTransportClientSocketPool::CancelRequest(
2111 const ClientSocketPool::GroupId& group_id,
2112 ClientSocketHandle* handle,
2113 bool cancel_connect_job) {
2114 for (std::unique_ptr<MockConnectJob>& it : job_list_) {
2115 if (it->CancelHandle(handle)) {
2116 cancel_count_++;
2117 break;
2118 }
2119 }
2120 }
2121
ReleaseSocket(const ClientSocketPool::GroupId & group_id,std::unique_ptr<StreamSocket> socket,int64_t generation)2122 void MockTransportClientSocketPool::ReleaseSocket(
2123 const ClientSocketPool::GroupId& group_id,
2124 std::unique_ptr<StreamSocket> socket,
2125 int64_t generation) {
2126 EXPECT_EQ(1, generation);
2127 release_count_++;
2128 }
2129
WrappedStreamSocket(std::unique_ptr<StreamSocket> transport)2130 WrappedStreamSocket::WrappedStreamSocket(
2131 std::unique_ptr<StreamSocket> transport)
2132 : transport_(std::move(transport)) {}
2133 WrappedStreamSocket::~WrappedStreamSocket() = default;
2134
Bind(const net::IPEndPoint & local_addr)2135 int WrappedStreamSocket::Bind(const net::IPEndPoint& local_addr) {
2136 NOTREACHED();
2137 }
2138
Connect(CompletionOnceCallback callback)2139 int WrappedStreamSocket::Connect(CompletionOnceCallback callback) {
2140 return transport_->Connect(std::move(callback));
2141 }
2142
Disconnect()2143 void WrappedStreamSocket::Disconnect() {
2144 transport_->Disconnect();
2145 }
2146
IsConnected() const2147 bool WrappedStreamSocket::IsConnected() const {
2148 return transport_->IsConnected();
2149 }
2150
IsConnectedAndIdle() const2151 bool WrappedStreamSocket::IsConnectedAndIdle() const {
2152 return transport_->IsConnectedAndIdle();
2153 }
2154
GetPeerAddress(IPEndPoint * address) const2155 int WrappedStreamSocket::GetPeerAddress(IPEndPoint* address) const {
2156 return transport_->GetPeerAddress(address);
2157 }
2158
GetLocalAddress(IPEndPoint * address) const2159 int WrappedStreamSocket::GetLocalAddress(IPEndPoint* address) const {
2160 return transport_->GetLocalAddress(address);
2161 }
2162
NetLog() const2163 const NetLogWithSource& WrappedStreamSocket::NetLog() const {
2164 return transport_->NetLog();
2165 }
2166
WasEverUsed() const2167 bool WrappedStreamSocket::WasEverUsed() const {
2168 return transport_->WasEverUsed();
2169 }
2170
GetNegotiatedProtocol() const2171 NextProto WrappedStreamSocket::GetNegotiatedProtocol() const {
2172 return transport_->GetNegotiatedProtocol();
2173 }
2174
GetSSLInfo(SSLInfo * ssl_info)2175 bool WrappedStreamSocket::GetSSLInfo(SSLInfo* ssl_info) {
2176 return transport_->GetSSLInfo(ssl_info);
2177 }
2178
GetTotalReceivedBytes() const2179 int64_t WrappedStreamSocket::GetTotalReceivedBytes() const {
2180 return transport_->GetTotalReceivedBytes();
2181 }
2182
ApplySocketTag(const SocketTag & tag)2183 void WrappedStreamSocket::ApplySocketTag(const SocketTag& tag) {
2184 transport_->ApplySocketTag(tag);
2185 }
2186
Read(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)2187 int WrappedStreamSocket::Read(IOBuffer* buf,
2188 int buf_len,
2189 CompletionOnceCallback callback) {
2190 return transport_->Read(buf, buf_len, std::move(callback));
2191 }
2192
ReadIfReady(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)2193 int WrappedStreamSocket::ReadIfReady(IOBuffer* buf,
2194 int buf_len,
2195 CompletionOnceCallback callback) {
2196 return transport_->ReadIfReady(buf, buf_len, std::move((callback)));
2197 }
2198
Write(IOBuffer * buf,int buf_len,CompletionOnceCallback callback,const NetworkTrafficAnnotationTag & traffic_annotation)2199 int WrappedStreamSocket::Write(
2200 IOBuffer* buf,
2201 int buf_len,
2202 CompletionOnceCallback callback,
2203 const NetworkTrafficAnnotationTag& traffic_annotation) {
2204 return transport_->Write(buf, buf_len, std::move(callback),
2205 TRAFFIC_ANNOTATION_FOR_TESTS);
2206 }
2207
SetReceiveBufferSize(int32_t size)2208 int WrappedStreamSocket::SetReceiveBufferSize(int32_t size) {
2209 return transport_->SetReceiveBufferSize(size);
2210 }
2211
SetSendBufferSize(int32_t size)2212 int WrappedStreamSocket::SetSendBufferSize(int32_t size) {
2213 return transport_->SetSendBufferSize(size);
2214 }
2215
Connect(CompletionOnceCallback callback)2216 int MockTaggingStreamSocket::Connect(CompletionOnceCallback callback) {
2217 connected_ = true;
2218 return WrappedStreamSocket::Connect(std::move(callback));
2219 }
2220
ApplySocketTag(const SocketTag & tag)2221 void MockTaggingStreamSocket::ApplySocketTag(const SocketTag& tag) {
2222 tagged_before_connected_ &= !connected_ || tag == tag_;
2223 tag_ = tag;
2224 transport_->ApplySocketTag(tag);
2225 }
2226
2227 std::unique_ptr<TransportClientSocket>
CreateTransportClientSocket(const AddressList & addresses,std::unique_ptr<SocketPerformanceWatcher> socket_performance_watcher,NetworkQualityEstimator * network_quality_estimator,NetLog * net_log,const NetLogSource & source)2228 MockTaggingClientSocketFactory::CreateTransportClientSocket(
2229 const AddressList& addresses,
2230 std::unique_ptr<SocketPerformanceWatcher> socket_performance_watcher,
2231 NetworkQualityEstimator* network_quality_estimator,
2232 NetLog* net_log,
2233 const NetLogSource& source) {
2234 auto socket = std::make_unique<MockTaggingStreamSocket>(
2235 MockClientSocketFactory::CreateTransportClientSocket(
2236 addresses, std::move(socket_performance_watcher),
2237 network_quality_estimator, net_log, source));
2238 tcp_socket_ = socket.get();
2239 return std::move(socket);
2240 }
2241
2242 std::unique_ptr<DatagramClientSocket>
CreateDatagramClientSocket(DatagramSocket::BindType bind_type,NetLog * net_log,const NetLogSource & source)2243 MockTaggingClientSocketFactory::CreateDatagramClientSocket(
2244 DatagramSocket::BindType bind_type,
2245 NetLog* net_log,
2246 const NetLogSource& source) {
2247 std::unique_ptr<DatagramClientSocket> socket(
2248 MockClientSocketFactory::CreateDatagramClientSocket(bind_type, net_log,
2249 source));
2250 udp_socket_ = static_cast<MockUDPClientSocket*>(socket.get());
2251 return socket;
2252 }
2253
2254 const char kSOCKS4TestHost[] = "127.0.0.1";
2255 const int kSOCKS4TestPort = 80;
2256
2257 const char kSOCKS4OkRequestLocalHostPort80[] = {0x04, 0x01, 0x00, 0x50, 127,
2258 0, 0, 1, 0};
2259 const int kSOCKS4OkRequestLocalHostPort80Length =
2260 std::size(kSOCKS4OkRequestLocalHostPort80);
2261
2262 const char kSOCKS4OkReply[] = {0x00, 0x5A, 0x00, 0x00, 0, 0, 0, 0};
2263 const int kSOCKS4OkReplyLength = std::size(kSOCKS4OkReply);
2264
2265 const char kSOCKS5TestHost[] = "host";
2266 const int kSOCKS5TestPort = 80;
2267
2268 const char kSOCKS5GreetRequest[] = {0x05, 0x01, 0x00};
2269 const int kSOCKS5GreetRequestLength = std::size(kSOCKS5GreetRequest);
2270
2271 const char kSOCKS5GreetResponse[] = {0x05, 0x00};
2272 const int kSOCKS5GreetResponseLength = std::size(kSOCKS5GreetResponse);
2273
2274 const char kSOCKS5OkRequest[] = {0x05, 0x01, 0x00, 0x03, 0x04, 'h',
2275 'o', 's', 't', 0x00, 0x50};
2276 const int kSOCKS5OkRequestLength = std::size(kSOCKS5OkRequest);
2277
2278 const char kSOCKS5OkResponse[] = {0x05, 0x00, 0x00, 0x01, 127,
2279 0, 0, 1, 0x00, 0x50};
2280 const int kSOCKS5OkResponseLength = std::size(kSOCKS5OkResponse);
2281
CountReadBytes(base::span<const MockRead> reads)2282 int64_t CountReadBytes(base::span<const MockRead> reads) {
2283 int64_t total = 0;
2284 for (const MockRead& read : reads)
2285 total += read.data_len;
2286 return total;
2287 }
2288
CountWriteBytes(base::span<const MockWrite> writes)2289 int64_t CountWriteBytes(base::span<const MockWrite> writes) {
2290 int64_t total = 0;
2291 for (const MockWrite& write : writes)
2292 total += write.data_len;
2293 return total;
2294 }
2295
2296 #if BUILDFLAG(IS_ANDROID)
CanGetTaggedBytes()2297 bool CanGetTaggedBytes() {
2298 // In Android P, /proc/net/xt_qtaguid/stats is no longer guaranteed to be
2299 // present, and has been replaced with eBPF Traffic Monitoring in netd. See:
2300 // https://source.android.com/devices/tech/datausage/ebpf-traffic-monitor
2301 //
2302 // To read traffic statistics from netd, apps should use the API
2303 // NetworkStatsManager.queryDetailsForUidTag(). But this API does not provide
2304 // statistics for local traffic, only mobile and WiFi traffic, so it would not
2305 // work in tests that spin up a local server. So for now, GetTaggedBytes is
2306 // only supported on Android releases older than P.
2307 return base::android::BuildInfo::GetInstance()->sdk_int() <
2308 base::android::SDK_VERSION_P;
2309 }
2310
GetTaggedBytes(int32_t expected_tag)2311 uint64_t GetTaggedBytes(int32_t expected_tag) {
2312 EXPECT_TRUE(CanGetTaggedBytes());
2313
2314 // To determine how many bytes the system saw with a particular tag read
2315 // the /proc/net/xt_qtaguid/stats file which contains the kernel's
2316 // dump of all the UIDs and their tags sent and received bytes.
2317 uint64_t bytes = 0;
2318 std::string contents;
2319 EXPECT_TRUE(base::ReadFileToString(
2320 base::FilePath::FromUTF8Unsafe("/proc/net/xt_qtaguid/stats"), &contents));
2321 for (size_t i = contents.find('\n'); // Skip first line which is headers.
2322 i != std::string::npos && i < contents.length();) {
2323 uint64_t tag, rx_bytes;
2324 uid_t uid;
2325 int n;
2326 // Parse out the numbers we care about. For reference here's the column
2327 // headers:
2328 // idx iface acct_tag_hex uid_tag_int cnt_set rx_bytes rx_packets tx_bytes
2329 // tx_packets rx_tcp_bytes rx_tcp_packets rx_udp_bytes rx_udp_packets
2330 // rx_other_bytes rx_other_packets tx_tcp_bytes tx_tcp_packets tx_udp_bytes
2331 // tx_udp_packets tx_other_bytes tx_other_packets
2332 EXPECT_EQ(sscanf(contents.c_str() + i,
2333 "%*d %*s 0x%" SCNx64 " %d %*d %" SCNu64
2334 " %*d %*d %*d %*d %*d %*d %*d %*d "
2335 "%*d %*d %*d %*d %*d %*d %*d%n",
2336 &tag, &uid, &rx_bytes, &n),
2337 3);
2338 // If this line matches our UID and |expected_tag| then add it to the total.
2339 if (uid == getuid() && (int32_t)(tag >> 32) == expected_tag) {
2340 bytes += rx_bytes;
2341 }
2342 // Move |i| to the next line.
2343 i += n + 1;
2344 }
2345 return bytes;
2346 }
2347 #endif
2348
2349 } // namespace net
2350