1 // Copyright 2013 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/socket/tcp_server_socket.h"
6
7 #include <memory>
8 #include <utility>
9
10 #include "base/check.h"
11 #include "base/functional/bind.h"
12 #include "base/functional/callback_helpers.h"
13 #include "base/notreached.h"
14 #include "net/base/net_errors.h"
15 #include "net/socket/socket_descriptor.h"
16 #include "net/socket/tcp_client_socket.h"
17
18 namespace net {
19
TCPServerSocket(NetLog * net_log,const NetLogSource & source)20 TCPServerSocket::TCPServerSocket(NetLog* net_log, const NetLogSource& source)
21 : TCPServerSocket(
22 std::make_unique<TCPSocket>(nullptr /* socket_performance_watcher */,
23 net_log,
24 source)) {}
25
TCPServerSocket(std::unique_ptr<TCPSocket> socket)26 TCPServerSocket::TCPServerSocket(std::unique_ptr<TCPSocket> socket)
27 : socket_(std::move(socket)) {}
28
AdoptSocket(SocketDescriptor socket)29 int TCPServerSocket::AdoptSocket(SocketDescriptor socket) {
30 return socket_->AdoptUnconnectedSocket(socket);
31 }
32
33 TCPServerSocket::~TCPServerSocket() = default;
34
Listen(const IPEndPoint & address,int backlog,absl::optional<bool> ipv6_only)35 int TCPServerSocket::Listen(const IPEndPoint& address,
36 int backlog,
37 absl::optional<bool> ipv6_only) {
38 int result = socket_->Open(address.GetFamily());
39 if (result != OK)
40 return result;
41
42 if (ipv6_only.has_value()) {
43 CHECK_EQ(address.address(), net::IPAddress::IPv6AllZeros());
44 result = socket_->SetIPv6Only(*ipv6_only);
45 if (result != OK) {
46 socket_->Close();
47 return result;
48 }
49 }
50
51 result = socket_->SetDefaultOptionsForServer();
52 if (result != OK) {
53 socket_->Close();
54 return result;
55 }
56
57 result = socket_->Bind(address);
58 if (result != OK) {
59 socket_->Close();
60 return result;
61 }
62
63 result = socket_->Listen(backlog);
64 if (result != OK) {
65 socket_->Close();
66 return result;
67 }
68
69 return OK;
70 }
71
GetLocalAddress(IPEndPoint * address) const72 int TCPServerSocket::GetLocalAddress(IPEndPoint* address) const {
73 return socket_->GetLocalAddress(address);
74 }
75
Accept(std::unique_ptr<StreamSocket> * socket,CompletionOnceCallback callback)76 int TCPServerSocket::Accept(std::unique_ptr<StreamSocket>* socket,
77 CompletionOnceCallback callback) {
78 return Accept(socket, std::move(callback), nullptr);
79 }
80
Accept(std::unique_ptr<StreamSocket> * socket,CompletionOnceCallback callback,IPEndPoint * peer_address)81 int TCPServerSocket::Accept(std::unique_ptr<StreamSocket>* socket,
82 CompletionOnceCallback callback,
83 IPEndPoint* peer_address) {
84 DCHECK(socket);
85 DCHECK(!callback.is_null());
86
87 if (pending_accept_) {
88 NOTREACHED();
89 return ERR_UNEXPECTED;
90 }
91
92 // It is safe to use base::Unretained(this). |socket_| is owned by this class,
93 // and the callback won't be run after |socket_| is destroyed.
94 CompletionOnceCallback accept_callback = base::BindOnce(
95 &TCPServerSocket::OnAcceptCompleted, base::Unretained(this), socket,
96 peer_address, std::move(callback));
97 int result = socket_->Accept(&accepted_socket_, &accepted_address_,
98 std::move(accept_callback));
99 if (result != ERR_IO_PENDING) {
100 // |accept_callback| won't be called so we need to run
101 // ConvertAcceptedSocket() ourselves in order to do the conversion from
102 // |accepted_socket_| to |socket|.
103 result = ConvertAcceptedSocket(result, socket, peer_address);
104 } else {
105 pending_accept_ = true;
106 }
107
108 return result;
109 }
110
DetachFromThread()111 void TCPServerSocket::DetachFromThread() {
112 socket_->DetachFromThread();
113 }
114
ConvertAcceptedSocket(int result,std::unique_ptr<StreamSocket> * output_accepted_socket,IPEndPoint * output_accepted_address)115 int TCPServerSocket::ConvertAcceptedSocket(
116 int result,
117 std::unique_ptr<StreamSocket>* output_accepted_socket,
118 IPEndPoint* output_accepted_address) {
119 // Make sure the TCPSocket object is destroyed in any case.
120 std::unique_ptr<TCPSocket> temp_accepted_socket(std::move(accepted_socket_));
121 if (result != OK)
122 return result;
123
124 if (output_accepted_address)
125 *output_accepted_address = accepted_address_;
126
127 *output_accepted_socket = std::make_unique<TCPClientSocket>(
128 std::move(temp_accepted_socket), accepted_address_);
129
130 return OK;
131 }
132
OnAcceptCompleted(std::unique_ptr<StreamSocket> * output_accepted_socket,IPEndPoint * output_accepted_address,CompletionOnceCallback forward_callback,int result)133 void TCPServerSocket::OnAcceptCompleted(
134 std::unique_ptr<StreamSocket>* output_accepted_socket,
135 IPEndPoint* output_accepted_address,
136 CompletionOnceCallback forward_callback,
137 int result) {
138 result = ConvertAcceptedSocket(result, output_accepted_socket,
139 output_accepted_address);
140 pending_accept_ = false;
141 std::move(forward_callback).Run(result);
142 }
143
144 } // namespace net
145