• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #ifndef NET_SOCKET_SOCKET_TEST_UTIL_H_
6 #define NET_SOCKET_SOCKET_TEST_UTIL_H_
7 
8 #include <stddef.h>
9 #include <stdint.h>
10 
11 #include <cstring>
12 #include <memory>
13 #include <optional>
14 #include <string>
15 #include <string_view>
16 #include <utility>
17 #include <vector>
18 
19 #include "base/check_op.h"
20 #include "base/containers/span.h"
21 #include "base/functional/bind.h"
22 #include "base/functional/callback.h"
23 #include "base/memory/ptr_util.h"
24 #include "base/memory/raw_ptr.h"
25 #include "base/memory/raw_span.h"
26 #include "base/memory/ref_counted.h"
27 #include "base/memory/weak_ptr.h"
28 #include "build/build_config.h"
29 #include "net/base/address_list.h"
30 #include "net/base/completion_once_callback.h"
31 #include "net/base/io_buffer.h"
32 #include "net/base/net_errors.h"
33 #include "net/base/test_completion_callback.h"
34 #include "net/http/http_auth_controller.h"
35 #include "net/log/net_log_with_source.h"
36 #include "net/socket/client_socket_factory.h"
37 #include "net/socket/client_socket_handle.h"
38 #include "net/socket/client_socket_pool.h"
39 #include "net/socket/datagram_client_socket.h"
40 #include "net/socket/socket_performance_watcher.h"
41 #include "net/socket/socket_tag.h"
42 #include "net/socket/ssl_client_socket.h"
43 #include "net/socket/transport_client_socket.h"
44 #include "net/socket/transport_client_socket_pool.h"
45 #include "net/ssl/ssl_config_service.h"
46 #include "net/ssl/ssl_info.h"
47 #include "testing/gtest/include/gtest/gtest.h"
48 
49 namespace base {
50 class RunLoop;
51 }
52 
53 namespace net {
54 
55 struct CommonConnectJobParams;
56 class NetLog;
57 struct NetworkTrafficAnnotationTag;
58 class X509Certificate;
59 
60 const handles::NetworkHandle kDefaultNetworkForTests = 1;
61 const handles::NetworkHandle kNewNetworkForTests = 2;
62 
63 enum {
64   // A private network error code used by the socket test utility classes.
65   // If the |result| member of a MockRead is
66   // ERR_TEST_PEER_CLOSE_AFTER_NEXT_MOCK_READ, that MockRead is just a
67   // marker that indicates the peer will close the connection after the next
68   // MockRead.  The other members of that MockRead are ignored.
69   ERR_TEST_PEER_CLOSE_AFTER_NEXT_MOCK_READ = -10000,
70 };
71 
72 class AsyncSocket;
73 class MockClientSocket;
74 class MockTCPClientSocket;
75 class MockSSLClientSocket;
76 class SSLClientSocket;
77 class StreamSocket;
78 
79 enum IoMode { ASYNC, SYNCHRONOUS };
80 
81 // Used to delay MockClientSocket::Connect.
82 // Example usage:
83 // TEST(FooTest, Test) {
84 //   MockClientSocketFactory socket_factory;
85 //
86 //   MockConnectCompleter completer;
87 //   SequencedSocketData data;
88 //   data.set_connect_data(MockConnect(&completer));
89 //   socket_factory.AddSocketDataProvider(&data);
90 //
91 //   // Create a MockClientSocket somehow.
92 //   std::unique_ptr<StreamSocket> stream = CreateStreamSocket();
93 //   std::optional<int> delayed_result;
94 //   int rv = stream->Connect(base::BindLambdaForTesting([&](int result){
95 //      delayed_result = result;
96 //   }));
97 //   // Connect() returns ERR_IO_PENDING.
98 //   EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
99 //
100 //   RunUntilIdle();
101 //   // Connect() is still blocked.
102 //   ASSERT_FALSE(delayed_result.has_value());
103 //
104 //   completer.Complete(OK);
105 //   RunUntilIdle();
106 //   EXPECT_THAT(delayed_result, Optional(IsOk()));
107 // }
108 class MockConnectCompleter {
109  public:
110   MockConnectCompleter();
111 
112   MockConnectCompleter(const MockConnectCompleter&) = delete;
113   MockConnectCompleter& operator=(const MockConnectCompleter&) = delete;
114 
115   ~MockConnectCompleter();
116 
117   // Completes Connect() with `result`.
118   void Complete(int result);
119 
120  private:
121   friend class MockTCPClientSocket;
122   friend class MockSSLClientSocket;
123   friend class MockUDPClientSocket;
124 
125   // Sets a completion callback that is passed to Connect(). Called by
126   // MockClientSocket implementations.
127   void SetCallback(CompletionOnceCallback callback);
128 
129   CompletionOnceCallback callback_;
130 };
131 
132 struct MockConnect {
133   // Asynchronous connection success.
134   // Creates a MockConnect with |mode| ASYC, |result| OK, and
135   // |peer_addr| 192.0.2.33.
136   MockConnect();
137   // Creates a MockConnect with the specified mode and result, with
138   // |peer_addr| 192.0.2.33.
139   MockConnect(IoMode io_mode, int r);
140   MockConnect(IoMode io_mode, int r, IPEndPoint addr);
141   MockConnect(IoMode io_mode, int r, IPEndPoint addr, bool first_attempt_fails);
142   // Creates a MockConnect that delays connection until `completer` invokes
143   // Complete().
144   explicit MockConnect(MockConnectCompleter* completer);
145   ~MockConnect();
146 
147   IoMode mode;
148   int result;
149   IPEndPoint peer_addr;
150   bool first_attempt_fails = false;
151   raw_ptr<MockConnectCompleter> completer;
152 };
153 
154 struct MockConfirm {
155   // Asynchronous confirm success.
156   // Creates a MockConfirm with |mode| ASYC and |result| OK.
157   MockConfirm();
158   // Creates a MockConfirm with the specified mode and result.
159   MockConfirm(IoMode io_mode, int r);
160   ~MockConfirm();
161 
162   IoMode mode;
163   int result;
164 };
165 
166 // MockRead and MockWrite shares the same interface and members, but we'd like
167 // to have distinct types because we don't want to have them used
168 // interchangably. To do this, a struct template is defined, and MockRead and
169 // MockWrite are instantiated by using this template. Template parameter |type|
170 // is not used in the struct definition (it purely exists for creating a new
171 // type).
172 //
173 // |data| in MockRead and MockWrite has different meanings: |data| in MockRead
174 // is the data returned from the socket when MockTCPClientSocket::Read() is
175 // attempted, while |data| in MockWrite is the expected data that should be
176 // given in MockTCPClientSocket::Write().
177 enum MockReadWriteType { MOCK_READ, MOCK_WRITE };
178 
179 template <MockReadWriteType type>
180 struct MockReadWrite {
181   // Flag to indicate that the message loop should be terminated.
182   enum { STOPLOOP = 1 << 31 };
183 
184   // Default
MockReadWriteMockReadWrite185   MockReadWrite()
186       : mode(SYNCHRONOUS),
187         result(0),
188         data(nullptr),
189         data_len(0),
190         sequence_number(0),
191         tos(0) {}
192 
193   // Read/write failure (no data).
MockReadWriteMockReadWrite194   MockReadWrite(IoMode io_mode, int result)
195       : mode(io_mode),
196         result(result),
197         data(nullptr),
198         data_len(0),
199         sequence_number(0),
200         tos(0) {}
201 
202   // Read/write failure (no data), with sequence information.
MockReadWriteMockReadWrite203   MockReadWrite(IoMode io_mode, int result, int seq)
204       : mode(io_mode),
205         result(result),
206         data(nullptr),
207         data_len(0),
208         sequence_number(seq),
209         tos(0) {}
210 
211   // Asynchronous read/write success (inferred data length).
MockReadWriteMockReadWrite212   explicit MockReadWrite(const char* data)
213       : mode(ASYNC),
214         result(0),
215         data(data),
216         data_len(strlen(data)),
217         sequence_number(0),
218         tos(0) {}
219 
220   // Read/write success (inferred data length).
MockReadWriteMockReadWrite221   MockReadWrite(IoMode io_mode, const char* data)
222       : mode(io_mode),
223         result(0),
224         data(data),
225         data_len(strlen(data)),
226         sequence_number(0),
227         tos(0) {}
228 
229   // Read/write success.
MockReadWriteMockReadWrite230   MockReadWrite(IoMode io_mode, const char* data, int data_len)
231       : mode(io_mode),
232         result(0),
233         data(data),
234         data_len(data_len),
235         sequence_number(0),
236         tos(0) {}
237 
238   // Read/write success (inferred data length) with sequence information.
MockReadWriteMockReadWrite239   MockReadWrite(IoMode io_mode, int seq, const char* data)
240       : mode(io_mode),
241         result(0),
242         data(data),
243         data_len(strlen(data)),
244         sequence_number(seq),
245         tos(0) {}
246 
247   // Read/write success with sequence information.
MockReadWriteMockReadWrite248   MockReadWrite(IoMode io_mode, const char* data, int data_len, int seq)
249       : mode(io_mode),
250         result(0),
251         data(data),
252         data_len(data_len),
253         sequence_number(seq),
254         tos(0) {}
255 
256   // Read/write success with sequence and TOS information.
MockReadWriteMockReadWrite257   MockReadWrite(IoMode io_mode,
258                 const char* data,
259                 int data_len,
260                 int seq,
261                 uint8_t tos_byte)
262       : mode(io_mode),
263         result(0),
264         data(data),
265         data_len(data_len),
266         sequence_number(seq),
267         tos(tos_byte) {}
268 
269   IoMode mode;
270   int result;
271   const char* data;
272   int data_len;
273 
274   // For data providers that only allows reads to occur in a particular
275   // sequence.  If a read occurs before the given |sequence_number| is reached,
276   // an ERR_IO_PENDING is returned.
277   int sequence_number;  // The sequence number at which a read is allowed
278                         // to occur.
279 
280   // The TOS byte of the datagram, for datagram sockets only.
281   uint8_t tos;
282 };
283 
284 typedef MockReadWrite<MOCK_READ> MockRead;
285 typedef MockReadWrite<MOCK_WRITE> MockWrite;
286 
287 struct MockWriteResult {
MockWriteResultMockWriteResult288   MockWriteResult(IoMode io_mode, int result) : mode(io_mode), result(result) {}
289 
290   IoMode mode;
291   int result;
292 };
293 
294 class SocketDataPrinter {
295  public:
296   ~SocketDataPrinter() = default;
297 
298   // Prints the write in |data| using some sort of protocol-specific
299   // format.
300   virtual std::string PrintWrite(const std::string& data) = 0;
301 };
302 
303 // The SocketDataProvider is an interface used by the MockClientSocket
304 // for getting data about individual reads and writes on the socket.  Can be
305 // used with at most one socket at a time.
306 // TODO(mmenke):  Do these really need to be re-useable?
307 class SocketDataProvider {
308  public:
309   SocketDataProvider();
310 
311   SocketDataProvider(const SocketDataProvider&) = delete;
312   SocketDataProvider& operator=(const SocketDataProvider&) = delete;
313 
314   virtual ~SocketDataProvider();
315 
316   // Returns the buffer and result code for the next simulated read.
317   // If the |MockRead.result| is ERR_IO_PENDING, it informs the caller
318   // that it will be called via the AsyncSocket::OnReadComplete()
319   // function at a later time.
320   virtual MockRead OnRead() = 0;
321   virtual MockWriteResult OnWrite(const std::string& data) = 0;
322   virtual bool AllReadDataConsumed() const = 0;
323   virtual bool AllWriteDataConsumed() const = 0;
CancelPendingRead()324   virtual void CancelPendingRead() {}
325 
326   // Returns the last set receive buffer size, or -1 if never set.
receive_buffer_size()327   int receive_buffer_size() const { return receive_buffer_size_; }
set_receive_buffer_size(int receive_buffer_size)328   void set_receive_buffer_size(int receive_buffer_size) {
329     receive_buffer_size_ = receive_buffer_size;
330   }
331 
332   // Returns the last set send buffer size, or -1 if never set.
send_buffer_size()333   int send_buffer_size() const { return send_buffer_size_; }
set_send_buffer_size(int send_buffer_size)334   void set_send_buffer_size(int send_buffer_size) {
335     send_buffer_size_ = send_buffer_size;
336   }
337 
338   // Returns the last set value of TCP no delay, or false if never set.
no_delay()339   bool no_delay() const { return no_delay_; }
set_no_delay(bool no_delay)340   void set_no_delay(bool no_delay) { no_delay_ = no_delay; }
341 
342   // Returns whether TCP keepalives were enabled or not. Returns kDefault by
343   // default.
344   enum class KeepAliveState { kEnabled, kDisabled, kDefault };
keep_alive_state()345   KeepAliveState keep_alive_state() const { return keep_alive_state_; }
346   // Last set TCP keepalive delay.
keep_alive_delay()347   int keep_alive_delay() const { return keep_alive_delay_; }
set_keep_alive(bool enable,int delay)348   void set_keep_alive(bool enable, int delay) {
349     keep_alive_state_ =
350         enable ? KeepAliveState::kEnabled : KeepAliveState::kDisabled;
351     keep_alive_delay_ = delay;
352   }
353 
354   // Setters / getters for the return values of the corresponding Set*()
355   // methods. By default, they all succeed, if the socket is connected.
356 
set_set_receive_buffer_size_result(int receive_buffer_size_result)357   void set_set_receive_buffer_size_result(int receive_buffer_size_result) {
358     set_receive_buffer_size_result_ = receive_buffer_size_result;
359   }
set_receive_buffer_size_result()360   int set_receive_buffer_size_result() const {
361     return set_receive_buffer_size_result_;
362   }
363 
set_set_send_buffer_size_result(int set_send_buffer_size_result)364   void set_set_send_buffer_size_result(int set_send_buffer_size_result) {
365     set_send_buffer_size_result_ = set_send_buffer_size_result;
366   }
set_send_buffer_size_result()367   int set_send_buffer_size_result() const {
368     return set_send_buffer_size_result_;
369   }
370 
set_set_no_delay_result(bool set_no_delay_result)371   void set_set_no_delay_result(bool set_no_delay_result) {
372     set_no_delay_result_ = set_no_delay_result;
373   }
set_no_delay_result()374   bool set_no_delay_result() const { return set_no_delay_result_; }
375 
set_set_keep_alive_result(bool set_keep_alive_result)376   void set_set_keep_alive_result(bool set_keep_alive_result) {
377     set_keep_alive_result_ = set_keep_alive_result;
378   }
set_keep_alive_result()379   bool set_keep_alive_result() const { return set_keep_alive_result_; }
380 
expected_addresses()381   const std::optional<AddressList>& expected_addresses() const {
382     return expected_addresses_;
383   }
set_expected_addresses(net::AddressList addresses)384   void set_expected_addresses(net::AddressList addresses) {
385     expected_addresses_ = std::move(addresses);
386   }
387 
388   // Returns true if the request should be considered idle, for the purposes of
389   // IsConnectedAndIdle.
390   virtual bool IsIdle() const;
391 
392   // Initializes the SocketDataProvider for use with |socket|.  Must be called
393   // before use
394   void Initialize(AsyncSocket* socket);
395   // Detaches the socket associated with a SocketDataProvider.  Must be called
396   // before |socket_| is destroyed, unless the SocketDataProvider has informed
397   // |socket_| it was destroyed.  Must also be called before Initialize() may
398   // be called again with a new socket.
399   void DetachSocket();
400 
401   // Accessor for the socket which is using the SocketDataProvider.
socket()402   AsyncSocket* socket() { return socket_; }
403 
connect_data()404   MockConnect connect_data() const { return connect_; }
set_connect_data(const MockConnect & connect)405   void set_connect_data(const MockConnect& connect) { connect_ = connect; }
406 
407  private:
408   // Called to inform subclasses of initialization.
409   virtual void Reset() = 0;
410 
411   MockConnect connect_;
412   raw_ptr<AsyncSocket> socket_ = nullptr;
413 
414   int receive_buffer_size_ = -1;
415   int send_buffer_size_ = -1;
416   // This reflects the default state of TCPClientSockets.
417   bool no_delay_ = true;
418 
419   KeepAliveState keep_alive_state_ = KeepAliveState::kDefault;
420   int keep_alive_delay_ = 0;
421 
422   int set_receive_buffer_size_result_ = net::OK;
423   int set_send_buffer_size_result_ = net::OK;
424   bool set_no_delay_result_ = true;
425   bool set_keep_alive_result_ = true;
426   std::optional<AddressList> expected_addresses_;
427 };
428 
429 // The AsyncSocket is an interface used by the SocketDataProvider to
430 // complete the asynchronous read operation.
431 class AsyncSocket {
432  public:
433   // If an async IO is pending because the SocketDataProvider returned
434   // ERR_IO_PENDING, then the AsyncSocket waits until this OnReadComplete
435   // is called to complete the asynchronous read operation.
436   // data.async is ignored, and this read is completed synchronously as
437   // part of this call.
438   // TODO(rch): this should take a std::string_view since most of the fields
439   // are ignored.
440   virtual void OnReadComplete(const MockRead& data) = 0;
441   // If an async IO is pending because the SocketDataProvider returned
442   // ERR_IO_PENDING, then the AsyncSocket waits until this OnReadComplete
443   // is called to complete the asynchronous read operation.
444   virtual void OnWriteComplete(int rv) = 0;
445   virtual void OnConnectComplete(const MockConnect& data) = 0;
446 
447   // Called when the SocketDataProvider associated with the socket is destroyed.
448   // The socket may continue to be used after the data provider is destroyed,
449   // so it should be sure not to dereference the provider after this is called.
450   virtual void OnDataProviderDestroyed() = 0;
451 };
452 
453 // StaticSocketDataHelper manages a list of reads and writes.
454 class StaticSocketDataHelper {
455  public:
456   StaticSocketDataHelper(base::span<const MockRead> reads,
457                          base::span<const MockWrite> writes);
458 
459   StaticSocketDataHelper(const StaticSocketDataHelper&) = delete;
460   StaticSocketDataHelper& operator=(const StaticSocketDataHelper&) = delete;
461 
462   ~StaticSocketDataHelper();
463 
464   // These functions get access to the next available read and write data. They
465   // CHECK fail if there is no data available.
466   const MockRead& PeekRead() const;
467   const MockWrite& PeekWrite() const;
468 
469   // Returns the current read or write, and then advances to the next one.
470   const MockRead& AdvanceRead();
471   const MockWrite& AdvanceWrite();
472 
473   // Resets the read and write indexes to 0.
474   void Reset();
475 
476   // Returns true if |data| is valid data for the next write. In order
477   // to support short writes, the next write may be longer than |data|
478   // in which case this method will still return true.
479   bool VerifyWriteData(const std::string& data, SocketDataPrinter* printer);
480 
read_index()481   size_t read_index() const { return read_index_; }
write_index()482   size_t write_index() const { return write_index_; }
read_count()483   size_t read_count() const { return reads_.size(); }
write_count()484   size_t write_count() const { return writes_.size(); }
485 
AllReadDataConsumed()486   bool AllReadDataConsumed() const { return read_index() >= read_count(); }
AllWriteDataConsumed()487   bool AllWriteDataConsumed() const { return write_index() >= write_count(); }
488 
489   void ExpectAllReadDataConsumed(SocketDataPrinter* printer) const;
490   void ExpectAllWriteDataConsumed(SocketDataPrinter* printer) const;
491 
492  private:
493   // Returns the next available read or write that is not a pause event. CHECK
494   // fails if no data is available.
495   const MockWrite& PeekRealWrite() const;
496 
497   const base::raw_span<const MockRead, DanglingUntriaged> reads_;
498   size_t read_index_ = 0;
499   const base::raw_span<const MockWrite, DanglingUntriaged> writes_;
500   size_t write_index_ = 0;
501 };
502 
503 // SocketDataProvider which responds based on static tables of mock reads and
504 // writes.
505 class StaticSocketDataProvider : public SocketDataProvider {
506  public:
507   StaticSocketDataProvider();
508   StaticSocketDataProvider(base::span<const MockRead> reads,
509                            base::span<const MockWrite> writes);
510 
511   StaticSocketDataProvider(const StaticSocketDataProvider&) = delete;
512   StaticSocketDataProvider& operator=(const StaticSocketDataProvider&) = delete;
513 
514   ~StaticSocketDataProvider() override;
515 
516   // Pause/resume reads from this provider.
517   void Pause();
518   void Resume();
519 
520   // From SocketDataProvider:
521   MockRead OnRead() override;
522   MockWriteResult OnWrite(const std::string& data) override;
523   bool AllReadDataConsumed() const override;
524   bool AllWriteDataConsumed() const override;
525 
read_index()526   size_t read_index() const { return helper_.read_index(); }
write_index()527   size_t write_index() const { return helper_.write_index(); }
read_count()528   size_t read_count() const { return helper_.read_count(); }
write_count()529   size_t write_count() const { return helper_.write_count(); }
530 
set_printer(SocketDataPrinter * printer)531   void set_printer(SocketDataPrinter* printer) { printer_ = printer; }
532 
533  private:
534   // From SocketDataProvider:
535   void Reset() override;
536 
537   StaticSocketDataHelper helper_;
538   raw_ptr<SocketDataPrinter> printer_ = nullptr;
539   bool paused_ = false;
540 };
541 
542 // SSLSocketDataProviders only need to keep track of the return code from calls
543 // to Connect().
544 struct SSLSocketDataProvider {
545   SSLSocketDataProvider(IoMode mode, int result);
546   explicit SSLSocketDataProvider(MockConnectCompleter* completer);
547   SSLSocketDataProvider(const SSLSocketDataProvider& other);
548   ~SSLSocketDataProvider();
549 
550   // Returns whether MockConnect data has been consumed.
ConnectDataConsumedSSLSocketDataProvider551   bool ConnectDataConsumed() const { return is_connect_data_consumed; }
552 
553   // Returns whether MockConfirm data has been consumed.
ConfirmDataConsumedSSLSocketDataProvider554   bool ConfirmDataConsumed() const { return is_confirm_data_consumed; }
555 
556   // Returns whether a Write occurred before ConfirmHandshake completed.
WriteBeforeConfirmSSLSocketDataProvider557   bool WriteBeforeConfirm() const { return write_called_before_confirm; }
558 
559   // Result for Connect().
560   MockConnect connect;
561   // Callback to run when Connect() is called. This is called at most once per
562   // socket but is repeating because SSLSocketDataProvider is copyable.
563   base::RepeatingClosure connect_callback;
564 
565   // Result for ConfirmHandshake().
566   MockConfirm confirm;
567   // Callback to run when ConfirmHandshake() is called. This is called at most
568   // once per socket but is repeating because SSLSocketDataProvider is
569   // copyable.
570   base::RepeatingClosure confirm_callback;
571 
572   // Result for GetNegotiatedProtocol().
573   NextProto next_proto = kProtoUnknown;
574 
575   // Result for GetPeerApplicationSettings().
576   std::optional<std::string> peer_application_settings;
577 
578   // Result for GetSSLInfo().
579   SSLInfo ssl_info;
580 
581   // Result for GetSSLCertRequestInfo().
582   scoped_refptr<SSLCertRequestInfo> cert_request_info;
583 
584   // Result for GetECHRetryConfigs().
585   std::vector<uint8_t> ech_retry_configs;
586 
587   std::optional<NextProtoVector> next_protos_expected_in_ssl_config;
588   std::optional<SSLConfig::ApplicationSettings> expected_application_settings;
589 
590   uint16_t expected_ssl_version_min;
591   uint16_t expected_ssl_version_max;
592   std::optional<bool> expected_early_data_enabled;
593   std::optional<bool> expected_send_client_cert;
594   scoped_refptr<X509Certificate> expected_client_cert;
595   std::optional<HostPortPair> expected_host_and_port;
596   std::optional<bool> expected_ignore_certificate_errors;
597   std::optional<NetworkAnonymizationKey> expected_network_anonymization_key;
598   std::optional<std::vector<uint8_t>> expected_ech_config_list;
599 
600   bool is_connect_data_consumed = false;
601   bool is_confirm_data_consumed = false;
602   bool write_called_before_confirm = false;
603 };
604 
605 // Uses the sequence_number field in the mock reads and writes to
606 // complete the operations in a specified order.
607 class SequencedSocketData : public SocketDataProvider {
608  public:
609   SequencedSocketData();
610 
611   // |reads| is the list of MockRead completions.
612   // |writes| is the list of MockWrite completions.
613   SequencedSocketData(base::span<const MockRead> reads,
614                       base::span<const MockWrite> writes);
615 
616   // |connect| is the result for the connect phase.
617   // |reads| is the list of MockRead completions.
618   // |writes| is the list of MockWrite completions.
619   SequencedSocketData(const MockConnect& connect,
620                       base::span<const MockRead> reads,
621                       base::span<const MockWrite> writes);
622 
623   SequencedSocketData(const SequencedSocketData&) = delete;
624   SequencedSocketData& operator=(const SequencedSocketData&) = delete;
625 
626   ~SequencedSocketData() override;
627 
628   // From SocketDataProvider:
629   MockRead OnRead() override;
630   MockWriteResult OnWrite(const std::string& data) override;
631   bool AllReadDataConsumed() const override;
632   bool AllWriteDataConsumed() const override;
633   bool IsIdle() const override;
634   void CancelPendingRead() override;
635 
636   // EXPECTs that all data has been consumed, printing any un-consumed data.
637   void ExpectAllReadDataConsumed() const;
638   void ExpectAllWriteDataConsumed() const;
639 
640   // An ASYNC read event with a return value of ERR_IO_PENDING will cause the
641   // socket data to pause at that event, and advance no further, until Resume is
642   // invoked.  At that point, the socket will continue at the next event in the
643   // sequence.
644   //
645   // If a request just wants to simulate a connection that stays open and never
646   // receives any more data, instead of pausing and then resuming a request, it
647   // should use a SYNCHRONOUS event with a return value of ERR_IO_PENDING
648   // instead.
649   bool IsPaused() const;
650   // Resumes events once |this| is in the paused state.  The next event will
651   // occur synchronously with the call if it can.
652   void Resume();
653   void RunUntilPaused();
654 
655   // When true, IsConnectedAndIdle() will return false if the next event in the
656   // sequence is a synchronous.  Otherwise, the socket claims to be idle as
657   // long as it's connected.  Defaults to false.
658   // TODO(mmenke):  See if this can be made the default behavior, and consider
659   // removing this mehtod.  Need to make sure it doesn't change what code any
660   // tests are targetted at testing.
set_busy_before_sync_reads(bool busy_before_sync_reads)661   void set_busy_before_sync_reads(bool busy_before_sync_reads) {
662     busy_before_sync_reads_ = busy_before_sync_reads;
663   }
664 
set_printer(SocketDataPrinter * printer)665   void set_printer(SocketDataPrinter* printer) { printer_ = printer; }
666 
667  private:
668   // Defines the state for the read or write path.
669   enum class IoState {
670     kIdle,        // No async operation is in progress.
671     kPending,     // An async operation in waiting for another operation to
672                   // complete.
673     kCompleting,  // A task has been posted to complete an async operation.
674     kPaused,      // IO is paused until Resume() is called.
675   };
676 
677   // From SocketDataProvider:
678   void Reset() override;
679 
680   void OnReadComplete();
681   void OnWriteComplete();
682 
683   void MaybePostReadCompleteTask();
684   void MaybePostWriteCompleteTask();
685 
686   StaticSocketDataHelper helper_;
687   raw_ptr<SocketDataPrinter> printer_ = nullptr;
688   int sequence_number_ = 0;
689   IoState read_state_ = IoState::kIdle;
690   IoState write_state_ = IoState::kIdle;
691 
692   bool busy_before_sync_reads_ = false;
693 
694   // Used by RunUntilPaused.  NULL at all other times.
695   std::unique_ptr<base::RunLoop> run_until_paused_run_loop_;
696 
697   base::WeakPtrFactory<SequencedSocketData> weak_factory_{this};
698 };
699 
700 // Holds an array of SocketDataProvider elements.  As Mock{TCP,SSL}StreamSocket
701 // objects get instantiated, they take their data from the i'th element of this
702 // array.
703 template <typename T>
704 class SocketDataProviderArray {
705  public:
706   SocketDataProviderArray() = default;
707 
GetNext()708   T* GetNext() {
709     DCHECK_LT(next_index_, data_providers_.size());
710     return data_providers_[next_index_++];
711   }
712 
713   // Like GetNext(), but returns nullptr when the end of the array is reached,
714   // instead of DCHECKing. GetNext() should generally be preferred, unless
715   // having no remaining elements is expected in some cases and is handled
716   // safely.
GetNextWithoutAsserting()717   T* GetNextWithoutAsserting() {
718     if (next_index_ == data_providers_.size())
719       return nullptr;
720     return data_providers_[next_index_++];
721   }
722 
Add(T * data_provider)723   void Add(T* data_provider) {
724     DCHECK(data_provider);
725     data_providers_.push_back(data_provider);
726   }
727 
next_index()728   size_t next_index() { return next_index_; }
729 
ResetNextIndex()730   void ResetNextIndex() { next_index_ = 0; }
731 
732  private:
733   // Index of the next |data_providers_| element to use. Not an iterator
734   // because those are invalidated on vector reallocation.
735   size_t next_index_ = 0;
736 
737   // SocketDataProviders to be returned.
738   std::vector<T*> data_providers_;
739 };
740 
741 class MockUDPClientSocket;
742 class MockTCPClientSocket;
743 class MockSSLClientSocket;
744 
745 // ClientSocketFactory which contains arrays of sockets of each type.
746 // You should first fill the arrays using Add{SSL,}SocketDataProvider(). When
747 // the factory is asked to create a socket, it takes next entry from appropriate
748 // array. You can use ResetNextMockIndexes to reset that next entry index for
749 // all mock socket types.
750 class MockClientSocketFactory : public ClientSocketFactory {
751  public:
752   MockClientSocketFactory();
753 
754   MockClientSocketFactory(const MockClientSocketFactory&) = delete;
755   MockClientSocketFactory& operator=(const MockClientSocketFactory&) = delete;
756 
757   ~MockClientSocketFactory() override;
758 
759   // Adds a SocketDataProvider that can be used to served either TCP or UDP
760   // connection requests. Sockets are returned in FIFO order.
761   void AddSocketDataProvider(SocketDataProvider* socket);
762 
763   // Like AddSocketDataProvider(), except sockets will only be used to service
764   // TCP connection requests. Sockets added with this method are used first,
765   // before sockets added with AddSocketDataProvider(). Particularly useful for
766   // QUIC tests with multiple sockets, where TCP connections may or may not be
767   // made, and have no guaranteed order, relative to UDP connections.
768   void AddTcpSocketDataProvider(SocketDataProvider* socket);
769 
770   void AddSSLSocketDataProvider(SSLSocketDataProvider* socket);
771   void ResetNextMockIndexes();
772 
mock_data()773   SocketDataProviderArray<SocketDataProvider>& mock_data() {
774     return mock_data_;
775   }
776 
set_enable_read_if_ready(bool enable_read_if_ready)777   void set_enable_read_if_ready(bool enable_read_if_ready) {
778     enable_read_if_ready_ = enable_read_if_ready;
779   }
780 
781   // ClientSocketFactory
782   std::unique_ptr<DatagramClientSocket> CreateDatagramClientSocket(
783       DatagramSocket::BindType bind_type,
784       NetLog* net_log,
785       const NetLogSource& source) override;
786   std::unique_ptr<TransportClientSocket> CreateTransportClientSocket(
787       const AddressList& addresses,
788       std::unique_ptr<SocketPerformanceWatcher> socket_performance_watcher,
789       NetworkQualityEstimator* network_quality_estimator,
790       NetLog* net_log,
791       const NetLogSource& source) override;
792   std::unique_ptr<SSLClientSocket> CreateSSLClientSocket(
793       SSLClientContext* context,
794       std::unique_ptr<StreamSocket> stream_socket,
795       const HostPortPair& host_and_port,
796       const SSLConfig& ssl_config) override;
udp_client_socket_ports()797   const std::vector<uint16_t>& udp_client_socket_ports() const {
798     return udp_client_socket_ports_;
799   }
800 
801  private:
802   SocketDataProviderArray<SocketDataProvider> mock_data_;
803   SocketDataProviderArray<SocketDataProvider> mock_tcp_data_;
804   SocketDataProviderArray<SSLSocketDataProvider> mock_ssl_data_;
805   std::vector<uint16_t> udp_client_socket_ports_;
806 
807   // If true, ReadIfReady() is enabled; otherwise ReadIfReady() returns
808   // ERR_READ_IF_READY_NOT_IMPLEMENTED.
809   bool enable_read_if_ready_ = false;
810 };
811 
812 class MockClientSocket : public TransportClientSocket {
813  public:
814   // The NetLogWithSource is needed to test LoadTimingInfo, which uses NetLog
815   // IDs as
816   // unique socket IDs.
817   explicit MockClientSocket(const NetLogWithSource& net_log);
818 
819   MockClientSocket(const MockClientSocket&) = delete;
820   MockClientSocket& operator=(const MockClientSocket&) = delete;
821 
822   // Socket implementation.
823   int Read(IOBuffer* buf,
824            int buf_len,
825            CompletionOnceCallback callback) override = 0;
826   int Write(IOBuffer* buf,
827             int buf_len,
828             CompletionOnceCallback callback,
829             const NetworkTrafficAnnotationTag& traffic_annotation) override = 0;
830   int SetReceiveBufferSize(int32_t size) override;
831   int SetSendBufferSize(int32_t size) override;
832 
833   // TransportClientSocket implementation.
834   int Bind(const net::IPEndPoint& local_addr) override;
835   bool SetNoDelay(bool no_delay) override;
836   bool SetKeepAlive(bool enable, int delay) override;
837 
838   // StreamSocket implementation.
839   int Connect(CompletionOnceCallback callback) override = 0;
840   void Disconnect() override;
841   bool IsConnected() const override;
842   bool IsConnectedAndIdle() const override;
843   int GetPeerAddress(IPEndPoint* address) const override;
844   int GetLocalAddress(IPEndPoint* address) const override;
845   const NetLogWithSource& NetLog() const override;
846   NextProto GetNegotiatedProtocol() const override;
847   int64_t GetTotalReceivedBytes() const override;
ApplySocketTag(const SocketTag & tag)848   void ApplySocketTag(const SocketTag& tag) override {}
849 
850  protected:
851   ~MockClientSocket() override;
852   void RunCallbackAsync(CompletionOnceCallback callback, int result);
853   void RunCallback(CompletionOnceCallback callback, int result);
854 
855   // True if Connect completed successfully and Disconnect hasn't been called.
856   bool connected_ = false;
857 
858   IPEndPoint local_addr_;
859   IPEndPoint peer_addr_;
860 
861   NetLogWithSource net_log_;
862 
863  private:
864   base::WeakPtrFactory<MockClientSocket> weak_factory_{this};
865 };
866 
867 class MockTCPClientSocket : public MockClientSocket, public AsyncSocket {
868  public:
869   MockTCPClientSocket(const AddressList& addresses,
870                       net::NetLog* net_log,
871                       SocketDataProvider* socket);
872 
873   MockTCPClientSocket(const MockTCPClientSocket&) = delete;
874   MockTCPClientSocket& operator=(const MockTCPClientSocket&) = delete;
875 
876   ~MockTCPClientSocket() override;
877 
addresses()878   const AddressList& addresses() const { return addresses_; }
879 
880   // Socket implementation.
881   int Read(IOBuffer* buf,
882            int buf_len,
883            CompletionOnceCallback callback) override;
884   int ReadIfReady(IOBuffer* buf,
885                   int buf_len,
886                   CompletionOnceCallback callback) override;
887   int CancelReadIfReady() override;
888   int Write(IOBuffer* buf,
889             int buf_len,
890             CompletionOnceCallback callback,
891             const NetworkTrafficAnnotationTag& traffic_annotation) override;
892   int SetReceiveBufferSize(int32_t size) override;
893   int SetSendBufferSize(int32_t size) override;
894 
895   // TransportClientSocket implementation.
896   bool SetNoDelay(bool no_delay) override;
897   bool SetKeepAlive(bool enable, int delay) override;
898 
899   // StreamSocket implementation.
900   void SetBeforeConnectCallback(
901       const BeforeConnectCallback& before_connect_callback) override;
902   int Connect(CompletionOnceCallback callback) override;
903   void Disconnect() override;
904   bool IsConnected() const override;
905   bool IsConnectedAndIdle() const override;
906   int GetPeerAddress(IPEndPoint* address) const override;
907   bool WasEverUsed() const override;
908   bool GetSSLInfo(SSLInfo* ssl_info) override;
909 
910   // AsyncSocket:
911   void OnReadComplete(const MockRead& data) override;
912   void OnWriteComplete(int rv) override;
913   void OnConnectComplete(const MockConnect& data) override;
914   void OnDataProviderDestroyed() override;
915 
set_enable_read_if_ready(bool enable_read_if_ready)916   void set_enable_read_if_ready(bool enable_read_if_ready) {
917     enable_read_if_ready_ = enable_read_if_ready;
918   }
919 
920  private:
921   void RetryRead(int rv);
922   int ReadIfReadyImpl(IOBuffer* buf,
923                       int buf_len,
924                       CompletionOnceCallback callback);
925 
926   // Helper method to run |pending_read_if_ready_callback_| if it is not null.
927   void RunReadIfReadyCallback(int result);
928 
929   AddressList addresses_;
930 
931   raw_ptr<SocketDataProvider> data_;
932   int read_offset_ = 0;
933   MockRead read_data_;
934   bool need_read_data_ = true;
935 
936   // True if the peer has closed the connection.  This allows us to simulate
937   // the recv(..., MSG_PEEK) call in the IsConnectedAndIdle method of the real
938   // TCPClientSocket.
939   bool peer_closed_connection_ = false;
940 
941   // While an asynchronous read is pending, we save our user-buffer state.
942   scoped_refptr<IOBuffer> pending_read_buf_ = nullptr;
943   int pending_read_buf_len_ = 0;
944   CompletionOnceCallback pending_read_callback_;
945 
946   // Non-null when a ReadIfReady() is pending.
947   CompletionOnceCallback pending_read_if_ready_callback_;
948 
949   CompletionOnceCallback pending_connect_callback_;
950   CompletionOnceCallback pending_write_callback_;
951   bool was_used_to_convey_data_ = false;
952 
953   // If true, ReadIfReady() is enabled; otherwise ReadIfReady() returns
954   // ERR_READ_IF_READY_NOT_IMPLEMENTED.
955   bool enable_read_if_ready_ = false;
956 
957   BeforeConnectCallback before_connect_callback_;
958 };
959 
960 class MockSSLClientSocket : public AsyncSocket, public SSLClientSocket {
961  public:
962   MockSSLClientSocket(std::unique_ptr<StreamSocket> stream_socket,
963                       const HostPortPair& host_and_port,
964                       const SSLConfig& ssl_config,
965                       SSLSocketDataProvider* socket);
966 
967   MockSSLClientSocket(const MockSSLClientSocket&) = delete;
968   MockSSLClientSocket& operator=(const MockSSLClientSocket&) = delete;
969 
970   ~MockSSLClientSocket() override;
971 
972   // Socket implementation.
973   int Read(IOBuffer* buf,
974            int buf_len,
975            CompletionOnceCallback callback) override;
976   int ReadIfReady(IOBuffer* buf,
977                   int buf_len,
978                   CompletionOnceCallback callback) override;
979   int Write(IOBuffer* buf,
980             int buf_len,
981             CompletionOnceCallback callback,
982             const NetworkTrafficAnnotationTag& traffic_annotation) override;
983   int CancelReadIfReady() override;
984 
985   // StreamSocket implementation.
986   int Connect(CompletionOnceCallback callback) override;
987   void Disconnect() override;
988   int ConfirmHandshake(CompletionOnceCallback callback) override;
989   bool IsConnected() const override;
990   bool IsConnectedAndIdle() const override;
991   bool WasEverUsed() const override;
992   int GetPeerAddress(IPEndPoint* address) const override;
993   int GetLocalAddress(IPEndPoint* address) const override;
994   NextProto GetNegotiatedProtocol() const override;
995   std::optional<std::string_view> GetPeerApplicationSettings() const override;
996   bool GetSSLInfo(SSLInfo* ssl_info) override;
997   void GetSSLCertRequestInfo(
998       SSLCertRequestInfo* cert_request_info) const override;
999   void ApplySocketTag(const SocketTag& tag) override;
1000   const NetLogWithSource& NetLog() const override;
1001   int64_t GetTotalReceivedBytes() const override;
1002   int SetReceiveBufferSize(int32_t size) override;
1003   int SetSendBufferSize(int32_t size) override;
1004 
1005   // SSLSocket implementation.
1006   int ExportKeyingMaterial(std::string_view label,
1007                            std::optional<base::span<const uint8_t>> context,
1008                            base::span<uint8_t> out) override;
1009 
1010   // SSLClientSocket implementation.
1011   std::vector<uint8_t> GetECHRetryConfigs() override;
1012 
1013   // This MockSocket does not implement the manual async IO feature.
1014   void OnReadComplete(const MockRead& data) override;
1015   void OnWriteComplete(int rv) override;
1016   void OnConnectComplete(const MockConnect& data) override;
1017   // SSL sockets don't need magic to deal with destruction of their data
1018   // provider.
1019   // TODO(mmenke):  Probably a good idea to support it, anyways.
OnDataProviderDestroyed()1020   void OnDataProviderDestroyed() override {}
1021 
1022  private:
1023   static void ConnectCallback(MockSSLClientSocket* ssl_client_socket,
1024                               CompletionOnceCallback callback,
1025                               int rv);
1026 
1027   void RunCallbackAsync(CompletionOnceCallback callback, int result);
1028   void RunCallback(CompletionOnceCallback callback, int result);
1029 
1030   void RunConfirmHandshakeCallback(CompletionOnceCallback callback, int result);
1031 
1032   bool connected_ = false;
1033   bool in_confirm_handshake_ = false;
1034   NetLogWithSource net_log_;
1035   std::unique_ptr<StreamSocket> stream_socket_;
1036   raw_ptr<SSLSocketDataProvider, AcrossTasksDanglingUntriaged> data_;
1037   // Address of the "remote" peer we're connected to.
1038   IPEndPoint peer_addr_;
1039 
1040   base::WeakPtrFactory<MockSSLClientSocket> weak_factory_{this};
1041 };
1042 
1043 class MockUDPClientSocket : public DatagramClientSocket, public AsyncSocket {
1044  public:
1045   explicit MockUDPClientSocket(SocketDataProvider* data = nullptr,
1046                                net::NetLog* net_log = nullptr);
1047 
1048   MockUDPClientSocket(const MockUDPClientSocket&) = delete;
1049   MockUDPClientSocket& operator=(const MockUDPClientSocket&) = delete;
1050 
1051   ~MockUDPClientSocket() override;
1052 
1053   // Socket implementation.
1054   int Read(IOBuffer* buf,
1055            int buf_len,
1056            CompletionOnceCallback callback) override;
1057   int Write(IOBuffer* buf,
1058             int buf_len,
1059             CompletionOnceCallback callback,
1060             const NetworkTrafficAnnotationTag& traffic_annotation) override;
1061 
1062   int SetReceiveBufferSize(int32_t size) override;
1063   int SetSendBufferSize(int32_t size) override;
1064   int SetDoNotFragment() override;
1065   int SetRecvTos() override;
1066   int SetTos(DiffServCodePoint dscp, EcnCodePoint ecn) override;
1067 
1068   // DatagramSocket implementation.
1069   void Close() override;
1070   int GetPeerAddress(IPEndPoint* address) const override;
1071   int GetLocalAddress(IPEndPoint* address) const override;
1072   void UseNonBlockingIO() override;
1073   int SetMulticastInterface(uint32_t interface_index) override;
1074   const NetLogWithSource& NetLog() const override;
1075 
1076   // DatagramClientSocket implementation.
1077   int Connect(const IPEndPoint& address) override;
1078   int ConnectUsingNetwork(handles::NetworkHandle network,
1079                           const IPEndPoint& address) override;
1080   int ConnectUsingDefaultNetwork(const IPEndPoint& address) override;
1081   int ConnectAsync(const IPEndPoint& address,
1082                    CompletionOnceCallback callback) override;
1083   int ConnectUsingNetworkAsync(handles::NetworkHandle network,
1084                                const IPEndPoint& address,
1085                                CompletionOnceCallback callback) override;
1086   int ConnectUsingDefaultNetworkAsync(const IPEndPoint& address,
1087                                       CompletionOnceCallback callback) override;
1088   handles::NetworkHandle GetBoundNetwork() const override;
1089   void ApplySocketTag(const SocketTag& tag) override;
SetMsgConfirm(bool confirm)1090   void SetMsgConfirm(bool confirm) override {}
1091   DscpAndEcn GetLastTos() const override;
1092 
1093   // AsyncSocket implementation.
1094   void OnReadComplete(const MockRead& data) override;
1095   void OnWriteComplete(int rv) override;
1096   void OnConnectComplete(const MockConnect& data) override;
1097   void OnDataProviderDestroyed() override;
1098 
set_source_port(uint16_t port)1099   void set_source_port(uint16_t port) { source_port_ = port; }
source_port()1100   uint16_t source_port() const { return source_port_; }
set_source_host(IPAddress addr)1101   void set_source_host(IPAddress addr) { source_host_ = addr; }
source_host()1102   IPAddress source_host() const { return source_host_; }
1103 
1104   // Returns last tag applied to socket.
tag()1105   SocketTag tag() const { return tag_; }
1106 
1107   // Returns false if socket's tag was changed after the socket was used for
1108   // data transfer (e.g. Read/Write() called), otherwise returns true.
tagged_before_data_transferred()1109   bool tagged_before_data_transferred() const {
1110     return tagged_before_data_transferred_;
1111   }
1112 
1113  private:
1114   int CompleteRead();
1115 
1116   void RunCallbackAsync(CompletionOnceCallback callback, int result);
1117   void RunCallback(CompletionOnceCallback callback, int result);
1118 
1119   bool connected_ = false;
1120   raw_ptr<SocketDataProvider> data_;
1121   int read_offset_ = 0;
1122   MockRead read_data_;
1123   bool need_read_data_ = true;
1124   IPAddress source_host_;
1125   uint16_t source_port_ = 123;  // Ephemeral source port.
1126 
1127   // Address of the "remote" peer we're connected to.
1128   IPEndPoint peer_addr_;
1129 
1130   // Network that the socket is bound to.
1131   handles::NetworkHandle network_ = handles::kInvalidNetworkHandle;
1132 
1133   // While an asynchronous IO is pending, we save our user-buffer state.
1134   scoped_refptr<IOBuffer> pending_read_buf_ = nullptr;
1135   int pending_read_buf_len_ = 0;
1136   CompletionOnceCallback pending_read_callback_;
1137   CompletionOnceCallback pending_write_callback_;
1138 
1139   NetLogWithSource net_log_;
1140 
1141   DatagramBuffers unwritten_buffers_;
1142 
1143   SocketTag tag_;
1144   bool data_transferred_ = false;
1145   bool tagged_before_data_transferred_ = true;
1146 
1147   uint8_t last_tos_ = 0;
1148 
1149   base::WeakPtrFactory<MockUDPClientSocket> weak_factory_{this};
1150 };
1151 
1152 class TestSocketRequest : public TestCompletionCallbackBase {
1153  public:
1154   TestSocketRequest(std::vector<raw_ptr<TestSocketRequest, VectorExperimental>>*
1155                         request_order,
1156                     size_t* completion_count);
1157 
1158   TestSocketRequest(const TestSocketRequest&) = delete;
1159   TestSocketRequest& operator=(const TestSocketRequest&) = delete;
1160 
1161   ~TestSocketRequest() override;
1162 
handle()1163   ClientSocketHandle* handle() { return &handle_; }
1164 
callback()1165   CompletionOnceCallback callback() {
1166     return base::BindOnce(&TestSocketRequest::OnComplete,
1167                           base::Unretained(this));
1168   }
1169 
1170  private:
1171   void OnComplete(int result);
1172 
1173   ClientSocketHandle handle_;
1174   raw_ptr<std::vector<raw_ptr<TestSocketRequest, VectorExperimental>>>
1175       request_order_;
1176   raw_ptr<size_t> completion_count_;
1177 };
1178 
1179 class ClientSocketPoolTest {
1180  public:
1181   enum KeepAlive {
1182     KEEP_ALIVE,
1183 
1184     // A socket will be disconnected in addition to handle being reset.
1185     NO_KEEP_ALIVE,
1186   };
1187 
1188   static const int kIndexOutOfBounds;
1189   static const int kRequestNotFound;
1190 
1191   ClientSocketPoolTest();
1192 
1193   ClientSocketPoolTest(const ClientSocketPoolTest&) = delete;
1194   ClientSocketPoolTest& operator=(const ClientSocketPoolTest&) = delete;
1195 
1196   ~ClientSocketPoolTest();
1197 
1198   template <typename PoolType>
StartRequestUsingPool(PoolType * socket_pool,const ClientSocketPool::GroupId & group_id,RequestPriority priority,ClientSocketPool::RespectLimits respect_limits,const scoped_refptr<typename PoolType::SocketParams> & socket_params)1199   int StartRequestUsingPool(
1200       PoolType* socket_pool,
1201       const ClientSocketPool::GroupId& group_id,
1202       RequestPriority priority,
1203       ClientSocketPool::RespectLimits respect_limits,
1204       const scoped_refptr<typename PoolType::SocketParams>& socket_params) {
1205     DCHECK(socket_pool);
1206     TestSocketRequest* request(
1207         new TestSocketRequest(&request_order_, &completion_count_));
1208     requests_.push_back(base::WrapUnique(request));
1209     int rv = request->handle()->Init(
1210         group_id, socket_params, std::nullopt /* proxy_annotation_tag */,
1211         priority, SocketTag(), respect_limits, request->callback(),
1212         ClientSocketPool::ProxyAuthCallback(), socket_pool, NetLogWithSource());
1213     if (rv != ERR_IO_PENDING)
1214       request_order_.push_back(request);
1215     return rv;
1216   }
1217 
1218   // Provided there were n requests started, takes |index| in range 1..n
1219   // and returns order in which that request completed, in range 1..n,
1220   // or kIndexOutOfBounds if |index| is out of bounds, or kRequestNotFound
1221   // if that request did not complete (for example was canceled).
1222   int GetOrderOfRequest(size_t index) const;
1223 
1224   // Resets first initialized socket handle from |requests_|. If found such
1225   // a handle, returns true.
1226   bool ReleaseOneConnection(KeepAlive keep_alive);
1227 
1228   // Releases connections until there is nothing to release.
1229   void ReleaseAllConnections(KeepAlive keep_alive);
1230 
1231   // Note that this uses 0-based indices, while GetOrderOfRequest takes and
1232   // returns 1-based indices.
request(int i)1233   TestSocketRequest* request(int i) { return requests_[i].get(); }
1234 
requests_size()1235   size_t requests_size() const { return requests_.size(); }
requests()1236   std::vector<std::unique_ptr<TestSocketRequest>>* requests() {
1237     return &requests_;
1238   }
completion_count()1239   size_t completion_count() const { return completion_count_; }
1240 
1241  private:
1242   std::vector<std::unique_ptr<TestSocketRequest>> requests_;
1243   std::vector<raw_ptr<TestSocketRequest, VectorExperimental>> request_order_;
1244   size_t completion_count_ = 0;
1245 };
1246 
1247 class MockTransportSocketParams
1248     : public base::RefCounted<MockTransportSocketParams> {
1249  public:
1250   MockTransportSocketParams(const MockTransportSocketParams&) = delete;
1251   MockTransportSocketParams& operator=(const MockTransportSocketParams&) =
1252       delete;
1253 
1254  private:
1255   friend class base::RefCounted<MockTransportSocketParams>;
1256   ~MockTransportSocketParams() = default;
1257 };
1258 
1259 class MockTransportClientSocketPool : public TransportClientSocketPool {
1260  public:
1261   class MockConnectJob {
1262    public:
1263     MockConnectJob(std::unique_ptr<StreamSocket> socket,
1264                    ClientSocketHandle* handle,
1265                    const SocketTag& socket_tag,
1266                    CompletionOnceCallback callback,
1267                    RequestPriority priority);
1268 
1269     MockConnectJob(const MockConnectJob&) = delete;
1270     MockConnectJob& operator=(const MockConnectJob&) = delete;
1271 
1272     ~MockConnectJob();
1273 
1274     int Connect();
1275     bool CancelHandle(const ClientSocketHandle* handle);
1276 
handle()1277     ClientSocketHandle* handle() const { return handle_; }
1278 
priority()1279     RequestPriority priority() const { return priority_; }
set_priority(RequestPriority priority)1280     void set_priority(RequestPriority priority) { priority_ = priority; }
1281 
1282    private:
1283     void OnConnect(int rv);
1284 
1285     std::unique_ptr<StreamSocket> socket_;
1286     raw_ptr<ClientSocketHandle> handle_;
1287     const SocketTag socket_tag_;
1288     CompletionOnceCallback user_callback_;
1289     RequestPriority priority_;
1290   };
1291 
1292   MockTransportClientSocketPool(
1293       int max_sockets,
1294       int max_sockets_per_group,
1295       const CommonConnectJobParams* common_connect_job_params);
1296 
1297   MockTransportClientSocketPool(const MockTransportClientSocketPool&) = delete;
1298   MockTransportClientSocketPool& operator=(
1299       const MockTransportClientSocketPool&) = delete;
1300 
1301   ~MockTransportClientSocketPool() override;
1302 
last_request_priority()1303   RequestPriority last_request_priority() const {
1304     return last_request_priority_;
1305   }
1306 
requests()1307   const std::vector<std::unique_ptr<MockConnectJob>>& requests() const {
1308     return job_list_;
1309   }
1310 
release_count()1311   int release_count() const { return release_count_; }
cancel_count()1312   int cancel_count() const { return cancel_count_; }
1313 
1314   // TransportClientSocketPool implementation.
1315   int RequestSocket(
1316       const GroupId& group_id,
1317       scoped_refptr<ClientSocketPool::SocketParams> socket_params,
1318       const std::optional<NetworkTrafficAnnotationTag>& proxy_annotation_tag,
1319       RequestPriority priority,
1320       const SocketTag& socket_tag,
1321       RespectLimits respect_limits,
1322       ClientSocketHandle* handle,
1323       CompletionOnceCallback callback,
1324       const ProxyAuthCallback& on_auth_callback,
1325       const NetLogWithSource& net_log) override;
1326   void SetPriority(const GroupId& group_id,
1327                    ClientSocketHandle* handle,
1328                    RequestPriority priority) override;
1329   void CancelRequest(const GroupId& group_id,
1330                      ClientSocketHandle* handle,
1331                      bool cancel_connect_job) override;
1332   void ReleaseSocket(const GroupId& group_id,
1333                      std::unique_ptr<StreamSocket> socket,
1334                      int64_t generation) override;
1335 
1336  private:
1337   raw_ptr<ClientSocketFactory> client_socket_factory_;
1338   std::vector<std::unique_ptr<MockConnectJob>> job_list_;
1339   RequestPriority last_request_priority_ = DEFAULT_PRIORITY;
1340   int release_count_ = 0;
1341   int cancel_count_ = 0;
1342 };
1343 
1344 // WrappedStreamSocket is a base class that wraps an existing StreamSocket,
1345 // forwarding the Socket and StreamSocket interfaces to the underlying
1346 // transport.
1347 // This is to provide a common base class for subclasses to override specific
1348 // StreamSocket methods for testing, while still communicating with a 'real'
1349 // StreamSocket.
1350 class WrappedStreamSocket : public TransportClientSocket {
1351  public:
1352   explicit WrappedStreamSocket(std::unique_ptr<StreamSocket> transport);
1353   ~WrappedStreamSocket() override;
1354 
1355   // StreamSocket implementation:
1356   int Bind(const net::IPEndPoint& local_addr) override;
1357   int Connect(CompletionOnceCallback callback) override;
1358   void Disconnect() override;
1359   bool IsConnected() const override;
1360   bool IsConnectedAndIdle() const override;
1361   int GetPeerAddress(IPEndPoint* address) const override;
1362   int GetLocalAddress(IPEndPoint* address) const override;
1363   const NetLogWithSource& NetLog() const override;
1364   bool WasEverUsed() const override;
1365   NextProto GetNegotiatedProtocol() const override;
1366   bool GetSSLInfo(SSLInfo* ssl_info) override;
1367   int64_t GetTotalReceivedBytes() const override;
1368   void ApplySocketTag(const SocketTag& tag) override;
1369 
1370   // Socket implementation:
1371   int Read(IOBuffer* buf,
1372            int buf_len,
1373            CompletionOnceCallback callback) override;
1374   int ReadIfReady(IOBuffer* buf,
1375                   int buf_len,
1376                   CompletionOnceCallback callback) override;
1377   int Write(IOBuffer* buf,
1378             int buf_len,
1379             CompletionOnceCallback callback,
1380             const NetworkTrafficAnnotationTag& traffic_annotation) override;
1381   int SetReceiveBufferSize(int32_t size) override;
1382   int SetSendBufferSize(int32_t size) override;
1383 
1384  protected:
1385   std::unique_ptr<StreamSocket> transport_;
1386 };
1387 
1388 // StreamSocket that wraps another StreamSocket, but keeps track of any
1389 // SocketTag applied to the socket.
1390 class MockTaggingStreamSocket : public WrappedStreamSocket {
1391  public:
MockTaggingStreamSocket(std::unique_ptr<StreamSocket> transport)1392   explicit MockTaggingStreamSocket(std::unique_ptr<StreamSocket> transport)
1393       : WrappedStreamSocket(std::move(transport)) {}
1394 
1395   MockTaggingStreamSocket(const MockTaggingStreamSocket&) = delete;
1396   MockTaggingStreamSocket& operator=(const MockTaggingStreamSocket&) = delete;
1397 
1398   ~MockTaggingStreamSocket() override = default;
1399 
1400   // StreamSocket implementation.
1401   int Connect(CompletionOnceCallback callback) override;
1402   void ApplySocketTag(const SocketTag& tag) override;
1403 
1404   // Returns false if socket's tag was changed after the socket was connected,
1405   // otherwise returns true.
tagged_before_connected()1406   bool tagged_before_connected() const { return tagged_before_connected_; }
1407 
1408   // Returns last tag applied to socket.
tag()1409   SocketTag tag() const { return tag_; }
1410 
1411  private:
1412   bool connected_ = false;
1413   bool tagged_before_connected_ = true;
1414   SocketTag tag_;
1415 };
1416 
1417 // Extend MockClientSocketFactory to return MockTaggingStreamSockets and
1418 // keep track of last socket produced for test inspection.
1419 class MockTaggingClientSocketFactory : public MockClientSocketFactory {
1420  public:
1421   MockTaggingClientSocketFactory() = default;
1422 
1423   MockTaggingClientSocketFactory(const MockTaggingClientSocketFactory&) =
1424       delete;
1425   MockTaggingClientSocketFactory& operator=(
1426       const MockTaggingClientSocketFactory&) = delete;
1427 
1428   // ClientSocketFactory implementation.
1429   std::unique_ptr<DatagramClientSocket> CreateDatagramClientSocket(
1430       DatagramSocket::BindType bind_type,
1431       NetLog* net_log,
1432       const NetLogSource& source) override;
1433   std::unique_ptr<TransportClientSocket> CreateTransportClientSocket(
1434       const AddressList& addresses,
1435       std::unique_ptr<SocketPerformanceWatcher> socket_performance_watcher,
1436       NetworkQualityEstimator* network_quality_estimator,
1437       NetLog* net_log,
1438       const NetLogSource& source) override;
1439 
1440   // These methods return pointers to last TCP and UDP sockets produced by this
1441   // factory. NOTE: Socket must still exist, or pointer will be to freed memory.
GetLastProducedTCPSocket()1442   MockTaggingStreamSocket* GetLastProducedTCPSocket() const {
1443     return tcp_socket_;
1444   }
GetLastProducedUDPSocket()1445   MockUDPClientSocket* GetLastProducedUDPSocket() const { return udp_socket_; }
1446 
1447  private:
1448   raw_ptr<MockTaggingStreamSocket, AcrossTasksDanglingUntriaged> tcp_socket_ =
1449       nullptr;
1450   raw_ptr<MockUDPClientSocket, AcrossTasksDanglingUntriaged> udp_socket_ =
1451       nullptr;
1452 };
1453 
1454 // Host / port used for SOCKS4 test strings.
1455 extern const char kSOCKS4TestHost[];
1456 extern const int kSOCKS4TestPort;
1457 
1458 // Constants for a successful SOCKS v4 handshake (connecting to kSOCKS4TestHost
1459 // on port kSOCKS4TestPort, for the request).
1460 extern const char kSOCKS4OkRequestLocalHostPort80[];
1461 extern const int kSOCKS4OkRequestLocalHostPort80Length;
1462 
1463 extern const char kSOCKS4OkReply[];
1464 extern const int kSOCKS4OkReplyLength;
1465 
1466 // Host / port used for SOCKS5 test strings.
1467 extern const char kSOCKS5TestHost[];
1468 extern const int kSOCKS5TestPort;
1469 
1470 // Constants for a successful SOCKS v5 handshake (connecting to kSOCKS5TestHost
1471 // on port kSOCKS5TestPort, for the request)..
1472 extern const char kSOCKS5GreetRequest[];
1473 extern const int kSOCKS5GreetRequestLength;
1474 
1475 extern const char kSOCKS5GreetResponse[];
1476 extern const int kSOCKS5GreetResponseLength;
1477 
1478 extern const char kSOCKS5OkRequest[];
1479 extern const int kSOCKS5OkRequestLength;
1480 
1481 extern const char kSOCKS5OkResponse[];
1482 extern const int kSOCKS5OkResponseLength;
1483 
1484 // Helper function to get the total data size of the MockReads in |reads|.
1485 int64_t CountReadBytes(base::span<const MockRead> reads);
1486 
1487 // Helper function to get the total data size of the MockWrites in |writes|.
1488 int64_t CountWriteBytes(base::span<const MockWrite> writes);
1489 
1490 #if BUILDFLAG(IS_ANDROID)
1491 // Returns whether the device supports calling GetTaggedBytes().
1492 bool CanGetTaggedBytes();
1493 
1494 // Query the system to find out how many bytes were received with tag
1495 // |expected_tag| for our UID.  Return the count of received bytes.
1496 uint64_t GetTaggedBytes(int32_t expected_tag);
1497 #endif
1498 
1499 }  // namespace net
1500 
1501 #endif  // NET_SOCKET_SOCKET_TEST_UTIL_H_
1502