1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <cassert>
18
19 #include "common/vsoc/lib/circqueue_impl.h"
20 #include "common/vsoc/lib/lock_guard.h"
21 #include "common/vsoc/lib/socket_forward_region_view.h"
22 #include "common/vsoc/shm/lock.h"
23 #include "common/vsoc/shm/socket_forward_layout.h"
24
25 using vsoc::layout::socket_forward::Queue;
26 using vsoc::layout::socket_forward::QueuePair;
27 // store the read and write direction as variables to keep the ifdefs and macros
28 // in later code to a minimum
29 constexpr auto ReadDirection = &QueuePair::
30 #ifdef CUTTLEFISH_HOST
31 guest_to_host;
32 #else
33 host_to_guest;
34 #endif
35
36 constexpr auto WriteDirection = &QueuePair::
37 #ifdef CUTTLEFISH_HOST
38 host_to_guest;
39 #else
40 guest_to_host;
41 #endif
42
43 using vsoc::socket_forward::SocketForwardRegionView;
44
MakeBegin(std::uint16_t port)45 vsoc::socket_forward::Packet vsoc::socket_forward::Packet::MakeBegin(
46 std::uint16_t port) {
47 auto packet = MakePacket(Header::BEGIN);
48 std::memcpy(packet.payload(), &port, sizeof port);
49 packet.set_payload_length(sizeof port);
50 return packet;
51 }
52
Recv(int connection_id,Packet * packet)53 void SocketForwardRegionView::Recv(int connection_id, Packet* packet) {
54 CHECK(packet != nullptr);
55 do {
56 (data()->queues_[connection_id].*ReadDirection)
57 .queue.Read(this, reinterpret_cast<char*>(packet), sizeof *packet);
58 } while (packet->IsBegin());
59 CHECK(!packet->empty()) << "zero-size data message received";
60 CHECK_LE(packet->payload_length(), kMaxPayloadSize) << "invalid size";
61 }
62
Send(int connection_id,const Packet & packet)63 bool SocketForwardRegionView::Send(int connection_id, const Packet& packet) {
64 CHECK(!packet.empty());
65 CHECK_LE(packet.payload_length(), kMaxPayloadSize);
66
67 (data()->queues_[connection_id].*WriteDirection)
68 .queue.Write(this, packet.raw_data(), packet.raw_data_length());
69 return true;
70 }
71
IgnoreUntilBegin(int connection_id)72 int SocketForwardRegionView::IgnoreUntilBegin(int connection_id) {
73 Packet packet{};
74 do {
75 (data()->queues_[connection_id].*ReadDirection)
76 .queue.Read(this, reinterpret_cast<char*>(&packet), sizeof packet);
77 } while (!packet.IsBegin());
78 return packet.port();
79 }
80
81 constexpr int kNumQueues =
82 static_cast<int>(vsoc::layout::socket_forward::kNumQueues);
83
CleanUpPreviousConnections()84 void SocketForwardRegionView::CleanUpPreviousConnections() {
85 data()->Recover();
86
87 static constexpr auto kRestartPacket = Packet::MakeRestart();
88 for (int connection_id = 0; connection_id < kNumQueues; ++connection_id) {
89 Send(connection_id, kRestartPacket);
90 }
91 }
92
93 SocketForwardRegionView::ConnectionViewCollection
AllConnections()94 SocketForwardRegionView::AllConnections() {
95 SocketForwardRegionView::ConnectionViewCollection all_queues;
96 for (int connection_id = 0; connection_id < kNumQueues; ++connection_id) {
97 all_queues.emplace_back(this, connection_id);
98 }
99 return all_queues;
100 }
101
102 // --- Connection ---- //
103
Recv(Packet * packet)104 void SocketForwardRegionView::ShmConnectionView::Receiver::Recv(Packet* packet) {
105 std::unique_lock<std::mutex> guard(receive_thread_data_lock_);
106 while (received_packet_free_) {
107 receive_thread_data_cv_.wait(guard);
108 }
109 CHECK(received_packet_.IsData());
110 *packet = received_packet_;
111 received_packet_free_ = true;
112 receive_thread_data_cv_.notify_one();
113 }
114
GotRecvClosed() const115 bool SocketForwardRegionView::ShmConnectionView::Receiver::GotRecvClosed() const {
116 return received_packet_.IsRecvClosed() || (received_packet_.IsRestart()
117 #ifdef CUTTLEFISH_HOST
118 && saw_data_
119 #endif
120 );
121 }
122
ShouldReceiveAnotherPacket() const123 bool SocketForwardRegionView::ShmConnectionView::Receiver::ShouldReceiveAnotherPacket() const {
124 return (received_packet_.IsRecvClosed() && !saw_end_) ||
125 (saw_end_ && received_packet_.IsEnd())
126 #ifdef CUTTLEFISH_HOST
127 || (received_packet_.IsRestart() && !saw_data_) ||
128 (received_packet_.IsBegin())
129 #endif
130 ;
131 }
132
ReceivePacket()133 void SocketForwardRegionView::ShmConnectionView::Receiver::ReceivePacket() {
134 view_->region_view()->Recv(view_->connection_id(), &received_packet_);
135 }
136
CheckPacketForRecvClosed()137 void SocketForwardRegionView::ShmConnectionView::Receiver::CheckPacketForRecvClosed() {
138 if (GotRecvClosed()) {
139 saw_recv_closed_ = true;
140 view_->MarkOtherSideRecvClosed();
141 }
142 #ifdef CUTTLEFISH_HOST
143 if (received_packet_.IsData()) {
144 saw_data_ = true;
145 }
146 #endif
147 }
148
CheckPacketForEnd()149 void SocketForwardRegionView::ShmConnectionView::Receiver::CheckPacketForEnd() {
150 if (received_packet_.IsEnd() || received_packet_.IsRestart()) {
151 CHECK(!saw_end_ || received_packet_.IsRestart());
152 saw_end_ = true;
153 }
154 }
155
156
ExpectMorePackets() const157 bool SocketForwardRegionView::ShmConnectionView::Receiver::ExpectMorePackets() const {
158 return !saw_recv_closed_ || !saw_end_;
159 }
160
UpdatePacketAndSignalAvailable()161 void SocketForwardRegionView::ShmConnectionView::Receiver::UpdatePacketAndSignalAvailable() {
162 if (!received_packet_.IsData()) {
163 static constexpr auto kEmptyPacket = Packet::MakeData();
164 received_packet_ = kEmptyPacket;
165 }
166 received_packet_free_ = false;
167 receive_thread_data_cv_.notify_one();
168 }
169
Start()170 void SocketForwardRegionView::ShmConnectionView::Receiver::Start() {
171 while (ExpectMorePackets()) {
172 std::unique_lock<std::mutex> guard(receive_thread_data_lock_);
173 while (!received_packet_free_) {
174 receive_thread_data_cv_.wait(guard);
175 }
176
177 do {
178 ReceivePacket();
179 CheckPacketForRecvClosed();
180 } while (ShouldReceiveAnotherPacket());
181
182 if (received_packet_.empty()) {
183 LOG(ERROR) << "Received empty packet.";
184 }
185
186 CheckPacketForEnd();
187
188 UpdatePacketAndSignalAvailable();
189 }
190 }
191
ResetAndConnect()192 auto SocketForwardRegionView::ShmConnectionView::ResetAndConnect()
193 -> ShmSenderReceiverPair {
194 if (receiver_) {
195 receiver_->Join();
196 }
197
198 {
199 std::lock_guard<std::mutex> guard(*other_side_receive_closed_lock_);
200 other_side_receive_closed_ = false;
201 }
202
203 #ifdef CUTTLEFISH_HOST
204 region_view()->IgnoreUntilBegin(connection_id());
205 region_view()->Send(connection_id(), Packet::MakeBegin(port_));
206 #else
207 region_view()->Send(connection_id(), Packet::MakeBegin(port_));
208 port_ =
209 region_view()->IgnoreUntilBegin(connection_id());
210 #endif
211
212 receiver_.reset(new Receiver{this});
213 return {ShmSender{this}, ShmReceiver{this}};
214 }
215
216 #ifdef CUTTLEFISH_HOST
EstablishConnection(int port)217 auto SocketForwardRegionView::ShmConnectionView::EstablishConnection(int port)
218 -> ShmSenderReceiverPair {
219 port_ = port;
220 return ResetAndConnect();
221 }
222 #else
WaitForNewConnection()223 auto SocketForwardRegionView::ShmConnectionView::WaitForNewConnection()
224 -> ShmSenderReceiverPair {
225 port_ = 0;
226 return ResetAndConnect();
227 }
228 #endif
229
Send(const Packet & packet)230 bool SocketForwardRegionView::ShmConnectionView::Send(const Packet& packet) {
231 if (packet.empty()) {
232 LOG(ERROR) << "Sending empty packet";
233 }
234 if (packet.IsData() && IsOtherSideRecvClosed()) {
235 return false;
236 }
237 return region_view()->Send(connection_id(), packet);
238 }
239
Recv(Packet * packet)240 void SocketForwardRegionView::ShmConnectionView::Recv(Packet* packet) {
241 receiver_->Recv(packet);
242 }
243
Recv(Packet * packet)244 void SocketForwardRegionView::ShmReceiver::Recv(Packet* packet) {
245 view_->Recv(packet);
246 }
247
Send(const Packet & packet)248 bool SocketForwardRegionView::ShmSender::Send(const Packet& packet) {
249 return view_->Send(packet);
250 }
251