• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "chre_host/socket_server.h"
18 
19 #include <poll.h>
20 
21 #include <cassert>
22 #include <cinttypes>
23 #include <csignal>
24 #include <cstdlib>
25 #include <map>
26 #include <mutex>
27 
28 #include <cutils/sockets.h>
29 
30 #include "chre_host/log.h"
31 
32 namespace android {
33 namespace chre {
34 
35 std::atomic<bool> SocketServer::sSignalReceived(false);
36 
SocketServer()37 SocketServer::SocketServer() {
38   // Initialize the socket fds field for all inactive client slots to -1, so
39   // poll skips over it, and we don't attempt to send on it
40   for (size_t i = 1; i <= kMaxActiveClients; i++) {
41     mPollFds[i].fd = -1;
42     mPollFds[i].events = POLLIN;
43   }
44 }
45 
run(const char * socketName,bool allowSocketCreation,ClientMessageCallback clientMessageCallback)46 void SocketServer::run(const char *socketName, bool allowSocketCreation,
47                        ClientMessageCallback clientMessageCallback) {
48   mClientMessageCallback = clientMessageCallback;
49 
50   mSockFd = android_get_control_socket(socketName);
51   if (mSockFd == INVALID_SOCKET && allowSocketCreation) {
52     LOGI("Didn't inherit socket, creating...");
53     mSockFd = socket_local_server(socketName, ANDROID_SOCKET_NAMESPACE_RESERVED,
54                                   SOCK_SEQPACKET);
55   }
56 
57   if (mSockFd == INVALID_SOCKET) {
58     LOGE("Couldn't get/create socket");
59   } else {
60     int ret = listen(mSockFd, kMaxPendingConnectionRequests);
61     if (ret < 0) {
62       LOG_ERROR("Couldn't listen on socket", errno);
63     } else {
64       serviceSocket();
65     }
66 
67     {
68       std::lock_guard<std::mutex> lock(mClientsMutex);
69       for (const auto &pair : mClients) {
70         int clientSocket = pair.first;
71         if (close(clientSocket) != 0) {
72           LOGI("Couldn't close client %" PRIu16 "'s socket: %s",
73                pair.second.clientId, strerror(errno));
74         }
75       }
76       mClients.clear();
77     }
78     close(mSockFd);
79   }
80 }
81 
sendToAllClients(const void * data,size_t length)82 void SocketServer::sendToAllClients(const void *data, size_t length) {
83   std::lock_guard<std::mutex> lock(mClientsMutex);
84 
85   int deliveredCount = 0;
86   for (const auto &pair : mClients) {
87     int clientSocket = pair.first;
88     uint16_t clientId = pair.second.clientId;
89     if (sendToClientSocket(data, length, clientSocket, clientId)) {
90       deliveredCount++;
91     } else if (errno == EINTR) {
92       // Exit early if we were interrupted - we should only get this for
93       // SIGINT/SIGTERM, so we should exit quickly
94       break;
95     }
96   }
97 
98   if (deliveredCount == 0) {
99     LOGW("Got message but didn't deliver to any clients");
100   }
101 }
102 
sendToClientById(const void * data,size_t length,uint16_t clientId)103 bool SocketServer::sendToClientById(const void *data, size_t length,
104                                     uint16_t clientId) {
105   std::lock_guard<std::mutex> lock(mClientsMutex);
106 
107   bool sent = false;
108   for (const auto &pair : mClients) {
109     uint16_t thisClientId = pair.second.clientId;
110     if (thisClientId == clientId) {
111       int clientSocket = pair.first;
112       sent = sendToClientSocket(data, length, clientSocket, thisClientId);
113       break;
114     }
115   }
116 
117   return sent;
118 }
119 
acceptClientConnection()120 void SocketServer::acceptClientConnection() {
121   int clientSocket = accept(mSockFd, NULL, NULL);
122   if (clientSocket < 0) {
123     LOG_ERROR("Couldn't accept client connection", errno);
124   } else if (mClients.size() >= kMaxActiveClients) {
125     LOGW("Rejecting client request - maximum number of clients reached");
126     close(clientSocket);
127   } else {
128     ClientData clientData;
129     clientData.clientId = mNextClientId++;
130 
131     // We currently don't handle wraparound - if we're getting this many
132     // connects/disconnects, then something is wrong.
133     // TODO: can handle this properly by iterating over the existing clients to
134     // avoid a conflict.
135     if (clientData.clientId == 0) {
136       LOGE("Couldn't allocate client ID");
137       std::exit(-1);
138     }
139 
140     bool slotFound = false;
141     for (size_t i = 1; i <= kMaxActiveClients; i++) {
142       if (mPollFds[i].fd < 0) {
143         mPollFds[i].fd = clientSocket;
144         slotFound = true;
145         break;
146       }
147     }
148 
149     if (!slotFound) {
150       LOGE("Couldn't find slot for client!");
151       assert(slotFound);
152       close(clientSocket);
153     } else {
154       {
155         std::lock_guard<std::mutex> lock(mClientsMutex);
156         mClients[clientSocket] = clientData;
157       }
158       LOGI(
159           "Accepted new client connection (count %zu), assigned client ID "
160           "%" PRIu16,
161           mClients.size(), clientData.clientId);
162     }
163   }
164 }
165 
handleClientData(int clientSocket)166 void SocketServer::handleClientData(int clientSocket) {
167   const ClientData &clientData = mClients[clientSocket];
168   uint16_t clientId = clientData.clientId;
169 
170   ssize_t packetSize =
171       recv(clientSocket, mRecvBuffer.data(), mRecvBuffer.size(), MSG_DONTWAIT);
172   if (packetSize < 0) {
173     LOGE("Couldn't get packet from client %" PRIu16 ": %s", clientId,
174          strerror(errno));
175   } else if (packetSize == 0) {
176     LOGI("Client %" PRIu16 " disconnected", clientId);
177     disconnectClient(clientSocket);
178   } else {
179     LOGV("Got %zd byte packet from client %" PRIu16, packetSize, clientId);
180     mClientMessageCallback(clientId, mRecvBuffer.data(), packetSize);
181   }
182 }
183 
disconnectClient(int clientSocket)184 void SocketServer::disconnectClient(int clientSocket) {
185   {
186     std::lock_guard<std::mutex> lock(mClientsMutex);
187     mClients.erase(clientSocket);
188   }
189   close(clientSocket);
190 
191   bool removed = false;
192   for (size_t i = 1; i <= kMaxActiveClients; i++) {
193     if (mPollFds[i].fd == clientSocket) {
194       mPollFds[i].fd = -1;
195       removed = true;
196       break;
197     }
198   }
199 
200   if (!removed) {
201     LOGE("Out of sync");
202     assert(removed);
203   }
204 }
205 
sendToClientSocket(const void * data,size_t length,int clientSocket,uint16_t clientId)206 bool SocketServer::sendToClientSocket(const void *data, size_t length,
207                                       int clientSocket, uint16_t clientId) {
208   errno = 0;
209   ssize_t bytesSent = send(clientSocket, data, length, 0);
210   if (bytesSent < 0) {
211     LOGE("Error sending packet of size %zu to client %" PRIu16 ": %s", length,
212          clientId, strerror(errno));
213   } else if (bytesSent == 0) {
214     LOGW("Client %" PRIu16 " disconnected before message could be delivered",
215          clientId);
216   } else {
217     LOGV("Delivered message of size %zu bytes to client %" PRIu16, length,
218          clientId);
219   }
220 
221   return (bytesSent > 0);
222 }
223 
serviceSocket()224 void SocketServer::serviceSocket() {
225   constexpr size_t kListenIndex = 0;
226   static_assert(kListenIndex == 0,
227                 "Code assumes that the first index is always the listen "
228                 "socket");
229 
230   mPollFds[kListenIndex].fd = mSockFd;
231   mPollFds[kListenIndex].events = POLLIN;
232 
233   // Signal mask used with ppoll() so we gracefully handle SIGINT and SIGTERM,
234   // and ignore other signals
235   sigset_t signalMask;
236   sigfillset(&signalMask);
237   sigdelset(&signalMask, SIGINT);
238   sigdelset(&signalMask, SIGTERM);
239 
240   LOGI("Ready to accept connections");
241   while (!sSignalReceived) {
242     int ret = ppoll(mPollFds, 1 + kMaxActiveClients, nullptr, &signalMask);
243     if (ret == -1) {
244       // Don't use TEMP_FAILURE_RETRY since our logic needs to check
245       // sSignalReceived to see if it should exit where as TEMP_FAILURE_RETRY
246       // is a tight retry loop around ppoll.
247       if (errno == EINTR) {
248         continue;
249       }
250       LOGI("Exiting poll loop: %s", strerror(errno));
251       break;
252     }
253 
254     if (mPollFds[kListenIndex].revents & POLLIN) {
255       acceptClientConnection();
256     }
257 
258     for (size_t i = 1; i <= kMaxActiveClients; i++) {
259       if (mPollFds[i].fd < 0) {
260         continue;
261       }
262 
263       if (mPollFds[i].revents & POLLIN) {
264         handleClientData(mPollFds[i].fd);
265       }
266     }
267   }
268 }
269 
270 }  // namespace chre
271 }  // namespace android
272