• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #ifndef NET_SOCKET_SOCKET_TEST_UTIL_H_
6 #define NET_SOCKET_SOCKET_TEST_UTIL_H_
7 
8 #include <stddef.h>
9 #include <stdint.h>
10 
11 #include <cstring>
12 #include <memory>
13 #include <string>
14 #include <utility>
15 #include <vector>
16 
17 #include "base/check_op.h"
18 #include "base/containers/span.h"
19 #include "base/functional/bind.h"
20 #include "base/functional/callback.h"
21 #include "base/memory/ptr_util.h"
22 #include "base/memory/raw_ptr.h"
23 #include "base/memory/raw_ptr_exclusion.h"
24 #include "base/memory/ref_counted.h"
25 #include "base/memory/weak_ptr.h"
26 #include "build/build_config.h"
27 #include "net/base/address_list.h"
28 #include "net/base/completion_once_callback.h"
29 #include "net/base/io_buffer.h"
30 #include "net/base/net_errors.h"
31 #include "net/base/test_completion_callback.h"
32 #include "net/http/http_auth_controller.h"
33 #include "net/log/net_log_with_source.h"
34 #include "net/socket/client_socket_factory.h"
35 #include "net/socket/client_socket_handle.h"
36 #include "net/socket/client_socket_pool.h"
37 #include "net/socket/datagram_client_socket.h"
38 #include "net/socket/socket_performance_watcher.h"
39 #include "net/socket/socket_tag.h"
40 #include "net/socket/ssl_client_socket.h"
41 #include "net/socket/transport_client_socket.h"
42 #include "net/socket/transport_client_socket_pool.h"
43 #include "net/ssl/ssl_config_service.h"
44 #include "net/ssl/ssl_info.h"
45 #include "testing/gtest/include/gtest/gtest.h"
46 #include "third_party/abseil-cpp/absl/types/optional.h"
47 
48 namespace base {
49 class RunLoop;
50 }
51 
52 namespace net {
53 
54 struct CommonConnectJobParams;
55 class NetLog;
56 struct NetworkTrafficAnnotationTag;
57 class X509Certificate;
58 
59 const handles::NetworkHandle kDefaultNetworkForTests = 1;
60 const handles::NetworkHandle kNewNetworkForTests = 2;
61 
62 enum {
63   // A private network error code used by the socket test utility classes.
64   // If the |result| member of a MockRead is
65   // ERR_TEST_PEER_CLOSE_AFTER_NEXT_MOCK_READ, that MockRead is just a
66   // marker that indicates the peer will close the connection after the next
67   // MockRead.  The other members of that MockRead are ignored.
68   ERR_TEST_PEER_CLOSE_AFTER_NEXT_MOCK_READ = -10000,
69 };
70 
71 class AsyncSocket;
72 class MockClientSocket;
73 class SSLClientSocket;
74 class StreamSocket;
75 
76 enum IoMode { ASYNC, SYNCHRONOUS };
77 
78 struct MockConnect {
79   // Asynchronous connection success.
80   // Creates a MockConnect with |mode| ASYC, |result| OK, and
81   // |peer_addr| 192.0.2.33.
82   MockConnect();
83   // Creates a MockConnect with the specified mode and result, with
84   // |peer_addr| 192.0.2.33.
85   MockConnect(IoMode io_mode, int r);
86   MockConnect(IoMode io_mode, int r, IPEndPoint addr);
87   MockConnect(IoMode io_mode, int r, IPEndPoint addr, bool first_attempt_fails);
88   ~MockConnect();
89 
90   IoMode mode;
91   int result;
92   IPEndPoint peer_addr;
93   bool first_attempt_fails = false;
94 };
95 
96 struct MockConfirm {
97   // Asynchronous confirm success.
98   // Creates a MockConfirm with |mode| ASYC and |result| OK.
99   MockConfirm();
100   // Creates a MockConfirm with the specified mode and result.
101   MockConfirm(IoMode io_mode, int r);
102   ~MockConfirm();
103 
104   IoMode mode;
105   int result;
106 };
107 
108 // MockRead and MockWrite shares the same interface and members, but we'd like
109 // to have distinct types because we don't want to have them used
110 // interchangably. To do this, a struct template is defined, and MockRead and
111 // MockWrite are instantiated by using this template. Template parameter |type|
112 // is not used in the struct definition (it purely exists for creating a new
113 // type).
114 //
115 // |data| in MockRead and MockWrite has different meanings: |data| in MockRead
116 // is the data returned from the socket when MockTCPClientSocket::Read() is
117 // attempted, while |data| in MockWrite is the expected data that should be
118 // given in MockTCPClientSocket::Write().
119 enum MockReadWriteType { MOCK_READ, MOCK_WRITE };
120 
121 template <MockReadWriteType type>
122 struct MockReadWrite {
123   // Flag to indicate that the message loop should be terminated.
124   enum { STOPLOOP = 1 << 31 };
125 
126   // Default
MockReadWriteMockReadWrite127   MockReadWrite()
128       : mode(SYNCHRONOUS),
129         result(0),
130         data(nullptr),
131         data_len(0),
132         sequence_number(0) {}
133 
134   // Read/write failure (no data).
MockReadWriteMockReadWrite135   MockReadWrite(IoMode io_mode, int result)
136       : mode(io_mode),
137         result(result),
138         data(nullptr),
139         data_len(0),
140         sequence_number(0) {}
141 
142   // Read/write failure (no data), with sequence information.
MockReadWriteMockReadWrite143   MockReadWrite(IoMode io_mode, int result, int seq)
144       : mode(io_mode),
145         result(result),
146         data(nullptr),
147         data_len(0),
148         sequence_number(seq) {}
149 
150   // Asynchronous read/write success (inferred data length).
MockReadWriteMockReadWrite151   explicit MockReadWrite(const char* data)
152       : mode(ASYNC),
153         result(0),
154         data(data),
155         data_len(strlen(data)),
156         sequence_number(0) {}
157 
158   // Read/write success (inferred data length).
MockReadWriteMockReadWrite159   MockReadWrite(IoMode io_mode, const char* data)
160       : mode(io_mode),
161         result(0),
162         data(data),
163         data_len(strlen(data)),
164         sequence_number(0) {}
165 
166   // Read/write success.
MockReadWriteMockReadWrite167   MockReadWrite(IoMode io_mode, const char* data, int data_len)
168       : mode(io_mode),
169         result(0),
170         data(data),
171         data_len(data_len),
172         sequence_number(0) {}
173 
174   // Read/write success (inferred data length) with sequence information.
MockReadWriteMockReadWrite175   MockReadWrite(IoMode io_mode, int seq, const char* data)
176       : mode(io_mode),
177         result(0),
178         data(data),
179         data_len(strlen(data)),
180         sequence_number(seq) {}
181 
182   // Read/write success with sequence information.
MockReadWriteMockReadWrite183   MockReadWrite(IoMode io_mode, const char* data, int data_len, int seq)
184       : mode(io_mode),
185         result(0),
186         data(data),
187         data_len(data_len),
188         sequence_number(seq) {}
189 
190   IoMode mode;
191   int result;
192   const char* data;
193   int data_len;
194 
195   // For data providers that only allows reads to occur in a particular
196   // sequence.  If a read occurs before the given |sequence_number| is reached,
197   // an ERR_IO_PENDING is returned.
198   int sequence_number;  // The sequence number at which a read is allowed
199                         // to occur.
200 };
201 
202 typedef MockReadWrite<MOCK_READ> MockRead;
203 typedef MockReadWrite<MOCK_WRITE> MockWrite;
204 
205 struct MockWriteResult {
MockWriteResultMockWriteResult206   MockWriteResult(IoMode io_mode, int result) : mode(io_mode), result(result) {}
207 
208   IoMode mode;
209   int result;
210 };
211 
212 // The SocketDataProvider is an interface used by the MockClientSocket
213 // for getting data about individual reads and writes on the socket.  Can be
214 // used with at most one socket at a time.
215 // TODO(mmenke):  Do these really need to be re-useable?
216 class SocketDataProvider {
217  public:
218   SocketDataProvider();
219 
220   SocketDataProvider(const SocketDataProvider&) = delete;
221   SocketDataProvider& operator=(const SocketDataProvider&) = delete;
222 
223   virtual ~SocketDataProvider();
224 
225   // Returns the buffer and result code for the next simulated read.
226   // If the |MockRead.result| is ERR_IO_PENDING, it informs the caller
227   // that it will be called via the AsyncSocket::OnReadComplete()
228   // function at a later time.
229   virtual MockRead OnRead() = 0;
230   virtual MockWriteResult OnWrite(const std::string& data) = 0;
231   virtual bool AllReadDataConsumed() const = 0;
232   virtual bool AllWriteDataConsumed() const = 0;
CancelPendingRead()233   virtual void CancelPendingRead() {}
234 
235   // Returns the last set receive buffer size, or -1 if never set.
receive_buffer_size()236   int receive_buffer_size() const { return receive_buffer_size_; }
set_receive_buffer_size(int receive_buffer_size)237   void set_receive_buffer_size(int receive_buffer_size) {
238     receive_buffer_size_ = receive_buffer_size;
239   }
240 
241   // Returns the last set send buffer size, or -1 if never set.
send_buffer_size()242   int send_buffer_size() const { return send_buffer_size_; }
set_send_buffer_size(int send_buffer_size)243   void set_send_buffer_size(int send_buffer_size) {
244     send_buffer_size_ = send_buffer_size;
245   }
246 
247   // Returns the last set value of TCP no delay, or false if never set.
no_delay()248   bool no_delay() const { return no_delay_; }
set_no_delay(bool no_delay)249   void set_no_delay(bool no_delay) { no_delay_ = no_delay; }
250 
251   // Returns whether TCP keepalives were enabled or not. Returns kDefault by
252   // default.
253   enum class KeepAliveState { kEnabled, kDisabled, kDefault };
keep_alive_state()254   KeepAliveState keep_alive_state() const { return keep_alive_state_; }
255   // Last set TCP keepalive delay.
keep_alive_delay()256   int keep_alive_delay() const { return keep_alive_delay_; }
set_keep_alive(bool enable,int delay)257   void set_keep_alive(bool enable, int delay) {
258     keep_alive_state_ =
259         enable ? KeepAliveState::kEnabled : KeepAliveState::kDisabled;
260     keep_alive_delay_ = delay;
261   }
262 
263   // Setters / getters for the return values of the corresponding Set*()
264   // methods. By default, they all succeed, if the socket is connected.
265 
set_set_receive_buffer_size_result(int receive_buffer_size_result)266   void set_set_receive_buffer_size_result(int receive_buffer_size_result) {
267     set_receive_buffer_size_result_ = receive_buffer_size_result;
268   }
set_receive_buffer_size_result()269   int set_receive_buffer_size_result() const {
270     return set_receive_buffer_size_result_;
271   }
272 
set_set_send_buffer_size_result(int set_send_buffer_size_result)273   void set_set_send_buffer_size_result(int set_send_buffer_size_result) {
274     set_send_buffer_size_result_ = set_send_buffer_size_result;
275   }
set_send_buffer_size_result()276   int set_send_buffer_size_result() const {
277     return set_send_buffer_size_result_;
278   }
279 
set_set_no_delay_result(bool set_no_delay_result)280   void set_set_no_delay_result(bool set_no_delay_result) {
281     set_no_delay_result_ = set_no_delay_result;
282   }
set_no_delay_result()283   bool set_no_delay_result() const { return set_no_delay_result_; }
284 
set_set_keep_alive_result(bool set_keep_alive_result)285   void set_set_keep_alive_result(bool set_keep_alive_result) {
286     set_keep_alive_result_ = set_keep_alive_result;
287   }
set_keep_alive_result()288   bool set_keep_alive_result() const { return set_keep_alive_result_; }
289 
expected_addresses()290   const absl::optional<AddressList>& expected_addresses() const {
291     return expected_addresses_;
292   }
set_expected_addresses(net::AddressList addresses)293   void set_expected_addresses(net::AddressList addresses) {
294     expected_addresses_ = std::move(addresses);
295   }
296 
297   // Returns true if the request should be considered idle, for the purposes of
298   // IsConnectedAndIdle.
299   virtual bool IsIdle() const;
300 
301   // Initializes the SocketDataProvider for use with |socket|.  Must be called
302   // before use
303   void Initialize(AsyncSocket* socket);
304   // Detaches the socket associated with a SocketDataProvider.  Must be called
305   // before |socket_| is destroyed, unless the SocketDataProvider has informed
306   // |socket_| it was destroyed.  Must also be called before Initialize() may
307   // be called again with a new socket.
308   void DetachSocket();
309 
310   // Accessor for the socket which is using the SocketDataProvider.
socket()311   AsyncSocket* socket() { return socket_; }
312 
connect_data()313   MockConnect connect_data() const { return connect_; }
set_connect_data(const MockConnect & connect)314   void set_connect_data(const MockConnect& connect) { connect_ = connect; }
315 
316  private:
317   // Called to inform subclasses of initialization.
318   virtual void Reset() = 0;
319 
320   MockConnect connect_;
321   raw_ptr<AsyncSocket> socket_ = nullptr;
322 
323   int receive_buffer_size_ = -1;
324   int send_buffer_size_ = -1;
325   // This reflects the default state of TCPClientSockets.
326   bool no_delay_ = true;
327 
328   KeepAliveState keep_alive_state_ = KeepAliveState::kDefault;
329   int keep_alive_delay_ = 0;
330 
331   int set_receive_buffer_size_result_ = net::OK;
332   int set_send_buffer_size_result_ = net::OK;
333   bool set_no_delay_result_ = true;
334   bool set_keep_alive_result_ = true;
335   absl::optional<AddressList> expected_addresses_;
336 };
337 
338 // The AsyncSocket is an interface used by the SocketDataProvider to
339 // complete the asynchronous read operation.
340 class AsyncSocket {
341  public:
342   // If an async IO is pending because the SocketDataProvider returned
343   // ERR_IO_PENDING, then the AsyncSocket waits until this OnReadComplete
344   // is called to complete the asynchronous read operation.
345   // data.async is ignored, and this read is completed synchronously as
346   // part of this call.
347   // TODO(rch): this should take a StringPiece since most of the fields
348   // are ignored.
349   virtual void OnReadComplete(const MockRead& data) = 0;
350   // If an async IO is pending because the SocketDataProvider returned
351   // ERR_IO_PENDING, then the AsyncSocket waits until this OnReadComplete
352   // is called to complete the asynchronous read operation.
353   virtual void OnWriteComplete(int rv) = 0;
354   virtual void OnConnectComplete(const MockConnect& data) = 0;
355 
356   // Called when the SocketDataProvider associated with the socket is destroyed.
357   // The socket may continue to be used after the data provider is destroyed,
358   // so it should be sure not to dereference the provider after this is called.
359   virtual void OnDataProviderDestroyed() = 0;
360 };
361 
362 class SocketDataPrinter {
363  public:
364   ~SocketDataPrinter() = default;
365 
366   // Prints the write in |data| using some sort of protocol-specific
367   // format.
368   virtual std::string PrintWrite(const std::string& data) = 0;
369 };
370 
371 // StaticSocketDataHelper manages a list of reads and writes.
372 class StaticSocketDataHelper {
373  public:
374   StaticSocketDataHelper(base::span<const MockRead> reads,
375                          base::span<const MockWrite> writes);
376 
377   StaticSocketDataHelper(const StaticSocketDataHelper&) = delete;
378   StaticSocketDataHelper& operator=(const StaticSocketDataHelper&) = delete;
379 
380   ~StaticSocketDataHelper();
381 
382   // These functions get access to the next available read and write data. They
383   // CHECK fail if there is no data available.
384   const MockRead& PeekRead() const;
385   const MockWrite& PeekWrite() const;
386 
387   // Returns the current read or write, and then advances to the next one.
388   const MockRead& AdvanceRead();
389   const MockWrite& AdvanceWrite();
390 
391   // Resets the read and write indexes to 0.
392   void Reset();
393 
394   // Returns true if |data| is valid data for the next write. In order
395   // to support short writes, the next write may be longer than |data|
396   // in which case this method will still return true.
397   bool VerifyWriteData(const std::string& data, SocketDataPrinter* printer);
398 
read_index()399   size_t read_index() const { return read_index_; }
write_index()400   size_t write_index() const { return write_index_; }
read_count()401   size_t read_count() const { return reads_.size(); }
write_count()402   size_t write_count() const { return writes_.size(); }
403 
AllReadDataConsumed()404   bool AllReadDataConsumed() const { return read_index() >= read_count(); }
AllWriteDataConsumed()405   bool AllWriteDataConsumed() const { return write_index() >= write_count(); }
406 
407  private:
408   // Returns the next available read or write that is not a pause event. CHECK
409   // fails if no data is available.
410   const MockWrite& PeekRealWrite() const;
411 
412   const base::span<const MockRead> reads_;
413   size_t read_index_ = 0;
414   const base::span<const MockWrite> writes_;
415   size_t write_index_ = 0;
416 };
417 
418 // SocketDataProvider which responds based on static tables of mock reads and
419 // writes.
420 class StaticSocketDataProvider : public SocketDataProvider {
421  public:
422   StaticSocketDataProvider();
423   StaticSocketDataProvider(base::span<const MockRead> reads,
424                            base::span<const MockWrite> writes);
425 
426   StaticSocketDataProvider(const StaticSocketDataProvider&) = delete;
427   StaticSocketDataProvider& operator=(const StaticSocketDataProvider&) = delete;
428 
429   ~StaticSocketDataProvider() override;
430 
431   // Pause/resume reads from this provider.
432   void Pause();
433   void Resume();
434 
435   // From SocketDataProvider:
436   MockRead OnRead() override;
437   MockWriteResult OnWrite(const std::string& data) override;
438   bool AllReadDataConsumed() const override;
439   bool AllWriteDataConsumed() const override;
440 
read_index()441   size_t read_index() const { return helper_.read_index(); }
write_index()442   size_t write_index() const { return helper_.write_index(); }
read_count()443   size_t read_count() const { return helper_.read_count(); }
write_count()444   size_t write_count() const { return helper_.write_count(); }
445 
set_printer(SocketDataPrinter * printer)446   void set_printer(SocketDataPrinter* printer) { printer_ = printer; }
447 
448  private:
449   // From SocketDataProvider:
450   void Reset() override;
451 
452   StaticSocketDataHelper helper_;
453   // This field is not a raw_ptr<> because it was filtered by the rewriter for:
454   // #union
455   RAW_PTR_EXCLUSION SocketDataPrinter* printer_ = nullptr;
456   bool paused_ = false;
457 };
458 
459 // SSLSocketDataProviders only need to keep track of the return code from calls
460 // to Connect().
461 struct SSLSocketDataProvider {
462   SSLSocketDataProvider(IoMode mode, int result);
463   SSLSocketDataProvider(const SSLSocketDataProvider& other);
464   ~SSLSocketDataProvider();
465 
466   // Returns whether MockConnect data has been consumed.
ConnectDataConsumedSSLSocketDataProvider467   bool ConnectDataConsumed() const { return is_connect_data_consumed; }
468 
469   // Returns whether MockConfirm data has been consumed.
ConfirmDataConsumedSSLSocketDataProvider470   bool ConfirmDataConsumed() const { return is_confirm_data_consumed; }
471 
472   // Returns whether a Write occurred before ConfirmHandshake completed.
WriteBeforeConfirmSSLSocketDataProvider473   bool WriteBeforeConfirm() const { return write_called_before_confirm; }
474 
475   // Result for Connect().
476   MockConnect connect;
477   // Callback to run when Connect() is called. This is called at most once per
478   // socket but is repeating because SSLSocketDataProvider is copyable.
479   base::RepeatingClosure connect_callback;
480 
481   // Result for ConfirmHandshake().
482   MockConfirm confirm;
483   // Callback to run when ConfirmHandshake() is called. This is called at most
484   // once per socket but is repeating because SSLSocketDataProvider is
485   // copyable.
486   base::RepeatingClosure confirm_callback;
487 
488   // Result for GetNegotiatedProtocol().
489   NextProto next_proto = kProtoUnknown;
490 
491   // Result for GetPeerApplicationSettings().
492   absl::optional<std::string> peer_application_settings;
493 
494   // Result for GetSSLInfo().
495   SSLInfo ssl_info;
496 
497   // Result for GetSSLCertRequestInfo().
498   scoped_refptr<SSLCertRequestInfo> cert_request_info;
499 
500   // Result for GetECHRetryConfigs().
501   std::vector<uint8_t> ech_retry_configs;
502 
503   absl::optional<NextProtoVector> next_protos_expected_in_ssl_config;
504 
505   uint16_t expected_ssl_version_min;
506   uint16_t expected_ssl_version_max;
507   absl::optional<bool> expected_send_client_cert;
508   scoped_refptr<X509Certificate> expected_client_cert;
509   absl::optional<HostPortPair> expected_host_and_port;
510   absl::optional<bool> expected_ignore_certificate_errors;
511   absl::optional<NetworkAnonymizationKey> expected_network_anonymization_key;
512   absl::optional<bool> expected_disable_sha1_server_signatures;
513   absl::optional<std::vector<uint8_t>> expected_ech_config_list;
514 
515   bool is_connect_data_consumed = false;
516   bool is_confirm_data_consumed = false;
517   bool write_called_before_confirm = false;
518 };
519 
520 // Uses the sequence_number field in the mock reads and writes to
521 // complete the operations in a specified order.
522 class SequencedSocketData : public SocketDataProvider {
523  public:
524   SequencedSocketData();
525 
526   // |reads| is the list of MockRead completions.
527   // |writes| is the list of MockWrite completions.
528   SequencedSocketData(base::span<const MockRead> reads,
529                       base::span<const MockWrite> writes);
530 
531   // |connect| is the result for the connect phase.
532   // |reads| is the list of MockRead completions.
533   // |writes| is the list of MockWrite completions.
534   SequencedSocketData(const MockConnect& connect,
535                       base::span<const MockRead> reads,
536                       base::span<const MockWrite> writes);
537 
538   SequencedSocketData(const SequencedSocketData&) = delete;
539   SequencedSocketData& operator=(const SequencedSocketData&) = delete;
540 
541   ~SequencedSocketData() override;
542 
543   // From SocketDataProvider:
544   MockRead OnRead() override;
545   MockWriteResult OnWrite(const std::string& data) override;
546   bool AllReadDataConsumed() const override;
547   bool AllWriteDataConsumed() const override;
548   bool IsIdle() const override;
549   void CancelPendingRead() override;
550 
551   // An ASYNC read event with a return value of ERR_IO_PENDING will cause the
552   // socket data to pause at that event, and advance no further, until Resume is
553   // invoked.  At that point, the socket will continue at the next event in the
554   // sequence.
555   //
556   // If a request just wants to simulate a connection that stays open and never
557   // receives any more data, instead of pausing and then resuming a request, it
558   // should use a SYNCHRONOUS event with a return value of ERR_IO_PENDING
559   // instead.
560   bool IsPaused() const;
561   // Resumes events once |this| is in the paused state.  The next event will
562   // occur synchronously with the call if it can.
563   void Resume();
564   void RunUntilPaused();
565 
566   // When true, IsConnectedAndIdle() will return false if the next event in the
567   // sequence is a synchronous.  Otherwise, the socket claims to be idle as
568   // long as it's connected.  Defaults to false.
569   // TODO(mmenke):  See if this can be made the default behavior, and consider
570   // removing this mehtod.  Need to make sure it doesn't change what code any
571   // tests are targetted at testing.
set_busy_before_sync_reads(bool busy_before_sync_reads)572   void set_busy_before_sync_reads(bool busy_before_sync_reads) {
573     busy_before_sync_reads_ = busy_before_sync_reads;
574   }
575 
set_printer(SocketDataPrinter * printer)576   void set_printer(SocketDataPrinter* printer) { printer_ = printer; }
577 
578  private:
579   // Defines the state for the read or write path.
580   enum class IoState {
581     kIdle,        // No async operation is in progress.
582     kPending,     // An async operation in waiting for another operation to
583                   // complete.
584     kCompleting,  // A task has been posted to complete an async operation.
585     kPaused,      // IO is paused until Resume() is called.
586   };
587 
588   // From SocketDataProvider:
589   void Reset() override;
590 
591   void OnReadComplete();
592   void OnWriteComplete();
593 
594   void MaybePostReadCompleteTask();
595   void MaybePostWriteCompleteTask();
596 
597   StaticSocketDataHelper helper_;
598   raw_ptr<SocketDataPrinter> printer_ = nullptr;
599   int sequence_number_ = 0;
600   IoState read_state_ = IoState::kIdle;
601   IoState write_state_ = IoState::kIdle;
602 
603   bool busy_before_sync_reads_ = false;
604 
605   // Used by RunUntilPaused.  NULL at all other times.
606   std::unique_ptr<base::RunLoop> run_until_paused_run_loop_;
607 
608   base::WeakPtrFactory<SequencedSocketData> weak_factory_{this};
609 };
610 
611 // Holds an array of SocketDataProvider elements.  As Mock{TCP,SSL}StreamSocket
612 // objects get instantiated, they take their data from the i'th element of this
613 // array.
614 template <typename T>
615 class SocketDataProviderArray {
616  public:
617   SocketDataProviderArray() = default;
618 
GetNext()619   T* GetNext() {
620     DCHECK_LT(next_index_, data_providers_.size());
621     return data_providers_[next_index_++];
622   }
623 
624   // Like GetNext(), but returns nullptr when the end of the array is reached,
625   // instead of DCHECKing. GetNext() should generally be preferred, unless
626   // having no remaining elements is expected in some cases and is handled
627   // safely.
GetNextWithoutAsserting()628   T* GetNextWithoutAsserting() {
629     if (next_index_ == data_providers_.size())
630       return nullptr;
631     return data_providers_[next_index_++];
632   }
633 
Add(T * data_provider)634   void Add(T* data_provider) {
635     DCHECK(data_provider);
636     data_providers_.push_back(data_provider);
637   }
638 
next_index()639   size_t next_index() { return next_index_; }
640 
ResetNextIndex()641   void ResetNextIndex() { next_index_ = 0; }
642 
643  private:
644   // Index of the next |data_providers_| element to use. Not an iterator
645   // because those are invalidated on vector reallocation.
646   size_t next_index_ = 0;
647 
648   // SocketDataProviders to be returned.
649   std::vector<T*> data_providers_;
650 };
651 
652 class MockUDPClientSocket;
653 class MockTCPClientSocket;
654 class MockSSLClientSocket;
655 
656 // ClientSocketFactory which contains arrays of sockets of each type.
657 // You should first fill the arrays using Add{SSL,}SocketDataProvider(). When
658 // the factory is asked to create a socket, it takes next entry from appropriate
659 // array. You can use ResetNextMockIndexes to reset that next entry index for
660 // all mock socket types.
661 class MockClientSocketFactory : public ClientSocketFactory {
662  public:
663   MockClientSocketFactory();
664 
665   MockClientSocketFactory(const MockClientSocketFactory&) = delete;
666   MockClientSocketFactory& operator=(const MockClientSocketFactory&) = delete;
667 
668   ~MockClientSocketFactory() override;
669 
670   // Adds a SocketDataProvider that can be used to served either TCP or UDP
671   // connection requests. Sockets are returned in FIFO order.
672   void AddSocketDataProvider(SocketDataProvider* socket);
673 
674   // Like AddSocketDataProvider(), except sockets will only be used to service
675   // TCP connection requests. Sockets added with this method are used first,
676   // before sockets added with AddSocketDataProvider(). Particularly useful for
677   // QUIC tests with multiple sockets, where TCP connections may or may not be
678   // made, and have no guaranteed order, relative to UDP connections.
679   void AddTcpSocketDataProvider(SocketDataProvider* socket);
680 
681   void AddSSLSocketDataProvider(SSLSocketDataProvider* socket);
682   void ResetNextMockIndexes();
683 
mock_data()684   SocketDataProviderArray<SocketDataProvider>& mock_data() {
685     return mock_data_;
686   }
687 
set_enable_read_if_ready(bool enable_read_if_ready)688   void set_enable_read_if_ready(bool enable_read_if_ready) {
689     enable_read_if_ready_ = enable_read_if_ready;
690   }
691 
692   // ClientSocketFactory
693   std::unique_ptr<DatagramClientSocket> CreateDatagramClientSocket(
694       DatagramSocket::BindType bind_type,
695       NetLog* net_log,
696       const NetLogSource& source) override;
697   std::unique_ptr<TransportClientSocket> CreateTransportClientSocket(
698       const AddressList& addresses,
699       std::unique_ptr<SocketPerformanceWatcher> socket_performance_watcher,
700       NetworkQualityEstimator* network_quality_estimator,
701       NetLog* net_log,
702       const NetLogSource& source) override;
703   std::unique_ptr<SSLClientSocket> CreateSSLClientSocket(
704       SSLClientContext* context,
705       std::unique_ptr<StreamSocket> stream_socket,
706       const HostPortPair& host_and_port,
707       const SSLConfig& ssl_config) override;
udp_client_socket_ports()708   const std::vector<uint16_t>& udp_client_socket_ports() const {
709     return udp_client_socket_ports_;
710   }
711 
712  private:
713   SocketDataProviderArray<SocketDataProvider> mock_data_;
714   SocketDataProviderArray<SocketDataProvider> mock_tcp_data_;
715   SocketDataProviderArray<SSLSocketDataProvider> mock_ssl_data_;
716   std::vector<uint16_t> udp_client_socket_ports_;
717 
718   // If true, ReadIfReady() is enabled; otherwise ReadIfReady() returns
719   // ERR_READ_IF_READY_NOT_IMPLEMENTED.
720   bool enable_read_if_ready_ = false;
721 };
722 
723 class MockClientSocket : public TransportClientSocket {
724  public:
725   // The NetLogWithSource is needed to test LoadTimingInfo, which uses NetLog
726   // IDs as
727   // unique socket IDs.
728   explicit MockClientSocket(const NetLogWithSource& net_log);
729 
730   MockClientSocket(const MockClientSocket&) = delete;
731   MockClientSocket& operator=(const MockClientSocket&) = delete;
732 
733   // Socket implementation.
734   int Read(IOBuffer* buf,
735            int buf_len,
736            CompletionOnceCallback callback) override = 0;
737   int Write(IOBuffer* buf,
738             int buf_len,
739             CompletionOnceCallback callback,
740             const NetworkTrafficAnnotationTag& traffic_annotation) override = 0;
741   int SetReceiveBufferSize(int32_t size) override;
742   int SetSendBufferSize(int32_t size) override;
743 
744   // TransportClientSocket implementation.
745   int Bind(const net::IPEndPoint& local_addr) override;
746   bool SetNoDelay(bool no_delay) override;
747   bool SetKeepAlive(bool enable, int delay) override;
748 
749   // StreamSocket implementation.
750   int Connect(CompletionOnceCallback callback) override = 0;
751   void Disconnect() override;
752   bool IsConnected() const override;
753   bool IsConnectedAndIdle() const override;
754   int GetPeerAddress(IPEndPoint* address) const override;
755   int GetLocalAddress(IPEndPoint* address) const override;
756   const NetLogWithSource& NetLog() const override;
757   NextProto GetNegotiatedProtocol() const override;
758   int64_t GetTotalReceivedBytes() const override;
ApplySocketTag(const SocketTag & tag)759   void ApplySocketTag(const SocketTag& tag) override {}
760 
761  protected:
762   ~MockClientSocket() override;
763   void RunCallbackAsync(CompletionOnceCallback callback, int result);
764   void RunCallback(CompletionOnceCallback callback, int result);
765 
766   // True if Connect completed successfully and Disconnect hasn't been called.
767   bool connected_ = false;
768 
769   IPEndPoint local_addr_;
770   IPEndPoint peer_addr_;
771 
772   NetLogWithSource net_log_;
773 
774  private:
775   base::WeakPtrFactory<MockClientSocket> weak_factory_{this};
776 };
777 
778 class MockTCPClientSocket : public MockClientSocket, public AsyncSocket {
779  public:
780   MockTCPClientSocket(const AddressList& addresses,
781                       net::NetLog* net_log,
782                       SocketDataProvider* socket);
783 
784   MockTCPClientSocket(const MockTCPClientSocket&) = delete;
785   MockTCPClientSocket& operator=(const MockTCPClientSocket&) = delete;
786 
787   ~MockTCPClientSocket() override;
788 
addresses()789   const AddressList& addresses() const { return addresses_; }
790 
791   // Socket implementation.
792   int Read(IOBuffer* buf,
793            int buf_len,
794            CompletionOnceCallback callback) override;
795   int ReadIfReady(IOBuffer* buf,
796                   int buf_len,
797                   CompletionOnceCallback callback) override;
798   int CancelReadIfReady() override;
799   int Write(IOBuffer* buf,
800             int buf_len,
801             CompletionOnceCallback callback,
802             const NetworkTrafficAnnotationTag& traffic_annotation) override;
803   int SetReceiveBufferSize(int32_t size) override;
804   int SetSendBufferSize(int32_t size) override;
805 
806   // TransportClientSocket implementation.
807   bool SetNoDelay(bool no_delay) override;
808   bool SetKeepAlive(bool enable, int delay) override;
809 
810   // StreamSocket implementation.
811   void SetBeforeConnectCallback(
812       const BeforeConnectCallback& before_connect_callback) override;
813   int Connect(CompletionOnceCallback callback) override;
814   void Disconnect() override;
815   bool IsConnected() const override;
816   bool IsConnectedAndIdle() const override;
817   int GetPeerAddress(IPEndPoint* address) const override;
818   bool WasEverUsed() const override;
819   bool GetSSLInfo(SSLInfo* ssl_info) override;
820 
821   // AsyncSocket:
822   void OnReadComplete(const MockRead& data) override;
823   void OnWriteComplete(int rv) override;
824   void OnConnectComplete(const MockConnect& data) override;
825   void OnDataProviderDestroyed() override;
826 
set_enable_read_if_ready(bool enable_read_if_ready)827   void set_enable_read_if_ready(bool enable_read_if_ready) {
828     enable_read_if_ready_ = enable_read_if_ready;
829   }
830 
831  private:
832   void RetryRead(int rv);
833   int ReadIfReadyImpl(IOBuffer* buf,
834                       int buf_len,
835                       CompletionOnceCallback callback);
836 
837   // Helper method to run |pending_read_if_ready_callback_| if it is not null.
838   void RunReadIfReadyCallback(int result);
839 
840   AddressList addresses_;
841 
842   raw_ptr<SocketDataProvider> data_;
843   int read_offset_ = 0;
844   MockRead read_data_;
845   bool need_read_data_ = true;
846 
847   // True if the peer has closed the connection.  This allows us to simulate
848   // the recv(..., MSG_PEEK) call in the IsConnectedAndIdle method of the real
849   // TCPClientSocket.
850   bool peer_closed_connection_ = false;
851 
852   // While an asynchronous read is pending, we save our user-buffer state.
853   scoped_refptr<IOBuffer> pending_read_buf_ = nullptr;
854   int pending_read_buf_len_ = 0;
855   CompletionOnceCallback pending_read_callback_;
856 
857   // Non-null when a ReadIfReady() is pending.
858   CompletionOnceCallback pending_read_if_ready_callback_;
859 
860   CompletionOnceCallback pending_connect_callback_;
861   CompletionOnceCallback pending_write_callback_;
862   bool was_used_to_convey_data_ = false;
863 
864   // If true, ReadIfReady() is enabled; otherwise ReadIfReady() returns
865   // ERR_READ_IF_READY_NOT_IMPLEMENTED.
866   bool enable_read_if_ready_ = false;
867 
868   BeforeConnectCallback before_connect_callback_;
869 };
870 
871 class MockSSLClientSocket : public AsyncSocket, public SSLClientSocket {
872  public:
873   MockSSLClientSocket(std::unique_ptr<StreamSocket> stream_socket,
874                       const HostPortPair& host_and_port,
875                       const SSLConfig& ssl_config,
876                       SSLSocketDataProvider* socket);
877 
878   MockSSLClientSocket(const MockSSLClientSocket&) = delete;
879   MockSSLClientSocket& operator=(const MockSSLClientSocket&) = delete;
880 
881   ~MockSSLClientSocket() override;
882 
883   // Socket implementation.
884   int Read(IOBuffer* buf,
885            int buf_len,
886            CompletionOnceCallback callback) override;
887   int ReadIfReady(IOBuffer* buf,
888                   int buf_len,
889                   CompletionOnceCallback callback) override;
890   int Write(IOBuffer* buf,
891             int buf_len,
892             CompletionOnceCallback callback,
893             const NetworkTrafficAnnotationTag& traffic_annotation) override;
894   int CancelReadIfReady() override;
895 
896   // StreamSocket implementation.
897   int Connect(CompletionOnceCallback callback) override;
898   void Disconnect() override;
899   int ConfirmHandshake(CompletionOnceCallback callback) override;
900   bool IsConnected() const override;
901   bool IsConnectedAndIdle() const override;
902   bool WasEverUsed() const override;
903   int GetPeerAddress(IPEndPoint* address) const override;
904   int GetLocalAddress(IPEndPoint* address) const override;
905   NextProto GetNegotiatedProtocol() const override;
906   absl::optional<base::StringPiece> GetPeerApplicationSettings() const override;
907   bool GetSSLInfo(SSLInfo* ssl_info) override;
908   void GetSSLCertRequestInfo(
909       SSLCertRequestInfo* cert_request_info) const override;
910   void ApplySocketTag(const SocketTag& tag) override;
911   const NetLogWithSource& NetLog() const override;
912   int64_t GetTotalReceivedBytes() const override;
913   int SetReceiveBufferSize(int32_t size) override;
914   int SetSendBufferSize(int32_t size) override;
915 
916   // SSLSocket implementation.
917   int ExportKeyingMaterial(base::StringPiece label,
918                            bool has_context,
919                            base::StringPiece context,
920                            unsigned char* out,
921                            unsigned int outlen) override;
922 
923   // SSLClientSocket implementation.
924   std::vector<uint8_t> GetECHRetryConfigs() override;
925 
926   // This MockSocket does not implement the manual async IO feature.
927   void OnReadComplete(const MockRead& data) override;
928   void OnWriteComplete(int rv) override;
929   void OnConnectComplete(const MockConnect& data) override;
930   // SSL sockets don't need magic to deal with destruction of their data
931   // provider.
932   // TODO(mmenke):  Probably a good idea to support it, anyways.
OnDataProviderDestroyed()933   void OnDataProviderDestroyed() override {}
934 
935  private:
936   static void ConnectCallback(MockSSLClientSocket* ssl_client_socket,
937                               CompletionOnceCallback callback,
938                               int rv);
939 
940   void RunCallbackAsync(CompletionOnceCallback callback, int result);
941   void RunCallback(CompletionOnceCallback callback, int result);
942 
943   void RunConfirmHandshakeCallback(CompletionOnceCallback callback, int result);
944 
945   bool connected_ = false;
946   bool in_confirm_handshake_ = false;
947   NetLogWithSource net_log_;
948   std::unique_ptr<StreamSocket> stream_socket_;
949   raw_ptr<SSLSocketDataProvider, AcrossTasksDanglingUntriaged> data_;
950   // Address of the "remote" peer we're connected to.
951   IPEndPoint peer_addr_;
952 
953   base::WeakPtrFactory<MockSSLClientSocket> weak_factory_{this};
954 };
955 
956 class MockUDPClientSocket : public DatagramClientSocket, public AsyncSocket {
957  public:
958   explicit MockUDPClientSocket(SocketDataProvider* data = nullptr,
959                                net::NetLog* net_log = nullptr);
960 
961   MockUDPClientSocket(const MockUDPClientSocket&) = delete;
962   MockUDPClientSocket& operator=(const MockUDPClientSocket&) = delete;
963 
964   ~MockUDPClientSocket() override;
965 
966   // Socket implementation.
967   int Read(IOBuffer* buf,
968            int buf_len,
969            CompletionOnceCallback callback) override;
970   int Write(IOBuffer* buf,
971             int buf_len,
972             CompletionOnceCallback callback,
973             const NetworkTrafficAnnotationTag& traffic_annotation) override;
974 
975   int SetReceiveBufferSize(int32_t size) override;
976   int SetSendBufferSize(int32_t size) override;
977   int SetDoNotFragment() override;
978   int SetRecvEcn() override;
979 
980   // DatagramSocket implementation.
981   void Close() override;
982   int GetPeerAddress(IPEndPoint* address) const override;
983   int GetLocalAddress(IPEndPoint* address) const override;
984   void UseNonBlockingIO() override;
985   int SetMulticastInterface(uint32_t interface_index) override;
986   const NetLogWithSource& NetLog() const override;
987 
988   // DatagramClientSocket implementation.
989   int Connect(const IPEndPoint& address) override;
990   int ConnectUsingNetwork(handles::NetworkHandle network,
991                           const IPEndPoint& address) override;
992   int ConnectUsingDefaultNetwork(const IPEndPoint& address) override;
993   int ConnectAsync(const IPEndPoint& address,
994                    CompletionOnceCallback callback) override;
995   int ConnectUsingNetworkAsync(handles::NetworkHandle network,
996                                const IPEndPoint& address,
997                                CompletionOnceCallback callback) override;
998   int ConnectUsingDefaultNetworkAsync(const IPEndPoint& address,
999                                       CompletionOnceCallback callback) override;
1000   handles::NetworkHandle GetBoundNetwork() const override;
1001   void ApplySocketTag(const SocketTag& tag) override;
SetMsgConfirm(bool confirm)1002   void SetMsgConfirm(bool confirm) override {}
1003 
1004   // AsyncSocket implementation.
1005   void OnReadComplete(const MockRead& data) override;
1006   void OnWriteComplete(int rv) override;
1007   void OnConnectComplete(const MockConnect& data) override;
1008   void OnDataProviderDestroyed() override;
1009 
set_source_port(uint16_t port)1010   void set_source_port(uint16_t port) { source_port_ = port; }
source_port()1011   uint16_t source_port() const { return source_port_; }
set_source_host(IPAddress addr)1012   void set_source_host(IPAddress addr) { source_host_ = addr; }
source_host()1013   IPAddress source_host() const { return source_host_; }
1014 
1015   // Returns last tag applied to socket.
tag()1016   SocketTag tag() const { return tag_; }
1017 
1018   // Returns false if socket's tag was changed after the socket was used for
1019   // data transfer (e.g. Read/Write() called), otherwise returns true.
tagged_before_data_transferred()1020   bool tagged_before_data_transferred() const {
1021     return tagged_before_data_transferred_;
1022   }
1023 
1024  private:
1025   int CompleteRead();
1026 
1027   void RunCallbackAsync(CompletionOnceCallback callback, int result);
1028   void RunCallback(CompletionOnceCallback callback, int result);
1029 
1030   bool connected_ = false;
1031   raw_ptr<SocketDataProvider> data_;
1032   int read_offset_ = 0;
1033   MockRead read_data_;
1034   bool need_read_data_ = true;
1035   IPAddress source_host_;
1036   uint16_t source_port_ = 123;  // Ephemeral source port.
1037 
1038   // Address of the "remote" peer we're connected to.
1039   IPEndPoint peer_addr_;
1040 
1041   // Network that the socket is bound to.
1042   handles::NetworkHandle network_ = handles::kInvalidNetworkHandle;
1043 
1044   // While an asynchronous IO is pending, we save our user-buffer state.
1045   scoped_refptr<IOBuffer> pending_read_buf_ = nullptr;
1046   int pending_read_buf_len_ = 0;
1047   CompletionOnceCallback pending_read_callback_;
1048   CompletionOnceCallback pending_write_callback_;
1049 
1050   NetLogWithSource net_log_;
1051 
1052   DatagramBuffers unwritten_buffers_;
1053 
1054   SocketTag tag_;
1055   bool data_transferred_ = false;
1056   bool tagged_before_data_transferred_ = true;
1057 
1058   base::WeakPtrFactory<MockUDPClientSocket> weak_factory_{this};
1059 };
1060 
1061 class TestSocketRequest : public TestCompletionCallbackBase {
1062  public:
1063   TestSocketRequest(std::vector<TestSocketRequest*>* request_order,
1064                     size_t* completion_count);
1065 
1066   TestSocketRequest(const TestSocketRequest&) = delete;
1067   TestSocketRequest& operator=(const TestSocketRequest&) = delete;
1068 
1069   ~TestSocketRequest() override;
1070 
handle()1071   ClientSocketHandle* handle() { return &handle_; }
1072 
callback()1073   CompletionOnceCallback callback() {
1074     return base::BindOnce(&TestSocketRequest::OnComplete,
1075                           base::Unretained(this));
1076   }
1077 
1078  private:
1079   void OnComplete(int result);
1080 
1081   ClientSocketHandle handle_;
1082   raw_ptr<std::vector<TestSocketRequest*>> request_order_;
1083   raw_ptr<size_t> completion_count_;
1084 };
1085 
1086 class ClientSocketPoolTest {
1087  public:
1088   enum KeepAlive {
1089     KEEP_ALIVE,
1090 
1091     // A socket will be disconnected in addition to handle being reset.
1092     NO_KEEP_ALIVE,
1093   };
1094 
1095   static const int kIndexOutOfBounds;
1096   static const int kRequestNotFound;
1097 
1098   ClientSocketPoolTest();
1099 
1100   ClientSocketPoolTest(const ClientSocketPoolTest&) = delete;
1101   ClientSocketPoolTest& operator=(const ClientSocketPoolTest&) = delete;
1102 
1103   ~ClientSocketPoolTest();
1104 
1105   template <typename PoolType>
StartRequestUsingPool(PoolType * socket_pool,const ClientSocketPool::GroupId & group_id,RequestPriority priority,ClientSocketPool::RespectLimits respect_limits,const scoped_refptr<typename PoolType::SocketParams> & socket_params)1106   int StartRequestUsingPool(
1107       PoolType* socket_pool,
1108       const ClientSocketPool::GroupId& group_id,
1109       RequestPriority priority,
1110       ClientSocketPool::RespectLimits respect_limits,
1111       const scoped_refptr<typename PoolType::SocketParams>& socket_params) {
1112     DCHECK(socket_pool);
1113     TestSocketRequest* request(
1114         new TestSocketRequest(&request_order_, &completion_count_));
1115     requests_.push_back(base::WrapUnique(request));
1116     int rv = request->handle()->Init(
1117         group_id, socket_params, absl::nullopt /* proxy_annotation_tag */,
1118         priority, SocketTag(), respect_limits, request->callback(),
1119         ClientSocketPool::ProxyAuthCallback(), socket_pool, NetLogWithSource());
1120     if (rv != ERR_IO_PENDING)
1121       request_order_.push_back(request);
1122     return rv;
1123   }
1124 
1125   // Provided there were n requests started, takes |index| in range 1..n
1126   // and returns order in which that request completed, in range 1..n,
1127   // or kIndexOutOfBounds if |index| is out of bounds, or kRequestNotFound
1128   // if that request did not complete (for example was canceled).
1129   int GetOrderOfRequest(size_t index) const;
1130 
1131   // Resets first initialized socket handle from |requests_|. If found such
1132   // a handle, returns true.
1133   bool ReleaseOneConnection(KeepAlive keep_alive);
1134 
1135   // Releases connections until there is nothing to release.
1136   void ReleaseAllConnections(KeepAlive keep_alive);
1137 
1138   // Note that this uses 0-based indices, while GetOrderOfRequest takes and
1139   // returns 1-based indices.
request(int i)1140   TestSocketRequest* request(int i) { return requests_[i].get(); }
1141 
requests_size()1142   size_t requests_size() const { return requests_.size(); }
requests()1143   std::vector<std::unique_ptr<TestSocketRequest>>* requests() {
1144     return &requests_;
1145   }
completion_count()1146   size_t completion_count() const { return completion_count_; }
1147 
1148  private:
1149   std::vector<std::unique_ptr<TestSocketRequest>> requests_;
1150   std::vector<TestSocketRequest*> request_order_;
1151   size_t completion_count_ = 0;
1152 };
1153 
1154 class MockTransportSocketParams
1155     : public base::RefCounted<MockTransportSocketParams> {
1156  public:
1157   MockTransportSocketParams(const MockTransportSocketParams&) = delete;
1158   MockTransportSocketParams& operator=(const MockTransportSocketParams&) =
1159       delete;
1160 
1161  private:
1162   friend class base::RefCounted<MockTransportSocketParams>;
1163   ~MockTransportSocketParams() = default;
1164 };
1165 
1166 class MockTransportClientSocketPool : public TransportClientSocketPool {
1167  public:
1168   class MockConnectJob {
1169    public:
1170     MockConnectJob(std::unique_ptr<StreamSocket> socket,
1171                    ClientSocketHandle* handle,
1172                    const SocketTag& socket_tag,
1173                    CompletionOnceCallback callback,
1174                    RequestPriority priority);
1175 
1176     MockConnectJob(const MockConnectJob&) = delete;
1177     MockConnectJob& operator=(const MockConnectJob&) = delete;
1178 
1179     ~MockConnectJob();
1180 
1181     int Connect();
1182     bool CancelHandle(const ClientSocketHandle* handle);
1183 
handle()1184     ClientSocketHandle* handle() const { return handle_; }
1185 
priority()1186     RequestPriority priority() const { return priority_; }
set_priority(RequestPriority priority)1187     void set_priority(RequestPriority priority) { priority_ = priority; }
1188 
1189    private:
1190     void OnConnect(int rv);
1191 
1192     std::unique_ptr<StreamSocket> socket_;
1193     raw_ptr<ClientSocketHandle> handle_;
1194     const SocketTag socket_tag_;
1195     CompletionOnceCallback user_callback_;
1196     RequestPriority priority_;
1197   };
1198 
1199   MockTransportClientSocketPool(
1200       int max_sockets,
1201       int max_sockets_per_group,
1202       const CommonConnectJobParams* common_connect_job_params);
1203 
1204   MockTransportClientSocketPool(const MockTransportClientSocketPool&) = delete;
1205   MockTransportClientSocketPool& operator=(
1206       const MockTransportClientSocketPool&) = delete;
1207 
1208   ~MockTransportClientSocketPool() override;
1209 
last_request_priority()1210   RequestPriority last_request_priority() const {
1211     return last_request_priority_;
1212   }
1213 
requests()1214   const std::vector<std::unique_ptr<MockConnectJob>>& requests() const {
1215     return job_list_;
1216   }
1217 
release_count()1218   int release_count() const { return release_count_; }
cancel_count()1219   int cancel_count() const { return cancel_count_; }
1220 
1221   // TransportClientSocketPool implementation.
1222   int RequestSocket(
1223       const GroupId& group_id,
1224       scoped_refptr<ClientSocketPool::SocketParams> socket_params,
1225       const absl::optional<NetworkTrafficAnnotationTag>& proxy_annotation_tag,
1226       RequestPriority priority,
1227       const SocketTag& socket_tag,
1228       RespectLimits respect_limits,
1229       ClientSocketHandle* handle,
1230       CompletionOnceCallback callback,
1231       const ProxyAuthCallback& on_auth_callback,
1232       const NetLogWithSource& net_log) override;
1233   void SetPriority(const GroupId& group_id,
1234                    ClientSocketHandle* handle,
1235                    RequestPriority priority) override;
1236   void CancelRequest(const GroupId& group_id,
1237                      ClientSocketHandle* handle,
1238                      bool cancel_connect_job) override;
1239   void ReleaseSocket(const GroupId& group_id,
1240                      std::unique_ptr<StreamSocket> socket,
1241                      int64_t generation) override;
1242 
1243  private:
1244   raw_ptr<ClientSocketFactory> client_socket_factory_;
1245   std::vector<std::unique_ptr<MockConnectJob>> job_list_;
1246   RequestPriority last_request_priority_ = DEFAULT_PRIORITY;
1247   int release_count_ = 0;
1248   int cancel_count_ = 0;
1249 };
1250 
1251 // WrappedStreamSocket is a base class that wraps an existing StreamSocket,
1252 // forwarding the Socket and StreamSocket interfaces to the underlying
1253 // transport.
1254 // This is to provide a common base class for subclasses to override specific
1255 // StreamSocket methods for testing, while still communicating with a 'real'
1256 // StreamSocket.
1257 class WrappedStreamSocket : public TransportClientSocket {
1258  public:
1259   explicit WrappedStreamSocket(std::unique_ptr<StreamSocket> transport);
1260   ~WrappedStreamSocket() override;
1261 
1262   // StreamSocket implementation:
1263   int Bind(const net::IPEndPoint& local_addr) override;
1264   int Connect(CompletionOnceCallback callback) override;
1265   void Disconnect() override;
1266   bool IsConnected() const override;
1267   bool IsConnectedAndIdle() const override;
1268   int GetPeerAddress(IPEndPoint* address) const override;
1269   int GetLocalAddress(IPEndPoint* address) const override;
1270   const NetLogWithSource& NetLog() const override;
1271   bool WasEverUsed() const override;
1272   NextProto GetNegotiatedProtocol() const override;
1273   bool GetSSLInfo(SSLInfo* ssl_info) override;
1274   int64_t GetTotalReceivedBytes() const override;
1275   void ApplySocketTag(const SocketTag& tag) override;
1276 
1277   // Socket implementation:
1278   int Read(IOBuffer* buf,
1279            int buf_len,
1280            CompletionOnceCallback callback) override;
1281   int ReadIfReady(IOBuffer* buf,
1282                   int buf_len,
1283                   CompletionOnceCallback callback) override;
1284   int Write(IOBuffer* buf,
1285             int buf_len,
1286             CompletionOnceCallback callback,
1287             const NetworkTrafficAnnotationTag& traffic_annotation) override;
1288   int SetReceiveBufferSize(int32_t size) override;
1289   int SetSendBufferSize(int32_t size) override;
1290 
1291  protected:
1292   std::unique_ptr<StreamSocket> transport_;
1293 };
1294 
1295 // StreamSocket that wraps another StreamSocket, but keeps track of any
1296 // SocketTag applied to the socket.
1297 class MockTaggingStreamSocket : public WrappedStreamSocket {
1298  public:
MockTaggingStreamSocket(std::unique_ptr<StreamSocket> transport)1299   explicit MockTaggingStreamSocket(std::unique_ptr<StreamSocket> transport)
1300       : WrappedStreamSocket(std::move(transport)) {}
1301 
1302   MockTaggingStreamSocket(const MockTaggingStreamSocket&) = delete;
1303   MockTaggingStreamSocket& operator=(const MockTaggingStreamSocket&) = delete;
1304 
1305   ~MockTaggingStreamSocket() override = default;
1306 
1307   // StreamSocket implementation.
1308   int Connect(CompletionOnceCallback callback) override;
1309   void ApplySocketTag(const SocketTag& tag) override;
1310 
1311   // Returns false if socket's tag was changed after the socket was connected,
1312   // otherwise returns true.
tagged_before_connected()1313   bool tagged_before_connected() const { return tagged_before_connected_; }
1314 
1315   // Returns last tag applied to socket.
tag()1316   SocketTag tag() const { return tag_; }
1317 
1318  private:
1319   bool connected_ = false;
1320   bool tagged_before_connected_ = true;
1321   SocketTag tag_;
1322 };
1323 
1324 // Extend MockClientSocketFactory to return MockTaggingStreamSockets and
1325 // keep track of last socket produced for test inspection.
1326 class MockTaggingClientSocketFactory : public MockClientSocketFactory {
1327  public:
1328   MockTaggingClientSocketFactory() = default;
1329 
1330   MockTaggingClientSocketFactory(const MockTaggingClientSocketFactory&) =
1331       delete;
1332   MockTaggingClientSocketFactory& operator=(
1333       const MockTaggingClientSocketFactory&) = delete;
1334 
1335   // ClientSocketFactory implementation.
1336   std::unique_ptr<DatagramClientSocket> CreateDatagramClientSocket(
1337       DatagramSocket::BindType bind_type,
1338       NetLog* net_log,
1339       const NetLogSource& source) override;
1340   std::unique_ptr<TransportClientSocket> CreateTransportClientSocket(
1341       const AddressList& addresses,
1342       std::unique_ptr<SocketPerformanceWatcher> socket_performance_watcher,
1343       NetworkQualityEstimator* network_quality_estimator,
1344       NetLog* net_log,
1345       const NetLogSource& source) override;
1346 
1347   // These methods return pointers to last TCP and UDP sockets produced by this
1348   // factory. NOTE: Socket must still exist, or pointer will be to freed memory.
GetLastProducedTCPSocket()1349   MockTaggingStreamSocket* GetLastProducedTCPSocket() const {
1350     return tcp_socket_;
1351   }
GetLastProducedUDPSocket()1352   MockUDPClientSocket* GetLastProducedUDPSocket() const { return udp_socket_; }
1353 
1354  private:
1355   raw_ptr<MockTaggingStreamSocket, AcrossTasksDanglingUntriaged> tcp_socket_ =
1356       nullptr;
1357   raw_ptr<MockUDPClientSocket, AcrossTasksDanglingUntriaged> udp_socket_ =
1358       nullptr;
1359 };
1360 
1361 // Host / port used for SOCKS4 test strings.
1362 extern const char kSOCKS4TestHost[];
1363 extern const int kSOCKS4TestPort;
1364 
1365 // Constants for a successful SOCKS v4 handshake (connecting to kSOCKS4TestHost
1366 // on port kSOCKS4TestPort, for the request).
1367 extern const char kSOCKS4OkRequestLocalHostPort80[];
1368 extern const int kSOCKS4OkRequestLocalHostPort80Length;
1369 
1370 extern const char kSOCKS4OkReply[];
1371 extern const int kSOCKS4OkReplyLength;
1372 
1373 // Host / port used for SOCKS5 test strings.
1374 extern const char kSOCKS5TestHost[];
1375 extern const int kSOCKS5TestPort;
1376 
1377 // Constants for a successful SOCKS v5 handshake (connecting to kSOCKS5TestHost
1378 // on port kSOCKS5TestPort, for the request)..
1379 extern const char kSOCKS5GreetRequest[];
1380 extern const int kSOCKS5GreetRequestLength;
1381 
1382 extern const char kSOCKS5GreetResponse[];
1383 extern const int kSOCKS5GreetResponseLength;
1384 
1385 extern const char kSOCKS5OkRequest[];
1386 extern const int kSOCKS5OkRequestLength;
1387 
1388 extern const char kSOCKS5OkResponse[];
1389 extern const int kSOCKS5OkResponseLength;
1390 
1391 // Helper function to get the total data size of the MockReads in |reads|.
1392 int64_t CountReadBytes(base::span<const MockRead> reads);
1393 
1394 // Helper function to get the total data size of the MockWrites in |writes|.
1395 int64_t CountWriteBytes(base::span<const MockWrite> writes);
1396 
1397 #if BUILDFLAG(IS_ANDROID)
1398 // Returns whether the device supports calling GetTaggedBytes().
1399 bool CanGetTaggedBytes();
1400 
1401 // Query the system to find out how many bytes were received with tag
1402 // |expected_tag| for our UID.  Return the count of received bytes.
1403 uint64_t GetTaggedBytes(int32_t expected_tag);
1404 #endif
1405 
1406 }  // namespace net
1407 
1408 #endif  // NET_SOCKET_SOCKET_TEST_UTIL_H_
1409