• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/socket/socket_test_util.h"
6 
7 #include <inttypes.h>  // For SCNx64
8 #include <stdint.h>
9 #include <stdio.h>
10 
11 #include <memory>
12 #include <string>
13 #include <utility>
14 #include <vector>
15 
16 #include "base/compiler_specific.h"
17 #include "base/files/file_util.h"
18 #include "base/functional/bind.h"
19 #include "base/functional/callback_helpers.h"
20 #include "base/location.h"
21 #include "base/logging.h"
22 #include "base/rand_util.h"
23 #include "base/ranges/algorithm.h"
24 #include "base/run_loop.h"
25 #include "base/task/single_thread_task_runner.h"
26 #include "base/time/time.h"
27 #include "build/build_config.h"
28 #include "net/base/address_family.h"
29 #include "net/base/address_list.h"
30 #include "net/base/auth.h"
31 #include "net/base/hex_utils.h"
32 #include "net/base/ip_address.h"
33 #include "net/base/load_timing_info.h"
34 #include "net/base/proxy_server.h"
35 #include "net/http/http_network_session.h"
36 #include "net/http/http_request_headers.h"
37 #include "net/http/http_response_headers.h"
38 #include "net/log/net_log_source.h"
39 #include "net/log/net_log_source_type.h"
40 #include "net/socket/connect_job.h"
41 #include "net/socket/socket.h"
42 #include "net/socket/stream_socket.h"
43 #include "net/socket/websocket_endpoint_lock_manager.h"
44 #include "net/ssl/ssl_cert_request_info.h"
45 #include "net/ssl/ssl_connection_status_flags.h"
46 #include "net/ssl/ssl_info.h"
47 #include "net/traffic_annotation/network_traffic_annotation.h"
48 #include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
49 #include "testing/gtest/include/gtest/gtest.h"
50 #include "third_party/abseil-cpp/absl/strings/ascii.h"
51 
52 #if BUILDFLAG(IS_ANDROID)
53 #include "base/android/build_info.h"
54 #endif
55 
56 #define NET_TRACE(level, s) VLOG(level) << s << __FUNCTION__ << "() "
57 
58 namespace net {
59 namespace {
60 
AsciifyHigh(char x)61 inline char AsciifyHigh(char x) {
62   char nybble = static_cast<char>((x >> 4) & 0x0F);
63   return nybble + ((nybble < 0x0A) ? '0' : 'A' - 10);
64 }
65 
AsciifyLow(char x)66 inline char AsciifyLow(char x) {
67   char nybble = static_cast<char>((x >> 0) & 0x0F);
68   return nybble + ((nybble < 0x0A) ? '0' : 'A' - 10);
69 }
70 
Asciify(char x)71 inline char Asciify(char x) {
72   return absl::ascii_isprint(static_cast<unsigned char>(x)) ? x : '.';
73 }
74 
DumpData(const char * data,int data_len)75 void DumpData(const char* data, int data_len) {
76   if (logging::LOG_INFO < logging::GetMinLogLevel())
77     return;
78   DVLOG(1) << "Length:  " << data_len;
79   const char* pfx = "Data:    ";
80   if (!data || (data_len <= 0)) {
81     DVLOG(1) << pfx << "<None>";
82   } else {
83     int i;
84     for (i = 0; i <= (data_len - 4); i += 4) {
85       DVLOG(1) << pfx << AsciifyHigh(data[i + 0]) << AsciifyLow(data[i + 0])
86                << AsciifyHigh(data[i + 1]) << AsciifyLow(data[i + 1])
87                << AsciifyHigh(data[i + 2]) << AsciifyLow(data[i + 2])
88                << AsciifyHigh(data[i + 3]) << AsciifyLow(data[i + 3]) << "  '"
89                << Asciify(data[i + 0]) << Asciify(data[i + 1])
90                << Asciify(data[i + 2]) << Asciify(data[i + 3]) << "'";
91       pfx = "         ";
92     }
93     // Take care of any 'trailing' bytes, if data_len was not a multiple of 4.
94     switch (data_len - i) {
95       case 3:
96         DVLOG(1) << pfx << AsciifyHigh(data[i + 0]) << AsciifyLow(data[i + 0])
97                  << AsciifyHigh(data[i + 1]) << AsciifyLow(data[i + 1])
98                  << AsciifyHigh(data[i + 2]) << AsciifyLow(data[i + 2])
99                  << "    '" << Asciify(data[i + 0]) << Asciify(data[i + 1])
100                  << Asciify(data[i + 2]) << " '";
101         break;
102       case 2:
103         DVLOG(1) << pfx << AsciifyHigh(data[i + 0]) << AsciifyLow(data[i + 0])
104                  << AsciifyHigh(data[i + 1]) << AsciifyLow(data[i + 1])
105                  << "      '" << Asciify(data[i + 0]) << Asciify(data[i + 1])
106                  << "  '";
107         break;
108       case 1:
109         DVLOG(1) << pfx << AsciifyHigh(data[i + 0]) << AsciifyLow(data[i + 0])
110                  << "        '" << Asciify(data[i + 0]) << "   '";
111         break;
112     }
113   }
114 }
115 
116 template <MockReadWriteType type>
DumpMockReadWrite(const MockReadWrite<type> & r)117 void DumpMockReadWrite(const MockReadWrite<type>& r) {
118   if (logging::LOG_INFO < logging::GetMinLogLevel())
119     return;
120   DVLOG(1) << "Async:   " << (r.mode == ASYNC) << "\nResult:  " << r.result;
121   DumpData(r.data, r.data_len);
122   const char* stop = (r.sequence_number & MockRead::STOPLOOP) ? " (STOP)" : "";
123   DVLOG(1) << "Stage:   " << (r.sequence_number & ~MockRead::STOPLOOP) << stop;
124 }
125 
RunClosureIfNonNull(base::OnceClosure closure)126 void RunClosureIfNonNull(base::OnceClosure closure) {
127   if (!closure.is_null()) {
128     std::move(closure).Run();
129   }
130 }
131 
132 }  // namespace
133 
MockConnect()134 MockConnect::MockConnect() : mode(ASYNC), result(OK) {
135   peer_addr = IPEndPoint(IPAddress(192, 0, 2, 33), 0);
136 }
137 
MockConnect(IoMode io_mode,int r)138 MockConnect::MockConnect(IoMode io_mode, int r) : mode(io_mode), result(r) {
139   peer_addr = IPEndPoint(IPAddress(192, 0, 2, 33), 0);
140 }
141 
MockConnect(IoMode io_mode,int r,IPEndPoint addr)142 MockConnect::MockConnect(IoMode io_mode, int r, IPEndPoint addr)
143     : mode(io_mode), result(r), peer_addr(addr) {}
144 
MockConnect(IoMode io_mode,int r,IPEndPoint addr,bool first_attempt_fails)145 MockConnect::MockConnect(IoMode io_mode,
146                          int r,
147                          IPEndPoint addr,
148                          bool first_attempt_fails)
149     : mode(io_mode),
150       result(r),
151       peer_addr(addr),
152       first_attempt_fails(first_attempt_fails) {}
153 
154 MockConnect::~MockConnect() = default;
155 
MockConfirm()156 MockConfirm::MockConfirm() : mode(SYNCHRONOUS), result(OK) {}
157 
MockConfirm(IoMode io_mode,int r)158 MockConfirm::MockConfirm(IoMode io_mode, int r) : mode(io_mode), result(r) {}
159 
160 MockConfirm::~MockConfirm() = default;
161 
IsIdle() const162 bool SocketDataProvider::IsIdle() const {
163   return true;
164 }
165 
Initialize(AsyncSocket * socket)166 void SocketDataProvider::Initialize(AsyncSocket* socket) {
167   CHECK(!socket_);
168   CHECK(socket);
169   socket_ = socket;
170   Reset();
171 }
172 
DetachSocket()173 void SocketDataProvider::DetachSocket() {
174   CHECK(socket_);
175   socket_ = nullptr;
176 }
177 
178 SocketDataProvider::SocketDataProvider() = default;
179 
~SocketDataProvider()180 SocketDataProvider::~SocketDataProvider() {
181   if (socket_)
182     socket_->OnDataProviderDestroyed();
183 }
184 
StaticSocketDataHelper(base::span<const MockRead> reads,base::span<const MockWrite> writes)185 StaticSocketDataHelper::StaticSocketDataHelper(
186     base::span<const MockRead> reads,
187     base::span<const MockWrite> writes)
188     : reads_(reads), writes_(writes) {}
189 
190 StaticSocketDataHelper::~StaticSocketDataHelper() = default;
191 
PeekRead() const192 const MockRead& StaticSocketDataHelper::PeekRead() const {
193   CHECK(!AllReadDataConsumed());
194   return reads_[read_index_];
195 }
196 
PeekWrite() const197 const MockWrite& StaticSocketDataHelper::PeekWrite() const {
198   CHECK(!AllWriteDataConsumed());
199   return writes_[write_index_];
200 }
201 
AdvanceRead()202 const MockRead& StaticSocketDataHelper::AdvanceRead() {
203   CHECK(!AllReadDataConsumed());
204   return reads_[read_index_++];
205 }
206 
AdvanceWrite()207 const MockWrite& StaticSocketDataHelper::AdvanceWrite() {
208   CHECK(!AllWriteDataConsumed());
209   return writes_[write_index_++];
210 }
211 
Reset()212 void StaticSocketDataHelper::Reset() {
213   read_index_ = 0;
214   write_index_ = 0;
215 }
216 
VerifyWriteData(const std::string & data,SocketDataPrinter * printer)217 bool StaticSocketDataHelper::VerifyWriteData(const std::string& data,
218                                              SocketDataPrinter* printer) {
219   CHECK(!AllWriteDataConsumed());
220   // Check that the actual data matches the expectations, skipping over any
221   // pause events.
222   const MockWrite& next_write = PeekRealWrite();
223   if (!next_write.data)
224     return true;
225 
226   // Note: Partial writes are supported here.  If the expected data
227   // is a match, but shorter than the write actually written, that is legal.
228   // Example:
229   //   Application writes "foobarbaz" (9 bytes)
230   //   Expected write was "foo" (3 bytes)
231   //   This is a success, and the function returns true.
232   std::string expected_data(next_write.data, next_write.data_len);
233   std::string actual_data(data.substr(0, next_write.data_len));
234   EXPECT_GE(data.length(), expected_data.length());
235   if (printer) {
236     EXPECT_TRUE(actual_data == expected_data)
237         << "Actual formatted write data:\n"
238         << printer->PrintWrite(data) << "Expected formatted write data:\n"
239         << printer->PrintWrite(expected_data) << "Actual raw write data:\n"
240         << HexDump(data) << "Expected raw write data:\n"
241         << HexDump(expected_data);
242   } else {
243     EXPECT_TRUE(actual_data == expected_data)
244         << "Actual write data:\n"
245         << HexDump(data) << "Expected write data:\n"
246         << HexDump(expected_data);
247   }
248   return expected_data == actual_data;
249 }
250 
PeekRealWrite() const251 const MockWrite& StaticSocketDataHelper::PeekRealWrite() const {
252   for (size_t i = write_index_; i < write_count(); i++) {
253     if (writes_[i].mode != ASYNC || writes_[i].result != ERR_IO_PENDING)
254       return writes_[i];
255   }
256 
257   CHECK(false) << "No write data available.";
258   return writes_[0];  // Avoid warning about unreachable missing return.
259 }
260 
StaticSocketDataProvider()261 StaticSocketDataProvider::StaticSocketDataProvider()
262     : StaticSocketDataProvider(base::span<const MockRead>(),
263                                base::span<const MockWrite>()) {}
264 
StaticSocketDataProvider(base::span<const MockRead> reads,base::span<const MockWrite> writes)265 StaticSocketDataProvider::StaticSocketDataProvider(
266     base::span<const MockRead> reads,
267     base::span<const MockWrite> writes)
268     : helper_(reads, writes) {}
269 
270 StaticSocketDataProvider::~StaticSocketDataProvider() = default;
271 
Pause()272 void StaticSocketDataProvider::Pause() {
273   paused_ = true;
274 }
275 
Resume()276 void StaticSocketDataProvider::Resume() {
277   paused_ = false;
278 }
279 
OnRead()280 MockRead StaticSocketDataProvider::OnRead() {
281   if (AllReadDataConsumed()) {
282     const net::MockRead pending_read(net::SYNCHRONOUS, net::ERR_IO_PENDING);
283     return pending_read;
284   }
285 
286   return helper_.AdvanceRead();
287 }
288 
OnWrite(const std::string & data)289 MockWriteResult StaticSocketDataProvider::OnWrite(const std::string& data) {
290   if (helper_.write_count() == 0) {
291     // Not using mock writes; succeed synchronously.
292     return MockWriteResult(SYNCHRONOUS, data.length());
293   }
294   if (printer_) {
295     EXPECT_FALSE(helper_.AllWriteDataConsumed())
296         << "No more mock data to match write:\nFormatted write data:\n"
297         << printer_->PrintWrite(data) << "Raw write data:\n"
298         << HexDump(data);
299   } else {
300     EXPECT_FALSE(helper_.AllWriteDataConsumed())
301         << "No more mock data to match write:\nRaw write data:\n"
302         << HexDump(data);
303   }
304   if (helper_.AllWriteDataConsumed()) {
305     return MockWriteResult(SYNCHRONOUS, ERR_UNEXPECTED);
306   }
307 
308   // Check that what we are writing matches the expectation.
309   // Then give the mocked return value.
310   if (!helper_.VerifyWriteData(data, printer_))
311     return MockWriteResult(SYNCHRONOUS, ERR_UNEXPECTED);
312 
313   const MockWrite& next_write = helper_.AdvanceWrite();
314   // In the case that the write was successful, return the number of bytes
315   // written. Otherwise return the error code.
316   int result =
317       next_write.result == OK ? next_write.data_len : next_write.result;
318   return MockWriteResult(next_write.mode, result);
319 }
320 
AllReadDataConsumed() const321 bool StaticSocketDataProvider::AllReadDataConsumed() const {
322   return paused_ || helper_.AllReadDataConsumed();
323 }
324 
AllWriteDataConsumed() const325 bool StaticSocketDataProvider::AllWriteDataConsumed() const {
326   return helper_.AllWriteDataConsumed();
327 }
328 
Reset()329 void StaticSocketDataProvider::Reset() {
330   helper_.Reset();
331 }
332 
SSLSocketDataProvider(IoMode mode,int result)333 SSLSocketDataProvider::SSLSocketDataProvider(IoMode mode, int result)
334     : connect(mode, result),
335       expected_ssl_version_min(kDefaultSSLVersionMin),
336       expected_ssl_version_max(kDefaultSSLVersionMax) {
337   SSLConnectionStatusSetVersion(SSL_CONNECTION_VERSION_TLS1_3,
338                                 &ssl_info.connection_status);
339   // Set to TLS_CHACHA20_POLY1305_SHA256
340   SSLConnectionStatusSetCipherSuite(0x1301, &ssl_info.connection_status);
341 }
342 
343 SSLSocketDataProvider::SSLSocketDataProvider(
344     const SSLSocketDataProvider& other) = default;
345 
346 SSLSocketDataProvider::~SSLSocketDataProvider() = default;
347 
SequencedSocketData()348 SequencedSocketData::SequencedSocketData()
349     : SequencedSocketData(base::span<const MockRead>(),
350                           base::span<const MockWrite>()) {}
351 
SequencedSocketData(base::span<const MockRead> reads,base::span<const MockWrite> writes)352 SequencedSocketData::SequencedSocketData(base::span<const MockRead> reads,
353                                          base::span<const MockWrite> writes)
354     : helper_(reads, writes) {
355   // Check that reads and writes have a contiguous set of sequence numbers
356   // starting from 0 and working their way up, with no repeats and skipping
357   // no values.
358   int next_sequence_number = 0;
359   bool last_event_was_pause = false;
360 
361   auto next_read = reads.begin();
362   auto next_write = writes.begin();
363   while (next_read != reads.end() || next_write != writes.end()) {
364     if (next_read != reads.end() &&
365         next_read->sequence_number == next_sequence_number) {
366       // Check if this is a pause.
367       if (next_read->mode == ASYNC && next_read->result == ERR_IO_PENDING) {
368         CHECK(!last_event_was_pause)
369             << "Two pauses in a row are not allowed: " << next_sequence_number;
370         last_event_was_pause = true;
371       } else if (last_event_was_pause) {
372         CHECK_EQ(ASYNC, next_read->mode)
373             << "A sync event after a pause makes no sense: "
374             << next_sequence_number;
375         CHECK_NE(ERR_IO_PENDING, next_read->result)
376             << "A pause event after a pause makes no sense: "
377             << next_sequence_number;
378         last_event_was_pause = false;
379       }
380 
381       ++next_read;
382       ++next_sequence_number;
383       continue;
384     }
385     if (next_write != writes.end() &&
386         next_write->sequence_number == next_sequence_number) {
387       // Check if this is a pause.
388       if (next_write->mode == ASYNC && next_write->result == ERR_IO_PENDING) {
389         CHECK(!last_event_was_pause)
390             << "Two pauses in a row are not allowed: " << next_sequence_number;
391         last_event_was_pause = true;
392       } else if (last_event_was_pause) {
393         CHECK_EQ(ASYNC, next_write->mode)
394             << "A sync event after a pause makes no sense: "
395             << next_sequence_number;
396         CHECK_NE(ERR_IO_PENDING, next_write->result)
397             << "A pause event after a pause makes no sense: "
398             << next_sequence_number;
399         last_event_was_pause = false;
400       }
401 
402       ++next_write;
403       ++next_sequence_number;
404       continue;
405     }
406     if (next_write != writes.end()) {
407       CHECK(false) << "Sequence number " << next_write->sequence_number
408                    << " not found where expected: " << next_sequence_number;
409     } else {
410       CHECK(false) << "Too few writes, next expected sequence number: "
411                    << next_sequence_number;
412     }
413     return;
414   }
415 
416   // Last event must not be a pause.  For the final event to indicate the
417   // operation never completes, it should be SYNCHRONOUS and return
418   // ERR_IO_PENDING.
419   CHECK(!last_event_was_pause);
420 
421   CHECK(next_read == reads.end());
422   CHECK(next_write == writes.end());
423 }
424 
SequencedSocketData(const MockConnect & connect,base::span<const MockRead> reads,base::span<const MockWrite> writes)425 SequencedSocketData::SequencedSocketData(const MockConnect& connect,
426                                          base::span<const MockRead> reads,
427                                          base::span<const MockWrite> writes)
428     : SequencedSocketData(reads, writes) {
429   set_connect_data(connect);
430 }
OnRead()431 MockRead SequencedSocketData::OnRead() {
432   CHECK_EQ(IoState::kIdle, read_state_);
433   CHECK(!helper_.AllReadDataConsumed())
434       << "Application tried to read but there is no read data left";
435 
436   NET_TRACE(1, " *** ") << "sequence_number: " << sequence_number_;
437   const MockRead& next_read = helper_.PeekRead();
438   NET_TRACE(1, " *** ") << "next_read: " << next_read.sequence_number;
439   CHECK_GE(next_read.sequence_number, sequence_number_);
440 
441   if (next_read.sequence_number <= sequence_number_) {
442     if (next_read.mode == SYNCHRONOUS) {
443       NET_TRACE(1, " *** ") << "Returning synchronously";
444       DumpMockReadWrite(next_read);
445       helper_.AdvanceRead();
446       ++sequence_number_;
447       MaybePostWriteCompleteTask();
448       return next_read;
449     }
450 
451     // If the result is ERR_IO_PENDING, then pause.
452     if (next_read.result == ERR_IO_PENDING) {
453       NET_TRACE(1, " *** ") << "Pausing read at: " << sequence_number_;
454       read_state_ = IoState::kPaused;
455       if (run_until_paused_run_loop_)
456         run_until_paused_run_loop_->Quit();
457       return MockRead(SYNCHRONOUS, ERR_IO_PENDING);
458     }
459     base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
460         FROM_HERE, base::BindOnce(&SequencedSocketData::OnReadComplete,
461                                   weak_factory_.GetWeakPtr()));
462     CHECK_NE(IoState::kCompleting, write_state_);
463     read_state_ = IoState::kCompleting;
464   } else if (next_read.mode == SYNCHRONOUS) {
465     ADD_FAILURE() << "Unable to perform synchronous IO while stopped";
466     return MockRead(SYNCHRONOUS, ERR_UNEXPECTED);
467   } else {
468     NET_TRACE(1, " *** ") << "Waiting for write to trigger read";
469     read_state_ = IoState::kPending;
470   }
471 
472   return MockRead(SYNCHRONOUS, ERR_IO_PENDING);
473 }
474 
OnWrite(const std::string & data)475 MockWriteResult SequencedSocketData::OnWrite(const std::string& data) {
476   CHECK_EQ(IoState::kIdle, write_state_);
477   if (printer_) {
478     CHECK(!helper_.AllWriteDataConsumed())
479         << "\nNo more mock data to match write:\nFormatted write data:\n"
480         << printer_->PrintWrite(data) << "Raw write data:\n"
481         << HexDump(data);
482   } else {
483     CHECK(!helper_.AllWriteDataConsumed())
484         << "\nNo more mock data to match write:\nRaw write data:\n"
485         << HexDump(data);
486   }
487 
488   NET_TRACE(1, " *** ") << "sequence_number: " << sequence_number_;
489   const MockWrite& next_write = helper_.PeekWrite();
490   NET_TRACE(1, " *** ") << "next_write: " << next_write.sequence_number;
491   CHECK_GE(next_write.sequence_number, sequence_number_);
492 
493   if (!helper_.VerifyWriteData(data, printer_))
494     return MockWriteResult(SYNCHRONOUS, ERR_UNEXPECTED);
495 
496   if (next_write.sequence_number <= sequence_number_) {
497     if (next_write.mode == SYNCHRONOUS) {
498       helper_.AdvanceWrite();
499       ++sequence_number_;
500       MaybePostReadCompleteTask();
501       // In the case that the write was successful, return the number of bytes
502       // written. Otherwise return the error code.
503       int rv =
504           next_write.result != OK ? next_write.result : next_write.data_len;
505       NET_TRACE(1, " *** ") << "Returning synchronously";
506       return MockWriteResult(SYNCHRONOUS, rv);
507     }
508 
509     // If the result is ERR_IO_PENDING, then pause.
510     if (next_write.result == ERR_IO_PENDING) {
511       NET_TRACE(1, " *** ") << "Pausing write at: " << sequence_number_;
512       write_state_ = IoState::kPaused;
513       if (run_until_paused_run_loop_)
514         run_until_paused_run_loop_->Quit();
515       return MockWriteResult(SYNCHRONOUS, ERR_IO_PENDING);
516     }
517 
518     NET_TRACE(1, " *** ") << "Posting task to complete write";
519     base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
520         FROM_HERE, base::BindOnce(&SequencedSocketData::OnWriteComplete,
521                                   weak_factory_.GetWeakPtr()));
522     CHECK_NE(IoState::kCompleting, read_state_);
523     write_state_ = IoState::kCompleting;
524   } else if (next_write.mode == SYNCHRONOUS) {
525     ADD_FAILURE() << "Unable to perform synchronous IO while stopped";
526     return MockWriteResult(SYNCHRONOUS, ERR_UNEXPECTED);
527   } else {
528     NET_TRACE(1, " *** ") << "Waiting for read to trigger write";
529     write_state_ = IoState::kPending;
530   }
531 
532   return MockWriteResult(SYNCHRONOUS, ERR_IO_PENDING);
533 }
534 
AllReadDataConsumed() const535 bool SequencedSocketData::AllReadDataConsumed() const {
536   return helper_.AllReadDataConsumed();
537 }
538 
CancelPendingRead()539 void SequencedSocketData::CancelPendingRead() {
540   DCHECK_EQ(IoState::kPending, read_state_);
541 
542   read_state_ = IoState::kIdle;
543 }
544 
AllWriteDataConsumed() const545 bool SequencedSocketData::AllWriteDataConsumed() const {
546   return helper_.AllWriteDataConsumed();
547 }
548 
IsIdle() const549 bool SequencedSocketData::IsIdle() const {
550   // If |busy_before_sync_reads_| is not set, always considered idle.  If
551   // no reads left, or the next operation is a write, also consider it idle.
552   if (!busy_before_sync_reads_ || helper_.AllReadDataConsumed() ||
553       helper_.PeekRead().sequence_number != sequence_number_) {
554     return true;
555   }
556 
557   // If the next operation is synchronous read, treat the socket as not idle.
558   if (helper_.PeekRead().mode == SYNCHRONOUS)
559     return false;
560   return true;
561 }
562 
IsPaused() const563 bool SequencedSocketData::IsPaused() const {
564   // Both states should not be paused.
565   DCHECK(read_state_ != IoState::kPaused || write_state_ != IoState::kPaused);
566   return write_state_ == IoState::kPaused || read_state_ == IoState::kPaused;
567 }
568 
Resume()569 void SequencedSocketData::Resume() {
570   if (!IsPaused()) {
571     ADD_FAILURE() << "Unable to Resume when not paused.";
572     return;
573   }
574 
575   sequence_number_++;
576   if (read_state_ == IoState::kPaused) {
577     read_state_ = IoState::kPending;
578     helper_.AdvanceRead();
579   } else {  // write_state_ == IoState::kPaused
580     write_state_ = IoState::kPending;
581     helper_.AdvanceWrite();
582   }
583 
584   if (!helper_.AllWriteDataConsumed() &&
585       helper_.PeekWrite().sequence_number == sequence_number_) {
586     // The next event hasn't even started yet.  Pausing isn't really needed in
587     // that case, but may as well support it.
588     if (write_state_ != IoState::kPending)
589       return;
590     write_state_ = IoState::kCompleting;
591     OnWriteComplete();
592     return;
593   }
594 
595   CHECK(!helper_.AllReadDataConsumed());
596 
597   // The next event hasn't even started yet.  Pausing isn't really needed in
598   // that case, but may as well support it.
599   if (read_state_ != IoState::kPending)
600     return;
601   read_state_ = IoState::kCompleting;
602   OnReadComplete();
603 }
604 
RunUntilPaused()605 void SequencedSocketData::RunUntilPaused() {
606   CHECK(!run_until_paused_run_loop_);
607 
608   if (IsPaused())
609     return;
610 
611   run_until_paused_run_loop_ = std::make_unique<base::RunLoop>();
612   run_until_paused_run_loop_->Run();
613   run_until_paused_run_loop_.reset();
614   DCHECK(IsPaused());
615 }
616 
MaybePostReadCompleteTask()617 void SequencedSocketData::MaybePostReadCompleteTask() {
618   NET_TRACE(1, " ****** ") << " current: " << sequence_number_;
619   // Only trigger the next read to complete if there is already a read pending
620   // which should complete at the current sequence number.
621   if (read_state_ != IoState::kPending ||
622       helper_.PeekRead().sequence_number != sequence_number_) {
623     return;
624   }
625 
626   // If the result is ERR_IO_PENDING, then pause.
627   if (helper_.PeekRead().result == ERR_IO_PENDING) {
628     NET_TRACE(1, " *** ") << "Pausing read at: " << sequence_number_;
629     read_state_ = IoState::kPaused;
630     if (run_until_paused_run_loop_)
631       run_until_paused_run_loop_->Quit();
632     return;
633   }
634 
635   NET_TRACE(1, " ****** ") << "Posting task to complete read: "
636                            << sequence_number_;
637   base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
638       FROM_HERE, base::BindOnce(&SequencedSocketData::OnReadComplete,
639                                 weak_factory_.GetWeakPtr()));
640   CHECK_NE(IoState::kCompleting, write_state_);
641   read_state_ = IoState::kCompleting;
642 }
643 
MaybePostWriteCompleteTask()644 void SequencedSocketData::MaybePostWriteCompleteTask() {
645   NET_TRACE(1, " ****** ") << " current: " << sequence_number_;
646   // Only trigger the next write to complete if there is already a write pending
647   // which should complete at the current sequence number.
648   if (write_state_ != IoState::kPending ||
649       helper_.PeekWrite().sequence_number != sequence_number_) {
650     return;
651   }
652 
653   // If the result is ERR_IO_PENDING, then pause.
654   if (helper_.PeekWrite().result == ERR_IO_PENDING) {
655     NET_TRACE(1, " *** ") << "Pausing write at: " << sequence_number_;
656     write_state_ = IoState::kPaused;
657     if (run_until_paused_run_loop_)
658       run_until_paused_run_loop_->Quit();
659     return;
660   }
661 
662   NET_TRACE(1, " ****** ") << "Posting task to complete write: "
663                            << sequence_number_;
664   base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
665       FROM_HERE, base::BindOnce(&SequencedSocketData::OnWriteComplete,
666                                 weak_factory_.GetWeakPtr()));
667   CHECK_NE(IoState::kCompleting, read_state_);
668   write_state_ = IoState::kCompleting;
669 }
670 
Reset()671 void SequencedSocketData::Reset() {
672   helper_.Reset();
673   sequence_number_ = 0;
674   read_state_ = IoState::kIdle;
675   write_state_ = IoState::kIdle;
676   weak_factory_.InvalidateWeakPtrs();
677 }
678 
OnReadComplete()679 void SequencedSocketData::OnReadComplete() {
680   CHECK_EQ(IoState::kCompleting, read_state_);
681   NET_TRACE(1, " *** ") << "Completing read for: " << sequence_number_;
682 
683   MockRead data = helper_.AdvanceRead();
684   DCHECK_EQ(sequence_number_, data.sequence_number);
685   sequence_number_++;
686   read_state_ = IoState::kIdle;
687 
688   // The result of this read completing might trigger the completion
689   // of a pending write. If so, post a task to complete the write later.
690   // Since the socket may call back into the SequencedSocketData
691   // from socket()->OnReadComplete(), trigger the write task to be posted
692   // before calling that.
693   MaybePostWriteCompleteTask();
694 
695   if (!socket()) {
696     NET_TRACE(1, " *** ") << "No socket available to complete read";
697     return;
698   }
699 
700   NET_TRACE(1, " *** ") << "Completing socket read for: "
701                         << data.sequence_number;
702   DumpMockReadWrite(data);
703   socket()->OnReadComplete(data);
704   NET_TRACE(1, " *** ") << "Done";
705 }
706 
OnWriteComplete()707 void SequencedSocketData::OnWriteComplete() {
708   CHECK_EQ(IoState::kCompleting, write_state_);
709   NET_TRACE(1, " *** ") << " Completing write for: " << sequence_number_;
710 
711   const MockWrite& data = helper_.AdvanceWrite();
712   DCHECK_EQ(sequence_number_, data.sequence_number);
713   sequence_number_++;
714   write_state_ = IoState::kIdle;
715   int rv = data.result == OK ? data.data_len : data.result;
716 
717   // The result of this write completing might trigger the completion
718   // of a pending read. If so, post a task to complete the read later.
719   // Since the socket may call back into the SequencedSocketData
720   // from socket()->OnWriteComplete(), trigger the write task to be posted
721   // before calling that.
722   MaybePostReadCompleteTask();
723 
724   if (!socket()) {
725     NET_TRACE(1, " *** ") << "No socket available to complete write";
726     return;
727   }
728 
729   NET_TRACE(1, " *** ") << " Completing socket write for: "
730                         << data.sequence_number;
731   socket()->OnWriteComplete(rv);
732   NET_TRACE(1, " *** ") << "Done";
733 }
734 
735 SequencedSocketData::~SequencedSocketData() = default;
736 
737 MockClientSocketFactory::MockClientSocketFactory() = default;
738 
739 MockClientSocketFactory::~MockClientSocketFactory() = default;
740 
AddSocketDataProvider(SocketDataProvider * data)741 void MockClientSocketFactory::AddSocketDataProvider(SocketDataProvider* data) {
742   mock_data_.Add(data);
743 }
744 
AddTcpSocketDataProvider(SocketDataProvider * data)745 void MockClientSocketFactory::AddTcpSocketDataProvider(
746     SocketDataProvider* data) {
747   mock_tcp_data_.Add(data);
748 }
749 
AddSSLSocketDataProvider(SSLSocketDataProvider * data)750 void MockClientSocketFactory::AddSSLSocketDataProvider(
751     SSLSocketDataProvider* data) {
752   mock_ssl_data_.Add(data);
753 }
754 
ResetNextMockIndexes()755 void MockClientSocketFactory::ResetNextMockIndexes() {
756   mock_data_.ResetNextIndex();
757   mock_ssl_data_.ResetNextIndex();
758 }
759 
760 std::unique_ptr<DatagramClientSocket>
CreateDatagramClientSocket(DatagramSocket::BindType bind_type,NetLog * net_log,const NetLogSource & source)761 MockClientSocketFactory::CreateDatagramClientSocket(
762     DatagramSocket::BindType bind_type,
763     NetLog* net_log,
764     const NetLogSource& source) {
765   SocketDataProvider* data_provider = mock_data_.GetNext();
766   auto socket = std::make_unique<MockUDPClientSocket>(data_provider, net_log);
767   if (bind_type == DatagramSocket::RANDOM_BIND)
768     socket->set_source_port(static_cast<uint16_t>(base::RandInt(1025, 65535)));
769   udp_client_socket_ports_.push_back(socket->source_port());
770   return std::move(socket);
771 }
772 
773 std::unique_ptr<TransportClientSocket>
CreateTransportClientSocket(const AddressList & addresses,std::unique_ptr<SocketPerformanceWatcher> socket_performance_watcher,NetworkQualityEstimator * network_quality_estimator,NetLog * net_log,const NetLogSource & source)774 MockClientSocketFactory::CreateTransportClientSocket(
775     const AddressList& addresses,
776     std::unique_ptr<SocketPerformanceWatcher> socket_performance_watcher,
777     NetworkQualityEstimator* network_quality_estimator,
778     NetLog* net_log,
779     const NetLogSource& source) {
780   SocketDataProvider* data_provider = mock_tcp_data_.GetNextWithoutAsserting();
781   if (!data_provider)
782     data_provider = mock_data_.GetNext();
783   auto socket =
784       std::make_unique<MockTCPClientSocket>(addresses, net_log, data_provider);
785   if (enable_read_if_ready_)
786     socket->set_enable_read_if_ready(enable_read_if_ready_);
787   return std::move(socket);
788 }
789 
CreateSSLClientSocket(SSLClientContext * context,std::unique_ptr<StreamSocket> stream_socket,const HostPortPair & host_and_port,const SSLConfig & ssl_config)790 std::unique_ptr<SSLClientSocket> MockClientSocketFactory::CreateSSLClientSocket(
791     SSLClientContext* context,
792     std::unique_ptr<StreamSocket> stream_socket,
793     const HostPortPair& host_and_port,
794     const SSLConfig& ssl_config) {
795   SSLSocketDataProvider* next_ssl_data = mock_ssl_data_.GetNext();
796   if (next_ssl_data->next_protos_expected_in_ssl_config.has_value()) {
797     EXPECT_TRUE(base::ranges::equal(
798         next_ssl_data->next_protos_expected_in_ssl_config.value(),
799         ssl_config.alpn_protos));
800   }
801 
802   // The protocol version used is a combination of the per-socket SSLConfig and
803   // the SSLConfigService.
804   EXPECT_EQ(
805       next_ssl_data->expected_ssl_version_min,
806       ssl_config.version_min_override.value_or(context->config().version_min));
807   EXPECT_EQ(
808       next_ssl_data->expected_ssl_version_max,
809       ssl_config.version_max_override.value_or(context->config().version_max));
810 
811   if (next_ssl_data->expected_send_client_cert) {
812     // Client certificate preferences come from |context|.
813     scoped_refptr<X509Certificate> client_cert;
814     scoped_refptr<SSLPrivateKey> client_private_key;
815     bool send_client_cert = context->GetClientCertificate(
816         host_and_port, &client_cert, &client_private_key);
817 
818     EXPECT_EQ(*next_ssl_data->expected_send_client_cert, send_client_cert);
819     // Note |send_client_cert| may be true while |client_cert| is null if the
820     // socket is configured to continue without a certificate, as opposed to
821     // surfacing the certificate challenge.
822     EXPECT_EQ(!!next_ssl_data->expected_client_cert, !!client_cert);
823     if (next_ssl_data->expected_client_cert && client_cert) {
824       EXPECT_TRUE(next_ssl_data->expected_client_cert->EqualsIncludingChain(
825           client_cert.get()));
826     }
827   }
828   if (next_ssl_data->expected_host_and_port) {
829     EXPECT_EQ(*next_ssl_data->expected_host_and_port, host_and_port);
830   }
831   if (next_ssl_data->expected_ignore_certificate_errors) {
832     EXPECT_EQ(*next_ssl_data->expected_ignore_certificate_errors,
833               ssl_config.ignore_certificate_errors);
834   }
835   if (next_ssl_data->expected_network_anonymization_key) {
836     EXPECT_EQ(*next_ssl_data->expected_network_anonymization_key,
837               ssl_config.network_anonymization_key);
838   }
839   if (next_ssl_data->expected_disable_sha1_server_signatures) {
840     EXPECT_EQ(*next_ssl_data->expected_disable_sha1_server_signatures,
841               ssl_config.disable_sha1_server_signatures);
842   }
843   if (next_ssl_data->expected_ech_config_list) {
844     EXPECT_EQ(*next_ssl_data->expected_ech_config_list,
845               ssl_config.ech_config_list);
846   }
847   return std::make_unique<MockSSLClientSocket>(
848       std::move(stream_socket), host_and_port, ssl_config, next_ssl_data);
849 }
850 
MockClientSocket(const NetLogWithSource & net_log)851 MockClientSocket::MockClientSocket(const NetLogWithSource& net_log)
852     : net_log_(net_log) {
853   local_addr_ = IPEndPoint(IPAddress(192, 0, 2, 33), 123);
854   peer_addr_ = IPEndPoint(IPAddress(192, 0, 2, 33), 0);
855 }
856 
SetReceiveBufferSize(int32_t size)857 int MockClientSocket::SetReceiveBufferSize(int32_t size) {
858   return OK;
859 }
860 
SetSendBufferSize(int32_t size)861 int MockClientSocket::SetSendBufferSize(int32_t size) {
862   return OK;
863 }
864 
Bind(const net::IPEndPoint & local_addr)865 int MockClientSocket::Bind(const net::IPEndPoint& local_addr) {
866   local_addr_ = local_addr;
867   return net::OK;
868 }
869 
SetNoDelay(bool no_delay)870 bool MockClientSocket::SetNoDelay(bool no_delay) {
871   return true;
872 }
873 
SetKeepAlive(bool enable,int delay)874 bool MockClientSocket::SetKeepAlive(bool enable, int delay) {
875   return true;
876 }
877 
Disconnect()878 void MockClientSocket::Disconnect() {
879   connected_ = false;
880 }
881 
IsConnected() const882 bool MockClientSocket::IsConnected() const {
883   return connected_;
884 }
885 
IsConnectedAndIdle() const886 bool MockClientSocket::IsConnectedAndIdle() const {
887   return connected_;
888 }
889 
GetPeerAddress(IPEndPoint * address) const890 int MockClientSocket::GetPeerAddress(IPEndPoint* address) const {
891   if (!IsConnected())
892     return ERR_SOCKET_NOT_CONNECTED;
893   *address = peer_addr_;
894   return OK;
895 }
896 
GetLocalAddress(IPEndPoint * address) const897 int MockClientSocket::GetLocalAddress(IPEndPoint* address) const {
898   *address = local_addr_;
899   return OK;
900 }
901 
NetLog() const902 const NetLogWithSource& MockClientSocket::NetLog() const {
903   return net_log_;
904 }
905 
GetNegotiatedProtocol() const906 NextProto MockClientSocket::GetNegotiatedProtocol() const {
907   return kProtoUnknown;
908 }
909 
910 MockClientSocket::~MockClientSocket() = default;
911 
RunCallbackAsync(CompletionOnceCallback callback,int result)912 void MockClientSocket::RunCallbackAsync(CompletionOnceCallback callback,
913                                         int result) {
914   base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
915       FROM_HERE,
916       base::BindOnce(&MockClientSocket::RunCallback, weak_factory_.GetWeakPtr(),
917                      std::move(callback), result));
918 }
919 
RunCallback(CompletionOnceCallback callback,int result)920 void MockClientSocket::RunCallback(CompletionOnceCallback callback,
921                                    int result) {
922   std::move(callback).Run(result);
923 }
924 
MockTCPClientSocket(const AddressList & addresses,net::NetLog * net_log,SocketDataProvider * data)925 MockTCPClientSocket::MockTCPClientSocket(const AddressList& addresses,
926                                          net::NetLog* net_log,
927                                          SocketDataProvider* data)
928     : MockClientSocket(NetLogWithSource::Make(net_log, NetLogSourceType::NONE)),
929       addresses_(addresses),
930       data_(data),
931       read_data_(SYNCHRONOUS, ERR_UNEXPECTED) {
932   DCHECK(data_);
933   peer_addr_ = data->connect_data().peer_addr;
934   data_->Initialize(this);
935   if (data_->expected_addresses()) {
936     EXPECT_EQ(*data_->expected_addresses(), addresses);
937   }
938 }
939 
~MockTCPClientSocket()940 MockTCPClientSocket::~MockTCPClientSocket() {
941   if (data_)
942     data_->DetachSocket();
943 }
944 
Read(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)945 int MockTCPClientSocket::Read(IOBuffer* buf,
946                               int buf_len,
947                               CompletionOnceCallback callback) {
948   // If the buffer is already in use, a read is already in progress!
949   DCHECK(!pending_read_buf_);
950   // Use base::Unretained() is safe because MockClientSocket::RunCallbackAsync()
951   // takes a weak ptr of the base class, MockClientSocket.
952   int rv = ReadIfReadyImpl(
953       buf, buf_len,
954       base::BindOnce(&MockTCPClientSocket::RetryRead, base::Unretained(this)));
955   if (rv == ERR_IO_PENDING) {
956     DCHECK(callback);
957 
958     pending_read_buf_ = buf;
959     pending_read_buf_len_ = buf_len;
960     pending_read_callback_ = std::move(callback);
961   }
962   return rv;
963 }
964 
ReadIfReady(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)965 int MockTCPClientSocket::ReadIfReady(IOBuffer* buf,
966                                      int buf_len,
967                                      CompletionOnceCallback callback) {
968   DCHECK(!pending_read_if_ready_callback_);
969 
970   if (!enable_read_if_ready_)
971     return ERR_READ_IF_READY_NOT_IMPLEMENTED;
972   return ReadIfReadyImpl(buf, buf_len, std::move(callback));
973 }
974 
CancelReadIfReady()975 int MockTCPClientSocket::CancelReadIfReady() {
976   DCHECK(pending_read_if_ready_callback_);
977 
978   pending_read_if_ready_callback_.Reset();
979   data_->CancelPendingRead();
980   return OK;
981 }
982 
Write(IOBuffer * buf,int buf_len,CompletionOnceCallback callback,const NetworkTrafficAnnotationTag &)983 int MockTCPClientSocket::Write(
984     IOBuffer* buf,
985     int buf_len,
986     CompletionOnceCallback callback,
987     const NetworkTrafficAnnotationTag& /* traffic_annotation */) {
988   DCHECK(buf);
989   DCHECK_GT(buf_len, 0);
990 
991   if (!connected_ || !data_)
992     return ERR_UNEXPECTED;
993 
994   std::string data(buf->data(), buf_len);
995   MockWriteResult write_result = data_->OnWrite(data);
996 
997   was_used_to_convey_data_ = true;
998 
999   if (write_result.result == ERR_CONNECTION_CLOSED) {
1000     // This MockWrite is just a marker to instruct us to set
1001     // peer_closed_connection_.
1002     peer_closed_connection_ = true;
1003   }
1004   // ERR_IO_PENDING is a signal that the socket data will call back
1005   // asynchronously later.
1006   if (write_result.result == ERR_IO_PENDING) {
1007     pending_write_callback_ = std::move(callback);
1008     return ERR_IO_PENDING;
1009   }
1010 
1011   if (write_result.mode == ASYNC) {
1012     RunCallbackAsync(std::move(callback), write_result.result);
1013     return ERR_IO_PENDING;
1014   }
1015 
1016   return write_result.result;
1017 }
1018 
SetReceiveBufferSize(int32_t size)1019 int MockTCPClientSocket::SetReceiveBufferSize(int32_t size) {
1020   if (!connected_)
1021     return net::ERR_UNEXPECTED;
1022   data_->set_receive_buffer_size(size);
1023   return data_->set_receive_buffer_size_result();
1024 }
1025 
SetSendBufferSize(int32_t size)1026 int MockTCPClientSocket::SetSendBufferSize(int32_t size) {
1027   if (!connected_)
1028     return net::ERR_UNEXPECTED;
1029   data_->set_send_buffer_size(size);
1030   return data_->set_send_buffer_size_result();
1031 }
1032 
SetNoDelay(bool no_delay)1033 bool MockTCPClientSocket::SetNoDelay(bool no_delay) {
1034   if (!connected_)
1035     return false;
1036   data_->set_no_delay(no_delay);
1037   return data_->set_no_delay_result();
1038 }
1039 
SetKeepAlive(bool enable,int delay)1040 bool MockTCPClientSocket::SetKeepAlive(bool enable, int delay) {
1041   if (!connected_)
1042     return false;
1043   data_->set_keep_alive(enable, delay);
1044   return data_->set_keep_alive_result();
1045 }
1046 
SetBeforeConnectCallback(const BeforeConnectCallback & before_connect_callback)1047 void MockTCPClientSocket::SetBeforeConnectCallback(
1048     const BeforeConnectCallback& before_connect_callback) {
1049   DCHECK(!before_connect_callback_);
1050   DCHECK(!connected_);
1051 
1052   before_connect_callback_ = before_connect_callback;
1053 }
1054 
Connect(CompletionOnceCallback callback)1055 int MockTCPClientSocket::Connect(CompletionOnceCallback callback) {
1056   if (!data_)
1057     return ERR_UNEXPECTED;
1058 
1059   if (connected_)
1060     return OK;
1061 
1062   // Setting socket options fails if not connected, so need to set this before
1063   // calling |before_connect_callback_|.
1064   connected_ = true;
1065 
1066   if (before_connect_callback_) {
1067     for (size_t index = 0; index < addresses_.size(); index++) {
1068       int result = before_connect_callback_.Run();
1069       if (data_->connect_data().first_attempt_fails && index == 0) {
1070         continue;
1071       }
1072       DCHECK_NE(result, ERR_IO_PENDING);
1073       if (result != net::OK) {
1074         connected_ = false;
1075         return result;
1076       }
1077       break;
1078     }
1079   }
1080 
1081   peer_closed_connection_ = false;
1082 
1083   int result = data_->connect_data().result;
1084   IoMode mode = data_->connect_data().mode;
1085   if (mode == SYNCHRONOUS)
1086     return result;
1087 
1088   DCHECK(callback);
1089 
1090   if (result == ERR_IO_PENDING)
1091     pending_connect_callback_ = std::move(callback);
1092   else
1093     RunCallbackAsync(std::move(callback), result);
1094   return ERR_IO_PENDING;
1095 }
1096 
Disconnect()1097 void MockTCPClientSocket::Disconnect() {
1098   MockClientSocket::Disconnect();
1099   pending_connect_callback_.Reset();
1100   pending_read_callback_.Reset();
1101 }
1102 
IsConnected() const1103 bool MockTCPClientSocket::IsConnected() const {
1104   if (!data_)
1105     return false;
1106   return connected_ && !peer_closed_connection_;
1107 }
1108 
IsConnectedAndIdle() const1109 bool MockTCPClientSocket::IsConnectedAndIdle() const {
1110   if (!data_)
1111     return false;
1112   return IsConnected() && data_->IsIdle();
1113 }
1114 
GetPeerAddress(IPEndPoint * address) const1115 int MockTCPClientSocket::GetPeerAddress(IPEndPoint* address) const {
1116   if (addresses_.empty())
1117     return MockClientSocket::GetPeerAddress(address);
1118 
1119   if (data_->connect_data().first_attempt_fails) {
1120     DCHECK_GE(addresses_.size(), 2U);
1121     *address = addresses_[1];
1122   } else {
1123     *address = addresses_[0];
1124   }
1125   return OK;
1126 }
1127 
WasEverUsed() const1128 bool MockTCPClientSocket::WasEverUsed() const {
1129   return was_used_to_convey_data_;
1130 }
1131 
GetSSLInfo(SSLInfo * ssl_info)1132 bool MockTCPClientSocket::GetSSLInfo(SSLInfo* ssl_info) {
1133   return false;
1134 }
1135 
OnReadComplete(const MockRead & data)1136 void MockTCPClientSocket::OnReadComplete(const MockRead& data) {
1137   // If |data_| has been destroyed, safest to just do nothing.
1138   if (!data_)
1139     return;
1140 
1141   // There must be a read pending.
1142   DCHECK(pending_read_if_ready_callback_);
1143   // You can't complete a read with another ERR_IO_PENDING status code.
1144   DCHECK_NE(ERR_IO_PENDING, data.result);
1145   // Since we've been waiting for data, need_read_data_ should be true.
1146   DCHECK(need_read_data_);
1147 
1148   read_data_ = data;
1149   need_read_data_ = false;
1150 
1151   // The caller is simulating that this IO completes right now.  Don't
1152   // let CompleteRead() schedule a callback.
1153   read_data_.mode = SYNCHRONOUS;
1154   RunCallback(std::move(pending_read_if_ready_callback_),
1155               read_data_.result > 0 ? OK : read_data_.result);
1156 }
1157 
OnWriteComplete(int rv)1158 void MockTCPClientSocket::OnWriteComplete(int rv) {
1159   // If |data_| has been destroyed, safest to just do nothing.
1160   if (!data_)
1161     return;
1162 
1163   // There must be a read pending.
1164   DCHECK(!pending_write_callback_.is_null());
1165   RunCallback(std::move(pending_write_callback_), rv);
1166 }
1167 
OnConnectComplete(const MockConnect & data)1168 void MockTCPClientSocket::OnConnectComplete(const MockConnect& data) {
1169   // If |data_| has been destroyed, safest to just do nothing.
1170   if (!data_)
1171     return;
1172 
1173   RunCallback(std::move(pending_connect_callback_), data.result);
1174 }
1175 
OnDataProviderDestroyed()1176 void MockTCPClientSocket::OnDataProviderDestroyed() {
1177   data_ = nullptr;
1178 }
1179 
RetryRead(int rv)1180 void MockTCPClientSocket::RetryRead(int rv) {
1181   DCHECK(pending_read_callback_);
1182   DCHECK(pending_read_buf_.get());
1183   DCHECK_LT(0, pending_read_buf_len_);
1184 
1185   if (rv == OK) {
1186     rv = ReadIfReadyImpl(pending_read_buf_.get(), pending_read_buf_len_,
1187                          base::BindOnce(&MockTCPClientSocket::RetryRead,
1188                                         base::Unretained(this)));
1189     if (rv == ERR_IO_PENDING)
1190       return;
1191   }
1192   pending_read_buf_ = nullptr;
1193   pending_read_buf_len_ = 0;
1194   RunCallback(std::move(pending_read_callback_), rv);
1195 }
1196 
ReadIfReadyImpl(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)1197 int MockTCPClientSocket::ReadIfReadyImpl(IOBuffer* buf,
1198                                          int buf_len,
1199                                          CompletionOnceCallback callback) {
1200   if (!connected_ || !data_)
1201     return ERR_UNEXPECTED;
1202 
1203   DCHECK(!pending_read_if_ready_callback_);
1204 
1205   if (need_read_data_) {
1206     read_data_ = data_->OnRead();
1207     if (read_data_.result == ERR_CONNECTION_CLOSED) {
1208       // This MockRead is just a marker to instruct us to set
1209       // peer_closed_connection_.
1210       peer_closed_connection_ = true;
1211     }
1212     if (read_data_.result == ERR_TEST_PEER_CLOSE_AFTER_NEXT_MOCK_READ) {
1213       // This MockRead is just a marker to instruct us to set
1214       // peer_closed_connection_.  Skip it and get the next one.
1215       read_data_ = data_->OnRead();
1216       peer_closed_connection_ = true;
1217     }
1218     // ERR_IO_PENDING means that the SocketDataProvider is taking responsibility
1219     // to complete the async IO manually later (via OnReadComplete).
1220     if (read_data_.result == ERR_IO_PENDING) {
1221       // We need to be using async IO in this case.
1222       DCHECK(!callback.is_null());
1223       pending_read_if_ready_callback_ = std::move(callback);
1224       return ERR_IO_PENDING;
1225     }
1226     need_read_data_ = false;
1227   }
1228 
1229   int result = read_data_.result;
1230   DCHECK_NE(ERR_IO_PENDING, result);
1231   if (read_data_.mode == ASYNC) {
1232     DCHECK(!callback.is_null());
1233     read_data_.mode = SYNCHRONOUS;
1234     pending_read_if_ready_callback_ = std::move(callback);
1235     // base::Unretained() is safe here because RunCallbackAsync will wrap it
1236     // with a callback associated with a weak ptr.
1237     RunCallbackAsync(
1238         base::BindOnce(&MockTCPClientSocket::RunReadIfReadyCallback,
1239                        base::Unretained(this)),
1240         result);
1241     return ERR_IO_PENDING;
1242   }
1243 
1244   was_used_to_convey_data_ = true;
1245   if (read_data_.data) {
1246     if (read_data_.data_len - read_offset_ > 0) {
1247       result = std::min(buf_len, read_data_.data_len - read_offset_);
1248       memcpy(buf->data(), read_data_.data + read_offset_, result);
1249       read_offset_ += result;
1250       if (read_offset_ == read_data_.data_len) {
1251         need_read_data_ = true;
1252         read_offset_ = 0;
1253       }
1254     } else {
1255       result = 0;  // EOF
1256     }
1257   }
1258   return result;
1259 }
1260 
RunReadIfReadyCallback(int result)1261 void MockTCPClientSocket::RunReadIfReadyCallback(int result) {
1262   // If ReadIfReady is already canceled, do nothing.
1263   if (!pending_read_if_ready_callback_)
1264     return;
1265   std::move(pending_read_if_ready_callback_).Run(result);
1266 }
1267 
1268 // static
ConnectCallback(MockSSLClientSocket * ssl_client_socket,CompletionOnceCallback callback,int rv)1269 void MockSSLClientSocket::ConnectCallback(
1270     MockSSLClientSocket* ssl_client_socket,
1271     CompletionOnceCallback callback,
1272     int rv) {
1273   if (rv == OK)
1274     ssl_client_socket->connected_ = true;
1275   std::move(callback).Run(rv);
1276 }
1277 
MockSSLClientSocket(std::unique_ptr<StreamSocket> stream_socket,const HostPortPair & host_and_port,const SSLConfig & ssl_config,SSLSocketDataProvider * data)1278 MockSSLClientSocket::MockSSLClientSocket(
1279     std::unique_ptr<StreamSocket> stream_socket,
1280     const HostPortPair& host_and_port,
1281     const SSLConfig& ssl_config,
1282     SSLSocketDataProvider* data)
1283     : net_log_(stream_socket->NetLog()),
1284       stream_socket_(std::move(stream_socket)),
1285       data_(data) {
1286   DCHECK(data_);
1287   peer_addr_ = data->connect.peer_addr;
1288 }
1289 
~MockSSLClientSocket()1290 MockSSLClientSocket::~MockSSLClientSocket() {
1291   Disconnect();
1292 }
1293 
Read(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)1294 int MockSSLClientSocket::Read(IOBuffer* buf,
1295                               int buf_len,
1296                               CompletionOnceCallback callback) {
1297   return stream_socket_->Read(buf, buf_len, std::move(callback));
1298 }
1299 
ReadIfReady(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)1300 int MockSSLClientSocket::ReadIfReady(IOBuffer* buf,
1301                                      int buf_len,
1302                                      CompletionOnceCallback callback) {
1303   return stream_socket_->ReadIfReady(buf, buf_len, std::move(callback));
1304 }
1305 
Write(IOBuffer * buf,int buf_len,CompletionOnceCallback callback,const NetworkTrafficAnnotationTag & traffic_annotation)1306 int MockSSLClientSocket::Write(
1307     IOBuffer* buf,
1308     int buf_len,
1309     CompletionOnceCallback callback,
1310     const NetworkTrafficAnnotationTag& traffic_annotation) {
1311   if (!data_->is_confirm_data_consumed)
1312     data_->write_called_before_confirm = true;
1313   return stream_socket_->Write(buf, buf_len, std::move(callback),
1314                                traffic_annotation);
1315 }
1316 
CancelReadIfReady()1317 int MockSSLClientSocket::CancelReadIfReady() {
1318   return stream_socket_->CancelReadIfReady();
1319 }
1320 
Connect(CompletionOnceCallback callback)1321 int MockSSLClientSocket::Connect(CompletionOnceCallback callback) {
1322   DCHECK(stream_socket_->IsConnected());
1323   data_->is_connect_data_consumed = true;
1324   if (data_->connect.result == OK)
1325     connected_ = true;
1326   RunClosureIfNonNull(std::move(data_->connect_callback));
1327   if (data_->connect.mode == ASYNC) {
1328     RunCallbackAsync(std::move(callback), data_->connect.result);
1329     return ERR_IO_PENDING;
1330   }
1331   return data_->connect.result;
1332 }
1333 
Disconnect()1334 void MockSSLClientSocket::Disconnect() {
1335   if (stream_socket_ != nullptr)
1336     stream_socket_->Disconnect();
1337 }
1338 
RunConfirmHandshakeCallback(CompletionOnceCallback callback,int result)1339 void MockSSLClientSocket::RunConfirmHandshakeCallback(
1340     CompletionOnceCallback callback,
1341     int result) {
1342   DCHECK(in_confirm_handshake_);
1343   in_confirm_handshake_ = false;
1344   data_->is_confirm_data_consumed = true;
1345   std::move(callback).Run(result);
1346 }
1347 
ConfirmHandshake(CompletionOnceCallback callback)1348 int MockSSLClientSocket::ConfirmHandshake(CompletionOnceCallback callback) {
1349   DCHECK(stream_socket_->IsConnected());
1350   DCHECK(!in_confirm_handshake_);
1351   if (data_->is_confirm_data_consumed)
1352     return data_->confirm.result;
1353   RunClosureIfNonNull(std::move(data_->confirm_callback));
1354   if (data_->confirm.mode == ASYNC) {
1355     in_confirm_handshake_ = true;
1356     RunCallbackAsync(
1357         base::BindOnce(&MockSSLClientSocket::RunConfirmHandshakeCallback,
1358                        base::Unretained(this), std::move(callback)),
1359         data_->confirm.result);
1360     return ERR_IO_PENDING;
1361   }
1362   data_->is_confirm_data_consumed = true;
1363   if (data_->confirm.result == ERR_IO_PENDING) {
1364     // `MockConfirm(SYNCHRONOUS, ERR_IO_PENDING)` means `ConfirmHandshake()`
1365     // never completes.
1366     in_confirm_handshake_ = true;
1367   }
1368   return data_->confirm.result;
1369 }
1370 
IsConnected() const1371 bool MockSSLClientSocket::IsConnected() const {
1372   return stream_socket_->IsConnected();
1373 }
1374 
IsConnectedAndIdle() const1375 bool MockSSLClientSocket::IsConnectedAndIdle() const {
1376   return stream_socket_->IsConnectedAndIdle();
1377 }
1378 
WasEverUsed() const1379 bool MockSSLClientSocket::WasEverUsed() const {
1380   return stream_socket_->WasEverUsed();
1381 }
1382 
GetLocalAddress(IPEndPoint * address) const1383 int MockSSLClientSocket::GetLocalAddress(IPEndPoint* address) const {
1384   *address = IPEndPoint(IPAddress(192, 0, 2, 33), 123);
1385   return OK;
1386 }
1387 
GetPeerAddress(IPEndPoint * address) const1388 int MockSSLClientSocket::GetPeerAddress(IPEndPoint* address) const {
1389   return stream_socket_->GetPeerAddress(address);
1390 }
1391 
GetNegotiatedProtocol() const1392 NextProto MockSSLClientSocket::GetNegotiatedProtocol() const {
1393   return data_->next_proto;
1394 }
1395 
1396 absl::optional<base::StringPiece>
GetPeerApplicationSettings() const1397 MockSSLClientSocket::GetPeerApplicationSettings() const {
1398   return data_->peer_application_settings;
1399 }
1400 
GetSSLInfo(SSLInfo * requested_ssl_info)1401 bool MockSSLClientSocket::GetSSLInfo(SSLInfo* requested_ssl_info) {
1402   requested_ssl_info->Reset();
1403   *requested_ssl_info = data_->ssl_info;
1404   return true;
1405 }
1406 
ApplySocketTag(const SocketTag & tag)1407 void MockSSLClientSocket::ApplySocketTag(const SocketTag& tag) {
1408   return stream_socket_->ApplySocketTag(tag);
1409 }
1410 
NetLog() const1411 const NetLogWithSource& MockSSLClientSocket::NetLog() const {
1412   return net_log_;
1413 }
1414 
GetTotalReceivedBytes() const1415 int64_t MockSSLClientSocket::GetTotalReceivedBytes() const {
1416   NOTIMPLEMENTED();
1417   return 0;
1418 }
1419 
GetTotalReceivedBytes() const1420 int64_t MockClientSocket::GetTotalReceivedBytes() const {
1421   NOTIMPLEMENTED();
1422   return 0;
1423 }
1424 
SetReceiveBufferSize(int32_t size)1425 int MockSSLClientSocket::SetReceiveBufferSize(int32_t size) {
1426   return OK;
1427 }
1428 
SetSendBufferSize(int32_t size)1429 int MockSSLClientSocket::SetSendBufferSize(int32_t size) {
1430   return OK;
1431 }
1432 
GetSSLCertRequestInfo(SSLCertRequestInfo * cert_request_info) const1433 void MockSSLClientSocket::GetSSLCertRequestInfo(
1434     SSLCertRequestInfo* cert_request_info) const {
1435   DCHECK(cert_request_info);
1436   if (data_->cert_request_info) {
1437     cert_request_info->host_and_port = data_->cert_request_info->host_and_port;
1438     cert_request_info->is_proxy = data_->cert_request_info->is_proxy;
1439     cert_request_info->cert_authorities =
1440         data_->cert_request_info->cert_authorities;
1441     cert_request_info->signature_algorithms =
1442         data_->cert_request_info->signature_algorithms;
1443   } else {
1444     cert_request_info->Reset();
1445   }
1446 }
1447 
ExportKeyingMaterial(base::StringPiece label,bool has_context,base::StringPiece context,unsigned char * out,unsigned int outlen)1448 int MockSSLClientSocket::ExportKeyingMaterial(base::StringPiece label,
1449                                               bool has_context,
1450                                               base::StringPiece context,
1451                                               unsigned char* out,
1452                                               unsigned int outlen) {
1453   memset(out, 'A', outlen);
1454   return OK;
1455 }
1456 
GetECHRetryConfigs()1457 std::vector<uint8_t> MockSSLClientSocket::GetECHRetryConfigs() {
1458   return data_->ech_retry_configs;
1459 }
1460 
RunCallbackAsync(CompletionOnceCallback callback,int result)1461 void MockSSLClientSocket::RunCallbackAsync(CompletionOnceCallback callback,
1462                                            int result) {
1463   base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
1464       FROM_HERE,
1465       base::BindOnce(&MockSSLClientSocket::RunCallback,
1466                      weak_factory_.GetWeakPtr(), std::move(callback), result));
1467 }
1468 
RunCallback(CompletionOnceCallback callback,int result)1469 void MockSSLClientSocket::RunCallback(CompletionOnceCallback callback,
1470                                       int result) {
1471   std::move(callback).Run(result);
1472 }
1473 
OnReadComplete(const MockRead & data)1474 void MockSSLClientSocket::OnReadComplete(const MockRead& data) {
1475   NOTIMPLEMENTED();
1476 }
1477 
OnWriteComplete(int rv)1478 void MockSSLClientSocket::OnWriteComplete(int rv) {
1479   NOTIMPLEMENTED();
1480 }
1481 
OnConnectComplete(const MockConnect & data)1482 void MockSSLClientSocket::OnConnectComplete(const MockConnect& data) {
1483   NOTIMPLEMENTED();
1484 }
1485 
MockUDPClientSocket(SocketDataProvider * data,net::NetLog * net_log)1486 MockUDPClientSocket::MockUDPClientSocket(SocketDataProvider* data,
1487                                          net::NetLog* net_log)
1488     : data_(data),
1489       read_data_(SYNCHRONOUS, ERR_UNEXPECTED),
1490       source_host_(IPAddress(192, 0, 2, 33)),
1491       net_log_(NetLogWithSource::Make(net_log, NetLogSourceType::NONE)) {
1492   if (data_) {
1493     data_->Initialize(this);
1494     peer_addr_ = data->connect_data().peer_addr;
1495   }
1496 }
1497 
~MockUDPClientSocket()1498 MockUDPClientSocket::~MockUDPClientSocket() {
1499   if (data_)
1500     data_->DetachSocket();
1501 }
1502 
Read(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)1503 int MockUDPClientSocket::Read(IOBuffer* buf,
1504                               int buf_len,
1505                               CompletionOnceCallback callback) {
1506   DCHECK(callback);
1507 
1508   if (!connected_ || !data_)
1509     return ERR_UNEXPECTED;
1510   data_transferred_ = true;
1511 
1512   // If the buffer is already in use, a read is already in progress!
1513   DCHECK(!pending_read_buf_);
1514 
1515   // Store our async IO data.
1516   pending_read_buf_ = buf;
1517   pending_read_buf_len_ = buf_len;
1518   pending_read_callback_ = std::move(callback);
1519 
1520   if (need_read_data_) {
1521     read_data_ = data_->OnRead();
1522     // ERR_IO_PENDING means that the SocketDataProvider is taking responsibility
1523     // to complete the async IO manually later (via OnReadComplete).
1524     if (read_data_.result == ERR_IO_PENDING) {
1525       // We need to be using async IO in this case.
1526       DCHECK(!pending_read_callback_.is_null());
1527       return ERR_IO_PENDING;
1528     }
1529     need_read_data_ = false;
1530   }
1531 
1532   return CompleteRead();
1533 }
1534 
Write(IOBuffer * buf,int buf_len,CompletionOnceCallback callback,const NetworkTrafficAnnotationTag &)1535 int MockUDPClientSocket::Write(
1536     IOBuffer* buf,
1537     int buf_len,
1538     CompletionOnceCallback callback,
1539     const NetworkTrafficAnnotationTag& /* traffic_annotation */) {
1540   DCHECK(buf);
1541   DCHECK_GT(buf_len, 0);
1542   DCHECK(callback);
1543 
1544   if (!connected_ || !data_)
1545     return ERR_UNEXPECTED;
1546   data_transferred_ = true;
1547 
1548   std::string data(buf->data(), buf_len);
1549   MockWriteResult write_result = data_->OnWrite(data);
1550 
1551   // ERR_IO_PENDING is a signal that the socket data will call back
1552   // asynchronously.
1553   if (write_result.result == ERR_IO_PENDING) {
1554     pending_write_callback_ = std::move(callback);
1555     return ERR_IO_PENDING;
1556   }
1557   if (write_result.mode == ASYNC) {
1558     RunCallbackAsync(std::move(callback), write_result.result);
1559     return ERR_IO_PENDING;
1560   }
1561   return write_result.result;
1562 }
1563 
SetReceiveBufferSize(int32_t size)1564 int MockUDPClientSocket::SetReceiveBufferSize(int32_t size) {
1565   return OK;
1566 }
1567 
SetSendBufferSize(int32_t size)1568 int MockUDPClientSocket::SetSendBufferSize(int32_t size) {
1569   return OK;
1570 }
1571 
SetDoNotFragment()1572 int MockUDPClientSocket::SetDoNotFragment() {
1573   return OK;
1574 }
1575 
SetRecvEcn()1576 int MockUDPClientSocket::SetRecvEcn() {
1577   return OK;
1578 }
1579 
Close()1580 void MockUDPClientSocket::Close() {
1581   connected_ = false;
1582 }
1583 
GetPeerAddress(IPEndPoint * address) const1584 int MockUDPClientSocket::GetPeerAddress(IPEndPoint* address) const {
1585   if (!data_)
1586     return ERR_UNEXPECTED;
1587 
1588   *address = peer_addr_;
1589   return OK;
1590 }
1591 
GetLocalAddress(IPEndPoint * address) const1592 int MockUDPClientSocket::GetLocalAddress(IPEndPoint* address) const {
1593   *address = IPEndPoint(source_host_, source_port_);
1594   return OK;
1595 }
1596 
UseNonBlockingIO()1597 void MockUDPClientSocket::UseNonBlockingIO() {}
1598 
SetMulticastInterface(uint32_t interface_index)1599 int MockUDPClientSocket::SetMulticastInterface(uint32_t interface_index) {
1600   return OK;
1601 }
1602 
NetLog() const1603 const NetLogWithSource& MockUDPClientSocket::NetLog() const {
1604   return net_log_;
1605 }
1606 
Connect(const IPEndPoint & address)1607 int MockUDPClientSocket::Connect(const IPEndPoint& address) {
1608   if (!data_)
1609     return ERR_UNEXPECTED;
1610   DCHECK_NE(data_->connect_data().result, ERR_IO_PENDING);
1611   connected_ = true;
1612   peer_addr_ = address;
1613   return data_->connect_data().result;
1614 }
1615 
ConnectUsingNetwork(handles::NetworkHandle network,const IPEndPoint & address)1616 int MockUDPClientSocket::ConnectUsingNetwork(handles::NetworkHandle network,
1617                                              const IPEndPoint& address) {
1618   DCHECK(!connected_);
1619   if (!data_)
1620     return ERR_UNEXPECTED;
1621   DCHECK_NE(data_->connect_data().result, ERR_IO_PENDING);
1622   network_ = network;
1623   connected_ = true;
1624   peer_addr_ = address;
1625   return data_->connect_data().result;
1626 }
1627 
ConnectUsingDefaultNetwork(const IPEndPoint & address)1628 int MockUDPClientSocket::ConnectUsingDefaultNetwork(const IPEndPoint& address) {
1629   DCHECK(!connected_);
1630   if (!data_)
1631     return ERR_UNEXPECTED;
1632   DCHECK_NE(data_->connect_data().result, ERR_IO_PENDING);
1633   network_ = kDefaultNetworkForTests;
1634   connected_ = true;
1635   peer_addr_ = address;
1636   return data_->connect_data().result;
1637 }
1638 
ConnectAsync(const IPEndPoint & address,CompletionOnceCallback callback)1639 int MockUDPClientSocket::ConnectAsync(const IPEndPoint& address,
1640                                       CompletionOnceCallback callback) {
1641   DCHECK(callback);
1642   if (!data_) {
1643     return ERR_UNEXPECTED;
1644   }
1645   connected_ = true;
1646   peer_addr_ = address;
1647   int result = data_->connect_data().result;
1648   IoMode mode = data_->connect_data().mode;
1649   if (mode == SYNCHRONOUS) {
1650     return result;
1651   }
1652   RunCallbackAsync(std::move(callback), result);
1653   return ERR_IO_PENDING;
1654 }
1655 
ConnectUsingNetworkAsync(handles::NetworkHandle network,const IPEndPoint & address,CompletionOnceCallback callback)1656 int MockUDPClientSocket::ConnectUsingNetworkAsync(
1657     handles::NetworkHandle network,
1658     const IPEndPoint& address,
1659     CompletionOnceCallback callback) {
1660   DCHECK(callback);
1661   DCHECK(!connected_);
1662   if (!data_)
1663     return ERR_UNEXPECTED;
1664   network_ = network;
1665   connected_ = true;
1666   peer_addr_ = address;
1667   int result = data_->connect_data().result;
1668   IoMode mode = data_->connect_data().mode;
1669   if (mode == SYNCHRONOUS) {
1670     return result;
1671   }
1672   RunCallbackAsync(std::move(callback), result);
1673   return ERR_IO_PENDING;
1674 }
1675 
ConnectUsingDefaultNetworkAsync(const IPEndPoint & address,CompletionOnceCallback callback)1676 int MockUDPClientSocket::ConnectUsingDefaultNetworkAsync(
1677     const IPEndPoint& address,
1678     CompletionOnceCallback callback) {
1679   DCHECK(!connected_);
1680   if (!data_)
1681     return ERR_UNEXPECTED;
1682   network_ = kDefaultNetworkForTests;
1683   connected_ = true;
1684   peer_addr_ = address;
1685   int result = data_->connect_data().result;
1686   IoMode mode = data_->connect_data().mode;
1687   if (mode == SYNCHRONOUS) {
1688     return result;
1689   }
1690   RunCallbackAsync(std::move(callback), result);
1691   return ERR_IO_PENDING;
1692 }
1693 
GetBoundNetwork() const1694 handles::NetworkHandle MockUDPClientSocket::GetBoundNetwork() const {
1695   return network_;
1696 }
1697 
ApplySocketTag(const SocketTag & tag)1698 void MockUDPClientSocket::ApplySocketTag(const SocketTag& tag) {
1699   tagged_before_data_transferred_ &= !data_transferred_ || tag == tag_;
1700   tag_ = tag;
1701 }
1702 
OnReadComplete(const MockRead & data)1703 void MockUDPClientSocket::OnReadComplete(const MockRead& data) {
1704   if (!data_)
1705     return;
1706 
1707   // There must be a read pending.
1708   DCHECK(pending_read_buf_.get());
1709   DCHECK(pending_read_callback_);
1710   // You can't complete a read with another ERR_IO_PENDING status code.
1711   DCHECK_NE(ERR_IO_PENDING, data.result);
1712   // Since we've been waiting for data, need_read_data_ should be true.
1713   DCHECK(need_read_data_);
1714 
1715   read_data_ = data;
1716   need_read_data_ = false;
1717 
1718   // The caller is simulating that this IO completes right now.  Don't
1719   // let CompleteRead() schedule a callback.
1720   read_data_.mode = SYNCHRONOUS;
1721 
1722   CompletionOnceCallback callback = std::move(pending_read_callback_);
1723   int rv = CompleteRead();
1724   RunCallback(std::move(callback), rv);
1725 }
1726 
OnWriteComplete(int rv)1727 void MockUDPClientSocket::OnWriteComplete(int rv) {
1728   if (!data_)
1729     return;
1730 
1731   // There must be a read pending.
1732   DCHECK(!pending_write_callback_.is_null());
1733   RunCallback(std::move(pending_write_callback_), rv);
1734 }
1735 
OnConnectComplete(const MockConnect & data)1736 void MockUDPClientSocket::OnConnectComplete(const MockConnect& data) {
1737   NOTIMPLEMENTED();
1738 }
1739 
OnDataProviderDestroyed()1740 void MockUDPClientSocket::OnDataProviderDestroyed() {
1741   data_ = nullptr;
1742 }
1743 
CompleteRead()1744 int MockUDPClientSocket::CompleteRead() {
1745   DCHECK(pending_read_buf_.get());
1746   DCHECK(pending_read_buf_len_ > 0);
1747 
1748   // Save the pending async IO data and reset our |pending_| state.
1749   scoped_refptr<IOBuffer> buf = pending_read_buf_;
1750   int buf_len = pending_read_buf_len_;
1751   CompletionOnceCallback callback = std::move(pending_read_callback_);
1752   pending_read_buf_ = nullptr;
1753   pending_read_buf_len_ = 0;
1754 
1755   int result = read_data_.result;
1756   DCHECK(result != ERR_IO_PENDING);
1757 
1758   if (read_data_.data) {
1759     if (read_data_.data_len - read_offset_ > 0) {
1760       result = std::min(buf_len, read_data_.data_len - read_offset_);
1761       memcpy(buf->data(), read_data_.data + read_offset_, result);
1762       read_offset_ += result;
1763       if (read_offset_ == read_data_.data_len) {
1764         need_read_data_ = true;
1765         read_offset_ = 0;
1766       }
1767     } else {
1768       result = 0;  // EOF
1769     }
1770   }
1771 
1772   if (read_data_.mode == ASYNC) {
1773     DCHECK(!callback.is_null());
1774     RunCallbackAsync(std::move(callback), result);
1775     return ERR_IO_PENDING;
1776   }
1777   return result;
1778 }
1779 
RunCallbackAsync(CompletionOnceCallback callback,int result)1780 void MockUDPClientSocket::RunCallbackAsync(CompletionOnceCallback callback,
1781                                            int result) {
1782   base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
1783       FROM_HERE,
1784       base::BindOnce(&MockUDPClientSocket::RunCallback,
1785                      weak_factory_.GetWeakPtr(), std::move(callback), result));
1786 }
1787 
RunCallback(CompletionOnceCallback callback,int result)1788 void MockUDPClientSocket::RunCallback(CompletionOnceCallback callback,
1789                                       int result) {
1790   std::move(callback).Run(result);
1791 }
1792 
TestSocketRequest(std::vector<TestSocketRequest * > * request_order,size_t * completion_count)1793 TestSocketRequest::TestSocketRequest(
1794     std::vector<TestSocketRequest*>* request_order,
1795     size_t* completion_count)
1796     : request_order_(request_order), completion_count_(completion_count) {
1797   DCHECK(request_order);
1798   DCHECK(completion_count);
1799 }
1800 
1801 TestSocketRequest::~TestSocketRequest() = default;
1802 
OnComplete(int result)1803 void TestSocketRequest::OnComplete(int result) {
1804   SetResult(result);
1805   (*completion_count_)++;
1806   request_order_->push_back(this);
1807 }
1808 
1809 // static
1810 const int ClientSocketPoolTest::kIndexOutOfBounds = -1;
1811 
1812 // static
1813 const int ClientSocketPoolTest::kRequestNotFound = -2;
1814 
1815 ClientSocketPoolTest::ClientSocketPoolTest() = default;
1816 ClientSocketPoolTest::~ClientSocketPoolTest() = default;
1817 
GetOrderOfRequest(size_t index) const1818 int ClientSocketPoolTest::GetOrderOfRequest(size_t index) const {
1819   index--;
1820   if (index >= requests_.size())
1821     return kIndexOutOfBounds;
1822 
1823   for (size_t i = 0; i < request_order_.size(); i++)
1824     if (requests_[index].get() == request_order_[i])
1825       return i + 1;
1826 
1827   return kRequestNotFound;
1828 }
1829 
ReleaseOneConnection(KeepAlive keep_alive)1830 bool ClientSocketPoolTest::ReleaseOneConnection(KeepAlive keep_alive) {
1831   for (std::unique_ptr<TestSocketRequest>& it : requests_) {
1832     if (it->handle()->is_initialized()) {
1833       if (keep_alive == NO_KEEP_ALIVE)
1834         it->handle()->socket()->Disconnect();
1835       it->handle()->Reset();
1836       base::RunLoop().RunUntilIdle();
1837       return true;
1838     }
1839   }
1840   return false;
1841 }
1842 
ReleaseAllConnections(KeepAlive keep_alive)1843 void ClientSocketPoolTest::ReleaseAllConnections(KeepAlive keep_alive) {
1844   bool released_one;
1845   do {
1846     released_one = ReleaseOneConnection(keep_alive);
1847   } while (released_one);
1848 }
1849 
MockConnectJob(std::unique_ptr<StreamSocket> socket,ClientSocketHandle * handle,const SocketTag & socket_tag,CompletionOnceCallback callback,RequestPriority priority)1850 MockTransportClientSocketPool::MockConnectJob::MockConnectJob(
1851     std::unique_ptr<StreamSocket> socket,
1852     ClientSocketHandle* handle,
1853     const SocketTag& socket_tag,
1854     CompletionOnceCallback callback,
1855     RequestPriority priority)
1856     : socket_(std::move(socket)),
1857       handle_(handle),
1858       socket_tag_(socket_tag),
1859       user_callback_(std::move(callback)),
1860       priority_(priority) {}
1861 
1862 MockTransportClientSocketPool::MockConnectJob::~MockConnectJob() = default;
1863 
Connect()1864 int MockTransportClientSocketPool::MockConnectJob::Connect() {
1865   socket_->ApplySocketTag(socket_tag_);
1866   int rv = socket_->Connect(
1867       base::BindOnce(&MockConnectJob::OnConnect, base::Unretained(this)));
1868   if (rv != ERR_IO_PENDING) {
1869     user_callback_.Reset();
1870     OnConnect(rv);
1871   }
1872   return rv;
1873 }
1874 
CancelHandle(const ClientSocketHandle * handle)1875 bool MockTransportClientSocketPool::MockConnectJob::CancelHandle(
1876     const ClientSocketHandle* handle) {
1877   if (handle != handle_)
1878     return false;
1879   socket_.reset();
1880   handle_ = nullptr;
1881   user_callback_.Reset();
1882   return true;
1883 }
1884 
OnConnect(int rv)1885 void MockTransportClientSocketPool::MockConnectJob::OnConnect(int rv) {
1886   if (!socket_.get())
1887     return;
1888   if (rv == OK) {
1889     handle_->SetSocket(std::move(socket_));
1890 
1891     // Needed for socket pool tests that layer other sockets on top of mock
1892     // sockets.
1893     LoadTimingInfo::ConnectTiming connect_timing;
1894     base::TimeTicks now = base::TimeTicks::Now();
1895     connect_timing.domain_lookup_start = now;
1896     connect_timing.domain_lookup_end = now;
1897     connect_timing.connect_start = now;
1898     connect_timing.connect_end = now;
1899     handle_->set_connect_timing(connect_timing);
1900   } else {
1901     socket_.reset();
1902 
1903     // Needed to test copying of ConnectionAttempts in SSL ConnectJob.
1904     ConnectionAttempts attempts;
1905     attempts.push_back(ConnectionAttempt(IPEndPoint(), rv));
1906     handle_->set_connection_attempts(attempts);
1907   }
1908 
1909   handle_ = nullptr;
1910 
1911   if (!user_callback_.is_null()) {
1912     std::move(user_callback_).Run(rv);
1913   }
1914 }
1915 
MockTransportClientSocketPool(int max_sockets,int max_sockets_per_group,const CommonConnectJobParams * common_connect_job_params)1916 MockTransportClientSocketPool::MockTransportClientSocketPool(
1917     int max_sockets,
1918     int max_sockets_per_group,
1919     const CommonConnectJobParams* common_connect_job_params)
1920     : TransportClientSocketPool(
1921           max_sockets,
1922           max_sockets_per_group,
1923           base::Seconds(10) /* unused_idle_socket_timeout */,
1924           ProxyChain::Direct(),
1925           false /* is_for_websockets */,
1926           common_connect_job_params),
1927       client_socket_factory_(common_connect_job_params->client_socket_factory) {
1928 }
1929 
1930 MockTransportClientSocketPool::~MockTransportClientSocketPool() = default;
1931 
RequestSocket(const ClientSocketPool::GroupId & group_id,scoped_refptr<ClientSocketPool::SocketParams> socket_params,const absl::optional<NetworkTrafficAnnotationTag> & proxy_annotation_tag,RequestPriority priority,const SocketTag & socket_tag,RespectLimits respect_limits,ClientSocketHandle * handle,CompletionOnceCallback callback,const ProxyAuthCallback & on_auth_callback,const NetLogWithSource & net_log)1932 int MockTransportClientSocketPool::RequestSocket(
1933     const ClientSocketPool::GroupId& group_id,
1934     scoped_refptr<ClientSocketPool::SocketParams> socket_params,
1935     const absl::optional<NetworkTrafficAnnotationTag>& proxy_annotation_tag,
1936     RequestPriority priority,
1937     const SocketTag& socket_tag,
1938     RespectLimits respect_limits,
1939     ClientSocketHandle* handle,
1940     CompletionOnceCallback callback,
1941     const ProxyAuthCallback& on_auth_callback,
1942     const NetLogWithSource& net_log) {
1943   last_request_priority_ = priority;
1944   std::unique_ptr<StreamSocket> socket =
1945       client_socket_factory_->CreateTransportClientSocket(
1946           AddressList(), nullptr, nullptr, net_log.net_log(), NetLogSource());
1947   auto job = std::make_unique<MockConnectJob>(
1948       std::move(socket), handle, socket_tag, std::move(callback), priority);
1949   auto* job_ptr = job.get();
1950   job_list_.push_back(std::move(job));
1951   handle->set_group_generation(1);
1952   return job_ptr->Connect();
1953 }
1954 
SetPriority(const ClientSocketPool::GroupId & group_id,ClientSocketHandle * handle,RequestPriority priority)1955 void MockTransportClientSocketPool::SetPriority(
1956     const ClientSocketPool::GroupId& group_id,
1957     ClientSocketHandle* handle,
1958     RequestPriority priority) {
1959   for (auto& job : job_list_) {
1960     if (job->handle() == handle) {
1961       job->set_priority(priority);
1962       return;
1963     }
1964   }
1965   NOTREACHED();
1966 }
1967 
CancelRequest(const ClientSocketPool::GroupId & group_id,ClientSocketHandle * handle,bool cancel_connect_job)1968 void MockTransportClientSocketPool::CancelRequest(
1969     const ClientSocketPool::GroupId& group_id,
1970     ClientSocketHandle* handle,
1971     bool cancel_connect_job) {
1972   for (std::unique_ptr<MockConnectJob>& it : job_list_) {
1973     if (it->CancelHandle(handle)) {
1974       cancel_count_++;
1975       break;
1976     }
1977   }
1978 }
1979 
ReleaseSocket(const ClientSocketPool::GroupId & group_id,std::unique_ptr<StreamSocket> socket,int64_t generation)1980 void MockTransportClientSocketPool::ReleaseSocket(
1981     const ClientSocketPool::GroupId& group_id,
1982     std::unique_ptr<StreamSocket> socket,
1983     int64_t generation) {
1984   EXPECT_EQ(1, generation);
1985   release_count_++;
1986 }
1987 
WrappedStreamSocket(std::unique_ptr<StreamSocket> transport)1988 WrappedStreamSocket::WrappedStreamSocket(
1989     std::unique_ptr<StreamSocket> transport)
1990     : transport_(std::move(transport)) {}
1991 WrappedStreamSocket::~WrappedStreamSocket() = default;
1992 
Bind(const net::IPEndPoint & local_addr)1993 int WrappedStreamSocket::Bind(const net::IPEndPoint& local_addr) {
1994   NOTREACHED();
1995   return ERR_FAILED;
1996 }
1997 
Connect(CompletionOnceCallback callback)1998 int WrappedStreamSocket::Connect(CompletionOnceCallback callback) {
1999   return transport_->Connect(std::move(callback));
2000 }
2001 
Disconnect()2002 void WrappedStreamSocket::Disconnect() {
2003   transport_->Disconnect();
2004 }
2005 
IsConnected() const2006 bool WrappedStreamSocket::IsConnected() const {
2007   return transport_->IsConnected();
2008 }
2009 
IsConnectedAndIdle() const2010 bool WrappedStreamSocket::IsConnectedAndIdle() const {
2011   return transport_->IsConnectedAndIdle();
2012 }
2013 
GetPeerAddress(IPEndPoint * address) const2014 int WrappedStreamSocket::GetPeerAddress(IPEndPoint* address) const {
2015   return transport_->GetPeerAddress(address);
2016 }
2017 
GetLocalAddress(IPEndPoint * address) const2018 int WrappedStreamSocket::GetLocalAddress(IPEndPoint* address) const {
2019   return transport_->GetLocalAddress(address);
2020 }
2021 
NetLog() const2022 const NetLogWithSource& WrappedStreamSocket::NetLog() const {
2023   return transport_->NetLog();
2024 }
2025 
WasEverUsed() const2026 bool WrappedStreamSocket::WasEverUsed() const {
2027   return transport_->WasEverUsed();
2028 }
2029 
GetNegotiatedProtocol() const2030 NextProto WrappedStreamSocket::GetNegotiatedProtocol() const {
2031   return transport_->GetNegotiatedProtocol();
2032 }
2033 
GetSSLInfo(SSLInfo * ssl_info)2034 bool WrappedStreamSocket::GetSSLInfo(SSLInfo* ssl_info) {
2035   return transport_->GetSSLInfo(ssl_info);
2036 }
2037 
GetTotalReceivedBytes() const2038 int64_t WrappedStreamSocket::GetTotalReceivedBytes() const {
2039   return transport_->GetTotalReceivedBytes();
2040 }
2041 
ApplySocketTag(const SocketTag & tag)2042 void WrappedStreamSocket::ApplySocketTag(const SocketTag& tag) {
2043   transport_->ApplySocketTag(tag);
2044 }
2045 
Read(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)2046 int WrappedStreamSocket::Read(IOBuffer* buf,
2047                               int buf_len,
2048                               CompletionOnceCallback callback) {
2049   return transport_->Read(buf, buf_len, std::move(callback));
2050 }
2051 
ReadIfReady(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)2052 int WrappedStreamSocket::ReadIfReady(IOBuffer* buf,
2053                                      int buf_len,
2054                                      CompletionOnceCallback callback) {
2055   return transport_->ReadIfReady(buf, buf_len, std::move((callback)));
2056 }
2057 
Write(IOBuffer * buf,int buf_len,CompletionOnceCallback callback,const NetworkTrafficAnnotationTag & traffic_annotation)2058 int WrappedStreamSocket::Write(
2059     IOBuffer* buf,
2060     int buf_len,
2061     CompletionOnceCallback callback,
2062     const NetworkTrafficAnnotationTag& traffic_annotation) {
2063   return transport_->Write(buf, buf_len, std::move(callback),
2064                            TRAFFIC_ANNOTATION_FOR_TESTS);
2065 }
2066 
SetReceiveBufferSize(int32_t size)2067 int WrappedStreamSocket::SetReceiveBufferSize(int32_t size) {
2068   return transport_->SetReceiveBufferSize(size);
2069 }
2070 
SetSendBufferSize(int32_t size)2071 int WrappedStreamSocket::SetSendBufferSize(int32_t size) {
2072   return transport_->SetSendBufferSize(size);
2073 }
2074 
Connect(CompletionOnceCallback callback)2075 int MockTaggingStreamSocket::Connect(CompletionOnceCallback callback) {
2076   connected_ = true;
2077   return WrappedStreamSocket::Connect(std::move(callback));
2078 }
2079 
ApplySocketTag(const SocketTag & tag)2080 void MockTaggingStreamSocket::ApplySocketTag(const SocketTag& tag) {
2081   tagged_before_connected_ &= !connected_ || tag == tag_;
2082   tag_ = tag;
2083   transport_->ApplySocketTag(tag);
2084 }
2085 
2086 std::unique_ptr<TransportClientSocket>
CreateTransportClientSocket(const AddressList & addresses,std::unique_ptr<SocketPerformanceWatcher> socket_performance_watcher,NetworkQualityEstimator * network_quality_estimator,NetLog * net_log,const NetLogSource & source)2087 MockTaggingClientSocketFactory::CreateTransportClientSocket(
2088     const AddressList& addresses,
2089     std::unique_ptr<SocketPerformanceWatcher> socket_performance_watcher,
2090     NetworkQualityEstimator* network_quality_estimator,
2091     NetLog* net_log,
2092     const NetLogSource& source) {
2093   auto socket = std::make_unique<MockTaggingStreamSocket>(
2094       MockClientSocketFactory::CreateTransportClientSocket(
2095           addresses, std::move(socket_performance_watcher),
2096           network_quality_estimator, net_log, source));
2097   tcp_socket_ = socket.get();
2098   return std::move(socket);
2099 }
2100 
2101 std::unique_ptr<DatagramClientSocket>
CreateDatagramClientSocket(DatagramSocket::BindType bind_type,NetLog * net_log,const NetLogSource & source)2102 MockTaggingClientSocketFactory::CreateDatagramClientSocket(
2103     DatagramSocket::BindType bind_type,
2104     NetLog* net_log,
2105     const NetLogSource& source) {
2106   std::unique_ptr<DatagramClientSocket> socket(
2107       MockClientSocketFactory::CreateDatagramClientSocket(bind_type, net_log,
2108                                                           source));
2109   udp_socket_ = static_cast<MockUDPClientSocket*>(socket.get());
2110   return socket;
2111 }
2112 
2113 const char kSOCKS4TestHost[] = "127.0.0.1";
2114 const int kSOCKS4TestPort = 80;
2115 
2116 const char kSOCKS4OkRequestLocalHostPort80[] = {0x04, 0x01, 0x00, 0x50, 127,
2117                                                 0,    0,    1,    0};
2118 const int kSOCKS4OkRequestLocalHostPort80Length =
2119     std::size(kSOCKS4OkRequestLocalHostPort80);
2120 
2121 const char kSOCKS4OkReply[] = {0x00, 0x5A, 0x00, 0x00, 0, 0, 0, 0};
2122 const int kSOCKS4OkReplyLength = std::size(kSOCKS4OkReply);
2123 
2124 const char kSOCKS5TestHost[] = "host";
2125 const int kSOCKS5TestPort = 80;
2126 
2127 const char kSOCKS5GreetRequest[] = {0x05, 0x01, 0x00};
2128 const int kSOCKS5GreetRequestLength = std::size(kSOCKS5GreetRequest);
2129 
2130 const char kSOCKS5GreetResponse[] = {0x05, 0x00};
2131 const int kSOCKS5GreetResponseLength = std::size(kSOCKS5GreetResponse);
2132 
2133 const char kSOCKS5OkRequest[] = {0x05, 0x01, 0x00, 0x03, 0x04, 'h',
2134                                  'o',  's',  't',  0x00, 0x50};
2135 const int kSOCKS5OkRequestLength = std::size(kSOCKS5OkRequest);
2136 
2137 const char kSOCKS5OkResponse[] = {0x05, 0x00, 0x00, 0x01, 127,
2138                                   0,    0,    1,    0x00, 0x50};
2139 const int kSOCKS5OkResponseLength = std::size(kSOCKS5OkResponse);
2140 
CountReadBytes(base::span<const MockRead> reads)2141 int64_t CountReadBytes(base::span<const MockRead> reads) {
2142   int64_t total = 0;
2143   for (const MockRead& read : reads)
2144     total += read.data_len;
2145   return total;
2146 }
2147 
CountWriteBytes(base::span<const MockWrite> writes)2148 int64_t CountWriteBytes(base::span<const MockWrite> writes) {
2149   int64_t total = 0;
2150   for (const MockWrite& write : writes)
2151     total += write.data_len;
2152   return total;
2153 }
2154 
2155 #if BUILDFLAG(IS_ANDROID)
CanGetTaggedBytes()2156 bool CanGetTaggedBytes() {
2157   // In Android P, /proc/net/xt_qtaguid/stats is no longer guaranteed to be
2158   // present, and has been replaced with eBPF Traffic Monitoring in netd. See:
2159   // https://source.android.com/devices/tech/datausage/ebpf-traffic-monitor
2160   //
2161   // To read traffic statistics from netd, apps should use the API
2162   // NetworkStatsManager.queryDetailsForUidTag(). But this API does not provide
2163   // statistics for local traffic, only mobile and WiFi traffic, so it would not
2164   // work in tests that spin up a local server. So for now, GetTaggedBytes is
2165   // only supported on Android releases older than P.
2166   return base::android::BuildInfo::GetInstance()->sdk_int() <
2167          base::android::SDK_VERSION_P;
2168 }
2169 
GetTaggedBytes(int32_t expected_tag)2170 uint64_t GetTaggedBytes(int32_t expected_tag) {
2171   EXPECT_TRUE(CanGetTaggedBytes());
2172 
2173   // To determine how many bytes the system saw with a particular tag read
2174   // the /proc/net/xt_qtaguid/stats file which contains the kernel's
2175   // dump of all the UIDs and their tags sent and received bytes.
2176   uint64_t bytes = 0;
2177   std::string contents;
2178   EXPECT_TRUE(base::ReadFileToString(
2179       base::FilePath::FromUTF8Unsafe("/proc/net/xt_qtaguid/stats"), &contents));
2180   for (size_t i = contents.find('\n');  // Skip first line which is headers.
2181        i != std::string::npos && i < contents.length();) {
2182     uint64_t tag, rx_bytes;
2183     uid_t uid;
2184     int n;
2185     // Parse out the numbers we care about. For reference here's the column
2186     // headers:
2187     // idx iface acct_tag_hex uid_tag_int cnt_set rx_bytes rx_packets tx_bytes
2188     // tx_packets rx_tcp_bytes rx_tcp_packets rx_udp_bytes rx_udp_packets
2189     // rx_other_bytes rx_other_packets tx_tcp_bytes tx_tcp_packets tx_udp_bytes
2190     // tx_udp_packets tx_other_bytes tx_other_packets
2191     EXPECT_EQ(sscanf(contents.c_str() + i,
2192                      "%*d %*s 0x%" SCNx64 " %d %*d %" SCNu64
2193                      " %*d %*d %*d %*d %*d %*d %*d %*d "
2194                      "%*d %*d %*d %*d %*d %*d %*d%n",
2195                      &tag, &uid, &rx_bytes, &n),
2196               3);
2197     // If this line matches our UID and |expected_tag| then add it to the total.
2198     if (uid == getuid() && (int32_t)(tag >> 32) == expected_tag) {
2199       bytes += rx_bytes;
2200     }
2201     // Move |i| to the next line.
2202     i += n + 1;
2203   }
2204   return bytes;
2205 }
2206 #endif
2207 
2208 }  // namespace net
2209