1 // Copyright 2021 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 // https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14
15 #include "pw_rpc/internal/call.h"
16
17 #include <algorithm>
18 #include <array>
19 #include <cstdint>
20 #include <cstring>
21 #include <optional>
22
23 #include "pw_rpc/internal/test_utils.h"
24 #include "pw_rpc/service.h"
25 #include "pw_rpc_private/fake_server_reader_writer.h"
26 #include "pw_rpc_private/test_method.h"
27 #include "pw_unit_test/framework.h"
28
29 namespace pw::rpc {
30
31 class TestService : public Service {
32 public:
TestService(uint32_t id)33 constexpr TestService(uint32_t id) : Service(id, method) {}
34
35 static constexpr internal::TestMethodUnion method = internal::TestMethod(8);
36 };
37
38 namespace internal {
39 namespace {
40
41 constexpr uint32_t kChannelId = 99;
42 constexpr uint32_t kServiceId = 16;
43 constexpr uint32_t kMethodId = 8;
44 constexpr uint32_t kCallId = 327;
45 constexpr Packet kPacket(
46 pwpb::PacketType::REQUEST, kChannelId, kServiceId, kMethodId, kCallId);
47
48 using ::pw::rpc::internal::test::FakeServerReader;
49 using ::pw::rpc::internal::test::FakeServerReaderWriter;
50 using ::pw::rpc::internal::test::FakeServerWriter;
51 using ::std::byte;
52 using ::testing::Test;
53
54 static_assert(sizeof(Call) ==
55 // IntrusiveList::Item pointer
56 sizeof(IntrusiveList<Call>::Item) +
57 // Endpoint pointer
58 sizeof(Endpoint*) +
59 // call_id, channel_id, service_id, method_id
60 4 * sizeof(uint32_t) +
61 // Packed state and properties
62 sizeof(void*) +
63 // on_error and on_next callbacks
64 2 * sizeof(Function<void(Status)>),
65 "Unexpected padding in Call!");
66
67 static_assert(sizeof(CallProperties) == sizeof(uint8_t));
68
TEST(CallProperties,ValuesMatch)69 TEST(CallProperties, ValuesMatch) {
70 constexpr CallProperties props_1(
71 MethodType::kBidirectionalStreaming, kClientCall, kRawProto);
72 static_assert(props_1.method_type() == MethodType::kBidirectionalStreaming);
73 static_assert(props_1.call_type() == kClientCall);
74 static_assert(props_1.callback_proto_type() == kRawProto);
75
76 constexpr CallProperties props_2(
77 MethodType::kClientStreaming, kServerCall, kProtoStruct);
78 static_assert(props_2.method_type() == MethodType::kClientStreaming);
79 static_assert(props_2.call_type() == kServerCall);
80 static_assert(props_2.callback_proto_type() == kProtoStruct);
81
82 constexpr CallProperties props_3(
83 MethodType::kUnary, kClientCall, kProtoStruct);
84 static_assert(props_3.method_type() == MethodType::kUnary);
85 static_assert(props_3.call_type() == kClientCall);
86 static_assert(props_3.callback_proto_type() == kProtoStruct);
87 }
88
89 class ServerWriterTest : public Test {
90 public:
ServerWriterTest()91 ServerWriterTest() : context_(TestService::method.method()) {
92 rpc_lock().lock();
93 FakeServerWriter writer_temp(context_.get().ClaimLocked());
94 rpc_lock().unlock();
95 writer_ = std::move(writer_temp);
96 }
97
98 ServerContextForTest<TestService, kChannelId, kServiceId, kCallId> context_;
99 FakeServerWriter writer_;
100 };
101
TEST_F(ServerWriterTest,ConstructWithContext_StartsOpen)102 TEST_F(ServerWriterTest, ConstructWithContext_StartsOpen) {
103 EXPECT_TRUE(writer_.active());
104 }
105
TEST_F(ServerWriterTest,Move_ClosesOriginal)106 TEST_F(ServerWriterTest, Move_ClosesOriginal) {
107 FakeServerWriter moved(std::move(writer_));
108
109 #ifndef __clang_analyzer__
110 EXPECT_FALSE(writer_.active());
111 #endif // ignore use-after-move
112 EXPECT_TRUE(moved.active());
113 }
114
TEST_F(ServerWriterTest,DefaultConstruct_Closed)115 TEST_F(ServerWriterTest, DefaultConstruct_Closed) {
116 FakeServerWriter writer;
117 EXPECT_FALSE(writer.active());
118 }
119
TEST_F(ServerWriterTest,Construct_RegistersWithServer)120 TEST_F(ServerWriterTest, Construct_RegistersWithServer) {
121 RpcLockGuard lock;
122 IntrusiveList<Call>::iterator call = context_.server().FindCall(kPacket);
123 ASSERT_NE(call, context_.server().calls_end());
124 EXPECT_EQ(static_cast<void*>(&*call), static_cast<void*>(&writer_));
125 }
126
TEST_F(ServerWriterTest,Destruct_RemovesFromServer)127 TEST_F(ServerWriterTest, Destruct_RemovesFromServer) {
128 {
129 // Note `lock_guard` cannot be used here, because while the constructor
130 // of `FakeServerWriter` requires the lock be held, the destructor acquires
131 // it!
132 rpc_lock().lock();
133 FakeServerWriter writer(context_.get().ClaimLocked());
134 rpc_lock().unlock();
135 }
136
137 RpcLockGuard lock;
138 EXPECT_EQ(context_.server().FindCall(kPacket), context_.server().calls_end());
139 }
140
TEST_F(ServerWriterTest,Finish_RemovesFromServer)141 TEST_F(ServerWriterTest, Finish_RemovesFromServer) {
142 EXPECT_EQ(OkStatus(), writer_.Finish());
143 RpcLockGuard lock;
144 EXPECT_EQ(context_.server().FindCall(kPacket), context_.server().calls_end());
145 }
146
TEST_F(ServerWriterTest,Finish_SendsResponse)147 TEST_F(ServerWriterTest, Finish_SendsResponse) {
148 EXPECT_EQ(OkStatus(), writer_.Finish());
149
150 ASSERT_EQ(context_.output().total_packets(), 1u);
151 const Packet& packet = context_.output().last_packet();
152 EXPECT_EQ(packet.type(), pwpb::PacketType::RESPONSE);
153 EXPECT_EQ(packet.channel_id(), context_.channel_id());
154 EXPECT_EQ(packet.service_id(), context_.service_id());
155 EXPECT_EQ(packet.method_id(), context_.get().method().id());
156 EXPECT_TRUE(packet.payload().empty());
157 EXPECT_EQ(packet.status(), OkStatus());
158 }
159
TEST_F(ServerWriterTest,Finish_ReturnsStatusFromChannelSend)160 TEST_F(ServerWriterTest, Finish_ReturnsStatusFromChannelSend) {
161 context_.output().set_send_status(Status::Unauthenticated());
162
163 // All non-OK statuses are remapped to UNKNOWN.
164 EXPECT_EQ(Status::Unknown(), writer_.Finish());
165 }
166
TEST_F(ServerWriterTest,Finish)167 TEST_F(ServerWriterTest, Finish) {
168 ASSERT_TRUE(writer_.active());
169 EXPECT_EQ(OkStatus(), writer_.Finish());
170 EXPECT_FALSE(writer_.active());
171 EXPECT_EQ(Status::FailedPrecondition(), writer_.Finish());
172 }
173
TEST_F(ServerWriterTest,Open_SendsPacketWithPayload)174 TEST_F(ServerWriterTest, Open_SendsPacketWithPayload) {
175 constexpr byte data[] = {byte{0xf0}, byte{0x0d}};
176 ASSERT_EQ(OkStatus(), writer_.Write(data));
177
178 byte encoded[64];
179 auto result = context_.server_stream(data).Encode(encoded);
180 ASSERT_EQ(OkStatus(), result.status());
181
182 ConstByteSpan payload = context_.output().last_packet().payload();
183 EXPECT_EQ(sizeof(data), payload.size());
184 EXPECT_EQ(0, std::memcmp(data, payload.data(), sizeof(data)));
185 }
186
TEST_F(ServerWriterTest,Closed_IgnoresFinish)187 TEST_F(ServerWriterTest, Closed_IgnoresFinish) {
188 EXPECT_EQ(OkStatus(), writer_.Finish());
189 EXPECT_EQ(Status::FailedPrecondition(), writer_.Finish());
190 }
191
TEST_F(ServerWriterTest,DefaultConstructor_NoClientStream)192 TEST_F(ServerWriterTest, DefaultConstructor_NoClientStream) {
193 FakeServerWriter writer;
194 RpcLockGuard lock;
195 EXPECT_FALSE(writer.as_server_call().has_client_stream());
196 EXPECT_FALSE(writer.as_server_call().client_requested_completion());
197 }
198
TEST_F(ServerWriterTest,Open_NoClientStream)199 TEST_F(ServerWriterTest, Open_NoClientStream) {
200 RpcLockGuard lock;
201 EXPECT_FALSE(writer_.as_server_call().has_client_stream());
202 EXPECT_TRUE(writer_.as_server_call().has_server_stream());
203 EXPECT_FALSE(writer_.as_server_call().client_requested_completion());
204 }
205
206 class ServerReaderTest : public Test {
207 public:
ServerReaderTest()208 ServerReaderTest() : context_(TestService::method.method()) {
209 rpc_lock().lock();
210 FakeServerReader reader_temp(context_.get().ClaimLocked());
211 rpc_lock().unlock();
212 reader_ = std::move(reader_temp);
213 }
214
215 ServerContextForTest<TestService> context_;
216 FakeServerReader reader_;
217 };
218
TEST_F(ServerReaderTest,DefaultConstructor_StreamClosed)219 TEST_F(ServerReaderTest, DefaultConstructor_StreamClosed) {
220 FakeServerReader reader;
221 EXPECT_FALSE(reader.as_server_call().active());
222 RpcLockGuard lock;
223 EXPECT_FALSE(reader.as_server_call().client_requested_completion());
224 }
225
TEST_F(ServerReaderTest,Open_ClientStreamStartsOpen)226 TEST_F(ServerReaderTest, Open_ClientStreamStartsOpen) {
227 RpcLockGuard lock;
228 EXPECT_TRUE(reader_.as_server_call().has_client_stream());
229 EXPECT_FALSE(reader_.as_server_call().client_requested_completion());
230 }
231
TEST_F(ServerReaderTest,Close_ClosesStream)232 TEST_F(ServerReaderTest, Close_ClosesStream) {
233 EXPECT_TRUE(reader_.as_server_call().active());
234 rpc_lock().lock();
235 EXPECT_FALSE(reader_.as_server_call().client_requested_completion());
236 rpc_lock().unlock();
237 EXPECT_EQ(OkStatus(),
238 reader_.as_server_call().CloseAndSendResponse(OkStatus()));
239
240 EXPECT_FALSE(reader_.as_server_call().active());
241 RpcLockGuard lock;
242 EXPECT_TRUE(reader_.as_server_call().client_requested_completion());
243 }
244
TEST_F(ServerReaderTest,RequestCompletion_OnlyMakesClientNotReady)245 TEST_F(ServerReaderTest, RequestCompletion_OnlyMakesClientNotReady) {
246 EXPECT_TRUE(reader_.active());
247 rpc_lock().lock();
248 EXPECT_FALSE(reader_.as_server_call().client_requested_completion());
249 reader_.as_server_call().HandleClientRequestedCompletion();
250
251 EXPECT_TRUE(reader_.active());
252 RpcLockGuard lock;
253 EXPECT_TRUE(reader_.as_server_call().client_requested_completion());
254 }
255
256 class ServerReaderWriterTest : public Test {
257 public:
ServerReaderWriterTest()258 ServerReaderWriterTest() : context_(TestService::method.method()) {
259 rpc_lock().lock();
260 FakeServerReaderWriter reader_writer_temp(context_.get().ClaimLocked());
261 rpc_lock().unlock();
262 reader_writer_ = std::move(reader_writer_temp);
263 }
264
265 ServerContextForTest<TestService> context_;
266 FakeServerReaderWriter reader_writer_;
267 };
268
TEST_F(ServerReaderWriterTest,Move_MaintainsClientStream)269 TEST_F(ServerReaderWriterTest, Move_MaintainsClientStream) {
270 FakeServerReaderWriter destination;
271
272 rpc_lock().lock();
273 EXPECT_FALSE(destination.as_server_call().client_requested_completion());
274 rpc_lock().unlock();
275
276 destination = std::move(reader_writer_);
277 RpcLockGuard lock;
278 EXPECT_TRUE(destination.as_server_call().has_client_stream());
279 EXPECT_FALSE(destination.as_server_call().client_requested_completion());
280 }
281
TEST_F(ServerReaderWriterTest,Move_MovesCallbacks)282 TEST_F(ServerReaderWriterTest, Move_MovesCallbacks) {
283 int calls = 0;
284 reader_writer_.set_on_error([&calls](Status) { calls += 1; });
285 reader_writer_.set_on_next([&calls](ConstByteSpan) { calls += 1; });
286 reader_writer_.set_on_completion_requested_if_enabled(
287 [&calls]() { calls += 1; });
288
289 FakeServerReaderWriter destination(std::move(reader_writer_));
290 rpc_lock().lock();
291 destination.as_server_call().HandlePayload({});
292 rpc_lock().lock();
293 destination.as_server_call().HandleClientRequestedCompletion();
294 rpc_lock().lock();
295 destination.as_server_call().HandleError(Status::Unknown());
296
297 EXPECT_EQ(calls, 2 + PW_RPC_COMPLETION_REQUEST_CALLBACK);
298 }
299
TEST_F(ServerReaderWriterTest,Move_ClearsCallAndChannelId)300 TEST_F(ServerReaderWriterTest, Move_ClearsCallAndChannelId) {
301 rpc_lock().lock();
302 reader_writer_.set_id(999);
303 EXPECT_NE(reader_writer_.channel_id_locked(), 0u);
304 rpc_lock().unlock();
305
306 FakeServerReaderWriter destination(std::move(reader_writer_));
307
308 RpcLockGuard lock;
309 EXPECT_EQ(reader_writer_.id(), 0u);
310 EXPECT_EQ(reader_writer_.channel_id_locked(), 0u);
311 }
312
TEST_F(ServerReaderWriterTest,Move_SourceAwaitingCleanup_CleansUpCalls)313 TEST_F(ServerReaderWriterTest, Move_SourceAwaitingCleanup_CleansUpCalls) {
314 std::optional<Status> on_error_cb;
315 reader_writer_.set_on_error([&on_error_cb](Status error) {
316 ASSERT_FALSE(on_error_cb.has_value());
317 on_error_cb = error;
318 });
319
320 rpc_lock().lock();
321 context_.server().CloseCallAndMarkForCleanup(reader_writer_.as_server_call(),
322 Status::NotFound());
323 rpc_lock().unlock();
324
325 FakeServerReaderWriter destination(std::move(reader_writer_));
326
327 EXPECT_EQ(Status::NotFound(), on_error_cb);
328 }
329
TEST_F(ServerReaderWriterTest,Move_BothAwaitingCleanup_CleansUpCalls)330 TEST_F(ServerReaderWriterTest, Move_BothAwaitingCleanup_CleansUpCalls) {
331 rpc_lock().lock();
332 // Use call ID 123 so this call is distinct from the other.
333 FakeServerReaderWriter destination(context_.get(123).ClaimLocked());
334 rpc_lock().unlock();
335
336 std::optional<Status> destination_on_error_cb;
337 destination.set_on_error([&destination_on_error_cb](Status error) {
338 ASSERT_FALSE(destination_on_error_cb.has_value());
339 destination_on_error_cb = error;
340 });
341
342 std::optional<Status> source_on_error_cb;
343 reader_writer_.set_on_error([&source_on_error_cb](Status error) {
344 ASSERT_FALSE(source_on_error_cb.has_value());
345 source_on_error_cb = error;
346 });
347
348 // Simulate these two calls being closed by another thread.
349 rpc_lock().lock();
350 context_.server().CloseCallAndMarkForCleanup(destination.as_server_call(),
351 Status::NotFound());
352 context_.server().CloseCallAndMarkForCleanup(reader_writer_.as_server_call(),
353 Status::Unauthenticated());
354 rpc_lock().unlock();
355
356 destination = std::move(reader_writer_);
357
358 EXPECT_EQ(Status::NotFound(), destination_on_error_cb);
359 EXPECT_EQ(Status::Unauthenticated(), source_on_error_cb);
360 }
361
TEST_F(ServerReaderWriterTest,Close_ClearsCallAndChannelId)362 TEST_F(ServerReaderWriterTest, Close_ClearsCallAndChannelId) {
363 rpc_lock().lock();
364 reader_writer_.set_id(999);
365 EXPECT_NE(reader_writer_.channel_id_locked(), 0u);
366 rpc_lock().unlock();
367
368 EXPECT_EQ(OkStatus(), reader_writer_.Finish());
369
370 RpcLockGuard lock;
371 EXPECT_EQ(reader_writer_.id(), 0u);
372 EXPECT_EQ(reader_writer_.channel_id_locked(), 0u);
373 }
374
375 } // namespace
376 } // namespace internal
377 } // namespace pw::rpc
378