• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2020 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "RpcServer"
18 
19 #include <inttypes.h>
20 #include <poll.h>
21 #include <sys/socket.h>
22 #include <sys/un.h>
23 
24 #include <thread>
25 #include <vector>
26 
27 #include <android-base/file.h>
28 #include <android-base/hex.h>
29 #include <android-base/scopeguard.h>
30 #include <binder/Parcel.h>
31 #include <binder/RpcServer.h>
32 #include <binder/RpcTransportRaw.h>
33 #include <log/log.h>
34 
35 #include "FdTrigger.h"
36 #include "RpcSocketAddress.h"
37 #include "RpcState.h"
38 #include "RpcWireFormat.h"
39 
40 namespace android {
41 
42 constexpr size_t kSessionIdBytes = 32;
43 
44 using base::ScopeGuard;
45 using base::unique_fd;
46 
RpcServer(std::unique_ptr<RpcTransportCtx> ctx)47 RpcServer::RpcServer(std::unique_ptr<RpcTransportCtx> ctx) : mCtx(std::move(ctx)) {}
~RpcServer()48 RpcServer::~RpcServer() {
49     (void)shutdown();
50 }
51 
make(std::unique_ptr<RpcTransportCtxFactory> rpcTransportCtxFactory)52 sp<RpcServer> RpcServer::make(std::unique_ptr<RpcTransportCtxFactory> rpcTransportCtxFactory) {
53     // Default is without TLS.
54     if (rpcTransportCtxFactory == nullptr)
55         rpcTransportCtxFactory = RpcTransportCtxFactoryRaw::make();
56     auto ctx = rpcTransportCtxFactory->newServerCtx();
57     if (ctx == nullptr) return nullptr;
58     return sp<RpcServer>::make(std::move(ctx));
59 }
60 
setupUnixDomainServer(const char * path)61 status_t RpcServer::setupUnixDomainServer(const char* path) {
62     return setupSocketServer(UnixSocketAddress(path));
63 }
64 
setupVsockServer(unsigned int port)65 status_t RpcServer::setupVsockServer(unsigned int port) {
66     // realizing value w/ this type at compile time to avoid ubsan abort
67     constexpr unsigned int kAnyCid = VMADDR_CID_ANY;
68 
69     return setupSocketServer(VsockSocketAddress(kAnyCid, port));
70 }
71 
setupInetServer(const char * address,unsigned int port,unsigned int * assignedPort)72 status_t RpcServer::setupInetServer(const char* address, unsigned int port,
73                                     unsigned int* assignedPort) {
74     if (assignedPort != nullptr) *assignedPort = 0;
75     auto aiStart = InetSocketAddress::getAddrInfo(address, port);
76     if (aiStart == nullptr) return UNKNOWN_ERROR;
77     for (auto ai = aiStart.get(); ai != nullptr; ai = ai->ai_next) {
78         InetSocketAddress socketAddress(ai->ai_addr, ai->ai_addrlen, address, port);
79         if (status_t status = setupSocketServer(socketAddress); status != OK) {
80             continue;
81         }
82 
83         LOG_ALWAYS_FATAL_IF(socketAddress.addr()->sa_family != AF_INET, "expecting inet");
84         sockaddr_in addr{};
85         socklen_t len = sizeof(addr);
86         if (0 != getsockname(mServer.get(), reinterpret_cast<sockaddr*>(&addr), &len)) {
87             int savedErrno = errno;
88             ALOGE("Could not getsockname at %s: %s", socketAddress.toString().c_str(),
89                   strerror(savedErrno));
90             return -savedErrno;
91         }
92         LOG_ALWAYS_FATAL_IF(len != sizeof(addr), "Wrong socket type: len %zu vs len %zu",
93                             static_cast<size_t>(len), sizeof(addr));
94         unsigned int realPort = ntohs(addr.sin_port);
95         LOG_ALWAYS_FATAL_IF(port != 0 && realPort != port,
96                             "Requesting inet server on %s but it is set up on %u.",
97                             socketAddress.toString().c_str(), realPort);
98 
99         if (assignedPort != nullptr) {
100             *assignedPort = realPort;
101         }
102 
103         return OK;
104     }
105     ALOGE("None of the socket address resolved for %s:%u can be set up as inet server.", address,
106           port);
107     return UNKNOWN_ERROR;
108 }
109 
setMaxThreads(size_t threads)110 void RpcServer::setMaxThreads(size_t threads) {
111     LOG_ALWAYS_FATAL_IF(threads <= 0, "RpcServer is useless without threads");
112     LOG_ALWAYS_FATAL_IF(mJoinThreadRunning, "Cannot set max threads while running");
113     mMaxThreads = threads;
114 }
115 
getMaxThreads()116 size_t RpcServer::getMaxThreads() {
117     return mMaxThreads;
118 }
119 
setProtocolVersion(uint32_t version)120 void RpcServer::setProtocolVersion(uint32_t version) {
121     mProtocolVersion = version;
122 }
123 
setRootObject(const sp<IBinder> & binder)124 void RpcServer::setRootObject(const sp<IBinder>& binder) {
125     std::lock_guard<std::mutex> _l(mLock);
126     mRootObjectFactory = nullptr;
127     mRootObjectWeak = mRootObject = binder;
128 }
129 
setRootObjectWeak(const wp<IBinder> & binder)130 void RpcServer::setRootObjectWeak(const wp<IBinder>& binder) {
131     std::lock_guard<std::mutex> _l(mLock);
132     mRootObject.clear();
133     mRootObjectFactory = nullptr;
134     mRootObjectWeak = binder;
135 }
setPerSessionRootObject(std::function<sp<IBinder> (const sockaddr *,socklen_t)> && makeObject)136 void RpcServer::setPerSessionRootObject(
137         std::function<sp<IBinder>(const sockaddr*, socklen_t)>&& makeObject) {
138     std::lock_guard<std::mutex> _l(mLock);
139     mRootObject.clear();
140     mRootObjectWeak.clear();
141     mRootObjectFactory = std::move(makeObject);
142 }
143 
getRootObject()144 sp<IBinder> RpcServer::getRootObject() {
145     std::lock_guard<std::mutex> _l(mLock);
146     bool hasWeak = mRootObjectWeak.unsafe_get();
147     sp<IBinder> ret = mRootObjectWeak.promote();
148     ALOGW_IF(hasWeak && ret == nullptr, "RpcServer root object is freed, returning nullptr");
149     return ret;
150 }
151 
getCertificate(RpcCertificateFormat format)152 std::vector<uint8_t> RpcServer::getCertificate(RpcCertificateFormat format) {
153     std::lock_guard<std::mutex> _l(mLock);
154     return mCtx->getCertificate(format);
155 }
156 
joinRpcServer(sp<RpcServer> && thiz)157 static void joinRpcServer(sp<RpcServer>&& thiz) {
158     thiz->join();
159 }
160 
start()161 void RpcServer::start() {
162     std::lock_guard<std::mutex> _l(mLock);
163     LOG_ALWAYS_FATAL_IF(mJoinThread.get(), "Already started!");
164     mJoinThread = std::make_unique<std::thread>(&joinRpcServer, sp<RpcServer>::fromExisting(this));
165 }
166 
join()167 void RpcServer::join() {
168 
169     {
170         std::lock_guard<std::mutex> _l(mLock);
171         LOG_ALWAYS_FATAL_IF(!mServer.ok(), "RpcServer must be setup to join.");
172         LOG_ALWAYS_FATAL_IF(mShutdownTrigger != nullptr, "Already joined");
173         mJoinThreadRunning = true;
174         mShutdownTrigger = FdTrigger::make();
175         LOG_ALWAYS_FATAL_IF(mShutdownTrigger == nullptr, "Cannot create join signaler");
176     }
177 
178     status_t status;
179     while ((status = mShutdownTrigger->triggerablePoll(mServer, POLLIN)) == OK) {
180         sockaddr_storage addr;
181         socklen_t addrLen = sizeof(addr);
182 
183         unique_fd clientFd(
184                 TEMP_FAILURE_RETRY(accept4(mServer.get(), reinterpret_cast<sockaddr*>(&addr),
185                                            &addrLen, SOCK_CLOEXEC | SOCK_NONBLOCK)));
186 
187         LOG_ALWAYS_FATAL_IF(addrLen > static_cast<socklen_t>(sizeof(addr)), "Truncated address");
188 
189         if (clientFd < 0) {
190             ALOGE("Could not accept4 socket: %s", strerror(errno));
191             continue;
192         }
193         LOG_RPC_DETAIL("accept4 on fd %d yields fd %d", mServer.get(), clientFd.get());
194 
195         {
196             std::lock_guard<std::mutex> _l(mLock);
197             std::thread thread =
198                     std::thread(&RpcServer::establishConnection, sp<RpcServer>::fromExisting(this),
199                                 std::move(clientFd), addr, addrLen);
200             mConnectingThreads[thread.get_id()] = std::move(thread);
201         }
202     }
203     LOG_RPC_DETAIL("RpcServer::join exiting with %s", statusToString(status).c_str());
204 
205     {
206         std::lock_guard<std::mutex> _l(mLock);
207         mJoinThreadRunning = false;
208     }
209     mShutdownCv.notify_all();
210 }
211 
shutdown()212 bool RpcServer::shutdown() {
213     std::unique_lock<std::mutex> _l(mLock);
214     if (mShutdownTrigger == nullptr) {
215         LOG_RPC_DETAIL("Cannot shutdown. No shutdown trigger installed (already shutdown?)");
216         return false;
217     }
218 
219     mShutdownTrigger->trigger();
220 
221     for (auto& [id, session] : mSessions) {
222         (void)id;
223         // server lock is a more general lock
224         std::lock_guard<std::mutex> _lSession(session->mMutex);
225         session->mShutdownTrigger->trigger();
226     }
227 
228     while (mJoinThreadRunning || !mConnectingThreads.empty() || !mSessions.empty()) {
229         if (std::cv_status::timeout == mShutdownCv.wait_for(_l, std::chrono::seconds(1))) {
230             ALOGE("Waiting for RpcServer to shut down (1s w/o progress). Join thread running: %d, "
231                   "Connecting threads: "
232                   "%zu, Sessions: %zu. Is your server deadlocked?",
233                   mJoinThreadRunning, mConnectingThreads.size(), mSessions.size());
234         }
235     }
236 
237     // At this point, we know join() is about to exit, but the thread that calls
238     // join() may not have exited yet.
239     // If RpcServer owns the join thread (aka start() is called), make sure the thread exits;
240     // otherwise ~thread() may call std::terminate(), which may crash the process.
241     // If RpcServer does not own the join thread (aka join() is called directly),
242     // then the owner of RpcServer is responsible for cleaning up that thread.
243     if (mJoinThread.get()) {
244         mJoinThread->join();
245         mJoinThread.reset();
246     }
247 
248     LOG_RPC_DETAIL("Finished waiting on shutdown.");
249 
250     mShutdownTrigger = nullptr;
251     return true;
252 }
253 
listSessions()254 std::vector<sp<RpcSession>> RpcServer::listSessions() {
255     std::lock_guard<std::mutex> _l(mLock);
256     std::vector<sp<RpcSession>> sessions;
257     for (auto& [id, session] : mSessions) {
258         (void)id;
259         sessions.push_back(session);
260     }
261     return sessions;
262 }
263 
numUninitializedSessions()264 size_t RpcServer::numUninitializedSessions() {
265     std::lock_guard<std::mutex> _l(mLock);
266     return mConnectingThreads.size();
267 }
268 
establishConnection(sp<RpcServer> && server,base::unique_fd clientFd,const sockaddr_storage addr,socklen_t addrLen)269 void RpcServer::establishConnection(sp<RpcServer>&& server, base::unique_fd clientFd,
270                                     const sockaddr_storage addr, socklen_t addrLen) {
271     // mShutdownTrigger can only be cleared once connection threads have joined.
272     // It must be set before this thread is started
273     LOG_ALWAYS_FATAL_IF(server->mShutdownTrigger == nullptr);
274     LOG_ALWAYS_FATAL_IF(server->mCtx == nullptr);
275 
276     status_t status = OK;
277 
278     int clientFdForLog = clientFd.get();
279     auto client = server->mCtx->newTransport(std::move(clientFd), server->mShutdownTrigger.get());
280     if (client == nullptr) {
281         ALOGE("Dropping accept4()-ed socket because sslAccept fails");
282         status = DEAD_OBJECT;
283         // still need to cleanup before we can return
284     } else {
285         LOG_RPC_DETAIL("Created RpcTransport %p for client fd %d", client.get(), clientFdForLog);
286     }
287 
288     RpcConnectionHeader header;
289     if (status == OK) {
290         iovec iov{&header, sizeof(header)};
291         status = client->interruptableReadFully(server->mShutdownTrigger.get(), &iov, 1, {});
292         if (status != OK) {
293             ALOGE("Failed to read ID for client connecting to RPC server: %s",
294                   statusToString(status).c_str());
295             // still need to cleanup before we can return
296         }
297     }
298 
299     std::vector<uint8_t> sessionId;
300     if (status == OK) {
301         if (header.sessionIdSize > 0) {
302             if (header.sessionIdSize == kSessionIdBytes) {
303                 sessionId.resize(header.sessionIdSize);
304                 iovec iov{sessionId.data(), sessionId.size()};
305                 status =
306                         client->interruptableReadFully(server->mShutdownTrigger.get(), &iov, 1, {});
307                 if (status != OK) {
308                     ALOGE("Failed to read session ID for client connecting to RPC server: %s",
309                           statusToString(status).c_str());
310                     // still need to cleanup before we can return
311                 }
312             } else {
313                 ALOGE("Malformed session ID. Expecting session ID of size %zu but got %" PRIu16,
314                       kSessionIdBytes, header.sessionIdSize);
315                 status = BAD_VALUE;
316             }
317         }
318     }
319 
320     bool incoming = false;
321     uint32_t protocolVersion = 0;
322     bool requestingNewSession = false;
323 
324     if (status == OK) {
325         incoming = header.options & RPC_CONNECTION_OPTION_INCOMING;
326         protocolVersion = std::min(header.version,
327                                    server->mProtocolVersion.value_or(RPC_WIRE_PROTOCOL_VERSION));
328         requestingNewSession = sessionId.empty();
329 
330         if (requestingNewSession) {
331             RpcNewSessionResponse response{
332                     .version = protocolVersion,
333             };
334 
335             iovec iov{&response, sizeof(response)};
336             status = client->interruptableWriteFully(server->mShutdownTrigger.get(), &iov, 1, {});
337             if (status != OK) {
338                 ALOGE("Failed to send new session response: %s", statusToString(status).c_str());
339                 // still need to cleanup before we can return
340             }
341         }
342     }
343 
344     std::thread thisThread;
345     sp<RpcSession> session;
346     {
347         std::unique_lock<std::mutex> _l(server->mLock);
348 
349         auto threadId = server->mConnectingThreads.find(std::this_thread::get_id());
350         LOG_ALWAYS_FATAL_IF(threadId == server->mConnectingThreads.end(),
351                             "Must establish connection on owned thread");
352         thisThread = std::move(threadId->second);
353         ScopeGuard detachGuard = [&]() {
354             thisThread.detach();
355             _l.unlock();
356             server->mShutdownCv.notify_all();
357         };
358         server->mConnectingThreads.erase(threadId);
359 
360         if (status != OK || server->mShutdownTrigger->isTriggered()) {
361             return;
362         }
363 
364         if (requestingNewSession) {
365             if (incoming) {
366                 ALOGE("Cannot create a new session with an incoming connection, would leak");
367                 return;
368             }
369 
370             // Uniquely identify session at the application layer. Even if a
371             // client/server use the same certificates, if they create multiple
372             // sessions, we still want to distinguish between them.
373             sessionId.resize(kSessionIdBytes);
374             size_t tries = 0;
375             do {
376                 // don't block if there is some entropy issue
377                 if (tries++ > 5) {
378                     ALOGE("Cannot find new address: %s",
379                           base::HexString(sessionId.data(), sessionId.size()).c_str());
380                     return;
381                 }
382 
383                 base::unique_fd fd(TEMP_FAILURE_RETRY(
384                         open("/dev/urandom", O_RDONLY | O_CLOEXEC | O_NOFOLLOW)));
385                 if (!base::ReadFully(fd, sessionId.data(), sessionId.size())) {
386                     ALOGE("Could not read from /dev/urandom to create session ID");
387                     return;
388                 }
389             } while (server->mSessions.end() != server->mSessions.find(sessionId));
390 
391             session = RpcSession::make();
392             session->setMaxIncomingThreads(server->mMaxThreads);
393             if (!session->setProtocolVersion(protocolVersion)) return;
394 
395             // if null, falls back to server root
396             sp<IBinder> sessionSpecificRoot;
397             if (server->mRootObjectFactory != nullptr) {
398                 sessionSpecificRoot =
399                         server->mRootObjectFactory(reinterpret_cast<const sockaddr*>(&addr),
400                                                    addrLen);
401                 if (sessionSpecificRoot == nullptr) {
402                     ALOGE("Warning: server returned null from root object factory");
403                 }
404             }
405 
406             if (!session->setForServer(server,
407                                        sp<RpcServer::EventListener>::fromExisting(
408                                                static_cast<RpcServer::EventListener*>(
409                                                        server.get())),
410                                        sessionId, sessionSpecificRoot)) {
411                 ALOGE("Failed to attach server to session");
412                 return;
413             }
414 
415             server->mSessions[sessionId] = session;
416         } else {
417             auto it = server->mSessions.find(sessionId);
418             if (it == server->mSessions.end()) {
419                 ALOGE("Cannot add thread, no record of session with ID %s",
420                       base::HexString(sessionId.data(), sessionId.size()).c_str());
421                 return;
422             }
423             session = it->second;
424         }
425 
426         if (incoming) {
427             LOG_ALWAYS_FATAL_IF(OK != session->addOutgoingConnection(std::move(client), true),
428                                 "server state must already be initialized");
429             return;
430         }
431 
432         detachGuard.Disable();
433         session->preJoinThreadOwnership(std::move(thisThread));
434     }
435 
436     auto setupResult = session->preJoinSetup(std::move(client));
437 
438     // avoid strong cycle
439     server = nullptr;
440 
441     RpcSession::join(std::move(session), std::move(setupResult));
442 }
443 
setupSocketServer(const RpcSocketAddress & addr)444 status_t RpcServer::setupSocketServer(const RpcSocketAddress& addr) {
445     LOG_RPC_DETAIL("Setting up socket server %s", addr.toString().c_str());
446     LOG_ALWAYS_FATAL_IF(hasServer(), "Each RpcServer can only have one server.");
447 
448     unique_fd serverFd(TEMP_FAILURE_RETRY(
449             socket(addr.addr()->sa_family, SOCK_STREAM | SOCK_CLOEXEC | SOCK_NONBLOCK, 0)));
450     if (serverFd == -1) {
451         int savedErrno = errno;
452         ALOGE("Could not create socket: %s", strerror(savedErrno));
453         return -savedErrno;
454     }
455 
456     if (0 != TEMP_FAILURE_RETRY(bind(serverFd.get(), addr.addr(), addr.addrSize()))) {
457         int savedErrno = errno;
458         ALOGE("Could not bind socket at %s: %s", addr.toString().c_str(), strerror(savedErrno));
459         return -savedErrno;
460     }
461 
462     // Right now, we create all threads at once, making accept4 slow. To avoid hanging the client,
463     // the backlog is increased to a large number.
464     // TODO(b/189955605): Once we create threads dynamically & lazily, the backlog can be reduced
465     //  to 1.
466     if (0 != TEMP_FAILURE_RETRY(listen(serverFd.get(), 50 /*backlog*/))) {
467         int savedErrno = errno;
468         ALOGE("Could not listen socket at %s: %s", addr.toString().c_str(), strerror(savedErrno));
469         return -savedErrno;
470     }
471 
472     LOG_RPC_DETAIL("Successfully setup socket server %s", addr.toString().c_str());
473 
474     if (status_t status = setupExternalServer(std::move(serverFd)); status != OK) {
475         ALOGE("Another thread has set up server while calling setupSocketServer. Race?");
476         return status;
477     }
478     return OK;
479 }
480 
onSessionAllIncomingThreadsEnded(const sp<RpcSession> & session)481 void RpcServer::onSessionAllIncomingThreadsEnded(const sp<RpcSession>& session) {
482     const std::vector<uint8_t>& id = session->mId;
483     LOG_ALWAYS_FATAL_IF(id.empty(), "Server sessions must be initialized with ID");
484     LOG_RPC_DETAIL("Dropping session with address %s",
485                    base::HexString(id.data(), id.size()).c_str());
486 
487     std::lock_guard<std::mutex> _l(mLock);
488     auto it = mSessions.find(id);
489     LOG_ALWAYS_FATAL_IF(it == mSessions.end(), "Bad state, unknown session id %s",
490                         base::HexString(id.data(), id.size()).c_str());
491     LOG_ALWAYS_FATAL_IF(it->second != session, "Bad state, session has id mismatch %s",
492                         base::HexString(id.data(), id.size()).c_str());
493     (void)mSessions.erase(it);
494 }
495 
onSessionIncomingThreadEnded()496 void RpcServer::onSessionIncomingThreadEnded() {
497     mShutdownCv.notify_all();
498 }
499 
hasServer()500 bool RpcServer::hasServer() {
501     std::lock_guard<std::mutex> _l(mLock);
502     return mServer.ok();
503 }
504 
releaseServer()505 unique_fd RpcServer::releaseServer() {
506     std::lock_guard<std::mutex> _l(mLock);
507     return std::move(mServer);
508 }
509 
setupExternalServer(base::unique_fd serverFd)510 status_t RpcServer::setupExternalServer(base::unique_fd serverFd) {
511     std::lock_guard<std::mutex> _l(mLock);
512     if (mServer.ok()) {
513         ALOGE("Each RpcServer can only have one server.");
514         return INVALID_OPERATION;
515     }
516     mServer = std::move(serverFd);
517     return OK;
518 }
519 
520 } // namespace android
521