1 // Copyright 2012 The Chromium Authors 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #ifndef NET_SOCKET_SOCKET_TEST_UTIL_H_ 6 #define NET_SOCKET_SOCKET_TEST_UTIL_H_ 7 8 #include <stddef.h> 9 #include <stdint.h> 10 11 #include <cstring> 12 #include <memory> 13 #include <string> 14 #include <utility> 15 #include <vector> 16 17 #include "base/check_op.h" 18 #include "base/containers/span.h" 19 #include "base/functional/bind.h" 20 #include "base/functional/callback.h" 21 #include "base/memory/ptr_util.h" 22 #include "base/memory/raw_ptr.h" 23 #include "base/memory/raw_ptr_exclusion.h" 24 #include "base/memory/ref_counted.h" 25 #include "base/memory/weak_ptr.h" 26 #include "build/build_config.h" 27 #include "net/base/address_list.h" 28 #include "net/base/completion_once_callback.h" 29 #include "net/base/io_buffer.h" 30 #include "net/base/net_errors.h" 31 #include "net/base/test_completion_callback.h" 32 #include "net/http/http_auth_controller.h" 33 #include "net/log/net_log_with_source.h" 34 #include "net/socket/client_socket_factory.h" 35 #include "net/socket/client_socket_handle.h" 36 #include "net/socket/client_socket_pool.h" 37 #include "net/socket/datagram_client_socket.h" 38 #include "net/socket/socket_performance_watcher.h" 39 #include "net/socket/socket_tag.h" 40 #include "net/socket/ssl_client_socket.h" 41 #include "net/socket/transport_client_socket.h" 42 #include "net/socket/transport_client_socket_pool.h" 43 #include "net/ssl/ssl_config_service.h" 44 #include "net/ssl/ssl_info.h" 45 #include "testing/gtest/include/gtest/gtest.h" 46 #include "third_party/abseil-cpp/absl/types/optional.h" 47 48 namespace base { 49 class RunLoop; 50 } 51 52 namespace net { 53 54 struct CommonConnectJobParams; 55 class NetLog; 56 struct NetworkTrafficAnnotationTag; 57 class X509Certificate; 58 59 const handles::NetworkHandle kDefaultNetworkForTests = 1; 60 const handles::NetworkHandle kNewNetworkForTests = 2; 61 62 enum { 63 // A private network error code used by the socket test utility classes. 64 // If the |result| member of a MockRead is 65 // ERR_TEST_PEER_CLOSE_AFTER_NEXT_MOCK_READ, that MockRead is just a 66 // marker that indicates the peer will close the connection after the next 67 // MockRead. The other members of that MockRead are ignored. 68 ERR_TEST_PEER_CLOSE_AFTER_NEXT_MOCK_READ = -10000, 69 }; 70 71 class AsyncSocket; 72 class MockClientSocket; 73 class SSLClientSocket; 74 class StreamSocket; 75 76 enum IoMode { ASYNC, SYNCHRONOUS }; 77 78 struct MockConnect { 79 // Asynchronous connection success. 80 // Creates a MockConnect with |mode| ASYC, |result| OK, and 81 // |peer_addr| 192.0.2.33. 82 MockConnect(); 83 // Creates a MockConnect with the specified mode and result, with 84 // |peer_addr| 192.0.2.33. 85 MockConnect(IoMode io_mode, int r); 86 MockConnect(IoMode io_mode, int r, IPEndPoint addr); 87 MockConnect(IoMode io_mode, int r, IPEndPoint addr, bool first_attempt_fails); 88 ~MockConnect(); 89 90 IoMode mode; 91 int result; 92 IPEndPoint peer_addr; 93 bool first_attempt_fails = false; 94 }; 95 96 struct MockConfirm { 97 // Asynchronous confirm success. 98 // Creates a MockConfirm with |mode| ASYC and |result| OK. 99 MockConfirm(); 100 // Creates a MockConfirm with the specified mode and result. 101 MockConfirm(IoMode io_mode, int r); 102 ~MockConfirm(); 103 104 IoMode mode; 105 int result; 106 }; 107 108 // MockRead and MockWrite shares the same interface and members, but we'd like 109 // to have distinct types because we don't want to have them used 110 // interchangably. To do this, a struct template is defined, and MockRead and 111 // MockWrite are instantiated by using this template. Template parameter |type| 112 // is not used in the struct definition (it purely exists for creating a new 113 // type). 114 // 115 // |data| in MockRead and MockWrite has different meanings: |data| in MockRead 116 // is the data returned from the socket when MockTCPClientSocket::Read() is 117 // attempted, while |data| in MockWrite is the expected data that should be 118 // given in MockTCPClientSocket::Write(). 119 enum MockReadWriteType { MOCK_READ, MOCK_WRITE }; 120 121 template <MockReadWriteType type> 122 struct MockReadWrite { 123 // Flag to indicate that the message loop should be terminated. 124 enum { STOPLOOP = 1 << 31 }; 125 126 // Default MockReadWriteMockReadWrite127 MockReadWrite() 128 : mode(SYNCHRONOUS), 129 result(0), 130 data(nullptr), 131 data_len(0), 132 sequence_number(0) {} 133 134 // Read/write failure (no data). MockReadWriteMockReadWrite135 MockReadWrite(IoMode io_mode, int result) 136 : mode(io_mode), 137 result(result), 138 data(nullptr), 139 data_len(0), 140 sequence_number(0) {} 141 142 // Read/write failure (no data), with sequence information. MockReadWriteMockReadWrite143 MockReadWrite(IoMode io_mode, int result, int seq) 144 : mode(io_mode), 145 result(result), 146 data(nullptr), 147 data_len(0), 148 sequence_number(seq) {} 149 150 // Asynchronous read/write success (inferred data length). MockReadWriteMockReadWrite151 explicit MockReadWrite(const char* data) 152 : mode(ASYNC), 153 result(0), 154 data(data), 155 data_len(strlen(data)), 156 sequence_number(0) {} 157 158 // Read/write success (inferred data length). MockReadWriteMockReadWrite159 MockReadWrite(IoMode io_mode, const char* data) 160 : mode(io_mode), 161 result(0), 162 data(data), 163 data_len(strlen(data)), 164 sequence_number(0) {} 165 166 // Read/write success. MockReadWriteMockReadWrite167 MockReadWrite(IoMode io_mode, const char* data, int data_len) 168 : mode(io_mode), 169 result(0), 170 data(data), 171 data_len(data_len), 172 sequence_number(0) {} 173 174 // Read/write success (inferred data length) with sequence information. MockReadWriteMockReadWrite175 MockReadWrite(IoMode io_mode, int seq, const char* data) 176 : mode(io_mode), 177 result(0), 178 data(data), 179 data_len(strlen(data)), 180 sequence_number(seq) {} 181 182 // Read/write success with sequence information. MockReadWriteMockReadWrite183 MockReadWrite(IoMode io_mode, const char* data, int data_len, int seq) 184 : mode(io_mode), 185 result(0), 186 data(data), 187 data_len(data_len), 188 sequence_number(seq) {} 189 190 IoMode mode; 191 int result; 192 const char* data; 193 int data_len; 194 195 // For data providers that only allows reads to occur in a particular 196 // sequence. If a read occurs before the given |sequence_number| is reached, 197 // an ERR_IO_PENDING is returned. 198 int sequence_number; // The sequence number at which a read is allowed 199 // to occur. 200 }; 201 202 typedef MockReadWrite<MOCK_READ> MockRead; 203 typedef MockReadWrite<MOCK_WRITE> MockWrite; 204 205 struct MockWriteResult { MockWriteResultMockWriteResult206 MockWriteResult(IoMode io_mode, int result) : mode(io_mode), result(result) {} 207 208 IoMode mode; 209 int result; 210 }; 211 212 // The SocketDataProvider is an interface used by the MockClientSocket 213 // for getting data about individual reads and writes on the socket. Can be 214 // used with at most one socket at a time. 215 // TODO(mmenke): Do these really need to be re-useable? 216 class SocketDataProvider { 217 public: 218 SocketDataProvider(); 219 220 SocketDataProvider(const SocketDataProvider&) = delete; 221 SocketDataProvider& operator=(const SocketDataProvider&) = delete; 222 223 virtual ~SocketDataProvider(); 224 225 // Returns the buffer and result code for the next simulated read. 226 // If the |MockRead.result| is ERR_IO_PENDING, it informs the caller 227 // that it will be called via the AsyncSocket::OnReadComplete() 228 // function at a later time. 229 virtual MockRead OnRead() = 0; 230 virtual MockWriteResult OnWrite(const std::string& data) = 0; 231 virtual bool AllReadDataConsumed() const = 0; 232 virtual bool AllWriteDataConsumed() const = 0; CancelPendingRead()233 virtual void CancelPendingRead() {} 234 235 // Returns the last set receive buffer size, or -1 if never set. receive_buffer_size()236 int receive_buffer_size() const { return receive_buffer_size_; } set_receive_buffer_size(int receive_buffer_size)237 void set_receive_buffer_size(int receive_buffer_size) { 238 receive_buffer_size_ = receive_buffer_size; 239 } 240 241 // Returns the last set send buffer size, or -1 if never set. send_buffer_size()242 int send_buffer_size() const { return send_buffer_size_; } set_send_buffer_size(int send_buffer_size)243 void set_send_buffer_size(int send_buffer_size) { 244 send_buffer_size_ = send_buffer_size; 245 } 246 247 // Returns the last set value of TCP no delay, or false if never set. no_delay()248 bool no_delay() const { return no_delay_; } set_no_delay(bool no_delay)249 void set_no_delay(bool no_delay) { no_delay_ = no_delay; } 250 251 // Returns whether TCP keepalives were enabled or not. Returns kDefault by 252 // default. 253 enum class KeepAliveState { kEnabled, kDisabled, kDefault }; keep_alive_state()254 KeepAliveState keep_alive_state() const { return keep_alive_state_; } 255 // Last set TCP keepalive delay. keep_alive_delay()256 int keep_alive_delay() const { return keep_alive_delay_; } set_keep_alive(bool enable,int delay)257 void set_keep_alive(bool enable, int delay) { 258 keep_alive_state_ = 259 enable ? KeepAliveState::kEnabled : KeepAliveState::kDisabled; 260 keep_alive_delay_ = delay; 261 } 262 263 // Setters / getters for the return values of the corresponding Set*() 264 // methods. By default, they all succeed, if the socket is connected. 265 set_set_receive_buffer_size_result(int receive_buffer_size_result)266 void set_set_receive_buffer_size_result(int receive_buffer_size_result) { 267 set_receive_buffer_size_result_ = receive_buffer_size_result; 268 } set_receive_buffer_size_result()269 int set_receive_buffer_size_result() const { 270 return set_receive_buffer_size_result_; 271 } 272 set_set_send_buffer_size_result(int set_send_buffer_size_result)273 void set_set_send_buffer_size_result(int set_send_buffer_size_result) { 274 set_send_buffer_size_result_ = set_send_buffer_size_result; 275 } set_send_buffer_size_result()276 int set_send_buffer_size_result() const { 277 return set_send_buffer_size_result_; 278 } 279 set_set_no_delay_result(bool set_no_delay_result)280 void set_set_no_delay_result(bool set_no_delay_result) { 281 set_no_delay_result_ = set_no_delay_result; 282 } set_no_delay_result()283 bool set_no_delay_result() const { return set_no_delay_result_; } 284 set_set_keep_alive_result(bool set_keep_alive_result)285 void set_set_keep_alive_result(bool set_keep_alive_result) { 286 set_keep_alive_result_ = set_keep_alive_result; 287 } set_keep_alive_result()288 bool set_keep_alive_result() const { return set_keep_alive_result_; } 289 expected_addresses()290 const absl::optional<AddressList>& expected_addresses() const { 291 return expected_addresses_; 292 } set_expected_addresses(net::AddressList addresses)293 void set_expected_addresses(net::AddressList addresses) { 294 expected_addresses_ = std::move(addresses); 295 } 296 297 // Returns true if the request should be considered idle, for the purposes of 298 // IsConnectedAndIdle. 299 virtual bool IsIdle() const; 300 301 // Initializes the SocketDataProvider for use with |socket|. Must be called 302 // before use 303 void Initialize(AsyncSocket* socket); 304 // Detaches the socket associated with a SocketDataProvider. Must be called 305 // before |socket_| is destroyed, unless the SocketDataProvider has informed 306 // |socket_| it was destroyed. Must also be called before Initialize() may 307 // be called again with a new socket. 308 void DetachSocket(); 309 310 // Accessor for the socket which is using the SocketDataProvider. socket()311 AsyncSocket* socket() { return socket_; } 312 connect_data()313 MockConnect connect_data() const { return connect_; } set_connect_data(const MockConnect & connect)314 void set_connect_data(const MockConnect& connect) { connect_ = connect; } 315 316 private: 317 // Called to inform subclasses of initialization. 318 virtual void Reset() = 0; 319 320 MockConnect connect_; 321 raw_ptr<AsyncSocket> socket_ = nullptr; 322 323 int receive_buffer_size_ = -1; 324 int send_buffer_size_ = -1; 325 // This reflects the default state of TCPClientSockets. 326 bool no_delay_ = true; 327 328 KeepAliveState keep_alive_state_ = KeepAliveState::kDefault; 329 int keep_alive_delay_ = 0; 330 331 int set_receive_buffer_size_result_ = net::OK; 332 int set_send_buffer_size_result_ = net::OK; 333 bool set_no_delay_result_ = true; 334 bool set_keep_alive_result_ = true; 335 absl::optional<AddressList> expected_addresses_; 336 }; 337 338 // The AsyncSocket is an interface used by the SocketDataProvider to 339 // complete the asynchronous read operation. 340 class AsyncSocket { 341 public: 342 // If an async IO is pending because the SocketDataProvider returned 343 // ERR_IO_PENDING, then the AsyncSocket waits until this OnReadComplete 344 // is called to complete the asynchronous read operation. 345 // data.async is ignored, and this read is completed synchronously as 346 // part of this call. 347 // TODO(rch): this should take a StringPiece since most of the fields 348 // are ignored. 349 virtual void OnReadComplete(const MockRead& data) = 0; 350 // If an async IO is pending because the SocketDataProvider returned 351 // ERR_IO_PENDING, then the AsyncSocket waits until this OnReadComplete 352 // is called to complete the asynchronous read operation. 353 virtual void OnWriteComplete(int rv) = 0; 354 virtual void OnConnectComplete(const MockConnect& data) = 0; 355 356 // Called when the SocketDataProvider associated with the socket is destroyed. 357 // The socket may continue to be used after the data provider is destroyed, 358 // so it should be sure not to dereference the provider after this is called. 359 virtual void OnDataProviderDestroyed() = 0; 360 }; 361 362 class SocketDataPrinter { 363 public: 364 ~SocketDataPrinter() = default; 365 366 // Prints the write in |data| using some sort of protocol-specific 367 // format. 368 virtual std::string PrintWrite(const std::string& data) = 0; 369 }; 370 371 // StaticSocketDataHelper manages a list of reads and writes. 372 class StaticSocketDataHelper { 373 public: 374 StaticSocketDataHelper(base::span<const MockRead> reads, 375 base::span<const MockWrite> writes); 376 377 StaticSocketDataHelper(const StaticSocketDataHelper&) = delete; 378 StaticSocketDataHelper& operator=(const StaticSocketDataHelper&) = delete; 379 380 ~StaticSocketDataHelper(); 381 382 // These functions get access to the next available read and write data. They 383 // CHECK fail if there is no data available. 384 const MockRead& PeekRead() const; 385 const MockWrite& PeekWrite() const; 386 387 // Returns the current read or write, and then advances to the next one. 388 const MockRead& AdvanceRead(); 389 const MockWrite& AdvanceWrite(); 390 391 // Resets the read and write indexes to 0. 392 void Reset(); 393 394 // Returns true if |data| is valid data for the next write. In order 395 // to support short writes, the next write may be longer than |data| 396 // in which case this method will still return true. 397 bool VerifyWriteData(const std::string& data, SocketDataPrinter* printer); 398 read_index()399 size_t read_index() const { return read_index_; } write_index()400 size_t write_index() const { return write_index_; } read_count()401 size_t read_count() const { return reads_.size(); } write_count()402 size_t write_count() const { return writes_.size(); } 403 AllReadDataConsumed()404 bool AllReadDataConsumed() const { return read_index() >= read_count(); } AllWriteDataConsumed()405 bool AllWriteDataConsumed() const { return write_index() >= write_count(); } 406 407 private: 408 // Returns the next available read or write that is not a pause event. CHECK 409 // fails if no data is available. 410 const MockWrite& PeekRealWrite() const; 411 412 const base::span<const MockRead> reads_; 413 size_t read_index_ = 0; 414 const base::span<const MockWrite> writes_; 415 size_t write_index_ = 0; 416 }; 417 418 // SocketDataProvider which responds based on static tables of mock reads and 419 // writes. 420 class StaticSocketDataProvider : public SocketDataProvider { 421 public: 422 StaticSocketDataProvider(); 423 StaticSocketDataProvider(base::span<const MockRead> reads, 424 base::span<const MockWrite> writes); 425 426 StaticSocketDataProvider(const StaticSocketDataProvider&) = delete; 427 StaticSocketDataProvider& operator=(const StaticSocketDataProvider&) = delete; 428 429 ~StaticSocketDataProvider() override; 430 431 // Pause/resume reads from this provider. 432 void Pause(); 433 void Resume(); 434 435 // From SocketDataProvider: 436 MockRead OnRead() override; 437 MockWriteResult OnWrite(const std::string& data) override; 438 bool AllReadDataConsumed() const override; 439 bool AllWriteDataConsumed() const override; 440 read_index()441 size_t read_index() const { return helper_.read_index(); } write_index()442 size_t write_index() const { return helper_.write_index(); } read_count()443 size_t read_count() const { return helper_.read_count(); } write_count()444 size_t write_count() const { return helper_.write_count(); } 445 set_printer(SocketDataPrinter * printer)446 void set_printer(SocketDataPrinter* printer) { printer_ = printer; } 447 448 private: 449 // From SocketDataProvider: 450 void Reset() override; 451 452 StaticSocketDataHelper helper_; 453 // This field is not a raw_ptr<> because it was filtered by the rewriter for: 454 // #union 455 RAW_PTR_EXCLUSION SocketDataPrinter* printer_ = nullptr; 456 bool paused_ = false; 457 }; 458 459 // SSLSocketDataProviders only need to keep track of the return code from calls 460 // to Connect(). 461 struct SSLSocketDataProvider { 462 SSLSocketDataProvider(IoMode mode, int result); 463 SSLSocketDataProvider(const SSLSocketDataProvider& other); 464 ~SSLSocketDataProvider(); 465 466 // Returns whether MockConnect data has been consumed. ConnectDataConsumedSSLSocketDataProvider467 bool ConnectDataConsumed() const { return is_connect_data_consumed; } 468 469 // Returns whether MockConfirm data has been consumed. ConfirmDataConsumedSSLSocketDataProvider470 bool ConfirmDataConsumed() const { return is_confirm_data_consumed; } 471 472 // Returns whether a Write occurred before ConfirmHandshake completed. WriteBeforeConfirmSSLSocketDataProvider473 bool WriteBeforeConfirm() const { return write_called_before_confirm; } 474 475 // Result for Connect(). 476 MockConnect connect; 477 // Callback to run when Connect() is called. This is called at most once per 478 // socket but is repeating because SSLSocketDataProvider is copyable. 479 base::RepeatingClosure connect_callback; 480 481 // Result for ConfirmHandshake(). 482 MockConfirm confirm; 483 // Callback to run when ConfirmHandshake() is called. This is called at most 484 // once per socket but is repeating because SSLSocketDataProvider is 485 // copyable. 486 base::RepeatingClosure confirm_callback; 487 488 // Result for GetNegotiatedProtocol(). 489 NextProto next_proto = kProtoUnknown; 490 491 // Result for GetPeerApplicationSettings(). 492 absl::optional<std::string> peer_application_settings; 493 494 // Result for GetSSLInfo(). 495 SSLInfo ssl_info; 496 497 // Result for GetSSLCertRequestInfo(). 498 // This field is not a raw_ptr<> because it was filtered by the rewriter for: 499 // #union 500 RAW_PTR_EXCLUSION SSLCertRequestInfo* cert_request_info = nullptr; 501 502 // Result for GetECHRetryConfigs(). 503 std::vector<uint8_t> ech_retry_configs; 504 505 absl::optional<NextProtoVector> next_protos_expected_in_ssl_config; 506 507 uint16_t expected_ssl_version_min; 508 uint16_t expected_ssl_version_max; 509 absl::optional<bool> expected_send_client_cert; 510 scoped_refptr<X509Certificate> expected_client_cert; 511 absl::optional<HostPortPair> expected_host_and_port; 512 absl::optional<NetworkAnonymizationKey> expected_network_anonymization_key; 513 absl::optional<bool> expected_disable_sha1_server_signatures; 514 absl::optional<std::vector<uint8_t>> expected_ech_config_list; 515 516 bool is_connect_data_consumed = false; 517 bool is_confirm_data_consumed = false; 518 bool write_called_before_confirm = false; 519 }; 520 521 // Uses the sequence_number field in the mock reads and writes to 522 // complete the operations in a specified order. 523 class SequencedSocketData : public SocketDataProvider { 524 public: 525 SequencedSocketData(); 526 527 // |reads| is the list of MockRead completions. 528 // |writes| is the list of MockWrite completions. 529 SequencedSocketData(base::span<const MockRead> reads, 530 base::span<const MockWrite> writes); 531 532 // |connect| is the result for the connect phase. 533 // |reads| is the list of MockRead completions. 534 // |writes| is the list of MockWrite completions. 535 SequencedSocketData(const MockConnect& connect, 536 base::span<const MockRead> reads, 537 base::span<const MockWrite> writes); 538 539 SequencedSocketData(const SequencedSocketData&) = delete; 540 SequencedSocketData& operator=(const SequencedSocketData&) = delete; 541 542 ~SequencedSocketData() override; 543 544 // From SocketDataProvider: 545 MockRead OnRead() override; 546 MockWriteResult OnWrite(const std::string& data) override; 547 bool AllReadDataConsumed() const override; 548 bool AllWriteDataConsumed() const override; 549 bool IsIdle() const override; 550 void CancelPendingRead() override; 551 552 // An ASYNC read event with a return value of ERR_IO_PENDING will cause the 553 // socket data to pause at that event, and advance no further, until Resume is 554 // invoked. At that point, the socket will continue at the next event in the 555 // sequence. 556 // 557 // If a request just wants to simulate a connection that stays open and never 558 // receives any more data, instead of pausing and then resuming a request, it 559 // should use a SYNCHRONOUS event with a return value of ERR_IO_PENDING 560 // instead. 561 bool IsPaused() const; 562 // Resumes events once |this| is in the paused state. The next event will 563 // occur synchronously with the call if it can. 564 void Resume(); 565 void RunUntilPaused(); 566 567 // When true, IsConnectedAndIdle() will return false if the next event in the 568 // sequence is a synchronous. Otherwise, the socket claims to be idle as 569 // long as it's connected. Defaults to false. 570 // TODO(mmenke): See if this can be made the default behavior, and consider 571 // removing this mehtod. Need to make sure it doesn't change what code any 572 // tests are targetted at testing. set_busy_before_sync_reads(bool busy_before_sync_reads)573 void set_busy_before_sync_reads(bool busy_before_sync_reads) { 574 busy_before_sync_reads_ = busy_before_sync_reads; 575 } 576 set_printer(SocketDataPrinter * printer)577 void set_printer(SocketDataPrinter* printer) { printer_ = printer; } 578 579 private: 580 // Defines the state for the read or write path. 581 enum class IoState { 582 kIdle, // No async operation is in progress. 583 kPending, // An async operation in waiting for another operation to 584 // complete. 585 kCompleting, // A task has been posted to complete an async operation. 586 kPaused, // IO is paused until Resume() is called. 587 }; 588 589 // From SocketDataProvider: 590 void Reset() override; 591 592 void OnReadComplete(); 593 void OnWriteComplete(); 594 595 void MaybePostReadCompleteTask(); 596 void MaybePostWriteCompleteTask(); 597 598 StaticSocketDataHelper helper_; 599 raw_ptr<SocketDataPrinter> printer_ = nullptr; 600 int sequence_number_ = 0; 601 IoState read_state_ = IoState::kIdle; 602 IoState write_state_ = IoState::kIdle; 603 604 bool busy_before_sync_reads_ = false; 605 606 // Used by RunUntilPaused. NULL at all other times. 607 std::unique_ptr<base::RunLoop> run_until_paused_run_loop_; 608 609 base::WeakPtrFactory<SequencedSocketData> weak_factory_{this}; 610 }; 611 612 // Holds an array of SocketDataProvider elements. As Mock{TCP,SSL}StreamSocket 613 // objects get instantiated, they take their data from the i'th element of this 614 // array. 615 template <typename T> 616 class SocketDataProviderArray { 617 public: 618 SocketDataProviderArray() = default; 619 GetNext()620 T* GetNext() { 621 DCHECK_LT(next_index_, data_providers_.size()); 622 return data_providers_[next_index_++]; 623 } 624 625 // Like GetNext(), but returns nullptr when the end of the array is reached, 626 // instead of DCHECKing. GetNext() should generally be preferred, unless 627 // having no remaining elements is expected in some cases and is handled 628 // safely. GetNextWithoutAsserting()629 T* GetNextWithoutAsserting() { 630 if (next_index_ == data_providers_.size()) 631 return nullptr; 632 return data_providers_[next_index_++]; 633 } 634 Add(T * data_provider)635 void Add(T* data_provider) { 636 DCHECK(data_provider); 637 data_providers_.push_back(data_provider); 638 } 639 next_index()640 size_t next_index() { return next_index_; } 641 ResetNextIndex()642 void ResetNextIndex() { next_index_ = 0; } 643 644 private: 645 // Index of the next |data_providers_| element to use. Not an iterator 646 // because those are invalidated on vector reallocation. 647 size_t next_index_ = 0; 648 649 // SocketDataProviders to be returned. 650 std::vector<T*> data_providers_; 651 }; 652 653 class MockUDPClientSocket; 654 class MockTCPClientSocket; 655 class MockSSLClientSocket; 656 657 // ClientSocketFactory which contains arrays of sockets of each type. 658 // You should first fill the arrays using Add{SSL,}SocketDataProvider(). When 659 // the factory is asked to create a socket, it takes next entry from appropriate 660 // array. You can use ResetNextMockIndexes to reset that next entry index for 661 // all mock socket types. 662 class MockClientSocketFactory : public ClientSocketFactory { 663 public: 664 MockClientSocketFactory(); 665 666 MockClientSocketFactory(const MockClientSocketFactory&) = delete; 667 MockClientSocketFactory& operator=(const MockClientSocketFactory&) = delete; 668 669 ~MockClientSocketFactory() override; 670 671 // Adds a SocketDataProvider that can be used to served either TCP or UDP 672 // connection requests. Sockets are returned in FIFO order. 673 void AddSocketDataProvider(SocketDataProvider* socket); 674 675 // Like AddSocketDataProvider(), except sockets will only be used to service 676 // TCP connection requests. Sockets added with this method are used first, 677 // before sockets added with AddSocketDataProvider(). Particularly useful for 678 // QUIC tests with multiple sockets, where TCP connections may or may not be 679 // made, and have no guaranteed order, relative to UDP connections. 680 void AddTcpSocketDataProvider(SocketDataProvider* socket); 681 682 void AddSSLSocketDataProvider(SSLSocketDataProvider* socket); 683 void ResetNextMockIndexes(); 684 mock_data()685 SocketDataProviderArray<SocketDataProvider>& mock_data() { 686 return mock_data_; 687 } 688 set_enable_read_if_ready(bool enable_read_if_ready)689 void set_enable_read_if_ready(bool enable_read_if_ready) { 690 enable_read_if_ready_ = enable_read_if_ready; 691 } 692 693 // ClientSocketFactory 694 std::unique_ptr<DatagramClientSocket> CreateDatagramClientSocket( 695 DatagramSocket::BindType bind_type, 696 NetLog* net_log, 697 const NetLogSource& source) override; 698 std::unique_ptr<TransportClientSocket> CreateTransportClientSocket( 699 const AddressList& addresses, 700 std::unique_ptr<SocketPerformanceWatcher> socket_performance_watcher, 701 NetworkQualityEstimator* network_quality_estimator, 702 NetLog* net_log, 703 const NetLogSource& source) override; 704 std::unique_ptr<SSLClientSocket> CreateSSLClientSocket( 705 SSLClientContext* context, 706 std::unique_ptr<StreamSocket> stream_socket, 707 const HostPortPair& host_and_port, 708 const SSLConfig& ssl_config) override; udp_client_socket_ports()709 const std::vector<uint16_t>& udp_client_socket_ports() const { 710 return udp_client_socket_ports_; 711 } 712 713 private: 714 SocketDataProviderArray<SocketDataProvider> mock_data_; 715 SocketDataProviderArray<SocketDataProvider> mock_tcp_data_; 716 SocketDataProviderArray<SSLSocketDataProvider> mock_ssl_data_; 717 std::vector<uint16_t> udp_client_socket_ports_; 718 719 // If true, ReadIfReady() is enabled; otherwise ReadIfReady() returns 720 // ERR_READ_IF_READY_NOT_IMPLEMENTED. 721 bool enable_read_if_ready_ = false; 722 }; 723 724 class MockClientSocket : public TransportClientSocket { 725 public: 726 // The NetLogWithSource is needed to test LoadTimingInfo, which uses NetLog 727 // IDs as 728 // unique socket IDs. 729 explicit MockClientSocket(const NetLogWithSource& net_log); 730 731 MockClientSocket(const MockClientSocket&) = delete; 732 MockClientSocket& operator=(const MockClientSocket&) = delete; 733 734 // Socket implementation. 735 int Read(IOBuffer* buf, 736 int buf_len, 737 CompletionOnceCallback callback) override = 0; 738 int Write(IOBuffer* buf, 739 int buf_len, 740 CompletionOnceCallback callback, 741 const NetworkTrafficAnnotationTag& traffic_annotation) override = 0; 742 int SetReceiveBufferSize(int32_t size) override; 743 int SetSendBufferSize(int32_t size) override; 744 745 // TransportClientSocket implementation. 746 int Bind(const net::IPEndPoint& local_addr) override; 747 bool SetNoDelay(bool no_delay) override; 748 bool SetKeepAlive(bool enable, int delay) override; 749 750 // StreamSocket implementation. 751 int Connect(CompletionOnceCallback callback) override = 0; 752 void Disconnect() override; 753 bool IsConnected() const override; 754 bool IsConnectedAndIdle() const override; 755 int GetPeerAddress(IPEndPoint* address) const override; 756 int GetLocalAddress(IPEndPoint* address) const override; 757 const NetLogWithSource& NetLog() const override; 758 bool WasAlpnNegotiated() const override; 759 NextProto GetNegotiatedProtocol() const override; 760 int64_t GetTotalReceivedBytes() const override; ApplySocketTag(const SocketTag & tag)761 void ApplySocketTag(const SocketTag& tag) override {} 762 763 protected: 764 ~MockClientSocket() override; 765 void RunCallbackAsync(CompletionOnceCallback callback, int result); 766 void RunCallback(CompletionOnceCallback callback, int result); 767 768 // True if Connect completed successfully and Disconnect hasn't been called. 769 bool connected_ = false; 770 771 IPEndPoint local_addr_; 772 IPEndPoint peer_addr_; 773 774 NetLogWithSource net_log_; 775 776 private: 777 base::WeakPtrFactory<MockClientSocket> weak_factory_{this}; 778 }; 779 780 class MockTCPClientSocket : public MockClientSocket, public AsyncSocket { 781 public: 782 MockTCPClientSocket(const AddressList& addresses, 783 net::NetLog* net_log, 784 SocketDataProvider* socket); 785 786 MockTCPClientSocket(const MockTCPClientSocket&) = delete; 787 MockTCPClientSocket& operator=(const MockTCPClientSocket&) = delete; 788 789 ~MockTCPClientSocket() override; 790 addresses()791 const AddressList& addresses() const { return addresses_; } 792 793 // Socket implementation. 794 int Read(IOBuffer* buf, 795 int buf_len, 796 CompletionOnceCallback callback) override; 797 int ReadIfReady(IOBuffer* buf, 798 int buf_len, 799 CompletionOnceCallback callback) override; 800 int CancelReadIfReady() override; 801 int Write(IOBuffer* buf, 802 int buf_len, 803 CompletionOnceCallback callback, 804 const NetworkTrafficAnnotationTag& traffic_annotation) override; 805 int SetReceiveBufferSize(int32_t size) override; 806 int SetSendBufferSize(int32_t size) override; 807 808 // TransportClientSocket implementation. 809 bool SetNoDelay(bool no_delay) override; 810 bool SetKeepAlive(bool enable, int delay) override; 811 812 // StreamSocket implementation. 813 void SetBeforeConnectCallback( 814 const BeforeConnectCallback& before_connect_callback) override; 815 int Connect(CompletionOnceCallback callback) override; 816 void Disconnect() override; 817 bool IsConnected() const override; 818 bool IsConnectedAndIdle() const override; 819 int GetPeerAddress(IPEndPoint* address) const override; 820 bool WasEverUsed() const override; 821 bool GetSSLInfo(SSLInfo* ssl_info) override; 822 823 // AsyncSocket: 824 void OnReadComplete(const MockRead& data) override; 825 void OnWriteComplete(int rv) override; 826 void OnConnectComplete(const MockConnect& data) override; 827 void OnDataProviderDestroyed() override; 828 set_enable_read_if_ready(bool enable_read_if_ready)829 void set_enable_read_if_ready(bool enable_read_if_ready) { 830 enable_read_if_ready_ = enable_read_if_ready; 831 } 832 833 private: 834 void RetryRead(int rv); 835 int ReadIfReadyImpl(IOBuffer* buf, 836 int buf_len, 837 CompletionOnceCallback callback); 838 839 // Helper method to run |pending_read_if_ready_callback_| if it is not null. 840 void RunReadIfReadyCallback(int result); 841 842 AddressList addresses_; 843 844 raw_ptr<SocketDataProvider> data_; 845 int read_offset_ = 0; 846 MockRead read_data_; 847 bool need_read_data_ = true; 848 849 // True if the peer has closed the connection. This allows us to simulate 850 // the recv(..., MSG_PEEK) call in the IsConnectedAndIdle method of the real 851 // TCPClientSocket. 852 bool peer_closed_connection_ = false; 853 854 // While an asynchronous read is pending, we save our user-buffer state. 855 scoped_refptr<IOBuffer> pending_read_buf_ = nullptr; 856 int pending_read_buf_len_ = 0; 857 CompletionOnceCallback pending_read_callback_; 858 859 // Non-null when a ReadIfReady() is pending. 860 CompletionOnceCallback pending_read_if_ready_callback_; 861 862 CompletionOnceCallback pending_connect_callback_; 863 CompletionOnceCallback pending_write_callback_; 864 bool was_used_to_convey_data_ = false; 865 866 // If true, ReadIfReady() is enabled; otherwise ReadIfReady() returns 867 // ERR_READ_IF_READY_NOT_IMPLEMENTED. 868 bool enable_read_if_ready_ = false; 869 870 BeforeConnectCallback before_connect_callback_; 871 }; 872 873 class MockSSLClientSocket : public AsyncSocket, public SSLClientSocket { 874 public: 875 MockSSLClientSocket(std::unique_ptr<StreamSocket> stream_socket, 876 const HostPortPair& host_and_port, 877 const SSLConfig& ssl_config, 878 SSLSocketDataProvider* socket); 879 880 MockSSLClientSocket(const MockSSLClientSocket&) = delete; 881 MockSSLClientSocket& operator=(const MockSSLClientSocket&) = delete; 882 883 ~MockSSLClientSocket() override; 884 885 // Socket implementation. 886 int Read(IOBuffer* buf, 887 int buf_len, 888 CompletionOnceCallback callback) override; 889 int ReadIfReady(IOBuffer* buf, 890 int buf_len, 891 CompletionOnceCallback callback) override; 892 int Write(IOBuffer* buf, 893 int buf_len, 894 CompletionOnceCallback callback, 895 const NetworkTrafficAnnotationTag& traffic_annotation) override; 896 int CancelReadIfReady() override; 897 898 // StreamSocket implementation. 899 int Connect(CompletionOnceCallback callback) override; 900 void Disconnect() override; 901 int ConfirmHandshake(CompletionOnceCallback callback) override; 902 bool IsConnected() const override; 903 bool IsConnectedAndIdle() const override; 904 bool WasEverUsed() const override; 905 int GetPeerAddress(IPEndPoint* address) const override; 906 int GetLocalAddress(IPEndPoint* address) const override; 907 bool WasAlpnNegotiated() const override; 908 NextProto GetNegotiatedProtocol() const override; 909 absl::optional<base::StringPiece> GetPeerApplicationSettings() const override; 910 bool GetSSLInfo(SSLInfo* ssl_info) override; 911 void GetSSLCertRequestInfo( 912 SSLCertRequestInfo* cert_request_info) const override; 913 void ApplySocketTag(const SocketTag& tag) override; 914 const NetLogWithSource& NetLog() const override; 915 int64_t GetTotalReceivedBytes() const override; 916 int SetReceiveBufferSize(int32_t size) override; 917 int SetSendBufferSize(int32_t size) override; 918 919 // SSLSocket implementation. 920 int ExportKeyingMaterial(base::StringPiece label, 921 bool has_context, 922 base::StringPiece context, 923 unsigned char* out, 924 unsigned int outlen) override; 925 926 // SSLClientSocket implementation. 927 std::vector<uint8_t> GetECHRetryConfigs() override; 928 929 // This MockSocket does not implement the manual async IO feature. 930 void OnReadComplete(const MockRead& data) override; 931 void OnWriteComplete(int rv) override; 932 void OnConnectComplete(const MockConnect& data) override; 933 // SSL sockets don't need magic to deal with destruction of their data 934 // provider. 935 // TODO(mmenke): Probably a good idea to support it, anyways. OnDataProviderDestroyed()936 void OnDataProviderDestroyed() override {} 937 938 private: 939 static void ConnectCallback(MockSSLClientSocket* ssl_client_socket, 940 CompletionOnceCallback callback, 941 int rv); 942 943 void RunCallbackAsync(CompletionOnceCallback callback, int result); 944 void RunCallback(CompletionOnceCallback callback, int result); 945 946 void RunConfirmHandshakeCallback(CompletionOnceCallback callback, int result); 947 948 bool connected_ = false; 949 bool in_confirm_handshake_ = false; 950 NetLogWithSource net_log_; 951 std::unique_ptr<StreamSocket> stream_socket_; 952 raw_ptr<SSLSocketDataProvider> data_; 953 // Address of the "remote" peer we're connected to. 954 IPEndPoint peer_addr_; 955 956 base::WeakPtrFactory<MockSSLClientSocket> weak_factory_{this}; 957 }; 958 959 class MockUDPClientSocket : public DatagramClientSocket, public AsyncSocket { 960 public: 961 explicit MockUDPClientSocket(SocketDataProvider* data = nullptr, 962 net::NetLog* net_log = nullptr); 963 964 MockUDPClientSocket(const MockUDPClientSocket&) = delete; 965 MockUDPClientSocket& operator=(const MockUDPClientSocket&) = delete; 966 967 ~MockUDPClientSocket() override; 968 969 // Socket implementation. 970 int Read(IOBuffer* buf, 971 int buf_len, 972 CompletionOnceCallback callback) override; 973 int Write(IOBuffer* buf, 974 int buf_len, 975 CompletionOnceCallback callback, 976 const NetworkTrafficAnnotationTag& traffic_annotation) override; 977 978 int SetReceiveBufferSize(int32_t size) override; 979 int SetSendBufferSize(int32_t size) override; 980 int SetDoNotFragment() override; 981 982 // DatagramSocket implementation. 983 void Close() override; 984 int GetPeerAddress(IPEndPoint* address) const override; 985 int GetLocalAddress(IPEndPoint* address) const override; 986 void UseNonBlockingIO() override; 987 int SetMulticastInterface(uint32_t interface_index) override; 988 const NetLogWithSource& NetLog() const override; 989 990 // DatagramClientSocket implementation. 991 int Connect(const IPEndPoint& address) override; 992 int ConnectUsingNetwork(handles::NetworkHandle network, 993 const IPEndPoint& address) override; 994 int ConnectUsingDefaultNetwork(const IPEndPoint& address) override; 995 int ConnectAsync(const IPEndPoint& address, 996 CompletionOnceCallback callback) override; 997 int ConnectUsingNetworkAsync(handles::NetworkHandle network, 998 const IPEndPoint& address, 999 CompletionOnceCallback callback) override; 1000 int ConnectUsingDefaultNetworkAsync(const IPEndPoint& address, 1001 CompletionOnceCallback callback) override; 1002 handles::NetworkHandle GetBoundNetwork() const override; 1003 void ApplySocketTag(const SocketTag& tag) override; SetMsgConfirm(bool confirm)1004 void SetMsgConfirm(bool confirm) override {} 1005 1006 // AsyncSocket implementation. 1007 void OnReadComplete(const MockRead& data) override; 1008 void OnWriteComplete(int rv) override; 1009 void OnConnectComplete(const MockConnect& data) override; 1010 void OnDataProviderDestroyed() override; 1011 set_source_port(uint16_t port)1012 void set_source_port(uint16_t port) { source_port_ = port; } source_port()1013 uint16_t source_port() const { return source_port_; } set_source_host(IPAddress addr)1014 void set_source_host(IPAddress addr) { source_host_ = addr; } source_host()1015 IPAddress source_host() const { return source_host_; } 1016 1017 // Returns last tag applied to socket. tag()1018 SocketTag tag() const { return tag_; } 1019 1020 // Returns false if socket's tag was changed after the socket was used for 1021 // data transfer (e.g. Read/Write() called), otherwise returns true. tagged_before_data_transferred()1022 bool tagged_before_data_transferred() const { 1023 return tagged_before_data_transferred_; 1024 } 1025 1026 private: 1027 int CompleteRead(); 1028 1029 void RunCallbackAsync(CompletionOnceCallback callback, int result); 1030 void RunCallback(CompletionOnceCallback callback, int result); 1031 1032 bool connected_ = false; 1033 raw_ptr<SocketDataProvider> data_; 1034 int read_offset_ = 0; 1035 MockRead read_data_; 1036 bool need_read_data_ = true; 1037 IPAddress source_host_; 1038 uint16_t source_port_ = 123; // Ephemeral source port. 1039 1040 // Address of the "remote" peer we're connected to. 1041 IPEndPoint peer_addr_; 1042 1043 // Network that the socket is bound to. 1044 handles::NetworkHandle network_ = handles::kInvalidNetworkHandle; 1045 1046 // While an asynchronous IO is pending, we save our user-buffer state. 1047 scoped_refptr<IOBuffer> pending_read_buf_ = nullptr; 1048 int pending_read_buf_len_ = 0; 1049 CompletionOnceCallback pending_read_callback_; 1050 CompletionOnceCallback pending_write_callback_; 1051 1052 NetLogWithSource net_log_; 1053 1054 DatagramBuffers unwritten_buffers_; 1055 1056 SocketTag tag_; 1057 bool data_transferred_ = false; 1058 bool tagged_before_data_transferred_ = true; 1059 1060 base::WeakPtrFactory<MockUDPClientSocket> weak_factory_{this}; 1061 }; 1062 1063 class TestSocketRequest : public TestCompletionCallbackBase { 1064 public: 1065 TestSocketRequest(std::vector<TestSocketRequest*>* request_order, 1066 size_t* completion_count); 1067 1068 TestSocketRequest(const TestSocketRequest&) = delete; 1069 TestSocketRequest& operator=(const TestSocketRequest&) = delete; 1070 1071 ~TestSocketRequest() override; 1072 handle()1073 ClientSocketHandle* handle() { return &handle_; } 1074 callback()1075 CompletionOnceCallback callback() { 1076 return base::BindOnce(&TestSocketRequest::OnComplete, 1077 base::Unretained(this)); 1078 } 1079 1080 private: 1081 void OnComplete(int result); 1082 1083 ClientSocketHandle handle_; 1084 raw_ptr<std::vector<TestSocketRequest*>> request_order_; 1085 raw_ptr<size_t> completion_count_; 1086 }; 1087 1088 class ClientSocketPoolTest { 1089 public: 1090 enum KeepAlive { 1091 KEEP_ALIVE, 1092 1093 // A socket will be disconnected in addition to handle being reset. 1094 NO_KEEP_ALIVE, 1095 }; 1096 1097 static const int kIndexOutOfBounds; 1098 static const int kRequestNotFound; 1099 1100 ClientSocketPoolTest(); 1101 1102 ClientSocketPoolTest(const ClientSocketPoolTest&) = delete; 1103 ClientSocketPoolTest& operator=(const ClientSocketPoolTest&) = delete; 1104 1105 ~ClientSocketPoolTest(); 1106 1107 template <typename PoolType> StartRequestUsingPool(PoolType * socket_pool,const ClientSocketPool::GroupId & group_id,RequestPriority priority,ClientSocketPool::RespectLimits respect_limits,const scoped_refptr<typename PoolType::SocketParams> & socket_params)1108 int StartRequestUsingPool( 1109 PoolType* socket_pool, 1110 const ClientSocketPool::GroupId& group_id, 1111 RequestPriority priority, 1112 ClientSocketPool::RespectLimits respect_limits, 1113 const scoped_refptr<typename PoolType::SocketParams>& socket_params) { 1114 DCHECK(socket_pool); 1115 TestSocketRequest* request( 1116 new TestSocketRequest(&request_order_, &completion_count_)); 1117 requests_.push_back(base::WrapUnique(request)); 1118 int rv = request->handle()->Init( 1119 group_id, socket_params, absl::nullopt /* proxy_annotation_tag */, 1120 priority, SocketTag(), respect_limits, request->callback(), 1121 ClientSocketPool::ProxyAuthCallback(), socket_pool, NetLogWithSource()); 1122 if (rv != ERR_IO_PENDING) 1123 request_order_.push_back(request); 1124 return rv; 1125 } 1126 1127 // Provided there were n requests started, takes |index| in range 1..n 1128 // and returns order in which that request completed, in range 1..n, 1129 // or kIndexOutOfBounds if |index| is out of bounds, or kRequestNotFound 1130 // if that request did not complete (for example was canceled). 1131 int GetOrderOfRequest(size_t index) const; 1132 1133 // Resets first initialized socket handle from |requests_|. If found such 1134 // a handle, returns true. 1135 bool ReleaseOneConnection(KeepAlive keep_alive); 1136 1137 // Releases connections until there is nothing to release. 1138 void ReleaseAllConnections(KeepAlive keep_alive); 1139 1140 // Note that this uses 0-based indices, while GetOrderOfRequest takes and 1141 // returns 1-based indices. request(int i)1142 TestSocketRequest* request(int i) { return requests_[i].get(); } 1143 requests_size()1144 size_t requests_size() const { return requests_.size(); } requests()1145 std::vector<std::unique_ptr<TestSocketRequest>>* requests() { 1146 return &requests_; 1147 } completion_count()1148 size_t completion_count() const { return completion_count_; } 1149 1150 private: 1151 std::vector<std::unique_ptr<TestSocketRequest>> requests_; 1152 std::vector<TestSocketRequest*> request_order_; 1153 size_t completion_count_ = 0; 1154 }; 1155 1156 class MockTransportSocketParams 1157 : public base::RefCounted<MockTransportSocketParams> { 1158 public: 1159 MockTransportSocketParams(const MockTransportSocketParams&) = delete; 1160 MockTransportSocketParams& operator=(const MockTransportSocketParams&) = 1161 delete; 1162 1163 private: 1164 friend class base::RefCounted<MockTransportSocketParams>; 1165 ~MockTransportSocketParams() = default; 1166 }; 1167 1168 class MockTransportClientSocketPool : public TransportClientSocketPool { 1169 public: 1170 class MockConnectJob { 1171 public: 1172 MockConnectJob(std::unique_ptr<StreamSocket> socket, 1173 ClientSocketHandle* handle, 1174 const SocketTag& socket_tag, 1175 CompletionOnceCallback callback, 1176 RequestPriority priority); 1177 1178 MockConnectJob(const MockConnectJob&) = delete; 1179 MockConnectJob& operator=(const MockConnectJob&) = delete; 1180 1181 ~MockConnectJob(); 1182 1183 int Connect(); 1184 bool CancelHandle(const ClientSocketHandle* handle); 1185 handle()1186 ClientSocketHandle* handle() const { return handle_; } 1187 priority()1188 RequestPriority priority() const { return priority_; } set_priority(RequestPriority priority)1189 void set_priority(RequestPriority priority) { priority_ = priority; } 1190 1191 private: 1192 void OnConnect(int rv); 1193 1194 std::unique_ptr<StreamSocket> socket_; 1195 raw_ptr<ClientSocketHandle> handle_; 1196 const SocketTag socket_tag_; 1197 CompletionOnceCallback user_callback_; 1198 RequestPriority priority_; 1199 }; 1200 1201 MockTransportClientSocketPool( 1202 int max_sockets, 1203 int max_sockets_per_group, 1204 const CommonConnectJobParams* common_connect_job_params); 1205 1206 MockTransportClientSocketPool(const MockTransportClientSocketPool&) = delete; 1207 MockTransportClientSocketPool& operator=( 1208 const MockTransportClientSocketPool&) = delete; 1209 1210 ~MockTransportClientSocketPool() override; 1211 last_request_priority()1212 RequestPriority last_request_priority() const { 1213 return last_request_priority_; 1214 } 1215 requests()1216 const std::vector<std::unique_ptr<MockConnectJob>>& requests() const { 1217 return job_list_; 1218 } 1219 release_count()1220 int release_count() const { return release_count_; } cancel_count()1221 int cancel_count() const { return cancel_count_; } 1222 1223 // TransportClientSocketPool implementation. 1224 int RequestSocket( 1225 const GroupId& group_id, 1226 scoped_refptr<ClientSocketPool::SocketParams> socket_params, 1227 const absl::optional<NetworkTrafficAnnotationTag>& proxy_annotation_tag, 1228 RequestPriority priority, 1229 const SocketTag& socket_tag, 1230 RespectLimits respect_limits, 1231 ClientSocketHandle* handle, 1232 CompletionOnceCallback callback, 1233 const ProxyAuthCallback& on_auth_callback, 1234 const NetLogWithSource& net_log) override; 1235 void SetPriority(const GroupId& group_id, 1236 ClientSocketHandle* handle, 1237 RequestPriority priority) override; 1238 void CancelRequest(const GroupId& group_id, 1239 ClientSocketHandle* handle, 1240 bool cancel_connect_job) override; 1241 void ReleaseSocket(const GroupId& group_id, 1242 std::unique_ptr<StreamSocket> socket, 1243 int64_t generation) override; 1244 1245 private: 1246 raw_ptr<ClientSocketFactory> client_socket_factory_; 1247 std::vector<std::unique_ptr<MockConnectJob>> job_list_; 1248 RequestPriority last_request_priority_ = DEFAULT_PRIORITY; 1249 int release_count_ = 0; 1250 int cancel_count_ = 0; 1251 }; 1252 1253 // WrappedStreamSocket is a base class that wraps an existing StreamSocket, 1254 // forwarding the Socket and StreamSocket interfaces to the underlying 1255 // transport. 1256 // This is to provide a common base class for subclasses to override specific 1257 // StreamSocket methods for testing, while still communicating with a 'real' 1258 // StreamSocket. 1259 class WrappedStreamSocket : public TransportClientSocket { 1260 public: 1261 explicit WrappedStreamSocket(std::unique_ptr<StreamSocket> transport); 1262 ~WrappedStreamSocket() override; 1263 1264 // StreamSocket implementation: 1265 int Bind(const net::IPEndPoint& local_addr) override; 1266 int Connect(CompletionOnceCallback callback) override; 1267 void Disconnect() override; 1268 bool IsConnected() const override; 1269 bool IsConnectedAndIdle() const override; 1270 int GetPeerAddress(IPEndPoint* address) const override; 1271 int GetLocalAddress(IPEndPoint* address) const override; 1272 const NetLogWithSource& NetLog() const override; 1273 bool WasEverUsed() const override; 1274 bool WasAlpnNegotiated() const override; 1275 NextProto GetNegotiatedProtocol() const override; 1276 bool GetSSLInfo(SSLInfo* ssl_info) override; 1277 int64_t GetTotalReceivedBytes() const override; 1278 void ApplySocketTag(const SocketTag& tag) override; 1279 1280 // Socket implementation: 1281 int Read(IOBuffer* buf, 1282 int buf_len, 1283 CompletionOnceCallback callback) override; 1284 int ReadIfReady(IOBuffer* buf, 1285 int buf_len, 1286 CompletionOnceCallback callback) override; 1287 int Write(IOBuffer* buf, 1288 int buf_len, 1289 CompletionOnceCallback callback, 1290 const NetworkTrafficAnnotationTag& traffic_annotation) override; 1291 int SetReceiveBufferSize(int32_t size) override; 1292 int SetSendBufferSize(int32_t size) override; 1293 1294 protected: 1295 std::unique_ptr<StreamSocket> transport_; 1296 }; 1297 1298 // StreamSocket that wraps another StreamSocket, but keeps track of any 1299 // SocketTag applied to the socket. 1300 class MockTaggingStreamSocket : public WrappedStreamSocket { 1301 public: MockTaggingStreamSocket(std::unique_ptr<StreamSocket> transport)1302 explicit MockTaggingStreamSocket(std::unique_ptr<StreamSocket> transport) 1303 : WrappedStreamSocket(std::move(transport)) {} 1304 1305 MockTaggingStreamSocket(const MockTaggingStreamSocket&) = delete; 1306 MockTaggingStreamSocket& operator=(const MockTaggingStreamSocket&) = delete; 1307 1308 ~MockTaggingStreamSocket() override = default; 1309 1310 // StreamSocket implementation. 1311 int Connect(CompletionOnceCallback callback) override; 1312 void ApplySocketTag(const SocketTag& tag) override; 1313 1314 // Returns false if socket's tag was changed after the socket was connected, 1315 // otherwise returns true. tagged_before_connected()1316 bool tagged_before_connected() const { return tagged_before_connected_; } 1317 1318 // Returns last tag applied to socket. tag()1319 SocketTag tag() const { return tag_; } 1320 1321 private: 1322 bool connected_ = false; 1323 bool tagged_before_connected_ = true; 1324 SocketTag tag_; 1325 }; 1326 1327 // Extend MockClientSocketFactory to return MockTaggingStreamSockets and 1328 // keep track of last socket produced for test inspection. 1329 class MockTaggingClientSocketFactory : public MockClientSocketFactory { 1330 public: 1331 MockTaggingClientSocketFactory() = default; 1332 1333 MockTaggingClientSocketFactory(const MockTaggingClientSocketFactory&) = 1334 delete; 1335 MockTaggingClientSocketFactory& operator=( 1336 const MockTaggingClientSocketFactory&) = delete; 1337 1338 // ClientSocketFactory implementation. 1339 std::unique_ptr<DatagramClientSocket> CreateDatagramClientSocket( 1340 DatagramSocket::BindType bind_type, 1341 NetLog* net_log, 1342 const NetLogSource& source) override; 1343 std::unique_ptr<TransportClientSocket> CreateTransportClientSocket( 1344 const AddressList& addresses, 1345 std::unique_ptr<SocketPerformanceWatcher> socket_performance_watcher, 1346 NetworkQualityEstimator* network_quality_estimator, 1347 NetLog* net_log, 1348 const NetLogSource& source) override; 1349 1350 // These methods return pointers to last TCP and UDP sockets produced by this 1351 // factory. NOTE: Socket must still exist, or pointer will be to freed memory. GetLastProducedTCPSocket()1352 MockTaggingStreamSocket* GetLastProducedTCPSocket() const { 1353 return tcp_socket_; 1354 } GetLastProducedUDPSocket()1355 MockUDPClientSocket* GetLastProducedUDPSocket() const { return udp_socket_; } 1356 1357 private: 1358 raw_ptr<MockTaggingStreamSocket> tcp_socket_ = nullptr; 1359 raw_ptr<MockUDPClientSocket> udp_socket_ = nullptr; 1360 }; 1361 1362 // Host / port used for SOCKS4 test strings. 1363 extern const char kSOCKS4TestHost[]; 1364 extern const int kSOCKS4TestPort; 1365 1366 // Constants for a successful SOCKS v4 handshake (connecting to kSOCKS4TestHost 1367 // on port kSOCKS4TestPort, for the request). 1368 extern const char kSOCKS4OkRequestLocalHostPort80[]; 1369 extern const int kSOCKS4OkRequestLocalHostPort80Length; 1370 1371 extern const char kSOCKS4OkReply[]; 1372 extern const int kSOCKS4OkReplyLength; 1373 1374 // Host / port used for SOCKS5 test strings. 1375 extern const char kSOCKS5TestHost[]; 1376 extern const int kSOCKS5TestPort; 1377 1378 // Constants for a successful SOCKS v5 handshake (connecting to kSOCKS5TestHost 1379 // on port kSOCKS5TestPort, for the request).. 1380 extern const char kSOCKS5GreetRequest[]; 1381 extern const int kSOCKS5GreetRequestLength; 1382 1383 extern const char kSOCKS5GreetResponse[]; 1384 extern const int kSOCKS5GreetResponseLength; 1385 1386 extern const char kSOCKS5OkRequest[]; 1387 extern const int kSOCKS5OkRequestLength; 1388 1389 extern const char kSOCKS5OkResponse[]; 1390 extern const int kSOCKS5OkResponseLength; 1391 1392 // Helper function to get the total data size of the MockReads in |reads|. 1393 int64_t CountReadBytes(base::span<const MockRead> reads); 1394 1395 // Helper function to get the total data size of the MockWrites in |writes|. 1396 int64_t CountWriteBytes(base::span<const MockWrite> writes); 1397 1398 #if BUILDFLAG(IS_ANDROID) 1399 // Returns whether the device supports calling GetTaggedBytes(). 1400 bool CanGetTaggedBytes(); 1401 1402 // Query the system to find out how many bytes were received with tag 1403 // |expected_tag| for our UID. Return the count of received bytes. 1404 uint64_t GetTaggedBytes(int32_t expected_tag); 1405 #endif 1406 1407 } // namespace net 1408 1409 #endif // NET_SOCKET_SOCKET_TEST_UTIL_H_ 1410