1 // Copyright 2012 The Chromium Authors 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #ifndef NET_SOCKET_SOCKET_TEST_UTIL_H_ 6 #define NET_SOCKET_SOCKET_TEST_UTIL_H_ 7 8 #include <stddef.h> 9 #include <stdint.h> 10 11 #include <cstring> 12 #include <memory> 13 #include <optional> 14 #include <string> 15 #include <string_view> 16 #include <utility> 17 #include <vector> 18 19 #include "base/check_op.h" 20 #include "base/containers/span.h" 21 #include "base/functional/bind.h" 22 #include "base/functional/callback.h" 23 #include "base/memory/ptr_util.h" 24 #include "base/memory/raw_ptr.h" 25 #include "base/memory/raw_span.h" 26 #include "base/memory/ref_counted.h" 27 #include "base/memory/weak_ptr.h" 28 #include "build/build_config.h" 29 #include "net/base/address_list.h" 30 #include "net/base/completion_once_callback.h" 31 #include "net/base/io_buffer.h" 32 #include "net/base/net_errors.h" 33 #include "net/base/test_completion_callback.h" 34 #include "net/http/http_auth_controller.h" 35 #include "net/log/net_log_with_source.h" 36 #include "net/socket/client_socket_factory.h" 37 #include "net/socket/client_socket_handle.h" 38 #include "net/socket/client_socket_pool.h" 39 #include "net/socket/datagram_client_socket.h" 40 #include "net/socket/socket_performance_watcher.h" 41 #include "net/socket/socket_tag.h" 42 #include "net/socket/ssl_client_socket.h" 43 #include "net/socket/transport_client_socket.h" 44 #include "net/socket/transport_client_socket_pool.h" 45 #include "net/ssl/ssl_config_service.h" 46 #include "net/ssl/ssl_info.h" 47 #include "testing/gtest/include/gtest/gtest.h" 48 49 namespace base { 50 class RunLoop; 51 } 52 53 namespace net { 54 55 struct CommonConnectJobParams; 56 class NetLog; 57 struct NetworkTrafficAnnotationTag; 58 class X509Certificate; 59 60 const handles::NetworkHandle kDefaultNetworkForTests = 1; 61 const handles::NetworkHandle kNewNetworkForTests = 2; 62 63 enum { 64 // A private network error code used by the socket test utility classes. 65 // If the |result| member of a MockRead is 66 // ERR_TEST_PEER_CLOSE_AFTER_NEXT_MOCK_READ, that MockRead is just a 67 // marker that indicates the peer will close the connection after the next 68 // MockRead. The other members of that MockRead are ignored. 69 ERR_TEST_PEER_CLOSE_AFTER_NEXT_MOCK_READ = -10000, 70 }; 71 72 class AsyncSocket; 73 class MockClientSocket; 74 class MockTCPClientSocket; 75 class MockSSLClientSocket; 76 class SSLClientSocket; 77 class StreamSocket; 78 79 enum IoMode { ASYNC, SYNCHRONOUS }; 80 81 // Used to delay MockClientSocket::Connect. 82 // Example usage: 83 // TEST(FooTest, Test) { 84 // MockClientSocketFactory socket_factory; 85 // 86 // MockConnectCompleter completer; 87 // SequencedSocketData data; 88 // data.set_connect_data(MockConnect(&completer)); 89 // socket_factory.AddSocketDataProvider(&data); 90 // 91 // // Create a MockClientSocket somehow. 92 // std::unique_ptr<StreamSocket> stream = CreateStreamSocket(); 93 // std::optional<int> delayed_result; 94 // int rv = stream->Connect(base::BindLambdaForTesting([&](int result){ 95 // delayed_result = result; 96 // })); 97 // // Connect() returns ERR_IO_PENDING. 98 // EXPECT_THAT(rv, IsError(ERR_IO_PENDING)); 99 // 100 // RunUntilIdle(); 101 // // Connect() is still blocked. 102 // ASSERT_FALSE(delayed_result.has_value()); 103 // 104 // completer.Complete(OK); 105 // RunUntilIdle(); 106 // EXPECT_THAT(delayed_result, Optional(IsOk())); 107 // } 108 class MockConnectCompleter { 109 public: 110 MockConnectCompleter(); 111 112 MockConnectCompleter(const MockConnectCompleter&) = delete; 113 MockConnectCompleter& operator=(const MockConnectCompleter&) = delete; 114 115 ~MockConnectCompleter(); 116 117 // Completes Connect() with `result`. 118 void Complete(int result); 119 120 private: 121 friend class MockTCPClientSocket; 122 friend class MockSSLClientSocket; 123 friend class MockUDPClientSocket; 124 125 // Sets a completion callback that is passed to Connect(). Called by 126 // MockClientSocket implementations. 127 void SetCallback(CompletionOnceCallback callback); 128 129 CompletionOnceCallback callback_; 130 }; 131 132 struct MockConnect { 133 // Asynchronous connection success. 134 // Creates a MockConnect with |mode| ASYC, |result| OK, and 135 // |peer_addr| 192.0.2.33. 136 MockConnect(); 137 // Creates a MockConnect with the specified mode and result, with 138 // |peer_addr| 192.0.2.33. 139 MockConnect(IoMode io_mode, int r); 140 MockConnect(IoMode io_mode, int r, IPEndPoint addr); 141 MockConnect(IoMode io_mode, int r, IPEndPoint addr, bool first_attempt_fails); 142 // Creates a MockConnect that delays connection until `completer` invokes 143 // Complete(). 144 explicit MockConnect(MockConnectCompleter* completer); 145 ~MockConnect(); 146 147 IoMode mode; 148 int result; 149 IPEndPoint peer_addr; 150 bool first_attempt_fails = false; 151 raw_ptr<MockConnectCompleter> completer; 152 }; 153 154 struct MockConfirm { 155 // Asynchronous confirm success. 156 // Creates a MockConfirm with |mode| ASYC and |result| OK. 157 MockConfirm(); 158 // Creates a MockConfirm with the specified mode and result. 159 MockConfirm(IoMode io_mode, int r); 160 ~MockConfirm(); 161 162 IoMode mode; 163 int result; 164 }; 165 166 // MockRead and MockWrite shares the same interface and members, but we'd like 167 // to have distinct types because we don't want to have them used 168 // interchangably. To do this, a struct template is defined, and MockRead and 169 // MockWrite are instantiated by using this template. Template parameter |type| 170 // is not used in the struct definition (it purely exists for creating a new 171 // type). 172 // 173 // |data| in MockRead and MockWrite has different meanings: |data| in MockRead 174 // is the data returned from the socket when MockTCPClientSocket::Read() is 175 // attempted, while |data| in MockWrite is the expected data that should be 176 // given in MockTCPClientSocket::Write(). 177 enum MockReadWriteType { MOCK_READ, MOCK_WRITE }; 178 179 template <MockReadWriteType type> 180 struct MockReadWrite { 181 // Flag to indicate that the message loop should be terminated. 182 enum { STOPLOOP = 1 << 31 }; 183 184 // Default MockReadWriteMockReadWrite185 MockReadWrite() 186 : mode(SYNCHRONOUS), 187 result(0), 188 data(nullptr), 189 data_len(0), 190 sequence_number(0), 191 tos(0) {} 192 193 // Read/write failure (no data). MockReadWriteMockReadWrite194 MockReadWrite(IoMode io_mode, int result) 195 : mode(io_mode), 196 result(result), 197 data(nullptr), 198 data_len(0), 199 sequence_number(0), 200 tos(0) {} 201 202 // Read/write failure (no data), with sequence information. MockReadWriteMockReadWrite203 MockReadWrite(IoMode io_mode, int result, int seq) 204 : mode(io_mode), 205 result(result), 206 data(nullptr), 207 data_len(0), 208 sequence_number(seq), 209 tos(0) {} 210 211 // Asynchronous read/write success (inferred data length). MockReadWriteMockReadWrite212 explicit MockReadWrite(const char* data) 213 : mode(ASYNC), 214 result(0), 215 data(data), 216 data_len(strlen(data)), 217 sequence_number(0), 218 tos(0) {} 219 220 // Read/write success (inferred data length). MockReadWriteMockReadWrite221 MockReadWrite(IoMode io_mode, const char* data) 222 : mode(io_mode), 223 result(0), 224 data(data), 225 data_len(strlen(data)), 226 sequence_number(0), 227 tos(0) {} 228 229 // Read/write success. MockReadWriteMockReadWrite230 MockReadWrite(IoMode io_mode, const char* data, int data_len) 231 : mode(io_mode), 232 result(0), 233 data(data), 234 data_len(data_len), 235 sequence_number(0), 236 tos(0) {} 237 238 // Read/write success (inferred data length) with sequence information. MockReadWriteMockReadWrite239 MockReadWrite(IoMode io_mode, int seq, const char* data) 240 : mode(io_mode), 241 result(0), 242 data(data), 243 data_len(strlen(data)), 244 sequence_number(seq), 245 tos(0) {} 246 247 // Read/write success with sequence information. MockReadWriteMockReadWrite248 MockReadWrite(IoMode io_mode, const char* data, int data_len, int seq) 249 : mode(io_mode), 250 result(0), 251 data(data), 252 data_len(data_len), 253 sequence_number(seq), 254 tos(0) {} 255 256 // Read/write success with sequence and TOS information. MockReadWriteMockReadWrite257 MockReadWrite(IoMode io_mode, 258 const char* data, 259 int data_len, 260 int seq, 261 uint8_t tos_byte) 262 : mode(io_mode), 263 result(0), 264 data(data), 265 data_len(data_len), 266 sequence_number(seq), 267 tos(tos_byte) {} 268 269 IoMode mode; 270 int result; 271 const char* data; 272 int data_len; 273 274 // For data providers that only allows reads to occur in a particular 275 // sequence. If a read occurs before the given |sequence_number| is reached, 276 // an ERR_IO_PENDING is returned. 277 int sequence_number; // The sequence number at which a read is allowed 278 // to occur. 279 280 // The TOS byte of the datagram, for datagram sockets only. 281 uint8_t tos; 282 }; 283 284 typedef MockReadWrite<MOCK_READ> MockRead; 285 typedef MockReadWrite<MOCK_WRITE> MockWrite; 286 287 struct MockWriteResult { MockWriteResultMockWriteResult288 MockWriteResult(IoMode io_mode, int result) : mode(io_mode), result(result) {} 289 290 IoMode mode; 291 int result; 292 }; 293 294 class SocketDataPrinter { 295 public: 296 ~SocketDataPrinter() = default; 297 298 // Prints the write in |data| using some sort of protocol-specific 299 // format. 300 virtual std::string PrintWrite(const std::string& data) = 0; 301 }; 302 303 // The SocketDataProvider is an interface used by the MockClientSocket 304 // for getting data about individual reads and writes on the socket. Can be 305 // used with at most one socket at a time. 306 // TODO(mmenke): Do these really need to be re-useable? 307 class SocketDataProvider { 308 public: 309 SocketDataProvider(); 310 311 SocketDataProvider(const SocketDataProvider&) = delete; 312 SocketDataProvider& operator=(const SocketDataProvider&) = delete; 313 314 virtual ~SocketDataProvider(); 315 316 // Returns the buffer and result code for the next simulated read. 317 // If the |MockRead.result| is ERR_IO_PENDING, it informs the caller 318 // that it will be called via the AsyncSocket::OnReadComplete() 319 // function at a later time. 320 virtual MockRead OnRead() = 0; 321 virtual MockWriteResult OnWrite(const std::string& data) = 0; 322 virtual bool AllReadDataConsumed() const = 0; 323 virtual bool AllWriteDataConsumed() const = 0; CancelPendingRead()324 virtual void CancelPendingRead() {} 325 326 // Returns the last set receive buffer size, or -1 if never set. receive_buffer_size()327 int receive_buffer_size() const { return receive_buffer_size_; } set_receive_buffer_size(int receive_buffer_size)328 void set_receive_buffer_size(int receive_buffer_size) { 329 receive_buffer_size_ = receive_buffer_size; 330 } 331 332 // Returns the last set send buffer size, or -1 if never set. send_buffer_size()333 int send_buffer_size() const { return send_buffer_size_; } set_send_buffer_size(int send_buffer_size)334 void set_send_buffer_size(int send_buffer_size) { 335 send_buffer_size_ = send_buffer_size; 336 } 337 338 // Returns the last set value of TCP no delay, or false if never set. no_delay()339 bool no_delay() const { return no_delay_; } set_no_delay(bool no_delay)340 void set_no_delay(bool no_delay) { no_delay_ = no_delay; } 341 342 // Returns whether TCP keepalives were enabled or not. Returns kDefault by 343 // default. 344 enum class KeepAliveState { kEnabled, kDisabled, kDefault }; keep_alive_state()345 KeepAliveState keep_alive_state() const { return keep_alive_state_; } 346 // Last set TCP keepalive delay. keep_alive_delay()347 int keep_alive_delay() const { return keep_alive_delay_; } set_keep_alive(bool enable,int delay)348 void set_keep_alive(bool enable, int delay) { 349 keep_alive_state_ = 350 enable ? KeepAliveState::kEnabled : KeepAliveState::kDisabled; 351 keep_alive_delay_ = delay; 352 } 353 354 // Setters / getters for the return values of the corresponding Set*() 355 // methods. By default, they all succeed, if the socket is connected. 356 set_set_receive_buffer_size_result(int receive_buffer_size_result)357 void set_set_receive_buffer_size_result(int receive_buffer_size_result) { 358 set_receive_buffer_size_result_ = receive_buffer_size_result; 359 } set_receive_buffer_size_result()360 int set_receive_buffer_size_result() const { 361 return set_receive_buffer_size_result_; 362 } 363 set_set_send_buffer_size_result(int set_send_buffer_size_result)364 void set_set_send_buffer_size_result(int set_send_buffer_size_result) { 365 set_send_buffer_size_result_ = set_send_buffer_size_result; 366 } set_send_buffer_size_result()367 int set_send_buffer_size_result() const { 368 return set_send_buffer_size_result_; 369 } 370 set_set_no_delay_result(bool set_no_delay_result)371 void set_set_no_delay_result(bool set_no_delay_result) { 372 set_no_delay_result_ = set_no_delay_result; 373 } set_no_delay_result()374 bool set_no_delay_result() const { return set_no_delay_result_; } 375 set_set_keep_alive_result(bool set_keep_alive_result)376 void set_set_keep_alive_result(bool set_keep_alive_result) { 377 set_keep_alive_result_ = set_keep_alive_result; 378 } set_keep_alive_result()379 bool set_keep_alive_result() const { return set_keep_alive_result_; } 380 expected_addresses()381 const std::optional<AddressList>& expected_addresses() const { 382 return expected_addresses_; 383 } set_expected_addresses(net::AddressList addresses)384 void set_expected_addresses(net::AddressList addresses) { 385 expected_addresses_ = std::move(addresses); 386 } 387 388 // Returns true if the request should be considered idle, for the purposes of 389 // IsConnectedAndIdle. 390 virtual bool IsIdle() const; 391 392 // Initializes the SocketDataProvider for use with |socket|. Must be called 393 // before use 394 void Initialize(AsyncSocket* socket); 395 // Detaches the socket associated with a SocketDataProvider. Must be called 396 // before |socket_| is destroyed, unless the SocketDataProvider has informed 397 // |socket_| it was destroyed. Must also be called before Initialize() may 398 // be called again with a new socket. 399 void DetachSocket(); 400 401 // Accessor for the socket which is using the SocketDataProvider. socket()402 AsyncSocket* socket() { return socket_; } 403 connect_data()404 MockConnect connect_data() const { return connect_; } set_connect_data(const MockConnect & connect)405 void set_connect_data(const MockConnect& connect) { connect_ = connect; } 406 407 private: 408 // Called to inform subclasses of initialization. 409 virtual void Reset() = 0; 410 411 MockConnect connect_; 412 raw_ptr<AsyncSocket> socket_ = nullptr; 413 414 int receive_buffer_size_ = -1; 415 int send_buffer_size_ = -1; 416 // This reflects the default state of TCPClientSockets. 417 bool no_delay_ = true; 418 419 KeepAliveState keep_alive_state_ = KeepAliveState::kDefault; 420 int keep_alive_delay_ = 0; 421 422 int set_receive_buffer_size_result_ = net::OK; 423 int set_send_buffer_size_result_ = net::OK; 424 bool set_no_delay_result_ = true; 425 bool set_keep_alive_result_ = true; 426 std::optional<AddressList> expected_addresses_; 427 }; 428 429 // The AsyncSocket is an interface used by the SocketDataProvider to 430 // complete the asynchronous read operation. 431 class AsyncSocket { 432 public: 433 // If an async IO is pending because the SocketDataProvider returned 434 // ERR_IO_PENDING, then the AsyncSocket waits until this OnReadComplete 435 // is called to complete the asynchronous read operation. 436 // data.async is ignored, and this read is completed synchronously as 437 // part of this call. 438 // TODO(rch): this should take a std::string_view since most of the fields 439 // are ignored. 440 virtual void OnReadComplete(const MockRead& data) = 0; 441 // If an async IO is pending because the SocketDataProvider returned 442 // ERR_IO_PENDING, then the AsyncSocket waits until this OnReadComplete 443 // is called to complete the asynchronous read operation. 444 virtual void OnWriteComplete(int rv) = 0; 445 virtual void OnConnectComplete(const MockConnect& data) = 0; 446 447 // Called when the SocketDataProvider associated with the socket is destroyed. 448 // The socket may continue to be used after the data provider is destroyed, 449 // so it should be sure not to dereference the provider after this is called. 450 virtual void OnDataProviderDestroyed() = 0; 451 }; 452 453 // StaticSocketDataHelper manages a list of reads and writes. 454 class StaticSocketDataHelper { 455 public: 456 StaticSocketDataHelper(base::span<const MockRead> reads, 457 base::span<const MockWrite> writes); 458 459 StaticSocketDataHelper(const StaticSocketDataHelper&) = delete; 460 StaticSocketDataHelper& operator=(const StaticSocketDataHelper&) = delete; 461 462 ~StaticSocketDataHelper(); 463 464 // These functions get access to the next available read and write data. They 465 // CHECK fail if there is no data available. 466 const MockRead& PeekRead() const; 467 const MockWrite& PeekWrite() const; 468 469 // Returns the current read or write, and then advances to the next one. 470 const MockRead& AdvanceRead(); 471 const MockWrite& AdvanceWrite(); 472 473 // Resets the read and write indexes to 0. 474 void Reset(); 475 476 // Returns true if |data| is valid data for the next write. In order 477 // to support short writes, the next write may be longer than |data| 478 // in which case this method will still return true. 479 bool VerifyWriteData(const std::string& data, SocketDataPrinter* printer); 480 read_index()481 size_t read_index() const { return read_index_; } write_index()482 size_t write_index() const { return write_index_; } read_count()483 size_t read_count() const { return reads_.size(); } write_count()484 size_t write_count() const { return writes_.size(); } 485 AllReadDataConsumed()486 bool AllReadDataConsumed() const { return read_index() >= read_count(); } AllWriteDataConsumed()487 bool AllWriteDataConsumed() const { return write_index() >= write_count(); } 488 489 void ExpectAllReadDataConsumed(SocketDataPrinter* printer) const; 490 void ExpectAllWriteDataConsumed(SocketDataPrinter* printer) const; 491 492 private: 493 // Returns the next available read or write that is not a pause event. CHECK 494 // fails if no data is available. 495 const MockWrite& PeekRealWrite() const; 496 497 const base::raw_span<const MockRead, DanglingUntriaged> reads_; 498 size_t read_index_ = 0; 499 const base::raw_span<const MockWrite, DanglingUntriaged> writes_; 500 size_t write_index_ = 0; 501 }; 502 503 // SocketDataProvider which responds based on static tables of mock reads and 504 // writes. 505 class StaticSocketDataProvider : public SocketDataProvider { 506 public: 507 StaticSocketDataProvider(); 508 StaticSocketDataProvider(base::span<const MockRead> reads, 509 base::span<const MockWrite> writes); 510 511 StaticSocketDataProvider(const StaticSocketDataProvider&) = delete; 512 StaticSocketDataProvider& operator=(const StaticSocketDataProvider&) = delete; 513 514 ~StaticSocketDataProvider() override; 515 516 // Pause/resume reads from this provider. 517 void Pause(); 518 void Resume(); 519 520 // From SocketDataProvider: 521 MockRead OnRead() override; 522 MockWriteResult OnWrite(const std::string& data) override; 523 bool AllReadDataConsumed() const override; 524 bool AllWriteDataConsumed() const override; 525 read_index()526 size_t read_index() const { return helper_.read_index(); } write_index()527 size_t write_index() const { return helper_.write_index(); } read_count()528 size_t read_count() const { return helper_.read_count(); } write_count()529 size_t write_count() const { return helper_.write_count(); } 530 set_printer(SocketDataPrinter * printer)531 void set_printer(SocketDataPrinter* printer) { printer_ = printer; } 532 533 private: 534 // From SocketDataProvider: 535 void Reset() override; 536 537 StaticSocketDataHelper helper_; 538 raw_ptr<SocketDataPrinter> printer_ = nullptr; 539 bool paused_ = false; 540 }; 541 542 // SSLSocketDataProviders only need to keep track of the return code from calls 543 // to Connect(). 544 struct SSLSocketDataProvider { 545 SSLSocketDataProvider(IoMode mode, int result); 546 explicit SSLSocketDataProvider(MockConnectCompleter* completer); 547 SSLSocketDataProvider(const SSLSocketDataProvider& other); 548 ~SSLSocketDataProvider(); 549 550 // Returns whether MockConnect data has been consumed. ConnectDataConsumedSSLSocketDataProvider551 bool ConnectDataConsumed() const { return is_connect_data_consumed; } 552 553 // Returns whether MockConfirm data has been consumed. ConfirmDataConsumedSSLSocketDataProvider554 bool ConfirmDataConsumed() const { return is_confirm_data_consumed; } 555 556 // Returns whether a Write occurred before ConfirmHandshake completed. WriteBeforeConfirmSSLSocketDataProvider557 bool WriteBeforeConfirm() const { return write_called_before_confirm; } 558 559 // Result for Connect(). 560 MockConnect connect; 561 // Callback to run when Connect() is called. This is called at most once per 562 // socket but is repeating because SSLSocketDataProvider is copyable. 563 base::RepeatingClosure connect_callback; 564 565 // Result for ConfirmHandshake(). 566 MockConfirm confirm; 567 // Callback to run when ConfirmHandshake() is called. This is called at most 568 // once per socket but is repeating because SSLSocketDataProvider is 569 // copyable. 570 base::RepeatingClosure confirm_callback; 571 572 // Result for GetNegotiatedProtocol(). 573 NextProto next_proto = kProtoUnknown; 574 575 // Result for GetPeerApplicationSettings(). 576 std::optional<std::string> peer_application_settings; 577 578 // Result for GetSSLInfo(). 579 SSLInfo ssl_info; 580 581 // Result for GetSSLCertRequestInfo(). 582 scoped_refptr<SSLCertRequestInfo> cert_request_info; 583 584 // Result for GetECHRetryConfigs(). 585 std::vector<uint8_t> ech_retry_configs; 586 587 std::optional<NextProtoVector> next_protos_expected_in_ssl_config; 588 std::optional<SSLConfig::ApplicationSettings> expected_application_settings; 589 590 uint16_t expected_ssl_version_min; 591 uint16_t expected_ssl_version_max; 592 std::optional<bool> expected_early_data_enabled; 593 std::optional<bool> expected_send_client_cert; 594 scoped_refptr<X509Certificate> expected_client_cert; 595 std::optional<HostPortPair> expected_host_and_port; 596 std::optional<bool> expected_ignore_certificate_errors; 597 std::optional<NetworkAnonymizationKey> expected_network_anonymization_key; 598 std::optional<std::vector<uint8_t>> expected_ech_config_list; 599 600 bool is_connect_data_consumed = false; 601 bool is_confirm_data_consumed = false; 602 bool write_called_before_confirm = false; 603 }; 604 605 // Uses the sequence_number field in the mock reads and writes to 606 // complete the operations in a specified order. 607 class SequencedSocketData : public SocketDataProvider { 608 public: 609 SequencedSocketData(); 610 611 // |reads| is the list of MockRead completions. 612 // |writes| is the list of MockWrite completions. 613 SequencedSocketData(base::span<const MockRead> reads, 614 base::span<const MockWrite> writes); 615 616 // |connect| is the result for the connect phase. 617 // |reads| is the list of MockRead completions. 618 // |writes| is the list of MockWrite completions. 619 SequencedSocketData(const MockConnect& connect, 620 base::span<const MockRead> reads, 621 base::span<const MockWrite> writes); 622 623 SequencedSocketData(const SequencedSocketData&) = delete; 624 SequencedSocketData& operator=(const SequencedSocketData&) = delete; 625 626 ~SequencedSocketData() override; 627 628 // From SocketDataProvider: 629 MockRead OnRead() override; 630 MockWriteResult OnWrite(const std::string& data) override; 631 bool AllReadDataConsumed() const override; 632 bool AllWriteDataConsumed() const override; 633 bool IsIdle() const override; 634 void CancelPendingRead() override; 635 636 // EXPECTs that all data has been consumed, printing any un-consumed data. 637 void ExpectAllReadDataConsumed() const; 638 void ExpectAllWriteDataConsumed() const; 639 640 // An ASYNC read event with a return value of ERR_IO_PENDING will cause the 641 // socket data to pause at that event, and advance no further, until Resume is 642 // invoked. At that point, the socket will continue at the next event in the 643 // sequence. 644 // 645 // If a request just wants to simulate a connection that stays open and never 646 // receives any more data, instead of pausing and then resuming a request, it 647 // should use a SYNCHRONOUS event with a return value of ERR_IO_PENDING 648 // instead. 649 bool IsPaused() const; 650 // Resumes events once |this| is in the paused state. The next event will 651 // occur synchronously with the call if it can. 652 void Resume(); 653 void RunUntilPaused(); 654 655 // When true, IsConnectedAndIdle() will return false if the next event in the 656 // sequence is a synchronous. Otherwise, the socket claims to be idle as 657 // long as it's connected. Defaults to false. 658 // TODO(mmenke): See if this can be made the default behavior, and consider 659 // removing this mehtod. Need to make sure it doesn't change what code any 660 // tests are targetted at testing. set_busy_before_sync_reads(bool busy_before_sync_reads)661 void set_busy_before_sync_reads(bool busy_before_sync_reads) { 662 busy_before_sync_reads_ = busy_before_sync_reads; 663 } 664 set_printer(SocketDataPrinter * printer)665 void set_printer(SocketDataPrinter* printer) { printer_ = printer; } 666 667 private: 668 // Defines the state for the read or write path. 669 enum class IoState { 670 kIdle, // No async operation is in progress. 671 kPending, // An async operation in waiting for another operation to 672 // complete. 673 kCompleting, // A task has been posted to complete an async operation. 674 kPaused, // IO is paused until Resume() is called. 675 }; 676 677 // From SocketDataProvider: 678 void Reset() override; 679 680 void OnReadComplete(); 681 void OnWriteComplete(); 682 683 void MaybePostReadCompleteTask(); 684 void MaybePostWriteCompleteTask(); 685 686 StaticSocketDataHelper helper_; 687 raw_ptr<SocketDataPrinter> printer_ = nullptr; 688 int sequence_number_ = 0; 689 IoState read_state_ = IoState::kIdle; 690 IoState write_state_ = IoState::kIdle; 691 692 bool busy_before_sync_reads_ = false; 693 694 // Used by RunUntilPaused. NULL at all other times. 695 std::unique_ptr<base::RunLoop> run_until_paused_run_loop_; 696 697 base::WeakPtrFactory<SequencedSocketData> weak_factory_{this}; 698 }; 699 700 // Holds an array of SocketDataProvider elements. As Mock{TCP,SSL}StreamSocket 701 // objects get instantiated, they take their data from the i'th element of this 702 // array. 703 template <typename T> 704 class SocketDataProviderArray { 705 public: 706 SocketDataProviderArray() = default; 707 GetNext()708 T* GetNext() { 709 DCHECK_LT(next_index_, data_providers_.size()); 710 return data_providers_[next_index_++]; 711 } 712 713 // Like GetNext(), but returns nullptr when the end of the array is reached, 714 // instead of DCHECKing. GetNext() should generally be preferred, unless 715 // having no remaining elements is expected in some cases and is handled 716 // safely. GetNextWithoutAsserting()717 T* GetNextWithoutAsserting() { 718 if (next_index_ == data_providers_.size()) 719 return nullptr; 720 return data_providers_[next_index_++]; 721 } 722 Add(T * data_provider)723 void Add(T* data_provider) { 724 DCHECK(data_provider); 725 data_providers_.push_back(data_provider); 726 } 727 next_index()728 size_t next_index() { return next_index_; } 729 ResetNextIndex()730 void ResetNextIndex() { next_index_ = 0; } 731 732 private: 733 // Index of the next |data_providers_| element to use. Not an iterator 734 // because those are invalidated on vector reallocation. 735 size_t next_index_ = 0; 736 737 // SocketDataProviders to be returned. 738 std::vector<T*> data_providers_; 739 }; 740 741 class MockUDPClientSocket; 742 class MockTCPClientSocket; 743 class MockSSLClientSocket; 744 745 // ClientSocketFactory which contains arrays of sockets of each type. 746 // You should first fill the arrays using Add{SSL,}SocketDataProvider(). When 747 // the factory is asked to create a socket, it takes next entry from appropriate 748 // array. You can use ResetNextMockIndexes to reset that next entry index for 749 // all mock socket types. 750 class MockClientSocketFactory : public ClientSocketFactory { 751 public: 752 MockClientSocketFactory(); 753 754 MockClientSocketFactory(const MockClientSocketFactory&) = delete; 755 MockClientSocketFactory& operator=(const MockClientSocketFactory&) = delete; 756 757 ~MockClientSocketFactory() override; 758 759 // Adds a SocketDataProvider that can be used to served either TCP or UDP 760 // connection requests. Sockets are returned in FIFO order. 761 void AddSocketDataProvider(SocketDataProvider* socket); 762 763 // Like AddSocketDataProvider(), except sockets will only be used to service 764 // TCP connection requests. Sockets added with this method are used first, 765 // before sockets added with AddSocketDataProvider(). Particularly useful for 766 // QUIC tests with multiple sockets, where TCP connections may or may not be 767 // made, and have no guaranteed order, relative to UDP connections. 768 void AddTcpSocketDataProvider(SocketDataProvider* socket); 769 770 void AddSSLSocketDataProvider(SSLSocketDataProvider* socket); 771 void ResetNextMockIndexes(); 772 mock_data()773 SocketDataProviderArray<SocketDataProvider>& mock_data() { 774 return mock_data_; 775 } 776 set_enable_read_if_ready(bool enable_read_if_ready)777 void set_enable_read_if_ready(bool enable_read_if_ready) { 778 enable_read_if_ready_ = enable_read_if_ready; 779 } 780 781 // ClientSocketFactory 782 std::unique_ptr<DatagramClientSocket> CreateDatagramClientSocket( 783 DatagramSocket::BindType bind_type, 784 NetLog* net_log, 785 const NetLogSource& source) override; 786 std::unique_ptr<TransportClientSocket> CreateTransportClientSocket( 787 const AddressList& addresses, 788 std::unique_ptr<SocketPerformanceWatcher> socket_performance_watcher, 789 NetworkQualityEstimator* network_quality_estimator, 790 NetLog* net_log, 791 const NetLogSource& source) override; 792 std::unique_ptr<SSLClientSocket> CreateSSLClientSocket( 793 SSLClientContext* context, 794 std::unique_ptr<StreamSocket> stream_socket, 795 const HostPortPair& host_and_port, 796 const SSLConfig& ssl_config) override; udp_client_socket_ports()797 const std::vector<uint16_t>& udp_client_socket_ports() const { 798 return udp_client_socket_ports_; 799 } 800 801 private: 802 SocketDataProviderArray<SocketDataProvider> mock_data_; 803 SocketDataProviderArray<SocketDataProvider> mock_tcp_data_; 804 SocketDataProviderArray<SSLSocketDataProvider> mock_ssl_data_; 805 std::vector<uint16_t> udp_client_socket_ports_; 806 807 // If true, ReadIfReady() is enabled; otherwise ReadIfReady() returns 808 // ERR_READ_IF_READY_NOT_IMPLEMENTED. 809 bool enable_read_if_ready_ = false; 810 }; 811 812 class MockClientSocket : public TransportClientSocket { 813 public: 814 // The NetLogWithSource is needed to test LoadTimingInfo, which uses NetLog 815 // IDs as 816 // unique socket IDs. 817 explicit MockClientSocket(const NetLogWithSource& net_log); 818 819 MockClientSocket(const MockClientSocket&) = delete; 820 MockClientSocket& operator=(const MockClientSocket&) = delete; 821 822 // Socket implementation. 823 int Read(IOBuffer* buf, 824 int buf_len, 825 CompletionOnceCallback callback) override = 0; 826 int Write(IOBuffer* buf, 827 int buf_len, 828 CompletionOnceCallback callback, 829 const NetworkTrafficAnnotationTag& traffic_annotation) override = 0; 830 int SetReceiveBufferSize(int32_t size) override; 831 int SetSendBufferSize(int32_t size) override; 832 833 // TransportClientSocket implementation. 834 int Bind(const net::IPEndPoint& local_addr) override; 835 bool SetNoDelay(bool no_delay) override; 836 bool SetKeepAlive(bool enable, int delay) override; 837 838 // StreamSocket implementation. 839 int Connect(CompletionOnceCallback callback) override = 0; 840 void Disconnect() override; 841 bool IsConnected() const override; 842 bool IsConnectedAndIdle() const override; 843 int GetPeerAddress(IPEndPoint* address) const override; 844 int GetLocalAddress(IPEndPoint* address) const override; 845 const NetLogWithSource& NetLog() const override; 846 NextProto GetNegotiatedProtocol() const override; 847 int64_t GetTotalReceivedBytes() const override; ApplySocketTag(const SocketTag & tag)848 void ApplySocketTag(const SocketTag& tag) override {} 849 850 protected: 851 ~MockClientSocket() override; 852 void RunCallbackAsync(CompletionOnceCallback callback, int result); 853 void RunCallback(CompletionOnceCallback callback, int result); 854 855 // True if Connect completed successfully and Disconnect hasn't been called. 856 bool connected_ = false; 857 858 IPEndPoint local_addr_; 859 IPEndPoint peer_addr_; 860 861 NetLogWithSource net_log_; 862 863 private: 864 base::WeakPtrFactory<MockClientSocket> weak_factory_{this}; 865 }; 866 867 class MockTCPClientSocket : public MockClientSocket, public AsyncSocket { 868 public: 869 MockTCPClientSocket(const AddressList& addresses, 870 net::NetLog* net_log, 871 SocketDataProvider* socket); 872 873 MockTCPClientSocket(const MockTCPClientSocket&) = delete; 874 MockTCPClientSocket& operator=(const MockTCPClientSocket&) = delete; 875 876 ~MockTCPClientSocket() override; 877 addresses()878 const AddressList& addresses() const { return addresses_; } 879 880 // Socket implementation. 881 int Read(IOBuffer* buf, 882 int buf_len, 883 CompletionOnceCallback callback) override; 884 int ReadIfReady(IOBuffer* buf, 885 int buf_len, 886 CompletionOnceCallback callback) override; 887 int CancelReadIfReady() override; 888 int Write(IOBuffer* buf, 889 int buf_len, 890 CompletionOnceCallback callback, 891 const NetworkTrafficAnnotationTag& traffic_annotation) override; 892 int SetReceiveBufferSize(int32_t size) override; 893 int SetSendBufferSize(int32_t size) override; 894 895 // TransportClientSocket implementation. 896 bool SetNoDelay(bool no_delay) override; 897 bool SetKeepAlive(bool enable, int delay) override; 898 899 // StreamSocket implementation. 900 void SetBeforeConnectCallback( 901 const BeforeConnectCallback& before_connect_callback) override; 902 int Connect(CompletionOnceCallback callback) override; 903 void Disconnect() override; 904 bool IsConnected() const override; 905 bool IsConnectedAndIdle() const override; 906 int GetPeerAddress(IPEndPoint* address) const override; 907 bool WasEverUsed() const override; 908 bool GetSSLInfo(SSLInfo* ssl_info) override; 909 910 // AsyncSocket: 911 void OnReadComplete(const MockRead& data) override; 912 void OnWriteComplete(int rv) override; 913 void OnConnectComplete(const MockConnect& data) override; 914 void OnDataProviderDestroyed() override; 915 set_enable_read_if_ready(bool enable_read_if_ready)916 void set_enable_read_if_ready(bool enable_read_if_ready) { 917 enable_read_if_ready_ = enable_read_if_ready; 918 } 919 920 private: 921 void RetryRead(int rv); 922 int ReadIfReadyImpl(IOBuffer* buf, 923 int buf_len, 924 CompletionOnceCallback callback); 925 926 // Helper method to run |pending_read_if_ready_callback_| if it is not null. 927 void RunReadIfReadyCallback(int result); 928 929 AddressList addresses_; 930 931 raw_ptr<SocketDataProvider> data_; 932 int read_offset_ = 0; 933 MockRead read_data_; 934 bool need_read_data_ = true; 935 936 // True if the peer has closed the connection. This allows us to simulate 937 // the recv(..., MSG_PEEK) call in the IsConnectedAndIdle method of the real 938 // TCPClientSocket. 939 bool peer_closed_connection_ = false; 940 941 // While an asynchronous read is pending, we save our user-buffer state. 942 scoped_refptr<IOBuffer> pending_read_buf_ = nullptr; 943 int pending_read_buf_len_ = 0; 944 CompletionOnceCallback pending_read_callback_; 945 946 // Non-null when a ReadIfReady() is pending. 947 CompletionOnceCallback pending_read_if_ready_callback_; 948 949 CompletionOnceCallback pending_connect_callback_; 950 CompletionOnceCallback pending_write_callback_; 951 bool was_used_to_convey_data_ = false; 952 953 // If true, ReadIfReady() is enabled; otherwise ReadIfReady() returns 954 // ERR_READ_IF_READY_NOT_IMPLEMENTED. 955 bool enable_read_if_ready_ = false; 956 957 BeforeConnectCallback before_connect_callback_; 958 }; 959 960 class MockSSLClientSocket : public AsyncSocket, public SSLClientSocket { 961 public: 962 MockSSLClientSocket(std::unique_ptr<StreamSocket> stream_socket, 963 const HostPortPair& host_and_port, 964 const SSLConfig& ssl_config, 965 SSLSocketDataProvider* socket); 966 967 MockSSLClientSocket(const MockSSLClientSocket&) = delete; 968 MockSSLClientSocket& operator=(const MockSSLClientSocket&) = delete; 969 970 ~MockSSLClientSocket() override; 971 972 // Socket implementation. 973 int Read(IOBuffer* buf, 974 int buf_len, 975 CompletionOnceCallback callback) override; 976 int ReadIfReady(IOBuffer* buf, 977 int buf_len, 978 CompletionOnceCallback callback) override; 979 int Write(IOBuffer* buf, 980 int buf_len, 981 CompletionOnceCallback callback, 982 const NetworkTrafficAnnotationTag& traffic_annotation) override; 983 int CancelReadIfReady() override; 984 985 // StreamSocket implementation. 986 int Connect(CompletionOnceCallback callback) override; 987 void Disconnect() override; 988 int ConfirmHandshake(CompletionOnceCallback callback) override; 989 bool IsConnected() const override; 990 bool IsConnectedAndIdle() const override; 991 bool WasEverUsed() const override; 992 int GetPeerAddress(IPEndPoint* address) const override; 993 int GetLocalAddress(IPEndPoint* address) const override; 994 NextProto GetNegotiatedProtocol() const override; 995 std::optional<std::string_view> GetPeerApplicationSettings() const override; 996 bool GetSSLInfo(SSLInfo* ssl_info) override; 997 void GetSSLCertRequestInfo( 998 SSLCertRequestInfo* cert_request_info) const override; 999 void ApplySocketTag(const SocketTag& tag) override; 1000 const NetLogWithSource& NetLog() const override; 1001 int64_t GetTotalReceivedBytes() const override; 1002 int SetReceiveBufferSize(int32_t size) override; 1003 int SetSendBufferSize(int32_t size) override; 1004 1005 // SSLSocket implementation. 1006 int ExportKeyingMaterial(std::string_view label, 1007 std::optional<base::span<const uint8_t>> context, 1008 base::span<uint8_t> out) override; 1009 1010 // SSLClientSocket implementation. 1011 std::vector<uint8_t> GetECHRetryConfigs() override; 1012 1013 // This MockSocket does not implement the manual async IO feature. 1014 void OnReadComplete(const MockRead& data) override; 1015 void OnWriteComplete(int rv) override; 1016 void OnConnectComplete(const MockConnect& data) override; 1017 // SSL sockets don't need magic to deal with destruction of their data 1018 // provider. 1019 // TODO(mmenke): Probably a good idea to support it, anyways. OnDataProviderDestroyed()1020 void OnDataProviderDestroyed() override {} 1021 1022 private: 1023 static void ConnectCallback(MockSSLClientSocket* ssl_client_socket, 1024 CompletionOnceCallback callback, 1025 int rv); 1026 1027 void RunCallbackAsync(CompletionOnceCallback callback, int result); 1028 void RunCallback(CompletionOnceCallback callback, int result); 1029 1030 void RunConfirmHandshakeCallback(CompletionOnceCallback callback, int result); 1031 1032 bool connected_ = false; 1033 bool in_confirm_handshake_ = false; 1034 NetLogWithSource net_log_; 1035 std::unique_ptr<StreamSocket> stream_socket_; 1036 raw_ptr<SSLSocketDataProvider, AcrossTasksDanglingUntriaged> data_; 1037 // Address of the "remote" peer we're connected to. 1038 IPEndPoint peer_addr_; 1039 1040 base::WeakPtrFactory<MockSSLClientSocket> weak_factory_{this}; 1041 }; 1042 1043 class MockUDPClientSocket : public DatagramClientSocket, public AsyncSocket { 1044 public: 1045 explicit MockUDPClientSocket(SocketDataProvider* data = nullptr, 1046 net::NetLog* net_log = nullptr); 1047 1048 MockUDPClientSocket(const MockUDPClientSocket&) = delete; 1049 MockUDPClientSocket& operator=(const MockUDPClientSocket&) = delete; 1050 1051 ~MockUDPClientSocket() override; 1052 1053 // Socket implementation. 1054 int Read(IOBuffer* buf, 1055 int buf_len, 1056 CompletionOnceCallback callback) override; 1057 int Write(IOBuffer* buf, 1058 int buf_len, 1059 CompletionOnceCallback callback, 1060 const NetworkTrafficAnnotationTag& traffic_annotation) override; 1061 1062 int SetReceiveBufferSize(int32_t size) override; 1063 int SetSendBufferSize(int32_t size) override; 1064 int SetDoNotFragment() override; 1065 int SetRecvTos() override; 1066 int SetTos(DiffServCodePoint dscp, EcnCodePoint ecn) override; 1067 1068 // DatagramSocket implementation. 1069 void Close() override; 1070 int GetPeerAddress(IPEndPoint* address) const override; 1071 int GetLocalAddress(IPEndPoint* address) const override; 1072 void UseNonBlockingIO() override; 1073 int SetMulticastInterface(uint32_t interface_index) override; 1074 const NetLogWithSource& NetLog() const override; 1075 1076 // DatagramClientSocket implementation. 1077 int Connect(const IPEndPoint& address) override; 1078 int ConnectUsingNetwork(handles::NetworkHandle network, 1079 const IPEndPoint& address) override; 1080 int ConnectUsingDefaultNetwork(const IPEndPoint& address) override; 1081 int ConnectAsync(const IPEndPoint& address, 1082 CompletionOnceCallback callback) override; 1083 int ConnectUsingNetworkAsync(handles::NetworkHandle network, 1084 const IPEndPoint& address, 1085 CompletionOnceCallback callback) override; 1086 int ConnectUsingDefaultNetworkAsync(const IPEndPoint& address, 1087 CompletionOnceCallback callback) override; 1088 handles::NetworkHandle GetBoundNetwork() const override; 1089 void ApplySocketTag(const SocketTag& tag) override; SetMsgConfirm(bool confirm)1090 void SetMsgConfirm(bool confirm) override {} 1091 DscpAndEcn GetLastTos() const override; 1092 1093 // AsyncSocket implementation. 1094 void OnReadComplete(const MockRead& data) override; 1095 void OnWriteComplete(int rv) override; 1096 void OnConnectComplete(const MockConnect& data) override; 1097 void OnDataProviderDestroyed() override; 1098 set_source_port(uint16_t port)1099 void set_source_port(uint16_t port) { source_port_ = port; } source_port()1100 uint16_t source_port() const { return source_port_; } set_source_host(IPAddress addr)1101 void set_source_host(IPAddress addr) { source_host_ = addr; } source_host()1102 IPAddress source_host() const { return source_host_; } 1103 1104 // Returns last tag applied to socket. tag()1105 SocketTag tag() const { return tag_; } 1106 1107 // Returns false if socket's tag was changed after the socket was used for 1108 // data transfer (e.g. Read/Write() called), otherwise returns true. tagged_before_data_transferred()1109 bool tagged_before_data_transferred() const { 1110 return tagged_before_data_transferred_; 1111 } 1112 1113 private: 1114 int CompleteRead(); 1115 1116 void RunCallbackAsync(CompletionOnceCallback callback, int result); 1117 void RunCallback(CompletionOnceCallback callback, int result); 1118 1119 bool connected_ = false; 1120 raw_ptr<SocketDataProvider> data_; 1121 int read_offset_ = 0; 1122 MockRead read_data_; 1123 bool need_read_data_ = true; 1124 IPAddress source_host_; 1125 uint16_t source_port_ = 123; // Ephemeral source port. 1126 1127 // Address of the "remote" peer we're connected to. 1128 IPEndPoint peer_addr_; 1129 1130 // Network that the socket is bound to. 1131 handles::NetworkHandle network_ = handles::kInvalidNetworkHandle; 1132 1133 // While an asynchronous IO is pending, we save our user-buffer state. 1134 scoped_refptr<IOBuffer> pending_read_buf_ = nullptr; 1135 int pending_read_buf_len_ = 0; 1136 CompletionOnceCallback pending_read_callback_; 1137 CompletionOnceCallback pending_write_callback_; 1138 1139 NetLogWithSource net_log_; 1140 1141 DatagramBuffers unwritten_buffers_; 1142 1143 SocketTag tag_; 1144 bool data_transferred_ = false; 1145 bool tagged_before_data_transferred_ = true; 1146 1147 uint8_t last_tos_ = 0; 1148 1149 base::WeakPtrFactory<MockUDPClientSocket> weak_factory_{this}; 1150 }; 1151 1152 class TestSocketRequest : public TestCompletionCallbackBase { 1153 public: 1154 TestSocketRequest(std::vector<raw_ptr<TestSocketRequest, VectorExperimental>>* 1155 request_order, 1156 size_t* completion_count); 1157 1158 TestSocketRequest(const TestSocketRequest&) = delete; 1159 TestSocketRequest& operator=(const TestSocketRequest&) = delete; 1160 1161 ~TestSocketRequest() override; 1162 handle()1163 ClientSocketHandle* handle() { return &handle_; } 1164 callback()1165 CompletionOnceCallback callback() { 1166 return base::BindOnce(&TestSocketRequest::OnComplete, 1167 base::Unretained(this)); 1168 } 1169 1170 private: 1171 void OnComplete(int result); 1172 1173 ClientSocketHandle handle_; 1174 raw_ptr<std::vector<raw_ptr<TestSocketRequest, VectorExperimental>>> 1175 request_order_; 1176 raw_ptr<size_t> completion_count_; 1177 }; 1178 1179 class ClientSocketPoolTest { 1180 public: 1181 enum KeepAlive { 1182 KEEP_ALIVE, 1183 1184 // A socket will be disconnected in addition to handle being reset. 1185 NO_KEEP_ALIVE, 1186 }; 1187 1188 static const int kIndexOutOfBounds; 1189 static const int kRequestNotFound; 1190 1191 ClientSocketPoolTest(); 1192 1193 ClientSocketPoolTest(const ClientSocketPoolTest&) = delete; 1194 ClientSocketPoolTest& operator=(const ClientSocketPoolTest&) = delete; 1195 1196 ~ClientSocketPoolTest(); 1197 1198 template <typename PoolType> StartRequestUsingPool(PoolType * socket_pool,const ClientSocketPool::GroupId & group_id,RequestPriority priority,ClientSocketPool::RespectLimits respect_limits,const scoped_refptr<typename PoolType::SocketParams> & socket_params)1199 int StartRequestUsingPool( 1200 PoolType* socket_pool, 1201 const ClientSocketPool::GroupId& group_id, 1202 RequestPriority priority, 1203 ClientSocketPool::RespectLimits respect_limits, 1204 const scoped_refptr<typename PoolType::SocketParams>& socket_params) { 1205 DCHECK(socket_pool); 1206 TestSocketRequest* request( 1207 new TestSocketRequest(&request_order_, &completion_count_)); 1208 requests_.push_back(base::WrapUnique(request)); 1209 int rv = request->handle()->Init( 1210 group_id, socket_params, std::nullopt /* proxy_annotation_tag */, 1211 priority, SocketTag(), respect_limits, request->callback(), 1212 ClientSocketPool::ProxyAuthCallback(), socket_pool, NetLogWithSource()); 1213 if (rv != ERR_IO_PENDING) 1214 request_order_.push_back(request); 1215 return rv; 1216 } 1217 1218 // Provided there were n requests started, takes |index| in range 1..n 1219 // and returns order in which that request completed, in range 1..n, 1220 // or kIndexOutOfBounds if |index| is out of bounds, or kRequestNotFound 1221 // if that request did not complete (for example was canceled). 1222 int GetOrderOfRequest(size_t index) const; 1223 1224 // Resets first initialized socket handle from |requests_|. If found such 1225 // a handle, returns true. 1226 bool ReleaseOneConnection(KeepAlive keep_alive); 1227 1228 // Releases connections until there is nothing to release. 1229 void ReleaseAllConnections(KeepAlive keep_alive); 1230 1231 // Note that this uses 0-based indices, while GetOrderOfRequest takes and 1232 // returns 1-based indices. request(int i)1233 TestSocketRequest* request(int i) { return requests_[i].get(); } 1234 requests_size()1235 size_t requests_size() const { return requests_.size(); } requests()1236 std::vector<std::unique_ptr<TestSocketRequest>>* requests() { 1237 return &requests_; 1238 } completion_count()1239 size_t completion_count() const { return completion_count_; } 1240 1241 private: 1242 std::vector<std::unique_ptr<TestSocketRequest>> requests_; 1243 std::vector<raw_ptr<TestSocketRequest, VectorExperimental>> request_order_; 1244 size_t completion_count_ = 0; 1245 }; 1246 1247 class MockTransportSocketParams 1248 : public base::RefCounted<MockTransportSocketParams> { 1249 public: 1250 MockTransportSocketParams(const MockTransportSocketParams&) = delete; 1251 MockTransportSocketParams& operator=(const MockTransportSocketParams&) = 1252 delete; 1253 1254 private: 1255 friend class base::RefCounted<MockTransportSocketParams>; 1256 ~MockTransportSocketParams() = default; 1257 }; 1258 1259 class MockTransportClientSocketPool : public TransportClientSocketPool { 1260 public: 1261 class MockConnectJob { 1262 public: 1263 MockConnectJob(std::unique_ptr<StreamSocket> socket, 1264 ClientSocketHandle* handle, 1265 const SocketTag& socket_tag, 1266 CompletionOnceCallback callback, 1267 RequestPriority priority); 1268 1269 MockConnectJob(const MockConnectJob&) = delete; 1270 MockConnectJob& operator=(const MockConnectJob&) = delete; 1271 1272 ~MockConnectJob(); 1273 1274 int Connect(); 1275 bool CancelHandle(const ClientSocketHandle* handle); 1276 handle()1277 ClientSocketHandle* handle() const { return handle_; } 1278 priority()1279 RequestPriority priority() const { return priority_; } set_priority(RequestPriority priority)1280 void set_priority(RequestPriority priority) { priority_ = priority; } 1281 1282 private: 1283 void OnConnect(int rv); 1284 1285 std::unique_ptr<StreamSocket> socket_; 1286 raw_ptr<ClientSocketHandle> handle_; 1287 const SocketTag socket_tag_; 1288 CompletionOnceCallback user_callback_; 1289 RequestPriority priority_; 1290 }; 1291 1292 MockTransportClientSocketPool( 1293 int max_sockets, 1294 int max_sockets_per_group, 1295 const CommonConnectJobParams* common_connect_job_params); 1296 1297 MockTransportClientSocketPool(const MockTransportClientSocketPool&) = delete; 1298 MockTransportClientSocketPool& operator=( 1299 const MockTransportClientSocketPool&) = delete; 1300 1301 ~MockTransportClientSocketPool() override; 1302 last_request_priority()1303 RequestPriority last_request_priority() const { 1304 return last_request_priority_; 1305 } 1306 requests()1307 const std::vector<std::unique_ptr<MockConnectJob>>& requests() const { 1308 return job_list_; 1309 } 1310 release_count()1311 int release_count() const { return release_count_; } cancel_count()1312 int cancel_count() const { return cancel_count_; } 1313 1314 // TransportClientSocketPool implementation. 1315 int RequestSocket( 1316 const GroupId& group_id, 1317 scoped_refptr<ClientSocketPool::SocketParams> socket_params, 1318 const std::optional<NetworkTrafficAnnotationTag>& proxy_annotation_tag, 1319 RequestPriority priority, 1320 const SocketTag& socket_tag, 1321 RespectLimits respect_limits, 1322 ClientSocketHandle* handle, 1323 CompletionOnceCallback callback, 1324 const ProxyAuthCallback& on_auth_callback, 1325 const NetLogWithSource& net_log) override; 1326 void SetPriority(const GroupId& group_id, 1327 ClientSocketHandle* handle, 1328 RequestPriority priority) override; 1329 void CancelRequest(const GroupId& group_id, 1330 ClientSocketHandle* handle, 1331 bool cancel_connect_job) override; 1332 void ReleaseSocket(const GroupId& group_id, 1333 std::unique_ptr<StreamSocket> socket, 1334 int64_t generation) override; 1335 1336 private: 1337 raw_ptr<ClientSocketFactory> client_socket_factory_; 1338 std::vector<std::unique_ptr<MockConnectJob>> job_list_; 1339 RequestPriority last_request_priority_ = DEFAULT_PRIORITY; 1340 int release_count_ = 0; 1341 int cancel_count_ = 0; 1342 }; 1343 1344 // WrappedStreamSocket is a base class that wraps an existing StreamSocket, 1345 // forwarding the Socket and StreamSocket interfaces to the underlying 1346 // transport. 1347 // This is to provide a common base class for subclasses to override specific 1348 // StreamSocket methods for testing, while still communicating with a 'real' 1349 // StreamSocket. 1350 class WrappedStreamSocket : public TransportClientSocket { 1351 public: 1352 explicit WrappedStreamSocket(std::unique_ptr<StreamSocket> transport); 1353 ~WrappedStreamSocket() override; 1354 1355 // StreamSocket implementation: 1356 int Bind(const net::IPEndPoint& local_addr) override; 1357 int Connect(CompletionOnceCallback callback) override; 1358 void Disconnect() override; 1359 bool IsConnected() const override; 1360 bool IsConnectedAndIdle() const override; 1361 int GetPeerAddress(IPEndPoint* address) const override; 1362 int GetLocalAddress(IPEndPoint* address) const override; 1363 const NetLogWithSource& NetLog() const override; 1364 bool WasEverUsed() const override; 1365 NextProto GetNegotiatedProtocol() const override; 1366 bool GetSSLInfo(SSLInfo* ssl_info) override; 1367 int64_t GetTotalReceivedBytes() const override; 1368 void ApplySocketTag(const SocketTag& tag) override; 1369 1370 // Socket implementation: 1371 int Read(IOBuffer* buf, 1372 int buf_len, 1373 CompletionOnceCallback callback) override; 1374 int ReadIfReady(IOBuffer* buf, 1375 int buf_len, 1376 CompletionOnceCallback callback) override; 1377 int Write(IOBuffer* buf, 1378 int buf_len, 1379 CompletionOnceCallback callback, 1380 const NetworkTrafficAnnotationTag& traffic_annotation) override; 1381 int SetReceiveBufferSize(int32_t size) override; 1382 int SetSendBufferSize(int32_t size) override; 1383 1384 protected: 1385 std::unique_ptr<StreamSocket> transport_; 1386 }; 1387 1388 // StreamSocket that wraps another StreamSocket, but keeps track of any 1389 // SocketTag applied to the socket. 1390 class MockTaggingStreamSocket : public WrappedStreamSocket { 1391 public: MockTaggingStreamSocket(std::unique_ptr<StreamSocket> transport)1392 explicit MockTaggingStreamSocket(std::unique_ptr<StreamSocket> transport) 1393 : WrappedStreamSocket(std::move(transport)) {} 1394 1395 MockTaggingStreamSocket(const MockTaggingStreamSocket&) = delete; 1396 MockTaggingStreamSocket& operator=(const MockTaggingStreamSocket&) = delete; 1397 1398 ~MockTaggingStreamSocket() override = default; 1399 1400 // StreamSocket implementation. 1401 int Connect(CompletionOnceCallback callback) override; 1402 void ApplySocketTag(const SocketTag& tag) override; 1403 1404 // Returns false if socket's tag was changed after the socket was connected, 1405 // otherwise returns true. tagged_before_connected()1406 bool tagged_before_connected() const { return tagged_before_connected_; } 1407 1408 // Returns last tag applied to socket. tag()1409 SocketTag tag() const { return tag_; } 1410 1411 private: 1412 bool connected_ = false; 1413 bool tagged_before_connected_ = true; 1414 SocketTag tag_; 1415 }; 1416 1417 // Extend MockClientSocketFactory to return MockTaggingStreamSockets and 1418 // keep track of last socket produced for test inspection. 1419 class MockTaggingClientSocketFactory : public MockClientSocketFactory { 1420 public: 1421 MockTaggingClientSocketFactory() = default; 1422 1423 MockTaggingClientSocketFactory(const MockTaggingClientSocketFactory&) = 1424 delete; 1425 MockTaggingClientSocketFactory& operator=( 1426 const MockTaggingClientSocketFactory&) = delete; 1427 1428 // ClientSocketFactory implementation. 1429 std::unique_ptr<DatagramClientSocket> CreateDatagramClientSocket( 1430 DatagramSocket::BindType bind_type, 1431 NetLog* net_log, 1432 const NetLogSource& source) override; 1433 std::unique_ptr<TransportClientSocket> CreateTransportClientSocket( 1434 const AddressList& addresses, 1435 std::unique_ptr<SocketPerformanceWatcher> socket_performance_watcher, 1436 NetworkQualityEstimator* network_quality_estimator, 1437 NetLog* net_log, 1438 const NetLogSource& source) override; 1439 1440 // These methods return pointers to last TCP and UDP sockets produced by this 1441 // factory. NOTE: Socket must still exist, or pointer will be to freed memory. GetLastProducedTCPSocket()1442 MockTaggingStreamSocket* GetLastProducedTCPSocket() const { 1443 return tcp_socket_; 1444 } GetLastProducedUDPSocket()1445 MockUDPClientSocket* GetLastProducedUDPSocket() const { return udp_socket_; } 1446 1447 private: 1448 raw_ptr<MockTaggingStreamSocket, AcrossTasksDanglingUntriaged> tcp_socket_ = 1449 nullptr; 1450 raw_ptr<MockUDPClientSocket, AcrossTasksDanglingUntriaged> udp_socket_ = 1451 nullptr; 1452 }; 1453 1454 // Host / port used for SOCKS4 test strings. 1455 extern const char kSOCKS4TestHost[]; 1456 extern const int kSOCKS4TestPort; 1457 1458 // Constants for a successful SOCKS v4 handshake (connecting to kSOCKS4TestHost 1459 // on port kSOCKS4TestPort, for the request). 1460 extern const char kSOCKS4OkRequestLocalHostPort80[]; 1461 extern const int kSOCKS4OkRequestLocalHostPort80Length; 1462 1463 extern const char kSOCKS4OkReply[]; 1464 extern const int kSOCKS4OkReplyLength; 1465 1466 // Host / port used for SOCKS5 test strings. 1467 extern const char kSOCKS5TestHost[]; 1468 extern const int kSOCKS5TestPort; 1469 1470 // Constants for a successful SOCKS v5 handshake (connecting to kSOCKS5TestHost 1471 // on port kSOCKS5TestPort, for the request).. 1472 extern const char kSOCKS5GreetRequest[]; 1473 extern const int kSOCKS5GreetRequestLength; 1474 1475 extern const char kSOCKS5GreetResponse[]; 1476 extern const int kSOCKS5GreetResponseLength; 1477 1478 extern const char kSOCKS5OkRequest[]; 1479 extern const int kSOCKS5OkRequestLength; 1480 1481 extern const char kSOCKS5OkResponse[]; 1482 extern const int kSOCKS5OkResponseLength; 1483 1484 // Helper function to get the total data size of the MockReads in |reads|. 1485 int64_t CountReadBytes(base::span<const MockRead> reads); 1486 1487 // Helper function to get the total data size of the MockWrites in |writes|. 1488 int64_t CountWriteBytes(base::span<const MockWrite> writes); 1489 1490 #if BUILDFLAG(IS_ANDROID) 1491 // Returns whether the device supports calling GetTaggedBytes(). 1492 bool CanGetTaggedBytes(); 1493 1494 // Query the system to find out how many bytes were received with tag 1495 // |expected_tag| for our UID. Return the count of received bytes. 1496 uint64_t GetTaggedBytes(int32_t expected_tag); 1497 #endif 1498 1499 } // namespace net 1500 1501 #endif // NET_SOCKET_SOCKET_TEST_UTIL_H_ 1502