• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <set>
18 #include <thread>
19 #include <glog/logging.h>
20 #include <gflags/gflags.h>
21 
22 #include "common/libs/fs/shared_fd.h"
23 #include "common/vsoc/lib/socket_forward_region_view.h"
24 
25 #ifdef CUTTLEFISH_HOST
26 #include "host/libs/config/cuttlefish_config.h"
27 #endif
28 
29 using vsoc::socket_forward::Packet;
30 
31 DEFINE_uint32(tcp_port, 0, "TCP port (server on host, client on guest)");
32 DEFINE_uint32(vsock_port, 0, "vsock port (client on host, server on guest");
33 DEFINE_uint32(vsock_guest_cid, 0, "Guest identifier");
34 
35 namespace {
36 // Sends packets, Shutdown(SHUT_WR) on destruction
37 class SocketSender {
38  public:
SocketSender(cvd::SharedFD socket)39   explicit SocketSender(cvd::SharedFD socket) : socket_{socket} {}
40 
41   SocketSender(SocketSender&&) = default;
42   SocketSender& operator=(SocketSender&&) = default;
43 
44   SocketSender(const SocketSender&&) = delete;
45   SocketSender& operator=(const SocketSender&) = delete;
46 
~SocketSender()47   ~SocketSender() {
48     if (socket_.operator->()) {  // check that socket_ was not moved-from
49       socket_->Shutdown(SHUT_WR);
50     }
51   }
52 
SendAll(const Packet & packet)53   ssize_t SendAll(const Packet& packet) {
54     ssize_t written{};
55     while (written < static_cast<ssize_t>(packet.payload_length())) {
56       if (!socket_->IsOpen()) {
57         return -1;
58       }
59       auto just_written =
60           socket_->Send(packet.payload() + written,
61                         packet.payload_length() - written, MSG_NOSIGNAL);
62       if (just_written <= 0) {
63         LOG(INFO) << "Couldn't write to client: "
64                   << strerror(socket_->GetErrno());
65         return just_written;
66       }
67       written += just_written;
68     }
69     return written;
70   }
71 
72  private:
73   cvd::SharedFD socket_;
74 };
75 
76 class SocketReceiver {
77  public:
SocketReceiver(cvd::SharedFD socket)78   explicit SocketReceiver(cvd::SharedFD socket) : socket_{socket} {}
79 
80   SocketReceiver(SocketReceiver&&) = default;
81   SocketReceiver& operator=(SocketReceiver&&) = default;
82 
83   SocketReceiver(const SocketReceiver&&) = delete;
84   SocketReceiver& operator=(const SocketReceiver&) = delete;
85 
86   // *packet will be empty if Read returns 0 or error
Recv(Packet * packet)87   void Recv(Packet* packet) {
88     auto size = socket_->Read(packet->payload(), sizeof packet->payload());
89     if (size < 0) {
90       size = 0;
91     }
92     packet->set_payload_length(size);
93   }
94 
95  private:
96   cvd::SharedFD socket_;
97 };
98 
SocketToVsock(SocketReceiver socket_receiver,SocketSender vsock_sender)99 void SocketToVsock(SocketReceiver socket_receiver,
100                    SocketSender vsock_sender) {
101   while (true) {
102     auto packet = Packet::MakeData();
103     socket_receiver.Recv(&packet);
104     if (packet.empty() || vsock_sender.SendAll(packet) < 0) {
105       break;
106     }
107   }
108   LOG(INFO) << "Socket to vsock exiting";
109 }
110 
VsockToSocket(SocketSender socket_sender,SocketReceiver vsock_receiver)111 void VsockToSocket(SocketSender socket_sender,
112                    SocketReceiver vsock_receiver) {
113   auto packet = Packet::MakeData();
114   while (true) {
115     vsock_receiver.Recv(&packet);
116     CHECK(packet.IsData());
117     if (packet.empty()) {
118       break;
119     }
120     if (socket_sender.SendAll(packet) < 0) {
121       break;
122     }
123   }
124   LOG(INFO) << "Vsock to socket exiting";
125 }
126 
127 // One thread for reading from shm and writing into a socket.
128 // One thread for reading from a socket and writing into shm.
HandleConnection(cvd::SharedFD vsock,cvd::SharedFD socket)129 void HandleConnection(cvd::SharedFD vsock,
130                       cvd::SharedFD socket) {
131   auto socket_to_vsock =
132       std::thread(SocketToVsock, SocketReceiver{socket}, SocketSender{vsock});
133   VsockToSocket(SocketSender{socket}, SocketReceiver{vsock});
134   socket_to_vsock.join();
135 }
136 
137 #ifdef CUTTLEFISH_HOST
host()138 [[noreturn]] void host() {
139   LOG(INFO) << "starting server on " << FLAGS_tcp_port << " for vsock port "
140             << FLAGS_vsock_port;
141   auto server = cvd::SharedFD::SocketLocalServer(FLAGS_tcp_port, SOCK_STREAM);
142   CHECK(server->IsOpen()) << "Could not start server on " << FLAGS_tcp_port;
143   LOG(INFO) << "Accepting client connections";
144   int last_failure_reason = 0;
145   while (true) {
146     auto client_socket = cvd::SharedFD::Accept(*server);
147     CHECK(client_socket->IsOpen()) << "error creating client socket";
148     cvd::SharedFD vsock_socket = cvd::SharedFD::VsockClient(
149         FLAGS_vsock_guest_cid, FLAGS_vsock_port, SOCK_STREAM);
150     if (vsock_socket->IsOpen()) {
151       last_failure_reason = 0;
152       LOG(INFO) << "Connected to vsock:" << FLAGS_vsock_guest_cid << ":"
153                 << FLAGS_vsock_port;
154     } else {
155       // Don't log if the previous connection failed with the same error
156       if (last_failure_reason != vsock_socket->GetErrno()) {
157         last_failure_reason = vsock_socket->GetErrno();
158         LOG(ERROR) << "Unable to connect to vsock server: "
159                    << vsock_socket->StrError();
160       }
161       continue;
162     }
163     auto thread = std::thread(HandleConnection, std::move(vsock_socket),
164                               std::move(client_socket));
165     thread.detach();
166   }
167 }
168 
169 #else
OpenSocketConnection()170 cvd::SharedFD OpenSocketConnection() {
171   while (true) {
172     auto sock = cvd::SharedFD::SocketLocalClient(FLAGS_tcp_port, SOCK_STREAM);
173     if (sock->IsOpen()) {
174       return sock;
175     }
176     LOG(WARNING) << "could not connect on port " << FLAGS_tcp_port
177                  << ". sleeping for 1 second";
178     sleep(1);
179   }
180 }
181 
socketErrorIsRecoverable(int error)182 bool socketErrorIsRecoverable(int error) {
183   std::set<int> unrecoverable{EACCES, EAFNOSUPPORT, EINVAL, EPROTONOSUPPORT};
184   return unrecoverable.find(error) == unrecoverable.end();
185 }
186 
SleepForever()187 [[noreturn]] static void SleepForever() {
188   while (true) {
189     sleep(std::numeric_limits<unsigned int>::max());
190   }
191 }
192 
guest()193 [[noreturn]] void guest() {
194   LOG(INFO) << "Starting guest mainloop";
195   LOG(INFO) << "starting server on " << FLAGS_vsock_port;
196   cvd::SharedFD vsock;
197   do {
198     vsock = cvd::SharedFD::VsockServer(FLAGS_vsock_port, SOCK_STREAM);
199     if (!vsock->IsOpen() && !socketErrorIsRecoverable(vsock->GetErrno())) {
200       LOG(ERROR) << "Could not open vsock socket: " << vsock->StrError();
201       SleepForever();
202     }
203   } while (!vsock->IsOpen());
204   CHECK(vsock->IsOpen()) << "Could not start server on " << FLAGS_vsock_port;
205   while (true) {
206     LOG(INFO) << "waiting for vsock connection";
207     auto vsock_client = cvd::SharedFD::Accept(*vsock);
208     CHECK(vsock_client->IsOpen()) << "error creating vsock socket";
209     LOG(INFO) << "vsock socket accepted";
210     auto client = OpenSocketConnection();
211     CHECK(client->IsOpen()) << "error connecting to guest client";
212     auto thread = std::thread(HandleConnection, std::move(vsock_client),
213                               std::move(client));
214     thread.detach();
215   }
216 }
217 
218 #endif
219 }  // namespace
220 
main(int argc,char * argv[])221 int main(int argc, char* argv[]) {
222   gflags::ParseCommandLineFlags(&argc, &argv, true);
223 
224   CHECK(FLAGS_tcp_port != 0) << "Must specify -tcp_port flag";
225   CHECK(FLAGS_vsock_port != 0) << "Must specify -vsock_port flag";
226 #ifdef CUTTLEFISH_HOST
227   CHECK(FLAGS_vsock_guest_cid != 0) << "Must specify -vsock_guest_cid flag";
228   host();
229 #else
230   guest();
231 #endif
232 }
233