• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2021 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14 
15 #include "pw_rpc/internal/call.h"
16 
17 #include <algorithm>
18 #include <array>
19 #include <cstdint>
20 #include <cstring>
21 #include <optional>
22 
23 #include "pw_rpc/internal/test_utils.h"
24 #include "pw_rpc/service.h"
25 #include "pw_rpc_private/fake_server_reader_writer.h"
26 #include "pw_rpc_private/test_method.h"
27 #include "pw_unit_test/framework.h"
28 
29 namespace pw::rpc {
30 
31 class TestService : public Service {
32  public:
TestService(uint32_t id)33   constexpr TestService(uint32_t id) : Service(id, method) {}
34 
35   static constexpr internal::TestMethodUnion method = internal::TestMethod(8);
36 };
37 
38 namespace internal {
39 namespace {
40 
41 constexpr uint32_t kChannelId = 99;
42 constexpr uint32_t kServiceId = 16;
43 constexpr uint32_t kMethodId = 8;
44 constexpr uint32_t kCallId = 327;
45 constexpr Packet kPacket(
46     pwpb::PacketType::REQUEST, kChannelId, kServiceId, kMethodId, kCallId);
47 
48 using ::pw::rpc::internal::test::FakeServerReader;
49 using ::pw::rpc::internal::test::FakeServerReaderWriter;
50 using ::pw::rpc::internal::test::FakeServerWriter;
51 using ::std::byte;
52 using ::testing::Test;
53 
54 static_assert(sizeof(Call) ==
55                   // IntrusiveList::Item pointer
56                   sizeof(IntrusiveList<Call>::Item) +
57                       // Endpoint pointer
58                       sizeof(Endpoint*) +
59                       // call_id, channel_id, service_id, method_id
60                       4 * sizeof(uint32_t) +
61                       // Packed state and properties
62                       sizeof(void*) +
63                       // on_error and on_next callbacks
64                       2 * sizeof(Function<void(Status)>),
65               "Unexpected padding in Call!");
66 
67 static_assert(sizeof(CallProperties) == sizeof(uint8_t));
68 
TEST(CallProperties,ValuesMatch)69 TEST(CallProperties, ValuesMatch) {
70   constexpr CallProperties props_1(
71       MethodType::kBidirectionalStreaming, kClientCall, kRawProto);
72   static_assert(props_1.method_type() == MethodType::kBidirectionalStreaming);
73   static_assert(props_1.call_type() == kClientCall);
74   static_assert(props_1.callback_proto_type() == kRawProto);
75 
76   constexpr CallProperties props_2(
77       MethodType::kClientStreaming, kServerCall, kProtoStruct);
78   static_assert(props_2.method_type() == MethodType::kClientStreaming);
79   static_assert(props_2.call_type() == kServerCall);
80   static_assert(props_2.callback_proto_type() == kProtoStruct);
81 
82   constexpr CallProperties props_3(
83       MethodType::kUnary, kClientCall, kProtoStruct);
84   static_assert(props_3.method_type() == MethodType::kUnary);
85   static_assert(props_3.call_type() == kClientCall);
86   static_assert(props_3.callback_proto_type() == kProtoStruct);
87 }
88 
89 class ServerWriterTest : public Test {
90  public:
ServerWriterTest()91   ServerWriterTest() : context_(TestService::method.method()) {
92     rpc_lock().lock();
93     FakeServerWriter writer_temp(context_.get().ClaimLocked());
94     rpc_lock().unlock();
95     writer_ = std::move(writer_temp);
96   }
97 
98   ServerContextForTest<TestService, kChannelId, kServiceId, kCallId> context_;
99   FakeServerWriter writer_;
100 };
101 
TEST_F(ServerWriterTest,ConstructWithContext_StartsOpen)102 TEST_F(ServerWriterTest, ConstructWithContext_StartsOpen) {
103   EXPECT_TRUE(writer_.active());
104 }
105 
TEST_F(ServerWriterTest,Move_ClosesOriginal)106 TEST_F(ServerWriterTest, Move_ClosesOriginal) {
107   FakeServerWriter moved(std::move(writer_));
108 
109 #ifndef __clang_analyzer__
110   EXPECT_FALSE(writer_.active());
111 #endif  // ignore use-after-move
112   EXPECT_TRUE(moved.active());
113 }
114 
TEST_F(ServerWriterTest,DefaultConstruct_Closed)115 TEST_F(ServerWriterTest, DefaultConstruct_Closed) {
116   FakeServerWriter writer;
117   EXPECT_FALSE(writer.active());
118 }
119 
TEST_F(ServerWriterTest,Construct_RegistersWithServer)120 TEST_F(ServerWriterTest, Construct_RegistersWithServer) {
121   RpcLockGuard lock;
122   IntrusiveList<Call>::iterator call = context_.server().FindCall(kPacket);
123   ASSERT_NE(call, context_.server().calls_end());
124   EXPECT_EQ(static_cast<void*>(&*call), static_cast<void*>(&writer_));
125 }
126 
TEST_F(ServerWriterTest,Destruct_RemovesFromServer)127 TEST_F(ServerWriterTest, Destruct_RemovesFromServer) {
128   {
129     // Note `lock_guard` cannot be used here, because while the constructor
130     // of `FakeServerWriter` requires the lock be held, the destructor acquires
131     // it!
132     rpc_lock().lock();
133     FakeServerWriter writer(context_.get().ClaimLocked());
134     rpc_lock().unlock();
135   }
136 
137   RpcLockGuard lock;
138   EXPECT_EQ(context_.server().FindCall(kPacket), context_.server().calls_end());
139 }
140 
TEST_F(ServerWriterTest,Finish_RemovesFromServer)141 TEST_F(ServerWriterTest, Finish_RemovesFromServer) {
142   EXPECT_EQ(OkStatus(), writer_.Finish());
143   RpcLockGuard lock;
144   EXPECT_EQ(context_.server().FindCall(kPacket), context_.server().calls_end());
145 }
146 
TEST_F(ServerWriterTest,Finish_SendsResponse)147 TEST_F(ServerWriterTest, Finish_SendsResponse) {
148   EXPECT_EQ(OkStatus(), writer_.Finish());
149 
150   ASSERT_EQ(context_.output().total_packets(), 1u);
151   const Packet& packet = context_.output().last_packet();
152   EXPECT_EQ(packet.type(), pwpb::PacketType::RESPONSE);
153   EXPECT_EQ(packet.channel_id(), context_.channel_id());
154   EXPECT_EQ(packet.service_id(), context_.service_id());
155   EXPECT_EQ(packet.method_id(), context_.get().method().id());
156   EXPECT_TRUE(packet.payload().empty());
157   EXPECT_EQ(packet.status(), OkStatus());
158 }
159 
TEST_F(ServerWriterTest,Finish_ReturnsStatusFromChannelSend)160 TEST_F(ServerWriterTest, Finish_ReturnsStatusFromChannelSend) {
161   context_.output().set_send_status(Status::Unauthenticated());
162 
163   // All non-OK statuses are remapped to UNKNOWN.
164   EXPECT_EQ(Status::Unknown(), writer_.Finish());
165 }
166 
TEST_F(ServerWriterTest,Finish)167 TEST_F(ServerWriterTest, Finish) {
168   ASSERT_TRUE(writer_.active());
169   EXPECT_EQ(OkStatus(), writer_.Finish());
170   EXPECT_FALSE(writer_.active());
171   EXPECT_EQ(Status::FailedPrecondition(), writer_.Finish());
172 }
173 
TEST_F(ServerWriterTest,Open_SendsPacketWithPayload)174 TEST_F(ServerWriterTest, Open_SendsPacketWithPayload) {
175   constexpr byte data[] = {byte{0xf0}, byte{0x0d}};
176   ASSERT_EQ(OkStatus(), writer_.Write(data));
177 
178   byte encoded[64];
179   auto result = context_.server_stream(data).Encode(encoded);
180   ASSERT_EQ(OkStatus(), result.status());
181 
182   ConstByteSpan payload = context_.output().last_packet().payload();
183   EXPECT_EQ(sizeof(data), payload.size());
184   EXPECT_EQ(0, std::memcmp(data, payload.data(), sizeof(data)));
185 }
186 
TEST_F(ServerWriterTest,Closed_IgnoresFinish)187 TEST_F(ServerWriterTest, Closed_IgnoresFinish) {
188   EXPECT_EQ(OkStatus(), writer_.Finish());
189   EXPECT_EQ(Status::FailedPrecondition(), writer_.Finish());
190 }
191 
TEST_F(ServerWriterTest,DefaultConstructor_NoClientStream)192 TEST_F(ServerWriterTest, DefaultConstructor_NoClientStream) {
193   FakeServerWriter writer;
194   RpcLockGuard lock;
195   EXPECT_FALSE(writer.as_server_call().has_client_stream());
196   EXPECT_FALSE(writer.as_server_call().client_requested_completion());
197 }
198 
TEST_F(ServerWriterTest,Open_NoClientStream)199 TEST_F(ServerWriterTest, Open_NoClientStream) {
200   RpcLockGuard lock;
201   EXPECT_FALSE(writer_.as_server_call().has_client_stream());
202   EXPECT_TRUE(writer_.as_server_call().has_server_stream());
203   EXPECT_FALSE(writer_.as_server_call().client_requested_completion());
204 }
205 
206 class ServerReaderTest : public Test {
207  public:
ServerReaderTest()208   ServerReaderTest() : context_(TestService::method.method()) {
209     rpc_lock().lock();
210     FakeServerReader reader_temp(context_.get().ClaimLocked());
211     rpc_lock().unlock();
212     reader_ = std::move(reader_temp);
213   }
214 
215   ServerContextForTest<TestService> context_;
216   FakeServerReader reader_;
217 };
218 
TEST_F(ServerReaderTest,DefaultConstructor_StreamClosed)219 TEST_F(ServerReaderTest, DefaultConstructor_StreamClosed) {
220   FakeServerReader reader;
221   EXPECT_FALSE(reader.as_server_call().active());
222   RpcLockGuard lock;
223   EXPECT_FALSE(reader.as_server_call().client_requested_completion());
224 }
225 
TEST_F(ServerReaderTest,Open_ClientStreamStartsOpen)226 TEST_F(ServerReaderTest, Open_ClientStreamStartsOpen) {
227   RpcLockGuard lock;
228   EXPECT_TRUE(reader_.as_server_call().has_client_stream());
229   EXPECT_FALSE(reader_.as_server_call().client_requested_completion());
230 }
231 
TEST_F(ServerReaderTest,Close_ClosesStream)232 TEST_F(ServerReaderTest, Close_ClosesStream) {
233   EXPECT_TRUE(reader_.as_server_call().active());
234   rpc_lock().lock();
235   EXPECT_FALSE(reader_.as_server_call().client_requested_completion());
236   rpc_lock().unlock();
237   EXPECT_EQ(OkStatus(),
238             reader_.as_server_call().CloseAndSendResponse(OkStatus()));
239 
240   EXPECT_FALSE(reader_.as_server_call().active());
241   RpcLockGuard lock;
242   EXPECT_TRUE(reader_.as_server_call().client_requested_completion());
243 }
244 
TEST_F(ServerReaderTest,RequestCompletion_OnlyMakesClientNotReady)245 TEST_F(ServerReaderTest, RequestCompletion_OnlyMakesClientNotReady) {
246   EXPECT_TRUE(reader_.active());
247   rpc_lock().lock();
248   EXPECT_FALSE(reader_.as_server_call().client_requested_completion());
249   reader_.as_server_call().HandleClientRequestedCompletion();
250 
251   EXPECT_TRUE(reader_.active());
252   RpcLockGuard lock;
253   EXPECT_TRUE(reader_.as_server_call().client_requested_completion());
254 }
255 
256 class ServerReaderWriterTest : public Test {
257  public:
ServerReaderWriterTest()258   ServerReaderWriterTest() : context_(TestService::method.method()) {
259     rpc_lock().lock();
260     FakeServerReaderWriter reader_writer_temp(context_.get().ClaimLocked());
261     rpc_lock().unlock();
262     reader_writer_ = std::move(reader_writer_temp);
263   }
264 
265   ServerContextForTest<TestService> context_;
266   FakeServerReaderWriter reader_writer_;
267 };
268 
TEST_F(ServerReaderWriterTest,Move_MaintainsClientStream)269 TEST_F(ServerReaderWriterTest, Move_MaintainsClientStream) {
270   FakeServerReaderWriter destination;
271 
272   rpc_lock().lock();
273   EXPECT_FALSE(destination.as_server_call().client_requested_completion());
274   rpc_lock().unlock();
275 
276   destination = std::move(reader_writer_);
277   RpcLockGuard lock;
278   EXPECT_TRUE(destination.as_server_call().has_client_stream());
279   EXPECT_FALSE(destination.as_server_call().client_requested_completion());
280 }
281 
TEST_F(ServerReaderWriterTest,Move_MovesCallbacks)282 TEST_F(ServerReaderWriterTest, Move_MovesCallbacks) {
283   int calls = 0;
284   reader_writer_.set_on_error([&calls](Status) { calls += 1; });
285   reader_writer_.set_on_next([&calls](ConstByteSpan) { calls += 1; });
286   reader_writer_.set_on_completion_requested_if_enabled(
287       [&calls]() { calls += 1; });
288 
289   FakeServerReaderWriter destination(std::move(reader_writer_));
290   rpc_lock().lock();
291   destination.as_server_call().HandlePayload({});
292   rpc_lock().lock();
293   destination.as_server_call().HandleClientRequestedCompletion();
294   rpc_lock().lock();
295   destination.as_server_call().HandleError(Status::Unknown());
296 
297   EXPECT_EQ(calls, 2 + PW_RPC_COMPLETION_REQUEST_CALLBACK);
298 }
299 
TEST_F(ServerReaderWriterTest,Move_ClearsCallAndChannelId)300 TEST_F(ServerReaderWriterTest, Move_ClearsCallAndChannelId) {
301   rpc_lock().lock();
302   reader_writer_.set_id(999);
303   EXPECT_NE(reader_writer_.channel_id_locked(), 0u);
304   rpc_lock().unlock();
305 
306   FakeServerReaderWriter destination(std::move(reader_writer_));
307 
308   RpcLockGuard lock;
309   EXPECT_EQ(reader_writer_.id(), 0u);
310   EXPECT_EQ(reader_writer_.channel_id_locked(), 0u);
311 }
312 
TEST_F(ServerReaderWriterTest,Move_SourceAwaitingCleanup_CleansUpCalls)313 TEST_F(ServerReaderWriterTest, Move_SourceAwaitingCleanup_CleansUpCalls) {
314   std::optional<Status> on_error_cb;
315   reader_writer_.set_on_error([&on_error_cb](Status error) {
316     ASSERT_FALSE(on_error_cb.has_value());
317     on_error_cb = error;
318   });
319 
320   rpc_lock().lock();
321   context_.server().CloseCallAndMarkForCleanup(reader_writer_.as_server_call(),
322                                                Status::NotFound());
323   rpc_lock().unlock();
324 
325   FakeServerReaderWriter destination(std::move(reader_writer_));
326 
327   EXPECT_EQ(Status::NotFound(), on_error_cb);
328 }
329 
TEST_F(ServerReaderWriterTest,Move_BothAwaitingCleanup_CleansUpCalls)330 TEST_F(ServerReaderWriterTest, Move_BothAwaitingCleanup_CleansUpCalls) {
331   rpc_lock().lock();
332   // Use call ID 123 so this call is distinct from the other.
333   FakeServerReaderWriter destination(context_.get(123).ClaimLocked());
334   rpc_lock().unlock();
335 
336   std::optional<Status> destination_on_error_cb;
337   destination.set_on_error([&destination_on_error_cb](Status error) {
338     ASSERT_FALSE(destination_on_error_cb.has_value());
339     destination_on_error_cb = error;
340   });
341 
342   std::optional<Status> source_on_error_cb;
343   reader_writer_.set_on_error([&source_on_error_cb](Status error) {
344     ASSERT_FALSE(source_on_error_cb.has_value());
345     source_on_error_cb = error;
346   });
347 
348   // Simulate these two calls being closed by another thread.
349   rpc_lock().lock();
350   context_.server().CloseCallAndMarkForCleanup(destination.as_server_call(),
351                                                Status::NotFound());
352   context_.server().CloseCallAndMarkForCleanup(reader_writer_.as_server_call(),
353                                                Status::Unauthenticated());
354   rpc_lock().unlock();
355 
356   destination = std::move(reader_writer_);
357 
358   EXPECT_EQ(Status::NotFound(), destination_on_error_cb);
359   EXPECT_EQ(Status::Unauthenticated(), source_on_error_cb);
360 }
361 
TEST_F(ServerReaderWriterTest,Close_ClearsCallAndChannelId)362 TEST_F(ServerReaderWriterTest, Close_ClearsCallAndChannelId) {
363   rpc_lock().lock();
364   reader_writer_.set_id(999);
365   EXPECT_NE(reader_writer_.channel_id_locked(), 0u);
366   rpc_lock().unlock();
367 
368   EXPECT_EQ(OkStatus(), reader_writer_.Finish());
369 
370   RpcLockGuard lock;
371   EXPECT_EQ(reader_writer_.id(), 0u);
372   EXPECT_EQ(reader_writer_.channel_id_locked(), 0u);
373 }
374 
375 }  // namespace
376 }  // namespace internal
377 }  // namespace pw::rpc
378