1 // Copyright 2012 The Chromium Authors 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #ifndef NET_SOCKET_SOCKET_TEST_UTIL_H_ 6 #define NET_SOCKET_SOCKET_TEST_UTIL_H_ 7 8 #include <stddef.h> 9 #include <stdint.h> 10 11 #include <cstring> 12 #include <memory> 13 #include <string> 14 #include <utility> 15 #include <vector> 16 17 #include "base/check_op.h" 18 #include "base/containers/span.h" 19 #include "base/functional/bind.h" 20 #include "base/functional/callback.h" 21 #include "base/memory/ptr_util.h" 22 #include "base/memory/raw_ptr.h" 23 #include "base/memory/raw_ptr_exclusion.h" 24 #include "base/memory/ref_counted.h" 25 #include "base/memory/weak_ptr.h" 26 #include "build/build_config.h" 27 #include "net/base/address_list.h" 28 #include "net/base/completion_once_callback.h" 29 #include "net/base/io_buffer.h" 30 #include "net/base/net_errors.h" 31 #include "net/base/test_completion_callback.h" 32 #include "net/http/http_auth_controller.h" 33 #include "net/log/net_log_with_source.h" 34 #include "net/socket/client_socket_factory.h" 35 #include "net/socket/client_socket_handle.h" 36 #include "net/socket/client_socket_pool.h" 37 #include "net/socket/datagram_client_socket.h" 38 #include "net/socket/socket_performance_watcher.h" 39 #include "net/socket/socket_tag.h" 40 #include "net/socket/ssl_client_socket.h" 41 #include "net/socket/transport_client_socket.h" 42 #include "net/socket/transport_client_socket_pool.h" 43 #include "net/ssl/ssl_config_service.h" 44 #include "net/ssl/ssl_info.h" 45 #include "testing/gtest/include/gtest/gtest.h" 46 #include "third_party/abseil-cpp/absl/types/optional.h" 47 48 namespace base { 49 class RunLoop; 50 } 51 52 namespace net { 53 54 struct CommonConnectJobParams; 55 class NetLog; 56 struct NetworkTrafficAnnotationTag; 57 class X509Certificate; 58 59 const handles::NetworkHandle kDefaultNetworkForTests = 1; 60 const handles::NetworkHandle kNewNetworkForTests = 2; 61 62 enum { 63 // A private network error code used by the socket test utility classes. 64 // If the |result| member of a MockRead is 65 // ERR_TEST_PEER_CLOSE_AFTER_NEXT_MOCK_READ, that MockRead is just a 66 // marker that indicates the peer will close the connection after the next 67 // MockRead. The other members of that MockRead are ignored. 68 ERR_TEST_PEER_CLOSE_AFTER_NEXT_MOCK_READ = -10000, 69 }; 70 71 class AsyncSocket; 72 class MockClientSocket; 73 class SSLClientSocket; 74 class StreamSocket; 75 76 enum IoMode { ASYNC, SYNCHRONOUS }; 77 78 struct MockConnect { 79 // Asynchronous connection success. 80 // Creates a MockConnect with |mode| ASYC, |result| OK, and 81 // |peer_addr| 192.0.2.33. 82 MockConnect(); 83 // Creates a MockConnect with the specified mode and result, with 84 // |peer_addr| 192.0.2.33. 85 MockConnect(IoMode io_mode, int r); 86 MockConnect(IoMode io_mode, int r, IPEndPoint addr); 87 MockConnect(IoMode io_mode, int r, IPEndPoint addr, bool first_attempt_fails); 88 ~MockConnect(); 89 90 IoMode mode; 91 int result; 92 IPEndPoint peer_addr; 93 bool first_attempt_fails = false; 94 }; 95 96 struct MockConfirm { 97 // Asynchronous confirm success. 98 // Creates a MockConfirm with |mode| ASYC and |result| OK. 99 MockConfirm(); 100 // Creates a MockConfirm with the specified mode and result. 101 MockConfirm(IoMode io_mode, int r); 102 ~MockConfirm(); 103 104 IoMode mode; 105 int result; 106 }; 107 108 // MockRead and MockWrite shares the same interface and members, but we'd like 109 // to have distinct types because we don't want to have them used 110 // interchangably. To do this, a struct template is defined, and MockRead and 111 // MockWrite are instantiated by using this template. Template parameter |type| 112 // is not used in the struct definition (it purely exists for creating a new 113 // type). 114 // 115 // |data| in MockRead and MockWrite has different meanings: |data| in MockRead 116 // is the data returned from the socket when MockTCPClientSocket::Read() is 117 // attempted, while |data| in MockWrite is the expected data that should be 118 // given in MockTCPClientSocket::Write(). 119 enum MockReadWriteType { MOCK_READ, MOCK_WRITE }; 120 121 template <MockReadWriteType type> 122 struct MockReadWrite { 123 // Flag to indicate that the message loop should be terminated. 124 enum { STOPLOOP = 1 << 31 }; 125 126 // Default MockReadWriteMockReadWrite127 MockReadWrite() 128 : mode(SYNCHRONOUS), 129 result(0), 130 data(nullptr), 131 data_len(0), 132 sequence_number(0) {} 133 134 // Read/write failure (no data). MockReadWriteMockReadWrite135 MockReadWrite(IoMode io_mode, int result) 136 : mode(io_mode), 137 result(result), 138 data(nullptr), 139 data_len(0), 140 sequence_number(0) {} 141 142 // Read/write failure (no data), with sequence information. MockReadWriteMockReadWrite143 MockReadWrite(IoMode io_mode, int result, int seq) 144 : mode(io_mode), 145 result(result), 146 data(nullptr), 147 data_len(0), 148 sequence_number(seq) {} 149 150 // Asynchronous read/write success (inferred data length). MockReadWriteMockReadWrite151 explicit MockReadWrite(const char* data) 152 : mode(ASYNC), 153 result(0), 154 data(data), 155 data_len(strlen(data)), 156 sequence_number(0) {} 157 158 // Read/write success (inferred data length). MockReadWriteMockReadWrite159 MockReadWrite(IoMode io_mode, const char* data) 160 : mode(io_mode), 161 result(0), 162 data(data), 163 data_len(strlen(data)), 164 sequence_number(0) {} 165 166 // Read/write success. MockReadWriteMockReadWrite167 MockReadWrite(IoMode io_mode, const char* data, int data_len) 168 : mode(io_mode), 169 result(0), 170 data(data), 171 data_len(data_len), 172 sequence_number(0) {} 173 174 // Read/write success (inferred data length) with sequence information. MockReadWriteMockReadWrite175 MockReadWrite(IoMode io_mode, int seq, const char* data) 176 : mode(io_mode), 177 result(0), 178 data(data), 179 data_len(strlen(data)), 180 sequence_number(seq) {} 181 182 // Read/write success with sequence information. MockReadWriteMockReadWrite183 MockReadWrite(IoMode io_mode, const char* data, int data_len, int seq) 184 : mode(io_mode), 185 result(0), 186 data(data), 187 data_len(data_len), 188 sequence_number(seq) {} 189 190 IoMode mode; 191 int result; 192 const char* data; 193 int data_len; 194 195 // For data providers that only allows reads to occur in a particular 196 // sequence. If a read occurs before the given |sequence_number| is reached, 197 // an ERR_IO_PENDING is returned. 198 int sequence_number; // The sequence number at which a read is allowed 199 // to occur. 200 }; 201 202 typedef MockReadWrite<MOCK_READ> MockRead; 203 typedef MockReadWrite<MOCK_WRITE> MockWrite; 204 205 struct MockWriteResult { MockWriteResultMockWriteResult206 MockWriteResult(IoMode io_mode, int result) : mode(io_mode), result(result) {} 207 208 IoMode mode; 209 int result; 210 }; 211 212 // The SocketDataProvider is an interface used by the MockClientSocket 213 // for getting data about individual reads and writes on the socket. Can be 214 // used with at most one socket at a time. 215 // TODO(mmenke): Do these really need to be re-useable? 216 class SocketDataProvider { 217 public: 218 SocketDataProvider(); 219 220 SocketDataProvider(const SocketDataProvider&) = delete; 221 SocketDataProvider& operator=(const SocketDataProvider&) = delete; 222 223 virtual ~SocketDataProvider(); 224 225 // Returns the buffer and result code for the next simulated read. 226 // If the |MockRead.result| is ERR_IO_PENDING, it informs the caller 227 // that it will be called via the AsyncSocket::OnReadComplete() 228 // function at a later time. 229 virtual MockRead OnRead() = 0; 230 virtual MockWriteResult OnWrite(const std::string& data) = 0; 231 virtual bool AllReadDataConsumed() const = 0; 232 virtual bool AllWriteDataConsumed() const = 0; CancelPendingRead()233 virtual void CancelPendingRead() {} 234 235 // Returns the last set receive buffer size, or -1 if never set. receive_buffer_size()236 int receive_buffer_size() const { return receive_buffer_size_; } set_receive_buffer_size(int receive_buffer_size)237 void set_receive_buffer_size(int receive_buffer_size) { 238 receive_buffer_size_ = receive_buffer_size; 239 } 240 241 // Returns the last set send buffer size, or -1 if never set. send_buffer_size()242 int send_buffer_size() const { return send_buffer_size_; } set_send_buffer_size(int send_buffer_size)243 void set_send_buffer_size(int send_buffer_size) { 244 send_buffer_size_ = send_buffer_size; 245 } 246 247 // Returns the last set value of TCP no delay, or false if never set. no_delay()248 bool no_delay() const { return no_delay_; } set_no_delay(bool no_delay)249 void set_no_delay(bool no_delay) { no_delay_ = no_delay; } 250 251 // Returns whether TCP keepalives were enabled or not. Returns kDefault by 252 // default. 253 enum class KeepAliveState { kEnabled, kDisabled, kDefault }; keep_alive_state()254 KeepAliveState keep_alive_state() const { return keep_alive_state_; } 255 // Last set TCP keepalive delay. keep_alive_delay()256 int keep_alive_delay() const { return keep_alive_delay_; } set_keep_alive(bool enable,int delay)257 void set_keep_alive(bool enable, int delay) { 258 keep_alive_state_ = 259 enable ? KeepAliveState::kEnabled : KeepAliveState::kDisabled; 260 keep_alive_delay_ = delay; 261 } 262 263 // Setters / getters for the return values of the corresponding Set*() 264 // methods. By default, they all succeed, if the socket is connected. 265 set_set_receive_buffer_size_result(int receive_buffer_size_result)266 void set_set_receive_buffer_size_result(int receive_buffer_size_result) { 267 set_receive_buffer_size_result_ = receive_buffer_size_result; 268 } set_receive_buffer_size_result()269 int set_receive_buffer_size_result() const { 270 return set_receive_buffer_size_result_; 271 } 272 set_set_send_buffer_size_result(int set_send_buffer_size_result)273 void set_set_send_buffer_size_result(int set_send_buffer_size_result) { 274 set_send_buffer_size_result_ = set_send_buffer_size_result; 275 } set_send_buffer_size_result()276 int set_send_buffer_size_result() const { 277 return set_send_buffer_size_result_; 278 } 279 set_set_no_delay_result(bool set_no_delay_result)280 void set_set_no_delay_result(bool set_no_delay_result) { 281 set_no_delay_result_ = set_no_delay_result; 282 } set_no_delay_result()283 bool set_no_delay_result() const { return set_no_delay_result_; } 284 set_set_keep_alive_result(bool set_keep_alive_result)285 void set_set_keep_alive_result(bool set_keep_alive_result) { 286 set_keep_alive_result_ = set_keep_alive_result; 287 } set_keep_alive_result()288 bool set_keep_alive_result() const { return set_keep_alive_result_; } 289 expected_addresses()290 const absl::optional<AddressList>& expected_addresses() const { 291 return expected_addresses_; 292 } set_expected_addresses(net::AddressList addresses)293 void set_expected_addresses(net::AddressList addresses) { 294 expected_addresses_ = std::move(addresses); 295 } 296 297 // Returns true if the request should be considered idle, for the purposes of 298 // IsConnectedAndIdle. 299 virtual bool IsIdle() const; 300 301 // Initializes the SocketDataProvider for use with |socket|. Must be called 302 // before use 303 void Initialize(AsyncSocket* socket); 304 // Detaches the socket associated with a SocketDataProvider. Must be called 305 // before |socket_| is destroyed, unless the SocketDataProvider has informed 306 // |socket_| it was destroyed. Must also be called before Initialize() may 307 // be called again with a new socket. 308 void DetachSocket(); 309 310 // Accessor for the socket which is using the SocketDataProvider. socket()311 AsyncSocket* socket() { return socket_; } 312 connect_data()313 MockConnect connect_data() const { return connect_; } set_connect_data(const MockConnect & connect)314 void set_connect_data(const MockConnect& connect) { connect_ = connect; } 315 316 private: 317 // Called to inform subclasses of initialization. 318 virtual void Reset() = 0; 319 320 MockConnect connect_; 321 raw_ptr<AsyncSocket> socket_ = nullptr; 322 323 int receive_buffer_size_ = -1; 324 int send_buffer_size_ = -1; 325 // This reflects the default state of TCPClientSockets. 326 bool no_delay_ = true; 327 328 KeepAliveState keep_alive_state_ = KeepAliveState::kDefault; 329 int keep_alive_delay_ = 0; 330 331 int set_receive_buffer_size_result_ = net::OK; 332 int set_send_buffer_size_result_ = net::OK; 333 bool set_no_delay_result_ = true; 334 bool set_keep_alive_result_ = true; 335 absl::optional<AddressList> expected_addresses_; 336 }; 337 338 // The AsyncSocket is an interface used by the SocketDataProvider to 339 // complete the asynchronous read operation. 340 class AsyncSocket { 341 public: 342 // If an async IO is pending because the SocketDataProvider returned 343 // ERR_IO_PENDING, then the AsyncSocket waits until this OnReadComplete 344 // is called to complete the asynchronous read operation. 345 // data.async is ignored, and this read is completed synchronously as 346 // part of this call. 347 // TODO(rch): this should take a StringPiece since most of the fields 348 // are ignored. 349 virtual void OnReadComplete(const MockRead& data) = 0; 350 // If an async IO is pending because the SocketDataProvider returned 351 // ERR_IO_PENDING, then the AsyncSocket waits until this OnReadComplete 352 // is called to complete the asynchronous read operation. 353 virtual void OnWriteComplete(int rv) = 0; 354 virtual void OnConnectComplete(const MockConnect& data) = 0; 355 356 // Called when the SocketDataProvider associated with the socket is destroyed. 357 // The socket may continue to be used after the data provider is destroyed, 358 // so it should be sure not to dereference the provider after this is called. 359 virtual void OnDataProviderDestroyed() = 0; 360 }; 361 362 class SocketDataPrinter { 363 public: 364 ~SocketDataPrinter() = default; 365 366 // Prints the write in |data| using some sort of protocol-specific 367 // format. 368 virtual std::string PrintWrite(const std::string& data) = 0; 369 }; 370 371 // StaticSocketDataHelper manages a list of reads and writes. 372 class StaticSocketDataHelper { 373 public: 374 StaticSocketDataHelper(base::span<const MockRead> reads, 375 base::span<const MockWrite> writes); 376 377 StaticSocketDataHelper(const StaticSocketDataHelper&) = delete; 378 StaticSocketDataHelper& operator=(const StaticSocketDataHelper&) = delete; 379 380 ~StaticSocketDataHelper(); 381 382 // These functions get access to the next available read and write data. They 383 // CHECK fail if there is no data available. 384 const MockRead& PeekRead() const; 385 const MockWrite& PeekWrite() const; 386 387 // Returns the current read or write, and then advances to the next one. 388 const MockRead& AdvanceRead(); 389 const MockWrite& AdvanceWrite(); 390 391 // Resets the read and write indexes to 0. 392 void Reset(); 393 394 // Returns true if |data| is valid data for the next write. In order 395 // to support short writes, the next write may be longer than |data| 396 // in which case this method will still return true. 397 bool VerifyWriteData(const std::string& data, SocketDataPrinter* printer); 398 read_index()399 size_t read_index() const { return read_index_; } write_index()400 size_t write_index() const { return write_index_; } read_count()401 size_t read_count() const { return reads_.size(); } write_count()402 size_t write_count() const { return writes_.size(); } 403 AllReadDataConsumed()404 bool AllReadDataConsumed() const { return read_index() >= read_count(); } AllWriteDataConsumed()405 bool AllWriteDataConsumed() const { return write_index() >= write_count(); } 406 407 private: 408 // Returns the next available read or write that is not a pause event. CHECK 409 // fails if no data is available. 410 const MockWrite& PeekRealWrite() const; 411 412 const base::span<const MockRead> reads_; 413 size_t read_index_ = 0; 414 const base::span<const MockWrite> writes_; 415 size_t write_index_ = 0; 416 }; 417 418 // SocketDataProvider which responds based on static tables of mock reads and 419 // writes. 420 class StaticSocketDataProvider : public SocketDataProvider { 421 public: 422 StaticSocketDataProvider(); 423 StaticSocketDataProvider(base::span<const MockRead> reads, 424 base::span<const MockWrite> writes); 425 426 StaticSocketDataProvider(const StaticSocketDataProvider&) = delete; 427 StaticSocketDataProvider& operator=(const StaticSocketDataProvider&) = delete; 428 429 ~StaticSocketDataProvider() override; 430 431 // Pause/resume reads from this provider. 432 void Pause(); 433 void Resume(); 434 435 // From SocketDataProvider: 436 MockRead OnRead() override; 437 MockWriteResult OnWrite(const std::string& data) override; 438 bool AllReadDataConsumed() const override; 439 bool AllWriteDataConsumed() const override; 440 read_index()441 size_t read_index() const { return helper_.read_index(); } write_index()442 size_t write_index() const { return helper_.write_index(); } read_count()443 size_t read_count() const { return helper_.read_count(); } write_count()444 size_t write_count() const { return helper_.write_count(); } 445 set_printer(SocketDataPrinter * printer)446 void set_printer(SocketDataPrinter* printer) { printer_ = printer; } 447 448 private: 449 // From SocketDataProvider: 450 void Reset() override; 451 452 StaticSocketDataHelper helper_; 453 // This field is not a raw_ptr<> because it was filtered by the rewriter for: 454 // #union 455 RAW_PTR_EXCLUSION SocketDataPrinter* printer_ = nullptr; 456 bool paused_ = false; 457 }; 458 459 // SSLSocketDataProviders only need to keep track of the return code from calls 460 // to Connect(). 461 struct SSLSocketDataProvider { 462 SSLSocketDataProvider(IoMode mode, int result); 463 SSLSocketDataProvider(const SSLSocketDataProvider& other); 464 ~SSLSocketDataProvider(); 465 466 // Returns whether MockConnect data has been consumed. ConnectDataConsumedSSLSocketDataProvider467 bool ConnectDataConsumed() const { return is_connect_data_consumed; } 468 469 // Returns whether MockConfirm data has been consumed. ConfirmDataConsumedSSLSocketDataProvider470 bool ConfirmDataConsumed() const { return is_confirm_data_consumed; } 471 472 // Returns whether a Write occurred before ConfirmHandshake completed. WriteBeforeConfirmSSLSocketDataProvider473 bool WriteBeforeConfirm() const { return write_called_before_confirm; } 474 475 // Result for Connect(). 476 MockConnect connect; 477 // Callback to run when Connect() is called. This is called at most once per 478 // socket but is repeating because SSLSocketDataProvider is copyable. 479 base::RepeatingClosure connect_callback; 480 481 // Result for ConfirmHandshake(). 482 MockConfirm confirm; 483 // Callback to run when ConfirmHandshake() is called. This is called at most 484 // once per socket but is repeating because SSLSocketDataProvider is 485 // copyable. 486 base::RepeatingClosure confirm_callback; 487 488 // Result for GetNegotiatedProtocol(). 489 NextProto next_proto = kProtoUnknown; 490 491 // Result for GetPeerApplicationSettings(). 492 absl::optional<std::string> peer_application_settings; 493 494 // Result for GetSSLInfo(). 495 SSLInfo ssl_info; 496 497 // Result for GetSSLCertRequestInfo(). 498 scoped_refptr<SSLCertRequestInfo> cert_request_info; 499 500 // Result for GetECHRetryConfigs(). 501 std::vector<uint8_t> ech_retry_configs; 502 503 absl::optional<NextProtoVector> next_protos_expected_in_ssl_config; 504 505 uint16_t expected_ssl_version_min; 506 uint16_t expected_ssl_version_max; 507 absl::optional<bool> expected_send_client_cert; 508 scoped_refptr<X509Certificate> expected_client_cert; 509 absl::optional<HostPortPair> expected_host_and_port; 510 absl::optional<bool> expected_ignore_certificate_errors; 511 absl::optional<NetworkAnonymizationKey> expected_network_anonymization_key; 512 absl::optional<bool> expected_disable_sha1_server_signatures; 513 absl::optional<std::vector<uint8_t>> expected_ech_config_list; 514 515 bool is_connect_data_consumed = false; 516 bool is_confirm_data_consumed = false; 517 bool write_called_before_confirm = false; 518 }; 519 520 // Uses the sequence_number field in the mock reads and writes to 521 // complete the operations in a specified order. 522 class SequencedSocketData : public SocketDataProvider { 523 public: 524 SequencedSocketData(); 525 526 // |reads| is the list of MockRead completions. 527 // |writes| is the list of MockWrite completions. 528 SequencedSocketData(base::span<const MockRead> reads, 529 base::span<const MockWrite> writes); 530 531 // |connect| is the result for the connect phase. 532 // |reads| is the list of MockRead completions. 533 // |writes| is the list of MockWrite completions. 534 SequencedSocketData(const MockConnect& connect, 535 base::span<const MockRead> reads, 536 base::span<const MockWrite> writes); 537 538 SequencedSocketData(const SequencedSocketData&) = delete; 539 SequencedSocketData& operator=(const SequencedSocketData&) = delete; 540 541 ~SequencedSocketData() override; 542 543 // From SocketDataProvider: 544 MockRead OnRead() override; 545 MockWriteResult OnWrite(const std::string& data) override; 546 bool AllReadDataConsumed() const override; 547 bool AllWriteDataConsumed() const override; 548 bool IsIdle() const override; 549 void CancelPendingRead() override; 550 551 // An ASYNC read event with a return value of ERR_IO_PENDING will cause the 552 // socket data to pause at that event, and advance no further, until Resume is 553 // invoked. At that point, the socket will continue at the next event in the 554 // sequence. 555 // 556 // If a request just wants to simulate a connection that stays open and never 557 // receives any more data, instead of pausing and then resuming a request, it 558 // should use a SYNCHRONOUS event with a return value of ERR_IO_PENDING 559 // instead. 560 bool IsPaused() const; 561 // Resumes events once |this| is in the paused state. The next event will 562 // occur synchronously with the call if it can. 563 void Resume(); 564 void RunUntilPaused(); 565 566 // When true, IsConnectedAndIdle() will return false if the next event in the 567 // sequence is a synchronous. Otherwise, the socket claims to be idle as 568 // long as it's connected. Defaults to false. 569 // TODO(mmenke): See if this can be made the default behavior, and consider 570 // removing this mehtod. Need to make sure it doesn't change what code any 571 // tests are targetted at testing. set_busy_before_sync_reads(bool busy_before_sync_reads)572 void set_busy_before_sync_reads(bool busy_before_sync_reads) { 573 busy_before_sync_reads_ = busy_before_sync_reads; 574 } 575 set_printer(SocketDataPrinter * printer)576 void set_printer(SocketDataPrinter* printer) { printer_ = printer; } 577 578 private: 579 // Defines the state for the read or write path. 580 enum class IoState { 581 kIdle, // No async operation is in progress. 582 kPending, // An async operation in waiting for another operation to 583 // complete. 584 kCompleting, // A task has been posted to complete an async operation. 585 kPaused, // IO is paused until Resume() is called. 586 }; 587 588 // From SocketDataProvider: 589 void Reset() override; 590 591 void OnReadComplete(); 592 void OnWriteComplete(); 593 594 void MaybePostReadCompleteTask(); 595 void MaybePostWriteCompleteTask(); 596 597 StaticSocketDataHelper helper_; 598 raw_ptr<SocketDataPrinter> printer_ = nullptr; 599 int sequence_number_ = 0; 600 IoState read_state_ = IoState::kIdle; 601 IoState write_state_ = IoState::kIdle; 602 603 bool busy_before_sync_reads_ = false; 604 605 // Used by RunUntilPaused. NULL at all other times. 606 std::unique_ptr<base::RunLoop> run_until_paused_run_loop_; 607 608 base::WeakPtrFactory<SequencedSocketData> weak_factory_{this}; 609 }; 610 611 // Holds an array of SocketDataProvider elements. As Mock{TCP,SSL}StreamSocket 612 // objects get instantiated, they take their data from the i'th element of this 613 // array. 614 template <typename T> 615 class SocketDataProviderArray { 616 public: 617 SocketDataProviderArray() = default; 618 GetNext()619 T* GetNext() { 620 DCHECK_LT(next_index_, data_providers_.size()); 621 return data_providers_[next_index_++]; 622 } 623 624 // Like GetNext(), but returns nullptr when the end of the array is reached, 625 // instead of DCHECKing. GetNext() should generally be preferred, unless 626 // having no remaining elements is expected in some cases and is handled 627 // safely. GetNextWithoutAsserting()628 T* GetNextWithoutAsserting() { 629 if (next_index_ == data_providers_.size()) 630 return nullptr; 631 return data_providers_[next_index_++]; 632 } 633 Add(T * data_provider)634 void Add(T* data_provider) { 635 DCHECK(data_provider); 636 data_providers_.push_back(data_provider); 637 } 638 next_index()639 size_t next_index() { return next_index_; } 640 ResetNextIndex()641 void ResetNextIndex() { next_index_ = 0; } 642 643 private: 644 // Index of the next |data_providers_| element to use. Not an iterator 645 // because those are invalidated on vector reallocation. 646 size_t next_index_ = 0; 647 648 // SocketDataProviders to be returned. 649 std::vector<T*> data_providers_; 650 }; 651 652 class MockUDPClientSocket; 653 class MockTCPClientSocket; 654 class MockSSLClientSocket; 655 656 // ClientSocketFactory which contains arrays of sockets of each type. 657 // You should first fill the arrays using Add{SSL,}SocketDataProvider(). When 658 // the factory is asked to create a socket, it takes next entry from appropriate 659 // array. You can use ResetNextMockIndexes to reset that next entry index for 660 // all mock socket types. 661 class MockClientSocketFactory : public ClientSocketFactory { 662 public: 663 MockClientSocketFactory(); 664 665 MockClientSocketFactory(const MockClientSocketFactory&) = delete; 666 MockClientSocketFactory& operator=(const MockClientSocketFactory&) = delete; 667 668 ~MockClientSocketFactory() override; 669 670 // Adds a SocketDataProvider that can be used to served either TCP or UDP 671 // connection requests. Sockets are returned in FIFO order. 672 void AddSocketDataProvider(SocketDataProvider* socket); 673 674 // Like AddSocketDataProvider(), except sockets will only be used to service 675 // TCP connection requests. Sockets added with this method are used first, 676 // before sockets added with AddSocketDataProvider(). Particularly useful for 677 // QUIC tests with multiple sockets, where TCP connections may or may not be 678 // made, and have no guaranteed order, relative to UDP connections. 679 void AddTcpSocketDataProvider(SocketDataProvider* socket); 680 681 void AddSSLSocketDataProvider(SSLSocketDataProvider* socket); 682 void ResetNextMockIndexes(); 683 mock_data()684 SocketDataProviderArray<SocketDataProvider>& mock_data() { 685 return mock_data_; 686 } 687 set_enable_read_if_ready(bool enable_read_if_ready)688 void set_enable_read_if_ready(bool enable_read_if_ready) { 689 enable_read_if_ready_ = enable_read_if_ready; 690 } 691 692 // ClientSocketFactory 693 std::unique_ptr<DatagramClientSocket> CreateDatagramClientSocket( 694 DatagramSocket::BindType bind_type, 695 NetLog* net_log, 696 const NetLogSource& source) override; 697 std::unique_ptr<TransportClientSocket> CreateTransportClientSocket( 698 const AddressList& addresses, 699 std::unique_ptr<SocketPerformanceWatcher> socket_performance_watcher, 700 NetworkQualityEstimator* network_quality_estimator, 701 NetLog* net_log, 702 const NetLogSource& source) override; 703 std::unique_ptr<SSLClientSocket> CreateSSLClientSocket( 704 SSLClientContext* context, 705 std::unique_ptr<StreamSocket> stream_socket, 706 const HostPortPair& host_and_port, 707 const SSLConfig& ssl_config) override; udp_client_socket_ports()708 const std::vector<uint16_t>& udp_client_socket_ports() const { 709 return udp_client_socket_ports_; 710 } 711 712 private: 713 SocketDataProviderArray<SocketDataProvider> mock_data_; 714 SocketDataProviderArray<SocketDataProvider> mock_tcp_data_; 715 SocketDataProviderArray<SSLSocketDataProvider> mock_ssl_data_; 716 std::vector<uint16_t> udp_client_socket_ports_; 717 718 // If true, ReadIfReady() is enabled; otherwise ReadIfReady() returns 719 // ERR_READ_IF_READY_NOT_IMPLEMENTED. 720 bool enable_read_if_ready_ = false; 721 }; 722 723 class MockClientSocket : public TransportClientSocket { 724 public: 725 // The NetLogWithSource is needed to test LoadTimingInfo, which uses NetLog 726 // IDs as 727 // unique socket IDs. 728 explicit MockClientSocket(const NetLogWithSource& net_log); 729 730 MockClientSocket(const MockClientSocket&) = delete; 731 MockClientSocket& operator=(const MockClientSocket&) = delete; 732 733 // Socket implementation. 734 int Read(IOBuffer* buf, 735 int buf_len, 736 CompletionOnceCallback callback) override = 0; 737 int Write(IOBuffer* buf, 738 int buf_len, 739 CompletionOnceCallback callback, 740 const NetworkTrafficAnnotationTag& traffic_annotation) override = 0; 741 int SetReceiveBufferSize(int32_t size) override; 742 int SetSendBufferSize(int32_t size) override; 743 744 // TransportClientSocket implementation. 745 int Bind(const net::IPEndPoint& local_addr) override; 746 bool SetNoDelay(bool no_delay) override; 747 bool SetKeepAlive(bool enable, int delay) override; 748 749 // StreamSocket implementation. 750 int Connect(CompletionOnceCallback callback) override = 0; 751 void Disconnect() override; 752 bool IsConnected() const override; 753 bool IsConnectedAndIdle() const override; 754 int GetPeerAddress(IPEndPoint* address) const override; 755 int GetLocalAddress(IPEndPoint* address) const override; 756 const NetLogWithSource& NetLog() const override; 757 NextProto GetNegotiatedProtocol() const override; 758 int64_t GetTotalReceivedBytes() const override; ApplySocketTag(const SocketTag & tag)759 void ApplySocketTag(const SocketTag& tag) override {} 760 761 protected: 762 ~MockClientSocket() override; 763 void RunCallbackAsync(CompletionOnceCallback callback, int result); 764 void RunCallback(CompletionOnceCallback callback, int result); 765 766 // True if Connect completed successfully and Disconnect hasn't been called. 767 bool connected_ = false; 768 769 IPEndPoint local_addr_; 770 IPEndPoint peer_addr_; 771 772 NetLogWithSource net_log_; 773 774 private: 775 base::WeakPtrFactory<MockClientSocket> weak_factory_{this}; 776 }; 777 778 class MockTCPClientSocket : public MockClientSocket, public AsyncSocket { 779 public: 780 MockTCPClientSocket(const AddressList& addresses, 781 net::NetLog* net_log, 782 SocketDataProvider* socket); 783 784 MockTCPClientSocket(const MockTCPClientSocket&) = delete; 785 MockTCPClientSocket& operator=(const MockTCPClientSocket&) = delete; 786 787 ~MockTCPClientSocket() override; 788 addresses()789 const AddressList& addresses() const { return addresses_; } 790 791 // Socket implementation. 792 int Read(IOBuffer* buf, 793 int buf_len, 794 CompletionOnceCallback callback) override; 795 int ReadIfReady(IOBuffer* buf, 796 int buf_len, 797 CompletionOnceCallback callback) override; 798 int CancelReadIfReady() override; 799 int Write(IOBuffer* buf, 800 int buf_len, 801 CompletionOnceCallback callback, 802 const NetworkTrafficAnnotationTag& traffic_annotation) override; 803 int SetReceiveBufferSize(int32_t size) override; 804 int SetSendBufferSize(int32_t size) override; 805 806 // TransportClientSocket implementation. 807 bool SetNoDelay(bool no_delay) override; 808 bool SetKeepAlive(bool enable, int delay) override; 809 810 // StreamSocket implementation. 811 void SetBeforeConnectCallback( 812 const BeforeConnectCallback& before_connect_callback) override; 813 int Connect(CompletionOnceCallback callback) override; 814 void Disconnect() override; 815 bool IsConnected() const override; 816 bool IsConnectedAndIdle() const override; 817 int GetPeerAddress(IPEndPoint* address) const override; 818 bool WasEverUsed() const override; 819 bool GetSSLInfo(SSLInfo* ssl_info) override; 820 821 // AsyncSocket: 822 void OnReadComplete(const MockRead& data) override; 823 void OnWriteComplete(int rv) override; 824 void OnConnectComplete(const MockConnect& data) override; 825 void OnDataProviderDestroyed() override; 826 set_enable_read_if_ready(bool enable_read_if_ready)827 void set_enable_read_if_ready(bool enable_read_if_ready) { 828 enable_read_if_ready_ = enable_read_if_ready; 829 } 830 831 private: 832 void RetryRead(int rv); 833 int ReadIfReadyImpl(IOBuffer* buf, 834 int buf_len, 835 CompletionOnceCallback callback); 836 837 // Helper method to run |pending_read_if_ready_callback_| if it is not null. 838 void RunReadIfReadyCallback(int result); 839 840 AddressList addresses_; 841 842 raw_ptr<SocketDataProvider> data_; 843 int read_offset_ = 0; 844 MockRead read_data_; 845 bool need_read_data_ = true; 846 847 // True if the peer has closed the connection. This allows us to simulate 848 // the recv(..., MSG_PEEK) call in the IsConnectedAndIdle method of the real 849 // TCPClientSocket. 850 bool peer_closed_connection_ = false; 851 852 // While an asynchronous read is pending, we save our user-buffer state. 853 scoped_refptr<IOBuffer> pending_read_buf_ = nullptr; 854 int pending_read_buf_len_ = 0; 855 CompletionOnceCallback pending_read_callback_; 856 857 // Non-null when a ReadIfReady() is pending. 858 CompletionOnceCallback pending_read_if_ready_callback_; 859 860 CompletionOnceCallback pending_connect_callback_; 861 CompletionOnceCallback pending_write_callback_; 862 bool was_used_to_convey_data_ = false; 863 864 // If true, ReadIfReady() is enabled; otherwise ReadIfReady() returns 865 // ERR_READ_IF_READY_NOT_IMPLEMENTED. 866 bool enable_read_if_ready_ = false; 867 868 BeforeConnectCallback before_connect_callback_; 869 }; 870 871 class MockSSLClientSocket : public AsyncSocket, public SSLClientSocket { 872 public: 873 MockSSLClientSocket(std::unique_ptr<StreamSocket> stream_socket, 874 const HostPortPair& host_and_port, 875 const SSLConfig& ssl_config, 876 SSLSocketDataProvider* socket); 877 878 MockSSLClientSocket(const MockSSLClientSocket&) = delete; 879 MockSSLClientSocket& operator=(const MockSSLClientSocket&) = delete; 880 881 ~MockSSLClientSocket() override; 882 883 // Socket implementation. 884 int Read(IOBuffer* buf, 885 int buf_len, 886 CompletionOnceCallback callback) override; 887 int ReadIfReady(IOBuffer* buf, 888 int buf_len, 889 CompletionOnceCallback callback) override; 890 int Write(IOBuffer* buf, 891 int buf_len, 892 CompletionOnceCallback callback, 893 const NetworkTrafficAnnotationTag& traffic_annotation) override; 894 int CancelReadIfReady() override; 895 896 // StreamSocket implementation. 897 int Connect(CompletionOnceCallback callback) override; 898 void Disconnect() override; 899 int ConfirmHandshake(CompletionOnceCallback callback) override; 900 bool IsConnected() const override; 901 bool IsConnectedAndIdle() const override; 902 bool WasEverUsed() const override; 903 int GetPeerAddress(IPEndPoint* address) const override; 904 int GetLocalAddress(IPEndPoint* address) const override; 905 NextProto GetNegotiatedProtocol() const override; 906 absl::optional<base::StringPiece> GetPeerApplicationSettings() const override; 907 bool GetSSLInfo(SSLInfo* ssl_info) override; 908 void GetSSLCertRequestInfo( 909 SSLCertRequestInfo* cert_request_info) const override; 910 void ApplySocketTag(const SocketTag& tag) override; 911 const NetLogWithSource& NetLog() const override; 912 int64_t GetTotalReceivedBytes() const override; 913 int SetReceiveBufferSize(int32_t size) override; 914 int SetSendBufferSize(int32_t size) override; 915 916 // SSLSocket implementation. 917 int ExportKeyingMaterial(base::StringPiece label, 918 bool has_context, 919 base::StringPiece context, 920 unsigned char* out, 921 unsigned int outlen) override; 922 923 // SSLClientSocket implementation. 924 std::vector<uint8_t> GetECHRetryConfigs() override; 925 926 // This MockSocket does not implement the manual async IO feature. 927 void OnReadComplete(const MockRead& data) override; 928 void OnWriteComplete(int rv) override; 929 void OnConnectComplete(const MockConnect& data) override; 930 // SSL sockets don't need magic to deal with destruction of their data 931 // provider. 932 // TODO(mmenke): Probably a good idea to support it, anyways. OnDataProviderDestroyed()933 void OnDataProviderDestroyed() override {} 934 935 private: 936 static void ConnectCallback(MockSSLClientSocket* ssl_client_socket, 937 CompletionOnceCallback callback, 938 int rv); 939 940 void RunCallbackAsync(CompletionOnceCallback callback, int result); 941 void RunCallback(CompletionOnceCallback callback, int result); 942 943 void RunConfirmHandshakeCallback(CompletionOnceCallback callback, int result); 944 945 bool connected_ = false; 946 bool in_confirm_handshake_ = false; 947 NetLogWithSource net_log_; 948 std::unique_ptr<StreamSocket> stream_socket_; 949 raw_ptr<SSLSocketDataProvider, AcrossTasksDanglingUntriaged> data_; 950 // Address of the "remote" peer we're connected to. 951 IPEndPoint peer_addr_; 952 953 base::WeakPtrFactory<MockSSLClientSocket> weak_factory_{this}; 954 }; 955 956 class MockUDPClientSocket : public DatagramClientSocket, public AsyncSocket { 957 public: 958 explicit MockUDPClientSocket(SocketDataProvider* data = nullptr, 959 net::NetLog* net_log = nullptr); 960 961 MockUDPClientSocket(const MockUDPClientSocket&) = delete; 962 MockUDPClientSocket& operator=(const MockUDPClientSocket&) = delete; 963 964 ~MockUDPClientSocket() override; 965 966 // Socket implementation. 967 int Read(IOBuffer* buf, 968 int buf_len, 969 CompletionOnceCallback callback) override; 970 int Write(IOBuffer* buf, 971 int buf_len, 972 CompletionOnceCallback callback, 973 const NetworkTrafficAnnotationTag& traffic_annotation) override; 974 975 int SetReceiveBufferSize(int32_t size) override; 976 int SetSendBufferSize(int32_t size) override; 977 int SetDoNotFragment() override; 978 int SetRecvEcn() override; 979 980 // DatagramSocket implementation. 981 void Close() override; 982 int GetPeerAddress(IPEndPoint* address) const override; 983 int GetLocalAddress(IPEndPoint* address) const override; 984 void UseNonBlockingIO() override; 985 int SetMulticastInterface(uint32_t interface_index) override; 986 const NetLogWithSource& NetLog() const override; 987 988 // DatagramClientSocket implementation. 989 int Connect(const IPEndPoint& address) override; 990 int ConnectUsingNetwork(handles::NetworkHandle network, 991 const IPEndPoint& address) override; 992 int ConnectUsingDefaultNetwork(const IPEndPoint& address) override; 993 int ConnectAsync(const IPEndPoint& address, 994 CompletionOnceCallback callback) override; 995 int ConnectUsingNetworkAsync(handles::NetworkHandle network, 996 const IPEndPoint& address, 997 CompletionOnceCallback callback) override; 998 int ConnectUsingDefaultNetworkAsync(const IPEndPoint& address, 999 CompletionOnceCallback callback) override; 1000 handles::NetworkHandle GetBoundNetwork() const override; 1001 void ApplySocketTag(const SocketTag& tag) override; SetMsgConfirm(bool confirm)1002 void SetMsgConfirm(bool confirm) override {} 1003 1004 // AsyncSocket implementation. 1005 void OnReadComplete(const MockRead& data) override; 1006 void OnWriteComplete(int rv) override; 1007 void OnConnectComplete(const MockConnect& data) override; 1008 void OnDataProviderDestroyed() override; 1009 set_source_port(uint16_t port)1010 void set_source_port(uint16_t port) { source_port_ = port; } source_port()1011 uint16_t source_port() const { return source_port_; } set_source_host(IPAddress addr)1012 void set_source_host(IPAddress addr) { source_host_ = addr; } source_host()1013 IPAddress source_host() const { return source_host_; } 1014 1015 // Returns last tag applied to socket. tag()1016 SocketTag tag() const { return tag_; } 1017 1018 // Returns false if socket's tag was changed after the socket was used for 1019 // data transfer (e.g. Read/Write() called), otherwise returns true. tagged_before_data_transferred()1020 bool tagged_before_data_transferred() const { 1021 return tagged_before_data_transferred_; 1022 } 1023 1024 private: 1025 int CompleteRead(); 1026 1027 void RunCallbackAsync(CompletionOnceCallback callback, int result); 1028 void RunCallback(CompletionOnceCallback callback, int result); 1029 1030 bool connected_ = false; 1031 raw_ptr<SocketDataProvider> data_; 1032 int read_offset_ = 0; 1033 MockRead read_data_; 1034 bool need_read_data_ = true; 1035 IPAddress source_host_; 1036 uint16_t source_port_ = 123; // Ephemeral source port. 1037 1038 // Address of the "remote" peer we're connected to. 1039 IPEndPoint peer_addr_; 1040 1041 // Network that the socket is bound to. 1042 handles::NetworkHandle network_ = handles::kInvalidNetworkHandle; 1043 1044 // While an asynchronous IO is pending, we save our user-buffer state. 1045 scoped_refptr<IOBuffer> pending_read_buf_ = nullptr; 1046 int pending_read_buf_len_ = 0; 1047 CompletionOnceCallback pending_read_callback_; 1048 CompletionOnceCallback pending_write_callback_; 1049 1050 NetLogWithSource net_log_; 1051 1052 DatagramBuffers unwritten_buffers_; 1053 1054 SocketTag tag_; 1055 bool data_transferred_ = false; 1056 bool tagged_before_data_transferred_ = true; 1057 1058 base::WeakPtrFactory<MockUDPClientSocket> weak_factory_{this}; 1059 }; 1060 1061 class TestSocketRequest : public TestCompletionCallbackBase { 1062 public: 1063 TestSocketRequest(std::vector<TestSocketRequest*>* request_order, 1064 size_t* completion_count); 1065 1066 TestSocketRequest(const TestSocketRequest&) = delete; 1067 TestSocketRequest& operator=(const TestSocketRequest&) = delete; 1068 1069 ~TestSocketRequest() override; 1070 handle()1071 ClientSocketHandle* handle() { return &handle_; } 1072 callback()1073 CompletionOnceCallback callback() { 1074 return base::BindOnce(&TestSocketRequest::OnComplete, 1075 base::Unretained(this)); 1076 } 1077 1078 private: 1079 void OnComplete(int result); 1080 1081 ClientSocketHandle handle_; 1082 raw_ptr<std::vector<TestSocketRequest*>> request_order_; 1083 raw_ptr<size_t> completion_count_; 1084 }; 1085 1086 class ClientSocketPoolTest { 1087 public: 1088 enum KeepAlive { 1089 KEEP_ALIVE, 1090 1091 // A socket will be disconnected in addition to handle being reset. 1092 NO_KEEP_ALIVE, 1093 }; 1094 1095 static const int kIndexOutOfBounds; 1096 static const int kRequestNotFound; 1097 1098 ClientSocketPoolTest(); 1099 1100 ClientSocketPoolTest(const ClientSocketPoolTest&) = delete; 1101 ClientSocketPoolTest& operator=(const ClientSocketPoolTest&) = delete; 1102 1103 ~ClientSocketPoolTest(); 1104 1105 template <typename PoolType> StartRequestUsingPool(PoolType * socket_pool,const ClientSocketPool::GroupId & group_id,RequestPriority priority,ClientSocketPool::RespectLimits respect_limits,const scoped_refptr<typename PoolType::SocketParams> & socket_params)1106 int StartRequestUsingPool( 1107 PoolType* socket_pool, 1108 const ClientSocketPool::GroupId& group_id, 1109 RequestPriority priority, 1110 ClientSocketPool::RespectLimits respect_limits, 1111 const scoped_refptr<typename PoolType::SocketParams>& socket_params) { 1112 DCHECK(socket_pool); 1113 TestSocketRequest* request( 1114 new TestSocketRequest(&request_order_, &completion_count_)); 1115 requests_.push_back(base::WrapUnique(request)); 1116 int rv = request->handle()->Init( 1117 group_id, socket_params, absl::nullopt /* proxy_annotation_tag */, 1118 priority, SocketTag(), respect_limits, request->callback(), 1119 ClientSocketPool::ProxyAuthCallback(), socket_pool, NetLogWithSource()); 1120 if (rv != ERR_IO_PENDING) 1121 request_order_.push_back(request); 1122 return rv; 1123 } 1124 1125 // Provided there were n requests started, takes |index| in range 1..n 1126 // and returns order in which that request completed, in range 1..n, 1127 // or kIndexOutOfBounds if |index| is out of bounds, or kRequestNotFound 1128 // if that request did not complete (for example was canceled). 1129 int GetOrderOfRequest(size_t index) const; 1130 1131 // Resets first initialized socket handle from |requests_|. If found such 1132 // a handle, returns true. 1133 bool ReleaseOneConnection(KeepAlive keep_alive); 1134 1135 // Releases connections until there is nothing to release. 1136 void ReleaseAllConnections(KeepAlive keep_alive); 1137 1138 // Note that this uses 0-based indices, while GetOrderOfRequest takes and 1139 // returns 1-based indices. request(int i)1140 TestSocketRequest* request(int i) { return requests_[i].get(); } 1141 requests_size()1142 size_t requests_size() const { return requests_.size(); } requests()1143 std::vector<std::unique_ptr<TestSocketRequest>>* requests() { 1144 return &requests_; 1145 } completion_count()1146 size_t completion_count() const { return completion_count_; } 1147 1148 private: 1149 std::vector<std::unique_ptr<TestSocketRequest>> requests_; 1150 std::vector<TestSocketRequest*> request_order_; 1151 size_t completion_count_ = 0; 1152 }; 1153 1154 class MockTransportSocketParams 1155 : public base::RefCounted<MockTransportSocketParams> { 1156 public: 1157 MockTransportSocketParams(const MockTransportSocketParams&) = delete; 1158 MockTransportSocketParams& operator=(const MockTransportSocketParams&) = 1159 delete; 1160 1161 private: 1162 friend class base::RefCounted<MockTransportSocketParams>; 1163 ~MockTransportSocketParams() = default; 1164 }; 1165 1166 class MockTransportClientSocketPool : public TransportClientSocketPool { 1167 public: 1168 class MockConnectJob { 1169 public: 1170 MockConnectJob(std::unique_ptr<StreamSocket> socket, 1171 ClientSocketHandle* handle, 1172 const SocketTag& socket_tag, 1173 CompletionOnceCallback callback, 1174 RequestPriority priority); 1175 1176 MockConnectJob(const MockConnectJob&) = delete; 1177 MockConnectJob& operator=(const MockConnectJob&) = delete; 1178 1179 ~MockConnectJob(); 1180 1181 int Connect(); 1182 bool CancelHandle(const ClientSocketHandle* handle); 1183 handle()1184 ClientSocketHandle* handle() const { return handle_; } 1185 priority()1186 RequestPriority priority() const { return priority_; } set_priority(RequestPriority priority)1187 void set_priority(RequestPriority priority) { priority_ = priority; } 1188 1189 private: 1190 void OnConnect(int rv); 1191 1192 std::unique_ptr<StreamSocket> socket_; 1193 raw_ptr<ClientSocketHandle> handle_; 1194 const SocketTag socket_tag_; 1195 CompletionOnceCallback user_callback_; 1196 RequestPriority priority_; 1197 }; 1198 1199 MockTransportClientSocketPool( 1200 int max_sockets, 1201 int max_sockets_per_group, 1202 const CommonConnectJobParams* common_connect_job_params); 1203 1204 MockTransportClientSocketPool(const MockTransportClientSocketPool&) = delete; 1205 MockTransportClientSocketPool& operator=( 1206 const MockTransportClientSocketPool&) = delete; 1207 1208 ~MockTransportClientSocketPool() override; 1209 last_request_priority()1210 RequestPriority last_request_priority() const { 1211 return last_request_priority_; 1212 } 1213 requests()1214 const std::vector<std::unique_ptr<MockConnectJob>>& requests() const { 1215 return job_list_; 1216 } 1217 release_count()1218 int release_count() const { return release_count_; } cancel_count()1219 int cancel_count() const { return cancel_count_; } 1220 1221 // TransportClientSocketPool implementation. 1222 int RequestSocket( 1223 const GroupId& group_id, 1224 scoped_refptr<ClientSocketPool::SocketParams> socket_params, 1225 const absl::optional<NetworkTrafficAnnotationTag>& proxy_annotation_tag, 1226 RequestPriority priority, 1227 const SocketTag& socket_tag, 1228 RespectLimits respect_limits, 1229 ClientSocketHandle* handle, 1230 CompletionOnceCallback callback, 1231 const ProxyAuthCallback& on_auth_callback, 1232 const NetLogWithSource& net_log) override; 1233 void SetPriority(const GroupId& group_id, 1234 ClientSocketHandle* handle, 1235 RequestPriority priority) override; 1236 void CancelRequest(const GroupId& group_id, 1237 ClientSocketHandle* handle, 1238 bool cancel_connect_job) override; 1239 void ReleaseSocket(const GroupId& group_id, 1240 std::unique_ptr<StreamSocket> socket, 1241 int64_t generation) override; 1242 1243 private: 1244 raw_ptr<ClientSocketFactory> client_socket_factory_; 1245 std::vector<std::unique_ptr<MockConnectJob>> job_list_; 1246 RequestPriority last_request_priority_ = DEFAULT_PRIORITY; 1247 int release_count_ = 0; 1248 int cancel_count_ = 0; 1249 }; 1250 1251 // WrappedStreamSocket is a base class that wraps an existing StreamSocket, 1252 // forwarding the Socket and StreamSocket interfaces to the underlying 1253 // transport. 1254 // This is to provide a common base class for subclasses to override specific 1255 // StreamSocket methods for testing, while still communicating with a 'real' 1256 // StreamSocket. 1257 class WrappedStreamSocket : public TransportClientSocket { 1258 public: 1259 explicit WrappedStreamSocket(std::unique_ptr<StreamSocket> transport); 1260 ~WrappedStreamSocket() override; 1261 1262 // StreamSocket implementation: 1263 int Bind(const net::IPEndPoint& local_addr) override; 1264 int Connect(CompletionOnceCallback callback) override; 1265 void Disconnect() override; 1266 bool IsConnected() const override; 1267 bool IsConnectedAndIdle() const override; 1268 int GetPeerAddress(IPEndPoint* address) const override; 1269 int GetLocalAddress(IPEndPoint* address) const override; 1270 const NetLogWithSource& NetLog() const override; 1271 bool WasEverUsed() const override; 1272 NextProto GetNegotiatedProtocol() const override; 1273 bool GetSSLInfo(SSLInfo* ssl_info) override; 1274 int64_t GetTotalReceivedBytes() const override; 1275 void ApplySocketTag(const SocketTag& tag) override; 1276 1277 // Socket implementation: 1278 int Read(IOBuffer* buf, 1279 int buf_len, 1280 CompletionOnceCallback callback) override; 1281 int ReadIfReady(IOBuffer* buf, 1282 int buf_len, 1283 CompletionOnceCallback callback) override; 1284 int Write(IOBuffer* buf, 1285 int buf_len, 1286 CompletionOnceCallback callback, 1287 const NetworkTrafficAnnotationTag& traffic_annotation) override; 1288 int SetReceiveBufferSize(int32_t size) override; 1289 int SetSendBufferSize(int32_t size) override; 1290 1291 protected: 1292 std::unique_ptr<StreamSocket> transport_; 1293 }; 1294 1295 // StreamSocket that wraps another StreamSocket, but keeps track of any 1296 // SocketTag applied to the socket. 1297 class MockTaggingStreamSocket : public WrappedStreamSocket { 1298 public: MockTaggingStreamSocket(std::unique_ptr<StreamSocket> transport)1299 explicit MockTaggingStreamSocket(std::unique_ptr<StreamSocket> transport) 1300 : WrappedStreamSocket(std::move(transport)) {} 1301 1302 MockTaggingStreamSocket(const MockTaggingStreamSocket&) = delete; 1303 MockTaggingStreamSocket& operator=(const MockTaggingStreamSocket&) = delete; 1304 1305 ~MockTaggingStreamSocket() override = default; 1306 1307 // StreamSocket implementation. 1308 int Connect(CompletionOnceCallback callback) override; 1309 void ApplySocketTag(const SocketTag& tag) override; 1310 1311 // Returns false if socket's tag was changed after the socket was connected, 1312 // otherwise returns true. tagged_before_connected()1313 bool tagged_before_connected() const { return tagged_before_connected_; } 1314 1315 // Returns last tag applied to socket. tag()1316 SocketTag tag() const { return tag_; } 1317 1318 private: 1319 bool connected_ = false; 1320 bool tagged_before_connected_ = true; 1321 SocketTag tag_; 1322 }; 1323 1324 // Extend MockClientSocketFactory to return MockTaggingStreamSockets and 1325 // keep track of last socket produced for test inspection. 1326 class MockTaggingClientSocketFactory : public MockClientSocketFactory { 1327 public: 1328 MockTaggingClientSocketFactory() = default; 1329 1330 MockTaggingClientSocketFactory(const MockTaggingClientSocketFactory&) = 1331 delete; 1332 MockTaggingClientSocketFactory& operator=( 1333 const MockTaggingClientSocketFactory&) = delete; 1334 1335 // ClientSocketFactory implementation. 1336 std::unique_ptr<DatagramClientSocket> CreateDatagramClientSocket( 1337 DatagramSocket::BindType bind_type, 1338 NetLog* net_log, 1339 const NetLogSource& source) override; 1340 std::unique_ptr<TransportClientSocket> CreateTransportClientSocket( 1341 const AddressList& addresses, 1342 std::unique_ptr<SocketPerformanceWatcher> socket_performance_watcher, 1343 NetworkQualityEstimator* network_quality_estimator, 1344 NetLog* net_log, 1345 const NetLogSource& source) override; 1346 1347 // These methods return pointers to last TCP and UDP sockets produced by this 1348 // factory. NOTE: Socket must still exist, or pointer will be to freed memory. GetLastProducedTCPSocket()1349 MockTaggingStreamSocket* GetLastProducedTCPSocket() const { 1350 return tcp_socket_; 1351 } GetLastProducedUDPSocket()1352 MockUDPClientSocket* GetLastProducedUDPSocket() const { return udp_socket_; } 1353 1354 private: 1355 raw_ptr<MockTaggingStreamSocket, AcrossTasksDanglingUntriaged> tcp_socket_ = 1356 nullptr; 1357 raw_ptr<MockUDPClientSocket, AcrossTasksDanglingUntriaged> udp_socket_ = 1358 nullptr; 1359 }; 1360 1361 // Host / port used for SOCKS4 test strings. 1362 extern const char kSOCKS4TestHost[]; 1363 extern const int kSOCKS4TestPort; 1364 1365 // Constants for a successful SOCKS v4 handshake (connecting to kSOCKS4TestHost 1366 // on port kSOCKS4TestPort, for the request). 1367 extern const char kSOCKS4OkRequestLocalHostPort80[]; 1368 extern const int kSOCKS4OkRequestLocalHostPort80Length; 1369 1370 extern const char kSOCKS4OkReply[]; 1371 extern const int kSOCKS4OkReplyLength; 1372 1373 // Host / port used for SOCKS5 test strings. 1374 extern const char kSOCKS5TestHost[]; 1375 extern const int kSOCKS5TestPort; 1376 1377 // Constants for a successful SOCKS v5 handshake (connecting to kSOCKS5TestHost 1378 // on port kSOCKS5TestPort, for the request).. 1379 extern const char kSOCKS5GreetRequest[]; 1380 extern const int kSOCKS5GreetRequestLength; 1381 1382 extern const char kSOCKS5GreetResponse[]; 1383 extern const int kSOCKS5GreetResponseLength; 1384 1385 extern const char kSOCKS5OkRequest[]; 1386 extern const int kSOCKS5OkRequestLength; 1387 1388 extern const char kSOCKS5OkResponse[]; 1389 extern const int kSOCKS5OkResponseLength; 1390 1391 // Helper function to get the total data size of the MockReads in |reads|. 1392 int64_t CountReadBytes(base::span<const MockRead> reads); 1393 1394 // Helper function to get the total data size of the MockWrites in |writes|. 1395 int64_t CountWriteBytes(base::span<const MockWrite> writes); 1396 1397 #if BUILDFLAG(IS_ANDROID) 1398 // Returns whether the device supports calling GetTaggedBytes(). 1399 bool CanGetTaggedBytes(); 1400 1401 // Query the system to find out how many bytes were received with tag 1402 // |expected_tag| for our UID. Return the count of received bytes. 1403 uint64_t GetTaggedBytes(int32_t expected_tag); 1404 #endif 1405 1406 } // namespace net 1407 1408 #endif // NET_SOCKET_SOCKET_TEST_UTIL_H_ 1409