1 /*
2 * Copyright (C) 2021 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "common/libs/utils/vsock_connection.h"
18
19 #include <sys/socket.h>
20 #include <sys/time.h>
21
22 #include <functional>
23 #include <future>
24 #include <memory>
25 #include <mutex>
26 #include <new>
27 #include <ostream>
28 #include <string>
29 #include <tuple>
30 #include <utility>
31 #include <vector>
32
33 #include <android-base/logging.h>
34 #include <json/json.h>
35
36 #include "common/libs/fs/shared_buf.h"
37 #include "common/libs/fs/shared_select.h"
38
39 namespace cuttlefish {
40
~VsockConnection()41 VsockConnection::~VsockConnection() { Disconnect(); }
42
ConnectAsync(unsigned int port,unsigned int cid)43 std::future<bool> VsockConnection::ConnectAsync(unsigned int port,
44 unsigned int cid) {
45 return std::async(std::launch::async,
46 [this, port, cid]() { return Connect(port, cid); });
47 }
48
Disconnect()49 void VsockConnection::Disconnect() {
50 LOG(INFO) << "Disconnecting with fd status:" << fd_->StrError();
51 fd_->Shutdown(SHUT_RDWR);
52 if (disconnect_callback_) {
53 disconnect_callback_();
54 }
55 fd_->Close();
56 }
57
SetDisconnectCallback(std::function<void ()> callback)58 void VsockConnection::SetDisconnectCallback(std::function<void()> callback) {
59 disconnect_callback_ = callback;
60 }
61
IsConnected() const62 bool VsockConnection::IsConnected() const { return fd_->IsOpen(); }
63
DataAvailable() const64 bool VsockConnection::DataAvailable() const {
65 SharedFDSet read_set;
66 read_set.Set(fd_);
67 struct timeval timeout = {0, 0};
68 return Select(&read_set, nullptr, nullptr, &timeout) > 0;
69 }
70
Read()71 int32_t VsockConnection::Read() {
72 std::lock_guard<std::recursive_mutex> lock(read_mutex_);
73 int32_t result;
74 if (ReadExactBinary(fd_, &result) != sizeof(result)) {
75 Disconnect();
76 return 0;
77 }
78 return result;
79 }
80
Read(std::vector<char> & data)81 bool VsockConnection::Read(std::vector<char>& data) {
82 std::lock_guard<std::recursive_mutex> lock(read_mutex_);
83 return ReadExact(fd_, &data) == data.size();
84 }
85
Read(size_t size)86 std::vector<char> VsockConnection::Read(size_t size) {
87 if (size == 0) {
88 return {};
89 }
90 std::lock_guard<std::recursive_mutex> lock(read_mutex_);
91 std::vector<char> result(size);
92 if (ReadExact(fd_, &result) != size) {
93 Disconnect();
94 return {};
95 }
96 return result;
97 }
98
ReadAsync(size_t size)99 std::future<std::vector<char>> VsockConnection::ReadAsync(size_t size) {
100 return std::async(std::launch::async, [this, size]() { return Read(size); });
101 }
102
103 // Message format is buffer size followed by buffer data
ReadMessage()104 std::vector<char> VsockConnection::ReadMessage() {
105 std::lock_guard<std::recursive_mutex> lock(read_mutex_);
106 auto size = Read();
107 if (size < 0) {
108 Disconnect();
109 return {};
110 }
111 return Read(size);
112 }
113
ReadMessage(std::vector<char> & data)114 bool VsockConnection::ReadMessage(std::vector<char>& data) {
115 std::lock_guard<std::recursive_mutex> lock(read_mutex_);
116 auto size = Read();
117 if (size < 0) {
118 Disconnect();
119 return false;
120 }
121 data.resize(size);
122 return Read(data);
123 }
124
ReadMessageAsync()125 std::future<std::vector<char>> VsockConnection::ReadMessageAsync() {
126 return std::async(std::launch::async, [this]() { return ReadMessage(); });
127 }
128
ReadJsonMessage()129 Json::Value VsockConnection::ReadJsonMessage() {
130 auto msg = ReadMessage();
131 Json::CharReaderBuilder builder;
132 Json::CharReader* reader = builder.newCharReader();
133 Json::Value json_msg;
134 std::string errors;
135 if (!reader->parse(msg.data(), msg.data() + msg.size(), &json_msg, &errors)) {
136 return {};
137 }
138 return json_msg;
139 }
140
ReadJsonMessageAsync()141 std::future<Json::Value> VsockConnection::ReadJsonMessageAsync() {
142 return std::async(std::launch::async, [this]() { return ReadJsonMessage(); });
143 }
144
Write(int32_t data)145 bool VsockConnection::Write(int32_t data) {
146 std::lock_guard<std::recursive_mutex> lock(write_mutex_);
147 if (WriteAllBinary(fd_, &data) != sizeof(data)) {
148 Disconnect();
149 return false;
150 }
151 return true;
152 }
153
Write(const char * data,unsigned int size)154 bool VsockConnection::Write(const char* data, unsigned int size) {
155 std::lock_guard<std::recursive_mutex> lock(write_mutex_);
156 if (WriteAll(fd_, data, size) != size) {
157 Disconnect();
158 return false;
159 }
160 return true;
161 }
162
Write(const std::vector<char> & data)163 bool VsockConnection::Write(const std::vector<char>& data) {
164 return Write(data.data(), data.size());
165 }
166
167 // Message format is buffer size followed by buffer data
WriteMessage(const std::string & data)168 bool VsockConnection::WriteMessage(const std::string& data) {
169 return Write(data.size()) && Write(data.c_str(), data.length());
170 }
171
WriteMessage(const std::vector<char> & data)172 bool VsockConnection::WriteMessage(const std::vector<char>& data) {
173 std::lock_guard<std::recursive_mutex> lock(write_mutex_);
174 return Write(data.size()) && Write(data);
175 }
176
WriteMessage(const Json::Value & data)177 bool VsockConnection::WriteMessage(const Json::Value& data) {
178 Json::StreamWriterBuilder factory;
179 std::string message_str = Json::writeString(factory, data);
180 return WriteMessage(message_str);
181 }
182
WriteStrides(const char * data,unsigned int size,unsigned int num_strides,int stride_size)183 bool VsockConnection::WriteStrides(const char* data, unsigned int size,
184 unsigned int num_strides, int stride_size) {
185 const char* src = data;
186 for (unsigned int i = 0; i < num_strides; ++i, src += stride_size) {
187 if (!Write(src, size)) {
188 return false;
189 }
190 }
191 return true;
192 }
193
Connect(unsigned int port,unsigned int cid)194 bool VsockClientConnection::Connect(unsigned int port, unsigned int cid) {
195 fd_ = SharedFD::VsockClient(cid, port, SOCK_STREAM);
196 if (!fd_->IsOpen()) {
197 LOG(ERROR) << "Failed to connect:" << fd_->StrError();
198 }
199 return fd_->IsOpen();
200 }
201
~VsockServerConnection()202 VsockServerConnection::~VsockServerConnection() { ServerShutdown(); }
203
ServerShutdown()204 void VsockServerConnection::ServerShutdown() {
205 if (server_fd_->IsOpen()) {
206 LOG(INFO) << __FUNCTION__
207 << ": server fd status:" << server_fd_->StrError();
208 server_fd_->Shutdown(SHUT_RDWR);
209 server_fd_->Close();
210 }
211 }
212
Connect(unsigned int port,unsigned int cid)213 bool VsockServerConnection::Connect(unsigned int port, unsigned int cid) {
214 if (!server_fd_->IsOpen()) {
215 server_fd_ = cuttlefish::SharedFD::VsockServer(port, SOCK_STREAM, cid);
216 }
217 if (server_fd_->IsOpen()) {
218 fd_ = SharedFD::Accept(*server_fd_);
219 return fd_->IsOpen();
220 } else {
221 return false;
222 }
223 }
224
225 } // namespace cuttlefish
226