• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2023 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14 
15 #include "pw_rpc_transport/socket_rpc_transport.h"
16 
17 #include <algorithm>
18 #include <random>
19 
20 #include "pw_bytes/span.h"
21 #include "pw_log/log.h"
22 #include "pw_rpc_transport/socket_rpc_transport.h"
23 #include "pw_span/span.h"
24 #include "pw_status/status.h"
25 #include "pw_stream/socket_stream.h"
26 #include "pw_sync/thread_notification.h"
27 #include "pw_thread/thread.h"
28 #include "pw_thread_stl/options.h"
29 #include "pw_unit_test/framework.h"
30 
31 namespace pw::rpc {
32 namespace {
33 
34 using namespace std::chrono_literals;
35 
36 constexpr size_t kMaxWriteSize = 64;
37 constexpr size_t kReadBufferSize = 64;
38 // Let the kernel pick the port number.
39 constexpr uint16_t kServerPort = 0;
40 
41 class TestIngress : public RpcIngressHandler {
42  public:
TestIngress(size_t num_bytes_expected)43   explicit TestIngress(size_t num_bytes_expected)
44       : num_bytes_expected_(num_bytes_expected) {}
45 
ProcessIncomingData(ConstByteSpan buffer)46   Status ProcessIncomingData(ConstByteSpan buffer) override {
47     if (num_bytes_expected_ > 0) {
48       std::copy(buffer.begin(), buffer.end(), std::back_inserter(received_));
49       num_bytes_expected_ -= std::min(num_bytes_expected_, buffer.size());
50     }
51     if (num_bytes_expected_ == 0) {
52       done_.release();
53     }
54     return OkStatus();
55   }
56 
received() const57   std::vector<std::byte> received() const { return received_; }
Wait()58   void Wait() { done_.acquire(); }
59 
60  private:
61   size_t num_bytes_expected_ = 0;
62   sync::ThreadNotification done_;
63   std::vector<std::byte> received_;
64 };
65 
66 class SocketSender {
67  public:
SocketSender(SocketRpcTransport<kReadBufferSize> & transport)68   SocketSender(SocketRpcTransport<kReadBufferSize>& transport)
69       : transport_(transport) {
70     unsigned char c = 0;
71     for (auto& i : data_) {
72       i = std::byte{c++};
73     }
74     std::mt19937 rg{0x12345678};
75     std::shuffle(data_.begin(), data_.end(), rg);
76   }
77 
sent()78   std::vector<std::byte> sent() { return sent_; }
79 
MakeFrame(size_t max_size)80   RpcFrame MakeFrame(size_t max_size) {
81     std::mt19937 rg{0x12345678};
82     size_t offset = offset_dist_(rg);
83     size_t message_size = std::min(size_dist_(rg), max_size);
84     size_t header_size = message_size > 4 ? 4 : message_size;
85     size_t payload_size = message_size > 4 ? message_size - 4 : 0;
86 
87     return RpcFrame{.header = span(data_).subspan(offset, header_size),
88                     .payload = span(data_).subspan(offset, payload_size)};
89   }
90 
Send(size_t num_bytes)91   void Send(size_t num_bytes) {
92     size_t bytes_written = 0;
93     while (bytes_written < num_bytes) {
94       auto frame = MakeFrame(num_bytes - bytes_written);
95       std::copy(
96           frame.header.begin(), frame.header.end(), std::back_inserter(sent_));
97       std::copy(frame.payload.begin(),
98                 frame.payload.end(),
99                 std::back_inserter(sent_));
100 
101       // Tests below expect to see all data written to the socket to be received
102       // by the other end, so we keep retrying on any errors that could happen
103       // during reconnection: in reality it would be up to the higher level
104       // abstractions to do this depending on how they manage buffers etc. For
105       // the tests we just keep retrying indefinitely: if there is a
106       // non-transient problem then the test will eventually time out.
107       while (true) {
108         const auto send_status = transport_.Send(frame);
109         if (send_status.ok()) {
110           break;
111         }
112       }
113 
114       bytes_written += frame.header.size() + frame.payload.size();
115     }
116   }
117 
118  private:
119   SocketRpcTransport<kReadBufferSize>& transport_;
120   std::vector<std::byte> sent_;
121   std::array<std::byte, 256> data_{};
122   std::uniform_int_distribution<size_t> offset_dist_{0, 255};
123   std::uniform_int_distribution<size_t> size_dist_{1, kMaxWriteSize};
124 };
125 
126 class SocketSenderThreadCore : public SocketSender, public thread::ThreadCore {
127  public:
SocketSenderThreadCore(SocketRpcTransport<kReadBufferSize> & transport,size_t write_size)128   SocketSenderThreadCore(SocketRpcTransport<kReadBufferSize>& transport,
129                          size_t write_size)
130       : SocketSender(transport), write_size_(write_size) {}
131 
132  private:
Run()133   void Run() override { Send(write_size_); }
134   size_t write_size_;
135 };
136 
TEST(SocketRpcTransportTest,SendAndReceiveFramesOverSocketConnection)137 TEST(SocketRpcTransportTest, SendAndReceiveFramesOverSocketConnection) {
138   constexpr size_t kWriteSize = 8192;
139 
140   TestIngress server_ingress(kWriteSize);
141   TestIngress client_ingress(kWriteSize);
142 
143   auto server = SocketRpcTransport<kReadBufferSize>(
144       SocketRpcTransport<kReadBufferSize>::kAsServer,
145       kServerPort,
146       server_ingress);
147   auto server_thread = thread::Thread(thread::stl::Options(), server);
148 
149   server.WaitUntilReady();
150   auto server_port = server.port();
151 
152   auto client = SocketRpcTransport<kReadBufferSize>(
153       SocketRpcTransport<kReadBufferSize>::kAsClient,
154       "localhost",
155       server_port,
156       client_ingress);
157   auto client_thread = thread::Thread(thread::stl::Options(), client);
158 
159   client.WaitUntilConnected();
160   server.WaitUntilConnected();
161 
162   SocketSenderThreadCore client_sender(client, kWriteSize);
163   SocketSenderThreadCore server_sender(server, kWriteSize);
164 
165   auto client_sender_thread =
166       thread::Thread(thread::stl::Options(), client_sender);
167   auto server_sender_thread =
168       thread::Thread(thread::stl::Options(), server_sender);
169 
170   client_sender_thread.join();
171   server_sender_thread.join();
172 
173   server_ingress.Wait();
174   client_ingress.Wait();
175 
176   server.Stop();
177   client.Stop();
178 
179   server_thread.join();
180   client_thread.join();
181 
182   auto received_by_server = server_ingress.received();
183   EXPECT_EQ(received_by_server.size(), kWriteSize);
184   EXPECT_TRUE(std::equal(received_by_server.begin(),
185                          received_by_server.end(),
186                          client_sender.sent().begin()));
187 
188   auto received_by_client = client_ingress.received();
189   EXPECT_EQ(received_by_client.size(), kWriteSize);
190   EXPECT_TRUE(std::equal(received_by_client.begin(),
191                          received_by_client.end(),
192                          server_sender.sent().begin()));
193 }
194 
TEST(SocketRpcTransportTest,ServerReconnects)195 TEST(SocketRpcTransportTest, ServerReconnects) {
196   // Set up a server and a client that reconnects multiple times. The server
197   // must accept the new connection gracefully.
198   constexpr size_t kWriteSize = 8192;
199   std::vector<std::byte> received;
200 
201   TestIngress server_ingress(0);
202   auto server = SocketRpcTransport<kReadBufferSize>(
203       SocketRpcTransport<kReadBufferSize>::kAsServer,
204       kServerPort,
205       server_ingress);
206   auto server_thread = thread::Thread(thread::stl::Options(), server);
207 
208   server.WaitUntilReady();
209   auto server_port = server.port();
210   SocketSender server_sender(server);
211 
212   {
213     TestIngress client_ingress(kWriteSize);
214     auto client = SocketRpcTransport<kReadBufferSize>(
215         SocketRpcTransport<kReadBufferSize>::kAsClient,
216         "localhost",
217         server_port,
218         client_ingress);
219     auto client_thread = thread::Thread(thread::stl::Options(), client);
220 
221     client.WaitUntilConnected();
222     server.WaitUntilConnected();
223 
224     server_sender.Send(kWriteSize);
225     client_ingress.Wait();
226     auto client_received = client_ingress.received();
227     std::copy(client_received.begin(),
228               client_received.end(),
229               std::back_inserter(received));
230     EXPECT_EQ(received.size(), kWriteSize);
231 
232     // Stop the client but not the server: we're re-using the same server
233     // with a new client below.
234     client.Stop();
235     client_thread.join();
236   }
237 
238   // Reconnect to the server and keep sending frames.
239   {
240     TestIngress client_ingress(kWriteSize);
241     auto client = SocketRpcTransport<kReadBufferSize>(
242         SocketRpcTransport<kReadBufferSize>::kAsClient,
243         "localhost",
244         server_port,
245         client_ingress);
246     auto client_thread = thread::Thread(thread::stl::Options(), client);
247 
248     client.WaitUntilConnected();
249     server.WaitUntilConnected();
250 
251     server_sender.Send(kWriteSize);
252     client_ingress.Wait();
253     auto client_received = client_ingress.received();
254     std::copy(client_received.begin(),
255               client_received.end(),
256               std::back_inserter(received));
257 
258     client.Stop();
259     client_thread.join();
260 
261     // This time stop the server as well.
262     SocketSender client_sender(client);
263     server.Stop();
264     server_thread.join();
265   }
266 
267   EXPECT_EQ(received.size(), 2 * kWriteSize);
268   EXPECT_EQ(server_sender.sent().size(), 2 * kWriteSize);
269   EXPECT_TRUE(std::equal(
270       received.begin(), received.end(), server_sender.sent().begin()));
271 }
272 
TEST(SocketRpcTransportTest,ClientReconnects)273 TEST(SocketRpcTransportTest, ClientReconnects) {
274   // Set up a server and a client, then recycle the server. The client must
275   // must reconnect gracefully.
276   constexpr size_t kWriteSize = 8192;
277   uint16_t server_port = 0;
278 
279   TestIngress server_ingress(0);
280   TestIngress client_ingress(2 * kWriteSize);
281 
282   auto server = std::make_unique<SocketRpcTransport<kReadBufferSize>>(
283       SocketRpcTransport<kReadBufferSize>::kAsServer,
284       kServerPort,
285       server_ingress);
286   auto server_thread = thread::Thread(thread::stl::Options(), *server);
287 
288   server->WaitUntilReady();
289   server_port = server->port();
290 
291   auto client = SocketRpcTransport<kReadBufferSize>(
292       SocketRpcTransport<kReadBufferSize>::kAsClient,
293       "localhost",
294       server_port,
295       client_ingress);
296   auto client_thread = thread::Thread(thread::stl::Options(), client);
297 
298   client.WaitUntilConnected();
299   server->WaitUntilConnected();
300 
301   SocketSender client_sender(client);
302   SocketSender server1_sender(*server);
303   std::vector<std::byte> sent_by_server;
304 
305   server1_sender.Send(kWriteSize);
306   server->Stop();
307   auto server1_sent = server1_sender.sent();
308   std::copy(server1_sent.begin(),
309             server1_sent.end(),
310             std::back_inserter(sent_by_server));
311 
312   server_thread.join();
313   server = nullptr;
314 
315   server = std::make_unique<SocketRpcTransport<kReadBufferSize>>(
316       SocketRpcTransport<kReadBufferSize>::kAsServer,
317       server_port,
318       server_ingress);
319   SocketSender server2_sender(*server);
320   server_thread = thread::Thread(thread::stl::Options(), *server);
321 
322   client.WaitUntilConnected();
323   server->WaitUntilConnected();
324 
325   server2_sender.Send(kWriteSize);
326   client_ingress.Wait();
327 
328   server->Stop();
329   auto server2_sent = server2_sender.sent();
330   std::copy(server2_sent.begin(),
331             server2_sent.end(),
332             std::back_inserter(sent_by_server));
333 
334   server_thread.join();
335 
336   client.Stop();
337   client_thread.join();
338   server = nullptr;
339 
340   auto received_by_client = client_ingress.received();
341   EXPECT_EQ(received_by_client.size(), 2 * kWriteSize);
342   EXPECT_TRUE(std::equal(received_by_client.begin(),
343                          received_by_client.end(),
344                          sent_by_server.begin()));
345 }
346 
347 }  // namespace
348 }  // namespace pw::rpc
349