1 /*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "chre_host/socket_server.h"
18
19 #include <poll.h>
20
21 #include <cassert>
22 #include <cinttypes>
23 #include <csignal>
24 #include <cstdlib>
25 #include <map>
26 #include <mutex>
27
28 #include <cutils/sockets.h>
29
30 #include "chre_host/log.h"
31
32 namespace android {
33 namespace chre {
34
35 std::atomic<bool> SocketServer::sSignalReceived(false);
36
SocketServer()37 SocketServer::SocketServer() {
38 // Initialize the socket fds field for all inactive client slots to -1, so
39 // poll skips over it, and we don't attempt to send on it
40 for (size_t i = 1; i <= kMaxActiveClients; i++) {
41 mPollFds[i].fd = -1;
42 mPollFds[i].events = POLLIN;
43 }
44 }
45
run(const char * socketName,bool allowSocketCreation,ClientMessageCallback clientMessageCallback)46 void SocketServer::run(const char *socketName, bool allowSocketCreation,
47 ClientMessageCallback clientMessageCallback) {
48 mClientMessageCallback = clientMessageCallback;
49
50 mSockFd = android_get_control_socket(socketName);
51 if (mSockFd == INVALID_SOCKET && allowSocketCreation) {
52 LOGI("Didn't inherit socket, creating...");
53 mSockFd = socket_local_server(socketName, ANDROID_SOCKET_NAMESPACE_RESERVED,
54 SOCK_SEQPACKET);
55 }
56
57 if (mSockFd == INVALID_SOCKET) {
58 LOGE("Couldn't get/create socket");
59 } else {
60 int ret = listen(mSockFd, kMaxPendingConnectionRequests);
61 if (ret < 0) {
62 LOG_ERROR("Couldn't listen on socket", errno);
63 } else {
64 serviceSocket();
65 }
66
67 {
68 std::lock_guard<std::mutex> lock(mClientsMutex);
69 for (const auto &pair : mClients) {
70 int clientSocket = pair.first;
71 if (close(clientSocket) != 0) {
72 LOGI("Couldn't close client %" PRIu16 "'s socket: %s",
73 pair.second.clientId, strerror(errno));
74 }
75 }
76 mClients.clear();
77 }
78 close(mSockFd);
79 }
80 }
81
sendToAllClients(const void * data,size_t length)82 void SocketServer::sendToAllClients(const void *data, size_t length) {
83 std::lock_guard<std::mutex> lock(mClientsMutex);
84
85 int deliveredCount = 0;
86 for (const auto &pair : mClients) {
87 int clientSocket = pair.first;
88 uint16_t clientId = pair.second.clientId;
89 if (sendToClientSocket(data, length, clientSocket, clientId)) {
90 deliveredCount++;
91 } else if (errno == EINTR) {
92 // Exit early if we were interrupted - we should only get this for
93 // SIGINT/SIGTERM, so we should exit quickly
94 break;
95 }
96 }
97
98 if (deliveredCount == 0) {
99 LOGW("Got message but didn't deliver to any clients");
100 }
101 }
102
sendToClientById(const void * data,size_t length,uint16_t clientId)103 bool SocketServer::sendToClientById(const void *data, size_t length,
104 uint16_t clientId) {
105 std::lock_guard<std::mutex> lock(mClientsMutex);
106
107 bool sent = false;
108 for (const auto &pair : mClients) {
109 uint16_t thisClientId = pair.second.clientId;
110 if (thisClientId == clientId) {
111 int clientSocket = pair.first;
112 sent = sendToClientSocket(data, length, clientSocket, thisClientId);
113 break;
114 }
115 }
116
117 return sent;
118 }
119
acceptClientConnection()120 void SocketServer::acceptClientConnection() {
121 int clientSocket = accept(mSockFd, NULL, NULL);
122 if (clientSocket < 0) {
123 LOG_ERROR("Couldn't accept client connection", errno);
124 } else if (mClients.size() >= kMaxActiveClients) {
125 LOGW("Rejecting client request - maximum number of clients reached");
126 close(clientSocket);
127 } else {
128 ClientData clientData;
129 clientData.clientId = mNextClientId++;
130
131 // We currently don't handle wraparound - if we're getting this many
132 // connects/disconnects, then something is wrong.
133 // TODO: can handle this properly by iterating over the existing clients to
134 // avoid a conflict.
135 if (clientData.clientId == 0) {
136 LOGE("Couldn't allocate client ID");
137 std::exit(-1);
138 }
139
140 bool slotFound = false;
141 for (size_t i = 1; i <= kMaxActiveClients; i++) {
142 if (mPollFds[i].fd < 0) {
143 mPollFds[i].fd = clientSocket;
144 slotFound = true;
145 break;
146 }
147 }
148
149 if (!slotFound) {
150 LOGE("Couldn't find slot for client!");
151 assert(slotFound);
152 close(clientSocket);
153 } else {
154 {
155 std::lock_guard<std::mutex> lock(mClientsMutex);
156 mClients[clientSocket] = clientData;
157 }
158 LOGI(
159 "Accepted new client connection (count %zu), assigned client ID "
160 "%" PRIu16,
161 mClients.size(), clientData.clientId);
162 }
163 }
164 }
165
handleClientData(int clientSocket)166 void SocketServer::handleClientData(int clientSocket) {
167 const ClientData &clientData = mClients[clientSocket];
168 uint16_t clientId = clientData.clientId;
169
170 ssize_t packetSize =
171 recv(clientSocket, mRecvBuffer.data(), mRecvBuffer.size(), MSG_DONTWAIT);
172 if (packetSize < 0) {
173 LOGE("Couldn't get packet from client %" PRIu16 ": %s", clientId,
174 strerror(errno));
175 } else if (packetSize == 0) {
176 LOGI("Client %" PRIu16 " disconnected", clientId);
177 disconnectClient(clientSocket);
178 } else {
179 LOGV("Got %zd byte packet from client %" PRIu16, packetSize, clientId);
180 mClientMessageCallback(clientId, mRecvBuffer.data(), packetSize);
181 }
182 }
183
disconnectClient(int clientSocket)184 void SocketServer::disconnectClient(int clientSocket) {
185 {
186 std::lock_guard<std::mutex> lock(mClientsMutex);
187 mClients.erase(clientSocket);
188 }
189 close(clientSocket);
190
191 bool removed = false;
192 for (size_t i = 1; i <= kMaxActiveClients; i++) {
193 if (mPollFds[i].fd == clientSocket) {
194 mPollFds[i].fd = -1;
195 removed = true;
196 break;
197 }
198 }
199
200 if (!removed) {
201 LOGE("Out of sync");
202 assert(removed);
203 }
204 }
205
sendToClientSocket(const void * data,size_t length,int clientSocket,uint16_t clientId)206 bool SocketServer::sendToClientSocket(const void *data, size_t length,
207 int clientSocket, uint16_t clientId) {
208 errno = 0;
209 ssize_t bytesSent = send(clientSocket, data, length, 0);
210 if (bytesSent < 0) {
211 LOGE("Error sending packet of size %zu to client %" PRIu16 ": %s", length,
212 clientId, strerror(errno));
213 } else if (bytesSent == 0) {
214 LOGW("Client %" PRIu16 " disconnected before message could be delivered",
215 clientId);
216 } else {
217 LOGV("Delivered message of size %zu bytes to client %" PRIu16, length,
218 clientId);
219 }
220
221 return (bytesSent > 0);
222 }
223
serviceSocket()224 void SocketServer::serviceSocket() {
225 constexpr size_t kListenIndex = 0;
226 static_assert(kListenIndex == 0,
227 "Code assumes that the first index is always the listen "
228 "socket");
229
230 mPollFds[kListenIndex].fd = mSockFd;
231 mPollFds[kListenIndex].events = POLLIN;
232
233 // Signal mask used with ppoll() so we gracefully handle SIGINT and SIGTERM,
234 // and ignore other signals
235 sigset_t signalMask;
236 sigfillset(&signalMask);
237 sigdelset(&signalMask, SIGINT);
238 sigdelset(&signalMask, SIGTERM);
239
240 LOGI("Ready to accept connections");
241 while (!sSignalReceived) {
242 int ret = ppoll(mPollFds, 1 + kMaxActiveClients, nullptr, &signalMask);
243 if (ret == -1) {
244 // Don't use TEMP_FAILURE_RETRY since our logic needs to check
245 // sSignalReceived to see if it should exit where as TEMP_FAILURE_RETRY
246 // is a tight retry loop around ppoll.
247 if (errno == EINTR) {
248 continue;
249 }
250 LOGI("Exiting poll loop: %s", strerror(errno));
251 break;
252 }
253
254 if (mPollFds[kListenIndex].revents & POLLIN) {
255 acceptClientConnection();
256 }
257
258 for (size_t i = 1; i <= kMaxActiveClients; i++) {
259 if (mPollFds[i].fd < 0) {
260 continue;
261 }
262
263 if (mPollFds[i].revents & POLLIN) {
264 handleClientData(mPollFds[i].fd);
265 }
266 }
267 }
268 }
269
270 } // namespace chre
271 } // namespace android
272