• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #ifdef UNSAFE_BUFFERS_BUILD
6 // TODO(crbug.com/40284755): Remove this and spanify to fix the errors.
7 #pragma allow_unsafe_buffers
8 #endif
9 
10 #include "net/socket/socket_test_util.h"
11 
12 #include <inttypes.h>  // For SCNx64
13 #include <stdint.h>
14 #include <stdio.h>
15 
16 #include <memory>
17 #include <ostream>
18 #include <string>
19 #include <string_view>
20 #include <utility>
21 #include <vector>
22 
23 #include "base/compiler_specific.h"
24 #include "base/files/file_util.h"
25 #include "base/functional/bind.h"
26 #include "base/functional/callback_helpers.h"
27 #include "base/location.h"
28 #include "base/logging.h"
29 #include "base/memory/raw_ptr.h"
30 #include "base/notreached.h"
31 #include "base/rand_util.h"
32 #include "base/ranges/algorithm.h"
33 #include "base/run_loop.h"
34 #include "base/task/single_thread_task_runner.h"
35 #include "base/time/time.h"
36 #include "build/build_config.h"
37 #include "net/base/address_family.h"
38 #include "net/base/address_list.h"
39 #include "net/base/auth.h"
40 #include "net/base/completion_once_callback.h"
41 #include "net/base/hex_utils.h"
42 #include "net/base/ip_address.h"
43 #include "net/base/load_timing_info.h"
44 #include "net/base/net_errors.h"
45 #include "net/base/proxy_server.h"
46 #include "net/http/http_network_session.h"
47 #include "net/http/http_request_headers.h"
48 #include "net/http/http_response_headers.h"
49 #include "net/log/net_log_source.h"
50 #include "net/log/net_log_source_type.h"
51 #include "net/socket/connect_job.h"
52 #include "net/socket/socket.h"
53 #include "net/socket/stream_socket.h"
54 #include "net/socket/websocket_endpoint_lock_manager.h"
55 #include "net/ssl/ssl_cert_request_info.h"
56 #include "net/ssl/ssl_connection_status_flags.h"
57 #include "net/ssl/ssl_info.h"
58 #include "net/traffic_annotation/network_traffic_annotation.h"
59 #include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
60 #include "testing/gtest/include/gtest/gtest.h"
61 #include "third_party/abseil-cpp/absl/strings/ascii.h"
62 
63 #if BUILDFLAG(IS_ANDROID)
64 #include "base/android/build_info.h"
65 #endif
66 
67 #define NET_TRACE(level, s) VLOG(level) << s << __FUNCTION__ << "() "
68 
69 namespace net {
70 namespace {
71 
AsciifyHigh(char x)72 inline char AsciifyHigh(char x) {
73   char nybble = static_cast<char>((x >> 4) & 0x0F);
74   return nybble + ((nybble < 0x0A) ? '0' : 'A' - 10);
75 }
76 
AsciifyLow(char x)77 inline char AsciifyLow(char x) {
78   char nybble = static_cast<char>((x >> 0) & 0x0F);
79   return nybble + ((nybble < 0x0A) ? '0' : 'A' - 10);
80 }
81 
Asciify(char x)82 inline char Asciify(char x) {
83   return absl::ascii_isprint(static_cast<unsigned char>(x)) ? x : '.';
84 }
85 
DumpData(const char * data,int data_len)86 void DumpData(const char* data, int data_len) {
87   if (logging::LOGGING_INFO < logging::GetMinLogLevel()) {
88     return;
89   }
90   DVLOG(1) << "Length:  " << data_len;
91   const char* pfx = "Data:    ";
92   if (!data || (data_len <= 0)) {
93     DVLOG(1) << pfx << "<None>";
94   } else {
95     int i;
96     for (i = 0; i <= (data_len - 4); i += 4) {
97       DVLOG(1) << pfx << AsciifyHigh(data[i + 0]) << AsciifyLow(data[i + 0])
98                << AsciifyHigh(data[i + 1]) << AsciifyLow(data[i + 1])
99                << AsciifyHigh(data[i + 2]) << AsciifyLow(data[i + 2])
100                << AsciifyHigh(data[i + 3]) << AsciifyLow(data[i + 3]) << "  '"
101                << Asciify(data[i + 0]) << Asciify(data[i + 1])
102                << Asciify(data[i + 2]) << Asciify(data[i + 3]) << "'";
103       pfx = "         ";
104     }
105     // Take care of any 'trailing' bytes, if data_len was not a multiple of 4.
106     switch (data_len - i) {
107       case 3:
108         DVLOG(1) << pfx << AsciifyHigh(data[i + 0]) << AsciifyLow(data[i + 0])
109                  << AsciifyHigh(data[i + 1]) << AsciifyLow(data[i + 1])
110                  << AsciifyHigh(data[i + 2]) << AsciifyLow(data[i + 2])
111                  << "    '" << Asciify(data[i + 0]) << Asciify(data[i + 1])
112                  << Asciify(data[i + 2]) << " '";
113         break;
114       case 2:
115         DVLOG(1) << pfx << AsciifyHigh(data[i + 0]) << AsciifyLow(data[i + 0])
116                  << AsciifyHigh(data[i + 1]) << AsciifyLow(data[i + 1])
117                  << "      '" << Asciify(data[i + 0]) << Asciify(data[i + 1])
118                  << "  '";
119         break;
120       case 1:
121         DVLOG(1) << pfx << AsciifyHigh(data[i + 0]) << AsciifyLow(data[i + 0])
122                  << "        '" << Asciify(data[i + 0]) << "   '";
123         break;
124     }
125   }
126 }
127 
128 template <MockReadWriteType type>
DumpMockReadWrite(const MockReadWrite<type> & r)129 void DumpMockReadWrite(const MockReadWrite<type>& r) {
130   if (logging::LOGGING_INFO < logging::GetMinLogLevel()) {
131     return;
132   }
133   DVLOG(1) << "Async:   " << (r.mode == ASYNC) << "\nResult:  " << r.result;
134   DumpData(r.data, r.data_len);
135   const char* stop = (r.sequence_number & MockRead::STOPLOOP) ? " (STOP)" : "";
136   DVLOG(1) << "Stage:   " << (r.sequence_number & ~MockRead::STOPLOOP) << stop;
137 }
138 
RunClosureIfNonNull(base::OnceClosure closure)139 void RunClosureIfNonNull(base::OnceClosure closure) {
140   if (!closure.is_null()) {
141     std::move(closure).Run();
142   }
143 }
144 
145 }  // namespace
146 
147 MockConnectCompleter::MockConnectCompleter() = default;
148 
149 MockConnectCompleter::~MockConnectCompleter() = default;
150 
SetCallback(CompletionOnceCallback callback)151 void MockConnectCompleter::SetCallback(CompletionOnceCallback callback) {
152   CHECK(!callback_);
153   callback_ = std::move(callback);
154 }
155 
Complete(int result)156 void MockConnectCompleter::Complete(int result) {
157   CHECK(callback_);
158   std::move(callback_).Run(result);
159 }
160 
MockConnect()161 MockConnect::MockConnect() : mode(ASYNC), result(OK) {
162   peer_addr = IPEndPoint(IPAddress(192, 0, 2, 33), 0);
163 }
164 
MockConnect(IoMode io_mode,int r)165 MockConnect::MockConnect(IoMode io_mode, int r) : mode(io_mode), result(r) {
166   peer_addr = IPEndPoint(IPAddress(192, 0, 2, 33), 0);
167 }
168 
MockConnect(IoMode io_mode,int r,IPEndPoint addr)169 MockConnect::MockConnect(IoMode io_mode, int r, IPEndPoint addr)
170     : mode(io_mode), result(r), peer_addr(addr) {}
171 
MockConnect(IoMode io_mode,int r,IPEndPoint addr,bool first_attempt_fails)172 MockConnect::MockConnect(IoMode io_mode,
173                          int r,
174                          IPEndPoint addr,
175                          bool first_attempt_fails)
176     : mode(io_mode),
177       result(r),
178       peer_addr(addr),
179       first_attempt_fails(first_attempt_fails) {}
180 
MockConnect(MockConnectCompleter * completer)181 MockConnect::MockConnect(MockConnectCompleter* completer)
182     : mode(ASYNC), result(OK), completer(completer) {}
183 
184 MockConnect::~MockConnect() = default;
185 
MockConfirm()186 MockConfirm::MockConfirm() : mode(SYNCHRONOUS), result(OK) {}
187 
MockConfirm(IoMode io_mode,int r)188 MockConfirm::MockConfirm(IoMode io_mode, int r) : mode(io_mode), result(r) {}
189 
190 MockConfirm::~MockConfirm() = default;
191 
IsIdle() const192 bool SocketDataProvider::IsIdle() const {
193   return true;
194 }
195 
Initialize(AsyncSocket * socket)196 void SocketDataProvider::Initialize(AsyncSocket* socket) {
197   CHECK(!socket_);
198   CHECK(socket);
199   socket_ = socket;
200   Reset();
201 }
202 
DetachSocket()203 void SocketDataProvider::DetachSocket() {
204   CHECK(socket_);
205   socket_ = nullptr;
206 }
207 
208 SocketDataProvider::SocketDataProvider() = default;
209 
~SocketDataProvider()210 SocketDataProvider::~SocketDataProvider() {
211   if (socket_)
212     socket_->OnDataProviderDestroyed();
213 }
214 
StaticSocketDataHelper(base::span<const MockRead> reads,base::span<const MockWrite> writes)215 StaticSocketDataHelper::StaticSocketDataHelper(
216     base::span<const MockRead> reads,
217     base::span<const MockWrite> writes)
218     : reads_(reads), writes_(writes) {}
219 
220 StaticSocketDataHelper::~StaticSocketDataHelper() = default;
221 
PeekRead() const222 const MockRead& StaticSocketDataHelper::PeekRead() const {
223   CHECK(!AllReadDataConsumed());
224   return reads_[read_index_];
225 }
226 
PeekWrite() const227 const MockWrite& StaticSocketDataHelper::PeekWrite() const {
228   CHECK(!AllWriteDataConsumed());
229   return writes_[write_index_];
230 }
231 
AdvanceRead()232 const MockRead& StaticSocketDataHelper::AdvanceRead() {
233   CHECK(!AllReadDataConsumed());
234   return reads_[read_index_++];
235 }
236 
AdvanceWrite()237 const MockWrite& StaticSocketDataHelper::AdvanceWrite() {
238   CHECK(!AllWriteDataConsumed());
239   return writes_[write_index_++];
240 }
241 
Reset()242 void StaticSocketDataHelper::Reset() {
243   read_index_ = 0;
244   write_index_ = 0;
245 }
246 
VerifyWriteData(const std::string & data,SocketDataPrinter * printer)247 bool StaticSocketDataHelper::VerifyWriteData(const std::string& data,
248                                              SocketDataPrinter* printer) {
249   CHECK(!AllWriteDataConsumed());
250   // Check that the actual data matches the expectations, skipping over any
251   // pause events.
252   const MockWrite& next_write = PeekRealWrite();
253   if (!next_write.data)
254     return true;
255 
256   // Note: Partial writes are supported here.  If the expected data
257   // is a match, but shorter than the write actually written, that is legal.
258   // Example:
259   //   Application writes "foobarbaz" (9 bytes)
260   //   Expected write was "foo" (3 bytes)
261   //   This is a success, and the function returns true.
262   std::string expected_data(next_write.data, next_write.data_len);
263   std::string actual_data(data.substr(0, next_write.data_len));
264   if (printer) {
265     EXPECT_TRUE(actual_data == expected_data)
266         << "Actual formatted write data:\n"
267         << printer->PrintWrite(data) << "Expected formatted write data:\n"
268         << printer->PrintWrite(expected_data) << "Actual raw write data:\n"
269         << HexDump(data) << "Expected raw write data:\n"
270         << HexDump(expected_data);
271   } else {
272     EXPECT_TRUE(actual_data == expected_data)
273         << "Actual write data:\n"
274         << HexDump(data) << "Expected write data:\n"
275         << HexDump(expected_data);
276   }
277   return expected_data == actual_data;
278 }
279 
ExpectAllReadDataConsumed(SocketDataPrinter * printer) const280 void StaticSocketDataHelper::ExpectAllReadDataConsumed(
281     SocketDataPrinter* printer) const {
282   if (AllReadDataConsumed()) {
283     return;
284   }
285 
286   std::ostringstream msg;
287   if (read_index_ < read_count()) {
288     msg << "Unconsumed reads:\n";
289     for (size_t i = read_index_; i < read_count(); i++) {
290       msg << (reads_[i].mode == ASYNC ? "ASYNC" : "SYNC") << " MockRead seq "
291           << reads_[i].sequence_number << ":\n";
292       if (reads_[i].result != OK) {
293         msg << "Result: " << reads_[i].result << "\n";
294       }
295       if (reads_[i].data) {
296         std::string data(reads_[i].data, reads_[i].data_len);
297         if (printer) {
298           msg << printer->PrintWrite(data);
299         }
300         msg << HexDump(data);
301       }
302     }
303   }
304   EXPECT_TRUE(AllReadDataConsumed()) << msg.str();
305 }
306 
ExpectAllWriteDataConsumed(SocketDataPrinter * printer) const307 void StaticSocketDataHelper::ExpectAllWriteDataConsumed(
308     SocketDataPrinter* printer) const {
309   if (AllWriteDataConsumed()) {
310     return;
311   }
312 
313   std::ostringstream msg;
314   if (write_index_ < write_count()) {
315     msg << "Unconsumed writes:\n";
316     for (size_t i = write_index_; i < write_count(); i++) {
317       msg << (writes_[i].mode == ASYNC ? "ASYNC" : "SYNC") << " MockWrite seq "
318           << writes_[i].sequence_number << ":\n";
319       if (writes_[i].result != OK) {
320         msg << "Result: " << writes_[i].result << "\n";
321       }
322       if (writes_[i].data) {
323         std::string data(writes_[i].data, writes_[i].data_len);
324         if (printer) {
325           msg << printer->PrintWrite(data);
326         }
327         msg << HexDump(data);
328       }
329     }
330   }
331   EXPECT_TRUE(AllWriteDataConsumed()) << msg.str();
332 }
333 
PeekRealWrite() const334 const MockWrite& StaticSocketDataHelper::PeekRealWrite() const {
335   for (size_t i = write_index_; i < write_count(); i++) {
336     if (writes_[i].mode != ASYNC || writes_[i].result != ERR_IO_PENDING)
337       return writes_[i];
338   }
339 
340   NOTREACHED() << "No write data available.";
341 }
342 
StaticSocketDataProvider()343 StaticSocketDataProvider::StaticSocketDataProvider()
344     : StaticSocketDataProvider(base::span<const MockRead>(),
345                                base::span<const MockWrite>()) {}
346 
StaticSocketDataProvider(base::span<const MockRead> reads,base::span<const MockWrite> writes)347 StaticSocketDataProvider::StaticSocketDataProvider(
348     base::span<const MockRead> reads,
349     base::span<const MockWrite> writes)
350     : helper_(reads, writes) {}
351 
352 StaticSocketDataProvider::~StaticSocketDataProvider() = default;
353 
Pause()354 void StaticSocketDataProvider::Pause() {
355   paused_ = true;
356 }
357 
Resume()358 void StaticSocketDataProvider::Resume() {
359   paused_ = false;
360 }
361 
OnRead()362 MockRead StaticSocketDataProvider::OnRead() {
363   if (AllReadDataConsumed()) {
364     const net::MockRead pending_read(net::SYNCHRONOUS, net::ERR_IO_PENDING);
365     return pending_read;
366   }
367 
368   return helper_.AdvanceRead();
369 }
370 
OnWrite(const std::string & data)371 MockWriteResult StaticSocketDataProvider::OnWrite(const std::string& data) {
372   if (helper_.write_count() == 0) {
373     // Not using mock writes; succeed synchronously.
374     return MockWriteResult(SYNCHRONOUS, data.length());
375   }
376   if (printer_) {
377     EXPECT_FALSE(helper_.AllWriteDataConsumed())
378         << "No more mock data to match write:\nFormatted write data:\n"
379         << printer_->PrintWrite(data) << "Raw write data:\n"
380         << HexDump(data);
381   } else {
382     EXPECT_FALSE(helper_.AllWriteDataConsumed())
383         << "No more mock data to match write:\nRaw write data:\n"
384         << HexDump(data);
385   }
386   if (helper_.AllWriteDataConsumed()) {
387     return MockWriteResult(SYNCHRONOUS, ERR_UNEXPECTED);
388   }
389 
390   // Check that what we are writing matches the expectation.
391   // Then give the mocked return value.
392   if (!helper_.VerifyWriteData(data, printer_))
393     return MockWriteResult(SYNCHRONOUS, ERR_UNEXPECTED);
394 
395   const MockWrite& next_write = helper_.AdvanceWrite();
396   // In the case that the write was successful, return the number of bytes
397   // written. Otherwise return the error code.
398   int result =
399       next_write.result == OK ? next_write.data_len : next_write.result;
400   return MockWriteResult(next_write.mode, result);
401 }
402 
AllReadDataConsumed() const403 bool StaticSocketDataProvider::AllReadDataConsumed() const {
404   return paused_ || helper_.AllReadDataConsumed();
405 }
406 
AllWriteDataConsumed() const407 bool StaticSocketDataProvider::AllWriteDataConsumed() const {
408   return helper_.AllWriteDataConsumed();
409 }
410 
Reset()411 void StaticSocketDataProvider::Reset() {
412   helper_.Reset();
413 }
414 
SSLSocketDataProvider(IoMode mode,int result)415 SSLSocketDataProvider::SSLSocketDataProvider(IoMode mode, int result)
416     : connect(mode, result),
417       expected_ssl_version_min(kDefaultSSLVersionMin),
418       expected_ssl_version_max(kDefaultSSLVersionMax) {
419   SSLConnectionStatusSetVersion(SSL_CONNECTION_VERSION_TLS1_3,
420                                 &ssl_info.connection_status);
421   // Set to TLS_CHACHA20_POLY1305_SHA256
422   SSLConnectionStatusSetCipherSuite(0x1301, &ssl_info.connection_status);
423 }
424 
SSLSocketDataProvider(MockConnectCompleter * completer)425 SSLSocketDataProvider::SSLSocketDataProvider(MockConnectCompleter* completer)
426     : connect(completer),
427       expected_ssl_version_min(kDefaultSSLVersionMin),
428       expected_ssl_version_max(kDefaultSSLVersionMax) {
429   SSLConnectionStatusSetVersion(SSL_CONNECTION_VERSION_TLS1_3,
430                                 &ssl_info.connection_status);
431   // Set to TLS_CHACHA20_POLY1305_SHA256
432   SSLConnectionStatusSetCipherSuite(0x1301, &ssl_info.connection_status);
433 }
434 
435 SSLSocketDataProvider::SSLSocketDataProvider(
436     const SSLSocketDataProvider& other) = default;
437 
438 SSLSocketDataProvider::~SSLSocketDataProvider() = default;
439 
SequencedSocketData()440 SequencedSocketData::SequencedSocketData()
441     : SequencedSocketData(base::span<const MockRead>(),
442                           base::span<const MockWrite>()) {}
443 
SequencedSocketData(base::span<const MockRead> reads,base::span<const MockWrite> writes)444 SequencedSocketData::SequencedSocketData(base::span<const MockRead> reads,
445                                          base::span<const MockWrite> writes)
446     : helper_(reads, writes) {
447   // Check that reads and writes have a contiguous set of sequence numbers
448   // starting from 0 and working their way up, with no repeats and skipping
449   // no values.
450   int next_sequence_number = 0;
451   bool last_event_was_pause = false;
452 
453   auto next_read = reads.begin();
454   auto next_write = writes.begin();
455   while (next_read != reads.end() || next_write != writes.end()) {
456     if (next_read != reads.end() &&
457         next_read->sequence_number == next_sequence_number) {
458       // Check if this is a pause.
459       if (next_read->mode == ASYNC && next_read->result == ERR_IO_PENDING) {
460         CHECK(!last_event_was_pause)
461             << "Two pauses in a row are not allowed: " << next_sequence_number;
462         last_event_was_pause = true;
463       } else if (last_event_was_pause) {
464         CHECK_EQ(ASYNC, next_read->mode)
465             << "A sync event after a pause makes no sense: "
466             << next_sequence_number;
467         CHECK_NE(ERR_IO_PENDING, next_read->result)
468             << "A pause event after a pause makes no sense: "
469             << next_sequence_number;
470         last_event_was_pause = false;
471       }
472 
473       ++next_read;
474       ++next_sequence_number;
475       continue;
476     }
477     if (next_write != writes.end() &&
478         next_write->sequence_number == next_sequence_number) {
479       // Check if this is a pause.
480       if (next_write->mode == ASYNC && next_write->result == ERR_IO_PENDING) {
481         CHECK(!last_event_was_pause)
482             << "Two pauses in a row are not allowed: " << next_sequence_number;
483         last_event_was_pause = true;
484       } else if (last_event_was_pause) {
485         CHECK_EQ(ASYNC, next_write->mode)
486             << "A sync event after a pause makes no sense: "
487             << next_sequence_number;
488         CHECK_NE(ERR_IO_PENDING, next_write->result)
489             << "A pause event after a pause makes no sense: "
490             << next_sequence_number;
491         last_event_was_pause = false;
492       }
493 
494       ++next_write;
495       ++next_sequence_number;
496       continue;
497     }
498     if (next_write != writes.end()) {
499       NOTREACHED() << "Sequence number " << next_write->sequence_number
500                    << " not found where expected: " << next_sequence_number;
501     }
502     NOTREACHED() << "Too few writes, next expected sequence number: "
503                  << next_sequence_number;
504   }
505 
506   // Last event must not be a pause.  For the final event to indicate the
507   // operation never completes, it should be SYNCHRONOUS and return
508   // ERR_IO_PENDING.
509   CHECK(!last_event_was_pause);
510 
511   CHECK(next_read == reads.end());
512   CHECK(next_write == writes.end());
513 }
514 
SequencedSocketData(const MockConnect & connect,base::span<const MockRead> reads,base::span<const MockWrite> writes)515 SequencedSocketData::SequencedSocketData(const MockConnect& connect,
516                                          base::span<const MockRead> reads,
517                                          base::span<const MockWrite> writes)
518     : SequencedSocketData(reads, writes) {
519   set_connect_data(connect);
520 }
OnRead()521 MockRead SequencedSocketData::OnRead() {
522   CHECK_EQ(IoState::kIdle, read_state_);
523   CHECK(!helper_.AllReadDataConsumed())
524       << "Application tried to read but there is no read data left";
525 
526   NET_TRACE(1, " *** ") << "sequence_number: " << sequence_number_;
527   const MockRead& next_read = helper_.PeekRead();
528   NET_TRACE(1, " *** ") << "next_read: " << next_read.sequence_number;
529   CHECK_GE(next_read.sequence_number, sequence_number_);
530 
531   if (next_read.sequence_number <= sequence_number_) {
532     if (next_read.mode == SYNCHRONOUS) {
533       NET_TRACE(1, " *** ") << "Returning synchronously";
534       DumpMockReadWrite(next_read);
535       helper_.AdvanceRead();
536       ++sequence_number_;
537       MaybePostWriteCompleteTask();
538       return next_read;
539     }
540 
541     // If the result is ERR_IO_PENDING, then pause.
542     if (next_read.result == ERR_IO_PENDING) {
543       NET_TRACE(1, " *** ") << "Pausing read at: " << sequence_number_;
544       read_state_ = IoState::kPaused;
545       if (run_until_paused_run_loop_)
546         run_until_paused_run_loop_->Quit();
547       return MockRead(SYNCHRONOUS, ERR_IO_PENDING);
548     }
549     base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
550         FROM_HERE, base::BindOnce(&SequencedSocketData::OnReadComplete,
551                                   weak_factory_.GetWeakPtr()));
552     CHECK_NE(IoState::kCompleting, write_state_);
553     read_state_ = IoState::kCompleting;
554   } else if (next_read.mode == SYNCHRONOUS) {
555     ADD_FAILURE() << "Unable to perform synchronous IO while stopped";
556     return MockRead(SYNCHRONOUS, ERR_UNEXPECTED);
557   } else {
558     NET_TRACE(1, " *** ") << "Waiting for write to trigger read";
559     read_state_ = IoState::kPending;
560   }
561 
562   return MockRead(SYNCHRONOUS, ERR_IO_PENDING);
563 }
564 
OnWrite(const std::string & data)565 MockWriteResult SequencedSocketData::OnWrite(const std::string& data) {
566   CHECK_EQ(IoState::kIdle, write_state_);
567   if (printer_) {
568     CHECK(!helper_.AllWriteDataConsumed())
569         << "\nNo more mock data to match write:\nFormatted write data:\n"
570         << printer_->PrintWrite(data) << "Raw write data:\n"
571         << HexDump(data);
572   } else {
573     CHECK(!helper_.AllWriteDataConsumed())
574         << "\nNo more mock data to match write:\nRaw write data:\n"
575         << HexDump(data);
576   }
577 
578   NET_TRACE(1, " *** ") << "sequence_number: " << sequence_number_;
579   const MockWrite& next_write = helper_.PeekWrite();
580   NET_TRACE(1, " *** ") << "next_write: " << next_write.sequence_number;
581   CHECK_GE(next_write.sequence_number, sequence_number_);
582 
583   if (!helper_.VerifyWriteData(data, printer_))
584     return MockWriteResult(SYNCHRONOUS, ERR_UNEXPECTED);
585 
586   if (next_write.sequence_number <= sequence_number_) {
587     if (next_write.mode == SYNCHRONOUS) {
588       helper_.AdvanceWrite();
589       ++sequence_number_;
590       MaybePostReadCompleteTask();
591       // In the case that the write was successful, return the number of bytes
592       // written. Otherwise return the error code.
593       int rv =
594           next_write.result != OK ? next_write.result : next_write.data_len;
595       NET_TRACE(1, " *** ") << "Returning synchronously";
596       return MockWriteResult(SYNCHRONOUS, rv);
597     }
598 
599     // If the result is ERR_IO_PENDING, then pause.
600     if (next_write.result == ERR_IO_PENDING) {
601       NET_TRACE(1, " *** ") << "Pausing write at: " << sequence_number_;
602       write_state_ = IoState::kPaused;
603       if (run_until_paused_run_loop_)
604         run_until_paused_run_loop_->Quit();
605       return MockWriteResult(SYNCHRONOUS, ERR_IO_PENDING);
606     }
607 
608     NET_TRACE(1, " *** ") << "Posting task to complete write";
609     base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
610         FROM_HERE, base::BindOnce(&SequencedSocketData::OnWriteComplete,
611                                   weak_factory_.GetWeakPtr()));
612     CHECK_NE(IoState::kCompleting, read_state_);
613     write_state_ = IoState::kCompleting;
614   } else if (next_write.mode == SYNCHRONOUS) {
615     ADD_FAILURE() << "Unable to perform synchronous IO while stopped";
616     return MockWriteResult(SYNCHRONOUS, ERR_UNEXPECTED);
617   } else {
618     NET_TRACE(1, " *** ") << "Waiting for read to trigger write";
619     write_state_ = IoState::kPending;
620   }
621 
622   return MockWriteResult(SYNCHRONOUS, ERR_IO_PENDING);
623 }
624 
AllReadDataConsumed() const625 bool SequencedSocketData::AllReadDataConsumed() const {
626   return helper_.AllReadDataConsumed();
627 }
628 
CancelPendingRead()629 void SequencedSocketData::CancelPendingRead() {
630   DCHECK_EQ(IoState::kPending, read_state_);
631 
632   read_state_ = IoState::kIdle;
633 }
634 
AllWriteDataConsumed() const635 bool SequencedSocketData::AllWriteDataConsumed() const {
636   return helper_.AllWriteDataConsumed();
637 }
638 
ExpectAllReadDataConsumed() const639 void SequencedSocketData::ExpectAllReadDataConsumed() const {
640   helper_.ExpectAllReadDataConsumed(printer_.get());
641 }
642 
ExpectAllWriteDataConsumed() const643 void SequencedSocketData::ExpectAllWriteDataConsumed() const {
644   helper_.ExpectAllWriteDataConsumed(printer_.get());
645 }
646 
IsIdle() const647 bool SequencedSocketData::IsIdle() const {
648   // If |busy_before_sync_reads_| is not set, always considered idle.  If
649   // no reads left, or the next operation is a write, also consider it idle.
650   if (!busy_before_sync_reads_ || helper_.AllReadDataConsumed() ||
651       helper_.PeekRead().sequence_number != sequence_number_) {
652     return true;
653   }
654 
655   // If the next operation is synchronous read, treat the socket as not idle.
656   if (helper_.PeekRead().mode == SYNCHRONOUS)
657     return false;
658   return true;
659 }
660 
IsPaused() const661 bool SequencedSocketData::IsPaused() const {
662   // Both states should not be paused.
663   DCHECK(read_state_ != IoState::kPaused || write_state_ != IoState::kPaused);
664   return write_state_ == IoState::kPaused || read_state_ == IoState::kPaused;
665 }
666 
Resume()667 void SequencedSocketData::Resume() {
668   if (!IsPaused()) {
669     ADD_FAILURE() << "Unable to Resume when not paused.";
670     return;
671   }
672 
673   sequence_number_++;
674   if (read_state_ == IoState::kPaused) {
675     read_state_ = IoState::kPending;
676     helper_.AdvanceRead();
677   } else {  // write_state_ == IoState::kPaused
678     write_state_ = IoState::kPending;
679     helper_.AdvanceWrite();
680   }
681 
682   if (!helper_.AllWriteDataConsumed() &&
683       helper_.PeekWrite().sequence_number == sequence_number_) {
684     // The next event hasn't even started yet.  Pausing isn't really needed in
685     // that case, but may as well support it.
686     if (write_state_ != IoState::kPending)
687       return;
688     write_state_ = IoState::kCompleting;
689     OnWriteComplete();
690     return;
691   }
692 
693   CHECK(!helper_.AllReadDataConsumed());
694 
695   // The next event hasn't even started yet.  Pausing isn't really needed in
696   // that case, but may as well support it.
697   if (read_state_ != IoState::kPending)
698     return;
699   read_state_ = IoState::kCompleting;
700   OnReadComplete();
701 }
702 
RunUntilPaused()703 void SequencedSocketData::RunUntilPaused() {
704   CHECK(!run_until_paused_run_loop_);
705 
706   if (IsPaused())
707     return;
708 
709   run_until_paused_run_loop_ = std::make_unique<base::RunLoop>();
710   run_until_paused_run_loop_->Run();
711   run_until_paused_run_loop_.reset();
712   DCHECK(IsPaused());
713 }
714 
MaybePostReadCompleteTask()715 void SequencedSocketData::MaybePostReadCompleteTask() {
716   NET_TRACE(1, " ****** ") << " current: " << sequence_number_;
717   // Only trigger the next read to complete if there is already a read pending
718   // which should complete at the current sequence number.
719   if (read_state_ != IoState::kPending ||
720       helper_.PeekRead().sequence_number != sequence_number_) {
721     return;
722   }
723 
724   // If the result is ERR_IO_PENDING, then pause.
725   if (helper_.PeekRead().result == ERR_IO_PENDING) {
726     NET_TRACE(1, " *** ") << "Pausing read at: " << sequence_number_;
727     read_state_ = IoState::kPaused;
728     if (run_until_paused_run_loop_)
729       run_until_paused_run_loop_->Quit();
730     return;
731   }
732 
733   NET_TRACE(1, " ****** ") << "Posting task to complete read: "
734                            << sequence_number_;
735   base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
736       FROM_HERE, base::BindOnce(&SequencedSocketData::OnReadComplete,
737                                 weak_factory_.GetWeakPtr()));
738   CHECK_NE(IoState::kCompleting, write_state_);
739   read_state_ = IoState::kCompleting;
740 }
741 
MaybePostWriteCompleteTask()742 void SequencedSocketData::MaybePostWriteCompleteTask() {
743   NET_TRACE(1, " ****** ") << " current: " << sequence_number_;
744   // Only trigger the next write to complete if there is already a write pending
745   // which should complete at the current sequence number.
746   if (write_state_ != IoState::kPending ||
747       helper_.PeekWrite().sequence_number != sequence_number_) {
748     return;
749   }
750 
751   // If the result is ERR_IO_PENDING, then pause.
752   if (helper_.PeekWrite().result == ERR_IO_PENDING) {
753     NET_TRACE(1, " *** ") << "Pausing write at: " << sequence_number_;
754     write_state_ = IoState::kPaused;
755     if (run_until_paused_run_loop_)
756       run_until_paused_run_loop_->Quit();
757     return;
758   }
759 
760   NET_TRACE(1, " ****** ") << "Posting task to complete write: "
761                            << sequence_number_;
762   base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
763       FROM_HERE, base::BindOnce(&SequencedSocketData::OnWriteComplete,
764                                 weak_factory_.GetWeakPtr()));
765   CHECK_NE(IoState::kCompleting, read_state_);
766   write_state_ = IoState::kCompleting;
767 }
768 
Reset()769 void SequencedSocketData::Reset() {
770   helper_.Reset();
771   sequence_number_ = 0;
772   read_state_ = IoState::kIdle;
773   write_state_ = IoState::kIdle;
774   weak_factory_.InvalidateWeakPtrs();
775 }
776 
OnReadComplete()777 void SequencedSocketData::OnReadComplete() {
778   CHECK_EQ(IoState::kCompleting, read_state_);
779   NET_TRACE(1, " *** ") << "Completing read for: " << sequence_number_;
780 
781   MockRead data = helper_.AdvanceRead();
782   DCHECK_EQ(sequence_number_, data.sequence_number);
783   sequence_number_++;
784   read_state_ = IoState::kIdle;
785 
786   // The result of this read completing might trigger the completion
787   // of a pending write. If so, post a task to complete the write later.
788   // Since the socket may call back into the SequencedSocketData
789   // from socket()->OnReadComplete(), trigger the write task to be posted
790   // before calling that.
791   MaybePostWriteCompleteTask();
792 
793   if (!socket()) {
794     NET_TRACE(1, " *** ") << "No socket available to complete read";
795     return;
796   }
797 
798   NET_TRACE(1, " *** ") << "Completing socket read for: "
799                         << data.sequence_number;
800   DumpMockReadWrite(data);
801   socket()->OnReadComplete(data);
802   NET_TRACE(1, " *** ") << "Done";
803 }
804 
OnWriteComplete()805 void SequencedSocketData::OnWriteComplete() {
806   CHECK_EQ(IoState::kCompleting, write_state_);
807   NET_TRACE(1, " *** ") << " Completing write for: " << sequence_number_;
808 
809   const MockWrite& data = helper_.AdvanceWrite();
810   DCHECK_EQ(sequence_number_, data.sequence_number);
811   sequence_number_++;
812   write_state_ = IoState::kIdle;
813   int rv = data.result == OK ? data.data_len : data.result;
814 
815   // The result of this write completing might trigger the completion
816   // of a pending read. If so, post a task to complete the read later.
817   // Since the socket may call back into the SequencedSocketData
818   // from socket()->OnWriteComplete(), trigger the write task to be posted
819   // before calling that.
820   MaybePostReadCompleteTask();
821 
822   if (!socket()) {
823     NET_TRACE(1, " *** ") << "No socket available to complete write";
824     return;
825   }
826 
827   NET_TRACE(1, " *** ") << " Completing socket write for: "
828                         << data.sequence_number;
829   socket()->OnWriteComplete(rv);
830   NET_TRACE(1, " *** ") << "Done";
831 }
832 
833 SequencedSocketData::~SequencedSocketData() = default;
834 
835 MockClientSocketFactory::MockClientSocketFactory() = default;
836 
837 MockClientSocketFactory::~MockClientSocketFactory() = default;
838 
AddSocketDataProvider(SocketDataProvider * data)839 void MockClientSocketFactory::AddSocketDataProvider(SocketDataProvider* data) {
840   mock_data_.Add(data);
841 }
842 
AddTcpSocketDataProvider(SocketDataProvider * data)843 void MockClientSocketFactory::AddTcpSocketDataProvider(
844     SocketDataProvider* data) {
845   mock_tcp_data_.Add(data);
846 }
847 
AddSSLSocketDataProvider(SSLSocketDataProvider * data)848 void MockClientSocketFactory::AddSSLSocketDataProvider(
849     SSLSocketDataProvider* data) {
850   mock_ssl_data_.Add(data);
851 }
852 
ResetNextMockIndexes()853 void MockClientSocketFactory::ResetNextMockIndexes() {
854   mock_data_.ResetNextIndex();
855   mock_ssl_data_.ResetNextIndex();
856 }
857 
858 std::unique_ptr<DatagramClientSocket>
CreateDatagramClientSocket(DatagramSocket::BindType bind_type,NetLog * net_log,const NetLogSource & source)859 MockClientSocketFactory::CreateDatagramClientSocket(
860     DatagramSocket::BindType bind_type,
861     NetLog* net_log,
862     const NetLogSource& source) {
863   NET_TRACE(1, " *** ") << "mock_data_index: " << mock_data_.next_index();
864   SocketDataProvider* data_provider = mock_data_.GetNext();
865   auto socket = std::make_unique<MockUDPClientSocket>(data_provider, net_log);
866   if (bind_type == DatagramSocket::RANDOM_BIND)
867     socket->set_source_port(static_cast<uint16_t>(base::RandInt(1025, 65535)));
868   udp_client_socket_ports_.push_back(socket->source_port());
869   return std::move(socket);
870 }
871 
872 std::unique_ptr<TransportClientSocket>
CreateTransportClientSocket(const AddressList & addresses,std::unique_ptr<SocketPerformanceWatcher> socket_performance_watcher,NetworkQualityEstimator * network_quality_estimator,NetLog * net_log,const NetLogSource & source)873 MockClientSocketFactory::CreateTransportClientSocket(
874     const AddressList& addresses,
875     std::unique_ptr<SocketPerformanceWatcher> socket_performance_watcher,
876     NetworkQualityEstimator* network_quality_estimator,
877     NetLog* net_log,
878     const NetLogSource& source) {
879   SocketDataProvider* data_provider = mock_tcp_data_.GetNextWithoutAsserting();
880   if (data_provider) {
881     NET_TRACE(1, " *** ") << "mock_tcp_data_index: "
882                           << (mock_tcp_data_.next_index() - 1);
883   } else {
884     NET_TRACE(1, " *** ") << "mock_data_index: " << mock_data_.next_index();
885     data_provider = mock_data_.GetNext();
886   }
887   auto socket =
888       std::make_unique<MockTCPClientSocket>(addresses, net_log, data_provider);
889   if (enable_read_if_ready_)
890     socket->set_enable_read_if_ready(enable_read_if_ready_);
891   return std::move(socket);
892 }
893 
CreateSSLClientSocket(SSLClientContext * context,std::unique_ptr<StreamSocket> stream_socket,const HostPortPair & host_and_port,const SSLConfig & ssl_config)894 std::unique_ptr<SSLClientSocket> MockClientSocketFactory::CreateSSLClientSocket(
895     SSLClientContext* context,
896     std::unique_ptr<StreamSocket> stream_socket,
897     const HostPortPair& host_and_port,
898     const SSLConfig& ssl_config) {
899   NET_TRACE(1, " *** ") << "mock_ssl_data_index: "
900                         << mock_ssl_data_.next_index();
901   SSLSocketDataProvider* next_ssl_data = mock_ssl_data_.GetNext();
902   if (next_ssl_data->next_protos_expected_in_ssl_config.has_value()) {
903     EXPECT_TRUE(base::ranges::equal(
904         next_ssl_data->next_protos_expected_in_ssl_config.value(),
905         ssl_config.alpn_protos));
906   }
907   if (next_ssl_data->expected_application_settings) {
908     EXPECT_EQ(*next_ssl_data->expected_application_settings,
909               ssl_config.application_settings);
910   }
911 
912   // The protocol version used is a combination of the per-socket SSLConfig and
913   // the SSLConfigService.
914   EXPECT_EQ(
915       next_ssl_data->expected_ssl_version_min,
916       ssl_config.version_min_override.value_or(context->config().version_min));
917   EXPECT_EQ(
918       next_ssl_data->expected_ssl_version_max,
919       ssl_config.version_max_override.value_or(context->config().version_max));
920 
921   if (next_ssl_data->expected_early_data_enabled) {
922     EXPECT_EQ(*next_ssl_data->expected_early_data_enabled,
923               ssl_config.early_data_enabled);
924   }
925 
926   if (next_ssl_data->expected_send_client_cert) {
927     // Client certificate preferences come from |context|.
928     scoped_refptr<X509Certificate> client_cert;
929     scoped_refptr<SSLPrivateKey> client_private_key;
930     bool send_client_cert = context->GetClientCertificate(
931         host_and_port, &client_cert, &client_private_key);
932 
933     EXPECT_EQ(*next_ssl_data->expected_send_client_cert, send_client_cert);
934     // Note |send_client_cert| may be true while |client_cert| is null if the
935     // socket is configured to continue without a certificate, as opposed to
936     // surfacing the certificate challenge.
937     EXPECT_EQ(!!next_ssl_data->expected_client_cert, !!client_cert);
938     if (next_ssl_data->expected_client_cert && client_cert) {
939       EXPECT_TRUE(next_ssl_data->expected_client_cert->EqualsIncludingChain(
940           client_cert.get()));
941     }
942   }
943   if (next_ssl_data->expected_host_and_port) {
944     EXPECT_EQ(*next_ssl_data->expected_host_and_port, host_and_port);
945   }
946   if (next_ssl_data->expected_ignore_certificate_errors) {
947     EXPECT_EQ(*next_ssl_data->expected_ignore_certificate_errors,
948               ssl_config.ignore_certificate_errors);
949   }
950   if (next_ssl_data->expected_network_anonymization_key) {
951     EXPECT_EQ(*next_ssl_data->expected_network_anonymization_key,
952               ssl_config.network_anonymization_key);
953   }
954   if (next_ssl_data->expected_ech_config_list) {
955     EXPECT_EQ(*next_ssl_data->expected_ech_config_list,
956               ssl_config.ech_config_list);
957   }
958   return std::make_unique<MockSSLClientSocket>(
959       std::move(stream_socket), host_and_port, ssl_config, next_ssl_data);
960 }
961 
MockClientSocket(const NetLogWithSource & net_log)962 MockClientSocket::MockClientSocket(const NetLogWithSource& net_log)
963     : net_log_(net_log) {
964   local_addr_ = IPEndPoint(IPAddress(192, 0, 2, 33), 123);
965   peer_addr_ = IPEndPoint(IPAddress(192, 0, 2, 33), 0);
966 }
967 
SetReceiveBufferSize(int32_t size)968 int MockClientSocket::SetReceiveBufferSize(int32_t size) {
969   return OK;
970 }
971 
SetSendBufferSize(int32_t size)972 int MockClientSocket::SetSendBufferSize(int32_t size) {
973   return OK;
974 }
975 
Bind(const net::IPEndPoint & local_addr)976 int MockClientSocket::Bind(const net::IPEndPoint& local_addr) {
977   local_addr_ = local_addr;
978   return net::OK;
979 }
980 
SetNoDelay(bool no_delay)981 bool MockClientSocket::SetNoDelay(bool no_delay) {
982   return true;
983 }
984 
SetKeepAlive(bool enable,int delay)985 bool MockClientSocket::SetKeepAlive(bool enable, int delay) {
986   return true;
987 }
988 
Disconnect()989 void MockClientSocket::Disconnect() {
990   connected_ = false;
991 }
992 
IsConnected() const993 bool MockClientSocket::IsConnected() const {
994   return connected_;
995 }
996 
IsConnectedAndIdle() const997 bool MockClientSocket::IsConnectedAndIdle() const {
998   return connected_;
999 }
1000 
GetPeerAddress(IPEndPoint * address) const1001 int MockClientSocket::GetPeerAddress(IPEndPoint* address) const {
1002   if (!IsConnected())
1003     return ERR_SOCKET_NOT_CONNECTED;
1004   *address = peer_addr_;
1005   return OK;
1006 }
1007 
GetLocalAddress(IPEndPoint * address) const1008 int MockClientSocket::GetLocalAddress(IPEndPoint* address) const {
1009   *address = local_addr_;
1010   return OK;
1011 }
1012 
NetLog() const1013 const NetLogWithSource& MockClientSocket::NetLog() const {
1014   return net_log_;
1015 }
1016 
GetNegotiatedProtocol() const1017 NextProto MockClientSocket::GetNegotiatedProtocol() const {
1018   return kProtoUnknown;
1019 }
1020 
1021 MockClientSocket::~MockClientSocket() = default;
1022 
RunCallbackAsync(CompletionOnceCallback callback,int result)1023 void MockClientSocket::RunCallbackAsync(CompletionOnceCallback callback,
1024                                         int result) {
1025   base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
1026       FROM_HERE,
1027       base::BindOnce(&MockClientSocket::RunCallback, weak_factory_.GetWeakPtr(),
1028                      std::move(callback), result));
1029 }
1030 
RunCallback(CompletionOnceCallback callback,int result)1031 void MockClientSocket::RunCallback(CompletionOnceCallback callback,
1032                                    int result) {
1033   std::move(callback).Run(result);
1034 }
1035 
MockTCPClientSocket(const AddressList & addresses,net::NetLog * net_log,SocketDataProvider * data)1036 MockTCPClientSocket::MockTCPClientSocket(const AddressList& addresses,
1037                                          net::NetLog* net_log,
1038                                          SocketDataProvider* data)
1039     : MockClientSocket(
1040           NetLogWithSource::Make(net_log, NetLogSourceType::SOCKET)),
1041       addresses_(addresses),
1042       data_(data),
1043       read_data_(SYNCHRONOUS, ERR_UNEXPECTED) {
1044   DCHECK(data_);
1045   peer_addr_ = data->connect_data().peer_addr;
1046   data_->Initialize(this);
1047   if (data_->expected_addresses()) {
1048     EXPECT_EQ(*data_->expected_addresses(), addresses);
1049   }
1050 }
1051 
~MockTCPClientSocket()1052 MockTCPClientSocket::~MockTCPClientSocket() {
1053   if (data_)
1054     data_->DetachSocket();
1055 }
1056 
Read(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)1057 int MockTCPClientSocket::Read(IOBuffer* buf,
1058                               int buf_len,
1059                               CompletionOnceCallback callback) {
1060   // If the buffer is already in use, a read is already in progress!
1061   DCHECK(!pending_read_buf_);
1062   // Use base::Unretained() is safe because MockClientSocket::RunCallbackAsync()
1063   // takes a weak ptr of the base class, MockClientSocket.
1064   int rv = ReadIfReadyImpl(
1065       buf, buf_len,
1066       base::BindOnce(&MockTCPClientSocket::RetryRead, base::Unretained(this)));
1067   if (rv == ERR_IO_PENDING) {
1068     DCHECK(callback);
1069 
1070     pending_read_buf_ = buf;
1071     pending_read_buf_len_ = buf_len;
1072     pending_read_callback_ = std::move(callback);
1073   }
1074   return rv;
1075 }
1076 
ReadIfReady(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)1077 int MockTCPClientSocket::ReadIfReady(IOBuffer* buf,
1078                                      int buf_len,
1079                                      CompletionOnceCallback callback) {
1080   DCHECK(!pending_read_if_ready_callback_);
1081 
1082   if (!enable_read_if_ready_)
1083     return ERR_READ_IF_READY_NOT_IMPLEMENTED;
1084   return ReadIfReadyImpl(buf, buf_len, std::move(callback));
1085 }
1086 
CancelReadIfReady()1087 int MockTCPClientSocket::CancelReadIfReady() {
1088   DCHECK(pending_read_if_ready_callback_);
1089 
1090   pending_read_if_ready_callback_.Reset();
1091   data_->CancelPendingRead();
1092   return OK;
1093 }
1094 
Write(IOBuffer * buf,int buf_len,CompletionOnceCallback callback,const NetworkTrafficAnnotationTag &)1095 int MockTCPClientSocket::Write(
1096     IOBuffer* buf,
1097     int buf_len,
1098     CompletionOnceCallback callback,
1099     const NetworkTrafficAnnotationTag& /* traffic_annotation */) {
1100   DCHECK(buf);
1101   DCHECK_GT(buf_len, 0);
1102 
1103   if (!connected_ || !data_)
1104     return ERR_UNEXPECTED;
1105 
1106   std::string data(buf->data(), buf_len);
1107   MockWriteResult write_result = data_->OnWrite(data);
1108 
1109   was_used_to_convey_data_ = true;
1110 
1111   if (write_result.result == ERR_CONNECTION_CLOSED) {
1112     // This MockWrite is just a marker to instruct us to set
1113     // peer_closed_connection_.
1114     peer_closed_connection_ = true;
1115   }
1116   // ERR_IO_PENDING is a signal that the socket data will call back
1117   // asynchronously later.
1118   if (write_result.result == ERR_IO_PENDING) {
1119     pending_write_callback_ = std::move(callback);
1120     return ERR_IO_PENDING;
1121   }
1122 
1123   if (write_result.mode == ASYNC) {
1124     RunCallbackAsync(std::move(callback), write_result.result);
1125     return ERR_IO_PENDING;
1126   }
1127 
1128   return write_result.result;
1129 }
1130 
SetReceiveBufferSize(int32_t size)1131 int MockTCPClientSocket::SetReceiveBufferSize(int32_t size) {
1132   if (!connected_)
1133     return net::ERR_UNEXPECTED;
1134   data_->set_receive_buffer_size(size);
1135   return data_->set_receive_buffer_size_result();
1136 }
1137 
SetSendBufferSize(int32_t size)1138 int MockTCPClientSocket::SetSendBufferSize(int32_t size) {
1139   if (!connected_)
1140     return net::ERR_UNEXPECTED;
1141   data_->set_send_buffer_size(size);
1142   return data_->set_send_buffer_size_result();
1143 }
1144 
SetNoDelay(bool no_delay)1145 bool MockTCPClientSocket::SetNoDelay(bool no_delay) {
1146   if (!connected_)
1147     return false;
1148   data_->set_no_delay(no_delay);
1149   return data_->set_no_delay_result();
1150 }
1151 
SetKeepAlive(bool enable,int delay)1152 bool MockTCPClientSocket::SetKeepAlive(bool enable, int delay) {
1153   if (!connected_)
1154     return false;
1155   data_->set_keep_alive(enable, delay);
1156   return data_->set_keep_alive_result();
1157 }
1158 
SetBeforeConnectCallback(const BeforeConnectCallback & before_connect_callback)1159 void MockTCPClientSocket::SetBeforeConnectCallback(
1160     const BeforeConnectCallback& before_connect_callback) {
1161   DCHECK(!before_connect_callback_);
1162   DCHECK(!connected_);
1163 
1164   before_connect_callback_ = before_connect_callback;
1165 }
1166 
Connect(CompletionOnceCallback callback)1167 int MockTCPClientSocket::Connect(CompletionOnceCallback callback) {
1168   if (!data_)
1169     return ERR_UNEXPECTED;
1170 
1171   if (connected_)
1172     return OK;
1173 
1174   // Setting socket options fails if not connected, so need to set this before
1175   // calling |before_connect_callback_|.
1176   connected_ = true;
1177 
1178   if (before_connect_callback_) {
1179     for (size_t index = 0; index < addresses_.size(); index++) {
1180       int result = before_connect_callback_.Run();
1181       if (data_->connect_data().first_attempt_fails && index == 0) {
1182         continue;
1183       }
1184       DCHECK_NE(result, ERR_IO_PENDING);
1185       if (result != net::OK) {
1186         connected_ = false;
1187         return result;
1188       }
1189       break;
1190     }
1191   }
1192 
1193   peer_closed_connection_ = false;
1194 
1195   if (data_->connect_data().completer) {
1196     data_->connect_data().completer->SetCallback(std::move(callback));
1197     return ERR_IO_PENDING;
1198   }
1199 
1200   int result = data_->connect_data().result;
1201   IoMode mode = data_->connect_data().mode;
1202   if (mode == SYNCHRONOUS)
1203     return result;
1204 
1205   DCHECK(callback);
1206 
1207   if (result == ERR_IO_PENDING)
1208     pending_connect_callback_ = std::move(callback);
1209   else
1210     RunCallbackAsync(std::move(callback), result);
1211   return ERR_IO_PENDING;
1212 }
1213 
Disconnect()1214 void MockTCPClientSocket::Disconnect() {
1215   MockClientSocket::Disconnect();
1216   pending_connect_callback_.Reset();
1217   pending_read_callback_.Reset();
1218 }
1219 
IsConnected() const1220 bool MockTCPClientSocket::IsConnected() const {
1221   if (!data_)
1222     return false;
1223   return connected_ && !peer_closed_connection_;
1224 }
1225 
IsConnectedAndIdle() const1226 bool MockTCPClientSocket::IsConnectedAndIdle() const {
1227   if (!data_)
1228     return false;
1229   return IsConnected() && data_->IsIdle();
1230 }
1231 
GetPeerAddress(IPEndPoint * address) const1232 int MockTCPClientSocket::GetPeerAddress(IPEndPoint* address) const {
1233   if (addresses_.empty())
1234     return MockClientSocket::GetPeerAddress(address);
1235 
1236   if (data_->connect_data().first_attempt_fails) {
1237     DCHECK_GE(addresses_.size(), 2U);
1238     *address = addresses_[1];
1239   } else {
1240     *address = addresses_[0];
1241   }
1242   return OK;
1243 }
1244 
WasEverUsed() const1245 bool MockTCPClientSocket::WasEverUsed() const {
1246   return was_used_to_convey_data_;
1247 }
1248 
GetSSLInfo(SSLInfo * ssl_info)1249 bool MockTCPClientSocket::GetSSLInfo(SSLInfo* ssl_info) {
1250   return false;
1251 }
1252 
OnReadComplete(const MockRead & data)1253 void MockTCPClientSocket::OnReadComplete(const MockRead& data) {
1254   // If |data_| has been destroyed, safest to just do nothing.
1255   if (!data_)
1256     return;
1257 
1258   // There must be a read pending.
1259   DCHECK(pending_read_if_ready_callback_);
1260   // You can't complete a read with another ERR_IO_PENDING status code.
1261   DCHECK_NE(ERR_IO_PENDING, data.result);
1262   // Since we've been waiting for data, need_read_data_ should be true.
1263   DCHECK(need_read_data_);
1264 
1265   read_data_ = data;
1266   need_read_data_ = false;
1267 
1268   // The caller is simulating that this IO completes right now.  Don't
1269   // let CompleteRead() schedule a callback.
1270   read_data_.mode = SYNCHRONOUS;
1271   RunCallback(std::move(pending_read_if_ready_callback_),
1272               read_data_.result > 0 ? OK : read_data_.result);
1273 }
1274 
OnWriteComplete(int rv)1275 void MockTCPClientSocket::OnWriteComplete(int rv) {
1276   // If |data_| has been destroyed, safest to just do nothing.
1277   if (!data_)
1278     return;
1279 
1280   // There must be a read pending.
1281   DCHECK(!pending_write_callback_.is_null());
1282   RunCallback(std::move(pending_write_callback_), rv);
1283 }
1284 
OnConnectComplete(const MockConnect & data)1285 void MockTCPClientSocket::OnConnectComplete(const MockConnect& data) {
1286   // If |data_| has been destroyed, safest to just do nothing.
1287   if (!data_)
1288     return;
1289 
1290   RunCallback(std::move(pending_connect_callback_), data.result);
1291 }
1292 
OnDataProviderDestroyed()1293 void MockTCPClientSocket::OnDataProviderDestroyed() {
1294   data_ = nullptr;
1295 }
1296 
RetryRead(int rv)1297 void MockTCPClientSocket::RetryRead(int rv) {
1298   DCHECK(pending_read_callback_);
1299   DCHECK(pending_read_buf_.get());
1300   DCHECK_LT(0, pending_read_buf_len_);
1301 
1302   if (rv == OK) {
1303     rv = ReadIfReadyImpl(pending_read_buf_.get(), pending_read_buf_len_,
1304                          base::BindOnce(&MockTCPClientSocket::RetryRead,
1305                                         base::Unretained(this)));
1306     if (rv == ERR_IO_PENDING)
1307       return;
1308   }
1309   pending_read_buf_ = nullptr;
1310   pending_read_buf_len_ = 0;
1311   RunCallback(std::move(pending_read_callback_), rv);
1312 }
1313 
ReadIfReadyImpl(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)1314 int MockTCPClientSocket::ReadIfReadyImpl(IOBuffer* buf,
1315                                          int buf_len,
1316                                          CompletionOnceCallback callback) {
1317   if (!connected_ || !data_)
1318     return ERR_UNEXPECTED;
1319 
1320   DCHECK(!pending_read_if_ready_callback_);
1321 
1322   if (need_read_data_) {
1323     read_data_ = data_->OnRead();
1324     if (read_data_.result == ERR_CONNECTION_CLOSED) {
1325       // This MockRead is just a marker to instruct us to set
1326       // peer_closed_connection_.
1327       peer_closed_connection_ = true;
1328     }
1329     if (read_data_.result == ERR_TEST_PEER_CLOSE_AFTER_NEXT_MOCK_READ) {
1330       // This MockRead is just a marker to instruct us to set
1331       // peer_closed_connection_.  Skip it and get the next one.
1332       read_data_ = data_->OnRead();
1333       peer_closed_connection_ = true;
1334     }
1335     // ERR_IO_PENDING means that the SocketDataProvider is taking responsibility
1336     // to complete the async IO manually later (via OnReadComplete).
1337     if (read_data_.result == ERR_IO_PENDING) {
1338       // We need to be using async IO in this case.
1339       DCHECK(!callback.is_null());
1340       pending_read_if_ready_callback_ = std::move(callback);
1341       return ERR_IO_PENDING;
1342     }
1343     need_read_data_ = false;
1344   }
1345 
1346   int result = read_data_.result;
1347   DCHECK_NE(ERR_IO_PENDING, result);
1348   if (read_data_.mode == ASYNC) {
1349     DCHECK(!callback.is_null());
1350     read_data_.mode = SYNCHRONOUS;
1351     pending_read_if_ready_callback_ = std::move(callback);
1352     // base::Unretained() is safe here because RunCallbackAsync will wrap it
1353     // with a callback associated with a weak ptr.
1354     RunCallbackAsync(
1355         base::BindOnce(&MockTCPClientSocket::RunReadIfReadyCallback,
1356                        base::Unretained(this)),
1357         result);
1358     return ERR_IO_PENDING;
1359   }
1360 
1361   was_used_to_convey_data_ = true;
1362   if (read_data_.data) {
1363     if (read_data_.data_len - read_offset_ > 0) {
1364       result = std::min(buf_len, read_data_.data_len - read_offset_);
1365       memcpy(buf->data(), read_data_.data + read_offset_, result);
1366       read_offset_ += result;
1367       if (read_offset_ == read_data_.data_len) {
1368         need_read_data_ = true;
1369         read_offset_ = 0;
1370       }
1371     } else {
1372       result = 0;  // EOF
1373     }
1374   }
1375   return result;
1376 }
1377 
RunReadIfReadyCallback(int result)1378 void MockTCPClientSocket::RunReadIfReadyCallback(int result) {
1379   // If ReadIfReady is already canceled, do nothing.
1380   if (!pending_read_if_ready_callback_)
1381     return;
1382   std::move(pending_read_if_ready_callback_).Run(result);
1383 }
1384 
1385 // static
ConnectCallback(MockSSLClientSocket * ssl_client_socket,CompletionOnceCallback callback,int rv)1386 void MockSSLClientSocket::ConnectCallback(
1387     MockSSLClientSocket* ssl_client_socket,
1388     CompletionOnceCallback callback,
1389     int rv) {
1390   if (rv == OK)
1391     ssl_client_socket->connected_ = true;
1392   std::move(callback).Run(rv);
1393 }
1394 
MockSSLClientSocket(std::unique_ptr<StreamSocket> stream_socket,const HostPortPair & host_and_port,const SSLConfig & ssl_config,SSLSocketDataProvider * data)1395 MockSSLClientSocket::MockSSLClientSocket(
1396     std::unique_ptr<StreamSocket> stream_socket,
1397     const HostPortPair& host_and_port,
1398     const SSLConfig& ssl_config,
1399     SSLSocketDataProvider* data)
1400     : net_log_(stream_socket->NetLog()),
1401       stream_socket_(std::move(stream_socket)),
1402       data_(data) {
1403   DCHECK(data_);
1404   peer_addr_ = data->connect.peer_addr;
1405 }
1406 
~MockSSLClientSocket()1407 MockSSLClientSocket::~MockSSLClientSocket() {
1408   Disconnect();
1409 }
1410 
Read(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)1411 int MockSSLClientSocket::Read(IOBuffer* buf,
1412                               int buf_len,
1413                               CompletionOnceCallback callback) {
1414   return stream_socket_->Read(buf, buf_len, std::move(callback));
1415 }
1416 
ReadIfReady(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)1417 int MockSSLClientSocket::ReadIfReady(IOBuffer* buf,
1418                                      int buf_len,
1419                                      CompletionOnceCallback callback) {
1420   return stream_socket_->ReadIfReady(buf, buf_len, std::move(callback));
1421 }
1422 
Write(IOBuffer * buf,int buf_len,CompletionOnceCallback callback,const NetworkTrafficAnnotationTag & traffic_annotation)1423 int MockSSLClientSocket::Write(
1424     IOBuffer* buf,
1425     int buf_len,
1426     CompletionOnceCallback callback,
1427     const NetworkTrafficAnnotationTag& traffic_annotation) {
1428   if (!data_->is_confirm_data_consumed)
1429     data_->write_called_before_confirm = true;
1430   return stream_socket_->Write(buf, buf_len, std::move(callback),
1431                                traffic_annotation);
1432 }
1433 
CancelReadIfReady()1434 int MockSSLClientSocket::CancelReadIfReady() {
1435   return stream_socket_->CancelReadIfReady();
1436 }
1437 
Connect(CompletionOnceCallback callback)1438 int MockSSLClientSocket::Connect(CompletionOnceCallback callback) {
1439   DCHECK(stream_socket_->IsConnected());
1440   data_->is_connect_data_consumed = true;
1441   if (data_->connect.completer) {
1442     data_->connect.completer->SetCallback(std::move(callback));
1443     return ERR_IO_PENDING;
1444   }
1445   if (data_->connect.result == OK)
1446     connected_ = true;
1447   RunClosureIfNonNull(std::move(data_->connect_callback));
1448   if (data_->connect.mode == ASYNC) {
1449     RunCallbackAsync(std::move(callback), data_->connect.result);
1450     return ERR_IO_PENDING;
1451   }
1452   return data_->connect.result;
1453 }
1454 
Disconnect()1455 void MockSSLClientSocket::Disconnect() {
1456   if (stream_socket_ != nullptr)
1457     stream_socket_->Disconnect();
1458 }
1459 
RunConfirmHandshakeCallback(CompletionOnceCallback callback,int result)1460 void MockSSLClientSocket::RunConfirmHandshakeCallback(
1461     CompletionOnceCallback callback,
1462     int result) {
1463   DCHECK(in_confirm_handshake_);
1464   in_confirm_handshake_ = false;
1465   data_->is_confirm_data_consumed = true;
1466   std::move(callback).Run(result);
1467 }
1468 
ConfirmHandshake(CompletionOnceCallback callback)1469 int MockSSLClientSocket::ConfirmHandshake(CompletionOnceCallback callback) {
1470   DCHECK(stream_socket_->IsConnected());
1471   DCHECK(!in_confirm_handshake_);
1472   if (data_->is_confirm_data_consumed)
1473     return data_->confirm.result;
1474   RunClosureIfNonNull(std::move(data_->confirm_callback));
1475   if (data_->confirm.mode == ASYNC) {
1476     in_confirm_handshake_ = true;
1477     RunCallbackAsync(
1478         base::BindOnce(&MockSSLClientSocket::RunConfirmHandshakeCallback,
1479                        base::Unretained(this), std::move(callback)),
1480         data_->confirm.result);
1481     return ERR_IO_PENDING;
1482   }
1483   data_->is_confirm_data_consumed = true;
1484   if (data_->confirm.result == ERR_IO_PENDING) {
1485     // `MockConfirm(SYNCHRONOUS, ERR_IO_PENDING)` means `ConfirmHandshake()`
1486     // never completes.
1487     in_confirm_handshake_ = true;
1488   }
1489   return data_->confirm.result;
1490 }
1491 
IsConnected() const1492 bool MockSSLClientSocket::IsConnected() const {
1493   return stream_socket_->IsConnected();
1494 }
1495 
IsConnectedAndIdle() const1496 bool MockSSLClientSocket::IsConnectedAndIdle() const {
1497   return stream_socket_->IsConnectedAndIdle();
1498 }
1499 
WasEverUsed() const1500 bool MockSSLClientSocket::WasEverUsed() const {
1501   return stream_socket_->WasEverUsed();
1502 }
1503 
GetLocalAddress(IPEndPoint * address) const1504 int MockSSLClientSocket::GetLocalAddress(IPEndPoint* address) const {
1505   *address = IPEndPoint(IPAddress(192, 0, 2, 33), 123);
1506   return OK;
1507 }
1508 
GetPeerAddress(IPEndPoint * address) const1509 int MockSSLClientSocket::GetPeerAddress(IPEndPoint* address) const {
1510   return stream_socket_->GetPeerAddress(address);
1511 }
1512 
GetNegotiatedProtocol() const1513 NextProto MockSSLClientSocket::GetNegotiatedProtocol() const {
1514   return data_->next_proto;
1515 }
1516 
1517 std::optional<std::string_view>
GetPeerApplicationSettings() const1518 MockSSLClientSocket::GetPeerApplicationSettings() const {
1519   return data_->peer_application_settings;
1520 }
1521 
GetSSLInfo(SSLInfo * requested_ssl_info)1522 bool MockSSLClientSocket::GetSSLInfo(SSLInfo* requested_ssl_info) {
1523   *requested_ssl_info = data_->ssl_info;
1524   return true;
1525 }
1526 
ApplySocketTag(const SocketTag & tag)1527 void MockSSLClientSocket::ApplySocketTag(const SocketTag& tag) {
1528   return stream_socket_->ApplySocketTag(tag);
1529 }
1530 
NetLog() const1531 const NetLogWithSource& MockSSLClientSocket::NetLog() const {
1532   return net_log_;
1533 }
1534 
GetTotalReceivedBytes() const1535 int64_t MockSSLClientSocket::GetTotalReceivedBytes() const {
1536   NOTIMPLEMENTED();
1537   return 0;
1538 }
1539 
GetTotalReceivedBytes() const1540 int64_t MockClientSocket::GetTotalReceivedBytes() const {
1541   NOTIMPLEMENTED();
1542   return 0;
1543 }
1544 
SetReceiveBufferSize(int32_t size)1545 int MockSSLClientSocket::SetReceiveBufferSize(int32_t size) {
1546   return OK;
1547 }
1548 
SetSendBufferSize(int32_t size)1549 int MockSSLClientSocket::SetSendBufferSize(int32_t size) {
1550   return OK;
1551 }
1552 
GetSSLCertRequestInfo(SSLCertRequestInfo * cert_request_info) const1553 void MockSSLClientSocket::GetSSLCertRequestInfo(
1554     SSLCertRequestInfo* cert_request_info) const {
1555   DCHECK(cert_request_info);
1556   if (data_->cert_request_info) {
1557     cert_request_info->host_and_port = data_->cert_request_info->host_and_port;
1558     cert_request_info->is_proxy = data_->cert_request_info->is_proxy;
1559     cert_request_info->cert_authorities =
1560         data_->cert_request_info->cert_authorities;
1561     cert_request_info->signature_algorithms =
1562         data_->cert_request_info->signature_algorithms;
1563   } else {
1564     cert_request_info->Reset();
1565   }
1566 }
1567 
ExportKeyingMaterial(std::string_view label,std::optional<base::span<const uint8_t>> context,base::span<uint8_t> out)1568 int MockSSLClientSocket::ExportKeyingMaterial(
1569     std::string_view label,
1570     std::optional<base::span<const uint8_t>> context,
1571     base::span<uint8_t> out) {
1572   std::ranges::fill(out, 'A');
1573   return OK;
1574 }
1575 
GetECHRetryConfigs()1576 std::vector<uint8_t> MockSSLClientSocket::GetECHRetryConfigs() {
1577   return data_->ech_retry_configs;
1578 }
1579 
RunCallbackAsync(CompletionOnceCallback callback,int result)1580 void MockSSLClientSocket::RunCallbackAsync(CompletionOnceCallback callback,
1581                                            int result) {
1582   base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
1583       FROM_HERE,
1584       base::BindOnce(&MockSSLClientSocket::RunCallback,
1585                      weak_factory_.GetWeakPtr(), std::move(callback), result));
1586 }
1587 
RunCallback(CompletionOnceCallback callback,int result)1588 void MockSSLClientSocket::RunCallback(CompletionOnceCallback callback,
1589                                       int result) {
1590   std::move(callback).Run(result);
1591 }
1592 
OnReadComplete(const MockRead & data)1593 void MockSSLClientSocket::OnReadComplete(const MockRead& data) {
1594   NOTIMPLEMENTED();
1595 }
1596 
OnWriteComplete(int rv)1597 void MockSSLClientSocket::OnWriteComplete(int rv) {
1598   NOTIMPLEMENTED();
1599 }
1600 
OnConnectComplete(const MockConnect & data)1601 void MockSSLClientSocket::OnConnectComplete(const MockConnect& data) {
1602   NOTIMPLEMENTED();
1603 }
1604 
MockUDPClientSocket(SocketDataProvider * data,net::NetLog * net_log)1605 MockUDPClientSocket::MockUDPClientSocket(SocketDataProvider* data,
1606                                          net::NetLog* net_log)
1607     : data_(data),
1608       read_data_(SYNCHRONOUS, ERR_UNEXPECTED),
1609       source_host_(IPAddress(192, 0, 2, 33)),
1610       net_log_(NetLogWithSource::Make(net_log,
1611                                       NetLogSourceType::UDP_CLIENT_SOCKET)) {
1612   if (data_) {
1613     data_->Initialize(this);
1614     peer_addr_ = data->connect_data().peer_addr;
1615   }
1616 }
1617 
~MockUDPClientSocket()1618 MockUDPClientSocket::~MockUDPClientSocket() {
1619   if (data_)
1620     data_->DetachSocket();
1621 }
1622 
Read(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)1623 int MockUDPClientSocket::Read(IOBuffer* buf,
1624                               int buf_len,
1625                               CompletionOnceCallback callback) {
1626   DCHECK(callback);
1627 
1628   if (!connected_ || !data_)
1629     return ERR_UNEXPECTED;
1630   data_transferred_ = true;
1631 
1632   // If the buffer is already in use, a read is already in progress!
1633   DCHECK(!pending_read_buf_);
1634 
1635   // Store our async IO data.
1636   pending_read_buf_ = buf;
1637   pending_read_buf_len_ = buf_len;
1638   pending_read_callback_ = std::move(callback);
1639 
1640   if (need_read_data_) {
1641     read_data_ = data_->OnRead();
1642     last_tos_ = read_data_.tos;
1643     // ERR_IO_PENDING means that the SocketDataProvider is taking responsibility
1644     // to complete the async IO manually later (via OnReadComplete).
1645     if (read_data_.result == ERR_IO_PENDING) {
1646       // We need to be using async IO in this case.
1647       DCHECK(!pending_read_callback_.is_null());
1648       return ERR_IO_PENDING;
1649     }
1650     need_read_data_ = false;
1651   }
1652 
1653   return CompleteRead();
1654 }
1655 
Write(IOBuffer * buf,int buf_len,CompletionOnceCallback callback,const NetworkTrafficAnnotationTag &)1656 int MockUDPClientSocket::Write(
1657     IOBuffer* buf,
1658     int buf_len,
1659     CompletionOnceCallback callback,
1660     const NetworkTrafficAnnotationTag& /* traffic_annotation */) {
1661   DCHECK(buf);
1662   DCHECK_GT(buf_len, 0);
1663   DCHECK(callback);
1664 
1665   if (!connected_ || !data_)
1666     return ERR_UNEXPECTED;
1667   data_transferred_ = true;
1668 
1669   std::string data(buf->data(), buf_len);
1670   MockWriteResult write_result = data_->OnWrite(data);
1671 
1672   // ERR_IO_PENDING is a signal that the socket data will call back
1673   // asynchronously.
1674   if (write_result.result == ERR_IO_PENDING) {
1675     pending_write_callback_ = std::move(callback);
1676     return ERR_IO_PENDING;
1677   }
1678   if (write_result.mode == ASYNC) {
1679     RunCallbackAsync(std::move(callback), write_result.result);
1680     return ERR_IO_PENDING;
1681   }
1682   return write_result.result;
1683 }
1684 
SetReceiveBufferSize(int32_t size)1685 int MockUDPClientSocket::SetReceiveBufferSize(int32_t size) {
1686   return OK;
1687 }
1688 
SetSendBufferSize(int32_t size)1689 int MockUDPClientSocket::SetSendBufferSize(int32_t size) {
1690   return OK;
1691 }
1692 
SetDoNotFragment()1693 int MockUDPClientSocket::SetDoNotFragment() {
1694   return OK;
1695 }
1696 
SetRecvTos()1697 int MockUDPClientSocket::SetRecvTos() {
1698   return OK;
1699 }
1700 
SetTos(DiffServCodePoint dscp,EcnCodePoint ecn)1701 int MockUDPClientSocket::SetTos(DiffServCodePoint dscp, EcnCodePoint ecn) {
1702   return OK;
1703 }
1704 
Close()1705 void MockUDPClientSocket::Close() {
1706   connected_ = false;
1707 }
1708 
GetPeerAddress(IPEndPoint * address) const1709 int MockUDPClientSocket::GetPeerAddress(IPEndPoint* address) const {
1710   if (!data_)
1711     return ERR_UNEXPECTED;
1712 
1713   *address = peer_addr_;
1714   return OK;
1715 }
1716 
GetLocalAddress(IPEndPoint * address) const1717 int MockUDPClientSocket::GetLocalAddress(IPEndPoint* address) const {
1718   *address = IPEndPoint(source_host_, source_port_);
1719   return OK;
1720 }
1721 
UseNonBlockingIO()1722 void MockUDPClientSocket::UseNonBlockingIO() {}
1723 
SetMulticastInterface(uint32_t interface_index)1724 int MockUDPClientSocket::SetMulticastInterface(uint32_t interface_index) {
1725   return OK;
1726 }
1727 
NetLog() const1728 const NetLogWithSource& MockUDPClientSocket::NetLog() const {
1729   return net_log_;
1730 }
1731 
Connect(const IPEndPoint & address)1732 int MockUDPClientSocket::Connect(const IPEndPoint& address) {
1733   if (!data_)
1734     return ERR_UNEXPECTED;
1735   DCHECK_NE(data_->connect_data().result, ERR_IO_PENDING);
1736   connected_ = true;
1737   peer_addr_ = address;
1738   return data_->connect_data().result;
1739 }
1740 
ConnectUsingNetwork(handles::NetworkHandle network,const IPEndPoint & address)1741 int MockUDPClientSocket::ConnectUsingNetwork(handles::NetworkHandle network,
1742                                              const IPEndPoint& address) {
1743   DCHECK(!connected_);
1744   if (!data_)
1745     return ERR_UNEXPECTED;
1746   DCHECK_NE(data_->connect_data().result, ERR_IO_PENDING);
1747   network_ = network;
1748   connected_ = true;
1749   peer_addr_ = address;
1750   return data_->connect_data().result;
1751 }
1752 
ConnectUsingDefaultNetwork(const IPEndPoint & address)1753 int MockUDPClientSocket::ConnectUsingDefaultNetwork(const IPEndPoint& address) {
1754   DCHECK(!connected_);
1755   if (!data_)
1756     return ERR_UNEXPECTED;
1757   DCHECK_NE(data_->connect_data().result, ERR_IO_PENDING);
1758   network_ = kDefaultNetworkForTests;
1759   connected_ = true;
1760   peer_addr_ = address;
1761   return data_->connect_data().result;
1762 }
1763 
ConnectAsync(const IPEndPoint & address,CompletionOnceCallback callback)1764 int MockUDPClientSocket::ConnectAsync(const IPEndPoint& address,
1765                                       CompletionOnceCallback callback) {
1766   DCHECK(callback);
1767   if (!data_) {
1768     return ERR_UNEXPECTED;
1769   }
1770   connected_ = true;
1771   peer_addr_ = address;
1772   int result = data_->connect_data().result;
1773   IoMode mode = data_->connect_data().mode;
1774   if (data_->connect_data().completer) {
1775     data_->connect_data().completer->SetCallback(std::move(callback));
1776     return ERR_IO_PENDING;
1777   }
1778   if (mode == SYNCHRONOUS) {
1779     return result;
1780   }
1781   RunCallbackAsync(std::move(callback), result);
1782   return ERR_IO_PENDING;
1783 }
1784 
ConnectUsingNetworkAsync(handles::NetworkHandle network,const IPEndPoint & address,CompletionOnceCallback callback)1785 int MockUDPClientSocket::ConnectUsingNetworkAsync(
1786     handles::NetworkHandle network,
1787     const IPEndPoint& address,
1788     CompletionOnceCallback callback) {
1789   DCHECK(callback);
1790   DCHECK(!connected_);
1791   if (!data_)
1792     return ERR_UNEXPECTED;
1793   network_ = network;
1794   connected_ = true;
1795   peer_addr_ = address;
1796   int result = data_->connect_data().result;
1797   IoMode mode = data_->connect_data().mode;
1798   if (data_->connect_data().completer) {
1799     data_->connect_data().completer->SetCallback(std::move(callback));
1800     return ERR_IO_PENDING;
1801   }
1802   if (mode == SYNCHRONOUS) {
1803     return result;
1804   }
1805   RunCallbackAsync(std::move(callback), result);
1806   return ERR_IO_PENDING;
1807 }
1808 
ConnectUsingDefaultNetworkAsync(const IPEndPoint & address,CompletionOnceCallback callback)1809 int MockUDPClientSocket::ConnectUsingDefaultNetworkAsync(
1810     const IPEndPoint& address,
1811     CompletionOnceCallback callback) {
1812   DCHECK(!connected_);
1813   if (!data_)
1814     return ERR_UNEXPECTED;
1815   network_ = kDefaultNetworkForTests;
1816   connected_ = true;
1817   peer_addr_ = address;
1818   int result = data_->connect_data().result;
1819   IoMode mode = data_->connect_data().mode;
1820   if (data_->connect_data().completer) {
1821     data_->connect_data().completer->SetCallback(std::move(callback));
1822     return ERR_IO_PENDING;
1823   }
1824   if (mode == SYNCHRONOUS) {
1825     return result;
1826   }
1827   RunCallbackAsync(std::move(callback), result);
1828   return ERR_IO_PENDING;
1829 }
1830 
GetBoundNetwork() const1831 handles::NetworkHandle MockUDPClientSocket::GetBoundNetwork() const {
1832   return network_;
1833 }
1834 
ApplySocketTag(const SocketTag & tag)1835 void MockUDPClientSocket::ApplySocketTag(const SocketTag& tag) {
1836   tagged_before_data_transferred_ &= !data_transferred_ || tag == tag_;
1837   tag_ = tag;
1838 }
1839 
GetLastTos() const1840 DscpAndEcn MockUDPClientSocket::GetLastTos() const {
1841   return TosToDscpAndEcn(last_tos_);
1842 }
1843 
OnReadComplete(const MockRead & data)1844 void MockUDPClientSocket::OnReadComplete(const MockRead& data) {
1845   if (!data_)
1846     return;
1847 
1848   // There must be a read pending.
1849   DCHECK(pending_read_buf_.get());
1850   DCHECK(pending_read_callback_);
1851   // You can't complete a read with another ERR_IO_PENDING status code.
1852   DCHECK_NE(ERR_IO_PENDING, data.result);
1853   // Since we've been waiting for data, need_read_data_ should be true.
1854   DCHECK(need_read_data_);
1855 
1856   read_data_ = data;
1857   last_tos_ = data.tos;
1858   need_read_data_ = false;
1859 
1860   // The caller is simulating that this IO completes right now.  Don't
1861   // let CompleteRead() schedule a callback.
1862   read_data_.mode = SYNCHRONOUS;
1863 
1864   CompletionOnceCallback callback = std::move(pending_read_callback_);
1865   int rv = CompleteRead();
1866   RunCallback(std::move(callback), rv);
1867 }
1868 
OnWriteComplete(int rv)1869 void MockUDPClientSocket::OnWriteComplete(int rv) {
1870   if (!data_)
1871     return;
1872 
1873   // There must be a read pending.
1874   DCHECK(!pending_write_callback_.is_null());
1875   RunCallback(std::move(pending_write_callback_), rv);
1876 }
1877 
OnConnectComplete(const MockConnect & data)1878 void MockUDPClientSocket::OnConnectComplete(const MockConnect& data) {
1879   NOTIMPLEMENTED();
1880 }
1881 
OnDataProviderDestroyed()1882 void MockUDPClientSocket::OnDataProviderDestroyed() {
1883   data_ = nullptr;
1884 }
1885 
CompleteRead()1886 int MockUDPClientSocket::CompleteRead() {
1887   DCHECK(pending_read_buf_.get());
1888   DCHECK(pending_read_buf_len_ > 0);
1889 
1890   // Save the pending async IO data and reset our |pending_| state.
1891   scoped_refptr<IOBuffer> buf = pending_read_buf_;
1892   int buf_len = pending_read_buf_len_;
1893   CompletionOnceCallback callback = std::move(pending_read_callback_);
1894   pending_read_buf_ = nullptr;
1895   pending_read_buf_len_ = 0;
1896 
1897   int result = read_data_.result;
1898   DCHECK(result != ERR_IO_PENDING);
1899 
1900   if (read_data_.data) {
1901     if (read_data_.data_len - read_offset_ > 0) {
1902       result = std::min(buf_len, read_data_.data_len - read_offset_);
1903       memcpy(buf->data(), read_data_.data + read_offset_, result);
1904       read_offset_ += result;
1905       if (read_offset_ == read_data_.data_len) {
1906         need_read_data_ = true;
1907         read_offset_ = 0;
1908       }
1909     } else {
1910       result = 0;  // EOF
1911     }
1912   }
1913 
1914   if (read_data_.mode == ASYNC) {
1915     DCHECK(!callback.is_null());
1916     RunCallbackAsync(std::move(callback), result);
1917     return ERR_IO_PENDING;
1918   }
1919   return result;
1920 }
1921 
RunCallbackAsync(CompletionOnceCallback callback,int result)1922 void MockUDPClientSocket::RunCallbackAsync(CompletionOnceCallback callback,
1923                                            int result) {
1924   base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
1925       FROM_HERE,
1926       base::BindOnce(&MockUDPClientSocket::RunCallback,
1927                      weak_factory_.GetWeakPtr(), std::move(callback), result));
1928 }
1929 
RunCallback(CompletionOnceCallback callback,int result)1930 void MockUDPClientSocket::RunCallback(CompletionOnceCallback callback,
1931                                       int result) {
1932   std::move(callback).Run(result);
1933 }
1934 
TestSocketRequest(std::vector<raw_ptr<TestSocketRequest,VectorExperimental>> * request_order,size_t * completion_count)1935 TestSocketRequest::TestSocketRequest(
1936     std::vector<raw_ptr<TestSocketRequest, VectorExperimental>>* request_order,
1937     size_t* completion_count)
1938     : request_order_(request_order), completion_count_(completion_count) {
1939   DCHECK(request_order);
1940   DCHECK(completion_count);
1941 }
1942 
1943 TestSocketRequest::~TestSocketRequest() = default;
1944 
OnComplete(int result)1945 void TestSocketRequest::OnComplete(int result) {
1946   SetResult(result);
1947   (*completion_count_)++;
1948   request_order_->push_back(this);
1949 }
1950 
1951 // static
1952 const int ClientSocketPoolTest::kIndexOutOfBounds = -1;
1953 
1954 // static
1955 const int ClientSocketPoolTest::kRequestNotFound = -2;
1956 
1957 ClientSocketPoolTest::ClientSocketPoolTest() = default;
1958 ClientSocketPoolTest::~ClientSocketPoolTest() = default;
1959 
GetOrderOfRequest(size_t index) const1960 int ClientSocketPoolTest::GetOrderOfRequest(size_t index) const {
1961   index--;
1962   if (index >= requests_.size())
1963     return kIndexOutOfBounds;
1964 
1965   for (size_t i = 0; i < request_order_.size(); i++)
1966     if (requests_[index].get() == request_order_[i])
1967       return i + 1;
1968 
1969   return kRequestNotFound;
1970 }
1971 
ReleaseOneConnection(KeepAlive keep_alive)1972 bool ClientSocketPoolTest::ReleaseOneConnection(KeepAlive keep_alive) {
1973   for (std::unique_ptr<TestSocketRequest>& it : requests_) {
1974     if (it->handle()->is_initialized()) {
1975       if (keep_alive == NO_KEEP_ALIVE)
1976         it->handle()->socket()->Disconnect();
1977       it->handle()->Reset();
1978       base::RunLoop().RunUntilIdle();
1979       return true;
1980     }
1981   }
1982   return false;
1983 }
1984 
ReleaseAllConnections(KeepAlive keep_alive)1985 void ClientSocketPoolTest::ReleaseAllConnections(KeepAlive keep_alive) {
1986   bool released_one;
1987   do {
1988     released_one = ReleaseOneConnection(keep_alive);
1989   } while (released_one);
1990 }
1991 
MockConnectJob(std::unique_ptr<StreamSocket> socket,ClientSocketHandle * handle,const SocketTag & socket_tag,CompletionOnceCallback callback,RequestPriority priority)1992 MockTransportClientSocketPool::MockConnectJob::MockConnectJob(
1993     std::unique_ptr<StreamSocket> socket,
1994     ClientSocketHandle* handle,
1995     const SocketTag& socket_tag,
1996     CompletionOnceCallback callback,
1997     RequestPriority priority)
1998     : socket_(std::move(socket)),
1999       handle_(handle),
2000       socket_tag_(socket_tag),
2001       user_callback_(std::move(callback)),
2002       priority_(priority) {}
2003 
2004 MockTransportClientSocketPool::MockConnectJob::~MockConnectJob() = default;
2005 
Connect()2006 int MockTransportClientSocketPool::MockConnectJob::Connect() {
2007   socket_->ApplySocketTag(socket_tag_);
2008   int rv = socket_->Connect(
2009       base::BindOnce(&MockConnectJob::OnConnect, base::Unretained(this)));
2010   if (rv != ERR_IO_PENDING) {
2011     user_callback_.Reset();
2012     OnConnect(rv);
2013   }
2014   return rv;
2015 }
2016 
CancelHandle(const ClientSocketHandle * handle)2017 bool MockTransportClientSocketPool::MockConnectJob::CancelHandle(
2018     const ClientSocketHandle* handle) {
2019   if (handle != handle_)
2020     return false;
2021   socket_.reset();
2022   handle_ = nullptr;
2023   user_callback_.Reset();
2024   return true;
2025 }
2026 
OnConnect(int rv)2027 void MockTransportClientSocketPool::MockConnectJob::OnConnect(int rv) {
2028   if (!socket_.get())
2029     return;
2030   if (rv == OK) {
2031     handle_->SetSocket(std::move(socket_));
2032 
2033     // Needed for socket pool tests that layer other sockets on top of mock
2034     // sockets.
2035     LoadTimingInfo::ConnectTiming connect_timing;
2036     base::TimeTicks now = base::TimeTicks::Now();
2037     connect_timing.domain_lookup_start = now;
2038     connect_timing.domain_lookup_end = now;
2039     connect_timing.connect_start = now;
2040     connect_timing.connect_end = now;
2041     handle_->set_connect_timing(connect_timing);
2042   } else {
2043     socket_.reset();
2044 
2045     // Needed to test copying of ConnectionAttempts in SSL ConnectJob.
2046     ConnectionAttempts attempts;
2047     attempts.push_back(ConnectionAttempt(IPEndPoint(), rv));
2048     handle_->set_connection_attempts(attempts);
2049   }
2050 
2051   handle_ = nullptr;
2052 
2053   if (!user_callback_.is_null()) {
2054     std::move(user_callback_).Run(rv);
2055   }
2056 }
2057 
MockTransportClientSocketPool(int max_sockets,int max_sockets_per_group,const CommonConnectJobParams * common_connect_job_params)2058 MockTransportClientSocketPool::MockTransportClientSocketPool(
2059     int max_sockets,
2060     int max_sockets_per_group,
2061     const CommonConnectJobParams* common_connect_job_params)
2062     : TransportClientSocketPool(
2063           max_sockets,
2064           max_sockets_per_group,
2065           base::Seconds(10) /* unused_idle_socket_timeout */,
2066           ProxyChain::Direct(),
2067           false /* is_for_websockets */,
2068           common_connect_job_params),
2069       client_socket_factory_(common_connect_job_params->client_socket_factory) {
2070 }
2071 
2072 MockTransportClientSocketPool::~MockTransportClientSocketPool() = default;
2073 
RequestSocket(const ClientSocketPool::GroupId & group_id,scoped_refptr<ClientSocketPool::SocketParams> socket_params,const std::optional<NetworkTrafficAnnotationTag> & proxy_annotation_tag,RequestPriority priority,const SocketTag & socket_tag,RespectLimits respect_limits,ClientSocketHandle * handle,CompletionOnceCallback callback,const ProxyAuthCallback & on_auth_callback,const NetLogWithSource & net_log)2074 int MockTransportClientSocketPool::RequestSocket(
2075     const ClientSocketPool::GroupId& group_id,
2076     scoped_refptr<ClientSocketPool::SocketParams> socket_params,
2077     const std::optional<NetworkTrafficAnnotationTag>& proxy_annotation_tag,
2078     RequestPriority priority,
2079     const SocketTag& socket_tag,
2080     RespectLimits respect_limits,
2081     ClientSocketHandle* handle,
2082     CompletionOnceCallback callback,
2083     const ProxyAuthCallback& on_auth_callback,
2084     const NetLogWithSource& net_log) {
2085   last_request_priority_ = priority;
2086   std::unique_ptr<StreamSocket> socket =
2087       client_socket_factory_->CreateTransportClientSocket(
2088           AddressList(), nullptr, nullptr, net_log.net_log(), NetLogSource());
2089   auto job = std::make_unique<MockConnectJob>(
2090       std::move(socket), handle, socket_tag, std::move(callback), priority);
2091   auto* job_ptr = job.get();
2092   job_list_.push_back(std::move(job));
2093   handle->set_group_generation(1);
2094   return job_ptr->Connect();
2095 }
2096 
SetPriority(const ClientSocketPool::GroupId & group_id,ClientSocketHandle * handle,RequestPriority priority)2097 void MockTransportClientSocketPool::SetPriority(
2098     const ClientSocketPool::GroupId& group_id,
2099     ClientSocketHandle* handle,
2100     RequestPriority priority) {
2101   for (auto& job : job_list_) {
2102     if (job->handle() == handle) {
2103       job->set_priority(priority);
2104       return;
2105     }
2106   }
2107   NOTREACHED();
2108 }
2109 
CancelRequest(const ClientSocketPool::GroupId & group_id,ClientSocketHandle * handle,bool cancel_connect_job)2110 void MockTransportClientSocketPool::CancelRequest(
2111     const ClientSocketPool::GroupId& group_id,
2112     ClientSocketHandle* handle,
2113     bool cancel_connect_job) {
2114   for (std::unique_ptr<MockConnectJob>& it : job_list_) {
2115     if (it->CancelHandle(handle)) {
2116       cancel_count_++;
2117       break;
2118     }
2119   }
2120 }
2121 
ReleaseSocket(const ClientSocketPool::GroupId & group_id,std::unique_ptr<StreamSocket> socket,int64_t generation)2122 void MockTransportClientSocketPool::ReleaseSocket(
2123     const ClientSocketPool::GroupId& group_id,
2124     std::unique_ptr<StreamSocket> socket,
2125     int64_t generation) {
2126   EXPECT_EQ(1, generation);
2127   release_count_++;
2128 }
2129 
WrappedStreamSocket(std::unique_ptr<StreamSocket> transport)2130 WrappedStreamSocket::WrappedStreamSocket(
2131     std::unique_ptr<StreamSocket> transport)
2132     : transport_(std::move(transport)) {}
2133 WrappedStreamSocket::~WrappedStreamSocket() = default;
2134 
Bind(const net::IPEndPoint & local_addr)2135 int WrappedStreamSocket::Bind(const net::IPEndPoint& local_addr) {
2136   NOTREACHED();
2137 }
2138 
Connect(CompletionOnceCallback callback)2139 int WrappedStreamSocket::Connect(CompletionOnceCallback callback) {
2140   return transport_->Connect(std::move(callback));
2141 }
2142 
Disconnect()2143 void WrappedStreamSocket::Disconnect() {
2144   transport_->Disconnect();
2145 }
2146 
IsConnected() const2147 bool WrappedStreamSocket::IsConnected() const {
2148   return transport_->IsConnected();
2149 }
2150 
IsConnectedAndIdle() const2151 bool WrappedStreamSocket::IsConnectedAndIdle() const {
2152   return transport_->IsConnectedAndIdle();
2153 }
2154 
GetPeerAddress(IPEndPoint * address) const2155 int WrappedStreamSocket::GetPeerAddress(IPEndPoint* address) const {
2156   return transport_->GetPeerAddress(address);
2157 }
2158 
GetLocalAddress(IPEndPoint * address) const2159 int WrappedStreamSocket::GetLocalAddress(IPEndPoint* address) const {
2160   return transport_->GetLocalAddress(address);
2161 }
2162 
NetLog() const2163 const NetLogWithSource& WrappedStreamSocket::NetLog() const {
2164   return transport_->NetLog();
2165 }
2166 
WasEverUsed() const2167 bool WrappedStreamSocket::WasEverUsed() const {
2168   return transport_->WasEverUsed();
2169 }
2170 
GetNegotiatedProtocol() const2171 NextProto WrappedStreamSocket::GetNegotiatedProtocol() const {
2172   return transport_->GetNegotiatedProtocol();
2173 }
2174 
GetSSLInfo(SSLInfo * ssl_info)2175 bool WrappedStreamSocket::GetSSLInfo(SSLInfo* ssl_info) {
2176   return transport_->GetSSLInfo(ssl_info);
2177 }
2178 
GetTotalReceivedBytes() const2179 int64_t WrappedStreamSocket::GetTotalReceivedBytes() const {
2180   return transport_->GetTotalReceivedBytes();
2181 }
2182 
ApplySocketTag(const SocketTag & tag)2183 void WrappedStreamSocket::ApplySocketTag(const SocketTag& tag) {
2184   transport_->ApplySocketTag(tag);
2185 }
2186 
Read(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)2187 int WrappedStreamSocket::Read(IOBuffer* buf,
2188                               int buf_len,
2189                               CompletionOnceCallback callback) {
2190   return transport_->Read(buf, buf_len, std::move(callback));
2191 }
2192 
ReadIfReady(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)2193 int WrappedStreamSocket::ReadIfReady(IOBuffer* buf,
2194                                      int buf_len,
2195                                      CompletionOnceCallback callback) {
2196   return transport_->ReadIfReady(buf, buf_len, std::move((callback)));
2197 }
2198 
Write(IOBuffer * buf,int buf_len,CompletionOnceCallback callback,const NetworkTrafficAnnotationTag & traffic_annotation)2199 int WrappedStreamSocket::Write(
2200     IOBuffer* buf,
2201     int buf_len,
2202     CompletionOnceCallback callback,
2203     const NetworkTrafficAnnotationTag& traffic_annotation) {
2204   return transport_->Write(buf, buf_len, std::move(callback),
2205                            TRAFFIC_ANNOTATION_FOR_TESTS);
2206 }
2207 
SetReceiveBufferSize(int32_t size)2208 int WrappedStreamSocket::SetReceiveBufferSize(int32_t size) {
2209   return transport_->SetReceiveBufferSize(size);
2210 }
2211 
SetSendBufferSize(int32_t size)2212 int WrappedStreamSocket::SetSendBufferSize(int32_t size) {
2213   return transport_->SetSendBufferSize(size);
2214 }
2215 
Connect(CompletionOnceCallback callback)2216 int MockTaggingStreamSocket::Connect(CompletionOnceCallback callback) {
2217   connected_ = true;
2218   return WrappedStreamSocket::Connect(std::move(callback));
2219 }
2220 
ApplySocketTag(const SocketTag & tag)2221 void MockTaggingStreamSocket::ApplySocketTag(const SocketTag& tag) {
2222   tagged_before_connected_ &= !connected_ || tag == tag_;
2223   tag_ = tag;
2224   transport_->ApplySocketTag(tag);
2225 }
2226 
2227 std::unique_ptr<TransportClientSocket>
CreateTransportClientSocket(const AddressList & addresses,std::unique_ptr<SocketPerformanceWatcher> socket_performance_watcher,NetworkQualityEstimator * network_quality_estimator,NetLog * net_log,const NetLogSource & source)2228 MockTaggingClientSocketFactory::CreateTransportClientSocket(
2229     const AddressList& addresses,
2230     std::unique_ptr<SocketPerformanceWatcher> socket_performance_watcher,
2231     NetworkQualityEstimator* network_quality_estimator,
2232     NetLog* net_log,
2233     const NetLogSource& source) {
2234   auto socket = std::make_unique<MockTaggingStreamSocket>(
2235       MockClientSocketFactory::CreateTransportClientSocket(
2236           addresses, std::move(socket_performance_watcher),
2237           network_quality_estimator, net_log, source));
2238   tcp_socket_ = socket.get();
2239   return std::move(socket);
2240 }
2241 
2242 std::unique_ptr<DatagramClientSocket>
CreateDatagramClientSocket(DatagramSocket::BindType bind_type,NetLog * net_log,const NetLogSource & source)2243 MockTaggingClientSocketFactory::CreateDatagramClientSocket(
2244     DatagramSocket::BindType bind_type,
2245     NetLog* net_log,
2246     const NetLogSource& source) {
2247   std::unique_ptr<DatagramClientSocket> socket(
2248       MockClientSocketFactory::CreateDatagramClientSocket(bind_type, net_log,
2249                                                           source));
2250   udp_socket_ = static_cast<MockUDPClientSocket*>(socket.get());
2251   return socket;
2252 }
2253 
2254 const char kSOCKS4TestHost[] = "127.0.0.1";
2255 const int kSOCKS4TestPort = 80;
2256 
2257 const char kSOCKS4OkRequestLocalHostPort80[] = {0x04, 0x01, 0x00, 0x50, 127,
2258                                                 0,    0,    1,    0};
2259 const int kSOCKS4OkRequestLocalHostPort80Length =
2260     std::size(kSOCKS4OkRequestLocalHostPort80);
2261 
2262 const char kSOCKS4OkReply[] = {0x00, 0x5A, 0x00, 0x00, 0, 0, 0, 0};
2263 const int kSOCKS4OkReplyLength = std::size(kSOCKS4OkReply);
2264 
2265 const char kSOCKS5TestHost[] = "host";
2266 const int kSOCKS5TestPort = 80;
2267 
2268 const char kSOCKS5GreetRequest[] = {0x05, 0x01, 0x00};
2269 const int kSOCKS5GreetRequestLength = std::size(kSOCKS5GreetRequest);
2270 
2271 const char kSOCKS5GreetResponse[] = {0x05, 0x00};
2272 const int kSOCKS5GreetResponseLength = std::size(kSOCKS5GreetResponse);
2273 
2274 const char kSOCKS5OkRequest[] = {0x05, 0x01, 0x00, 0x03, 0x04, 'h',
2275                                  'o',  's',  't',  0x00, 0x50};
2276 const int kSOCKS5OkRequestLength = std::size(kSOCKS5OkRequest);
2277 
2278 const char kSOCKS5OkResponse[] = {0x05, 0x00, 0x00, 0x01, 127,
2279                                   0,    0,    1,    0x00, 0x50};
2280 const int kSOCKS5OkResponseLength = std::size(kSOCKS5OkResponse);
2281 
CountReadBytes(base::span<const MockRead> reads)2282 int64_t CountReadBytes(base::span<const MockRead> reads) {
2283   int64_t total = 0;
2284   for (const MockRead& read : reads)
2285     total += read.data_len;
2286   return total;
2287 }
2288 
CountWriteBytes(base::span<const MockWrite> writes)2289 int64_t CountWriteBytes(base::span<const MockWrite> writes) {
2290   int64_t total = 0;
2291   for (const MockWrite& write : writes)
2292     total += write.data_len;
2293   return total;
2294 }
2295 
2296 #if BUILDFLAG(IS_ANDROID)
CanGetTaggedBytes()2297 bool CanGetTaggedBytes() {
2298   // In Android P, /proc/net/xt_qtaguid/stats is no longer guaranteed to be
2299   // present, and has been replaced with eBPF Traffic Monitoring in netd. See:
2300   // https://source.android.com/devices/tech/datausage/ebpf-traffic-monitor
2301   //
2302   // To read traffic statistics from netd, apps should use the API
2303   // NetworkStatsManager.queryDetailsForUidTag(). But this API does not provide
2304   // statistics for local traffic, only mobile and WiFi traffic, so it would not
2305   // work in tests that spin up a local server. So for now, GetTaggedBytes is
2306   // only supported on Android releases older than P.
2307   return base::android::BuildInfo::GetInstance()->sdk_int() <
2308          base::android::SDK_VERSION_P;
2309 }
2310 
GetTaggedBytes(int32_t expected_tag)2311 uint64_t GetTaggedBytes(int32_t expected_tag) {
2312   EXPECT_TRUE(CanGetTaggedBytes());
2313 
2314   // To determine how many bytes the system saw with a particular tag read
2315   // the /proc/net/xt_qtaguid/stats file which contains the kernel's
2316   // dump of all the UIDs and their tags sent and received bytes.
2317   uint64_t bytes = 0;
2318   std::string contents;
2319   EXPECT_TRUE(base::ReadFileToString(
2320       base::FilePath::FromUTF8Unsafe("/proc/net/xt_qtaguid/stats"), &contents));
2321   for (size_t i = contents.find('\n');  // Skip first line which is headers.
2322        i != std::string::npos && i < contents.length();) {
2323     uint64_t tag, rx_bytes;
2324     uid_t uid;
2325     int n;
2326     // Parse out the numbers we care about. For reference here's the column
2327     // headers:
2328     // idx iface acct_tag_hex uid_tag_int cnt_set rx_bytes rx_packets tx_bytes
2329     // tx_packets rx_tcp_bytes rx_tcp_packets rx_udp_bytes rx_udp_packets
2330     // rx_other_bytes rx_other_packets tx_tcp_bytes tx_tcp_packets tx_udp_bytes
2331     // tx_udp_packets tx_other_bytes tx_other_packets
2332     EXPECT_EQ(sscanf(contents.c_str() + i,
2333                      "%*d %*s 0x%" SCNx64 " %d %*d %" SCNu64
2334                      " %*d %*d %*d %*d %*d %*d %*d %*d "
2335                      "%*d %*d %*d %*d %*d %*d %*d%n",
2336                      &tag, &uid, &rx_bytes, &n),
2337               3);
2338     // If this line matches our UID and |expected_tag| then add it to the total.
2339     if (uid == getuid() && (int32_t)(tag >> 32) == expected_tag) {
2340       bytes += rx_bytes;
2341     }
2342     // Move |i| to the next line.
2343     i += n + 1;
2344   }
2345   return bytes;
2346 }
2347 #endif
2348 
2349 }  // namespace net
2350