• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/socket/socket_test_util.h"
6 
7 #include <inttypes.h>  // For SCNx64
8 #include <stdint.h>
9 #include <stdio.h>
10 
11 #include <memory>
12 #include <string>
13 #include <utility>
14 #include <vector>
15 
16 #include "base/compiler_specific.h"
17 #include "base/files/file_util.h"
18 #include "base/functional/bind.h"
19 #include "base/functional/callback_helpers.h"
20 #include "base/location.h"
21 #include "base/logging.h"
22 #include "base/rand_util.h"
23 #include "base/ranges/algorithm.h"
24 #include "base/run_loop.h"
25 #include "base/task/single_thread_task_runner.h"
26 #include "base/time/time.h"
27 #include "build/build_config.h"
28 #include "net/base/address_family.h"
29 #include "net/base/address_list.h"
30 #include "net/base/auth.h"
31 #include "net/base/hex_utils.h"
32 #include "net/base/ip_address.h"
33 #include "net/base/load_timing_info.h"
34 #include "net/base/proxy_server.h"
35 #include "net/http/http_network_session.h"
36 #include "net/http/http_request_headers.h"
37 #include "net/http/http_response_headers.h"
38 #include "net/log/net_log_source.h"
39 #include "net/log/net_log_source_type.h"
40 #include "net/socket/connect_job.h"
41 #include "net/socket/socket.h"
42 #include "net/socket/stream_socket.h"
43 #include "net/socket/websocket_endpoint_lock_manager.h"
44 #include "net/ssl/ssl_cert_request_info.h"
45 #include "net/ssl/ssl_connection_status_flags.h"
46 #include "net/ssl/ssl_info.h"
47 #include "net/traffic_annotation/network_traffic_annotation.h"
48 #include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
49 #include "testing/gtest/include/gtest/gtest.h"
50 
51 #if BUILDFLAG(IS_ANDROID)
52 #include "base/android/build_info.h"
53 #endif
54 
55 #define NET_TRACE(level, s) VLOG(level) << s << __FUNCTION__ << "() "
56 
57 namespace net {
58 namespace {
59 
AsciifyHigh(char x)60 inline char AsciifyHigh(char x) {
61   char nybble = static_cast<char>((x >> 4) & 0x0F);
62   return nybble + ((nybble < 0x0A) ? '0' : 'A' - 10);
63 }
64 
AsciifyLow(char x)65 inline char AsciifyLow(char x) {
66   char nybble = static_cast<char>((x >> 0) & 0x0F);
67   return nybble + ((nybble < 0x0A) ? '0' : 'A' - 10);
68 }
69 
Asciify(char x)70 inline char Asciify(char x) {
71   if ((x < 0) || !isprint(x))
72     return '.';
73   return x;
74 }
75 
DumpData(const char * data,int data_len)76 void DumpData(const char* data, int data_len) {
77   if (logging::LOG_INFO < logging::GetMinLogLevel())
78     return;
79   DVLOG(1) << "Length:  " << data_len;
80   const char* pfx = "Data:    ";
81   if (!data || (data_len <= 0)) {
82     DVLOG(1) << pfx << "<None>";
83   } else {
84     int i;
85     for (i = 0; i <= (data_len - 4); i += 4) {
86       DVLOG(1) << pfx << AsciifyHigh(data[i + 0]) << AsciifyLow(data[i + 0])
87                << AsciifyHigh(data[i + 1]) << AsciifyLow(data[i + 1])
88                << AsciifyHigh(data[i + 2]) << AsciifyLow(data[i + 2])
89                << AsciifyHigh(data[i + 3]) << AsciifyLow(data[i + 3]) << "  '"
90                << Asciify(data[i + 0]) << Asciify(data[i + 1])
91                << Asciify(data[i + 2]) << Asciify(data[i + 3]) << "'";
92       pfx = "         ";
93     }
94     // Take care of any 'trailing' bytes, if data_len was not a multiple of 4.
95     switch (data_len - i) {
96       case 3:
97         DVLOG(1) << pfx << AsciifyHigh(data[i + 0]) << AsciifyLow(data[i + 0])
98                  << AsciifyHigh(data[i + 1]) << AsciifyLow(data[i + 1])
99                  << AsciifyHigh(data[i + 2]) << AsciifyLow(data[i + 2])
100                  << "    '" << Asciify(data[i + 0]) << Asciify(data[i + 1])
101                  << Asciify(data[i + 2]) << " '";
102         break;
103       case 2:
104         DVLOG(1) << pfx << AsciifyHigh(data[i + 0]) << AsciifyLow(data[i + 0])
105                  << AsciifyHigh(data[i + 1]) << AsciifyLow(data[i + 1])
106                  << "      '" << Asciify(data[i + 0]) << Asciify(data[i + 1])
107                  << "  '";
108         break;
109       case 1:
110         DVLOG(1) << pfx << AsciifyHigh(data[i + 0]) << AsciifyLow(data[i + 0])
111                  << "        '" << Asciify(data[i + 0]) << "   '";
112         break;
113     }
114   }
115 }
116 
117 template <MockReadWriteType type>
DumpMockReadWrite(const MockReadWrite<type> & r)118 void DumpMockReadWrite(const MockReadWrite<type>& r) {
119   if (logging::LOG_INFO < logging::GetMinLogLevel())
120     return;
121   DVLOG(1) << "Async:   " << (r.mode == ASYNC) << "\nResult:  " << r.result;
122   DumpData(r.data, r.data_len);
123   const char* stop = (r.sequence_number & MockRead::STOPLOOP) ? " (STOP)" : "";
124   DVLOG(1) << "Stage:   " << (r.sequence_number & ~MockRead::STOPLOOP) << stop;
125 }
126 
RunClosureIfNonNull(base::OnceClosure closure)127 void RunClosureIfNonNull(base::OnceClosure closure) {
128   if (!closure.is_null()) {
129     std::move(closure).Run();
130   }
131 }
132 
133 }  // namespace
134 
MockConnect()135 MockConnect::MockConnect() : mode(ASYNC), result(OK) {
136   peer_addr = IPEndPoint(IPAddress(192, 0, 2, 33), 0);
137 }
138 
MockConnect(IoMode io_mode,int r)139 MockConnect::MockConnect(IoMode io_mode, int r) : mode(io_mode), result(r) {
140   peer_addr = IPEndPoint(IPAddress(192, 0, 2, 33), 0);
141 }
142 
MockConnect(IoMode io_mode,int r,IPEndPoint addr)143 MockConnect::MockConnect(IoMode io_mode, int r, IPEndPoint addr)
144     : mode(io_mode), result(r), peer_addr(addr) {}
145 
MockConnect(IoMode io_mode,int r,IPEndPoint addr,bool first_attempt_fails)146 MockConnect::MockConnect(IoMode io_mode,
147                          int r,
148                          IPEndPoint addr,
149                          bool first_attempt_fails)
150     : mode(io_mode),
151       result(r),
152       peer_addr(addr),
153       first_attempt_fails(first_attempt_fails) {}
154 
155 MockConnect::~MockConnect() = default;
156 
MockConfirm()157 MockConfirm::MockConfirm() : mode(SYNCHRONOUS), result(OK) {}
158 
MockConfirm(IoMode io_mode,int r)159 MockConfirm::MockConfirm(IoMode io_mode, int r) : mode(io_mode), result(r) {}
160 
161 MockConfirm::~MockConfirm() = default;
162 
IsIdle() const163 bool SocketDataProvider::IsIdle() const {
164   return true;
165 }
166 
Initialize(AsyncSocket * socket)167 void SocketDataProvider::Initialize(AsyncSocket* socket) {
168   CHECK(!socket_);
169   CHECK(socket);
170   socket_ = socket;
171   Reset();
172 }
173 
DetachSocket()174 void SocketDataProvider::DetachSocket() {
175   CHECK(socket_);
176   socket_ = nullptr;
177 }
178 
179 SocketDataProvider::SocketDataProvider() = default;
180 
~SocketDataProvider()181 SocketDataProvider::~SocketDataProvider() {
182   if (socket_)
183     socket_->OnDataProviderDestroyed();
184 }
185 
StaticSocketDataHelper(base::span<const MockRead> reads,base::span<const MockWrite> writes)186 StaticSocketDataHelper::StaticSocketDataHelper(
187     base::span<const MockRead> reads,
188     base::span<const MockWrite> writes)
189     : reads_(reads), writes_(writes) {}
190 
191 StaticSocketDataHelper::~StaticSocketDataHelper() = default;
192 
PeekRead() const193 const MockRead& StaticSocketDataHelper::PeekRead() const {
194   CHECK(!AllReadDataConsumed());
195   return reads_[read_index_];
196 }
197 
PeekWrite() const198 const MockWrite& StaticSocketDataHelper::PeekWrite() const {
199   CHECK(!AllWriteDataConsumed());
200   return writes_[write_index_];
201 }
202 
AdvanceRead()203 const MockRead& StaticSocketDataHelper::AdvanceRead() {
204   CHECK(!AllReadDataConsumed());
205   return reads_[read_index_++];
206 }
207 
AdvanceWrite()208 const MockWrite& StaticSocketDataHelper::AdvanceWrite() {
209   CHECK(!AllWriteDataConsumed());
210   return writes_[write_index_++];
211 }
212 
Reset()213 void StaticSocketDataHelper::Reset() {
214   read_index_ = 0;
215   write_index_ = 0;
216 }
217 
VerifyWriteData(const std::string & data,SocketDataPrinter * printer)218 bool StaticSocketDataHelper::VerifyWriteData(const std::string& data,
219                                              SocketDataPrinter* printer) {
220   CHECK(!AllWriteDataConsumed());
221   // Check that the actual data matches the expectations, skipping over any
222   // pause events.
223   const MockWrite& next_write = PeekRealWrite();
224   if (!next_write.data)
225     return true;
226 
227   // Note: Partial writes are supported here.  If the expected data
228   // is a match, but shorter than the write actually written, that is legal.
229   // Example:
230   //   Application writes "foobarbaz" (9 bytes)
231   //   Expected write was "foo" (3 bytes)
232   //   This is a success, and the function returns true.
233   std::string expected_data(next_write.data, next_write.data_len);
234   std::string actual_data(data.substr(0, next_write.data_len));
235   EXPECT_GE(data.length(), expected_data.length());
236   if (printer) {
237     EXPECT_TRUE(actual_data == expected_data)
238         << "Actual formatted write data:\n"
239         << printer->PrintWrite(data) << "Expected formatted write data:\n"
240         << printer->PrintWrite(expected_data) << "Actual raw write data:\n"
241         << HexDump(data) << "Expected raw write data:\n"
242         << HexDump(expected_data);
243   } else {
244     EXPECT_TRUE(actual_data == expected_data)
245         << "Actual write data:\n"
246         << HexDump(data) << "Expected write data:\n"
247         << HexDump(expected_data);
248   }
249   return expected_data == actual_data;
250 }
251 
PeekRealWrite() const252 const MockWrite& StaticSocketDataHelper::PeekRealWrite() const {
253   for (size_t i = write_index_; i < write_count(); i++) {
254     if (writes_[i].mode != ASYNC || writes_[i].result != ERR_IO_PENDING)
255       return writes_[i];
256   }
257 
258   CHECK(false) << "No write data available.";
259   return writes_[0];  // Avoid warning about unreachable missing return.
260 }
261 
StaticSocketDataProvider()262 StaticSocketDataProvider::StaticSocketDataProvider()
263     : StaticSocketDataProvider(base::span<const MockRead>(),
264                                base::span<const MockWrite>()) {}
265 
StaticSocketDataProvider(base::span<const MockRead> reads,base::span<const MockWrite> writes)266 StaticSocketDataProvider::StaticSocketDataProvider(
267     base::span<const MockRead> reads,
268     base::span<const MockWrite> writes)
269     : helper_(reads, writes) {}
270 
271 StaticSocketDataProvider::~StaticSocketDataProvider() = default;
272 
Pause()273 void StaticSocketDataProvider::Pause() {
274   paused_ = true;
275 }
276 
Resume()277 void StaticSocketDataProvider::Resume() {
278   paused_ = false;
279 }
280 
OnRead()281 MockRead StaticSocketDataProvider::OnRead() {
282   if (AllReadDataConsumed()) {
283     const net::MockRead pending_read(net::SYNCHRONOUS, net::ERR_IO_PENDING);
284     return pending_read;
285   }
286 
287   return helper_.AdvanceRead();
288 }
289 
OnWrite(const std::string & data)290 MockWriteResult StaticSocketDataProvider::OnWrite(const std::string& data) {
291   if (helper_.write_count() == 0) {
292     // Not using mock writes; succeed synchronously.
293     return MockWriteResult(SYNCHRONOUS, data.length());
294   }
295   if (printer_) {
296     EXPECT_FALSE(helper_.AllWriteDataConsumed())
297         << "No more mock data to match write:\nFormatted write data:\n"
298         << printer_->PrintWrite(data) << "Raw write data:\n"
299         << HexDump(data);
300   } else {
301     EXPECT_FALSE(helper_.AllWriteDataConsumed())
302         << "No more mock data to match write:\nRaw write data:\n"
303         << HexDump(data);
304   }
305   if (helper_.AllWriteDataConsumed()) {
306     return MockWriteResult(SYNCHRONOUS, ERR_UNEXPECTED);
307   }
308 
309   // Check that what we are writing matches the expectation.
310   // Then give the mocked return value.
311   if (!helper_.VerifyWriteData(data, printer_))
312     return MockWriteResult(SYNCHRONOUS, ERR_UNEXPECTED);
313 
314   const MockWrite& next_write = helper_.AdvanceWrite();
315   // In the case that the write was successful, return the number of bytes
316   // written. Otherwise return the error code.
317   int result =
318       next_write.result == OK ? next_write.data_len : next_write.result;
319   return MockWriteResult(next_write.mode, result);
320 }
321 
AllReadDataConsumed() const322 bool StaticSocketDataProvider::AllReadDataConsumed() const {
323   return paused_ || helper_.AllReadDataConsumed();
324 }
325 
AllWriteDataConsumed() const326 bool StaticSocketDataProvider::AllWriteDataConsumed() const {
327   return helper_.AllWriteDataConsumed();
328 }
329 
Reset()330 void StaticSocketDataProvider::Reset() {
331   helper_.Reset();
332 }
333 
SSLSocketDataProvider(IoMode mode,int result)334 SSLSocketDataProvider::SSLSocketDataProvider(IoMode mode, int result)
335     : connect(mode, result),
336       expected_ssl_version_min(kDefaultSSLVersionMin),
337       expected_ssl_version_max(kDefaultSSLVersionMax) {
338   SSLConnectionStatusSetVersion(SSL_CONNECTION_VERSION_TLS1_3,
339                                 &ssl_info.connection_status);
340   // Set to TLS_CHACHA20_POLY1305_SHA256
341   SSLConnectionStatusSetCipherSuite(0x1301, &ssl_info.connection_status);
342 }
343 
344 SSLSocketDataProvider::SSLSocketDataProvider(
345     const SSLSocketDataProvider& other) = default;
346 
347 SSLSocketDataProvider::~SSLSocketDataProvider() = default;
348 
SequencedSocketData()349 SequencedSocketData::SequencedSocketData()
350     : SequencedSocketData(base::span<const MockRead>(),
351                           base::span<const MockWrite>()) {}
352 
SequencedSocketData(base::span<const MockRead> reads,base::span<const MockWrite> writes)353 SequencedSocketData::SequencedSocketData(base::span<const MockRead> reads,
354                                          base::span<const MockWrite> writes)
355     : helper_(reads, writes) {
356   // Check that reads and writes have a contiguous set of sequence numbers
357   // starting from 0 and working their way up, with no repeats and skipping
358   // no values.
359   int next_sequence_number = 0;
360   bool last_event_was_pause = false;
361 
362   auto next_read = reads.begin();
363   auto next_write = writes.begin();
364   while (next_read != reads.end() || next_write != writes.end()) {
365     if (next_read != reads.end() &&
366         next_read->sequence_number == next_sequence_number) {
367       // Check if this is a pause.
368       if (next_read->mode == ASYNC && next_read->result == ERR_IO_PENDING) {
369         CHECK(!last_event_was_pause)
370             << "Two pauses in a row are not allowed: " << next_sequence_number;
371         last_event_was_pause = true;
372       } else if (last_event_was_pause) {
373         CHECK_EQ(ASYNC, next_read->mode)
374             << "A sync event after a pause makes no sense: "
375             << next_sequence_number;
376         CHECK_NE(ERR_IO_PENDING, next_read->result)
377             << "A pause event after a pause makes no sense: "
378             << next_sequence_number;
379         last_event_was_pause = false;
380       }
381 
382       ++next_read;
383       ++next_sequence_number;
384       continue;
385     }
386     if (next_write != writes.end() &&
387         next_write->sequence_number == next_sequence_number) {
388       // Check if this is a pause.
389       if (next_write->mode == ASYNC && next_write->result == ERR_IO_PENDING) {
390         CHECK(!last_event_was_pause)
391             << "Two pauses in a row are not allowed: " << next_sequence_number;
392         last_event_was_pause = true;
393       } else if (last_event_was_pause) {
394         CHECK_EQ(ASYNC, next_write->mode)
395             << "A sync event after a pause makes no sense: "
396             << next_sequence_number;
397         CHECK_NE(ERR_IO_PENDING, next_write->result)
398             << "A pause event after a pause makes no sense: "
399             << next_sequence_number;
400         last_event_was_pause = false;
401       }
402 
403       ++next_write;
404       ++next_sequence_number;
405       continue;
406     }
407     if (next_write != writes.end()) {
408       CHECK(false) << "Sequence number " << next_write->sequence_number
409                    << " not found where expected: " << next_sequence_number;
410     } else {
411       CHECK(false) << "Too few writes, next expected sequence number: "
412                    << next_sequence_number;
413     }
414     return;
415   }
416 
417   // Last event must not be a pause.  For the final event to indicate the
418   // operation never completes, it should be SYNCHRONOUS and return
419   // ERR_IO_PENDING.
420   CHECK(!last_event_was_pause);
421 
422   CHECK(next_read == reads.end());
423   CHECK(next_write == writes.end());
424 }
425 
SequencedSocketData(const MockConnect & connect,base::span<const MockRead> reads,base::span<const MockWrite> writes)426 SequencedSocketData::SequencedSocketData(const MockConnect& connect,
427                                          base::span<const MockRead> reads,
428                                          base::span<const MockWrite> writes)
429     : SequencedSocketData(reads, writes) {
430   set_connect_data(connect);
431 }
OnRead()432 MockRead SequencedSocketData::OnRead() {
433   CHECK_EQ(IoState::kIdle, read_state_);
434   CHECK(!helper_.AllReadDataConsumed())
435       << "Application tried to read but there is no read data left";
436 
437   NET_TRACE(1, " *** ") << "sequence_number: " << sequence_number_;
438   const MockRead& next_read = helper_.PeekRead();
439   NET_TRACE(1, " *** ") << "next_read: " << next_read.sequence_number;
440   CHECK_GE(next_read.sequence_number, sequence_number_);
441 
442   if (next_read.sequence_number <= sequence_number_) {
443     if (next_read.mode == SYNCHRONOUS) {
444       NET_TRACE(1, " *** ") << "Returning synchronously";
445       DumpMockReadWrite(next_read);
446       helper_.AdvanceRead();
447       ++sequence_number_;
448       MaybePostWriteCompleteTask();
449       return next_read;
450     }
451 
452     // If the result is ERR_IO_PENDING, then pause.
453     if (next_read.result == ERR_IO_PENDING) {
454       NET_TRACE(1, " *** ") << "Pausing read at: " << sequence_number_;
455       read_state_ = IoState::kPaused;
456       if (run_until_paused_run_loop_)
457         run_until_paused_run_loop_->Quit();
458       return MockRead(SYNCHRONOUS, ERR_IO_PENDING);
459     }
460     base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
461         FROM_HERE, base::BindOnce(&SequencedSocketData::OnReadComplete,
462                                   weak_factory_.GetWeakPtr()));
463     CHECK_NE(IoState::kCompleting, write_state_);
464     read_state_ = IoState::kCompleting;
465   } else if (next_read.mode == SYNCHRONOUS) {
466     ADD_FAILURE() << "Unable to perform synchronous IO while stopped";
467     return MockRead(SYNCHRONOUS, ERR_UNEXPECTED);
468   } else {
469     NET_TRACE(1, " *** ") << "Waiting for write to trigger read";
470     read_state_ = IoState::kPending;
471   }
472 
473   return MockRead(SYNCHRONOUS, ERR_IO_PENDING);
474 }
475 
OnWrite(const std::string & data)476 MockWriteResult SequencedSocketData::OnWrite(const std::string& data) {
477   CHECK_EQ(IoState::kIdle, write_state_);
478   if (printer_) {
479     CHECK(!helper_.AllWriteDataConsumed())
480         << "\nNo more mock data to match write:\nFormatted write data:\n"
481         << printer_->PrintWrite(data) << "Raw write data:\n"
482         << HexDump(data);
483   } else {
484     CHECK(!helper_.AllWriteDataConsumed())
485         << "\nNo more mock data to match write:\nRaw write data:\n"
486         << HexDump(data);
487   }
488 
489   NET_TRACE(1, " *** ") << "sequence_number: " << sequence_number_;
490   const MockWrite& next_write = helper_.PeekWrite();
491   NET_TRACE(1, " *** ") << "next_write: " << next_write.sequence_number;
492   CHECK_GE(next_write.sequence_number, sequence_number_);
493 
494   if (!helper_.VerifyWriteData(data, printer_))
495     return MockWriteResult(SYNCHRONOUS, ERR_UNEXPECTED);
496 
497   if (next_write.sequence_number <= sequence_number_) {
498     if (next_write.mode == SYNCHRONOUS) {
499       helper_.AdvanceWrite();
500       ++sequence_number_;
501       MaybePostReadCompleteTask();
502       // In the case that the write was successful, return the number of bytes
503       // written. Otherwise return the error code.
504       int rv =
505           next_write.result != OK ? next_write.result : next_write.data_len;
506       NET_TRACE(1, " *** ") << "Returning synchronously";
507       return MockWriteResult(SYNCHRONOUS, rv);
508     }
509 
510     // If the result is ERR_IO_PENDING, then pause.
511     if (next_write.result == ERR_IO_PENDING) {
512       NET_TRACE(1, " *** ") << "Pausing write at: " << sequence_number_;
513       write_state_ = IoState::kPaused;
514       if (run_until_paused_run_loop_)
515         run_until_paused_run_loop_->Quit();
516       return MockWriteResult(SYNCHRONOUS, ERR_IO_PENDING);
517     }
518 
519     NET_TRACE(1, " *** ") << "Posting task to complete write";
520     base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
521         FROM_HERE, base::BindOnce(&SequencedSocketData::OnWriteComplete,
522                                   weak_factory_.GetWeakPtr()));
523     CHECK_NE(IoState::kCompleting, read_state_);
524     write_state_ = IoState::kCompleting;
525   } else if (next_write.mode == SYNCHRONOUS) {
526     ADD_FAILURE() << "Unable to perform synchronous IO while stopped";
527     return MockWriteResult(SYNCHRONOUS, ERR_UNEXPECTED);
528   } else {
529     NET_TRACE(1, " *** ") << "Waiting for read to trigger write";
530     write_state_ = IoState::kPending;
531   }
532 
533   return MockWriteResult(SYNCHRONOUS, ERR_IO_PENDING);
534 }
535 
AllReadDataConsumed() const536 bool SequencedSocketData::AllReadDataConsumed() const {
537   return helper_.AllReadDataConsumed();
538 }
539 
CancelPendingRead()540 void SequencedSocketData::CancelPendingRead() {
541   DCHECK_EQ(IoState::kPending, read_state_);
542 
543   read_state_ = IoState::kIdle;
544 }
545 
AllWriteDataConsumed() const546 bool SequencedSocketData::AllWriteDataConsumed() const {
547   return helper_.AllWriteDataConsumed();
548 }
549 
IsIdle() const550 bool SequencedSocketData::IsIdle() const {
551   // If |busy_before_sync_reads_| is not set, always considered idle.  If
552   // no reads left, or the next operation is a write, also consider it idle.
553   if (!busy_before_sync_reads_ || helper_.AllReadDataConsumed() ||
554       helper_.PeekRead().sequence_number != sequence_number_) {
555     return true;
556   }
557 
558   // If the next operation is synchronous read, treat the socket as not idle.
559   if (helper_.PeekRead().mode == SYNCHRONOUS)
560     return false;
561   return true;
562 }
563 
IsPaused() const564 bool SequencedSocketData::IsPaused() const {
565   // Both states should not be paused.
566   DCHECK(read_state_ != IoState::kPaused || write_state_ != IoState::kPaused);
567   return write_state_ == IoState::kPaused || read_state_ == IoState::kPaused;
568 }
569 
Resume()570 void SequencedSocketData::Resume() {
571   if (!IsPaused()) {
572     ADD_FAILURE() << "Unable to Resume when not paused.";
573     return;
574   }
575 
576   sequence_number_++;
577   if (read_state_ == IoState::kPaused) {
578     read_state_ = IoState::kPending;
579     helper_.AdvanceRead();
580   } else {  // write_state_ == IoState::kPaused
581     write_state_ = IoState::kPending;
582     helper_.AdvanceWrite();
583   }
584 
585   if (!helper_.AllWriteDataConsumed() &&
586       helper_.PeekWrite().sequence_number == sequence_number_) {
587     // The next event hasn't even started yet.  Pausing isn't really needed in
588     // that case, but may as well support it.
589     if (write_state_ != IoState::kPending)
590       return;
591     write_state_ = IoState::kCompleting;
592     OnWriteComplete();
593     return;
594   }
595 
596   CHECK(!helper_.AllReadDataConsumed());
597 
598   // The next event hasn't even started yet.  Pausing isn't really needed in
599   // that case, but may as well support it.
600   if (read_state_ != IoState::kPending)
601     return;
602   read_state_ = IoState::kCompleting;
603   OnReadComplete();
604 }
605 
RunUntilPaused()606 void SequencedSocketData::RunUntilPaused() {
607   CHECK(!run_until_paused_run_loop_);
608 
609   if (IsPaused())
610     return;
611 
612   run_until_paused_run_loop_ = std::make_unique<base::RunLoop>();
613   run_until_paused_run_loop_->Run();
614   run_until_paused_run_loop_.reset();
615   DCHECK(IsPaused());
616 }
617 
MaybePostReadCompleteTask()618 void SequencedSocketData::MaybePostReadCompleteTask() {
619   NET_TRACE(1, " ****** ") << " current: " << sequence_number_;
620   // Only trigger the next read to complete if there is already a read pending
621   // which should complete at the current sequence number.
622   if (read_state_ != IoState::kPending ||
623       helper_.PeekRead().sequence_number != sequence_number_) {
624     return;
625   }
626 
627   // If the result is ERR_IO_PENDING, then pause.
628   if (helper_.PeekRead().result == ERR_IO_PENDING) {
629     NET_TRACE(1, " *** ") << "Pausing read at: " << sequence_number_;
630     read_state_ = IoState::kPaused;
631     if (run_until_paused_run_loop_)
632       run_until_paused_run_loop_->Quit();
633     return;
634   }
635 
636   NET_TRACE(1, " ****** ") << "Posting task to complete read: "
637                            << sequence_number_;
638   base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
639       FROM_HERE, base::BindOnce(&SequencedSocketData::OnReadComplete,
640                                 weak_factory_.GetWeakPtr()));
641   CHECK_NE(IoState::kCompleting, write_state_);
642   read_state_ = IoState::kCompleting;
643 }
644 
MaybePostWriteCompleteTask()645 void SequencedSocketData::MaybePostWriteCompleteTask() {
646   NET_TRACE(1, " ****** ") << " current: " << sequence_number_;
647   // Only trigger the next write to complete if there is already a write pending
648   // which should complete at the current sequence number.
649   if (write_state_ != IoState::kPending ||
650       helper_.PeekWrite().sequence_number != sequence_number_) {
651     return;
652   }
653 
654   // If the result is ERR_IO_PENDING, then pause.
655   if (helper_.PeekWrite().result == ERR_IO_PENDING) {
656     NET_TRACE(1, " *** ") << "Pausing write at: " << sequence_number_;
657     write_state_ = IoState::kPaused;
658     if (run_until_paused_run_loop_)
659       run_until_paused_run_loop_->Quit();
660     return;
661   }
662 
663   NET_TRACE(1, " ****** ") << "Posting task to complete write: "
664                            << sequence_number_;
665   base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
666       FROM_HERE, base::BindOnce(&SequencedSocketData::OnWriteComplete,
667                                 weak_factory_.GetWeakPtr()));
668   CHECK_NE(IoState::kCompleting, read_state_);
669   write_state_ = IoState::kCompleting;
670 }
671 
Reset()672 void SequencedSocketData::Reset() {
673   helper_.Reset();
674   sequence_number_ = 0;
675   read_state_ = IoState::kIdle;
676   write_state_ = IoState::kIdle;
677   weak_factory_.InvalidateWeakPtrs();
678 }
679 
OnReadComplete()680 void SequencedSocketData::OnReadComplete() {
681   CHECK_EQ(IoState::kCompleting, read_state_);
682   NET_TRACE(1, " *** ") << "Completing read for: " << sequence_number_;
683 
684   MockRead data = helper_.AdvanceRead();
685   DCHECK_EQ(sequence_number_, data.sequence_number);
686   sequence_number_++;
687   read_state_ = IoState::kIdle;
688 
689   // The result of this read completing might trigger the completion
690   // of a pending write. If so, post a task to complete the write later.
691   // Since the socket may call back into the SequencedSocketData
692   // from socket()->OnReadComplete(), trigger the write task to be posted
693   // before calling that.
694   MaybePostWriteCompleteTask();
695 
696   if (!socket()) {
697     NET_TRACE(1, " *** ") << "No socket available to complete read";
698     return;
699   }
700 
701   NET_TRACE(1, " *** ") << "Completing socket read for: "
702                         << data.sequence_number;
703   DumpMockReadWrite(data);
704   socket()->OnReadComplete(data);
705   NET_TRACE(1, " *** ") << "Done";
706 }
707 
OnWriteComplete()708 void SequencedSocketData::OnWriteComplete() {
709   CHECK_EQ(IoState::kCompleting, write_state_);
710   NET_TRACE(1, " *** ") << " Completing write for: " << sequence_number_;
711 
712   const MockWrite& data = helper_.AdvanceWrite();
713   DCHECK_EQ(sequence_number_, data.sequence_number);
714   sequence_number_++;
715   write_state_ = IoState::kIdle;
716   int rv = data.result == OK ? data.data_len : data.result;
717 
718   // The result of this write completing might trigger the completion
719   // of a pending read. If so, post a task to complete the read later.
720   // Since the socket may call back into the SequencedSocketData
721   // from socket()->OnWriteComplete(), trigger the write task to be posted
722   // before calling that.
723   MaybePostReadCompleteTask();
724 
725   if (!socket()) {
726     NET_TRACE(1, " *** ") << "No socket available to complete write";
727     return;
728   }
729 
730   NET_TRACE(1, " *** ") << " Completing socket write for: "
731                         << data.sequence_number;
732   socket()->OnWriteComplete(rv);
733   NET_TRACE(1, " *** ") << "Done";
734 }
735 
736 SequencedSocketData::~SequencedSocketData() = default;
737 
738 MockClientSocketFactory::MockClientSocketFactory() = default;
739 
740 MockClientSocketFactory::~MockClientSocketFactory() = default;
741 
AddSocketDataProvider(SocketDataProvider * data)742 void MockClientSocketFactory::AddSocketDataProvider(SocketDataProvider* data) {
743   mock_data_.Add(data);
744 }
745 
AddTcpSocketDataProvider(SocketDataProvider * data)746 void MockClientSocketFactory::AddTcpSocketDataProvider(
747     SocketDataProvider* data) {
748   mock_tcp_data_.Add(data);
749 }
750 
AddSSLSocketDataProvider(SSLSocketDataProvider * data)751 void MockClientSocketFactory::AddSSLSocketDataProvider(
752     SSLSocketDataProvider* data) {
753   mock_ssl_data_.Add(data);
754 }
755 
ResetNextMockIndexes()756 void MockClientSocketFactory::ResetNextMockIndexes() {
757   mock_data_.ResetNextIndex();
758   mock_ssl_data_.ResetNextIndex();
759 }
760 
761 std::unique_ptr<DatagramClientSocket>
CreateDatagramClientSocket(DatagramSocket::BindType bind_type,NetLog * net_log,const NetLogSource & source)762 MockClientSocketFactory::CreateDatagramClientSocket(
763     DatagramSocket::BindType bind_type,
764     NetLog* net_log,
765     const NetLogSource& source) {
766   SocketDataProvider* data_provider = mock_data_.GetNext();
767   auto socket = std::make_unique<MockUDPClientSocket>(data_provider, net_log);
768   if (bind_type == DatagramSocket::RANDOM_BIND)
769     socket->set_source_port(static_cast<uint16_t>(base::RandInt(1025, 65535)));
770   udp_client_socket_ports_.push_back(socket->source_port());
771   return std::move(socket);
772 }
773 
774 std::unique_ptr<TransportClientSocket>
CreateTransportClientSocket(const AddressList & addresses,std::unique_ptr<SocketPerformanceWatcher> socket_performance_watcher,NetworkQualityEstimator * network_quality_estimator,NetLog * net_log,const NetLogSource & source)775 MockClientSocketFactory::CreateTransportClientSocket(
776     const AddressList& addresses,
777     std::unique_ptr<SocketPerformanceWatcher> socket_performance_watcher,
778     NetworkQualityEstimator* network_quality_estimator,
779     NetLog* net_log,
780     const NetLogSource& source) {
781   SocketDataProvider* data_provider = mock_tcp_data_.GetNextWithoutAsserting();
782   if (!data_provider)
783     data_provider = mock_data_.GetNext();
784   auto socket =
785       std::make_unique<MockTCPClientSocket>(addresses, net_log, data_provider);
786   if (enable_read_if_ready_)
787     socket->set_enable_read_if_ready(enable_read_if_ready_);
788   return std::move(socket);
789 }
790 
CreateSSLClientSocket(SSLClientContext * context,std::unique_ptr<StreamSocket> stream_socket,const HostPortPair & host_and_port,const SSLConfig & ssl_config)791 std::unique_ptr<SSLClientSocket> MockClientSocketFactory::CreateSSLClientSocket(
792     SSLClientContext* context,
793     std::unique_ptr<StreamSocket> stream_socket,
794     const HostPortPair& host_and_port,
795     const SSLConfig& ssl_config) {
796   SSLSocketDataProvider* next_ssl_data = mock_ssl_data_.GetNext();
797   if (next_ssl_data->next_protos_expected_in_ssl_config.has_value()) {
798     EXPECT_TRUE(base::ranges::equal(
799         next_ssl_data->next_protos_expected_in_ssl_config.value(),
800         ssl_config.alpn_protos));
801   }
802 
803   // The protocol version used is a combination of the per-socket SSLConfig and
804   // the SSLConfigService.
805   EXPECT_EQ(
806       next_ssl_data->expected_ssl_version_min,
807       ssl_config.version_min_override.value_or(context->config().version_min));
808   EXPECT_EQ(
809       next_ssl_data->expected_ssl_version_max,
810       ssl_config.version_max_override.value_or(context->config().version_max));
811 
812   if (next_ssl_data->expected_send_client_cert) {
813     // Client certificate preferences come from |context|.
814     scoped_refptr<X509Certificate> client_cert;
815     scoped_refptr<SSLPrivateKey> client_private_key;
816     bool send_client_cert = context->GetClientCertificate(
817         host_and_port, &client_cert, &client_private_key);
818 
819     EXPECT_EQ(*next_ssl_data->expected_send_client_cert, send_client_cert);
820     // Note |send_client_cert| may be true while |client_cert| is null if the
821     // socket is configured to continue without a certificate, as opposed to
822     // surfacing the certificate challenge.
823     EXPECT_EQ(!!next_ssl_data->expected_client_cert, !!client_cert);
824     if (next_ssl_data->expected_client_cert && client_cert) {
825       EXPECT_TRUE(next_ssl_data->expected_client_cert->EqualsIncludingChain(
826           client_cert.get()));
827     }
828   }
829   if (next_ssl_data->expected_host_and_port) {
830     EXPECT_EQ(*next_ssl_data->expected_host_and_port, host_and_port);
831   }
832   if (next_ssl_data->expected_network_anonymization_key) {
833     EXPECT_EQ(*next_ssl_data->expected_network_anonymization_key,
834               ssl_config.network_anonymization_key);
835   }
836   if (next_ssl_data->expected_disable_sha1_server_signatures) {
837     EXPECT_EQ(*next_ssl_data->expected_disable_sha1_server_signatures,
838               ssl_config.disable_sha1_server_signatures);
839   }
840   if (next_ssl_data->expected_ech_config_list) {
841     EXPECT_EQ(*next_ssl_data->expected_ech_config_list,
842               ssl_config.ech_config_list);
843   }
844   return std::make_unique<MockSSLClientSocket>(
845       std::move(stream_socket), host_and_port, ssl_config, next_ssl_data);
846 }
847 
MockClientSocket(const NetLogWithSource & net_log)848 MockClientSocket::MockClientSocket(const NetLogWithSource& net_log)
849     : net_log_(net_log) {
850   local_addr_ = IPEndPoint(IPAddress(192, 0, 2, 33), 123);
851   peer_addr_ = IPEndPoint(IPAddress(192, 0, 2, 33), 0);
852 }
853 
SetReceiveBufferSize(int32_t size)854 int MockClientSocket::SetReceiveBufferSize(int32_t size) {
855   return OK;
856 }
857 
SetSendBufferSize(int32_t size)858 int MockClientSocket::SetSendBufferSize(int32_t size) {
859   return OK;
860 }
861 
Bind(const net::IPEndPoint & local_addr)862 int MockClientSocket::Bind(const net::IPEndPoint& local_addr) {
863   local_addr_ = local_addr;
864   return net::OK;
865 }
866 
SetNoDelay(bool no_delay)867 bool MockClientSocket::SetNoDelay(bool no_delay) {
868   return true;
869 }
870 
SetKeepAlive(bool enable,int delay)871 bool MockClientSocket::SetKeepAlive(bool enable, int delay) {
872   return true;
873 }
874 
Disconnect()875 void MockClientSocket::Disconnect() {
876   connected_ = false;
877 }
878 
IsConnected() const879 bool MockClientSocket::IsConnected() const {
880   return connected_;
881 }
882 
IsConnectedAndIdle() const883 bool MockClientSocket::IsConnectedAndIdle() const {
884   return connected_;
885 }
886 
GetPeerAddress(IPEndPoint * address) const887 int MockClientSocket::GetPeerAddress(IPEndPoint* address) const {
888   if (!IsConnected())
889     return ERR_SOCKET_NOT_CONNECTED;
890   *address = peer_addr_;
891   return OK;
892 }
893 
GetLocalAddress(IPEndPoint * address) const894 int MockClientSocket::GetLocalAddress(IPEndPoint* address) const {
895   *address = local_addr_;
896   return OK;
897 }
898 
NetLog() const899 const NetLogWithSource& MockClientSocket::NetLog() const {
900   return net_log_;
901 }
902 
WasAlpnNegotiated() const903 bool MockClientSocket::WasAlpnNegotiated() const {
904   return false;
905 }
906 
GetNegotiatedProtocol() const907 NextProto MockClientSocket::GetNegotiatedProtocol() const {
908   return kProtoUnknown;
909 }
910 
911 MockClientSocket::~MockClientSocket() = default;
912 
RunCallbackAsync(CompletionOnceCallback callback,int result)913 void MockClientSocket::RunCallbackAsync(CompletionOnceCallback callback,
914                                         int result) {
915   base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
916       FROM_HERE,
917       base::BindOnce(&MockClientSocket::RunCallback, weak_factory_.GetWeakPtr(),
918                      std::move(callback), result));
919 }
920 
RunCallback(CompletionOnceCallback callback,int result)921 void MockClientSocket::RunCallback(CompletionOnceCallback callback,
922                                    int result) {
923   std::move(callback).Run(result);
924 }
925 
MockTCPClientSocket(const AddressList & addresses,net::NetLog * net_log,SocketDataProvider * data)926 MockTCPClientSocket::MockTCPClientSocket(const AddressList& addresses,
927                                          net::NetLog* net_log,
928                                          SocketDataProvider* data)
929     : MockClientSocket(NetLogWithSource::Make(net_log, NetLogSourceType::NONE)),
930       addresses_(addresses),
931       data_(data),
932       read_data_(SYNCHRONOUS, ERR_UNEXPECTED) {
933   DCHECK(data_);
934   peer_addr_ = data->connect_data().peer_addr;
935   data_->Initialize(this);
936   if (data_->expected_addresses()) {
937     EXPECT_EQ(*data_->expected_addresses(), addresses);
938   }
939 }
940 
~MockTCPClientSocket()941 MockTCPClientSocket::~MockTCPClientSocket() {
942   if (data_)
943     data_->DetachSocket();
944 }
945 
Read(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)946 int MockTCPClientSocket::Read(IOBuffer* buf,
947                               int buf_len,
948                               CompletionOnceCallback callback) {
949   // If the buffer is already in use, a read is already in progress!
950   DCHECK(!pending_read_buf_);
951   // Use base::Unretained() is safe because MockClientSocket::RunCallbackAsync()
952   // takes a weak ptr of the base class, MockClientSocket.
953   int rv = ReadIfReadyImpl(
954       buf, buf_len,
955       base::BindOnce(&MockTCPClientSocket::RetryRead, base::Unretained(this)));
956   if (rv == ERR_IO_PENDING) {
957     DCHECK(callback);
958 
959     pending_read_buf_ = buf;
960     pending_read_buf_len_ = buf_len;
961     pending_read_callback_ = std::move(callback);
962   }
963   return rv;
964 }
965 
ReadIfReady(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)966 int MockTCPClientSocket::ReadIfReady(IOBuffer* buf,
967                                      int buf_len,
968                                      CompletionOnceCallback callback) {
969   DCHECK(!pending_read_if_ready_callback_);
970 
971   if (!enable_read_if_ready_)
972     return ERR_READ_IF_READY_NOT_IMPLEMENTED;
973   return ReadIfReadyImpl(buf, buf_len, std::move(callback));
974 }
975 
CancelReadIfReady()976 int MockTCPClientSocket::CancelReadIfReady() {
977   DCHECK(pending_read_if_ready_callback_);
978 
979   pending_read_if_ready_callback_.Reset();
980   data_->CancelPendingRead();
981   return OK;
982 }
983 
Write(IOBuffer * buf,int buf_len,CompletionOnceCallback callback,const NetworkTrafficAnnotationTag &)984 int MockTCPClientSocket::Write(
985     IOBuffer* buf,
986     int buf_len,
987     CompletionOnceCallback callback,
988     const NetworkTrafficAnnotationTag& /* traffic_annotation */) {
989   DCHECK(buf);
990   DCHECK_GT(buf_len, 0);
991 
992   if (!connected_ || !data_)
993     return ERR_UNEXPECTED;
994 
995   std::string data(buf->data(), buf_len);
996   MockWriteResult write_result = data_->OnWrite(data);
997 
998   was_used_to_convey_data_ = true;
999 
1000   if (write_result.result == ERR_CONNECTION_CLOSED) {
1001     // This MockWrite is just a marker to instruct us to set
1002     // peer_closed_connection_.
1003     peer_closed_connection_ = true;
1004   }
1005   // ERR_IO_PENDING is a signal that the socket data will call back
1006   // asynchronously later.
1007   if (write_result.result == ERR_IO_PENDING) {
1008     pending_write_callback_ = std::move(callback);
1009     return ERR_IO_PENDING;
1010   }
1011 
1012   if (write_result.mode == ASYNC) {
1013     RunCallbackAsync(std::move(callback), write_result.result);
1014     return ERR_IO_PENDING;
1015   }
1016 
1017   return write_result.result;
1018 }
1019 
SetReceiveBufferSize(int32_t size)1020 int MockTCPClientSocket::SetReceiveBufferSize(int32_t size) {
1021   if (!connected_)
1022     return net::ERR_UNEXPECTED;
1023   data_->set_receive_buffer_size(size);
1024   return data_->set_receive_buffer_size_result();
1025 }
1026 
SetSendBufferSize(int32_t size)1027 int MockTCPClientSocket::SetSendBufferSize(int32_t size) {
1028   if (!connected_)
1029     return net::ERR_UNEXPECTED;
1030   data_->set_send_buffer_size(size);
1031   return data_->set_send_buffer_size_result();
1032 }
1033 
SetNoDelay(bool no_delay)1034 bool MockTCPClientSocket::SetNoDelay(bool no_delay) {
1035   if (!connected_)
1036     return false;
1037   data_->set_no_delay(no_delay);
1038   return data_->set_no_delay_result();
1039 }
1040 
SetKeepAlive(bool enable,int delay)1041 bool MockTCPClientSocket::SetKeepAlive(bool enable, int delay) {
1042   if (!connected_)
1043     return false;
1044   data_->set_keep_alive(enable, delay);
1045   return data_->set_keep_alive_result();
1046 }
1047 
SetBeforeConnectCallback(const BeforeConnectCallback & before_connect_callback)1048 void MockTCPClientSocket::SetBeforeConnectCallback(
1049     const BeforeConnectCallback& before_connect_callback) {
1050   DCHECK(!before_connect_callback_);
1051   DCHECK(!connected_);
1052 
1053   before_connect_callback_ = before_connect_callback;
1054 }
1055 
Connect(CompletionOnceCallback callback)1056 int MockTCPClientSocket::Connect(CompletionOnceCallback callback) {
1057   if (!data_)
1058     return ERR_UNEXPECTED;
1059 
1060   if (connected_)
1061     return OK;
1062 
1063   // Setting socket options fails if not connected, so need to set this before
1064   // calling |before_connect_callback_|.
1065   connected_ = true;
1066 
1067   if (before_connect_callback_) {
1068     for (size_t index = 0; index < addresses_.size(); index++) {
1069       int result = before_connect_callback_.Run();
1070       if (data_->connect_data().first_attempt_fails && index == 0) {
1071         continue;
1072       }
1073       DCHECK_NE(result, ERR_IO_PENDING);
1074       if (result != net::OK) {
1075         connected_ = false;
1076         return result;
1077       }
1078       break;
1079     }
1080   }
1081 
1082   peer_closed_connection_ = false;
1083 
1084   int result = data_->connect_data().result;
1085   IoMode mode = data_->connect_data().mode;
1086   if (mode == SYNCHRONOUS)
1087     return result;
1088 
1089   DCHECK(callback);
1090 
1091   if (result == ERR_IO_PENDING)
1092     pending_connect_callback_ = std::move(callback);
1093   else
1094     RunCallbackAsync(std::move(callback), result);
1095   return ERR_IO_PENDING;
1096 }
1097 
Disconnect()1098 void MockTCPClientSocket::Disconnect() {
1099   MockClientSocket::Disconnect();
1100   pending_connect_callback_.Reset();
1101   pending_read_callback_.Reset();
1102 }
1103 
IsConnected() const1104 bool MockTCPClientSocket::IsConnected() const {
1105   if (!data_)
1106     return false;
1107   return connected_ && !peer_closed_connection_;
1108 }
1109 
IsConnectedAndIdle() const1110 bool MockTCPClientSocket::IsConnectedAndIdle() const {
1111   if (!data_)
1112     return false;
1113   return IsConnected() && data_->IsIdle();
1114 }
1115 
GetPeerAddress(IPEndPoint * address) const1116 int MockTCPClientSocket::GetPeerAddress(IPEndPoint* address) const {
1117   if (addresses_.empty())
1118     return MockClientSocket::GetPeerAddress(address);
1119 
1120   if (data_->connect_data().first_attempt_fails) {
1121     DCHECK_GE(addresses_.size(), 2U);
1122     *address = addresses_[1];
1123   } else {
1124     *address = addresses_[0];
1125   }
1126   return OK;
1127 }
1128 
WasEverUsed() const1129 bool MockTCPClientSocket::WasEverUsed() const {
1130   return was_used_to_convey_data_;
1131 }
1132 
GetSSLInfo(SSLInfo * ssl_info)1133 bool MockTCPClientSocket::GetSSLInfo(SSLInfo* ssl_info) {
1134   return false;
1135 }
1136 
OnReadComplete(const MockRead & data)1137 void MockTCPClientSocket::OnReadComplete(const MockRead& data) {
1138   // If |data_| has been destroyed, safest to just do nothing.
1139   if (!data_)
1140     return;
1141 
1142   // There must be a read pending.
1143   DCHECK(pending_read_if_ready_callback_);
1144   // You can't complete a read with another ERR_IO_PENDING status code.
1145   DCHECK_NE(ERR_IO_PENDING, data.result);
1146   // Since we've been waiting for data, need_read_data_ should be true.
1147   DCHECK(need_read_data_);
1148 
1149   read_data_ = data;
1150   need_read_data_ = false;
1151 
1152   // The caller is simulating that this IO completes right now.  Don't
1153   // let CompleteRead() schedule a callback.
1154   read_data_.mode = SYNCHRONOUS;
1155   RunCallback(std::move(pending_read_if_ready_callback_),
1156               read_data_.result > 0 ? OK : read_data_.result);
1157 }
1158 
OnWriteComplete(int rv)1159 void MockTCPClientSocket::OnWriteComplete(int rv) {
1160   // If |data_| has been destroyed, safest to just do nothing.
1161   if (!data_)
1162     return;
1163 
1164   // There must be a read pending.
1165   DCHECK(!pending_write_callback_.is_null());
1166   RunCallback(std::move(pending_write_callback_), rv);
1167 }
1168 
OnConnectComplete(const MockConnect & data)1169 void MockTCPClientSocket::OnConnectComplete(const MockConnect& data) {
1170   // If |data_| has been destroyed, safest to just do nothing.
1171   if (!data_)
1172     return;
1173 
1174   RunCallback(std::move(pending_connect_callback_), data.result);
1175 }
1176 
OnDataProviderDestroyed()1177 void MockTCPClientSocket::OnDataProviderDestroyed() {
1178   data_ = nullptr;
1179 }
1180 
RetryRead(int rv)1181 void MockTCPClientSocket::RetryRead(int rv) {
1182   DCHECK(pending_read_callback_);
1183   DCHECK(pending_read_buf_.get());
1184   DCHECK_LT(0, pending_read_buf_len_);
1185 
1186   if (rv == OK) {
1187     rv = ReadIfReadyImpl(pending_read_buf_.get(), pending_read_buf_len_,
1188                          base::BindOnce(&MockTCPClientSocket::RetryRead,
1189                                         base::Unretained(this)));
1190     if (rv == ERR_IO_PENDING)
1191       return;
1192   }
1193   pending_read_buf_ = nullptr;
1194   pending_read_buf_len_ = 0;
1195   RunCallback(std::move(pending_read_callback_), rv);
1196 }
1197 
ReadIfReadyImpl(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)1198 int MockTCPClientSocket::ReadIfReadyImpl(IOBuffer* buf,
1199                                          int buf_len,
1200                                          CompletionOnceCallback callback) {
1201   if (!connected_ || !data_)
1202     return ERR_UNEXPECTED;
1203 
1204   DCHECK(!pending_read_if_ready_callback_);
1205 
1206   if (need_read_data_) {
1207     read_data_ = data_->OnRead();
1208     if (read_data_.result == ERR_CONNECTION_CLOSED) {
1209       // This MockRead is just a marker to instruct us to set
1210       // peer_closed_connection_.
1211       peer_closed_connection_ = true;
1212     }
1213     if (read_data_.result == ERR_TEST_PEER_CLOSE_AFTER_NEXT_MOCK_READ) {
1214       // This MockRead is just a marker to instruct us to set
1215       // peer_closed_connection_.  Skip it and get the next one.
1216       read_data_ = data_->OnRead();
1217       peer_closed_connection_ = true;
1218     }
1219     // ERR_IO_PENDING means that the SocketDataProvider is taking responsibility
1220     // to complete the async IO manually later (via OnReadComplete).
1221     if (read_data_.result == ERR_IO_PENDING) {
1222       // We need to be using async IO in this case.
1223       DCHECK(!callback.is_null());
1224       pending_read_if_ready_callback_ = std::move(callback);
1225       return ERR_IO_PENDING;
1226     }
1227     need_read_data_ = false;
1228   }
1229 
1230   int result = read_data_.result;
1231   DCHECK_NE(ERR_IO_PENDING, result);
1232   if (read_data_.mode == ASYNC) {
1233     DCHECK(!callback.is_null());
1234     read_data_.mode = SYNCHRONOUS;
1235     pending_read_if_ready_callback_ = std::move(callback);
1236     // base::Unretained() is safe here because RunCallbackAsync will wrap it
1237     // with a callback associated with a weak ptr.
1238     RunCallbackAsync(
1239         base::BindOnce(&MockTCPClientSocket::RunReadIfReadyCallback,
1240                        base::Unretained(this)),
1241         result);
1242     return ERR_IO_PENDING;
1243   }
1244 
1245   was_used_to_convey_data_ = true;
1246   if (read_data_.data) {
1247     if (read_data_.data_len - read_offset_ > 0) {
1248       result = std::min(buf_len, read_data_.data_len - read_offset_);
1249       memcpy(buf->data(), read_data_.data + read_offset_, result);
1250       read_offset_ += result;
1251       if (read_offset_ == read_data_.data_len) {
1252         need_read_data_ = true;
1253         read_offset_ = 0;
1254       }
1255     } else {
1256       result = 0;  // EOF
1257     }
1258   }
1259   return result;
1260 }
1261 
RunReadIfReadyCallback(int result)1262 void MockTCPClientSocket::RunReadIfReadyCallback(int result) {
1263   // If ReadIfReady is already canceled, do nothing.
1264   if (!pending_read_if_ready_callback_)
1265     return;
1266   std::move(pending_read_if_ready_callback_).Run(result);
1267 }
1268 
1269 // static
ConnectCallback(MockSSLClientSocket * ssl_client_socket,CompletionOnceCallback callback,int rv)1270 void MockSSLClientSocket::ConnectCallback(
1271     MockSSLClientSocket* ssl_client_socket,
1272     CompletionOnceCallback callback,
1273     int rv) {
1274   if (rv == OK)
1275     ssl_client_socket->connected_ = true;
1276   std::move(callback).Run(rv);
1277 }
1278 
MockSSLClientSocket(std::unique_ptr<StreamSocket> stream_socket,const HostPortPair & host_and_port,const SSLConfig & ssl_config,SSLSocketDataProvider * data)1279 MockSSLClientSocket::MockSSLClientSocket(
1280     std::unique_ptr<StreamSocket> stream_socket,
1281     const HostPortPair& host_and_port,
1282     const SSLConfig& ssl_config,
1283     SSLSocketDataProvider* data)
1284     : net_log_(stream_socket->NetLog()),
1285       stream_socket_(std::move(stream_socket)),
1286       data_(data) {
1287   DCHECK(data_);
1288   peer_addr_ = data->connect.peer_addr;
1289 }
1290 
~MockSSLClientSocket()1291 MockSSLClientSocket::~MockSSLClientSocket() {
1292   Disconnect();
1293 }
1294 
Read(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)1295 int MockSSLClientSocket::Read(IOBuffer* buf,
1296                               int buf_len,
1297                               CompletionOnceCallback callback) {
1298   return stream_socket_->Read(buf, buf_len, std::move(callback));
1299 }
1300 
ReadIfReady(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)1301 int MockSSLClientSocket::ReadIfReady(IOBuffer* buf,
1302                                      int buf_len,
1303                                      CompletionOnceCallback callback) {
1304   return stream_socket_->ReadIfReady(buf, buf_len, std::move(callback));
1305 }
1306 
Write(IOBuffer * buf,int buf_len,CompletionOnceCallback callback,const NetworkTrafficAnnotationTag & traffic_annotation)1307 int MockSSLClientSocket::Write(
1308     IOBuffer* buf,
1309     int buf_len,
1310     CompletionOnceCallback callback,
1311     const NetworkTrafficAnnotationTag& traffic_annotation) {
1312   if (!data_->is_confirm_data_consumed)
1313     data_->write_called_before_confirm = true;
1314   return stream_socket_->Write(buf, buf_len, std::move(callback),
1315                                traffic_annotation);
1316 }
1317 
CancelReadIfReady()1318 int MockSSLClientSocket::CancelReadIfReady() {
1319   return stream_socket_->CancelReadIfReady();
1320 }
1321 
Connect(CompletionOnceCallback callback)1322 int MockSSLClientSocket::Connect(CompletionOnceCallback callback) {
1323   DCHECK(stream_socket_->IsConnected());
1324   data_->is_connect_data_consumed = true;
1325   if (data_->connect.result == OK)
1326     connected_ = true;
1327   RunClosureIfNonNull(std::move(data_->connect_callback));
1328   if (data_->connect.mode == ASYNC) {
1329     RunCallbackAsync(std::move(callback), data_->connect.result);
1330     return ERR_IO_PENDING;
1331   }
1332   return data_->connect.result;
1333 }
1334 
Disconnect()1335 void MockSSLClientSocket::Disconnect() {
1336   if (stream_socket_ != nullptr)
1337     stream_socket_->Disconnect();
1338 }
1339 
RunConfirmHandshakeCallback(CompletionOnceCallback callback,int result)1340 void MockSSLClientSocket::RunConfirmHandshakeCallback(
1341     CompletionOnceCallback callback,
1342     int result) {
1343   DCHECK(in_confirm_handshake_);
1344   in_confirm_handshake_ = false;
1345   data_->is_confirm_data_consumed = true;
1346   std::move(callback).Run(result);
1347 }
1348 
ConfirmHandshake(CompletionOnceCallback callback)1349 int MockSSLClientSocket::ConfirmHandshake(CompletionOnceCallback callback) {
1350   DCHECK(stream_socket_->IsConnected());
1351   DCHECK(!in_confirm_handshake_);
1352   if (data_->is_confirm_data_consumed)
1353     return data_->confirm.result;
1354   RunClosureIfNonNull(std::move(data_->confirm_callback));
1355   if (data_->confirm.mode == ASYNC) {
1356     in_confirm_handshake_ = true;
1357     RunCallbackAsync(
1358         base::BindOnce(&MockSSLClientSocket::RunConfirmHandshakeCallback,
1359                        base::Unretained(this), std::move(callback)),
1360         data_->confirm.result);
1361     return ERR_IO_PENDING;
1362   }
1363   data_->is_confirm_data_consumed = true;
1364   if (data_->confirm.result == ERR_IO_PENDING) {
1365     // `MockConfirm(SYNCHRONOUS, ERR_IO_PENDING)` means `ConfirmHandshake()`
1366     // never completes.
1367     in_confirm_handshake_ = true;
1368   }
1369   return data_->confirm.result;
1370 }
1371 
IsConnected() const1372 bool MockSSLClientSocket::IsConnected() const {
1373   return stream_socket_->IsConnected();
1374 }
1375 
IsConnectedAndIdle() const1376 bool MockSSLClientSocket::IsConnectedAndIdle() const {
1377   return stream_socket_->IsConnectedAndIdle();
1378 }
1379 
WasEverUsed() const1380 bool MockSSLClientSocket::WasEverUsed() const {
1381   return stream_socket_->WasEverUsed();
1382 }
1383 
GetLocalAddress(IPEndPoint * address) const1384 int MockSSLClientSocket::GetLocalAddress(IPEndPoint* address) const {
1385   *address = IPEndPoint(IPAddress(192, 0, 2, 33), 123);
1386   return OK;
1387 }
1388 
GetPeerAddress(IPEndPoint * address) const1389 int MockSSLClientSocket::GetPeerAddress(IPEndPoint* address) const {
1390   return stream_socket_->GetPeerAddress(address);
1391 }
1392 
WasAlpnNegotiated() const1393 bool MockSSLClientSocket::WasAlpnNegotiated() const {
1394   return data_->next_proto != kProtoUnknown;
1395 }
1396 
GetNegotiatedProtocol() const1397 NextProto MockSSLClientSocket::GetNegotiatedProtocol() const {
1398   return data_->next_proto;
1399 }
1400 
1401 absl::optional<base::StringPiece>
GetPeerApplicationSettings() const1402 MockSSLClientSocket::GetPeerApplicationSettings() const {
1403   return data_->peer_application_settings;
1404 }
1405 
GetSSLInfo(SSLInfo * requested_ssl_info)1406 bool MockSSLClientSocket::GetSSLInfo(SSLInfo* requested_ssl_info) {
1407   requested_ssl_info->Reset();
1408   *requested_ssl_info = data_->ssl_info;
1409   return true;
1410 }
1411 
ApplySocketTag(const SocketTag & tag)1412 void MockSSLClientSocket::ApplySocketTag(const SocketTag& tag) {
1413   return stream_socket_->ApplySocketTag(tag);
1414 }
1415 
NetLog() const1416 const NetLogWithSource& MockSSLClientSocket::NetLog() const {
1417   return net_log_;
1418 }
1419 
GetTotalReceivedBytes() const1420 int64_t MockSSLClientSocket::GetTotalReceivedBytes() const {
1421   NOTIMPLEMENTED();
1422   return 0;
1423 }
1424 
GetTotalReceivedBytes() const1425 int64_t MockClientSocket::GetTotalReceivedBytes() const {
1426   NOTIMPLEMENTED();
1427   return 0;
1428 }
1429 
SetReceiveBufferSize(int32_t size)1430 int MockSSLClientSocket::SetReceiveBufferSize(int32_t size) {
1431   return OK;
1432 }
1433 
SetSendBufferSize(int32_t size)1434 int MockSSLClientSocket::SetSendBufferSize(int32_t size) {
1435   return OK;
1436 }
1437 
GetSSLCertRequestInfo(SSLCertRequestInfo * cert_request_info) const1438 void MockSSLClientSocket::GetSSLCertRequestInfo(
1439     SSLCertRequestInfo* cert_request_info) const {
1440   DCHECK(cert_request_info);
1441   if (data_->cert_request_info) {
1442     cert_request_info->host_and_port = data_->cert_request_info->host_and_port;
1443     cert_request_info->is_proxy = data_->cert_request_info->is_proxy;
1444     cert_request_info->cert_authorities =
1445         data_->cert_request_info->cert_authorities;
1446     cert_request_info->cert_key_types =
1447         data_->cert_request_info->cert_key_types;
1448   } else {
1449     cert_request_info->Reset();
1450   }
1451 }
1452 
ExportKeyingMaterial(base::StringPiece label,bool has_context,base::StringPiece context,unsigned char * out,unsigned int outlen)1453 int MockSSLClientSocket::ExportKeyingMaterial(base::StringPiece label,
1454                                               bool has_context,
1455                                               base::StringPiece context,
1456                                               unsigned char* out,
1457                                               unsigned int outlen) {
1458   memset(out, 'A', outlen);
1459   return OK;
1460 }
1461 
GetECHRetryConfigs()1462 std::vector<uint8_t> MockSSLClientSocket::GetECHRetryConfigs() {
1463   return data_->ech_retry_configs;
1464 }
1465 
RunCallbackAsync(CompletionOnceCallback callback,int result)1466 void MockSSLClientSocket::RunCallbackAsync(CompletionOnceCallback callback,
1467                                            int result) {
1468   base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
1469       FROM_HERE,
1470       base::BindOnce(&MockSSLClientSocket::RunCallback,
1471                      weak_factory_.GetWeakPtr(), std::move(callback), result));
1472 }
1473 
RunCallback(CompletionOnceCallback callback,int result)1474 void MockSSLClientSocket::RunCallback(CompletionOnceCallback callback,
1475                                       int result) {
1476   std::move(callback).Run(result);
1477 }
1478 
OnReadComplete(const MockRead & data)1479 void MockSSLClientSocket::OnReadComplete(const MockRead& data) {
1480   NOTIMPLEMENTED();
1481 }
1482 
OnWriteComplete(int rv)1483 void MockSSLClientSocket::OnWriteComplete(int rv) {
1484   NOTIMPLEMENTED();
1485 }
1486 
OnConnectComplete(const MockConnect & data)1487 void MockSSLClientSocket::OnConnectComplete(const MockConnect& data) {
1488   NOTIMPLEMENTED();
1489 }
1490 
MockUDPClientSocket(SocketDataProvider * data,net::NetLog * net_log)1491 MockUDPClientSocket::MockUDPClientSocket(SocketDataProvider* data,
1492                                          net::NetLog* net_log)
1493     : data_(data),
1494       read_data_(SYNCHRONOUS, ERR_UNEXPECTED),
1495       source_host_(IPAddress(192, 0, 2, 33)),
1496       net_log_(NetLogWithSource::Make(net_log, NetLogSourceType::NONE)) {
1497   if (data_) {
1498     data_->Initialize(this);
1499     peer_addr_ = data->connect_data().peer_addr;
1500   }
1501 }
1502 
~MockUDPClientSocket()1503 MockUDPClientSocket::~MockUDPClientSocket() {
1504   if (data_)
1505     data_->DetachSocket();
1506 }
1507 
Read(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)1508 int MockUDPClientSocket::Read(IOBuffer* buf,
1509                               int buf_len,
1510                               CompletionOnceCallback callback) {
1511   DCHECK(callback);
1512 
1513   if (!connected_ || !data_)
1514     return ERR_UNEXPECTED;
1515   data_transferred_ = true;
1516 
1517   // If the buffer is already in use, a read is already in progress!
1518   DCHECK(!pending_read_buf_);
1519 
1520   // Store our async IO data.
1521   pending_read_buf_ = buf;
1522   pending_read_buf_len_ = buf_len;
1523   pending_read_callback_ = std::move(callback);
1524 
1525   if (need_read_data_) {
1526     read_data_ = data_->OnRead();
1527     // ERR_IO_PENDING means that the SocketDataProvider is taking responsibility
1528     // to complete the async IO manually later (via OnReadComplete).
1529     if (read_data_.result == ERR_IO_PENDING) {
1530       // We need to be using async IO in this case.
1531       DCHECK(!pending_read_callback_.is_null());
1532       return ERR_IO_PENDING;
1533     }
1534     need_read_data_ = false;
1535   }
1536 
1537   return CompleteRead();
1538 }
1539 
Write(IOBuffer * buf,int buf_len,CompletionOnceCallback callback,const NetworkTrafficAnnotationTag &)1540 int MockUDPClientSocket::Write(
1541     IOBuffer* buf,
1542     int buf_len,
1543     CompletionOnceCallback callback,
1544     const NetworkTrafficAnnotationTag& /* traffic_annotation */) {
1545   DCHECK(buf);
1546   DCHECK_GT(buf_len, 0);
1547   DCHECK(callback);
1548 
1549   if (!connected_ || !data_)
1550     return ERR_UNEXPECTED;
1551   data_transferred_ = true;
1552 
1553   std::string data(buf->data(), buf_len);
1554   MockWriteResult write_result = data_->OnWrite(data);
1555 
1556   // ERR_IO_PENDING is a signal that the socket data will call back
1557   // asynchronously.
1558   if (write_result.result == ERR_IO_PENDING) {
1559     pending_write_callback_ = std::move(callback);
1560     return ERR_IO_PENDING;
1561   }
1562   if (write_result.mode == ASYNC) {
1563     RunCallbackAsync(std::move(callback), write_result.result);
1564     return ERR_IO_PENDING;
1565   }
1566   return write_result.result;
1567 }
1568 
SetReceiveBufferSize(int32_t size)1569 int MockUDPClientSocket::SetReceiveBufferSize(int32_t size) {
1570   return OK;
1571 }
1572 
SetSendBufferSize(int32_t size)1573 int MockUDPClientSocket::SetSendBufferSize(int32_t size) {
1574   return OK;
1575 }
1576 
SetDoNotFragment()1577 int MockUDPClientSocket::SetDoNotFragment() {
1578   return OK;
1579 }
1580 
Close()1581 void MockUDPClientSocket::Close() {
1582   connected_ = false;
1583 }
1584 
GetPeerAddress(IPEndPoint * address) const1585 int MockUDPClientSocket::GetPeerAddress(IPEndPoint* address) const {
1586   if (!data_)
1587     return ERR_UNEXPECTED;
1588 
1589   *address = peer_addr_;
1590   return OK;
1591 }
1592 
GetLocalAddress(IPEndPoint * address) const1593 int MockUDPClientSocket::GetLocalAddress(IPEndPoint* address) const {
1594   *address = IPEndPoint(source_host_, source_port_);
1595   return OK;
1596 }
1597 
UseNonBlockingIO()1598 void MockUDPClientSocket::UseNonBlockingIO() {}
1599 
SetMulticastInterface(uint32_t interface_index)1600 int MockUDPClientSocket::SetMulticastInterface(uint32_t interface_index) {
1601   return OK;
1602 }
1603 
NetLog() const1604 const NetLogWithSource& MockUDPClientSocket::NetLog() const {
1605   return net_log_;
1606 }
1607 
Connect(const IPEndPoint & address)1608 int MockUDPClientSocket::Connect(const IPEndPoint& address) {
1609   if (!data_)
1610     return ERR_UNEXPECTED;
1611   DCHECK_NE(data_->connect_data().result, ERR_IO_PENDING);
1612   connected_ = true;
1613   peer_addr_ = address;
1614   return data_->connect_data().result;
1615 }
1616 
ConnectUsingNetwork(handles::NetworkHandle network,const IPEndPoint & address)1617 int MockUDPClientSocket::ConnectUsingNetwork(handles::NetworkHandle network,
1618                                              const IPEndPoint& address) {
1619   DCHECK(!connected_);
1620   if (!data_)
1621     return ERR_UNEXPECTED;
1622   DCHECK_NE(data_->connect_data().result, ERR_IO_PENDING);
1623   network_ = network;
1624   connected_ = true;
1625   peer_addr_ = address;
1626   return data_->connect_data().result;
1627 }
1628 
ConnectUsingDefaultNetwork(const IPEndPoint & address)1629 int MockUDPClientSocket::ConnectUsingDefaultNetwork(const IPEndPoint& address) {
1630   DCHECK(!connected_);
1631   if (!data_)
1632     return ERR_UNEXPECTED;
1633   DCHECK_NE(data_->connect_data().result, ERR_IO_PENDING);
1634   network_ = kDefaultNetworkForTests;
1635   connected_ = true;
1636   peer_addr_ = address;
1637   return data_->connect_data().result;
1638 }
1639 
ConnectAsync(const IPEndPoint & address,CompletionOnceCallback callback)1640 int MockUDPClientSocket::ConnectAsync(const IPEndPoint& address,
1641                                       CompletionOnceCallback callback) {
1642   DCHECK(callback);
1643   if (!data_) {
1644     return ERR_UNEXPECTED;
1645   }
1646   connected_ = true;
1647   peer_addr_ = address;
1648   int result = data_->connect_data().result;
1649   IoMode mode = data_->connect_data().mode;
1650   if (mode == SYNCHRONOUS) {
1651     return result;
1652   }
1653   RunCallbackAsync(std::move(callback), result);
1654   return ERR_IO_PENDING;
1655 }
1656 
ConnectUsingNetworkAsync(handles::NetworkHandle network,const IPEndPoint & address,CompletionOnceCallback callback)1657 int MockUDPClientSocket::ConnectUsingNetworkAsync(
1658     handles::NetworkHandle network,
1659     const IPEndPoint& address,
1660     CompletionOnceCallback callback) {
1661   DCHECK(callback);
1662   DCHECK(!connected_);
1663   if (!data_)
1664     return ERR_UNEXPECTED;
1665   network_ = network;
1666   connected_ = true;
1667   peer_addr_ = address;
1668   int result = data_->connect_data().result;
1669   IoMode mode = data_->connect_data().mode;
1670   if (mode == SYNCHRONOUS) {
1671     return result;
1672   }
1673   RunCallbackAsync(std::move(callback), result);
1674   return ERR_IO_PENDING;
1675 }
1676 
ConnectUsingDefaultNetworkAsync(const IPEndPoint & address,CompletionOnceCallback callback)1677 int MockUDPClientSocket::ConnectUsingDefaultNetworkAsync(
1678     const IPEndPoint& address,
1679     CompletionOnceCallback callback) {
1680   DCHECK(!connected_);
1681   if (!data_)
1682     return ERR_UNEXPECTED;
1683   network_ = kDefaultNetworkForTests;
1684   connected_ = true;
1685   peer_addr_ = address;
1686   int result = data_->connect_data().result;
1687   IoMode mode = data_->connect_data().mode;
1688   if (mode == SYNCHRONOUS) {
1689     return result;
1690   }
1691   RunCallbackAsync(std::move(callback), result);
1692   return ERR_IO_PENDING;
1693 }
1694 
GetBoundNetwork() const1695 handles::NetworkHandle MockUDPClientSocket::GetBoundNetwork() const {
1696   return network_;
1697 }
1698 
ApplySocketTag(const SocketTag & tag)1699 void MockUDPClientSocket::ApplySocketTag(const SocketTag& tag) {
1700   tagged_before_data_transferred_ &= !data_transferred_ || tag == tag_;
1701   tag_ = tag;
1702 }
1703 
OnReadComplete(const MockRead & data)1704 void MockUDPClientSocket::OnReadComplete(const MockRead& data) {
1705   if (!data_)
1706     return;
1707 
1708   // There must be a read pending.
1709   DCHECK(pending_read_buf_.get());
1710   DCHECK(pending_read_callback_);
1711   // You can't complete a read with another ERR_IO_PENDING status code.
1712   DCHECK_NE(ERR_IO_PENDING, data.result);
1713   // Since we've been waiting for data, need_read_data_ should be true.
1714   DCHECK(need_read_data_);
1715 
1716   read_data_ = data;
1717   need_read_data_ = false;
1718 
1719   // The caller is simulating that this IO completes right now.  Don't
1720   // let CompleteRead() schedule a callback.
1721   read_data_.mode = SYNCHRONOUS;
1722 
1723   CompletionOnceCallback callback = std::move(pending_read_callback_);
1724   int rv = CompleteRead();
1725   RunCallback(std::move(callback), rv);
1726 }
1727 
OnWriteComplete(int rv)1728 void MockUDPClientSocket::OnWriteComplete(int rv) {
1729   if (!data_)
1730     return;
1731 
1732   // There must be a read pending.
1733   DCHECK(!pending_write_callback_.is_null());
1734   RunCallback(std::move(pending_write_callback_), rv);
1735 }
1736 
OnConnectComplete(const MockConnect & data)1737 void MockUDPClientSocket::OnConnectComplete(const MockConnect& data) {
1738   NOTIMPLEMENTED();
1739 }
1740 
OnDataProviderDestroyed()1741 void MockUDPClientSocket::OnDataProviderDestroyed() {
1742   data_ = nullptr;
1743 }
1744 
CompleteRead()1745 int MockUDPClientSocket::CompleteRead() {
1746   DCHECK(pending_read_buf_.get());
1747   DCHECK(pending_read_buf_len_ > 0);
1748 
1749   // Save the pending async IO data and reset our |pending_| state.
1750   scoped_refptr<IOBuffer> buf = pending_read_buf_;
1751   int buf_len = pending_read_buf_len_;
1752   CompletionOnceCallback callback = std::move(pending_read_callback_);
1753   pending_read_buf_ = nullptr;
1754   pending_read_buf_len_ = 0;
1755 
1756   int result = read_data_.result;
1757   DCHECK(result != ERR_IO_PENDING);
1758 
1759   if (read_data_.data) {
1760     if (read_data_.data_len - read_offset_ > 0) {
1761       result = std::min(buf_len, read_data_.data_len - read_offset_);
1762       memcpy(buf->data(), read_data_.data + read_offset_, result);
1763       read_offset_ += result;
1764       if (read_offset_ == read_data_.data_len) {
1765         need_read_data_ = true;
1766         read_offset_ = 0;
1767       }
1768     } else {
1769       result = 0;  // EOF
1770     }
1771   }
1772 
1773   if (read_data_.mode == ASYNC) {
1774     DCHECK(!callback.is_null());
1775     RunCallbackAsync(std::move(callback), result);
1776     return ERR_IO_PENDING;
1777   }
1778   return result;
1779 }
1780 
RunCallbackAsync(CompletionOnceCallback callback,int result)1781 void MockUDPClientSocket::RunCallbackAsync(CompletionOnceCallback callback,
1782                                            int result) {
1783   base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
1784       FROM_HERE,
1785       base::BindOnce(&MockUDPClientSocket::RunCallback,
1786                      weak_factory_.GetWeakPtr(), std::move(callback), result));
1787 }
1788 
RunCallback(CompletionOnceCallback callback,int result)1789 void MockUDPClientSocket::RunCallback(CompletionOnceCallback callback,
1790                                       int result) {
1791   std::move(callback).Run(result);
1792 }
1793 
TestSocketRequest(std::vector<TestSocketRequest * > * request_order,size_t * completion_count)1794 TestSocketRequest::TestSocketRequest(
1795     std::vector<TestSocketRequest*>* request_order,
1796     size_t* completion_count)
1797     : request_order_(request_order), completion_count_(completion_count) {
1798   DCHECK(request_order);
1799   DCHECK(completion_count);
1800 }
1801 
1802 TestSocketRequest::~TestSocketRequest() = default;
1803 
OnComplete(int result)1804 void TestSocketRequest::OnComplete(int result) {
1805   SetResult(result);
1806   (*completion_count_)++;
1807   request_order_->push_back(this);
1808 }
1809 
1810 // static
1811 const int ClientSocketPoolTest::kIndexOutOfBounds = -1;
1812 
1813 // static
1814 const int ClientSocketPoolTest::kRequestNotFound = -2;
1815 
1816 ClientSocketPoolTest::ClientSocketPoolTest() = default;
1817 ClientSocketPoolTest::~ClientSocketPoolTest() = default;
1818 
GetOrderOfRequest(size_t index) const1819 int ClientSocketPoolTest::GetOrderOfRequest(size_t index) const {
1820   index--;
1821   if (index >= requests_.size())
1822     return kIndexOutOfBounds;
1823 
1824   for (size_t i = 0; i < request_order_.size(); i++)
1825     if (requests_[index].get() == request_order_[i])
1826       return i + 1;
1827 
1828   return kRequestNotFound;
1829 }
1830 
ReleaseOneConnection(KeepAlive keep_alive)1831 bool ClientSocketPoolTest::ReleaseOneConnection(KeepAlive keep_alive) {
1832   for (std::unique_ptr<TestSocketRequest>& it : requests_) {
1833     if (it->handle()->is_initialized()) {
1834       if (keep_alive == NO_KEEP_ALIVE)
1835         it->handle()->socket()->Disconnect();
1836       it->handle()->Reset();
1837       base::RunLoop().RunUntilIdle();
1838       return true;
1839     }
1840   }
1841   return false;
1842 }
1843 
ReleaseAllConnections(KeepAlive keep_alive)1844 void ClientSocketPoolTest::ReleaseAllConnections(KeepAlive keep_alive) {
1845   bool released_one;
1846   do {
1847     released_one = ReleaseOneConnection(keep_alive);
1848   } while (released_one);
1849 }
1850 
MockConnectJob(std::unique_ptr<StreamSocket> socket,ClientSocketHandle * handle,const SocketTag & socket_tag,CompletionOnceCallback callback,RequestPriority priority)1851 MockTransportClientSocketPool::MockConnectJob::MockConnectJob(
1852     std::unique_ptr<StreamSocket> socket,
1853     ClientSocketHandle* handle,
1854     const SocketTag& socket_tag,
1855     CompletionOnceCallback callback,
1856     RequestPriority priority)
1857     : socket_(std::move(socket)),
1858       handle_(handle),
1859       socket_tag_(socket_tag),
1860       user_callback_(std::move(callback)),
1861       priority_(priority) {}
1862 
1863 MockTransportClientSocketPool::MockConnectJob::~MockConnectJob() = default;
1864 
Connect()1865 int MockTransportClientSocketPool::MockConnectJob::Connect() {
1866   socket_->ApplySocketTag(socket_tag_);
1867   int rv = socket_->Connect(
1868       base::BindOnce(&MockConnectJob::OnConnect, base::Unretained(this)));
1869   if (rv != ERR_IO_PENDING) {
1870     user_callback_.Reset();
1871     OnConnect(rv);
1872   }
1873   return rv;
1874 }
1875 
CancelHandle(const ClientSocketHandle * handle)1876 bool MockTransportClientSocketPool::MockConnectJob::CancelHandle(
1877     const ClientSocketHandle* handle) {
1878   if (handle != handle_)
1879     return false;
1880   socket_.reset();
1881   handle_ = nullptr;
1882   user_callback_.Reset();
1883   return true;
1884 }
1885 
OnConnect(int rv)1886 void MockTransportClientSocketPool::MockConnectJob::OnConnect(int rv) {
1887   if (!socket_.get())
1888     return;
1889   if (rv == OK) {
1890     handle_->SetSocket(std::move(socket_));
1891 
1892     // Needed for socket pool tests that layer other sockets on top of mock
1893     // sockets.
1894     LoadTimingInfo::ConnectTiming connect_timing;
1895     base::TimeTicks now = base::TimeTicks::Now();
1896     connect_timing.domain_lookup_start = now;
1897     connect_timing.domain_lookup_end = now;
1898     connect_timing.connect_start = now;
1899     connect_timing.connect_end = now;
1900     handle_->set_connect_timing(connect_timing);
1901   } else {
1902     socket_.reset();
1903 
1904     // Needed to test copying of ConnectionAttempts in SSL ConnectJob.
1905     ConnectionAttempts attempts;
1906     attempts.push_back(ConnectionAttempt(IPEndPoint(), rv));
1907     handle_->set_connection_attempts(attempts);
1908   }
1909 
1910   handle_ = nullptr;
1911 
1912   if (!user_callback_.is_null()) {
1913     std::move(user_callback_).Run(rv);
1914   }
1915 }
1916 
MockTransportClientSocketPool(int max_sockets,int max_sockets_per_group,const CommonConnectJobParams * common_connect_job_params)1917 MockTransportClientSocketPool::MockTransportClientSocketPool(
1918     int max_sockets,
1919     int max_sockets_per_group,
1920     const CommonConnectJobParams* common_connect_job_params)
1921     : TransportClientSocketPool(
1922           max_sockets,
1923           max_sockets_per_group,
1924           base::Seconds(10) /* unused_idle_socket_timeout */,
1925           ProxyServer::Direct(),
1926           false /* is_for_websockets */,
1927           common_connect_job_params),
1928       client_socket_factory_(common_connect_job_params->client_socket_factory) {
1929 }
1930 
1931 MockTransportClientSocketPool::~MockTransportClientSocketPool() = default;
1932 
RequestSocket(const ClientSocketPool::GroupId & group_id,scoped_refptr<ClientSocketPool::SocketParams> socket_params,const absl::optional<NetworkTrafficAnnotationTag> & proxy_annotation_tag,RequestPriority priority,const SocketTag & socket_tag,RespectLimits respect_limits,ClientSocketHandle * handle,CompletionOnceCallback callback,const ProxyAuthCallback & on_auth_callback,const NetLogWithSource & net_log)1933 int MockTransportClientSocketPool::RequestSocket(
1934     const ClientSocketPool::GroupId& group_id,
1935     scoped_refptr<ClientSocketPool::SocketParams> socket_params,
1936     const absl::optional<NetworkTrafficAnnotationTag>& proxy_annotation_tag,
1937     RequestPriority priority,
1938     const SocketTag& socket_tag,
1939     RespectLimits respect_limits,
1940     ClientSocketHandle* handle,
1941     CompletionOnceCallback callback,
1942     const ProxyAuthCallback& on_auth_callback,
1943     const NetLogWithSource& net_log) {
1944   last_request_priority_ = priority;
1945   std::unique_ptr<StreamSocket> socket =
1946       client_socket_factory_->CreateTransportClientSocket(
1947           AddressList(), nullptr, nullptr, net_log.net_log(), NetLogSource());
1948   auto job = std::make_unique<MockConnectJob>(
1949       std::move(socket), handle, socket_tag, std::move(callback), priority);
1950   auto* job_ptr = job.get();
1951   job_list_.push_back(std::move(job));
1952   handle->set_group_generation(1);
1953   return job_ptr->Connect();
1954 }
1955 
SetPriority(const ClientSocketPool::GroupId & group_id,ClientSocketHandle * handle,RequestPriority priority)1956 void MockTransportClientSocketPool::SetPriority(
1957     const ClientSocketPool::GroupId& group_id,
1958     ClientSocketHandle* handle,
1959     RequestPriority priority) {
1960   for (auto& job : job_list_) {
1961     if (job->handle() == handle) {
1962       job->set_priority(priority);
1963       return;
1964     }
1965   }
1966   NOTREACHED();
1967 }
1968 
CancelRequest(const ClientSocketPool::GroupId & group_id,ClientSocketHandle * handle,bool cancel_connect_job)1969 void MockTransportClientSocketPool::CancelRequest(
1970     const ClientSocketPool::GroupId& group_id,
1971     ClientSocketHandle* handle,
1972     bool cancel_connect_job) {
1973   for (std::unique_ptr<MockConnectJob>& it : job_list_) {
1974     if (it->CancelHandle(handle)) {
1975       cancel_count_++;
1976       break;
1977     }
1978   }
1979 }
1980 
ReleaseSocket(const ClientSocketPool::GroupId & group_id,std::unique_ptr<StreamSocket> socket,int64_t generation)1981 void MockTransportClientSocketPool::ReleaseSocket(
1982     const ClientSocketPool::GroupId& group_id,
1983     std::unique_ptr<StreamSocket> socket,
1984     int64_t generation) {
1985   EXPECT_EQ(1, generation);
1986   release_count_++;
1987 }
1988 
WrappedStreamSocket(std::unique_ptr<StreamSocket> transport)1989 WrappedStreamSocket::WrappedStreamSocket(
1990     std::unique_ptr<StreamSocket> transport)
1991     : transport_(std::move(transport)) {}
1992 WrappedStreamSocket::~WrappedStreamSocket() = default;
1993 
Bind(const net::IPEndPoint & local_addr)1994 int WrappedStreamSocket::Bind(const net::IPEndPoint& local_addr) {
1995   NOTREACHED();
1996   return ERR_FAILED;
1997 }
1998 
Connect(CompletionOnceCallback callback)1999 int WrappedStreamSocket::Connect(CompletionOnceCallback callback) {
2000   return transport_->Connect(std::move(callback));
2001 }
2002 
Disconnect()2003 void WrappedStreamSocket::Disconnect() {
2004   transport_->Disconnect();
2005 }
2006 
IsConnected() const2007 bool WrappedStreamSocket::IsConnected() const {
2008   return transport_->IsConnected();
2009 }
2010 
IsConnectedAndIdle() const2011 bool WrappedStreamSocket::IsConnectedAndIdle() const {
2012   return transport_->IsConnectedAndIdle();
2013 }
2014 
GetPeerAddress(IPEndPoint * address) const2015 int WrappedStreamSocket::GetPeerAddress(IPEndPoint* address) const {
2016   return transport_->GetPeerAddress(address);
2017 }
2018 
GetLocalAddress(IPEndPoint * address) const2019 int WrappedStreamSocket::GetLocalAddress(IPEndPoint* address) const {
2020   return transport_->GetLocalAddress(address);
2021 }
2022 
NetLog() const2023 const NetLogWithSource& WrappedStreamSocket::NetLog() const {
2024   return transport_->NetLog();
2025 }
2026 
WasEverUsed() const2027 bool WrappedStreamSocket::WasEverUsed() const {
2028   return transport_->WasEverUsed();
2029 }
2030 
WasAlpnNegotiated() const2031 bool WrappedStreamSocket::WasAlpnNegotiated() const {
2032   return transport_->WasAlpnNegotiated();
2033 }
2034 
GetNegotiatedProtocol() const2035 NextProto WrappedStreamSocket::GetNegotiatedProtocol() const {
2036   return transport_->GetNegotiatedProtocol();
2037 }
2038 
GetSSLInfo(SSLInfo * ssl_info)2039 bool WrappedStreamSocket::GetSSLInfo(SSLInfo* ssl_info) {
2040   return transport_->GetSSLInfo(ssl_info);
2041 }
2042 
GetTotalReceivedBytes() const2043 int64_t WrappedStreamSocket::GetTotalReceivedBytes() const {
2044   return transport_->GetTotalReceivedBytes();
2045 }
2046 
ApplySocketTag(const SocketTag & tag)2047 void WrappedStreamSocket::ApplySocketTag(const SocketTag& tag) {
2048   transport_->ApplySocketTag(tag);
2049 }
2050 
Read(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)2051 int WrappedStreamSocket::Read(IOBuffer* buf,
2052                               int buf_len,
2053                               CompletionOnceCallback callback) {
2054   return transport_->Read(buf, buf_len, std::move(callback));
2055 }
2056 
ReadIfReady(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)2057 int WrappedStreamSocket::ReadIfReady(IOBuffer* buf,
2058                                      int buf_len,
2059                                      CompletionOnceCallback callback) {
2060   return transport_->ReadIfReady(buf, buf_len, std::move((callback)));
2061 }
2062 
Write(IOBuffer * buf,int buf_len,CompletionOnceCallback callback,const NetworkTrafficAnnotationTag & traffic_annotation)2063 int WrappedStreamSocket::Write(
2064     IOBuffer* buf,
2065     int buf_len,
2066     CompletionOnceCallback callback,
2067     const NetworkTrafficAnnotationTag& traffic_annotation) {
2068   return transport_->Write(buf, buf_len, std::move(callback),
2069                            TRAFFIC_ANNOTATION_FOR_TESTS);
2070 }
2071 
SetReceiveBufferSize(int32_t size)2072 int WrappedStreamSocket::SetReceiveBufferSize(int32_t size) {
2073   return transport_->SetReceiveBufferSize(size);
2074 }
2075 
SetSendBufferSize(int32_t size)2076 int WrappedStreamSocket::SetSendBufferSize(int32_t size) {
2077   return transport_->SetSendBufferSize(size);
2078 }
2079 
Connect(CompletionOnceCallback callback)2080 int MockTaggingStreamSocket::Connect(CompletionOnceCallback callback) {
2081   connected_ = true;
2082   return WrappedStreamSocket::Connect(std::move(callback));
2083 }
2084 
ApplySocketTag(const SocketTag & tag)2085 void MockTaggingStreamSocket::ApplySocketTag(const SocketTag& tag) {
2086   tagged_before_connected_ &= !connected_ || tag == tag_;
2087   tag_ = tag;
2088   transport_->ApplySocketTag(tag);
2089 }
2090 
2091 std::unique_ptr<TransportClientSocket>
CreateTransportClientSocket(const AddressList & addresses,std::unique_ptr<SocketPerformanceWatcher> socket_performance_watcher,NetworkQualityEstimator * network_quality_estimator,NetLog * net_log,const NetLogSource & source)2092 MockTaggingClientSocketFactory::CreateTransportClientSocket(
2093     const AddressList& addresses,
2094     std::unique_ptr<SocketPerformanceWatcher> socket_performance_watcher,
2095     NetworkQualityEstimator* network_quality_estimator,
2096     NetLog* net_log,
2097     const NetLogSource& source) {
2098   auto socket = std::make_unique<MockTaggingStreamSocket>(
2099       MockClientSocketFactory::CreateTransportClientSocket(
2100           addresses, std::move(socket_performance_watcher),
2101           network_quality_estimator, net_log, source));
2102   tcp_socket_ = socket.get();
2103   return std::move(socket);
2104 }
2105 
2106 std::unique_ptr<DatagramClientSocket>
CreateDatagramClientSocket(DatagramSocket::BindType bind_type,NetLog * net_log,const NetLogSource & source)2107 MockTaggingClientSocketFactory::CreateDatagramClientSocket(
2108     DatagramSocket::BindType bind_type,
2109     NetLog* net_log,
2110     const NetLogSource& source) {
2111   std::unique_ptr<DatagramClientSocket> socket(
2112       MockClientSocketFactory::CreateDatagramClientSocket(bind_type, net_log,
2113                                                           source));
2114   udp_socket_ = static_cast<MockUDPClientSocket*>(socket.get());
2115   return socket;
2116 }
2117 
2118 const char kSOCKS4TestHost[] = "127.0.0.1";
2119 const int kSOCKS4TestPort = 80;
2120 
2121 const char kSOCKS4OkRequestLocalHostPort80[] = {0x04, 0x01, 0x00, 0x50, 127,
2122                                                 0,    0,    1,    0};
2123 const int kSOCKS4OkRequestLocalHostPort80Length =
2124     std::size(kSOCKS4OkRequestLocalHostPort80);
2125 
2126 const char kSOCKS4OkReply[] = {0x00, 0x5A, 0x00, 0x00, 0, 0, 0, 0};
2127 const int kSOCKS4OkReplyLength = std::size(kSOCKS4OkReply);
2128 
2129 const char kSOCKS5TestHost[] = "host";
2130 const int kSOCKS5TestPort = 80;
2131 
2132 const char kSOCKS5GreetRequest[] = {0x05, 0x01, 0x00};
2133 const int kSOCKS5GreetRequestLength = std::size(kSOCKS5GreetRequest);
2134 
2135 const char kSOCKS5GreetResponse[] = {0x05, 0x00};
2136 const int kSOCKS5GreetResponseLength = std::size(kSOCKS5GreetResponse);
2137 
2138 const char kSOCKS5OkRequest[] = {0x05, 0x01, 0x00, 0x03, 0x04, 'h',
2139                                  'o',  's',  't',  0x00, 0x50};
2140 const int kSOCKS5OkRequestLength = std::size(kSOCKS5OkRequest);
2141 
2142 const char kSOCKS5OkResponse[] = {0x05, 0x00, 0x00, 0x01, 127,
2143                                   0,    0,    1,    0x00, 0x50};
2144 const int kSOCKS5OkResponseLength = std::size(kSOCKS5OkResponse);
2145 
CountReadBytes(base::span<const MockRead> reads)2146 int64_t CountReadBytes(base::span<const MockRead> reads) {
2147   int64_t total = 0;
2148   for (const MockRead& read : reads)
2149     total += read.data_len;
2150   return total;
2151 }
2152 
CountWriteBytes(base::span<const MockWrite> writes)2153 int64_t CountWriteBytes(base::span<const MockWrite> writes) {
2154   int64_t total = 0;
2155   for (const MockWrite& write : writes)
2156     total += write.data_len;
2157   return total;
2158 }
2159 
2160 #if BUILDFLAG(IS_ANDROID)
CanGetTaggedBytes()2161 bool CanGetTaggedBytes() {
2162   // In Android P, /proc/net/xt_qtaguid/stats is no longer guaranteed to be
2163   // present, and has been replaced with eBPF Traffic Monitoring in netd. See:
2164   // https://source.android.com/devices/tech/datausage/ebpf-traffic-monitor
2165   //
2166   // To read traffic statistics from netd, apps should use the API
2167   // NetworkStatsManager.queryDetailsForUidTag(). But this API does not provide
2168   // statistics for local traffic, only mobile and WiFi traffic, so it would not
2169   // work in tests that spin up a local server. So for now, GetTaggedBytes is
2170   // only supported on Android releases older than P.
2171   return base::android::BuildInfo::GetInstance()->sdk_int() <
2172          base::android::SDK_VERSION_P;
2173 }
2174 
GetTaggedBytes(int32_t expected_tag)2175 uint64_t GetTaggedBytes(int32_t expected_tag) {
2176   EXPECT_TRUE(CanGetTaggedBytes());
2177 
2178   // To determine how many bytes the system saw with a particular tag read
2179   // the /proc/net/xt_qtaguid/stats file which contains the kernel's
2180   // dump of all the UIDs and their tags sent and received bytes.
2181   uint64_t bytes = 0;
2182   std::string contents;
2183   EXPECT_TRUE(base::ReadFileToString(
2184       base::FilePath::FromUTF8Unsafe("/proc/net/xt_qtaguid/stats"), &contents));
2185   for (size_t i = contents.find('\n');  // Skip first line which is headers.
2186        i != std::string::npos && i < contents.length();) {
2187     uint64_t tag, rx_bytes;
2188     uid_t uid;
2189     int n;
2190     // Parse out the numbers we care about. For reference here's the column
2191     // headers:
2192     // idx iface acct_tag_hex uid_tag_int cnt_set rx_bytes rx_packets tx_bytes
2193     // tx_packets rx_tcp_bytes rx_tcp_packets rx_udp_bytes rx_udp_packets
2194     // rx_other_bytes rx_other_packets tx_tcp_bytes tx_tcp_packets tx_udp_bytes
2195     // tx_udp_packets tx_other_bytes tx_other_packets
2196     EXPECT_EQ(sscanf(contents.c_str() + i,
2197                      "%*d %*s 0x%" SCNx64 " %d %*d %" SCNu64
2198                      " %*d %*d %*d %*d %*d %*d %*d %*d "
2199                      "%*d %*d %*d %*d %*d %*d %*d%n",
2200                      &tag, &uid, &rx_bytes, &n),
2201               3);
2202     // If this line matches our UID and |expected_tag| then add it to the total.
2203     if (uid == getuid() && (int32_t)(tag >> 32) == expected_tag) {
2204       bytes += rx_bytes;
2205     }
2206     // Move |i| to the next line.
2207     i += n + 1;
2208   }
2209   return bytes;
2210 }
2211 #endif
2212 
2213 }  // namespace net
2214