• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #ifndef NET_SOCKET_SOCKET_TEST_UTIL_H_
6 #define NET_SOCKET_SOCKET_TEST_UTIL_H_
7 
8 #include <stddef.h>
9 #include <stdint.h>
10 
11 #include <cstring>
12 #include <memory>
13 #include <string>
14 #include <utility>
15 #include <vector>
16 
17 #include "base/check_op.h"
18 #include "base/containers/span.h"
19 #include "base/functional/bind.h"
20 #include "base/functional/callback.h"
21 #include "base/memory/ptr_util.h"
22 #include "base/memory/raw_ptr.h"
23 #include "base/memory/raw_ptr_exclusion.h"
24 #include "base/memory/ref_counted.h"
25 #include "base/memory/weak_ptr.h"
26 #include "build/build_config.h"
27 #include "net/base/address_list.h"
28 #include "net/base/completion_once_callback.h"
29 #include "net/base/io_buffer.h"
30 #include "net/base/net_errors.h"
31 #include "net/base/test_completion_callback.h"
32 #include "net/http/http_auth_controller.h"
33 #include "net/log/net_log_with_source.h"
34 #include "net/socket/client_socket_factory.h"
35 #include "net/socket/client_socket_handle.h"
36 #include "net/socket/client_socket_pool.h"
37 #include "net/socket/datagram_client_socket.h"
38 #include "net/socket/socket_performance_watcher.h"
39 #include "net/socket/socket_tag.h"
40 #include "net/socket/ssl_client_socket.h"
41 #include "net/socket/transport_client_socket.h"
42 #include "net/socket/transport_client_socket_pool.h"
43 #include "net/ssl/ssl_config_service.h"
44 #include "net/ssl/ssl_info.h"
45 #include "testing/gtest/include/gtest/gtest.h"
46 #include "third_party/abseil-cpp/absl/types/optional.h"
47 
48 namespace base {
49 class RunLoop;
50 }
51 
52 namespace net {
53 
54 struct CommonConnectJobParams;
55 class NetLog;
56 struct NetworkTrafficAnnotationTag;
57 class X509Certificate;
58 
59 const handles::NetworkHandle kDefaultNetworkForTests = 1;
60 const handles::NetworkHandle kNewNetworkForTests = 2;
61 
62 enum {
63   // A private network error code used by the socket test utility classes.
64   // If the |result| member of a MockRead is
65   // ERR_TEST_PEER_CLOSE_AFTER_NEXT_MOCK_READ, that MockRead is just a
66   // marker that indicates the peer will close the connection after the next
67   // MockRead.  The other members of that MockRead are ignored.
68   ERR_TEST_PEER_CLOSE_AFTER_NEXT_MOCK_READ = -10000,
69 };
70 
71 class AsyncSocket;
72 class MockClientSocket;
73 class SSLClientSocket;
74 class StreamSocket;
75 
76 enum IoMode { ASYNC, SYNCHRONOUS };
77 
78 struct MockConnect {
79   // Asynchronous connection success.
80   // Creates a MockConnect with |mode| ASYC, |result| OK, and
81   // |peer_addr| 192.0.2.33.
82   MockConnect();
83   // Creates a MockConnect with the specified mode and result, with
84   // |peer_addr| 192.0.2.33.
85   MockConnect(IoMode io_mode, int r);
86   MockConnect(IoMode io_mode, int r, IPEndPoint addr);
87   MockConnect(IoMode io_mode, int r, IPEndPoint addr, bool first_attempt_fails);
88   ~MockConnect();
89 
90   IoMode mode;
91   int result;
92   IPEndPoint peer_addr;
93   bool first_attempt_fails = false;
94 };
95 
96 struct MockConfirm {
97   // Asynchronous confirm success.
98   // Creates a MockConfirm with |mode| ASYC and |result| OK.
99   MockConfirm();
100   // Creates a MockConfirm with the specified mode and result.
101   MockConfirm(IoMode io_mode, int r);
102   ~MockConfirm();
103 
104   IoMode mode;
105   int result;
106 };
107 
108 // MockRead and MockWrite shares the same interface and members, but we'd like
109 // to have distinct types because we don't want to have them used
110 // interchangably. To do this, a struct template is defined, and MockRead and
111 // MockWrite are instantiated by using this template. Template parameter |type|
112 // is not used in the struct definition (it purely exists for creating a new
113 // type).
114 //
115 // |data| in MockRead and MockWrite has different meanings: |data| in MockRead
116 // is the data returned from the socket when MockTCPClientSocket::Read() is
117 // attempted, while |data| in MockWrite is the expected data that should be
118 // given in MockTCPClientSocket::Write().
119 enum MockReadWriteType { MOCK_READ, MOCK_WRITE };
120 
121 template <MockReadWriteType type>
122 struct MockReadWrite {
123   // Flag to indicate that the message loop should be terminated.
124   enum { STOPLOOP = 1 << 31 };
125 
126   // Default
MockReadWriteMockReadWrite127   MockReadWrite()
128       : mode(SYNCHRONOUS),
129         result(0),
130         data(nullptr),
131         data_len(0),
132         sequence_number(0) {}
133 
134   // Read/write failure (no data).
MockReadWriteMockReadWrite135   MockReadWrite(IoMode io_mode, int result)
136       : mode(io_mode),
137         result(result),
138         data(nullptr),
139         data_len(0),
140         sequence_number(0) {}
141 
142   // Read/write failure (no data), with sequence information.
MockReadWriteMockReadWrite143   MockReadWrite(IoMode io_mode, int result, int seq)
144       : mode(io_mode),
145         result(result),
146         data(nullptr),
147         data_len(0),
148         sequence_number(seq) {}
149 
150   // Asynchronous read/write success (inferred data length).
MockReadWriteMockReadWrite151   explicit MockReadWrite(const char* data)
152       : mode(ASYNC),
153         result(0),
154         data(data),
155         data_len(strlen(data)),
156         sequence_number(0) {}
157 
158   // Read/write success (inferred data length).
MockReadWriteMockReadWrite159   MockReadWrite(IoMode io_mode, const char* data)
160       : mode(io_mode),
161         result(0),
162         data(data),
163         data_len(strlen(data)),
164         sequence_number(0) {}
165 
166   // Read/write success.
MockReadWriteMockReadWrite167   MockReadWrite(IoMode io_mode, const char* data, int data_len)
168       : mode(io_mode),
169         result(0),
170         data(data),
171         data_len(data_len),
172         sequence_number(0) {}
173 
174   // Read/write success (inferred data length) with sequence information.
MockReadWriteMockReadWrite175   MockReadWrite(IoMode io_mode, int seq, const char* data)
176       : mode(io_mode),
177         result(0),
178         data(data),
179         data_len(strlen(data)),
180         sequence_number(seq) {}
181 
182   // Read/write success with sequence information.
MockReadWriteMockReadWrite183   MockReadWrite(IoMode io_mode, const char* data, int data_len, int seq)
184       : mode(io_mode),
185         result(0),
186         data(data),
187         data_len(data_len),
188         sequence_number(seq) {}
189 
190   IoMode mode;
191   int result;
192   const char* data;
193   int data_len;
194 
195   // For data providers that only allows reads to occur in a particular
196   // sequence.  If a read occurs before the given |sequence_number| is reached,
197   // an ERR_IO_PENDING is returned.
198   int sequence_number;  // The sequence number at which a read is allowed
199                         // to occur.
200 };
201 
202 typedef MockReadWrite<MOCK_READ> MockRead;
203 typedef MockReadWrite<MOCK_WRITE> MockWrite;
204 
205 struct MockWriteResult {
MockWriteResultMockWriteResult206   MockWriteResult(IoMode io_mode, int result) : mode(io_mode), result(result) {}
207 
208   IoMode mode;
209   int result;
210 };
211 
212 // The SocketDataProvider is an interface used by the MockClientSocket
213 // for getting data about individual reads and writes on the socket.  Can be
214 // used with at most one socket at a time.
215 // TODO(mmenke):  Do these really need to be re-useable?
216 class SocketDataProvider {
217  public:
218   SocketDataProvider();
219 
220   SocketDataProvider(const SocketDataProvider&) = delete;
221   SocketDataProvider& operator=(const SocketDataProvider&) = delete;
222 
223   virtual ~SocketDataProvider();
224 
225   // Returns the buffer and result code for the next simulated read.
226   // If the |MockRead.result| is ERR_IO_PENDING, it informs the caller
227   // that it will be called via the AsyncSocket::OnReadComplete()
228   // function at a later time.
229   virtual MockRead OnRead() = 0;
230   virtual MockWriteResult OnWrite(const std::string& data) = 0;
231   virtual bool AllReadDataConsumed() const = 0;
232   virtual bool AllWriteDataConsumed() const = 0;
CancelPendingRead()233   virtual void CancelPendingRead() {}
234 
235   // Returns the last set receive buffer size, or -1 if never set.
receive_buffer_size()236   int receive_buffer_size() const { return receive_buffer_size_; }
set_receive_buffer_size(int receive_buffer_size)237   void set_receive_buffer_size(int receive_buffer_size) {
238     receive_buffer_size_ = receive_buffer_size;
239   }
240 
241   // Returns the last set send buffer size, or -1 if never set.
send_buffer_size()242   int send_buffer_size() const { return send_buffer_size_; }
set_send_buffer_size(int send_buffer_size)243   void set_send_buffer_size(int send_buffer_size) {
244     send_buffer_size_ = send_buffer_size;
245   }
246 
247   // Returns the last set value of TCP no delay, or false if never set.
no_delay()248   bool no_delay() const { return no_delay_; }
set_no_delay(bool no_delay)249   void set_no_delay(bool no_delay) { no_delay_ = no_delay; }
250 
251   // Returns whether TCP keepalives were enabled or not. Returns kDefault by
252   // default.
253   enum class KeepAliveState { kEnabled, kDisabled, kDefault };
keep_alive_state()254   KeepAliveState keep_alive_state() const { return keep_alive_state_; }
255   // Last set TCP keepalive delay.
keep_alive_delay()256   int keep_alive_delay() const { return keep_alive_delay_; }
set_keep_alive(bool enable,int delay)257   void set_keep_alive(bool enable, int delay) {
258     keep_alive_state_ =
259         enable ? KeepAliveState::kEnabled : KeepAliveState::kDisabled;
260     keep_alive_delay_ = delay;
261   }
262 
263   // Setters / getters for the return values of the corresponding Set*()
264   // methods. By default, they all succeed, if the socket is connected.
265 
set_set_receive_buffer_size_result(int receive_buffer_size_result)266   void set_set_receive_buffer_size_result(int receive_buffer_size_result) {
267     set_receive_buffer_size_result_ = receive_buffer_size_result;
268   }
set_receive_buffer_size_result()269   int set_receive_buffer_size_result() const {
270     return set_receive_buffer_size_result_;
271   }
272 
set_set_send_buffer_size_result(int set_send_buffer_size_result)273   void set_set_send_buffer_size_result(int set_send_buffer_size_result) {
274     set_send_buffer_size_result_ = set_send_buffer_size_result;
275   }
set_send_buffer_size_result()276   int set_send_buffer_size_result() const {
277     return set_send_buffer_size_result_;
278   }
279 
set_set_no_delay_result(bool set_no_delay_result)280   void set_set_no_delay_result(bool set_no_delay_result) {
281     set_no_delay_result_ = set_no_delay_result;
282   }
set_no_delay_result()283   bool set_no_delay_result() const { return set_no_delay_result_; }
284 
set_set_keep_alive_result(bool set_keep_alive_result)285   void set_set_keep_alive_result(bool set_keep_alive_result) {
286     set_keep_alive_result_ = set_keep_alive_result;
287   }
set_keep_alive_result()288   bool set_keep_alive_result() const { return set_keep_alive_result_; }
289 
expected_addresses()290   const absl::optional<AddressList>& expected_addresses() const {
291     return expected_addresses_;
292   }
set_expected_addresses(net::AddressList addresses)293   void set_expected_addresses(net::AddressList addresses) {
294     expected_addresses_ = std::move(addresses);
295   }
296 
297   // Returns true if the request should be considered idle, for the purposes of
298   // IsConnectedAndIdle.
299   virtual bool IsIdle() const;
300 
301   // Initializes the SocketDataProvider for use with |socket|.  Must be called
302   // before use
303   void Initialize(AsyncSocket* socket);
304   // Detaches the socket associated with a SocketDataProvider.  Must be called
305   // before |socket_| is destroyed, unless the SocketDataProvider has informed
306   // |socket_| it was destroyed.  Must also be called before Initialize() may
307   // be called again with a new socket.
308   void DetachSocket();
309 
310   // Accessor for the socket which is using the SocketDataProvider.
socket()311   AsyncSocket* socket() { return socket_; }
312 
connect_data()313   MockConnect connect_data() const { return connect_; }
set_connect_data(const MockConnect & connect)314   void set_connect_data(const MockConnect& connect) { connect_ = connect; }
315 
316  private:
317   // Called to inform subclasses of initialization.
318   virtual void Reset() = 0;
319 
320   MockConnect connect_;
321   raw_ptr<AsyncSocket> socket_ = nullptr;
322 
323   int receive_buffer_size_ = -1;
324   int send_buffer_size_ = -1;
325   // This reflects the default state of TCPClientSockets.
326   bool no_delay_ = true;
327 
328   KeepAliveState keep_alive_state_ = KeepAliveState::kDefault;
329   int keep_alive_delay_ = 0;
330 
331   int set_receive_buffer_size_result_ = net::OK;
332   int set_send_buffer_size_result_ = net::OK;
333   bool set_no_delay_result_ = true;
334   bool set_keep_alive_result_ = true;
335   absl::optional<AddressList> expected_addresses_;
336 };
337 
338 // The AsyncSocket is an interface used by the SocketDataProvider to
339 // complete the asynchronous read operation.
340 class AsyncSocket {
341  public:
342   // If an async IO is pending because the SocketDataProvider returned
343   // ERR_IO_PENDING, then the AsyncSocket waits until this OnReadComplete
344   // is called to complete the asynchronous read operation.
345   // data.async is ignored, and this read is completed synchronously as
346   // part of this call.
347   // TODO(rch): this should take a StringPiece since most of the fields
348   // are ignored.
349   virtual void OnReadComplete(const MockRead& data) = 0;
350   // If an async IO is pending because the SocketDataProvider returned
351   // ERR_IO_PENDING, then the AsyncSocket waits until this OnReadComplete
352   // is called to complete the asynchronous read operation.
353   virtual void OnWriteComplete(int rv) = 0;
354   virtual void OnConnectComplete(const MockConnect& data) = 0;
355 
356   // Called when the SocketDataProvider associated with the socket is destroyed.
357   // The socket may continue to be used after the data provider is destroyed,
358   // so it should be sure not to dereference the provider after this is called.
359   virtual void OnDataProviderDestroyed() = 0;
360 };
361 
362 class SocketDataPrinter {
363  public:
364   ~SocketDataPrinter() = default;
365 
366   // Prints the write in |data| using some sort of protocol-specific
367   // format.
368   virtual std::string PrintWrite(const std::string& data) = 0;
369 };
370 
371 // StaticSocketDataHelper manages a list of reads and writes.
372 class StaticSocketDataHelper {
373  public:
374   StaticSocketDataHelper(base::span<const MockRead> reads,
375                          base::span<const MockWrite> writes);
376 
377   StaticSocketDataHelper(const StaticSocketDataHelper&) = delete;
378   StaticSocketDataHelper& operator=(const StaticSocketDataHelper&) = delete;
379 
380   ~StaticSocketDataHelper();
381 
382   // These functions get access to the next available read and write data. They
383   // CHECK fail if there is no data available.
384   const MockRead& PeekRead() const;
385   const MockWrite& PeekWrite() const;
386 
387   // Returns the current read or write, and then advances to the next one.
388   const MockRead& AdvanceRead();
389   const MockWrite& AdvanceWrite();
390 
391   // Resets the read and write indexes to 0.
392   void Reset();
393 
394   // Returns true if |data| is valid data for the next write. In order
395   // to support short writes, the next write may be longer than |data|
396   // in which case this method will still return true.
397   bool VerifyWriteData(const std::string& data, SocketDataPrinter* printer);
398 
read_index()399   size_t read_index() const { return read_index_; }
write_index()400   size_t write_index() const { return write_index_; }
read_count()401   size_t read_count() const { return reads_.size(); }
write_count()402   size_t write_count() const { return writes_.size(); }
403 
AllReadDataConsumed()404   bool AllReadDataConsumed() const { return read_index() >= read_count(); }
AllWriteDataConsumed()405   bool AllWriteDataConsumed() const { return write_index() >= write_count(); }
406 
407  private:
408   // Returns the next available read or write that is not a pause event. CHECK
409   // fails if no data is available.
410   const MockWrite& PeekRealWrite() const;
411 
412   const base::span<const MockRead> reads_;
413   size_t read_index_ = 0;
414   const base::span<const MockWrite> writes_;
415   size_t write_index_ = 0;
416 };
417 
418 // SocketDataProvider which responds based on static tables of mock reads and
419 // writes.
420 class StaticSocketDataProvider : public SocketDataProvider {
421  public:
422   StaticSocketDataProvider();
423   StaticSocketDataProvider(base::span<const MockRead> reads,
424                            base::span<const MockWrite> writes);
425 
426   StaticSocketDataProvider(const StaticSocketDataProvider&) = delete;
427   StaticSocketDataProvider& operator=(const StaticSocketDataProvider&) = delete;
428 
429   ~StaticSocketDataProvider() override;
430 
431   // Pause/resume reads from this provider.
432   void Pause();
433   void Resume();
434 
435   // From SocketDataProvider:
436   MockRead OnRead() override;
437   MockWriteResult OnWrite(const std::string& data) override;
438   bool AllReadDataConsumed() const override;
439   bool AllWriteDataConsumed() const override;
440 
read_index()441   size_t read_index() const { return helper_.read_index(); }
write_index()442   size_t write_index() const { return helper_.write_index(); }
read_count()443   size_t read_count() const { return helper_.read_count(); }
write_count()444   size_t write_count() const { return helper_.write_count(); }
445 
set_printer(SocketDataPrinter * printer)446   void set_printer(SocketDataPrinter* printer) { printer_ = printer; }
447 
448  private:
449   // From SocketDataProvider:
450   void Reset() override;
451 
452   StaticSocketDataHelper helper_;
453   // This field is not a raw_ptr<> because it was filtered by the rewriter for:
454   // #union
455   RAW_PTR_EXCLUSION SocketDataPrinter* printer_ = nullptr;
456   bool paused_ = false;
457 };
458 
459 // SSLSocketDataProviders only need to keep track of the return code from calls
460 // to Connect().
461 struct SSLSocketDataProvider {
462   SSLSocketDataProvider(IoMode mode, int result);
463   SSLSocketDataProvider(const SSLSocketDataProvider& other);
464   ~SSLSocketDataProvider();
465 
466   // Returns whether MockConnect data has been consumed.
ConnectDataConsumedSSLSocketDataProvider467   bool ConnectDataConsumed() const { return is_connect_data_consumed; }
468 
469   // Returns whether MockConfirm data has been consumed.
ConfirmDataConsumedSSLSocketDataProvider470   bool ConfirmDataConsumed() const { return is_confirm_data_consumed; }
471 
472   // Returns whether a Write occurred before ConfirmHandshake completed.
WriteBeforeConfirmSSLSocketDataProvider473   bool WriteBeforeConfirm() const { return write_called_before_confirm; }
474 
475   // Result for Connect().
476   MockConnect connect;
477   // Callback to run when Connect() is called. This is called at most once per
478   // socket but is repeating because SSLSocketDataProvider is copyable.
479   base::RepeatingClosure connect_callback;
480 
481   // Result for ConfirmHandshake().
482   MockConfirm confirm;
483   // Callback to run when ConfirmHandshake() is called. This is called at most
484   // once per socket but is repeating because SSLSocketDataProvider is
485   // copyable.
486   base::RepeatingClosure confirm_callback;
487 
488   // Result for GetNegotiatedProtocol().
489   NextProto next_proto = kProtoUnknown;
490 
491   // Result for GetPeerApplicationSettings().
492   absl::optional<std::string> peer_application_settings;
493 
494   // Result for GetSSLInfo().
495   SSLInfo ssl_info;
496 
497   // Result for GetSSLCertRequestInfo().
498   // This field is not a raw_ptr<> because it was filtered by the rewriter for:
499   // #union
500   RAW_PTR_EXCLUSION SSLCertRequestInfo* cert_request_info = nullptr;
501 
502   // Result for GetECHRetryConfigs().
503   std::vector<uint8_t> ech_retry_configs;
504 
505   absl::optional<NextProtoVector> next_protos_expected_in_ssl_config;
506 
507   uint16_t expected_ssl_version_min;
508   uint16_t expected_ssl_version_max;
509   absl::optional<bool> expected_send_client_cert;
510   scoped_refptr<X509Certificate> expected_client_cert;
511   absl::optional<HostPortPair> expected_host_and_port;
512   absl::optional<NetworkAnonymizationKey> expected_network_anonymization_key;
513   absl::optional<bool> expected_disable_sha1_server_signatures;
514   absl::optional<std::vector<uint8_t>> expected_ech_config_list;
515 
516   bool is_connect_data_consumed = false;
517   bool is_confirm_data_consumed = false;
518   bool write_called_before_confirm = false;
519 };
520 
521 // Uses the sequence_number field in the mock reads and writes to
522 // complete the operations in a specified order.
523 class SequencedSocketData : public SocketDataProvider {
524  public:
525   SequencedSocketData();
526 
527   // |reads| is the list of MockRead completions.
528   // |writes| is the list of MockWrite completions.
529   SequencedSocketData(base::span<const MockRead> reads,
530                       base::span<const MockWrite> writes);
531 
532   // |connect| is the result for the connect phase.
533   // |reads| is the list of MockRead completions.
534   // |writes| is the list of MockWrite completions.
535   SequencedSocketData(const MockConnect& connect,
536                       base::span<const MockRead> reads,
537                       base::span<const MockWrite> writes);
538 
539   SequencedSocketData(const SequencedSocketData&) = delete;
540   SequencedSocketData& operator=(const SequencedSocketData&) = delete;
541 
542   ~SequencedSocketData() override;
543 
544   // From SocketDataProvider:
545   MockRead OnRead() override;
546   MockWriteResult OnWrite(const std::string& data) override;
547   bool AllReadDataConsumed() const override;
548   bool AllWriteDataConsumed() const override;
549   bool IsIdle() const override;
550   void CancelPendingRead() override;
551 
552   // An ASYNC read event with a return value of ERR_IO_PENDING will cause the
553   // socket data to pause at that event, and advance no further, until Resume is
554   // invoked.  At that point, the socket will continue at the next event in the
555   // sequence.
556   //
557   // If a request just wants to simulate a connection that stays open and never
558   // receives any more data, instead of pausing and then resuming a request, it
559   // should use a SYNCHRONOUS event with a return value of ERR_IO_PENDING
560   // instead.
561   bool IsPaused() const;
562   // Resumes events once |this| is in the paused state.  The next event will
563   // occur synchronously with the call if it can.
564   void Resume();
565   void RunUntilPaused();
566 
567   // When true, IsConnectedAndIdle() will return false if the next event in the
568   // sequence is a synchronous.  Otherwise, the socket claims to be idle as
569   // long as it's connected.  Defaults to false.
570   // TODO(mmenke):  See if this can be made the default behavior, and consider
571   // removing this mehtod.  Need to make sure it doesn't change what code any
572   // tests are targetted at testing.
set_busy_before_sync_reads(bool busy_before_sync_reads)573   void set_busy_before_sync_reads(bool busy_before_sync_reads) {
574     busy_before_sync_reads_ = busy_before_sync_reads;
575   }
576 
set_printer(SocketDataPrinter * printer)577   void set_printer(SocketDataPrinter* printer) { printer_ = printer; }
578 
579  private:
580   // Defines the state for the read or write path.
581   enum class IoState {
582     kIdle,        // No async operation is in progress.
583     kPending,     // An async operation in waiting for another operation to
584                   // complete.
585     kCompleting,  // A task has been posted to complete an async operation.
586     kPaused,      // IO is paused until Resume() is called.
587   };
588 
589   // From SocketDataProvider:
590   void Reset() override;
591 
592   void OnReadComplete();
593   void OnWriteComplete();
594 
595   void MaybePostReadCompleteTask();
596   void MaybePostWriteCompleteTask();
597 
598   StaticSocketDataHelper helper_;
599   raw_ptr<SocketDataPrinter> printer_ = nullptr;
600   int sequence_number_ = 0;
601   IoState read_state_ = IoState::kIdle;
602   IoState write_state_ = IoState::kIdle;
603 
604   bool busy_before_sync_reads_ = false;
605 
606   // Used by RunUntilPaused.  NULL at all other times.
607   std::unique_ptr<base::RunLoop> run_until_paused_run_loop_;
608 
609   base::WeakPtrFactory<SequencedSocketData> weak_factory_{this};
610 };
611 
612 // Holds an array of SocketDataProvider elements.  As Mock{TCP,SSL}StreamSocket
613 // objects get instantiated, they take their data from the i'th element of this
614 // array.
615 template <typename T>
616 class SocketDataProviderArray {
617  public:
618   SocketDataProviderArray() = default;
619 
GetNext()620   T* GetNext() {
621     DCHECK_LT(next_index_, data_providers_.size());
622     return data_providers_[next_index_++];
623   }
624 
625   // Like GetNext(), but returns nullptr when the end of the array is reached,
626   // instead of DCHECKing. GetNext() should generally be preferred, unless
627   // having no remaining elements is expected in some cases and is handled
628   // safely.
GetNextWithoutAsserting()629   T* GetNextWithoutAsserting() {
630     if (next_index_ == data_providers_.size())
631       return nullptr;
632     return data_providers_[next_index_++];
633   }
634 
Add(T * data_provider)635   void Add(T* data_provider) {
636     DCHECK(data_provider);
637     data_providers_.push_back(data_provider);
638   }
639 
next_index()640   size_t next_index() { return next_index_; }
641 
ResetNextIndex()642   void ResetNextIndex() { next_index_ = 0; }
643 
644  private:
645   // Index of the next |data_providers_| element to use. Not an iterator
646   // because those are invalidated on vector reallocation.
647   size_t next_index_ = 0;
648 
649   // SocketDataProviders to be returned.
650   std::vector<T*> data_providers_;
651 };
652 
653 class MockUDPClientSocket;
654 class MockTCPClientSocket;
655 class MockSSLClientSocket;
656 
657 // ClientSocketFactory which contains arrays of sockets of each type.
658 // You should first fill the arrays using Add{SSL,}SocketDataProvider(). When
659 // the factory is asked to create a socket, it takes next entry from appropriate
660 // array. You can use ResetNextMockIndexes to reset that next entry index for
661 // all mock socket types.
662 class MockClientSocketFactory : public ClientSocketFactory {
663  public:
664   MockClientSocketFactory();
665 
666   MockClientSocketFactory(const MockClientSocketFactory&) = delete;
667   MockClientSocketFactory& operator=(const MockClientSocketFactory&) = delete;
668 
669   ~MockClientSocketFactory() override;
670 
671   // Adds a SocketDataProvider that can be used to served either TCP or UDP
672   // connection requests. Sockets are returned in FIFO order.
673   void AddSocketDataProvider(SocketDataProvider* socket);
674 
675   // Like AddSocketDataProvider(), except sockets will only be used to service
676   // TCP connection requests. Sockets added with this method are used first,
677   // before sockets added with AddSocketDataProvider(). Particularly useful for
678   // QUIC tests with multiple sockets, where TCP connections may or may not be
679   // made, and have no guaranteed order, relative to UDP connections.
680   void AddTcpSocketDataProvider(SocketDataProvider* socket);
681 
682   void AddSSLSocketDataProvider(SSLSocketDataProvider* socket);
683   void ResetNextMockIndexes();
684 
mock_data()685   SocketDataProviderArray<SocketDataProvider>& mock_data() {
686     return mock_data_;
687   }
688 
set_enable_read_if_ready(bool enable_read_if_ready)689   void set_enable_read_if_ready(bool enable_read_if_ready) {
690     enable_read_if_ready_ = enable_read_if_ready;
691   }
692 
693   // ClientSocketFactory
694   std::unique_ptr<DatagramClientSocket> CreateDatagramClientSocket(
695       DatagramSocket::BindType bind_type,
696       NetLog* net_log,
697       const NetLogSource& source) override;
698   std::unique_ptr<TransportClientSocket> CreateTransportClientSocket(
699       const AddressList& addresses,
700       std::unique_ptr<SocketPerformanceWatcher> socket_performance_watcher,
701       NetworkQualityEstimator* network_quality_estimator,
702       NetLog* net_log,
703       const NetLogSource& source) override;
704   std::unique_ptr<SSLClientSocket> CreateSSLClientSocket(
705       SSLClientContext* context,
706       std::unique_ptr<StreamSocket> stream_socket,
707       const HostPortPair& host_and_port,
708       const SSLConfig& ssl_config) override;
udp_client_socket_ports()709   const std::vector<uint16_t>& udp_client_socket_ports() const {
710     return udp_client_socket_ports_;
711   }
712 
713  private:
714   SocketDataProviderArray<SocketDataProvider> mock_data_;
715   SocketDataProviderArray<SocketDataProvider> mock_tcp_data_;
716   SocketDataProviderArray<SSLSocketDataProvider> mock_ssl_data_;
717   std::vector<uint16_t> udp_client_socket_ports_;
718 
719   // If true, ReadIfReady() is enabled; otherwise ReadIfReady() returns
720   // ERR_READ_IF_READY_NOT_IMPLEMENTED.
721   bool enable_read_if_ready_ = false;
722 };
723 
724 class MockClientSocket : public TransportClientSocket {
725  public:
726   // The NetLogWithSource is needed to test LoadTimingInfo, which uses NetLog
727   // IDs as
728   // unique socket IDs.
729   explicit MockClientSocket(const NetLogWithSource& net_log);
730 
731   MockClientSocket(const MockClientSocket&) = delete;
732   MockClientSocket& operator=(const MockClientSocket&) = delete;
733 
734   // Socket implementation.
735   int Read(IOBuffer* buf,
736            int buf_len,
737            CompletionOnceCallback callback) override = 0;
738   int Write(IOBuffer* buf,
739             int buf_len,
740             CompletionOnceCallback callback,
741             const NetworkTrafficAnnotationTag& traffic_annotation) override = 0;
742   int SetReceiveBufferSize(int32_t size) override;
743   int SetSendBufferSize(int32_t size) override;
744 
745   // TransportClientSocket implementation.
746   int Bind(const net::IPEndPoint& local_addr) override;
747   bool SetNoDelay(bool no_delay) override;
748   bool SetKeepAlive(bool enable, int delay) override;
749 
750   // StreamSocket implementation.
751   int Connect(CompletionOnceCallback callback) override = 0;
752   void Disconnect() override;
753   bool IsConnected() const override;
754   bool IsConnectedAndIdle() const override;
755   int GetPeerAddress(IPEndPoint* address) const override;
756   int GetLocalAddress(IPEndPoint* address) const override;
757   const NetLogWithSource& NetLog() const override;
758   bool WasAlpnNegotiated() const override;
759   NextProto GetNegotiatedProtocol() const override;
760   int64_t GetTotalReceivedBytes() const override;
ApplySocketTag(const SocketTag & tag)761   void ApplySocketTag(const SocketTag& tag) override {}
762 
763  protected:
764   ~MockClientSocket() override;
765   void RunCallbackAsync(CompletionOnceCallback callback, int result);
766   void RunCallback(CompletionOnceCallback callback, int result);
767 
768   // True if Connect completed successfully and Disconnect hasn't been called.
769   bool connected_ = false;
770 
771   IPEndPoint local_addr_;
772   IPEndPoint peer_addr_;
773 
774   NetLogWithSource net_log_;
775 
776  private:
777   base::WeakPtrFactory<MockClientSocket> weak_factory_{this};
778 };
779 
780 class MockTCPClientSocket : public MockClientSocket, public AsyncSocket {
781  public:
782   MockTCPClientSocket(const AddressList& addresses,
783                       net::NetLog* net_log,
784                       SocketDataProvider* socket);
785 
786   MockTCPClientSocket(const MockTCPClientSocket&) = delete;
787   MockTCPClientSocket& operator=(const MockTCPClientSocket&) = delete;
788 
789   ~MockTCPClientSocket() override;
790 
addresses()791   const AddressList& addresses() const { return addresses_; }
792 
793   // Socket implementation.
794   int Read(IOBuffer* buf,
795            int buf_len,
796            CompletionOnceCallback callback) override;
797   int ReadIfReady(IOBuffer* buf,
798                   int buf_len,
799                   CompletionOnceCallback callback) override;
800   int CancelReadIfReady() override;
801   int Write(IOBuffer* buf,
802             int buf_len,
803             CompletionOnceCallback callback,
804             const NetworkTrafficAnnotationTag& traffic_annotation) override;
805   int SetReceiveBufferSize(int32_t size) override;
806   int SetSendBufferSize(int32_t size) override;
807 
808   // TransportClientSocket implementation.
809   bool SetNoDelay(bool no_delay) override;
810   bool SetKeepAlive(bool enable, int delay) override;
811 
812   // StreamSocket implementation.
813   void SetBeforeConnectCallback(
814       const BeforeConnectCallback& before_connect_callback) override;
815   int Connect(CompletionOnceCallback callback) override;
816   void Disconnect() override;
817   bool IsConnected() const override;
818   bool IsConnectedAndIdle() const override;
819   int GetPeerAddress(IPEndPoint* address) const override;
820   bool WasEverUsed() const override;
821   bool GetSSLInfo(SSLInfo* ssl_info) override;
822 
823   // AsyncSocket:
824   void OnReadComplete(const MockRead& data) override;
825   void OnWriteComplete(int rv) override;
826   void OnConnectComplete(const MockConnect& data) override;
827   void OnDataProviderDestroyed() override;
828 
set_enable_read_if_ready(bool enable_read_if_ready)829   void set_enable_read_if_ready(bool enable_read_if_ready) {
830     enable_read_if_ready_ = enable_read_if_ready;
831   }
832 
833  private:
834   void RetryRead(int rv);
835   int ReadIfReadyImpl(IOBuffer* buf,
836                       int buf_len,
837                       CompletionOnceCallback callback);
838 
839   // Helper method to run |pending_read_if_ready_callback_| if it is not null.
840   void RunReadIfReadyCallback(int result);
841 
842   AddressList addresses_;
843 
844   raw_ptr<SocketDataProvider> data_;
845   int read_offset_ = 0;
846   MockRead read_data_;
847   bool need_read_data_ = true;
848 
849   // True if the peer has closed the connection.  This allows us to simulate
850   // the recv(..., MSG_PEEK) call in the IsConnectedAndIdle method of the real
851   // TCPClientSocket.
852   bool peer_closed_connection_ = false;
853 
854   // While an asynchronous read is pending, we save our user-buffer state.
855   scoped_refptr<IOBuffer> pending_read_buf_ = nullptr;
856   int pending_read_buf_len_ = 0;
857   CompletionOnceCallback pending_read_callback_;
858 
859   // Non-null when a ReadIfReady() is pending.
860   CompletionOnceCallback pending_read_if_ready_callback_;
861 
862   CompletionOnceCallback pending_connect_callback_;
863   CompletionOnceCallback pending_write_callback_;
864   bool was_used_to_convey_data_ = false;
865 
866   // If true, ReadIfReady() is enabled; otherwise ReadIfReady() returns
867   // ERR_READ_IF_READY_NOT_IMPLEMENTED.
868   bool enable_read_if_ready_ = false;
869 
870   BeforeConnectCallback before_connect_callback_;
871 };
872 
873 class MockSSLClientSocket : public AsyncSocket, public SSLClientSocket {
874  public:
875   MockSSLClientSocket(std::unique_ptr<StreamSocket> stream_socket,
876                       const HostPortPair& host_and_port,
877                       const SSLConfig& ssl_config,
878                       SSLSocketDataProvider* socket);
879 
880   MockSSLClientSocket(const MockSSLClientSocket&) = delete;
881   MockSSLClientSocket& operator=(const MockSSLClientSocket&) = delete;
882 
883   ~MockSSLClientSocket() override;
884 
885   // Socket implementation.
886   int Read(IOBuffer* buf,
887            int buf_len,
888            CompletionOnceCallback callback) override;
889   int ReadIfReady(IOBuffer* buf,
890                   int buf_len,
891                   CompletionOnceCallback callback) override;
892   int Write(IOBuffer* buf,
893             int buf_len,
894             CompletionOnceCallback callback,
895             const NetworkTrafficAnnotationTag& traffic_annotation) override;
896   int CancelReadIfReady() override;
897 
898   // StreamSocket implementation.
899   int Connect(CompletionOnceCallback callback) override;
900   void Disconnect() override;
901   int ConfirmHandshake(CompletionOnceCallback callback) override;
902   bool IsConnected() const override;
903   bool IsConnectedAndIdle() const override;
904   bool WasEverUsed() const override;
905   int GetPeerAddress(IPEndPoint* address) const override;
906   int GetLocalAddress(IPEndPoint* address) const override;
907   bool WasAlpnNegotiated() const override;
908   NextProto GetNegotiatedProtocol() const override;
909   absl::optional<base::StringPiece> GetPeerApplicationSettings() const override;
910   bool GetSSLInfo(SSLInfo* ssl_info) override;
911   void GetSSLCertRequestInfo(
912       SSLCertRequestInfo* cert_request_info) const override;
913   void ApplySocketTag(const SocketTag& tag) override;
914   const NetLogWithSource& NetLog() const override;
915   int64_t GetTotalReceivedBytes() const override;
916   int SetReceiveBufferSize(int32_t size) override;
917   int SetSendBufferSize(int32_t size) override;
918 
919   // SSLSocket implementation.
920   int ExportKeyingMaterial(base::StringPiece label,
921                            bool has_context,
922                            base::StringPiece context,
923                            unsigned char* out,
924                            unsigned int outlen) override;
925 
926   // SSLClientSocket implementation.
927   std::vector<uint8_t> GetECHRetryConfigs() override;
928 
929   // This MockSocket does not implement the manual async IO feature.
930   void OnReadComplete(const MockRead& data) override;
931   void OnWriteComplete(int rv) override;
932   void OnConnectComplete(const MockConnect& data) override;
933   // SSL sockets don't need magic to deal with destruction of their data
934   // provider.
935   // TODO(mmenke):  Probably a good idea to support it, anyways.
OnDataProviderDestroyed()936   void OnDataProviderDestroyed() override {}
937 
938  private:
939   static void ConnectCallback(MockSSLClientSocket* ssl_client_socket,
940                               CompletionOnceCallback callback,
941                               int rv);
942 
943   void RunCallbackAsync(CompletionOnceCallback callback, int result);
944   void RunCallback(CompletionOnceCallback callback, int result);
945 
946   void RunConfirmHandshakeCallback(CompletionOnceCallback callback, int result);
947 
948   bool connected_ = false;
949   bool in_confirm_handshake_ = false;
950   NetLogWithSource net_log_;
951   std::unique_ptr<StreamSocket> stream_socket_;
952   raw_ptr<SSLSocketDataProvider> data_;
953   // Address of the "remote" peer we're connected to.
954   IPEndPoint peer_addr_;
955 
956   base::WeakPtrFactory<MockSSLClientSocket> weak_factory_{this};
957 };
958 
959 class MockUDPClientSocket : public DatagramClientSocket, public AsyncSocket {
960  public:
961   explicit MockUDPClientSocket(SocketDataProvider* data = nullptr,
962                                net::NetLog* net_log = nullptr);
963 
964   MockUDPClientSocket(const MockUDPClientSocket&) = delete;
965   MockUDPClientSocket& operator=(const MockUDPClientSocket&) = delete;
966 
967   ~MockUDPClientSocket() override;
968 
969   // Socket implementation.
970   int Read(IOBuffer* buf,
971            int buf_len,
972            CompletionOnceCallback callback) override;
973   int Write(IOBuffer* buf,
974             int buf_len,
975             CompletionOnceCallback callback,
976             const NetworkTrafficAnnotationTag& traffic_annotation) override;
977 
978   int SetReceiveBufferSize(int32_t size) override;
979   int SetSendBufferSize(int32_t size) override;
980   int SetDoNotFragment() override;
981 
982   // DatagramSocket implementation.
983   void Close() override;
984   int GetPeerAddress(IPEndPoint* address) const override;
985   int GetLocalAddress(IPEndPoint* address) const override;
986   void UseNonBlockingIO() override;
987   int SetMulticastInterface(uint32_t interface_index) override;
988   const NetLogWithSource& NetLog() const override;
989 
990   // DatagramClientSocket implementation.
991   int Connect(const IPEndPoint& address) override;
992   int ConnectUsingNetwork(handles::NetworkHandle network,
993                           const IPEndPoint& address) override;
994   int ConnectUsingDefaultNetwork(const IPEndPoint& address) override;
995   int ConnectAsync(const IPEndPoint& address,
996                    CompletionOnceCallback callback) override;
997   int ConnectUsingNetworkAsync(handles::NetworkHandle network,
998                                const IPEndPoint& address,
999                                CompletionOnceCallback callback) override;
1000   int ConnectUsingDefaultNetworkAsync(const IPEndPoint& address,
1001                                       CompletionOnceCallback callback) override;
1002   handles::NetworkHandle GetBoundNetwork() const override;
1003   void ApplySocketTag(const SocketTag& tag) override;
SetMsgConfirm(bool confirm)1004   void SetMsgConfirm(bool confirm) override {}
1005 
1006   // AsyncSocket implementation.
1007   void OnReadComplete(const MockRead& data) override;
1008   void OnWriteComplete(int rv) override;
1009   void OnConnectComplete(const MockConnect& data) override;
1010   void OnDataProviderDestroyed() override;
1011 
set_source_port(uint16_t port)1012   void set_source_port(uint16_t port) { source_port_ = port; }
source_port()1013   uint16_t source_port() const { return source_port_; }
set_source_host(IPAddress addr)1014   void set_source_host(IPAddress addr) { source_host_ = addr; }
source_host()1015   IPAddress source_host() const { return source_host_; }
1016 
1017   // Returns last tag applied to socket.
tag()1018   SocketTag tag() const { return tag_; }
1019 
1020   // Returns false if socket's tag was changed after the socket was used for
1021   // data transfer (e.g. Read/Write() called), otherwise returns true.
tagged_before_data_transferred()1022   bool tagged_before_data_transferred() const {
1023     return tagged_before_data_transferred_;
1024   }
1025 
1026  private:
1027   int CompleteRead();
1028 
1029   void RunCallbackAsync(CompletionOnceCallback callback, int result);
1030   void RunCallback(CompletionOnceCallback callback, int result);
1031 
1032   bool connected_ = false;
1033   raw_ptr<SocketDataProvider> data_;
1034   int read_offset_ = 0;
1035   MockRead read_data_;
1036   bool need_read_data_ = true;
1037   IPAddress source_host_;
1038   uint16_t source_port_ = 123;  // Ephemeral source port.
1039 
1040   // Address of the "remote" peer we're connected to.
1041   IPEndPoint peer_addr_;
1042 
1043   // Network that the socket is bound to.
1044   handles::NetworkHandle network_ = handles::kInvalidNetworkHandle;
1045 
1046   // While an asynchronous IO is pending, we save our user-buffer state.
1047   scoped_refptr<IOBuffer> pending_read_buf_ = nullptr;
1048   int pending_read_buf_len_ = 0;
1049   CompletionOnceCallback pending_read_callback_;
1050   CompletionOnceCallback pending_write_callback_;
1051 
1052   NetLogWithSource net_log_;
1053 
1054   DatagramBuffers unwritten_buffers_;
1055 
1056   SocketTag tag_;
1057   bool data_transferred_ = false;
1058   bool tagged_before_data_transferred_ = true;
1059 
1060   base::WeakPtrFactory<MockUDPClientSocket> weak_factory_{this};
1061 };
1062 
1063 class TestSocketRequest : public TestCompletionCallbackBase {
1064  public:
1065   TestSocketRequest(std::vector<TestSocketRequest*>* request_order,
1066                     size_t* completion_count);
1067 
1068   TestSocketRequest(const TestSocketRequest&) = delete;
1069   TestSocketRequest& operator=(const TestSocketRequest&) = delete;
1070 
1071   ~TestSocketRequest() override;
1072 
handle()1073   ClientSocketHandle* handle() { return &handle_; }
1074 
callback()1075   CompletionOnceCallback callback() {
1076     return base::BindOnce(&TestSocketRequest::OnComplete,
1077                           base::Unretained(this));
1078   }
1079 
1080  private:
1081   void OnComplete(int result);
1082 
1083   ClientSocketHandle handle_;
1084   raw_ptr<std::vector<TestSocketRequest*>> request_order_;
1085   raw_ptr<size_t> completion_count_;
1086 };
1087 
1088 class ClientSocketPoolTest {
1089  public:
1090   enum KeepAlive {
1091     KEEP_ALIVE,
1092 
1093     // A socket will be disconnected in addition to handle being reset.
1094     NO_KEEP_ALIVE,
1095   };
1096 
1097   static const int kIndexOutOfBounds;
1098   static const int kRequestNotFound;
1099 
1100   ClientSocketPoolTest();
1101 
1102   ClientSocketPoolTest(const ClientSocketPoolTest&) = delete;
1103   ClientSocketPoolTest& operator=(const ClientSocketPoolTest&) = delete;
1104 
1105   ~ClientSocketPoolTest();
1106 
1107   template <typename PoolType>
StartRequestUsingPool(PoolType * socket_pool,const ClientSocketPool::GroupId & group_id,RequestPriority priority,ClientSocketPool::RespectLimits respect_limits,const scoped_refptr<typename PoolType::SocketParams> & socket_params)1108   int StartRequestUsingPool(
1109       PoolType* socket_pool,
1110       const ClientSocketPool::GroupId& group_id,
1111       RequestPriority priority,
1112       ClientSocketPool::RespectLimits respect_limits,
1113       const scoped_refptr<typename PoolType::SocketParams>& socket_params) {
1114     DCHECK(socket_pool);
1115     TestSocketRequest* request(
1116         new TestSocketRequest(&request_order_, &completion_count_));
1117     requests_.push_back(base::WrapUnique(request));
1118     int rv = request->handle()->Init(
1119         group_id, socket_params, absl::nullopt /* proxy_annotation_tag */,
1120         priority, SocketTag(), respect_limits, request->callback(),
1121         ClientSocketPool::ProxyAuthCallback(), socket_pool, NetLogWithSource());
1122     if (rv != ERR_IO_PENDING)
1123       request_order_.push_back(request);
1124     return rv;
1125   }
1126 
1127   // Provided there were n requests started, takes |index| in range 1..n
1128   // and returns order in which that request completed, in range 1..n,
1129   // or kIndexOutOfBounds if |index| is out of bounds, or kRequestNotFound
1130   // if that request did not complete (for example was canceled).
1131   int GetOrderOfRequest(size_t index) const;
1132 
1133   // Resets first initialized socket handle from |requests_|. If found such
1134   // a handle, returns true.
1135   bool ReleaseOneConnection(KeepAlive keep_alive);
1136 
1137   // Releases connections until there is nothing to release.
1138   void ReleaseAllConnections(KeepAlive keep_alive);
1139 
1140   // Note that this uses 0-based indices, while GetOrderOfRequest takes and
1141   // returns 1-based indices.
request(int i)1142   TestSocketRequest* request(int i) { return requests_[i].get(); }
1143 
requests_size()1144   size_t requests_size() const { return requests_.size(); }
requests()1145   std::vector<std::unique_ptr<TestSocketRequest>>* requests() {
1146     return &requests_;
1147   }
completion_count()1148   size_t completion_count() const { return completion_count_; }
1149 
1150  private:
1151   std::vector<std::unique_ptr<TestSocketRequest>> requests_;
1152   std::vector<TestSocketRequest*> request_order_;
1153   size_t completion_count_ = 0;
1154 };
1155 
1156 class MockTransportSocketParams
1157     : public base::RefCounted<MockTransportSocketParams> {
1158  public:
1159   MockTransportSocketParams(const MockTransportSocketParams&) = delete;
1160   MockTransportSocketParams& operator=(const MockTransportSocketParams&) =
1161       delete;
1162 
1163  private:
1164   friend class base::RefCounted<MockTransportSocketParams>;
1165   ~MockTransportSocketParams() = default;
1166 };
1167 
1168 class MockTransportClientSocketPool : public TransportClientSocketPool {
1169  public:
1170   class MockConnectJob {
1171    public:
1172     MockConnectJob(std::unique_ptr<StreamSocket> socket,
1173                    ClientSocketHandle* handle,
1174                    const SocketTag& socket_tag,
1175                    CompletionOnceCallback callback,
1176                    RequestPriority priority);
1177 
1178     MockConnectJob(const MockConnectJob&) = delete;
1179     MockConnectJob& operator=(const MockConnectJob&) = delete;
1180 
1181     ~MockConnectJob();
1182 
1183     int Connect();
1184     bool CancelHandle(const ClientSocketHandle* handle);
1185 
handle()1186     ClientSocketHandle* handle() const { return handle_; }
1187 
priority()1188     RequestPriority priority() const { return priority_; }
set_priority(RequestPriority priority)1189     void set_priority(RequestPriority priority) { priority_ = priority; }
1190 
1191    private:
1192     void OnConnect(int rv);
1193 
1194     std::unique_ptr<StreamSocket> socket_;
1195     raw_ptr<ClientSocketHandle> handle_;
1196     const SocketTag socket_tag_;
1197     CompletionOnceCallback user_callback_;
1198     RequestPriority priority_;
1199   };
1200 
1201   MockTransportClientSocketPool(
1202       int max_sockets,
1203       int max_sockets_per_group,
1204       const CommonConnectJobParams* common_connect_job_params);
1205 
1206   MockTransportClientSocketPool(const MockTransportClientSocketPool&) = delete;
1207   MockTransportClientSocketPool& operator=(
1208       const MockTransportClientSocketPool&) = delete;
1209 
1210   ~MockTransportClientSocketPool() override;
1211 
last_request_priority()1212   RequestPriority last_request_priority() const {
1213     return last_request_priority_;
1214   }
1215 
requests()1216   const std::vector<std::unique_ptr<MockConnectJob>>& requests() const {
1217     return job_list_;
1218   }
1219 
release_count()1220   int release_count() const { return release_count_; }
cancel_count()1221   int cancel_count() const { return cancel_count_; }
1222 
1223   // TransportClientSocketPool implementation.
1224   int RequestSocket(
1225       const GroupId& group_id,
1226       scoped_refptr<ClientSocketPool::SocketParams> socket_params,
1227       const absl::optional<NetworkTrafficAnnotationTag>& proxy_annotation_tag,
1228       RequestPriority priority,
1229       const SocketTag& socket_tag,
1230       RespectLimits respect_limits,
1231       ClientSocketHandle* handle,
1232       CompletionOnceCallback callback,
1233       const ProxyAuthCallback& on_auth_callback,
1234       const NetLogWithSource& net_log) override;
1235   void SetPriority(const GroupId& group_id,
1236                    ClientSocketHandle* handle,
1237                    RequestPriority priority) override;
1238   void CancelRequest(const GroupId& group_id,
1239                      ClientSocketHandle* handle,
1240                      bool cancel_connect_job) override;
1241   void ReleaseSocket(const GroupId& group_id,
1242                      std::unique_ptr<StreamSocket> socket,
1243                      int64_t generation) override;
1244 
1245  private:
1246   raw_ptr<ClientSocketFactory> client_socket_factory_;
1247   std::vector<std::unique_ptr<MockConnectJob>> job_list_;
1248   RequestPriority last_request_priority_ = DEFAULT_PRIORITY;
1249   int release_count_ = 0;
1250   int cancel_count_ = 0;
1251 };
1252 
1253 // WrappedStreamSocket is a base class that wraps an existing StreamSocket,
1254 // forwarding the Socket and StreamSocket interfaces to the underlying
1255 // transport.
1256 // This is to provide a common base class for subclasses to override specific
1257 // StreamSocket methods for testing, while still communicating with a 'real'
1258 // StreamSocket.
1259 class WrappedStreamSocket : public TransportClientSocket {
1260  public:
1261   explicit WrappedStreamSocket(std::unique_ptr<StreamSocket> transport);
1262   ~WrappedStreamSocket() override;
1263 
1264   // StreamSocket implementation:
1265   int Bind(const net::IPEndPoint& local_addr) override;
1266   int Connect(CompletionOnceCallback callback) override;
1267   void Disconnect() override;
1268   bool IsConnected() const override;
1269   bool IsConnectedAndIdle() const override;
1270   int GetPeerAddress(IPEndPoint* address) const override;
1271   int GetLocalAddress(IPEndPoint* address) const override;
1272   const NetLogWithSource& NetLog() const override;
1273   bool WasEverUsed() const override;
1274   bool WasAlpnNegotiated() const override;
1275   NextProto GetNegotiatedProtocol() const override;
1276   bool GetSSLInfo(SSLInfo* ssl_info) override;
1277   int64_t GetTotalReceivedBytes() const override;
1278   void ApplySocketTag(const SocketTag& tag) override;
1279 
1280   // Socket implementation:
1281   int Read(IOBuffer* buf,
1282            int buf_len,
1283            CompletionOnceCallback callback) override;
1284   int ReadIfReady(IOBuffer* buf,
1285                   int buf_len,
1286                   CompletionOnceCallback callback) override;
1287   int Write(IOBuffer* buf,
1288             int buf_len,
1289             CompletionOnceCallback callback,
1290             const NetworkTrafficAnnotationTag& traffic_annotation) override;
1291   int SetReceiveBufferSize(int32_t size) override;
1292   int SetSendBufferSize(int32_t size) override;
1293 
1294  protected:
1295   std::unique_ptr<StreamSocket> transport_;
1296 };
1297 
1298 // StreamSocket that wraps another StreamSocket, but keeps track of any
1299 // SocketTag applied to the socket.
1300 class MockTaggingStreamSocket : public WrappedStreamSocket {
1301  public:
MockTaggingStreamSocket(std::unique_ptr<StreamSocket> transport)1302   explicit MockTaggingStreamSocket(std::unique_ptr<StreamSocket> transport)
1303       : WrappedStreamSocket(std::move(transport)) {}
1304 
1305   MockTaggingStreamSocket(const MockTaggingStreamSocket&) = delete;
1306   MockTaggingStreamSocket& operator=(const MockTaggingStreamSocket&) = delete;
1307 
1308   ~MockTaggingStreamSocket() override = default;
1309 
1310   // StreamSocket implementation.
1311   int Connect(CompletionOnceCallback callback) override;
1312   void ApplySocketTag(const SocketTag& tag) override;
1313 
1314   // Returns false if socket's tag was changed after the socket was connected,
1315   // otherwise returns true.
tagged_before_connected()1316   bool tagged_before_connected() const { return tagged_before_connected_; }
1317 
1318   // Returns last tag applied to socket.
tag()1319   SocketTag tag() const { return tag_; }
1320 
1321  private:
1322   bool connected_ = false;
1323   bool tagged_before_connected_ = true;
1324   SocketTag tag_;
1325 };
1326 
1327 // Extend MockClientSocketFactory to return MockTaggingStreamSockets and
1328 // keep track of last socket produced for test inspection.
1329 class MockTaggingClientSocketFactory : public MockClientSocketFactory {
1330  public:
1331   MockTaggingClientSocketFactory() = default;
1332 
1333   MockTaggingClientSocketFactory(const MockTaggingClientSocketFactory&) =
1334       delete;
1335   MockTaggingClientSocketFactory& operator=(
1336       const MockTaggingClientSocketFactory&) = delete;
1337 
1338   // ClientSocketFactory implementation.
1339   std::unique_ptr<DatagramClientSocket> CreateDatagramClientSocket(
1340       DatagramSocket::BindType bind_type,
1341       NetLog* net_log,
1342       const NetLogSource& source) override;
1343   std::unique_ptr<TransportClientSocket> CreateTransportClientSocket(
1344       const AddressList& addresses,
1345       std::unique_ptr<SocketPerformanceWatcher> socket_performance_watcher,
1346       NetworkQualityEstimator* network_quality_estimator,
1347       NetLog* net_log,
1348       const NetLogSource& source) override;
1349 
1350   // These methods return pointers to last TCP and UDP sockets produced by this
1351   // factory. NOTE: Socket must still exist, or pointer will be to freed memory.
GetLastProducedTCPSocket()1352   MockTaggingStreamSocket* GetLastProducedTCPSocket() const {
1353     return tcp_socket_;
1354   }
GetLastProducedUDPSocket()1355   MockUDPClientSocket* GetLastProducedUDPSocket() const { return udp_socket_; }
1356 
1357  private:
1358   raw_ptr<MockTaggingStreamSocket> tcp_socket_ = nullptr;
1359   raw_ptr<MockUDPClientSocket> udp_socket_ = nullptr;
1360 };
1361 
1362 // Host / port used for SOCKS4 test strings.
1363 extern const char kSOCKS4TestHost[];
1364 extern const int kSOCKS4TestPort;
1365 
1366 // Constants for a successful SOCKS v4 handshake (connecting to kSOCKS4TestHost
1367 // on port kSOCKS4TestPort, for the request).
1368 extern const char kSOCKS4OkRequestLocalHostPort80[];
1369 extern const int kSOCKS4OkRequestLocalHostPort80Length;
1370 
1371 extern const char kSOCKS4OkReply[];
1372 extern const int kSOCKS4OkReplyLength;
1373 
1374 // Host / port used for SOCKS5 test strings.
1375 extern const char kSOCKS5TestHost[];
1376 extern const int kSOCKS5TestPort;
1377 
1378 // Constants for a successful SOCKS v5 handshake (connecting to kSOCKS5TestHost
1379 // on port kSOCKS5TestPort, for the request)..
1380 extern const char kSOCKS5GreetRequest[];
1381 extern const int kSOCKS5GreetRequestLength;
1382 
1383 extern const char kSOCKS5GreetResponse[];
1384 extern const int kSOCKS5GreetResponseLength;
1385 
1386 extern const char kSOCKS5OkRequest[];
1387 extern const int kSOCKS5OkRequestLength;
1388 
1389 extern const char kSOCKS5OkResponse[];
1390 extern const int kSOCKS5OkResponseLength;
1391 
1392 // Helper function to get the total data size of the MockReads in |reads|.
1393 int64_t CountReadBytes(base::span<const MockRead> reads);
1394 
1395 // Helper function to get the total data size of the MockWrites in |writes|.
1396 int64_t CountWriteBytes(base::span<const MockWrite> writes);
1397 
1398 #if BUILDFLAG(IS_ANDROID)
1399 // Returns whether the device supports calling GetTaggedBytes().
1400 bool CanGetTaggedBytes();
1401 
1402 // Query the system to find out how many bytes were received with tag
1403 // |expected_tag| for our UID.  Return the count of received bytes.
1404 uint64_t GetTaggedBytes(int32_t expected_tag);
1405 #endif
1406 
1407 }  // namespace net
1408 
1409 #endif  // NET_SOCKET_SOCKET_TEST_UTIL_H_
1410