• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
17 
18 #include <cstdlib>
19 #include <utility>
20 #include <vector>
21 #include <memory>
22 
23 #include "common/vsoc/lib/typed_region_view.h"
24 #include "common/vsoc/shm/socket_forward_layout.h"
25 
26 namespace vsoc {
27 namespace socket_forward {
28 
29 struct Header {
30   std::uint32_t payload_length;
31   enum MessageType : std::uint32_t {
32     DATA = 0,
33     BEGIN,
34     END,
35     RECV_CLOSED,  // indicate that this side's receive end is closed
36     RESTART,
37   };
38   MessageType message_type;
39 };
40 
41 constexpr std::size_t kMaxPayloadSize =
42     layout::socket_forward::kMaxPacketSize - sizeof(Header);
43 
44 struct Packet {
45  private:
46   Header header_;
47   using Payload = char[kMaxPayloadSize];
48   Payload payload_data_;
49 
MakePacketPacket50   static constexpr Packet MakePacket(Header::MessageType type) {
51     Packet packet{};
52     packet.header_.message_type = type;
53     return packet;
54   }
55 
56  public:
57   // port is only revelant on the host-side.
58   static Packet MakeBegin(std::uint16_t port);
59 
MakeEndPacket60   static constexpr Packet MakeEnd() { return MakePacket(Header::END); }
61 
MakeRecvClosedPacket62   static constexpr Packet MakeRecvClosed() {
63     return MakePacket(Header::RECV_CLOSED);
64   }
65 
MakeRestartPacket66   static constexpr Packet MakeRestart() { return MakePacket(Header::RESTART); }
67 
68   // NOTE payload and payload_length must still be set.
MakeDataPacket69   static constexpr Packet MakeData() { return MakePacket(Header::DATA); }
70 
emptyPacket71   bool empty() const { return IsData() && header_.payload_length == 0; }
72 
set_payload_lengthPacket73   void set_payload_length(std::uint32_t length) {
74     CHECK_LE(length, sizeof payload_data_);
75     header_.payload_length = length;
76   }
77 
payloadPacket78   Payload& payload() { return payload_data_; }
79 
payloadPacket80   const Payload& payload() const { return payload_data_; }
81 
payload_lengthPacket82   constexpr std::uint32_t payload_length() const {
83     return header_.payload_length;
84   }
85 
IsBeginPacket86   constexpr bool IsBegin() const {
87     return header_.message_type == Header::BEGIN;
88   }
89 
IsEndPacket90   constexpr bool IsEnd() const { return header_.message_type == Header::END; }
91 
IsDataPacket92   constexpr bool IsData() const { return header_.message_type == Header::DATA; }
93 
IsRecvClosedPacket94   constexpr bool IsRecvClosed() const {
95     return header_.message_type == Header::RECV_CLOSED;
96   }
97 
IsRestartPacket98   constexpr bool IsRestart() const {
99     return header_.message_type == Header::RESTART;
100   }
101 
portPacket102   constexpr std::uint16_t port() const {
103     CHECK(IsBegin());
104     std::uint16_t port_number{};
105     CHECK_EQ(payload_length(), sizeof port_number);
106     std::memcpy(&port_number, payload(), sizeof port_number);
107     return port_number;
108   }
109 
raw_dataPacket110   char* raw_data() { return reinterpret_cast<char*>(this); }
111 
raw_dataPacket112   const char* raw_data() const { return reinterpret_cast<const char*>(this); }
113 
raw_data_lengthPacket114   constexpr size_t raw_data_length() const {
115     return payload_length() + sizeof header_;
116   }
117 };
118 
119 static_assert(sizeof(Packet) == layout::socket_forward::kMaxPacketSize, "");
120 static_assert(std::is_pod<Packet>{}, "");
121 
122 // Data sent will start with a uint32_t indicating the number of bytes being
123 // sent, followed be the data itself
124 class SocketForwardRegionView
125     : public TypedRegionView<SocketForwardRegionView,
126                              layout::socket_forward::SocketForwardLayout> {
127  private:
128   // Returns an empty data packet if the other side is closed.
129   void Recv(int connection_id, Packet* packet);
130   // Returns true on success
131   bool Send(int connection_id, const Packet& packet);
132 
133   // skip everything in the connection queue until seeing a BEGIN packet.
134   // returns port from begin packet.
135   int IgnoreUntilBegin(int connection_id);
136 
137  public:
138   class ShmSender;
139   class ShmReceiver;
140 
141   using ShmSenderReceiverPair = std::pair<ShmSender, ShmReceiver>;
142 
143   class ShmConnectionView {
144    public:
ShmConnectionView(SocketForwardRegionView * region_view,int connection_id)145     ShmConnectionView(SocketForwardRegionView* region_view, int connection_id)
146         : region_view_{region_view}, connection_id_{connection_id} {}
147 
148 #ifdef CUTTLEFISH_HOST
149     ShmSenderReceiverPair EstablishConnection(int port);
150 #else
151     // Should not be called while there is an active ShmSender or ShmReceiver
152     // for this connection.
153     ShmSenderReceiverPair WaitForNewConnection();
154 #endif
155 
port()156     int port() const { return port_; }
157 
158     bool Send(const Packet& packet);
159     void Recv(Packet* packet);
160 
161     ShmConnectionView(const ShmConnectionView&) = delete;
162     ShmConnectionView& operator=(const ShmConnectionView&) = delete;
163 
164     // Moving invalidates all existing ShmSenders and ShmReceiver
165     ShmConnectionView(ShmConnectionView&&) = default;
166     ShmConnectionView& operator=(ShmConnectionView&&) = default;
167     ~ShmConnectionView() = default;
168 
169     // NOTE should only be used for debugging/logging purposes.
170     // connection_ids are an implementation detail that are currently useful for
171     // debugging, but may go away in the future.
connection_id()172     int connection_id() const { return connection_id_; }
173 
174    private:
region_view()175     SocketForwardRegionView* region_view() const { return region_view_; }
176 
IsOtherSideRecvClosed()177     bool IsOtherSideRecvClosed() {
178       std::lock_guard<std::mutex> guard(*other_side_receive_closed_lock_);
179       return other_side_receive_closed_;
180     }
181 
MarkOtherSideRecvClosed()182     void MarkOtherSideRecvClosed() {
183       std::lock_guard<std::mutex> guard(*other_side_receive_closed_lock_);
184       other_side_receive_closed_ = true;
185     }
186 
187     void ReceiverThread();
188     ShmSenderReceiverPair ResetAndConnect();
189 
190     class Receiver {
191      public:
Receiver(ShmConnectionView * view)192       Receiver(ShmConnectionView* view)
193           : view_{view}
194       {
195         receiver_thread_ = std::thread([this] { Start(); });
196       }
197 
198       void Recv(Packet* packet);
199 
Join()200       void Join() { receiver_thread_.join(); }
201 
202       Receiver(const Receiver&) = delete;
203       Receiver& operator=(const Receiver&) = delete;
204 
205       ~Receiver() = default;
206      private:
207       void Start();
208       bool GotRecvClosed() const;
209       void ReceivePacket();
210       void CheckPacketForRecvClosed();
211       void CheckPacketForEnd();
212       void UpdatePacketAndSignalAvailable();
213       bool ShouldReceiveAnotherPacket() const;
214       bool ExpectMorePackets() const;
215 
216       std::mutex receive_thread_data_lock_;
217       std::condition_variable receive_thread_data_cv_;
218       bool received_packet_free_ = true;
219       Packet received_packet_{};
220 
221       ShmConnectionView* view_{};
222       bool saw_recv_closed_ = false;
223       bool saw_end_ = false;
224 #ifdef CUTTLEFISH_HOST
225       bool saw_data_ = false;
226 #endif
227 
228       std::thread receiver_thread_;
229     };
230 
231     SocketForwardRegionView* region_view_{};
232     int connection_id_ = -1;
233     int port_ = -1;
234 
235     std::unique_ptr<std::mutex> other_side_receive_closed_lock_ =
236         std::unique_ptr<std::mutex>{new std::mutex{}};
237     bool other_side_receive_closed_ = false;
238 
239     std::unique_ptr<Receiver> receiver_;
240   };
241 
242   class ShmSender {
243    public:
ShmSender(ShmConnectionView * view)244     explicit ShmSender(ShmConnectionView* view) : view_{view} {}
245 
246     ShmSender(const ShmSender&) = delete;
247     ShmSender& operator=(const ShmSender&) = delete;
248 
249     ShmSender(ShmSender&&) = default;
250     ShmSender& operator=(ShmSender&&) = default;
251     ~ShmSender() = default;
252 
253     // Returns true on success
254     bool Send(const Packet& packet);
255 
256    private:
257     struct EndSender {
operatorEndSender258       void operator()(ShmConnectionView* view) const {
259         if (view) {
260           view->Send(Packet::MakeEnd());
261         }
262       }
263     };
264 
265     // Doesn't actually own the View, responsible for sending the End
266     // indicator and marking the sending side as disconnected.
267     std::unique_ptr<ShmConnectionView, EndSender> view_;
268   };
269 
270   class ShmReceiver {
271    public:
ShmReceiver(ShmConnectionView * view)272     explicit ShmReceiver(ShmConnectionView* view) : view_{view} {}
273     ShmReceiver(const ShmReceiver&) = delete;
274     ShmReceiver& operator=(const ShmReceiver&) = delete;
275 
276     ShmReceiver(ShmReceiver&&) = default;
277     ShmReceiver& operator=(ShmReceiver&&) = default;
278     ~ShmReceiver() = default;
279 
280     void Recv(Packet* packet);
281 
282    private:
283     struct RecvClosedSender {
operatorRecvClosedSender284       void operator()(ShmConnectionView* view) const {
285         if (view) {
286           view->Send(Packet::MakeRecvClosed());
287         }
288       }
289     };
290 
291     // Doesn't actually own the view, responsible for sending the RecvClosed
292     // indicator
293     std::unique_ptr<ShmConnectionView, RecvClosedSender> view_{};
294   };
295 
296   friend ShmConnectionView;
297 
298   SocketForwardRegionView() = default;
299   ~SocketForwardRegionView() = default;
300   SocketForwardRegionView(const SocketForwardRegionView&) = delete;
301   SocketForwardRegionView& operator=(const SocketForwardRegionView&) = delete;
302 
303   using ConnectionViewCollection = std::vector<ShmConnectionView>;
304   ConnectionViewCollection AllConnections();
305 
306   int port(int connection_id);
307   void CleanUpPreviousConnections();
308 
309  private:
310 #ifndef CUTTLEFISH_HOST
311   std::uint32_t last_seq_number_{};
312 #endif
313 };
314 
315 }  // namespace socket_forward
316 }  // namespace vsoc
317