1 /*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "chre_host/socket_server.h"
18
19 #include <poll.h>
20
21 #include <cassert>
22 #include <cinttypes>
23 #include <csignal>
24 #include <cstdlib>
25 #include <map>
26 #include <mutex>
27
28 #include <cutils/sockets.h>
29
30 #include "chre_host/log.h"
31
32 namespace android {
33 namespace chre {
34
35 std::atomic<bool> SocketServer::sSignalReceived(false);
36
37 namespace {
38
maskAllSignals()39 void maskAllSignals() {
40 sigset_t signalMask;
41 sigfillset(&signalMask);
42 if (sigprocmask(SIG_SETMASK, &signalMask, NULL) != 0) {
43 LOG_ERROR("Couldn't mask all signals", errno);
44 }
45 }
46
maskAllSignalsExceptIntAndTerm()47 void maskAllSignalsExceptIntAndTerm() {
48 sigset_t signalMask;
49 sigfillset(&signalMask);
50 sigdelset(&signalMask, SIGINT);
51 sigdelset(&signalMask, SIGTERM);
52 if (sigprocmask(SIG_SETMASK, &signalMask, NULL) != 0) {
53 LOG_ERROR("Couldn't mask all signals except INT/TERM", errno);
54 }
55 }
56
57 } // anonymous namespace
58
SocketServer()59 SocketServer::SocketServer() {
60 // Initialize the socket fds field for all inactive client slots to -1, so
61 // poll skips over it, and we don't attempt to send on it
62 for (size_t i = 1; i <= kMaxActiveClients; i++) {
63 mPollFds[i].fd = -1;
64 mPollFds[i].events = POLLIN;
65 }
66 }
67
run(const char * socketName,bool allowSocketCreation,ClientMessageCallback clientMessageCallback)68 void SocketServer::run(const char *socketName, bool allowSocketCreation,
69 ClientMessageCallback clientMessageCallback) {
70 mClientMessageCallback = clientMessageCallback;
71
72 mSockFd = android_get_control_socket(socketName);
73 if (mSockFd == INVALID_SOCKET && allowSocketCreation) {
74 LOGI("Didn't inherit socket, creating...");
75 mSockFd = socket_local_server(socketName, ANDROID_SOCKET_NAMESPACE_RESERVED,
76 SOCK_SEQPACKET);
77 }
78
79 if (mSockFd == INVALID_SOCKET) {
80 LOGE("Couldn't get/create socket");
81 } else {
82 int ret = listen(mSockFd, kMaxPendingConnectionRequests);
83 if (ret < 0) {
84 LOG_ERROR("Couldn't listen on socket", errno);
85 } else {
86 serviceSocket();
87 }
88
89 {
90 std::lock_guard<std::mutex> lock(mClientsMutex);
91 for (const auto &pair : mClients) {
92 int clientSocket = pair.first;
93 if (close(clientSocket) != 0) {
94 LOGI("Couldn't close client %" PRIu16 "'s socket: %s",
95 pair.second.clientId, strerror(errno));
96 }
97 }
98 mClients.clear();
99 }
100 close(mSockFd);
101 }
102 }
103
sendToAllClients(const void * data,size_t length)104 void SocketServer::sendToAllClients(const void *data, size_t length) {
105 std::lock_guard<std::mutex> lock(mClientsMutex);
106
107 int deliveredCount = 0;
108 for (const auto &pair : mClients) {
109 int clientSocket = pair.first;
110 uint16_t clientId = pair.second.clientId;
111 if (sendToClientSocket(data, length, clientSocket, clientId)) {
112 deliveredCount++;
113 } else if (errno == EINTR) {
114 // Exit early if we were interrupted - we should only get this for
115 // SIGINT/SIGTERM, so we should exit quickly
116 break;
117 }
118 }
119
120 if (deliveredCount == 0) {
121 LOGW("Got message but didn't deliver to any clients");
122 }
123 }
124
sendToClientById(const void * data,size_t length,uint16_t clientId)125 bool SocketServer::sendToClientById(const void *data, size_t length,
126 uint16_t clientId) {
127 std::lock_guard<std::mutex> lock(mClientsMutex);
128
129 bool sent = false;
130 for (const auto &pair : mClients) {
131 uint16_t thisClientId = pair.second.clientId;
132 if (thisClientId == clientId) {
133 int clientSocket = pair.first;
134 sent = sendToClientSocket(data, length, clientSocket, thisClientId);
135 break;
136 }
137 }
138
139 return sent;
140 }
141
acceptClientConnection()142 void SocketServer::acceptClientConnection() {
143 int clientSocket = accept(mSockFd, NULL, NULL);
144 if (clientSocket < 0) {
145 LOG_ERROR("Couldn't accept client connection", errno);
146 } else if (mClients.size() >= kMaxActiveClients) {
147 LOGW("Rejecting client request - maximum number of clients reached");
148 close(clientSocket);
149 } else {
150 ClientData clientData;
151 clientData.clientId = mNextClientId++;
152
153 // We currently don't handle wraparound - if we're getting this many
154 // connects/disconnects, then something is wrong.
155 // TODO: can handle this properly by iterating over the existing clients to
156 // avoid a conflict.
157 if (clientData.clientId == 0) {
158 LOGE("Couldn't allocate client ID");
159 std::exit(-1);
160 }
161
162 bool slotFound = false;
163 for (size_t i = 1; i <= kMaxActiveClients; i++) {
164 if (mPollFds[i].fd < 0) {
165 mPollFds[i].fd = clientSocket;
166 slotFound = true;
167 break;
168 }
169 }
170
171 if (!slotFound) {
172 LOGE("Couldn't find slot for client!");
173 assert(slotFound);
174 close(clientSocket);
175 } else {
176 {
177 std::lock_guard<std::mutex> lock(mClientsMutex);
178 mClients[clientSocket] = clientData;
179 }
180 LOGI(
181 "Accepted new client connection (count %zu), assigned client ID "
182 "%" PRIu16,
183 mClients.size(), clientData.clientId);
184 }
185 }
186 }
187
handleClientData(int clientSocket)188 void SocketServer::handleClientData(int clientSocket) {
189 const ClientData &clientData = mClients[clientSocket];
190 uint16_t clientId = clientData.clientId;
191
192 ssize_t packetSize =
193 recv(clientSocket, mRecvBuffer.data(), mRecvBuffer.size(), MSG_DONTWAIT);
194 if (packetSize < 0) {
195 LOGE("Couldn't get packet from client %" PRIu16 ": %s", clientId,
196 strerror(errno));
197 } else if (packetSize == 0) {
198 LOGI("Client %" PRIu16 " disconnected", clientId);
199 disconnectClient(clientSocket);
200 } else {
201 LOGV("Got %zd byte packet from client %" PRIu16, packetSize, clientId);
202 mClientMessageCallback(clientId, mRecvBuffer.data(), packetSize);
203 }
204 }
205
disconnectClient(int clientSocket)206 void SocketServer::disconnectClient(int clientSocket) {
207 {
208 std::lock_guard<std::mutex> lock(mClientsMutex);
209 mClients.erase(clientSocket);
210 }
211 close(clientSocket);
212
213 bool removed = false;
214 for (size_t i = 1; i <= kMaxActiveClients; i++) {
215 if (mPollFds[i].fd == clientSocket) {
216 mPollFds[i].fd = -1;
217 removed = true;
218 break;
219 }
220 }
221
222 if (!removed) {
223 LOGE("Out of sync");
224 assert(removed);
225 }
226 }
227
sendToClientSocket(const void * data,size_t length,int clientSocket,uint16_t clientId)228 bool SocketServer::sendToClientSocket(const void *data, size_t length,
229 int clientSocket, uint16_t clientId) {
230 errno = 0;
231 ssize_t bytesSent = send(clientSocket, data, length, 0);
232 if (bytesSent < 0) {
233 LOGE("Error sending packet of size %zu to client %" PRIu16 ": %s", length,
234 clientId, strerror(errno));
235 } else if (bytesSent == 0) {
236 LOGW("Client %" PRIu16 " disconnected before message could be delivered",
237 clientId);
238 } else {
239 LOGV("Delivered message of size %zu bytes to client %" PRIu16, length,
240 clientId);
241 }
242
243 return (bytesSent > 0);
244 }
245
serviceSocket()246 void SocketServer::serviceSocket() {
247 constexpr size_t kListenIndex = 0;
248 static_assert(kListenIndex == 0,
249 "Code assumes that the first index is always the listen "
250 "socket");
251
252 mPollFds[kListenIndex].fd = mSockFd;
253 mPollFds[kListenIndex].events = POLLIN;
254
255 // Signal mask used with ppoll() so we gracefully handle SIGINT and SIGTERM,
256 // and ignore other signals
257 sigset_t signalMask;
258 sigfillset(&signalMask);
259 sigdelset(&signalMask, SIGINT);
260 sigdelset(&signalMask, SIGTERM);
261
262 // Masking signals here ensure that after this point, we won't handle INT/TERM
263 // until after we call into ppoll()
264 maskAllSignals();
265 std::signal(SIGINT, signalHandler);
266 std::signal(SIGTERM, signalHandler);
267
268 LOGI("Ready to accept connections");
269 while (!sSignalReceived) {
270 int ret = ppoll(mPollFds, 1 + kMaxActiveClients, nullptr, &signalMask);
271 maskAllSignalsExceptIntAndTerm();
272 if (ret == -1) {
273 LOGI("Exiting poll loop: %s", strerror(errno));
274 break;
275 }
276
277 if (mPollFds[kListenIndex].revents & POLLIN) {
278 acceptClientConnection();
279 }
280
281 for (size_t i = 1; i <= kMaxActiveClients; i++) {
282 if (mPollFds[i].fd < 0) {
283 continue;
284 }
285
286 if (mPollFds[i].revents & POLLIN) {
287 handleClientData(mPollFds[i].fd);
288 }
289 }
290
291 // Mask all signals to ensure that sSignalReceived can't become true between
292 // checking it in the while condition and calling into ppoll()
293 maskAllSignals();
294 }
295 }
296
signalHandler(int signal)297 void SocketServer::signalHandler(int signal) {
298 LOGD("Caught signal %d", signal);
299 sSignalReceived = true;
300 }
301
302 } // namespace chre
303 } // namespace android
304