• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "adbwifi/pairing/pairing_server.h"
18 
19 #include <sys/epoll.h>
20 #include <sys/eventfd.h>
21 
22 #include <atomic>
23 #include <deque>
24 #include <iomanip>
25 #include <mutex>
26 #include <sstream>
27 #include <thread>
28 #include <tuple>
29 #include <unordered_map>
30 #include <variant>
31 #include <vector>
32 
33 #include <adbwifi/pairing/pairing_connection.h>
34 #include <android-base/logging.h>
35 #include <android-base/parsenetaddress.h>
36 #include <android-base/thread_annotations.h>
37 #include <android-base/unique_fd.h>
38 #include <cutils/sockets.h>
39 
40 namespace adbwifi {
41 namespace pairing {
42 
43 using android::base::ScopedLockAssertion;
44 using android::base::unique_fd;
45 
46 namespace {
47 
48 // The implimentation has two background threads running: one to handle and
49 // accept any new pairing connection requests (socket accept), and the other to
50 // handle connection events (connection started, connection finished).
51 class PairingServerImpl : public PairingServer {
52   public:
53     virtual ~PairingServerImpl();
54 
55     // All parameters must be non-empty.
56     explicit PairingServerImpl(const Data& pswd, const PeerInfo& peer_info, const Data& cert,
57                                const Data& priv_key, int port);
58 
59     // Starts the pairing server. This call is non-blocking. Upon completion,
60     // if the pairing was successful, then |cb| will be called with the PublicKeyHeader
61     // containing the info of the trusted peer. Otherwise, |cb| will be
62     // called with an empty value. Start can only be called once in the lifetime
63     // of this object.
64     //
65     // Returns true if PairingServer was successfully started. Otherwise,
66     // returns false.
67     virtual bool start(PairingConnection::ResultCallback cb, void* opaque) override;
68 
69   private:
70     // Setup the server socket to accept incoming connections
71     bool setupServer();
72     // Force stop the server thread.
73     void stopServer();
74 
75     // handles a new pairing client connection
76     bool handleNewClientConnection(int fd) EXCLUDES(conn_mutex_);
77 
78     // ======== connection events thread =============
79     std::mutex conn_mutex_;
80     std::condition_variable conn_cv_;
81 
82     using FdVal = int;
83     using ConnectionPtr = std::unique_ptr<PairingConnection>;
84     using NewConnectionEvent = std::tuple<unique_fd, ConnectionPtr>;
85     // <fd, PeerInfo.name, PeerInfo.guid, certificate>
86     using ConnectionFinishedEvent = std::tuple<FdVal, std::optional<std::string>,
87                                                std::optional<std::string>, std::optional<Data>>;
88     using ConnectionEvent = std::variant<NewConnectionEvent, ConnectionFinishedEvent>;
89     // Queue for connections to write into. We have a separate queue to read
90     // from, in order to minimize the time the server thread is blocked.
91     std::deque<ConnectionEvent> conn_write_queue_ GUARDED_BY(conn_mutex_);
92     std::deque<ConnectionEvent> conn_read_queue_;
93     // Map of fds to their PairingConnections currently running.
94     std::unordered_map<FdVal, ConnectionPtr> connections_;
95 
96     // Two threads launched when starting the pairing server:
97     // 1) A server thread that waits for incoming client connections, and
98     // 2) A connection events thread that synchonizes events from all of the
99     //    clients, since each PairingConnection is running in it's own thread.
100     void startConnectionEventsThread();
101     void startServerThread();
102 
103     std::thread conn_events_thread_;
104     void connectionEventsWorker();
105     std::thread server_thread_;
106     void serverWorker();
107     bool is_terminate_ GUARDED_BY(conn_mutex_) = false;
108 
109     enum class State {
110         Ready,
111         Running,
112         Stopped,
113     };
114     State state_ = State::Ready;
115     Data pswd_;
116     PeerInfo peer_info_;
117     Data cert_;
118     Data priv_key_;
119     int port_ = -1;
120 
121     PairingConnection::ResultCallback cb_;
122     void* opaque_ = nullptr;
123     bool got_valid_pairing_ = false;
124 
125     static const int kEpollConstSocket = 0;
126     // Used to break the server thread from epoll_wait
127     static const int kEpollConstEventFd = 1;
128     unique_fd epoll_fd_;
129     unique_fd server_fd_;
130     unique_fd event_fd_;
131 };  // PairingServerImpl
132 
PairingServerImpl(const Data & pswd,const PeerInfo & peer_info,const Data & cert,const Data & priv_key,int port)133 PairingServerImpl::PairingServerImpl(const Data& pswd, const PeerInfo& peer_info, const Data& cert,
134                                      const Data& priv_key, int port)
135     : pswd_(pswd), peer_info_(peer_info), cert_(cert), priv_key_(priv_key), port_(port) {
136     CHECK(!pswd_.empty() && !cert_.empty() && !priv_key_.empty() && port_ > 0);
137     CHECK('\0' == peer_info.name[kPeerNameLength - 1] &&
138           '\0' == peer_info.guid[kPeerGuidLength - 1] && strlen(peer_info.name) > 0 &&
139           strlen(peer_info.guid) > 0);
140 }
141 
~PairingServerImpl()142 PairingServerImpl::~PairingServerImpl() {
143     // Since these connections have references to us, let's make sure they
144     // destruct before us.
145     if (server_thread_.joinable()) {
146         stopServer();
147         server_thread_.join();
148     }
149 
150     {
151         std::lock_guard<std::mutex> lock(conn_mutex_);
152         is_terminate_ = true;
153     }
154     conn_cv_.notify_one();
155     if (conn_events_thread_.joinable()) {
156         conn_events_thread_.join();
157     }
158 
159     // Notify the cb_ if it hasn't already.
160     if (!got_valid_pairing_ && cb_ != nullptr) {
161         cb_(nullptr, nullptr, opaque_);
162     }
163 }
164 
start(PairingConnection::ResultCallback cb,void * opaque)165 bool PairingServerImpl::start(PairingConnection::ResultCallback cb, void* opaque) {
166     cb_ = cb;
167     opaque_ = opaque;
168 
169     if (state_ != State::Ready) {
170         LOG(ERROR) << "PairingServer already running or stopped";
171         return false;
172     }
173 
174     if (!setupServer()) {
175         LOG(ERROR) << "Unable to start PairingServer";
176         state_ = State::Stopped;
177         return false;
178     }
179 
180     state_ = State::Running;
181     return true;
182 }
183 
stopServer()184 void PairingServerImpl::stopServer() {
185     if (event_fd_.get() == -1) {
186         return;
187     }
188     uint64_t value = 1;
189     ssize_t rc = write(event_fd_.get(), &value, sizeof(value));
190     if (rc == -1) {
191         // This can happen if the server didn't start.
192         PLOG(ERROR) << "write to eventfd failed";
193     } else if (rc != sizeof(value)) {
194         LOG(FATAL) << "write to event returned short (" << rc << ")";
195     }
196 }
197 
setupServer()198 bool PairingServerImpl::setupServer() {
199     epoll_fd_.reset(epoll_create1(EPOLL_CLOEXEC));
200     if (epoll_fd_ == -1) {
201         PLOG(ERROR) << "failed to create epoll fd";
202         return false;
203     }
204 
205     event_fd_.reset(eventfd(0, EFD_CLOEXEC | EFD_NONBLOCK));
206     if (event_fd_ == -1) {
207         PLOG(ERROR) << "failed to create eventfd";
208         return false;
209     }
210 
211     server_fd_.reset(socket_inaddr_any_server(port_, SOCK_STREAM));
212     if (server_fd_.get() == -1) {
213         PLOG(ERROR) << "Failed to start pairing connection server";
214         return false;
215     } else if (fcntl(server_fd_.get(), F_SETFD, FD_CLOEXEC) != 0) {
216         PLOG(ERROR) << "Failed to make server socket cloexec";
217         return false;
218     } else if (fcntl(server_fd_.get(), F_SETFD, O_NONBLOCK) != 0) {
219         PLOG(ERROR) << "Failed to make server socket nonblocking";
220         return false;
221     }
222 
223     startConnectionEventsThread();
224     startServerThread();
225     return true;
226 }
227 
startServerThread()228 void PairingServerImpl::startServerThread() {
229     server_thread_ = std::thread([this]() { serverWorker(); });
230 }
231 
startConnectionEventsThread()232 void PairingServerImpl::startConnectionEventsThread() {
233     conn_events_thread_ = std::thread([this]() { connectionEventsWorker(); });
234 }
235 
serverWorker()236 void PairingServerImpl::serverWorker() {
237     {
238         struct epoll_event event;
239         event.events = EPOLLIN;
240         event.data.u64 = kEpollConstSocket;
241         CHECK_EQ(0, epoll_ctl(epoll_fd_.get(), EPOLL_CTL_ADD, server_fd_.get(), &event));
242     }
243 
244     {
245         struct epoll_event event;
246         event.events = EPOLLIN;
247         event.data.u64 = kEpollConstEventFd;
248         CHECK_EQ(0, epoll_ctl(epoll_fd_.get(), EPOLL_CTL_ADD, event_fd_.get(), &event));
249     }
250 
251     while (true) {
252         struct epoll_event events[2];
253         int rc = TEMP_FAILURE_RETRY(epoll_wait(epoll_fd_.get(), events, 2, -1));
254         if (rc == -1) {
255             PLOG(ERROR) << "epoll_wait failed";
256             return;
257         } else if (rc == 0) {
258             LOG(ERROR) << "epoll_wait returned 0";
259             return;
260         }
261 
262         for (int i = 0; i < rc; ++i) {
263             struct epoll_event& event = events[i];
264             switch (event.data.u64) {
265                 case kEpollConstSocket:
266                     handleNewClientConnection(server_fd_.get());
267                     break;
268                 case kEpollConstEventFd:
269                     uint64_t dummy;
270                     int rc = TEMP_FAILURE_RETRY(read(event_fd_.get(), &dummy, sizeof(dummy)));
271                     if (rc != sizeof(dummy)) {
272                         PLOG(FATAL) << "failed to read from eventfd (rc=" << rc << ")";
273                     }
274                     return;
275             }
276         }
277     }
278 }
279 
connectionEventsWorker()280 void PairingServerImpl::connectionEventsWorker() {
281     for (;;) {
282         // Transfer the write queue to the read queue.
283         {
284             std::unique_lock<std::mutex> lock(conn_mutex_);
285             ScopedLockAssertion assume_locked(conn_mutex_);
286 
287             if (is_terminate_) {
288                 // We check |is_terminate_| twice because condition_variable's
289                 // notify() only wakes up a thread if it is in the wait state
290                 // prior to notify(). Furthermore, we aren't holding the mutex
291                 // when processing the events in |conn_read_queue_|.
292                 return;
293             }
294             if (conn_write_queue_.empty()) {
295                 // We need to wait for new events, or the termination signal.
296                 conn_cv_.wait(lock, [this]() REQUIRES(conn_mutex_) {
297                     return (is_terminate_ || !conn_write_queue_.empty());
298                 });
299             }
300             if (is_terminate_) {
301                 // We're done.
302                 return;
303             }
304             // Move all events into the read queue.
305             conn_read_queue_ = std::move(conn_write_queue_);
306             conn_write_queue_.clear();
307         }
308 
309         // Process all events in the read queue.
310         while (conn_read_queue_.size() > 0) {
311             auto& event = conn_read_queue_.front();
312             if (auto* p = std::get_if<NewConnectionEvent>(&event)) {
313                 // Ignore if we are already at the max number of connections
314                 if (connections_.size() >= internal::kMaxConnections) {
315                     conn_read_queue_.pop_front();
316                     continue;
317                 }
318                 auto [ufd, connection] = std::move(*p);
319                 int fd = ufd.release();
320                 bool started = connection->start(
321                         fd,
322                         [fd](const PeerInfo* peer_info, const Data* cert, void* opaque) {
323                             auto* p = reinterpret_cast<PairingServerImpl*>(opaque);
324 
325                             ConnectionFinishedEvent event;
326                             if (peer_info != nullptr && cert != nullptr) {
327                                 event = std::make_tuple(fd, std::string(peer_info->name),
328                                                         std::string(peer_info->guid), Data(*cert));
329                             } else {
330                                 event = std::make_tuple(fd, std::nullopt, std::nullopt,
331                                                         std::nullopt);
332                             }
333                             {
334                                 std::lock_guard<std::mutex> lock(p->conn_mutex_);
335                                 p->conn_write_queue_.push_back(std::move(event));
336                             }
337                             p->conn_cv_.notify_one();
338                         },
339                         this);
340                 if (!started) {
341                     LOG(ERROR) << "PairingServer unable to start a PairingConnection fd=" << fd;
342                     ufd.reset(fd);
343                 } else {
344                     connections_[fd] = std::move(connection);
345                 }
346             } else if (auto* p = std::get_if<ConnectionFinishedEvent>(&event)) {
347                 auto [fd, name, guid, cert] = std::move(*p);
348                 if (name.has_value() && guid.has_value() && cert.has_value() && !name->empty() &&
349                     !guid->empty() && !cert->empty()) {
350                     // Valid pairing. Let's shutdown the server and close any
351                     // pairing connections in progress.
352                     stopServer();
353                     connections_.clear();
354 
355                     CHECK_LE(name->size(), kPeerNameLength);
356                     CHECK_LE(guid->size(), kPeerGuidLength);
357                     PeerInfo info = {};
358                     strncpy(info.name, name->data(), name->size());
359                     strncpy(info.guid, guid->data(), guid->size());
360 
361                     cb_(&info, &*cert, opaque_);
362 
363                     got_valid_pairing_ = true;
364                     return;
365                 }
366                 // Invalid pairing. Close the invalid connection.
367                 if (connections_.find(fd) != connections_.end()) {
368                     connections_.erase(fd);
369                 }
370             }
371             conn_read_queue_.pop_front();
372         }
373     }
374 }
375 
handleNewClientConnection(int fd)376 bool PairingServerImpl::handleNewClientConnection(int fd) {
377     unique_fd ufd(TEMP_FAILURE_RETRY(accept4(fd, nullptr, nullptr, SOCK_CLOEXEC)));
378     if (ufd == -1) {
379         PLOG(WARNING) << "adb_socket_accept failed fd=" << fd;
380         return false;
381     }
382     auto connection = PairingConnection::create(PairingConnection::Role::Server, pswd_, peer_info_,
383                                                 cert_, priv_key_);
384     if (connection == nullptr) {
385         LOG(ERROR) << "PairingServer unable to create a PairingConnection fd=" << fd;
386         return false;
387     }
388     // send the new connection to the connection thread for further processing
389     NewConnectionEvent event = std::make_tuple(std::move(ufd), std::move(connection));
390     {
391         std::lock_guard<std::mutex> lock(conn_mutex_);
392         conn_write_queue_.push_back(std::move(event));
393     }
394     conn_cv_.notify_one();
395 
396     return true;
397 }
398 
399 }  // namespace
400 
401 // static
create(const Data & pswd,const PeerInfo & peer_info,const Data & cert,const Data & priv_key,int port)402 std::unique_ptr<PairingServer> PairingServer::create(const Data& pswd, const PeerInfo& peer_info,
403                                                      const Data& cert, const Data& priv_key,
404                                                      int port) {
405     if (pswd.empty() || cert.empty() || priv_key.empty() || port <= 0) {
406         return nullptr;
407     }
408     // Make sure peer_info has a non-empty, null-terminated string for guid and
409     // name.
410     if ('\0' != peer_info.name[kPeerNameLength - 1] ||
411         '\0' != peer_info.guid[kPeerGuidLength - 1] || strlen(peer_info.name) == 0 ||
412         strlen(peer_info.guid) == 0) {
413         LOG(ERROR) << "The GUID/short name fields are empty or not null-terminated";
414         return nullptr;
415     }
416 
417     if (port != kDefaultPairingPort) {
418         LOG(WARNING) << "Starting server with non-default pairing port=" << port;
419     }
420 
421     return std::unique_ptr<PairingServer>(
422             new PairingServerImpl(pswd, peer_info, cert, priv_key, port));
423 }
424 
425 }  // namespace pairing
426 }  // namespace adbwifi
427