• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2007 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define TRACE_TAG TRANSPORT
18 
19 #include "sysdeps.h"
20 
21 #include "transport.h"
22 
23 #include <ctype.h>
24 #include <errno.h>
25 #include <inttypes.h>
26 #include <stdio.h>
27 #include <stdlib.h>
28 #include <string.h>
29 #include <unistd.h>
30 
31 #include <algorithm>
32 #include <list>
33 #include <memory>
34 #include <mutex>
35 #include <set>
36 #include <thread>
37 
38 #include <adb/crypto/rsa_2048_key.h>
39 #include <adb/crypto/x509_generator.h>
40 #include <adb/tls/tls_connection.h>
41 #include <android-base/logging.h>
42 #include <android-base/no_destructor.h>
43 #include <android-base/parsenetaddress.h>
44 #include <android-base/stringprintf.h>
45 #include <android-base/strings.h>
46 #include <android-base/thread_annotations.h>
47 
48 #include <diagnose_usb.h>
49 
50 #include "adb.h"
51 #include "adb_auth.h"
52 #include "adb_io.h"
53 #include "adb_trace.h"
54 #include "adb_utils.h"
55 #include "fdevent/fdevent.h"
56 #include "sysdeps/chrono.h"
57 
58 using namespace adb::crypto;
59 using namespace adb::tls;
60 using android::base::ScopedLockAssertion;
61 using TlsError = TlsConnection::TlsError;
62 
63 static void remove_transport(atransport* transport);
64 static void transport_destroy(atransport* transport);
65 
66 // TODO: unordered_map<TransportId, atransport*>
67 static auto& transport_list = *new std::list<atransport*>();
68 static auto& pending_list = *new std::list<atransport*>();
69 
70 static auto& transport_lock = *new std::recursive_mutex();
71 
72 const char* const kFeatureShell2 = "shell_v2";
73 const char* const kFeatureCmd = "cmd";
74 const char* const kFeatureStat2 = "stat_v2";
75 const char* const kFeatureLs2 = "ls_v2";
76 const char* const kFeatureLibusb = "libusb";
77 const char* const kFeaturePushSync = "push_sync";
78 const char* const kFeatureApex = "apex";
79 const char* const kFeatureFixedPushMkdir = "fixed_push_mkdir";
80 const char* const kFeatureAbb = "abb";
81 const char* const kFeatureFixedPushSymlinkTimestamp = "fixed_push_symlink_timestamp";
82 const char* const kFeatureAbbExec = "abb_exec";
83 const char* const kFeatureRemountShell = "remount_shell";
84 const char* const kFeatureTrackApp = "track_app";
85 const char* const kFeatureSendRecv2 = "sendrecv_v2";
86 const char* const kFeatureSendRecv2Brotli = "sendrecv_v2_brotli";
87 const char* const kFeatureSendRecv2LZ4 = "sendrecv_v2_lz4";
88 const char* const kFeatureSendRecv2Zstd = "sendrecv_v2_zstd";
89 const char* const kFeatureSendRecv2DryRunSend = "sendrecv_v2_dry_run_send";
90 // TODO(joshuaduong): Bump to v2 when openscreen discovery is enabled by default
91 const char* const kFeatureOpenscreenMdns = "openscreen_mdns";
92 
93 namespace {
94 
95 #if ADB_HOST
96 // Tracks and handles atransport*s that are attempting reconnection.
97 class ReconnectHandler {
98   public:
99     ReconnectHandler() = default;
100     ~ReconnectHandler() = default;
101 
102     // Starts the ReconnectHandler thread.
103     void Start();
104 
105     // Requests the ReconnectHandler thread to stop.
106     void Stop();
107 
108     // Adds the atransport* to the queue of reconnect attempts.
109     void TrackTransport(atransport* transport);
110 
111     // Wake up the ReconnectHandler thread to have it check for kicked transports.
112     void CheckForKicked();
113 
114   private:
115     // The main thread loop.
116     void Run();
117 
118     // Tracks a reconnection attempt.
119     struct ReconnectAttempt {
120         atransport* transport;
121         std::chrono::steady_clock::time_point reconnect_time;
122         size_t attempts_left;
123 
operator <__anon7cf071d00111::ReconnectHandler::ReconnectAttempt124         bool operator<(const ReconnectAttempt& rhs) const {
125             if (reconnect_time == rhs.reconnect_time) {
126                 return reinterpret_cast<uintptr_t>(transport) <
127                        reinterpret_cast<uintptr_t>(rhs.transport);
128             }
129             return reconnect_time < rhs.reconnect_time;
130         }
131     };
132 
133     // Only retry for up to one minute.
134     static constexpr const std::chrono::seconds kDefaultTimeout = 3s;
135     static constexpr const size_t kMaxAttempts = 20;
136 
137     // Protects all members.
138     std::mutex reconnect_mutex_;
139     bool running_ GUARDED_BY(reconnect_mutex_) = true;
140     std::thread handler_thread_;
141     std::condition_variable reconnect_cv_;
142     std::set<ReconnectAttempt> reconnect_queue_ GUARDED_BY(reconnect_mutex_);
143 
144     DISALLOW_COPY_AND_ASSIGN(ReconnectHandler);
145 };
146 
Start()147 void ReconnectHandler::Start() {
148     check_main_thread();
149     handler_thread_ = std::thread(&ReconnectHandler::Run, this);
150 }
151 
Stop()152 void ReconnectHandler::Stop() {
153     check_main_thread();
154     {
155         std::lock_guard<std::mutex> lock(reconnect_mutex_);
156         running_ = false;
157     }
158     reconnect_cv_.notify_one();
159     handler_thread_.join();
160 
161     // Drain the queue to free all resources.
162     std::lock_guard<std::mutex> lock(reconnect_mutex_);
163     while (!reconnect_queue_.empty()) {
164         ReconnectAttempt attempt = *reconnect_queue_.begin();
165         reconnect_queue_.erase(reconnect_queue_.begin());
166         remove_transport(attempt.transport);
167     }
168 }
169 
TrackTransport(atransport * transport)170 void ReconnectHandler::TrackTransport(atransport* transport) {
171     check_main_thread();
172     {
173         std::lock_guard<std::mutex> lock(reconnect_mutex_);
174         if (!running_) return;
175         // Arbitrary sleep to give adbd time to get ready, if we disconnected because it exited.
176         auto reconnect_time = std::chrono::steady_clock::now() + 250ms;
177         reconnect_queue_.emplace(
178                 ReconnectAttempt{transport, reconnect_time, ReconnectHandler::kMaxAttempts});
179     }
180     reconnect_cv_.notify_one();
181 }
182 
CheckForKicked()183 void ReconnectHandler::CheckForKicked() {
184     reconnect_cv_.notify_one();
185 }
186 
Run()187 void ReconnectHandler::Run() {
188     while (true) {
189         ReconnectAttempt attempt;
190         {
191             std::unique_lock<std::mutex> lock(reconnect_mutex_);
192             ScopedLockAssertion assume_lock(reconnect_mutex_);
193 
194             if (!reconnect_queue_.empty()) {
195                 // FIXME: libstdc++ (used on Windows) implements condition_variable with
196                 //        system_clock as its clock, so we're probably hosed if the clock changes,
197                 //        even if we use steady_clock throughout. This problem goes away once we
198                 //        switch to libc++.
199                 reconnect_cv_.wait_until(lock, reconnect_queue_.begin()->reconnect_time);
200             } else {
201                 reconnect_cv_.wait(lock);
202             }
203 
204             if (!running_) return;
205 
206             // Scan the whole list for kicked transports, so that we immediately handle an explicit
207             // disconnect request.
208             bool kicked = false;
209             for (auto it = reconnect_queue_.begin(); it != reconnect_queue_.end();) {
210                 if (it->transport->kicked()) {
211                     D("transport %s was kicked. giving up on it.", it->transport->serial.c_str());
212                     remove_transport(it->transport);
213                     it = reconnect_queue_.erase(it);
214                 } else {
215                     ++it;
216                 }
217                 kicked = true;
218             }
219 
220             if (reconnect_queue_.empty()) continue;
221 
222             // Go back to sleep if we either woke up spuriously, or we were woken up to remove
223             // a kicked transport, and the first transport isn't ready for reconnection yet.
224             auto now = std::chrono::steady_clock::now();
225             if (reconnect_queue_.begin()->reconnect_time > now) {
226                 continue;
227             }
228 
229             attempt = *reconnect_queue_.begin();
230             reconnect_queue_.erase(reconnect_queue_.begin());
231         }
232         D("attempting to reconnect %s", attempt.transport->serial.c_str());
233 
234         switch (attempt.transport->Reconnect()) {
235             case ReconnectResult::Retry: {
236                 D("attempting to reconnect %s failed.", attempt.transport->serial.c_str());
237                 if (attempt.attempts_left == 0) {
238                     D("transport %s exceeded the number of retry attempts. giving up on it.",
239                       attempt.transport->serial.c_str());
240                     remove_transport(attempt.transport);
241                     continue;
242                 }
243 
244                 std::lock_guard<std::mutex> lock(reconnect_mutex_);
245                 reconnect_queue_.emplace(ReconnectAttempt{
246                         attempt.transport,
247                         std::chrono::steady_clock::now() + ReconnectHandler::kDefaultTimeout,
248                         attempt.attempts_left - 1});
249                 continue;
250             }
251 
252             case ReconnectResult::Success:
253                 D("reconnection to %s succeeded.", attempt.transport->serial.c_str());
254                 register_transport(attempt.transport);
255                 continue;
256 
257             case ReconnectResult::Abort:
258                 D("cancelling reconnection attempt to %s.", attempt.transport->serial.c_str());
259                 remove_transport(attempt.transport);
260                 continue;
261         }
262     }
263 }
264 
265 static auto& reconnect_handler = *new ReconnectHandler();
266 
267 #endif
268 
269 }  // namespace
270 
NextTransportId()271 TransportId NextTransportId() {
272     static std::atomic<TransportId> next(1);
273     return next++;
274 }
275 
Reset()276 void Connection::Reset() {
277     LOG(INFO) << "Connection::Reset(): stopping";
278     Stop();
279 }
280 
BlockingConnectionAdapter(std::unique_ptr<BlockingConnection> connection)281 BlockingConnectionAdapter::BlockingConnectionAdapter(std::unique_ptr<BlockingConnection> connection)
282     : underlying_(std::move(connection)) {}
283 
~BlockingConnectionAdapter()284 BlockingConnectionAdapter::~BlockingConnectionAdapter() {
285     LOG(INFO) << "BlockingConnectionAdapter(" << this->transport_name_ << "): destructing";
286     Stop();
287 }
288 
Start()289 void BlockingConnectionAdapter::Start() {
290     std::lock_guard<std::mutex> lock(mutex_);
291     if (started_) {
292         LOG(FATAL) << "BlockingConnectionAdapter(" << this->transport_name_
293                    << "): started multiple times";
294     }
295 
296     StartReadThread();
297 
298     write_thread_ = std::thread([this]() {
299         LOG(INFO) << this->transport_name_ << ": write thread spawning";
300         while (true) {
301             std::unique_lock<std::mutex> lock(mutex_);
302             ScopedLockAssertion assume_locked(mutex_);
303             cv_.wait(lock, [this]() REQUIRES(mutex_) {
304                 return this->stopped_ || !this->write_queue_.empty();
305             });
306 
307             if (this->stopped_) {
308                 return;
309             }
310 
311             std::unique_ptr<apacket> packet = std::move(this->write_queue_.front());
312             this->write_queue_.pop_front();
313             lock.unlock();
314 
315             if (!this->underlying_->Write(packet.get())) {
316                 break;
317             }
318         }
319         std::call_once(this->error_flag_, [this]() { this->error_callback_(this, "write failed"); });
320     });
321 
322     started_ = true;
323 }
324 
StartReadThread()325 void BlockingConnectionAdapter::StartReadThread() {
326     read_thread_ = std::thread([this]() {
327         LOG(INFO) << this->transport_name_ << ": read thread spawning";
328         while (true) {
329             auto packet = std::make_unique<apacket>();
330             if (!underlying_->Read(packet.get())) {
331                 PLOG(INFO) << this->transport_name_ << ": read failed";
332                 break;
333             }
334 
335             bool got_stls_cmd = false;
336             if (packet->msg.command == A_STLS) {
337                 got_stls_cmd = true;
338             }
339 
340             read_callback_(this, std::move(packet));
341 
342             // If we received the STLS packet, we are about to perform the TLS
343             // handshake. So this read thread must stop and resume after the
344             // handshake completes otherwise this will interfere in the process.
345             if (got_stls_cmd) {
346                 LOG(INFO) << this->transport_name_
347                           << ": Received STLS packet. Stopping read thread.";
348                 return;
349             }
350         }
351         std::call_once(this->error_flag_, [this]() { this->error_callback_(this, "read failed"); });
352     });
353 }
354 
DoTlsHandshake(RSA * key,std::string * auth_key)355 bool BlockingConnectionAdapter::DoTlsHandshake(RSA* key, std::string* auth_key) {
356     std::lock_guard<std::mutex> lock(mutex_);
357     if (read_thread_.joinable()) {
358         read_thread_.join();
359     }
360     bool success = this->underlying_->DoTlsHandshake(key, auth_key);
361     StartReadThread();
362     return success;
363 }
364 
Reset()365 void BlockingConnectionAdapter::Reset() {
366     {
367         std::lock_guard<std::mutex> lock(mutex_);
368         if (!started_) {
369             LOG(INFO) << "BlockingConnectionAdapter(" << this->transport_name_ << "): not started";
370             return;
371         }
372 
373         if (stopped_) {
374             LOG(INFO) << "BlockingConnectionAdapter(" << this->transport_name_
375                       << "): already stopped";
376             return;
377         }
378     }
379 
380     LOG(INFO) << "BlockingConnectionAdapter(" << this->transport_name_ << "): resetting";
381     this->underlying_->Reset();
382     Stop();
383 }
384 
Stop()385 void BlockingConnectionAdapter::Stop() {
386     {
387         std::lock_guard<std::mutex> lock(mutex_);
388         if (!started_) {
389             LOG(INFO) << "BlockingConnectionAdapter(" << this->transport_name_ << "): not started";
390             return;
391         }
392 
393         if (stopped_) {
394             LOG(INFO) << "BlockingConnectionAdapter(" << this->transport_name_
395                       << "): already stopped";
396             return;
397         }
398 
399         stopped_ = true;
400     }
401 
402     LOG(INFO) << "BlockingConnectionAdapter(" << this->transport_name_ << "): stopping";
403 
404     this->underlying_->Close();
405     this->cv_.notify_one();
406 
407     // Move the threads out into locals with the lock taken, and then unlock to let them exit.
408     std::thread read_thread;
409     std::thread write_thread;
410 
411     {
412         std::lock_guard<std::mutex> lock(mutex_);
413         read_thread = std::move(read_thread_);
414         write_thread = std::move(write_thread_);
415     }
416 
417     read_thread.join();
418     write_thread.join();
419 
420     LOG(INFO) << "BlockingConnectionAdapter(" << this->transport_name_ << "): stopped";
421     std::call_once(this->error_flag_, [this]() { this->error_callback_(this, "requested stop"); });
422 }
423 
Write(std::unique_ptr<apacket> packet)424 bool BlockingConnectionAdapter::Write(std::unique_ptr<apacket> packet) {
425     {
426         std::lock_guard<std::mutex> lock(this->mutex_);
427         write_queue_.emplace_back(std::move(packet));
428     }
429 
430     cv_.notify_one();
431     return true;
432 }
433 
FdConnection(unique_fd fd)434 FdConnection::FdConnection(unique_fd fd) : fd_(std::move(fd)) {}
435 
~FdConnection()436 FdConnection::~FdConnection() {}
437 
DispatchRead(void * buf,size_t len)438 bool FdConnection::DispatchRead(void* buf, size_t len) {
439     if (tls_ != nullptr) {
440         // The TlsConnection doesn't allow 0 byte reads
441         if (len == 0) {
442             return true;
443         }
444         return tls_->ReadFully(buf, len);
445     }
446 
447     return ReadFdExactly(fd_.get(), buf, len);
448 }
449 
DispatchWrite(void * buf,size_t len)450 bool FdConnection::DispatchWrite(void* buf, size_t len) {
451     if (tls_ != nullptr) {
452         // The TlsConnection doesn't allow 0 byte writes
453         if (len == 0) {
454             return true;
455         }
456         return tls_->WriteFully(std::string_view(reinterpret_cast<const char*>(buf), len));
457     }
458 
459     return WriteFdExactly(fd_.get(), buf, len);
460 }
461 
Read(apacket * packet)462 bool FdConnection::Read(apacket* packet) {
463     if (!DispatchRead(&packet->msg, sizeof(amessage))) {
464         D("remote local: read terminated (message)");
465         return false;
466     }
467 
468     if (packet->msg.data_length > MAX_PAYLOAD) {
469         D("remote local: read overflow (data length = %" PRIu32 ")", packet->msg.data_length);
470         return false;
471     }
472 
473     packet->payload.resize(packet->msg.data_length);
474 
475     if (!DispatchRead(&packet->payload[0], packet->payload.size())) {
476         D("remote local: terminated (data)");
477         return false;
478     }
479 
480     return true;
481 }
482 
Write(apacket * packet)483 bool FdConnection::Write(apacket* packet) {
484     if (!DispatchWrite(&packet->msg, sizeof(packet->msg))) {
485         D("remote local: write terminated");
486         return false;
487     }
488 
489     if (packet->msg.data_length) {
490         if (!DispatchWrite(&packet->payload[0], packet->msg.data_length)) {
491             D("remote local: write terminated");
492             return false;
493         }
494     }
495 
496     return true;
497 }
498 
DoTlsHandshake(RSA * key,std::string * auth_key)499 bool FdConnection::DoTlsHandshake(RSA* key, std::string* auth_key) {
500     bssl::UniquePtr<EVP_PKEY> evp_pkey(EVP_PKEY_new());
501     if (!EVP_PKEY_set1_RSA(evp_pkey.get(), key)) {
502         LOG(ERROR) << "EVP_PKEY_set1_RSA failed";
503         return false;
504     }
505     auto x509 = GenerateX509Certificate(evp_pkey.get());
506     auto x509_str = X509ToPEMString(x509.get());
507     auto evp_str = Key::ToPEMString(evp_pkey.get());
508 
509     int osh = cast_handle_to_int(adb_get_os_handle(fd_));
510 #if ADB_HOST
511     tls_ = TlsConnection::Create(TlsConnection::Role::Client, x509_str, evp_str, osh);
512 #else
513     tls_ = TlsConnection::Create(TlsConnection::Role::Server, x509_str, evp_str, osh);
514 #endif
515     CHECK(tls_);
516 #if ADB_HOST
517     // TLS 1.3 gives the client no message if the server rejected the
518     // certificate. This will enable a check in the tls connection to check
519     // whether the client certificate got rejected. Note that this assumes
520     // that, on handshake success, the server speaks first.
521     tls_->EnableClientPostHandshakeCheck(true);
522     // Add callback to set the certificate when server issues the
523     // CertificateRequest.
524     tls_->SetCertificateCallback(adb_tls_set_certificate);
525     // Allow any server certificate
526     tls_->SetCertVerifyCallback([](X509_STORE_CTX*) { return 1; });
527 #else
528     // Add callback to check certificate against a list of known public keys
529     tls_->SetCertVerifyCallback(
530             [auth_key](X509_STORE_CTX* ctx) { return adbd_tls_verify_cert(ctx, auth_key); });
531     // Add the list of allowed client CA issuers
532     auto ca_list = adbd_tls_client_ca_list();
533     tls_->SetClientCAList(ca_list.get());
534 #endif
535 
536     auto err = tls_->DoHandshake();
537     if (err == TlsError::Success) {
538         return true;
539     }
540 
541     tls_.reset();
542     return false;
543 }
544 
Close()545 void FdConnection::Close() {
546     adb_shutdown(fd_.get());
547     fd_.reset();
548 }
549 
send_packet(apacket * p,atransport * t)550 void send_packet(apacket* p, atransport* t) {
551     p->msg.magic = p->msg.command ^ 0xffffffff;
552     // compute a checksum for connection/auth packets for compatibility reasons
553     if (t->get_protocol_version() >= A_VERSION_SKIP_CHECKSUM) {
554         p->msg.data_check = 0;
555     } else {
556         p->msg.data_check = calculate_apacket_checksum(p);
557     }
558 
559     VLOG(TRANSPORT) << dump_packet(t->serial.c_str(), "to remote", p);
560 
561     if (t == nullptr) {
562         LOG(FATAL) << "Transport is null";
563     }
564 
565     if (t->Write(p) != 0) {
566         D("%s: failed to enqueue packet, closing transport", t->serial.c_str());
567         t->Kick();
568     }
569 }
570 
kick_transport(atransport * t,bool reset)571 void kick_transport(atransport* t, bool reset) {
572     std::lock_guard<std::recursive_mutex> lock(transport_lock);
573     // As kick_transport() can be called from threads without guarantee that t is valid,
574     // check if the transport is in transport_list first.
575     //
576     // TODO(jmgao): WTF? Is this actually true?
577     if (std::find(transport_list.begin(), transport_list.end(), t) != transport_list.end()) {
578         if (reset) {
579             t->Reset();
580         } else {
581             t->Kick();
582         }
583     }
584 
585 #if ADB_HOST
586     reconnect_handler.CheckForKicked();
587 #endif
588 }
589 
590 static int transport_registration_send = -1;
591 static int transport_registration_recv = -1;
592 static fdevent* transport_registration_fde;
593 
594 #if ADB_HOST
595 
596 /* this adds support required by the 'track-devices' service.
597  * this is used to send the content of "list_transport" to any
598  * number of client connections that want it through a single
599  * live TCP connection
600  */
601 struct device_tracker {
602     asocket socket;
603     bool update_needed = false;
604     bool long_output = false;
605     device_tracker* next = nullptr;
606 };
607 
608 /* linked list of all device trackers */
609 static device_tracker* device_tracker_list;
610 
device_tracker_remove(device_tracker * tracker)611 static void device_tracker_remove(device_tracker* tracker) {
612     device_tracker** pnode = &device_tracker_list;
613     device_tracker* node = *pnode;
614 
615     std::lock_guard<std::recursive_mutex> lock(transport_lock);
616     while (node) {
617         if (node == tracker) {
618             *pnode = node->next;
619             break;
620         }
621         pnode = &node->next;
622         node = *pnode;
623     }
624 }
625 
device_tracker_close(asocket * socket)626 static void device_tracker_close(asocket* socket) {
627     device_tracker* tracker = (device_tracker*)socket;
628     asocket* peer = socket->peer;
629 
630     D("device tracker %p removed", tracker);
631     if (peer) {
632         peer->peer = nullptr;
633         peer->close(peer);
634     }
635     device_tracker_remove(tracker);
636     delete tracker;
637 }
638 
device_tracker_enqueue(asocket * socket,apacket::payload_type)639 static int device_tracker_enqueue(asocket* socket, apacket::payload_type) {
640     /* you can't read from a device tracker, close immediately */
641     device_tracker_close(socket);
642     return -1;
643 }
644 
device_tracker_send(device_tracker * tracker,const std::string & string)645 static int device_tracker_send(device_tracker* tracker, const std::string& string) {
646     asocket* peer = tracker->socket.peer;
647 
648     apacket::payload_type data;
649     data.resize(4 + string.size());
650     char buf[5];
651     snprintf(buf, sizeof(buf), "%04x", static_cast<int>(string.size()));
652     memcpy(&data[0], buf, 4);
653     memcpy(&data[4], string.data(), string.size());
654     return peer->enqueue(peer, std::move(data));
655 }
656 
device_tracker_ready(asocket * socket)657 static void device_tracker_ready(asocket* socket) {
658     device_tracker* tracker = reinterpret_cast<device_tracker*>(socket);
659 
660     // We want to send the device list when the tracker connects
661     // for the first time, even if no update occurred.
662     if (tracker->update_needed) {
663         tracker->update_needed = false;
664         device_tracker_send(tracker, list_transports(tracker->long_output));
665     }
666 }
667 
create_device_tracker(bool long_output)668 asocket* create_device_tracker(bool long_output) {
669     device_tracker* tracker = new device_tracker();
670     if (tracker == nullptr) LOG(FATAL) << "cannot allocate device tracker";
671 
672     D("device tracker %p created", tracker);
673 
674     tracker->socket.enqueue = device_tracker_enqueue;
675     tracker->socket.ready = device_tracker_ready;
676     tracker->socket.close = device_tracker_close;
677     tracker->update_needed = true;
678     tracker->long_output = long_output;
679 
680     tracker->next = device_tracker_list;
681     device_tracker_list = tracker;
682 
683     return &tracker->socket;
684 }
685 
686 // Check if all of the USB transports are connected.
iterate_transports(std::function<bool (const atransport *)> fn)687 bool iterate_transports(std::function<bool(const atransport*)> fn) {
688     std::lock_guard<std::recursive_mutex> lock(transport_lock);
689     for (const auto& t : transport_list) {
690         if (!fn(t)) {
691             return false;
692         }
693     }
694     for (const auto& t : pending_list) {
695         if (!fn(t)) {
696             return false;
697         }
698     }
699     return true;
700 }
701 
702 // Call this function each time the transport list has changed.
update_transports()703 void update_transports() {
704     update_transport_status();
705 
706     // Notify `adb track-devices` clients.
707     device_tracker* tracker = device_tracker_list;
708     while (tracker != nullptr) {
709         device_tracker* next = tracker->next;
710         // This may destroy the tracker if the connection is closed.
711         device_tracker_send(tracker, list_transports(tracker->long_output));
712         tracker = next;
713     }
714 }
715 
716 #else
717 
update_transports()718 void update_transports() {
719     // Nothing to do on the device side.
720 }
721 
722 #endif  // ADB_HOST
723 
724 struct tmsg {
725     atransport* transport;
726     int action;
727 };
728 
transport_read_action(int fd,struct tmsg * m)729 static int transport_read_action(int fd, struct tmsg* m) {
730     char* p = (char*)m;
731     int len = sizeof(*m);
732     int r;
733 
734     while (len > 0) {
735         r = adb_read(fd, p, len);
736         if (r > 0) {
737             len -= r;
738             p += r;
739         } else {
740             D("transport_read_action: on fd %d: %s", fd, strerror(errno));
741             return -1;
742         }
743     }
744     return 0;
745 }
746 
transport_write_action(int fd,struct tmsg * m)747 static int transport_write_action(int fd, struct tmsg* m) {
748     char* p = (char*)m;
749     int len = sizeof(*m);
750     int r;
751 
752     while (len > 0) {
753         r = adb_write(fd, p, len);
754         if (r > 0) {
755             len -= r;
756             p += r;
757         } else {
758             D("transport_write_action: on fd %d: %s", fd, strerror(errno));
759             return -1;
760         }
761     }
762     return 0;
763 }
764 
transport_registration_func(int _fd,unsigned ev,void *)765 static void transport_registration_func(int _fd, unsigned ev, void*) {
766     tmsg m;
767     atransport* t;
768 
769     if (!(ev & FDE_READ)) {
770         return;
771     }
772 
773     if (transport_read_action(_fd, &m)) {
774         PLOG(FATAL) << "cannot read transport registration socket";
775     }
776 
777     t = m.transport;
778 
779     if (m.action == 0) {
780         D("transport: %s deleting", t->serial.c_str());
781 
782         {
783             std::lock_guard<std::recursive_mutex> lock(transport_lock);
784             transport_list.remove(t);
785         }
786 
787         delete t;
788 
789         update_transports();
790         return;
791     }
792 
793     /* don't create transport threads for inaccessible devices */
794     if (t->GetConnectionState() != kCsNoPerm) {
795         // The connection gets a reference to the atransport. It will release it
796         // upon a read/write error.
797         t->connection()->SetTransportName(t->serial_name());
798         t->connection()->SetReadCallback([t](Connection*, std::unique_ptr<apacket> p) {
799             if (!check_header(p.get(), t)) {
800                 D("%s: remote read: bad header", t->serial.c_str());
801                 return false;
802             }
803 
804             VLOG(TRANSPORT) << dump_packet(t->serial.c_str(), "from remote", p.get());
805             apacket* packet = p.release();
806 
807             // TODO: Does this need to run on the main thread?
808             fdevent_run_on_main_thread([packet, t]() { handle_packet(packet, t); });
809             return true;
810         });
811         t->connection()->SetErrorCallback([t](Connection*, const std::string& error) {
812             LOG(INFO) << t->serial_name() << ": connection terminated: " << error;
813             fdevent_run_on_main_thread([t]() {
814                 handle_offline(t);
815                 transport_destroy(t);
816             });
817         });
818 
819         t->connection()->Start();
820 #if ADB_HOST
821         send_connect(t);
822 #endif
823     }
824 
825     {
826         std::lock_guard<std::recursive_mutex> lock(transport_lock);
827         auto it = std::find(pending_list.begin(), pending_list.end(), t);
828         if (it != pending_list.end()) {
829             pending_list.remove(t);
830             transport_list.push_front(t);
831         }
832     }
833 
834     update_transports();
835 }
836 
837 #if ADB_HOST
init_reconnect_handler(void)838 void init_reconnect_handler(void) {
839     reconnect_handler.Start();
840 }
841 #endif
842 
init_transport_registration(void)843 void init_transport_registration(void) {
844     int s[2];
845 
846     if (adb_socketpair(s)) {
847         PLOG(FATAL) << "cannot open transport registration socketpair";
848     }
849     D("socketpair: (%d,%d)", s[0], s[1]);
850 
851     transport_registration_send = s[0];
852     transport_registration_recv = s[1];
853 
854     transport_registration_fde =
855         fdevent_create(transport_registration_recv, transport_registration_func, nullptr);
856     fdevent_set(transport_registration_fde, FDE_READ);
857 }
858 
kick_all_transports()859 void kick_all_transports() {
860 #if ADB_HOST
861     reconnect_handler.Stop();
862 #endif
863     // To avoid only writing part of a packet to a transport after exit, kick all transports.
864     std::lock_guard<std::recursive_mutex> lock(transport_lock);
865     for (auto t : transport_list) {
866         t->Kick();
867     }
868 }
869 
kick_all_tcp_tls_transports()870 void kick_all_tcp_tls_transports() {
871     std::lock_guard<std::recursive_mutex> lock(transport_lock);
872     for (auto t : transport_list) {
873         if (t->IsTcpDevice() && t->use_tls) {
874             t->Kick();
875         }
876     }
877 }
878 
879 #if !ADB_HOST
kick_all_transports_by_auth_key(std::string_view auth_key)880 void kick_all_transports_by_auth_key(std::string_view auth_key) {
881     std::lock_guard<std::recursive_mutex> lock(transport_lock);
882     for (auto t : transport_list) {
883         if (auth_key == t->auth_key) {
884             t->Kick();
885         }
886     }
887 }
888 #endif
889 
890 /* the fdevent select pump is single threaded */
register_transport(atransport * transport)891 void register_transport(atransport* transport) {
892     tmsg m;
893     m.transport = transport;
894     m.action = 1;
895     D("transport: %s registered", transport->serial.c_str());
896     if (transport_write_action(transport_registration_send, &m)) {
897         PLOG(FATAL) << "cannot write transport registration socket";
898     }
899 }
900 
remove_transport(atransport * transport)901 static void remove_transport(atransport* transport) {
902     tmsg m;
903     m.transport = transport;
904     m.action = 0;
905     D("transport: %s removed", transport->serial.c_str());
906     if (transport_write_action(transport_registration_send, &m)) {
907         PLOG(FATAL) << "cannot write transport registration socket";
908     }
909 }
910 
transport_destroy(atransport * t)911 static void transport_destroy(atransport* t) {
912     check_main_thread();
913     CHECK(t != nullptr);
914 
915     std::lock_guard<std::recursive_mutex> lock(transport_lock);
916     LOG(INFO) << "destroying transport " << t->serial_name();
917     t->connection()->Stop();
918 #if ADB_HOST
919     if (t->IsTcpDevice() && !t->kicked()) {
920         D("transport: %s destroy (attempting reconnection)", t->serial.c_str());
921 
922         // We need to clear the transport's keys, so that on the next connection, it tries
923         // again from the beginning.
924         t->ResetKeys();
925         reconnect_handler.TrackTransport(t);
926         return;
927     }
928 #endif
929 
930     D("transport: %s destroy (kicking and closing)", t->serial.c_str());
931     remove_transport(t);
932 }
933 
934 #if ADB_HOST
qual_match(const std::string & to_test,const char * prefix,const std::string & qual,bool sanitize_qual)935 static int qual_match(const std::string& to_test, const char* prefix, const std::string& qual,
936                       bool sanitize_qual) {
937     if (to_test.empty()) /* Return true if both the qual and to_test are empty strings. */
938         return qual.empty();
939 
940     if (qual.empty()) return 0;
941 
942     const char* ptr = to_test.c_str();
943     if (prefix) {
944         while (*prefix) {
945             if (*prefix++ != *ptr++) return 0;
946         }
947     }
948 
949     for (char ch : qual) {
950         if (sanitize_qual && !isalnum(ch)) ch = '_';
951         if (ch != *ptr++) return 0;
952     }
953 
954     /* Everything matched so far.  Return true if *ptr is a NUL. */
955     return !*ptr;
956 }
957 
acquire_one_transport(TransportType type,const char * serial,TransportId transport_id,bool * is_ambiguous,std::string * error_out,bool accept_any_state)958 atransport* acquire_one_transport(TransportType type, const char* serial, TransportId transport_id,
959                                   bool* is_ambiguous, std::string* error_out,
960                                   bool accept_any_state) {
961     atransport* result = nullptr;
962 
963     if (transport_id != 0) {
964         *error_out =
965             android::base::StringPrintf("no device with transport id '%" PRIu64 "'", transport_id);
966     } else if (serial) {
967         *error_out = android::base::StringPrintf("device '%s' not found", serial);
968     } else if (type == kTransportLocal) {
969         *error_out = "no emulators found";
970     } else if (type == kTransportAny) {
971         *error_out = "no devices/emulators found";
972     } else {
973         *error_out = "no devices found";
974     }
975 
976     std::unique_lock<std::recursive_mutex> lock(transport_lock);
977     for (const auto& t : transport_list) {
978         if (t->GetConnectionState() == kCsNoPerm) {
979             *error_out = UsbNoPermissionsLongHelpText();
980             continue;
981         }
982 
983         if (transport_id) {
984             if (t->id == transport_id) {
985                 result = t;
986                 break;
987             }
988         } else if (serial) {
989             if (t->MatchesTarget(serial)) {
990                 if (result) {
991                     *error_out = "more than one device";
992                     if (is_ambiguous) *is_ambiguous = true;
993                     result = nullptr;
994                     break;
995                 }
996                 result = t;
997             }
998         } else {
999             if (type == kTransportUsb && t->type == kTransportUsb) {
1000                 if (result) {
1001                     *error_out = "more than one device";
1002                     if (is_ambiguous) *is_ambiguous = true;
1003                     result = nullptr;
1004                     break;
1005                 }
1006                 result = t;
1007             } else if (type == kTransportLocal && t->type == kTransportLocal) {
1008                 if (result) {
1009                     *error_out = "more than one emulator";
1010                     if (is_ambiguous) *is_ambiguous = true;
1011                     result = nullptr;
1012                     break;
1013                 }
1014                 result = t;
1015             } else if (type == kTransportAny) {
1016                 if (result) {
1017                     *error_out = "more than one device/emulator";
1018                     if (is_ambiguous) *is_ambiguous = true;
1019                     result = nullptr;
1020                     break;
1021                 }
1022                 result = t;
1023             }
1024         }
1025     }
1026     lock.unlock();
1027 
1028     if (result && !accept_any_state) {
1029         // The caller requires an active transport.
1030         // Make sure that we're actually connected.
1031         ConnectionState state = result->GetConnectionState();
1032         switch (state) {
1033             case kCsConnecting:
1034                 *error_out = "device still connecting";
1035                 result = nullptr;
1036                 break;
1037 
1038             case kCsAuthorizing:
1039                 *error_out = "device still authorizing";
1040                 result = nullptr;
1041                 break;
1042 
1043             case kCsUnauthorized: {
1044                 *error_out = "device unauthorized.\n";
1045                 char* ADB_VENDOR_KEYS = getenv("ADB_VENDOR_KEYS");
1046                 *error_out += "This adb server's $ADB_VENDOR_KEYS is ";
1047                 *error_out += ADB_VENDOR_KEYS ? ADB_VENDOR_KEYS : "not set";
1048                 *error_out += "\n";
1049                 *error_out += "Try 'adb kill-server' if that seems wrong.\n";
1050                 *error_out += "Otherwise check for a confirmation dialog on your device.";
1051                 result = nullptr;
1052                 break;
1053             }
1054 
1055             case kCsOffline:
1056                 *error_out = "device offline";
1057                 result = nullptr;
1058                 break;
1059 
1060             default:
1061                 break;
1062         }
1063     }
1064 
1065     if (result) {
1066         *error_out = "success";
1067     }
1068 
1069     return result;
1070 }
1071 
WaitForConnection(std::chrono::milliseconds timeout)1072 bool ConnectionWaitable::WaitForConnection(std::chrono::milliseconds timeout) {
1073     std::unique_lock<std::mutex> lock(mutex_);
1074     ScopedLockAssertion assume_locked(mutex_);
1075     return cv_.wait_for(lock, timeout, [&]() REQUIRES(mutex_) {
1076         return connection_established_ready_;
1077     }) && connection_established_;
1078 }
1079 
SetConnectionEstablished(bool success)1080 void ConnectionWaitable::SetConnectionEstablished(bool success) {
1081     {
1082         std::lock_guard<std::mutex> lock(mutex_);
1083         if (connection_established_ready_) return;
1084         connection_established_ready_ = true;
1085         connection_established_ = success;
1086         D("connection established with %d", success);
1087     }
1088     cv_.notify_one();
1089 }
1090 #endif
1091 
~atransport()1092 atransport::~atransport() {
1093 #if ADB_HOST
1094     // If the connection callback had not been run before, run it now.
1095     SetConnectionEstablished(false);
1096 #endif
1097 }
1098 
Write(apacket * p)1099 int atransport::Write(apacket* p) {
1100     return this->connection()->Write(std::unique_ptr<apacket>(p)) ? 0 : -1;
1101 }
1102 
Reset()1103 void atransport::Reset() {
1104     if (!kicked_.exchange(true)) {
1105         LOG(INFO) << "resetting transport " << this << " " << this->serial;
1106         this->connection()->Reset();
1107     }
1108 }
1109 
Kick()1110 void atransport::Kick() {
1111     if (!kicked_.exchange(true)) {
1112         LOG(INFO) << "kicking transport " << this << " " << this->serial;
1113         this->connection()->Stop();
1114     }
1115 }
1116 
GetConnectionState() const1117 ConnectionState atransport::GetConnectionState() const {
1118     return connection_state_;
1119 }
1120 
SetConnectionState(ConnectionState state)1121 void atransport::SetConnectionState(ConnectionState state) {
1122     check_main_thread();
1123     connection_state_ = state;
1124     update_transports();
1125 }
1126 
SetConnection(std::shared_ptr<Connection> connection)1127 void atransport::SetConnection(std::shared_ptr<Connection> connection) {
1128     std::lock_guard<std::mutex> lock(mutex_);
1129     connection_ = std::shared_ptr<Connection>(std::move(connection));
1130 }
1131 
connection_state_name() const1132 std::string atransport::connection_state_name() const {
1133     ConnectionState state = GetConnectionState();
1134     switch (state) {
1135         case kCsOffline:
1136             return "offline";
1137         case kCsBootloader:
1138             return "bootloader";
1139         case kCsDevice:
1140             return "device";
1141         case kCsHost:
1142             return "host";
1143         case kCsRecovery:
1144             return "recovery";
1145         case kCsRescue:
1146             return "rescue";
1147         case kCsNoPerm:
1148             return UsbNoPermissionsShortHelpText();
1149         case kCsSideload:
1150             return "sideload";
1151         case kCsUnauthorized:
1152             return "unauthorized";
1153         case kCsAuthorizing:
1154             return "authorizing";
1155         case kCsConnecting:
1156             return "connecting";
1157         default:
1158             return "unknown";
1159     }
1160 }
1161 
update_version(int version,size_t payload)1162 void atransport::update_version(int version, size_t payload) {
1163     protocol_version = std::min(version, A_VERSION);
1164     max_payload = std::min(payload, MAX_PAYLOAD);
1165 }
1166 
get_protocol_version() const1167 int atransport::get_protocol_version() const {
1168     return protocol_version;
1169 }
1170 
get_tls_version() const1171 int atransport::get_tls_version() const {
1172     return tls_version;
1173 }
1174 
get_max_payload() const1175 size_t atransport::get_max_payload() const {
1176     return max_payload;
1177 }
1178 
supported_features()1179 const FeatureSet& supported_features() {
1180     static const android::base::NoDestructor<FeatureSet> features([] {
1181         return FeatureSet{
1182                 kFeatureShell2,
1183                 kFeatureCmd,
1184                 kFeatureStat2,
1185                 kFeatureLs2,
1186                 kFeatureFixedPushMkdir,
1187                 kFeatureApex,
1188                 kFeatureAbb,
1189                 kFeatureFixedPushSymlinkTimestamp,
1190                 kFeatureAbbExec,
1191                 kFeatureRemountShell,
1192                 kFeatureTrackApp,
1193                 kFeatureSendRecv2,
1194                 kFeatureSendRecv2Brotli,
1195                 kFeatureSendRecv2LZ4,
1196                 kFeatureSendRecv2Zstd,
1197                 kFeatureSendRecv2DryRunSend,
1198                 kFeatureOpenscreenMdns,
1199                 // Increment ADB_SERVER_VERSION when adding a feature that adbd needs
1200                 // to know about. Otherwise, the client can be stuck running an old
1201                 // version of the server even after upgrading their copy of adb.
1202                 // (http://b/24370690)
1203         };
1204     }());
1205 
1206     return *features;
1207 }
1208 
FeatureSetToString(const FeatureSet & features)1209 std::string FeatureSetToString(const FeatureSet& features) {
1210     return android::base::Join(features, ',');
1211 }
1212 
StringToFeatureSet(const std::string & features_string)1213 FeatureSet StringToFeatureSet(const std::string& features_string) {
1214     if (features_string.empty()) {
1215         return FeatureSet();
1216     }
1217 
1218     return android::base::Split(features_string, ",");
1219 }
1220 
1221 template <class Range, class Value>
contains(const Range & r,const Value & v)1222 static bool contains(const Range& r, const Value& v) {
1223     return std::find(std::begin(r), std::end(r), v) != std::end(r);
1224 }
1225 
CanUseFeature(const FeatureSet & feature_set,const std::string & feature)1226 bool CanUseFeature(const FeatureSet& feature_set, const std::string& feature) {
1227     return contains(feature_set, feature) && contains(supported_features(), feature);
1228 }
1229 
has_feature(const std::string & feature) const1230 bool atransport::has_feature(const std::string& feature) const {
1231     return contains(features_, feature);
1232 }
1233 
SetFeatures(const std::string & features_string)1234 void atransport::SetFeatures(const std::string& features_string) {
1235     features_ = StringToFeatureSet(features_string);
1236 }
1237 
AddDisconnect(adisconnect * disconnect)1238 void atransport::AddDisconnect(adisconnect* disconnect) {
1239     disconnects_.push_back(disconnect);
1240 }
1241 
RemoveDisconnect(adisconnect * disconnect)1242 void atransport::RemoveDisconnect(adisconnect* disconnect) {
1243     disconnects_.remove(disconnect);
1244 }
1245 
RunDisconnects()1246 void atransport::RunDisconnects() {
1247     for (const auto& disconnect : disconnects_) {
1248         disconnect->func(disconnect->opaque, this);
1249     }
1250     disconnects_.clear();
1251 }
1252 
1253 #if ADB_HOST
MatchesTarget(const std::string & target) const1254 bool atransport::MatchesTarget(const std::string& target) const {
1255     if (!serial.empty()) {
1256         if (target == serial) {
1257             return true;
1258         } else if (type == kTransportLocal) {
1259             // Local transports can match [tcp:|udp:]<hostname>[:port].
1260             const char* local_target_ptr = target.c_str();
1261 
1262             // For fastboot compatibility, ignore protocol prefixes.
1263             if (android::base::StartsWith(target, "tcp:") ||
1264                 android::base::StartsWith(target, "udp:")) {
1265                 local_target_ptr += 4;
1266             }
1267 
1268             // Parse our |serial| and the given |target| to check if the hostnames and ports match.
1269             std::string serial_host, error;
1270             int serial_port = -1;
1271             if (android::base::ParseNetAddress(serial, &serial_host, &serial_port, nullptr, &error)) {
1272                 // |target| may omit the port to default to ours.
1273                 std::string target_host;
1274                 int target_port = serial_port;
1275                 if (android::base::ParseNetAddress(local_target_ptr, &target_host, &target_port,
1276                                                    nullptr, &error) &&
1277                     serial_host == target_host && serial_port == target_port) {
1278                     return true;
1279                 }
1280             }
1281         }
1282     }
1283 
1284     return (target == devpath) || qual_match(target, "product:", product, false) ||
1285            qual_match(target, "model:", model, true) ||
1286            qual_match(target, "device:", device, false);
1287 }
1288 
SetConnectionEstablished(bool success)1289 void atransport::SetConnectionEstablished(bool success) {
1290     connection_waitable_->SetConnectionEstablished(success);
1291 }
1292 
Reconnect()1293 ReconnectResult atransport::Reconnect() {
1294     return reconnect_(this);
1295 }
1296 
1297 // We use newline as our delimiter, make sure to never output it.
sanitize(std::string str,bool alphanumeric)1298 static std::string sanitize(std::string str, bool alphanumeric) {
1299     auto pred = alphanumeric ? [](const char c) { return !isalnum(c); }
1300                              : [](const char c) { return c == '\n'; };
1301     std::replace_if(str.begin(), str.end(), pred, '_');
1302     return str;
1303 }
1304 
append_transport_info(std::string * result,const char * key,const std::string & value,bool alphanumeric)1305 static void append_transport_info(std::string* result, const char* key, const std::string& value,
1306                                   bool alphanumeric) {
1307     if (value.empty()) {
1308         return;
1309     }
1310 
1311     *result += ' ';
1312     *result += key;
1313     *result += sanitize(value, alphanumeric);
1314 }
1315 
append_transport(const atransport * t,std::string * result,bool long_listing)1316 static void append_transport(const atransport* t, std::string* result, bool long_listing) {
1317     std::string serial = t->serial;
1318     if (serial.empty()) {
1319         serial = "(no serial number)";
1320     }
1321 
1322     if (!long_listing) {
1323         *result += serial;
1324         *result += '\t';
1325         *result += t->connection_state_name();
1326     } else {
1327         android::base::StringAppendF(result, "%-22s %s", serial.c_str(),
1328                                      t->connection_state_name().c_str());
1329 
1330         append_transport_info(result, "", t->devpath, false);
1331         append_transport_info(result, "product:", t->product, false);
1332         append_transport_info(result, "model:", t->model, true);
1333         append_transport_info(result, "device:", t->device, false);
1334 
1335         // Put id at the end, so that anyone parsing the output here can always find it by scanning
1336         // backwards from newlines, even with hypothetical devices named 'transport_id:1'.
1337         *result += " transport_id:";
1338         *result += std::to_string(t->id);
1339     }
1340     *result += '\n';
1341 }
1342 
list_transports(bool long_listing)1343 std::string list_transports(bool long_listing) {
1344     std::lock_guard<std::recursive_mutex> lock(transport_lock);
1345 
1346     auto sorted_transport_list = transport_list;
1347     sorted_transport_list.sort([](atransport*& x, atransport*& y) {
1348         if (x->type != y->type) {
1349             return x->type < y->type;
1350         }
1351         return x->serial < y->serial;
1352     });
1353 
1354     std::string result;
1355     for (const auto& t : sorted_transport_list) {
1356         append_transport(t, &result, long_listing);
1357     }
1358     return result;
1359 }
1360 
close_usb_devices(std::function<bool (const atransport *)> predicate,bool reset)1361 void close_usb_devices(std::function<bool(const atransport*)> predicate, bool reset) {
1362     std::lock_guard<std::recursive_mutex> lock(transport_lock);
1363     for (auto& t : transport_list) {
1364         if (predicate(t)) {
1365             if (reset) {
1366                 t->Reset();
1367             } else {
1368                 t->Kick();
1369             }
1370         }
1371     }
1372 }
1373 
1374 /* hack for osx */
close_usb_devices(bool reset)1375 void close_usb_devices(bool reset) {
1376     close_usb_devices([](const atransport*) { return true; }, reset);
1377 }
1378 #endif
1379 
register_socket_transport(unique_fd s,std::string serial,int port,int local,atransport::ReconnectCallback reconnect,bool use_tls,int * error)1380 bool register_socket_transport(unique_fd s, std::string serial, int port, int local,
1381                                atransport::ReconnectCallback reconnect, bool use_tls, int* error) {
1382     atransport* t = new atransport(std::move(reconnect), kCsOffline);
1383     t->use_tls = use_tls;
1384 
1385     D("transport: %s init'ing for socket %d, on port %d", serial.c_str(), s.get(), port);
1386     if (init_socket_transport(t, std::move(s), port, local) < 0) {
1387         delete t;
1388         if (error) *error = errno;
1389         return false;
1390     }
1391 
1392     std::unique_lock<std::recursive_mutex> lock(transport_lock);
1393     for (const auto& transport : pending_list) {
1394         if (serial == transport->serial) {
1395             VLOG(TRANSPORT) << "socket transport " << transport->serial
1396                             << " is already in pending_list and fails to register";
1397             delete t;
1398             if (error) *error = EALREADY;
1399             return false;
1400         }
1401     }
1402 
1403     for (const auto& transport : transport_list) {
1404         if (serial == transport->serial) {
1405             VLOG(TRANSPORT) << "socket transport " << transport->serial
1406                             << " is already in transport_list and fails to register";
1407             delete t;
1408             if (error) *error = EALREADY;
1409             return false;
1410         }
1411     }
1412 
1413     t->serial = std::move(serial);
1414     pending_list.push_front(t);
1415 
1416     lock.unlock();
1417 
1418 #if ADB_HOST
1419     auto waitable = t->connection_waitable();
1420 #endif
1421     register_transport(t);
1422 
1423     if (local == 1) {
1424         // Do not wait for emulator transports.
1425         return true;
1426     }
1427 
1428 #if ADB_HOST
1429     if (!waitable->WaitForConnection(std::chrono::seconds(10))) {
1430         if (error) *error = ETIMEDOUT;
1431         return false;
1432     }
1433 
1434     if (t->GetConnectionState() == kCsUnauthorized) {
1435         if (error) *error = EPERM;
1436         return false;
1437     }
1438 #endif
1439 
1440     return true;
1441 }
1442 
1443 #if ADB_HOST
find_transport(const char * serial)1444 atransport* find_transport(const char* serial) {
1445     atransport* result = nullptr;
1446 
1447     std::lock_guard<std::recursive_mutex> lock(transport_lock);
1448     for (auto& t : transport_list) {
1449         if (strcmp(serial, t->serial.c_str()) == 0) {
1450             result = t;
1451             break;
1452         }
1453     }
1454 
1455     return result;
1456 }
1457 
kick_all_tcp_devices()1458 void kick_all_tcp_devices() {
1459     std::lock_guard<std::recursive_mutex> lock(transport_lock);
1460     for (auto& t : transport_list) {
1461         if (t->IsTcpDevice()) {
1462             // Kicking breaks the read_transport thread of this transport out of any read, then
1463             // the read_transport thread will notify the main thread to make this transport
1464             // offline. Then the main thread will notify the write_transport thread to exit.
1465             // Finally, this transport will be closed and freed in the main thread.
1466             t->Kick();
1467         }
1468     }
1469     reconnect_handler.CheckForKicked();
1470 }
1471 
register_usb_transport(std::shared_ptr<Connection> connection,const char * serial,const char * devpath,unsigned writeable)1472 void register_usb_transport(std::shared_ptr<Connection> connection, const char* serial,
1473                             const char* devpath, unsigned writeable) {
1474     atransport* t = new atransport(writeable ? kCsOffline : kCsNoPerm);
1475     if (serial) {
1476         t->serial = serial;
1477     }
1478     if (devpath) {
1479         t->devpath = devpath;
1480     }
1481 
1482     t->SetConnection(std::move(connection));
1483     t->type = kTransportUsb;
1484 
1485     {
1486         std::lock_guard<std::recursive_mutex> lock(transport_lock);
1487         pending_list.push_front(t);
1488     }
1489 
1490     register_transport(t);
1491 }
1492 
register_usb_transport(usb_handle * usb,const char * serial,const char * devpath,unsigned writeable)1493 void register_usb_transport(usb_handle* usb, const char* serial, const char* devpath,
1494                             unsigned writeable) {
1495     atransport* t = new atransport(writeable ? kCsOffline : kCsNoPerm);
1496 
1497     D("transport: %p init'ing for usb_handle %p (sn='%s')", t, usb, serial ? serial : "");
1498     init_usb_transport(t, usb);
1499     if (serial) {
1500         t->serial = serial;
1501     }
1502 
1503     if (devpath) {
1504         t->devpath = devpath;
1505     }
1506 
1507     {
1508         std::lock_guard<std::recursive_mutex> lock(transport_lock);
1509         pending_list.push_front(t);
1510     }
1511 
1512     register_transport(t);
1513 }
1514 
1515 // This should only be used for transports with connection_state == kCsNoPerm.
unregister_usb_transport(usb_handle * usb)1516 void unregister_usb_transport(usb_handle* usb) {
1517     std::lock_guard<std::recursive_mutex> lock(transport_lock);
1518     transport_list.remove_if([usb](atransport* t) {
1519         return t->GetUsbHandle() == usb && t->GetConnectionState() == kCsNoPerm;
1520     });
1521 }
1522 #endif
1523 
check_header(apacket * p,atransport * t)1524 bool check_header(apacket* p, atransport* t) {
1525     if (p->msg.magic != (p->msg.command ^ 0xffffffff)) {
1526         VLOG(RWX) << "check_header(): invalid magic command = " << std::hex << p->msg.command
1527                   << ", magic = " << p->msg.magic;
1528         return false;
1529     }
1530 
1531     if (p->msg.data_length > t->get_max_payload()) {
1532         VLOG(RWX) << "check_header(): " << p->msg.data_length
1533                   << " atransport::max_payload = " << t->get_max_payload();
1534         return false;
1535     }
1536 
1537     return true;
1538 }
1539 
1540 #if ADB_HOST
Key()1541 std::shared_ptr<RSA> atransport::Key() {
1542     if (keys_.empty()) {
1543         return nullptr;
1544     }
1545 
1546     std::shared_ptr<RSA> result = keys_[0];
1547     return result;
1548 }
1549 
NextKey()1550 std::shared_ptr<RSA> atransport::NextKey() {
1551     if (keys_.empty()) {
1552         LOG(INFO) << "fetching keys for transport " << this->serial_name();
1553         keys_ = adb_auth_get_private_keys();
1554 
1555         // We should have gotten at least one key: the one that's automatically generated.
1556         CHECK(!keys_.empty());
1557     } else {
1558         keys_.pop_front();
1559     }
1560 
1561     return Key();
1562 }
1563 
ResetKeys()1564 void atransport::ResetKeys() {
1565     keys_.clear();
1566 }
1567 #endif
1568