• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2016 The Android Open Source Project
3  * All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions
7  * are met:
8  *  * Redistributions of source code must retain the above copyright
9  *    notice, this list of conditions and the following disclaimer.
10  *  * Redistributions in binary form must reproduce the above copyright
11  *    notice, this list of conditions and the following disclaimer in
12  *    the documentation and/or other materials provided with the
13  *    distribution.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
16  * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
17  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
18  * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
19  * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
20  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
21  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS
22  * OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED
23  * AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
24  * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT
25  * OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
26  * SUCH DAMAGE.
27  */
28 
29 #include "tcp.h"
30 
31 #include <android-base/parseint.h>
32 #include <android-base/stringprintf.h>
33 
34 namespace tcp {
35 
36 static constexpr int kProtocolVersion = 1;
37 static constexpr size_t kHandshakeLength = 4;
38 static constexpr int kHandshakeTimeoutMs = 2000;
39 
40 // Extract the big-endian 8-byte message length into a 64-bit number.
ExtractMessageLength(const void * buffer)41 static uint64_t ExtractMessageLength(const void* buffer) {
42     uint64_t ret = 0;
43     for (int i = 0; i < 8; ++i) {
44         ret |= uint64_t{reinterpret_cast<const uint8_t*>(buffer)[i]} << (56 - i * 8);
45     }
46     return ret;
47 }
48 
49 // Encode the 64-bit number into a big-endian 8-byte message length.
EncodeMessageLength(uint64_t length,void * buffer)50 static void EncodeMessageLength(uint64_t length, void* buffer) {
51     for (int i = 0; i < 8; ++i) {
52         reinterpret_cast<uint8_t*>(buffer)[i] = length >> (56 - i * 8);
53     }
54 }
55 
56 class TcpTransport : public Transport {
57   public:
58     // Factory function so we can return nullptr if initialization fails.
59     static std::unique_ptr<TcpTransport> NewTransport(std::unique_ptr<Socket> socket,
60                                                       std::string* error);
61 
62     ~TcpTransport() override = default;
63 
64     ssize_t Read(void* data, size_t length) override;
65     ssize_t Write(const void* data, size_t length) override;
66     int Close() override;
67     int Reset() override;
68 
69   private:
TcpTransport(std::unique_ptr<Socket> sock)70     explicit TcpTransport(std::unique_ptr<Socket> sock) : socket_(std::move(sock)) {}
71 
72     // Connects to the device and performs the initial handshake. Returns false and fills |error|
73     // on failure.
74     bool InitializeProtocol(std::string* error);
75 
76     std::unique_ptr<Socket> socket_;
77     uint64_t message_bytes_left_ = 0;
78 
79     DISALLOW_COPY_AND_ASSIGN(TcpTransport);
80 };
81 
NewTransport(std::unique_ptr<Socket> socket,std::string * error)82 std::unique_ptr<TcpTransport> TcpTransport::NewTransport(std::unique_ptr<Socket> socket,
83                                                          std::string* error) {
84     std::unique_ptr<TcpTransport> transport(new TcpTransport(std::move(socket)));
85 
86     if (!transport->InitializeProtocol(error)) {
87         return nullptr;
88     }
89 
90     return transport;
91 }
92 
93 // These error strings are checked in tcp_test.cpp and should be kept in sync.
InitializeProtocol(std::string * error)94 bool TcpTransport::InitializeProtocol(std::string* error) {
95     std::string handshake_message(android::base::StringPrintf("FB%02d", kProtocolVersion));
96 
97     if (!socket_->Send(handshake_message.c_str(), kHandshakeLength)) {
98         *error = android::base::StringPrintf("Failed to send initialization message (%s)",
99                                              Socket::GetErrorMessage().c_str());
100         return false;
101     }
102 
103     char buffer[kHandshakeLength + 1];
104     buffer[kHandshakeLength] = '\0';
105     if (socket_->ReceiveAll(buffer, kHandshakeLength, kHandshakeTimeoutMs) != kHandshakeLength) {
106         *error = android::base::StringPrintf(
107                 "No initialization message received (%s). Target may not support TCP fastboot",
108                 Socket::GetErrorMessage().c_str());
109         return false;
110     }
111 
112     if (memcmp(buffer, "FB", 2) != 0) {
113         *error = "Unrecognized initialization message. Target may not support TCP fastboot";
114         return false;
115     }
116 
117     int version = 0;
118     if (!android::base::ParseInt(buffer + 2, &version) || version < kProtocolVersion) {
119         *error = android::base::StringPrintf("Unknown TCP protocol version %s (host version %02d)",
120                                              buffer + 2, kProtocolVersion);
121         return false;
122     }
123 
124     error->clear();
125     return true;
126 }
127 
Read(void * data,size_t length)128 ssize_t TcpTransport::Read(void* data, size_t length) {
129     if (socket_ == nullptr) {
130         return -1;
131     }
132 
133     // Unless we're mid-message, read the next 8-byte message length.
134     if (message_bytes_left_ == 0) {
135         char buffer[8];
136         if (socket_->ReceiveAll(buffer, 8, 0) != 8) {
137             Close();
138             return -1;
139         }
140         message_bytes_left_ = ExtractMessageLength(buffer);
141     }
142 
143     // Now read the message (up to |length| bytes).
144     if (length > message_bytes_left_) {
145         length = message_bytes_left_;
146     }
147     ssize_t bytes_read = socket_->ReceiveAll(data, length, 0);
148     if (bytes_read == -1) {
149         Close();
150     } else {
151         message_bytes_left_ -= bytes_read;
152     }
153     return bytes_read;
154 }
155 
Write(const void * data,size_t length)156 ssize_t TcpTransport::Write(const void* data, size_t length) {
157     if (socket_ == nullptr) {
158         return -1;
159     }
160 
161     // Use multi-buffer writes for better performance.
162     char header[8];
163     EncodeMessageLength(length, header);
164     if (!socket_->Send(std::vector<cutils_socket_buffer_t>{{header, 8}, {data, length}})) {
165         Close();
166         return -1;
167     }
168 
169     return length;
170 }
171 
Close()172 int TcpTransport::Close() {
173     if (socket_ == nullptr) {
174         return 0;
175     }
176 
177     int result = socket_->Close();
178     socket_.reset();
179     return result;
180 }
181 
Reset()182 int TcpTransport::Reset() {
183     return 0;
184 }
185 
Connect(const std::string & hostname,int port,std::string * error)186 std::unique_ptr<Transport> Connect(const std::string& hostname, int port, std::string* error) {
187     return internal::Connect(Socket::NewClient(Socket::Protocol::kTcp, hostname, port, error),
188                              error);
189 }
190 
191 namespace internal {
192 
Connect(std::unique_ptr<Socket> sock,std::string * error)193 std::unique_ptr<Transport> Connect(std::unique_ptr<Socket> sock, std::string* error) {
194     if (sock == nullptr) {
195         // If Socket creation failed |error| is already set.
196         return nullptr;
197     }
198 
199     return TcpTransport::NewTransport(std::move(sock), error);
200 }
201 
202 }  // namespace internal
203 
204 }  // namespace tcp
205