1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #pragma once 17 18 #include <cstdlib> 19 #include <utility> 20 #include <vector> 21 #include <memory> 22 23 #include "common/vsoc/lib/typed_region_view.h" 24 #include "common/vsoc/shm/socket_forward_layout.h" 25 26 namespace vsoc { 27 namespace socket_forward { 28 29 struct Header { 30 std::uint32_t payload_length; 31 enum MessageType : std::uint32_t { 32 DATA = 0, 33 BEGIN, 34 END, 35 RECV_CLOSED, // indicate that this side's receive end is closed 36 RESTART, 37 }; 38 MessageType message_type; 39 }; 40 41 constexpr std::size_t kMaxPayloadSize = 42 layout::socket_forward::kMaxPacketSize - sizeof(Header); 43 44 struct Packet { 45 private: 46 Header header_; 47 using Payload = char[kMaxPayloadSize]; 48 Payload payload_data_; 49 MakePacketPacket50 static constexpr Packet MakePacket(Header::MessageType type) { 51 Packet packet{}; 52 packet.header_.message_type = type; 53 return packet; 54 } 55 56 public: 57 // port is only revelant on the host-side. 58 static Packet MakeBegin(std::uint16_t port); 59 MakeEndPacket60 static constexpr Packet MakeEnd() { return MakePacket(Header::END); } 61 MakeRecvClosedPacket62 static constexpr Packet MakeRecvClosed() { 63 return MakePacket(Header::RECV_CLOSED); 64 } 65 MakeRestartPacket66 static constexpr Packet MakeRestart() { return MakePacket(Header::RESTART); } 67 68 // NOTE payload and payload_length must still be set. MakeDataPacket69 static constexpr Packet MakeData() { return MakePacket(Header::DATA); } 70 emptyPacket71 bool empty() const { return IsData() && header_.payload_length == 0; } 72 set_payload_lengthPacket73 void set_payload_length(std::uint32_t length) { 74 CHECK_LE(length, sizeof payload_data_); 75 header_.payload_length = length; 76 } 77 payloadPacket78 Payload& payload() { return payload_data_; } 79 payloadPacket80 const Payload& payload() const { return payload_data_; } 81 payload_lengthPacket82 constexpr std::uint32_t payload_length() const { 83 return header_.payload_length; 84 } 85 IsBeginPacket86 constexpr bool IsBegin() const { 87 return header_.message_type == Header::BEGIN; 88 } 89 IsEndPacket90 constexpr bool IsEnd() const { return header_.message_type == Header::END; } 91 IsDataPacket92 constexpr bool IsData() const { return header_.message_type == Header::DATA; } 93 IsRecvClosedPacket94 constexpr bool IsRecvClosed() const { 95 return header_.message_type == Header::RECV_CLOSED; 96 } 97 IsRestartPacket98 constexpr bool IsRestart() const { 99 return header_.message_type == Header::RESTART; 100 } 101 portPacket102 constexpr std::uint16_t port() const { 103 CHECK(IsBegin()); 104 std::uint16_t port_number{}; 105 CHECK_EQ(payload_length(), sizeof port_number); 106 std::memcpy(&port_number, payload(), sizeof port_number); 107 return port_number; 108 } 109 raw_dataPacket110 char* raw_data() { return reinterpret_cast<char*>(this); } 111 raw_dataPacket112 const char* raw_data() const { return reinterpret_cast<const char*>(this); } 113 raw_data_lengthPacket114 constexpr size_t raw_data_length() const { 115 return payload_length() + sizeof header_; 116 } 117 }; 118 119 static_assert(sizeof(Packet) == layout::socket_forward::kMaxPacketSize, ""); 120 static_assert(std::is_pod<Packet>{}, ""); 121 122 // Data sent will start with a uint32_t indicating the number of bytes being 123 // sent, followed be the data itself 124 class SocketForwardRegionView 125 : public TypedRegionView<SocketForwardRegionView, 126 layout::socket_forward::SocketForwardLayout> { 127 private: 128 // Returns an empty data packet if the other side is closed. 129 void Recv(int connection_id, Packet* packet); 130 // Returns true on success 131 bool Send(int connection_id, const Packet& packet); 132 133 // skip everything in the connection queue until seeing a BEGIN packet. 134 // returns port from begin packet. 135 int IgnoreUntilBegin(int connection_id); 136 137 public: 138 class ShmSender; 139 class ShmReceiver; 140 141 using ShmSenderReceiverPair = std::pair<ShmSender, ShmReceiver>; 142 143 class ShmConnectionView { 144 public: ShmConnectionView(SocketForwardRegionView * region_view,int connection_id)145 ShmConnectionView(SocketForwardRegionView* region_view, int connection_id) 146 : region_view_{region_view}, connection_id_{connection_id} {} 147 148 #ifdef CUTTLEFISH_HOST 149 ShmSenderReceiverPair EstablishConnection(int port); 150 #else 151 // Should not be called while there is an active ShmSender or ShmReceiver 152 // for this connection. 153 ShmSenderReceiverPair WaitForNewConnection(); 154 #endif 155 port()156 int port() const { return port_; } 157 158 bool Send(const Packet& packet); 159 void Recv(Packet* packet); 160 161 ShmConnectionView(const ShmConnectionView&) = delete; 162 ShmConnectionView& operator=(const ShmConnectionView&) = delete; 163 164 // Moving invalidates all existing ShmSenders and ShmReceiver 165 ShmConnectionView(ShmConnectionView&&) = default; 166 ShmConnectionView& operator=(ShmConnectionView&&) = default; 167 ~ShmConnectionView() = default; 168 169 // NOTE should only be used for debugging/logging purposes. 170 // connection_ids are an implementation detail that are currently useful for 171 // debugging, but may go away in the future. connection_id()172 int connection_id() const { return connection_id_; } 173 174 private: region_view()175 SocketForwardRegionView* region_view() const { return region_view_; } 176 IsOtherSideRecvClosed()177 bool IsOtherSideRecvClosed() { 178 std::lock_guard<std::mutex> guard(*other_side_receive_closed_lock_); 179 return other_side_receive_closed_; 180 } 181 MarkOtherSideRecvClosed()182 void MarkOtherSideRecvClosed() { 183 std::lock_guard<std::mutex> guard(*other_side_receive_closed_lock_); 184 other_side_receive_closed_ = true; 185 } 186 187 void ReceiverThread(); 188 ShmSenderReceiverPair ResetAndConnect(); 189 190 class Receiver { 191 public: Receiver(ShmConnectionView * view)192 Receiver(ShmConnectionView* view) 193 : view_{view} 194 { 195 receiver_thread_ = std::thread([this] { Start(); }); 196 } 197 198 void Recv(Packet* packet); 199 Join()200 void Join() { receiver_thread_.join(); } 201 202 Receiver(const Receiver&) = delete; 203 Receiver& operator=(const Receiver&) = delete; 204 205 ~Receiver() = default; 206 private: 207 void Start(); 208 bool GotRecvClosed() const; 209 void ReceivePacket(); 210 void CheckPacketForRecvClosed(); 211 void CheckPacketForEnd(); 212 void UpdatePacketAndSignalAvailable(); 213 bool ShouldReceiveAnotherPacket() const; 214 bool ExpectMorePackets() const; 215 216 std::mutex receive_thread_data_lock_; 217 std::condition_variable receive_thread_data_cv_; 218 bool received_packet_free_ = true; 219 Packet received_packet_{}; 220 221 ShmConnectionView* view_{}; 222 bool saw_recv_closed_ = false; 223 bool saw_end_ = false; 224 #ifdef CUTTLEFISH_HOST 225 bool saw_data_ = false; 226 #endif 227 228 std::thread receiver_thread_; 229 }; 230 231 SocketForwardRegionView* region_view_{}; 232 int connection_id_ = -1; 233 int port_ = -1; 234 235 std::unique_ptr<std::mutex> other_side_receive_closed_lock_ = 236 std::unique_ptr<std::mutex>{new std::mutex{}}; 237 bool other_side_receive_closed_ = false; 238 239 std::unique_ptr<Receiver> receiver_; 240 }; 241 242 class ShmSender { 243 public: ShmSender(ShmConnectionView * view)244 explicit ShmSender(ShmConnectionView* view) : view_{view} {} 245 246 ShmSender(const ShmSender&) = delete; 247 ShmSender& operator=(const ShmSender&) = delete; 248 249 ShmSender(ShmSender&&) = default; 250 ShmSender& operator=(ShmSender&&) = default; 251 ~ShmSender() = default; 252 253 // Returns true on success 254 bool Send(const Packet& packet); 255 256 private: 257 struct EndSender { operatorEndSender258 void operator()(ShmConnectionView* view) const { 259 if (view) { 260 view->Send(Packet::MakeEnd()); 261 } 262 } 263 }; 264 265 // Doesn't actually own the View, responsible for sending the End 266 // indicator and marking the sending side as disconnected. 267 std::unique_ptr<ShmConnectionView, EndSender> view_; 268 }; 269 270 class ShmReceiver { 271 public: ShmReceiver(ShmConnectionView * view)272 explicit ShmReceiver(ShmConnectionView* view) : view_{view} {} 273 ShmReceiver(const ShmReceiver&) = delete; 274 ShmReceiver& operator=(const ShmReceiver&) = delete; 275 276 ShmReceiver(ShmReceiver&&) = default; 277 ShmReceiver& operator=(ShmReceiver&&) = default; 278 ~ShmReceiver() = default; 279 280 void Recv(Packet* packet); 281 282 private: 283 struct RecvClosedSender { operatorRecvClosedSender284 void operator()(ShmConnectionView* view) const { 285 if (view) { 286 view->Send(Packet::MakeRecvClosed()); 287 } 288 } 289 }; 290 291 // Doesn't actually own the view, responsible for sending the RecvClosed 292 // indicator 293 std::unique_ptr<ShmConnectionView, RecvClosedSender> view_{}; 294 }; 295 296 friend ShmConnectionView; 297 298 SocketForwardRegionView() = default; 299 ~SocketForwardRegionView() = default; 300 SocketForwardRegionView(const SocketForwardRegionView&) = delete; 301 SocketForwardRegionView& operator=(const SocketForwardRegionView&) = delete; 302 303 using ConnectionViewCollection = std::vector<ShmConnectionView>; 304 ConnectionViewCollection AllConnections(); 305 306 int port(int connection_id); 307 void CleanUpPreviousConnections(); 308 309 private: 310 #ifndef CUTTLEFISH_HOST 311 std::uint32_t last_seq_number_{}; 312 #endif 313 }; 314 315 } // namespace socket_forward 316 } // namespace vsoc 317