1 // Copyright 2023 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 // https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14
15 #include "pw_rpc_transport/socket_rpc_transport.h"
16
17 #include <algorithm>
18 #include <random>
19
20 #include "pw_bytes/span.h"
21 #include "pw_log/log.h"
22 #include "pw_rpc_transport/socket_rpc_transport.h"
23 #include "pw_span/span.h"
24 #include "pw_status/status.h"
25 #include "pw_stream/socket_stream.h"
26 #include "pw_sync/thread_notification.h"
27 #include "pw_thread/thread.h"
28 #include "pw_thread_stl/options.h"
29 #include "pw_unit_test/framework.h"
30
31 namespace pw::rpc {
32 namespace {
33
34 using namespace std::chrono_literals;
35
36 constexpr size_t kMaxWriteSize = 64;
37 constexpr size_t kReadBufferSize = 64;
38 // Let the kernel pick the port number.
39 constexpr uint16_t kServerPort = 0;
40
41 class TestIngress : public RpcIngressHandler {
42 public:
TestIngress(size_t num_bytes_expected)43 explicit TestIngress(size_t num_bytes_expected)
44 : num_bytes_expected_(num_bytes_expected) {}
45
ProcessIncomingData(ConstByteSpan buffer)46 Status ProcessIncomingData(ConstByteSpan buffer) override {
47 if (num_bytes_expected_ > 0) {
48 std::copy(buffer.begin(), buffer.end(), std::back_inserter(received_));
49 num_bytes_expected_ -= std::min(num_bytes_expected_, buffer.size());
50 }
51 if (num_bytes_expected_ == 0) {
52 done_.release();
53 }
54 return OkStatus();
55 }
56
received() const57 std::vector<std::byte> received() const { return received_; }
Wait()58 void Wait() { done_.acquire(); }
59
60 private:
61 size_t num_bytes_expected_ = 0;
62 sync::ThreadNotification done_;
63 std::vector<std::byte> received_;
64 };
65
66 class SocketSender {
67 public:
SocketSender(SocketRpcTransport<kReadBufferSize> & transport)68 SocketSender(SocketRpcTransport<kReadBufferSize>& transport)
69 : transport_(transport) {
70 unsigned char c = 0;
71 for (auto& i : data_) {
72 i = std::byte{c++};
73 }
74 std::mt19937 rg{0x12345678};
75 std::shuffle(data_.begin(), data_.end(), rg);
76 }
77
sent()78 std::vector<std::byte> sent() { return sent_; }
79
MakeFrame(size_t max_size)80 RpcFrame MakeFrame(size_t max_size) {
81 std::mt19937 rg{0x12345678};
82 size_t offset = offset_dist_(rg);
83 size_t message_size = std::min(size_dist_(rg), max_size);
84 size_t header_size = message_size > 4 ? 4 : message_size;
85 size_t payload_size = message_size > 4 ? message_size - 4 : 0;
86
87 return RpcFrame{.header = span(data_).subspan(offset, header_size),
88 .payload = span(data_).subspan(offset, payload_size)};
89 }
90
Send(size_t num_bytes)91 void Send(size_t num_bytes) {
92 size_t bytes_written = 0;
93 while (bytes_written < num_bytes) {
94 auto frame = MakeFrame(num_bytes - bytes_written);
95 std::copy(
96 frame.header.begin(), frame.header.end(), std::back_inserter(sent_));
97 std::copy(frame.payload.begin(),
98 frame.payload.end(),
99 std::back_inserter(sent_));
100
101 // Tests below expect to see all data written to the socket to be received
102 // by the other end, so we keep retrying on any errors that could happen
103 // during reconnection: in reality it would be up to the higher level
104 // abstractions to do this depending on how they manage buffers etc. For
105 // the tests we just keep retrying indefinitely: if there is a
106 // non-transient problem then the test will eventually time out.
107 while (true) {
108 const auto send_status = transport_.Send(frame);
109 if (send_status.ok()) {
110 break;
111 }
112 }
113
114 bytes_written += frame.header.size() + frame.payload.size();
115 }
116 }
117
118 private:
119 SocketRpcTransport<kReadBufferSize>& transport_;
120 std::vector<std::byte> sent_;
121 std::array<std::byte, 256> data_{};
122 std::uniform_int_distribution<size_t> offset_dist_{0, 255};
123 std::uniform_int_distribution<size_t> size_dist_{1, kMaxWriteSize};
124 };
125
126 class SocketSenderThreadCore : public SocketSender, public thread::ThreadCore {
127 public:
SocketSenderThreadCore(SocketRpcTransport<kReadBufferSize> & transport,size_t write_size)128 SocketSenderThreadCore(SocketRpcTransport<kReadBufferSize>& transport,
129 size_t write_size)
130 : SocketSender(transport), write_size_(write_size) {}
131
132 private:
Run()133 void Run() override { Send(write_size_); }
134 size_t write_size_;
135 };
136
TEST(SocketRpcTransportTest,SendAndReceiveFramesOverSocketConnection)137 TEST(SocketRpcTransportTest, SendAndReceiveFramesOverSocketConnection) {
138 constexpr size_t kWriteSize = 8192;
139
140 TestIngress server_ingress(kWriteSize);
141 TestIngress client_ingress(kWriteSize);
142
143 auto server = SocketRpcTransport<kReadBufferSize>(
144 SocketRpcTransport<kReadBufferSize>::kAsServer,
145 kServerPort,
146 server_ingress);
147 auto server_thread = thread::Thread(thread::stl::Options(), server);
148
149 server.WaitUntilReady();
150 auto server_port = server.port();
151
152 auto client = SocketRpcTransport<kReadBufferSize>(
153 SocketRpcTransport<kReadBufferSize>::kAsClient,
154 "localhost",
155 server_port,
156 client_ingress);
157 auto client_thread = thread::Thread(thread::stl::Options(), client);
158
159 client.WaitUntilConnected();
160 server.WaitUntilConnected();
161
162 SocketSenderThreadCore client_sender(client, kWriteSize);
163 SocketSenderThreadCore server_sender(server, kWriteSize);
164
165 auto client_sender_thread =
166 thread::Thread(thread::stl::Options(), client_sender);
167 auto server_sender_thread =
168 thread::Thread(thread::stl::Options(), server_sender);
169
170 client_sender_thread.join();
171 server_sender_thread.join();
172
173 server_ingress.Wait();
174 client_ingress.Wait();
175
176 server.Stop();
177 client.Stop();
178
179 server_thread.join();
180 client_thread.join();
181
182 auto received_by_server = server_ingress.received();
183 EXPECT_EQ(received_by_server.size(), kWriteSize);
184 EXPECT_TRUE(std::equal(received_by_server.begin(),
185 received_by_server.end(),
186 client_sender.sent().begin()));
187
188 auto received_by_client = client_ingress.received();
189 EXPECT_EQ(received_by_client.size(), kWriteSize);
190 EXPECT_TRUE(std::equal(received_by_client.begin(),
191 received_by_client.end(),
192 server_sender.sent().begin()));
193 }
194
TEST(SocketRpcTransportTest,ServerReconnects)195 TEST(SocketRpcTransportTest, ServerReconnects) {
196 // Set up a server and a client that reconnects multiple times. The server
197 // must accept the new connection gracefully.
198 constexpr size_t kWriteSize = 8192;
199 std::vector<std::byte> received;
200
201 TestIngress server_ingress(0);
202 auto server = SocketRpcTransport<kReadBufferSize>(
203 SocketRpcTransport<kReadBufferSize>::kAsServer,
204 kServerPort,
205 server_ingress);
206 auto server_thread = thread::Thread(thread::stl::Options(), server);
207
208 server.WaitUntilReady();
209 auto server_port = server.port();
210 SocketSender server_sender(server);
211
212 {
213 TestIngress client_ingress(kWriteSize);
214 auto client = SocketRpcTransport<kReadBufferSize>(
215 SocketRpcTransport<kReadBufferSize>::kAsClient,
216 "localhost",
217 server_port,
218 client_ingress);
219 auto client_thread = thread::Thread(thread::stl::Options(), client);
220
221 client.WaitUntilConnected();
222 server.WaitUntilConnected();
223
224 server_sender.Send(kWriteSize);
225 client_ingress.Wait();
226 auto client_received = client_ingress.received();
227 std::copy(client_received.begin(),
228 client_received.end(),
229 std::back_inserter(received));
230 EXPECT_EQ(received.size(), kWriteSize);
231
232 // Stop the client but not the server: we're re-using the same server
233 // with a new client below.
234 client.Stop();
235 client_thread.join();
236 }
237
238 // Reconnect to the server and keep sending frames.
239 {
240 TestIngress client_ingress(kWriteSize);
241 auto client = SocketRpcTransport<kReadBufferSize>(
242 SocketRpcTransport<kReadBufferSize>::kAsClient,
243 "localhost",
244 server_port,
245 client_ingress);
246 auto client_thread = thread::Thread(thread::stl::Options(), client);
247
248 client.WaitUntilConnected();
249 server.WaitUntilConnected();
250
251 server_sender.Send(kWriteSize);
252 client_ingress.Wait();
253 auto client_received = client_ingress.received();
254 std::copy(client_received.begin(),
255 client_received.end(),
256 std::back_inserter(received));
257
258 client.Stop();
259 client_thread.join();
260
261 // This time stop the server as well.
262 SocketSender client_sender(client);
263 server.Stop();
264 server_thread.join();
265 }
266
267 EXPECT_EQ(received.size(), 2 * kWriteSize);
268 EXPECT_EQ(server_sender.sent().size(), 2 * kWriteSize);
269 EXPECT_TRUE(std::equal(
270 received.begin(), received.end(), server_sender.sent().begin()));
271 }
272
TEST(SocketRpcTransportTest,ClientReconnects)273 TEST(SocketRpcTransportTest, ClientReconnects) {
274 // Set up a server and a client, then recycle the server. The client must
275 // must reconnect gracefully.
276 constexpr size_t kWriteSize = 8192;
277 uint16_t server_port = 0;
278
279 TestIngress server_ingress(0);
280 TestIngress client_ingress(2 * kWriteSize);
281
282 auto server = std::make_unique<SocketRpcTransport<kReadBufferSize>>(
283 SocketRpcTransport<kReadBufferSize>::kAsServer,
284 kServerPort,
285 server_ingress);
286 auto server_thread = thread::Thread(thread::stl::Options(), *server);
287
288 server->WaitUntilReady();
289 server_port = server->port();
290
291 auto client = SocketRpcTransport<kReadBufferSize>(
292 SocketRpcTransport<kReadBufferSize>::kAsClient,
293 "localhost",
294 server_port,
295 client_ingress);
296 auto client_thread = thread::Thread(thread::stl::Options(), client);
297
298 client.WaitUntilConnected();
299 server->WaitUntilConnected();
300
301 SocketSender client_sender(client);
302 SocketSender server1_sender(*server);
303 std::vector<std::byte> sent_by_server;
304
305 server1_sender.Send(kWriteSize);
306 server->Stop();
307 auto server1_sent = server1_sender.sent();
308 std::copy(server1_sent.begin(),
309 server1_sent.end(),
310 std::back_inserter(sent_by_server));
311
312 server_thread.join();
313 server = nullptr;
314
315 server = std::make_unique<SocketRpcTransport<kReadBufferSize>>(
316 SocketRpcTransport<kReadBufferSize>::kAsServer,
317 server_port,
318 server_ingress);
319 SocketSender server2_sender(*server);
320 server_thread = thread::Thread(thread::stl::Options(), *server);
321
322 client.WaitUntilConnected();
323 server->WaitUntilConnected();
324
325 server2_sender.Send(kWriteSize);
326 client_ingress.Wait();
327
328 server->Stop();
329 auto server2_sent = server2_sender.sent();
330 std::copy(server2_sent.begin(),
331 server2_sent.end(),
332 std::back_inserter(sent_by_server));
333
334 server_thread.join();
335
336 client.Stop();
337 client_thread.join();
338 server = nullptr;
339
340 auto received_by_client = client_ingress.received();
341 EXPECT_EQ(received_by_client.size(), 2 * kWriteSize);
342 EXPECT_TRUE(std::equal(received_by_client.begin(),
343 received_by_client.end(),
344 sent_by_server.begin()));
345 }
346
347 } // namespace
348 } // namespace pw::rpc
349