1 // Copyright 2020 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 // https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14
15 #include "pw_rpc/server.h"
16
17 #include <array>
18 #include <cstdint>
19
20 #include "pw_assert/check.h"
21 #include "pw_rpc/internal/call.h"
22 #include "pw_rpc/internal/method.h"
23 #include "pw_rpc/internal/packet.h"
24 #include "pw_rpc/internal/test_utils.h"
25 #include "pw_rpc/service.h"
26 #include "pw_rpc_private/fake_server_reader_writer.h"
27 #include "pw_rpc_private/test_method.h"
28 #include "pw_unit_test/framework.h"
29
30 namespace pw::rpc {
31
32 class ServerTestHelper {
33 public:
FindMethod(Server & server,uint32_t service_id,uint32_t method_id)34 static std::tuple<Service*, const internal::Method*> FindMethod(
35 Server& server, uint32_t service_id, uint32_t method_id) {
36 return server.FindMethod(service_id, method_id);
37 }
38 };
39
40 namespace {
41
42 using std::byte;
43
44 using internal::Packet;
45 using internal::TestMethod;
46 using internal::TestMethodUnion;
47 using internal::pwpb::PacketType;
48
49 class TestService : public Service {
50 public:
TestService(uint32_t service_id)51 TestService(uint32_t service_id)
52 : Service(service_id, methods_),
53 methods_{
54 TestMethod(100, MethodType::kBidirectionalStreaming),
55 TestMethod(200),
56 } {}
57
method(uint32_t id)58 const TestMethod& method(uint32_t id) {
59 for (TestMethodUnion& method : methods_) {
60 if (method.method().id() == id) {
61 return method.test_method();
62 }
63 }
64
65 PW_CRASH("Invalid method ID %u", static_cast<unsigned>(id));
66 }
67
68 private:
69 std::array<TestMethodUnion, 2> methods_;
70 };
71
72 class EmptyService : public Service {
73 public:
EmptyService()74 constexpr EmptyService() : Service(200, methods_) {}
75
76 private:
77 static constexpr std::array<TestMethodUnion, 0> methods_ = {};
78 };
79
80 uint32_t kDefaultCallId = 24601;
81
82 class BasicServer : public ::testing::Test {
83 protected:
84 static constexpr byte kDefaultPayload[] = {
85 byte(0x82), byte(0x02), byte(0xff), byte(0xff)};
86
BasicServer()87 BasicServer()
88 : channels_{
89 Channel::Create<1>(&output_),
90 Channel::Create<2>(&output_),
91 Channel(), // available for assignment
92 },
93 server_(channels_),
94 service_1_(1),
95 service_42_(42) {
96 server_.RegisterService(service_1_, service_42_, empty_service_);
97 }
98
EncodePacket(PacketType type,uint32_t channel_id,uint32_t service_id,uint32_t method_id,uint32_t call_id=kDefaultCallId)99 span<const byte> EncodePacket(PacketType type,
100 uint32_t channel_id,
101 uint32_t service_id,
102 uint32_t method_id,
103 uint32_t call_id = kDefaultCallId) {
104 return EncodePacketWithBody(type,
105 channel_id,
106 service_id,
107 method_id,
108 call_id,
109 kDefaultPayload,
110 OkStatus());
111 }
112
EncodeCancel(uint32_t channel_id=1,uint32_t service_id=42,uint32_t method_id=100,uint32_t call_id=kDefaultCallId)113 span<const byte> EncodeCancel(uint32_t channel_id = 1,
114 uint32_t service_id = 42,
115 uint32_t method_id = 100,
116 uint32_t call_id = kDefaultCallId) {
117 return EncodePacketWithBody(PacketType::CLIENT_ERROR,
118 channel_id,
119 service_id,
120 method_id,
121 call_id,
122 {},
123 Status::Cancelled());
124 }
125
126 template <typename T = ConstByteSpan>
PacketForRpc(PacketType type,Status status=OkStatus (),T && payload={},uint32_t call_id=kDefaultCallId)127 ConstByteSpan PacketForRpc(PacketType type,
128 Status status = OkStatus(),
129 T&& payload = {},
130 uint32_t call_id = kDefaultCallId) {
131 return EncodePacketWithBody(
132 type, 1, 42, 100, call_id, as_bytes(span(payload)), status);
133 }
134
135 RawFakeChannelOutput<2> output_;
136 std::array<Channel, 3> channels_;
137 Server server_;
138 TestService service_1_;
139 TestService service_42_;
140 EmptyService empty_service_;
141
142 private:
143 byte request_buffer_[64];
144
EncodePacketWithBody(PacketType type,uint32_t channel_id,uint32_t service_id,uint32_t method_id,uint32_t call_id,span<const byte> payload,Status status)145 span<const byte> EncodePacketWithBody(PacketType type,
146 uint32_t channel_id,
147 uint32_t service_id,
148 uint32_t method_id,
149 uint32_t call_id,
150 span<const byte> payload,
151 Status status) {
152 auto result =
153 Packet(
154 type, channel_id, service_id, method_id, call_id, payload, status)
155 .Encode(request_buffer_);
156 EXPECT_EQ(OkStatus(), result.status());
157 return result.value_or(ConstByteSpan());
158 }
159 };
160
TEST_F(BasicServer,IsServiceRegistered)161 TEST_F(BasicServer, IsServiceRegistered) {
162 TestService unregisteredService(0);
163 EXPECT_FALSE(server_.IsServiceRegistered(unregisteredService));
164 EXPECT_TRUE(server_.IsServiceRegistered(service_1_));
165 }
166
TEST_F(BasicServer,ProcessPacket_ValidMethodInService1_InvokesMethod)167 TEST_F(BasicServer, ProcessPacket_ValidMethodInService1_InvokesMethod) {
168 EXPECT_EQ(
169 OkStatus(),
170 server_.ProcessPacket(EncodePacket(PacketType::REQUEST, 1, 1, 100)));
171
172 const TestMethod& method = service_1_.method(100);
173 EXPECT_EQ(1u, method.last_channel_id());
174 ASSERT_EQ(sizeof(kDefaultPayload), method.last_request().payload().size());
175 EXPECT_EQ(std::memcmp(kDefaultPayload,
176 method.last_request().payload().data(),
177 method.last_request().payload().size()),
178 0);
179 }
180
TEST_F(BasicServer,ProcessPacket_ValidMethodInService42_InvokesMethod)181 TEST_F(BasicServer, ProcessPacket_ValidMethodInService42_InvokesMethod) {
182 EXPECT_EQ(
183 OkStatus(),
184 server_.ProcessPacket(EncodePacket(PacketType::REQUEST, 1, 42, 200)));
185
186 const TestMethod& method = service_42_.method(200);
187 EXPECT_EQ(1u, method.last_channel_id());
188 ASSERT_EQ(sizeof(kDefaultPayload), method.last_request().payload().size());
189 EXPECT_EQ(std::memcmp(kDefaultPayload,
190 method.last_request().payload().data(),
191 method.last_request().payload().size()),
192 0);
193 }
194
TEST_F(BasicServer,UnregisterService_CannotCallMethod)195 TEST_F(BasicServer, UnregisterService_CannotCallMethod) {
196 const uint32_t kCallId = 8675309;
197 server_.UnregisterService(service_1_, service_42_);
198
199 EXPECT_EQ(OkStatus(),
200 server_.ProcessPacket(
201 EncodePacket(PacketType::REQUEST, 1, 1, 100, kCallId)));
202
203 const Packet& packet =
204 static_cast<internal::test::FakeChannelOutput&>(output_).last_packet();
205 EXPECT_EQ(packet.type(), PacketType::SERVER_ERROR);
206 EXPECT_EQ(packet.channel_id(), 1u);
207 EXPECT_EQ(packet.service_id(), 1u);
208 EXPECT_EQ(packet.method_id(), 100u);
209 EXPECT_EQ(packet.call_id(), kCallId);
210 EXPECT_EQ(packet.status(), Status::NotFound());
211 }
212
TEST_F(BasicServer,UnregisterService_AlreadyUnregistered_DoesNothing)213 TEST_F(BasicServer, UnregisterService_AlreadyUnregistered_DoesNothing) {
214 server_.UnregisterService(service_42_, service_42_, service_42_);
215 server_.UnregisterService(service_42_);
216
217 EXPECT_EQ(
218 OkStatus(),
219 server_.ProcessPacket(EncodePacket(PacketType::REQUEST, 1, 1, 100)));
220
221 const TestMethod& method = service_1_.method(100);
222 EXPECT_EQ(1u, method.last_channel_id());
223 ASSERT_EQ(sizeof(kDefaultPayload), method.last_request().payload().size());
224 EXPECT_EQ(std::memcmp(kDefaultPayload,
225 method.last_request().payload().data(),
226 method.last_request().payload().size()),
227 0);
228 }
229
TEST_F(BasicServer,ProcessPacket_IncompletePacket_NothingIsInvoked)230 TEST_F(BasicServer, ProcessPacket_IncompletePacket_NothingIsInvoked) {
231 EXPECT_EQ(
232 Status::DataLoss(),
233 server_.ProcessPacket(EncodePacket(PacketType::REQUEST, 0, 42, 101)));
234 EXPECT_EQ(
235 Status::DataLoss(),
236 server_.ProcessPacket(EncodePacket(PacketType::REQUEST, 1, 0, 101)));
237 EXPECT_EQ(Status::DataLoss(),
238 server_.ProcessPacket(EncodePacket(PacketType::REQUEST, 1, 42, 0)));
239
240 EXPECT_EQ(0u, service_42_.method(100).last_channel_id());
241 EXPECT_EQ(0u, service_42_.method(200).last_channel_id());
242 }
243
TEST_F(BasicServer,ProcessPacket_NoChannel_SendsNothing)244 TEST_F(BasicServer, ProcessPacket_NoChannel_SendsNothing) {
245 EXPECT_EQ(
246 Status::DataLoss(),
247 server_.ProcessPacket(EncodePacket(PacketType::REQUEST, 0, 42, 101)));
248
249 EXPECT_EQ(output_.total_packets(), 0u);
250 }
251
TEST_F(BasicServer,ProcessPacket_NoService_SendsNothing)252 TEST_F(BasicServer, ProcessPacket_NoService_SendsNothing) {
253 EXPECT_EQ(
254 Status::DataLoss(),
255 server_.ProcessPacket(EncodePacket(PacketType::REQUEST, 1, 0, 101)));
256
257 EXPECT_EQ(output_.total_packets(), 0u);
258 }
259
TEST_F(BasicServer,ProcessPacket_NoMethod_SendsNothing)260 TEST_F(BasicServer, ProcessPacket_NoMethod_SendsNothing) {
261 EXPECT_EQ(Status::DataLoss(),
262 server_.ProcessPacket(EncodePacket(PacketType::REQUEST, 1, 42, 0)));
263
264 EXPECT_EQ(output_.total_packets(), 0u);
265 }
266
TEST_F(BasicServer,ProcessPacket_InvalidMethod_NothingIsInvoked)267 TEST_F(BasicServer, ProcessPacket_InvalidMethod_NothingIsInvoked) {
268 EXPECT_EQ(
269 OkStatus(),
270 server_.ProcessPacket(EncodePacket(PacketType::REQUEST, 1, 42, 101)));
271
272 EXPECT_EQ(0u, service_42_.method(100).last_channel_id());
273 EXPECT_EQ(0u, service_42_.method(200).last_channel_id());
274 }
275
TEST_F(BasicServer,ProcessPacket_ClientErrorWithInvalidMethod_NoResponse)276 TEST_F(BasicServer, ProcessPacket_ClientErrorWithInvalidMethod_NoResponse) {
277 EXPECT_EQ(OkStatus(),
278 server_.ProcessPacket(
279 EncodePacket(PacketType::CLIENT_ERROR, 1, 42, 101)));
280
281 EXPECT_EQ(0u, output_.total_packets());
282 }
283
TEST_F(BasicServer,ProcessPacket_InvalidMethod_SendsError)284 TEST_F(BasicServer, ProcessPacket_InvalidMethod_SendsError) {
285 EXPECT_EQ(
286 OkStatus(),
287 server_.ProcessPacket(EncodePacket(PacketType::REQUEST, 1, 42, 27)));
288
289 const Packet& packet =
290 static_cast<internal::test::FakeChannelOutput&>(output_).last_packet();
291 EXPECT_EQ(packet.type(), PacketType::SERVER_ERROR);
292 EXPECT_EQ(packet.channel_id(), 1u);
293 EXPECT_EQ(packet.service_id(), 42u);
294 EXPECT_EQ(packet.method_id(), 27u); // No method ID 27
295 EXPECT_EQ(packet.status(), Status::NotFound());
296 }
297
TEST_F(BasicServer,ProcessPacket_InvalidService_SendsError)298 TEST_F(BasicServer, ProcessPacket_InvalidService_SendsError) {
299 EXPECT_EQ(
300 OkStatus(),
301 server_.ProcessPacket(EncodePacket(PacketType::REQUEST, 1, 43, 27)));
302
303 const Packet& packet =
304 static_cast<internal::test::FakeChannelOutput&>(output_).last_packet();
305 EXPECT_EQ(packet.type(), PacketType::SERVER_ERROR);
306 EXPECT_EQ(packet.channel_id(), 1u);
307 EXPECT_EQ(packet.service_id(), 43u); // No service ID 43
308 EXPECT_EQ(packet.method_id(), 27u);
309 EXPECT_EQ(packet.status(), Status::NotFound());
310 }
311
TEST_F(BasicServer,ProcessPacket_UnassignedChannel)312 TEST_F(BasicServer, ProcessPacket_UnassignedChannel) {
313 EXPECT_EQ(Status::Unavailable(),
314 server_.ProcessPacket(
315 EncodePacket(PacketType::REQUEST, /*channel_id=*/99, 42, 27)));
316 }
317
TEST_F(BasicServer,ProcessPacket_ClientErrorOnUnassignedChannel_NoResponse)318 TEST_F(BasicServer, ProcessPacket_ClientErrorOnUnassignedChannel_NoResponse) {
319 channels_[2] = Channel::Create<3>(&output_); // Occupy only available channel
320
321 EXPECT_EQ(Status::Unavailable(),
322 server_.ProcessPacket(EncodePacket(
323 PacketType::CLIENT_ERROR, /*channel_id=*/99, 42, 27)));
324
325 EXPECT_EQ(0u, output_.total_packets());
326 }
327
TEST_F(BasicServer,ProcessPacket_Cancel_MethodNotActive_SendsNothing)328 TEST_F(BasicServer, ProcessPacket_Cancel_MethodNotActive_SendsNothing) {
329 // Set up a fake ServerWriter representing an ongoing RPC.
330 EXPECT_EQ(OkStatus(), server_.ProcessPacket(EncodeCancel(1, 42, 100)));
331
332 EXPECT_EQ(output_.total_packets(), 0u);
333 }
334
GetChannel(internal::Endpoint & endpoint,uint32_t id)335 const Channel* GetChannel(internal::Endpoint& endpoint, uint32_t id) {
336 internal::RpcLockGuard lock;
337 return endpoint.GetInternalChannel(id);
338 }
339
TEST_F(BasicServer,CloseChannel_NoCalls)340 TEST_F(BasicServer, CloseChannel_NoCalls) {
341 EXPECT_NE(nullptr, GetChannel(server_, 2));
342 EXPECT_EQ(OkStatus(), server_.CloseChannel(2));
343 EXPECT_EQ(nullptr, GetChannel(server_, 2));
344 ASSERT_EQ(output_.total_packets(), 0u);
345 }
346
TEST_F(BasicServer,CloseChannel_UnknownChannel)347 TEST_F(BasicServer, CloseChannel_UnknownChannel) {
348 ASSERT_EQ(nullptr, GetChannel(server_, 13579));
349 EXPECT_EQ(Status::NotFound(), server_.CloseChannel(13579));
350 }
351
TEST_F(BasicServer,CloseChannel_PendingCall)352 TEST_F(BasicServer, CloseChannel_PendingCall) {
353 EXPECT_NE(nullptr, GetChannel(server_, 1));
354 EXPECT_EQ(static_cast<internal::Endpoint&>(server_).active_call_count(), 0u);
355
356 internal::test::FakeServerReaderWriter call;
357 service_42_.method(100).keep_call_active(call);
358
359 EXPECT_EQ(
360 OkStatus(),
361 server_.ProcessPacket(EncodePacket(PacketType::REQUEST, 1, 42, 100)));
362
363 Status on_error_status;
364 call.set_on_error(
365 [&on_error_status](Status error) { on_error_status = error; });
366
367 ASSERT_TRUE(call.active());
368 EXPECT_EQ(static_cast<internal::Endpoint&>(server_).active_call_count(), 1u);
369
370 EXPECT_EQ(OkStatus(), server_.CloseChannel(1));
371 EXPECT_EQ(nullptr, GetChannel(server_, 1));
372
373 EXPECT_EQ(static_cast<internal::Endpoint&>(server_).active_call_count(), 0u);
374
375 // Should call on_error, but not send a packet since the channel is closed.
376 EXPECT_EQ(Status::Aborted(), on_error_status);
377 ASSERT_EQ(output_.total_packets(), 0u);
378 }
379
TEST_F(BasicServer,OpenChannel_UnusedSlot)380 TEST_F(BasicServer, OpenChannel_UnusedSlot) {
381 const span request = EncodePacket(PacketType::REQUEST, 9, 42, 100);
382 EXPECT_EQ(Status::Unavailable(), server_.ProcessPacket(request));
383
384 EXPECT_EQ(OkStatus(), server_.OpenChannel(9, output_));
385 EXPECT_EQ(OkStatus(), server_.ProcessPacket(request));
386
387 const Packet& packet =
388 static_cast<internal::test::FakeChannelOutput&>(output_).last_packet();
389 EXPECT_EQ(packet.type(), PacketType::RESPONSE);
390 EXPECT_EQ(packet.channel_id(), 9u);
391 EXPECT_EQ(packet.service_id(), 42u);
392 EXPECT_EQ(packet.method_id(), 100u);
393 }
394
TEST_F(BasicServer,OpenChannel_AlreadyExists)395 TEST_F(BasicServer, OpenChannel_AlreadyExists) {
396 ASSERT_NE(nullptr, GetChannel(server_, 1));
397 EXPECT_EQ(Status::AlreadyExists(), server_.OpenChannel(1, output_));
398 }
399
TEST_F(BasicServer,OpenChannel_AdditionalSlot)400 TEST_F(BasicServer, OpenChannel_AdditionalSlot) {
401 EXPECT_EQ(OkStatus(), server_.OpenChannel(3, output_));
402
403 constexpr Status kExpected =
404 PW_RPC_DYNAMIC_ALLOCATION == 0 ? Status::ResourceExhausted() : OkStatus();
405 EXPECT_EQ(kExpected, server_.OpenChannel(19823, output_));
406 }
407
TEST_F(BasicServer,FindMethod_FoundOkOptionallyCheckType)408 TEST_F(BasicServer, FindMethod_FoundOkOptionallyCheckType) {
409 const auto [service, method] = ServerTestHelper::FindMethod(server_, 1, 100);
410 ASSERT_TRUE(service != nullptr);
411 ASSERT_TRUE(method != nullptr);
412 #if PW_RPC_METHOD_STORES_TYPE
413 EXPECT_EQ(MethodType::kBidirectionalStreaming, method->type());
414 #endif
415 }
416
TEST_F(BasicServer,FindMethod_NotFound)417 TEST_F(BasicServer, FindMethod_NotFound) {
418 {
419 const auto [service, method] =
420 ServerTestHelper::FindMethod(server_, 2, 100);
421 ASSERT_TRUE(service == nullptr);
422 ASSERT_TRUE(method == nullptr);
423 }
424
425 {
426 const auto [service, method] =
427 ServerTestHelper::FindMethod(server_, 1, 101);
428 ASSERT_TRUE(service != nullptr);
429 ASSERT_TRUE(method == nullptr);
430 }
431 }
432
433 class BidiMethod : public BasicServer {
434 protected:
BidiMethod()435 BidiMethod() {
436 internal::rpc_lock().lock();
437 internal::CallContext context(server_,
438 channels_[0].id(),
439 service_42_,
440 service_42_.method(100),
441 kDefaultCallId);
442 // A local temporary is required since the constructor requires a lock,
443 // but the *move* constructor takes out the lock.
444 internal::test::FakeServerReaderWriter responder_temp(
445 context.ClaimLocked());
446 internal::rpc_lock().unlock();
447 responder_ = std::move(responder_temp);
448 PW_CHECK(responder_.active());
449 }
450
451 internal::test::FakeServerReaderWriter responder_;
452 };
453
TEST_F(BidiMethod,DuplicateCallId_CancelsExistingThenCallsAgain)454 TEST_F(BidiMethod, DuplicateCallId_CancelsExistingThenCallsAgain) {
455 int cancelled = 0;
456 responder_.set_on_error([&cancelled](Status error) {
457 if (error.IsCancelled()) {
458 cancelled += 1;
459 }
460 });
461
462 const TestMethod& method = service_42_.method(100);
463 ASSERT_EQ(method.invocations(), 0u);
464
465 EXPECT_EQ(OkStatus(),
466 server_.ProcessPacket(PacketForRpc(PacketType::REQUEST)));
467
468 EXPECT_EQ(cancelled, 1);
469 EXPECT_EQ(method.invocations(), 1u);
470 }
471
TEST_F(BidiMethod,DuplicateMethodDifferentCallId_NotCancelled)472 TEST_F(BidiMethod, DuplicateMethodDifferentCallId_NotCancelled) {
473 int cancelled = 0;
474 responder_.set_on_error([&cancelled](Status error) {
475 if (error.IsCancelled()) {
476 cancelled += 1;
477 }
478 });
479
480 const uint32_t kSecondCallId = 1625;
481 EXPECT_EQ(OkStatus(),
482 server_.ProcessPacket(PacketForRpc(
483 PacketType::REQUEST, OkStatus(), {}, kSecondCallId)));
484
485 EXPECT_EQ(cancelled, 0);
486 }
487
span_as_cstr(ConstByteSpan span)488 const char* span_as_cstr(ConstByteSpan span) {
489 return reinterpret_cast<const char*>(span.data());
490 }
491
TEST_F(BidiMethod,DuplicateMethodDifferentCallIdEachCallGetsSeparateResponse)492 TEST_F(BidiMethod, DuplicateMethodDifferentCallIdEachCallGetsSeparateResponse) {
493 const uint32_t kSecondCallId = 1625;
494
495 internal::rpc_lock().lock();
496 internal::test::FakeServerReaderWriter responder_2(
497 internal::CallContext(server_,
498 channels_[0].id(),
499 service_42_,
500 service_42_.method(100),
501 kSecondCallId)
502 .ClaimLocked());
503 internal::rpc_lock().unlock();
504
505 ConstByteSpan data_1 = as_bytes(span("data_1_unset"));
506 responder_.set_on_next(
507 [&data_1](ConstByteSpan payload) { data_1 = payload; });
508
509 ConstByteSpan data_2 = as_bytes(span("data_2_unset"));
510 responder_2.set_on_next(
511 [&data_2](ConstByteSpan payload) { data_2 = payload; });
512
513 const char* kMessage1 = "hello_1";
514 const char* kMessage2 = "hello_2";
515
516 EXPECT_EQ(
517 OkStatus(),
518 server_.ProcessPacket(PacketForRpc(
519 PacketType::CLIENT_STREAM, OkStatus(), "hello_2", kSecondCallId)));
520
521 EXPECT_STREQ(span_as_cstr(data_2), kMessage2);
522
523 EXPECT_EQ(
524 OkStatus(),
525 server_.ProcessPacket(PacketForRpc(
526 PacketType::CLIENT_STREAM, OkStatus(), "hello_1", kDefaultCallId)));
527
528 EXPECT_STREQ(span_as_cstr(data_1), kMessage1);
529 }
530
TEST_F(BidiMethod,Cancel_ClosesServerWriter)531 TEST_F(BidiMethod, Cancel_ClosesServerWriter) {
532 EXPECT_EQ(OkStatus(), server_.ProcessPacket(EncodeCancel()));
533
534 EXPECT_FALSE(responder_.active());
535 }
536
TEST_F(BidiMethod,Cancel_SendsNoResponse)537 TEST_F(BidiMethod, Cancel_SendsNoResponse) {
538 EXPECT_EQ(OkStatus(), server_.ProcessPacket(EncodeCancel()));
539
540 EXPECT_EQ(output_.total_packets(), 0u);
541 }
542
TEST_F(BidiMethod,ClientError_ClosesServerWriterWithoutResponse)543 TEST_F(BidiMethod, ClientError_ClosesServerWriterWithoutResponse) {
544 ASSERT_EQ(OkStatus(),
545 server_.ProcessPacket(PacketForRpc(PacketType::CLIENT_ERROR)));
546
547 EXPECT_FALSE(responder_.active());
548 EXPECT_EQ(output_.total_packets(), 0u);
549 }
550
TEST_F(BidiMethod,ClientError_CallsOnErrorCallback)551 TEST_F(BidiMethod, ClientError_CallsOnErrorCallback) {
552 Status status = Status::Unknown();
553 responder_.set_on_error([&status](Status error) { status = error; });
554
555 ASSERT_EQ(OkStatus(),
556 server_.ProcessPacket(PacketForRpc(PacketType::CLIENT_ERROR,
557 Status::Unauthenticated())));
558
559 EXPECT_EQ(status, Status::Unauthenticated());
560 }
561
TEST_F(BidiMethod,Cancel_CallsOnErrorCallback)562 TEST_F(BidiMethod, Cancel_CallsOnErrorCallback) {
563 Status status = Status::Unknown();
564 responder_.set_on_error([&status](Status error) { status = error; });
565
566 ASSERT_EQ(OkStatus(), server_.ProcessPacket(EncodeCancel()));
567 EXPECT_EQ(status, Status::Cancelled());
568 }
569
TEST_F(BidiMethod,Cancel_IncorrectChannel_SendsNothing)570 TEST_F(BidiMethod, Cancel_IncorrectChannel_SendsNothing) {
571 EXPECT_EQ(OkStatus(), server_.ProcessPacket(EncodeCancel(2, 42, 100)));
572
573 EXPECT_EQ(output_.total_packets(), 0u);
574 EXPECT_TRUE(responder_.active());
575 }
576
TEST_F(BidiMethod,Cancel_IncorrectService_SendsNothing)577 TEST_F(BidiMethod, Cancel_IncorrectService_SendsNothing) {
578 EXPECT_EQ(OkStatus(), server_.ProcessPacket(EncodeCancel(1, 43, 100)));
579 EXPECT_EQ(output_.total_packets(), 0u);
580 EXPECT_TRUE(responder_.active());
581 }
582
TEST_F(BidiMethod,Cancel_IncorrectMethod_SendsNothing)583 TEST_F(BidiMethod, Cancel_IncorrectMethod_SendsNothing) {
584 EXPECT_EQ(OkStatus(), server_.ProcessPacket(EncodeCancel(1, 42, 101)));
585 EXPECT_EQ(output_.total_packets(), 0u);
586 EXPECT_TRUE(responder_.active());
587 }
588
TEST_F(BidiMethod,ClientStream_CallsCallback)589 TEST_F(BidiMethod, ClientStream_CallsCallback) {
590 ConstByteSpan data = as_bytes(span("?"));
591 responder_.set_on_next([&data](ConstByteSpan payload) { data = payload; });
592
593 ASSERT_EQ(OkStatus(),
594 server_.ProcessPacket(
595 PacketForRpc(PacketType::CLIENT_STREAM, {}, "hello")));
596
597 EXPECT_EQ(output_.total_packets(), 0u);
598 EXPECT_STREQ(span_as_cstr(data), "hello");
599 }
600
TEST_F(BidiMethod,ClientStream_CallsCallbackOnCallWithOpenId)601 TEST_F(BidiMethod, ClientStream_CallsCallbackOnCallWithOpenId) {
602 ConstByteSpan data = as_bytes(span("?"));
603 responder_.set_on_next([&data](ConstByteSpan payload) { data = payload; });
604
605 ASSERT_EQ(
606 OkStatus(),
607 server_.ProcessPacket(PacketForRpc(
608 PacketType::CLIENT_STREAM, {}, "hello", internal::kOpenCallId)));
609
610 EXPECT_EQ(output_.total_packets(), 0u);
611 EXPECT_STREQ(span_as_cstr(data), "hello");
612 }
613
TEST_F(BidiMethod,ClientStream_CallsCallbackOnCallWithLegacyOpenId)614 TEST_F(BidiMethod, ClientStream_CallsCallbackOnCallWithLegacyOpenId) {
615 ConstByteSpan data = as_bytes(span("?"));
616 responder_.set_on_next([&data](ConstByteSpan payload) { data = payload; });
617
618 ASSERT_EQ(OkStatus(),
619 server_.ProcessPacket(PacketForRpc(PacketType::CLIENT_STREAM,
620 {},
621 "hello",
622 internal::kLegacyOpenCallId)));
623
624 EXPECT_EQ(output_.total_packets(), 0u);
625 EXPECT_STREQ(span_as_cstr(data), "hello");
626 }
627
TEST_F(BidiMethod,ClientStream_CallsOpenIdOnCallWithDifferentId)628 TEST_F(BidiMethod, ClientStream_CallsOpenIdOnCallWithDifferentId) {
629 const uint32_t kSecondCallId = 1625;
630 internal::CallContext context(server_,
631 channels_[0].id(),
632 service_42_,
633 service_42_.method(100),
634 internal::kOpenCallId);
635 internal::rpc_lock().lock();
636 auto temp_responder =
637 internal::test::FakeServerReaderWriter(context.ClaimLocked());
638 internal::rpc_lock().unlock();
639 responder_ = std::move(temp_responder);
640
641 ConstByteSpan data = as_bytes(span("?"));
642 responder_.set_on_next([&data](ConstByteSpan payload) { data = payload; });
643
644 ASSERT_EQ(OkStatus(),
645 server_.ProcessPacket(PacketForRpc(
646 PacketType::CLIENT_STREAM, {}, "hello", kSecondCallId)));
647
648 EXPECT_EQ(output_.total_packets(), 0u);
649 EXPECT_STREQ(span_as_cstr(data), "hello");
650
651 internal::RpcLockGuard lock;
652 EXPECT_EQ(responder_.as_server_call().id(), kSecondCallId);
653 }
654
TEST_F(BidiMethod,ClientStream_CallsLegacyOpenIdOnCallWithDifferentId)655 TEST_F(BidiMethod, ClientStream_CallsLegacyOpenIdOnCallWithDifferentId) {
656 const uint32_t kSecondCallId = 1625;
657 internal::CallContext context(server_,
658 channels_[0].id(),
659 service_42_,
660 service_42_.method(100),
661 internal::kLegacyOpenCallId);
662 internal::rpc_lock().lock();
663 auto temp_responder =
664 internal::test::FakeServerReaderWriter(context.ClaimLocked());
665 internal::rpc_lock().unlock();
666 responder_ = std::move(temp_responder);
667
668 ConstByteSpan data = as_bytes(span("?"));
669 responder_.set_on_next([&data](ConstByteSpan payload) { data = payload; });
670
671 ASSERT_EQ(OkStatus(),
672 server_.ProcessPacket(PacketForRpc(
673 PacketType::CLIENT_STREAM, {}, "hello", kSecondCallId)));
674
675 EXPECT_EQ(output_.total_packets(), 0u);
676 EXPECT_STREQ(span_as_cstr(data), "hello");
677
678 internal::RpcLockGuard lock;
679 EXPECT_EQ(responder_.as_server_call().id(), kSecondCallId);
680 }
681
TEST_F(BidiMethod,UnregsiterService_AbortsActiveCalls)682 TEST_F(BidiMethod, UnregsiterService_AbortsActiveCalls) {
683 ASSERT_TRUE(responder_.active());
684
685 Status on_error_status = OkStatus();
686 responder_.set_on_error(
687 [&on_error_status](Status status) { on_error_status = status; });
688
689 server_.UnregisterService(service_42_);
690
691 EXPECT_FALSE(responder_.active());
692 EXPECT_EQ(Status::Aborted(), on_error_status);
693 }
694
TEST_F(BidiMethod,ClientRequestedCompletion_CallsCallback)695 TEST_F(BidiMethod, ClientRequestedCompletion_CallsCallback) {
696 bool called = false;
697 #if PW_RPC_COMPLETION_REQUEST_CALLBACK
698 responder_.set_on_completion_requested([&called]() { called = true; });
699 #endif
700 ASSERT_EQ(OkStatus(),
701 server_.ProcessPacket(
702 PacketForRpc(PacketType::CLIENT_REQUEST_COMPLETION)));
703
704 EXPECT_EQ(output_.total_packets(), 0u);
705 EXPECT_EQ(called, PW_RPC_COMPLETION_REQUEST_CALLBACK);
706 }
707
TEST_F(BidiMethod,ClientRequestedCompletion_CallsCallbackIfEnabled)708 TEST_F(BidiMethod, ClientRequestedCompletion_CallsCallbackIfEnabled) {
709 bool called = false;
710 responder_.set_on_completion_requested_if_enabled(
711 [&called]() { called = true; });
712
713 ASSERT_EQ(OkStatus(),
714 server_.ProcessPacket(
715 PacketForRpc(PacketType::CLIENT_REQUEST_COMPLETION)));
716
717 EXPECT_EQ(output_.total_packets(), 0u);
718 EXPECT_EQ(called, PW_RPC_COMPLETION_REQUEST_CALLBACK);
719 }
720
TEST_F(BidiMethod,ClientRequestedCompletion_ErrorWhenClosed)721 TEST_F(BidiMethod, ClientRequestedCompletion_ErrorWhenClosed) {
722 const auto end = PacketForRpc(PacketType::CLIENT_REQUEST_COMPLETION);
723 ASSERT_EQ(OkStatus(), server_.ProcessPacket(end));
724 ASSERT_EQ(OkStatus(), server_.ProcessPacket(end));
725
726 ASSERT_EQ(output_.total_packets(), 0u);
727 }
728
TEST_F(BidiMethod,ClientRequestedCompletion_ErrorWhenAlreadyClosed)729 TEST_F(BidiMethod, ClientRequestedCompletion_ErrorWhenAlreadyClosed) {
730 ASSERT_EQ(OkStatus(), server_.ProcessPacket(EncodeCancel()));
731 EXPECT_FALSE(responder_.active());
732
733 const auto end = PacketForRpc(PacketType::CLIENT_REQUEST_COMPLETION);
734 ASSERT_EQ(OkStatus(), server_.ProcessPacket(end));
735
736 ASSERT_EQ(output_.total_packets(), 1u);
737 const Packet& packet =
738 static_cast<internal::test::FakeChannelOutput&>(output_).last_packet();
739 EXPECT_EQ(packet.type(), PacketType::SERVER_ERROR);
740 EXPECT_EQ(packet.status(), Status::FailedPrecondition());
741 }
742
743 class ServerStreamingMethod : public BasicServer {
744 protected:
ServerStreamingMethod()745 ServerStreamingMethod() {
746 internal::CallContext context(server_,
747 channels_[0].id(),
748 service_42_,
749 service_42_.method(100),
750 kDefaultCallId);
751 internal::rpc_lock().lock();
752 internal::test::FakeServerWriter responder_temp(context.ClaimLocked());
753 internal::rpc_lock().unlock();
754 responder_ = std::move(responder_temp);
755 PW_CHECK(responder_.active());
756 }
757
758 internal::test::FakeServerWriter responder_;
759 };
760
TEST_F(ServerStreamingMethod,ClientStream_InvalidArgumentError)761 TEST_F(ServerStreamingMethod, ClientStream_InvalidArgumentError) {
762 ASSERT_EQ(OkStatus(),
763 server_.ProcessPacket(PacketForRpc(PacketType::CLIENT_STREAM)));
764
765 ASSERT_EQ(output_.total_packets(), 1u);
766 const Packet& packet =
767 static_cast<internal::test::FakeChannelOutput&>(output_).last_packet();
768 EXPECT_EQ(packet.type(), PacketType::SERVER_ERROR);
769 EXPECT_EQ(packet.status(), Status::InvalidArgument());
770 }
771
TEST_F(ServerStreamingMethod,ClientRequestedCompletion_CallsCallback)772 TEST_F(ServerStreamingMethod, ClientRequestedCompletion_CallsCallback) {
773 bool called = false;
774 #if PW_RPC_COMPLETION_REQUEST_CALLBACK
775 responder_.set_on_completion_requested([&called]() { called = true; });
776 #endif
777
778 ASSERT_EQ(OkStatus(),
779 server_.ProcessPacket(
780 PacketForRpc(PacketType::CLIENT_REQUEST_COMPLETION)));
781
782 EXPECT_EQ(output_.total_packets(), 0u);
783 EXPECT_EQ(called, PW_RPC_COMPLETION_REQUEST_CALLBACK);
784 }
785
TEST_F(ServerStreamingMethod,ClientRequestedCompletion_CallsCallbackIfEnabled)786 TEST_F(ServerStreamingMethod,
787 ClientRequestedCompletion_CallsCallbackIfEnabled) {
788 bool called = false;
789 responder_.set_on_completion_requested_if_enabled(
790 [&called]() { called = true; });
791
792 ASSERT_EQ(OkStatus(),
793 server_.ProcessPacket(
794 PacketForRpc(PacketType::CLIENT_REQUEST_COMPLETION)));
795
796 EXPECT_EQ(output_.total_packets(), 0u);
797 EXPECT_EQ(called, PW_RPC_COMPLETION_REQUEST_CALLBACK);
798 }
799
TEST_F(ServerStreamingMethod,ClientRequestedCompletion_ErrorWhenClosed)800 TEST_F(ServerStreamingMethod, ClientRequestedCompletion_ErrorWhenClosed) {
801 const auto end = PacketForRpc(PacketType::CLIENT_REQUEST_COMPLETION);
802 ASSERT_EQ(OkStatus(), server_.ProcessPacket(end));
803 ASSERT_EQ(OkStatus(), server_.ProcessPacket(end));
804
805 ASSERT_EQ(output_.total_packets(), 0u);
806 }
807
TEST_F(ServerStreamingMethod,ClientRequestedCompletion_ErrorWhenAlreadyClosed)808 TEST_F(ServerStreamingMethod,
809 ClientRequestedCompletion_ErrorWhenAlreadyClosed) {
810 ASSERT_EQ(OkStatus(), server_.ProcessPacket(EncodeCancel()));
811 EXPECT_FALSE(responder_.active());
812
813 const auto end = PacketForRpc(PacketType::CLIENT_REQUEST_COMPLETION);
814 ASSERT_EQ(OkStatus(), server_.ProcessPacket(end));
815
816 ASSERT_EQ(output_.total_packets(), 1u);
817 const Packet& packet =
818 static_cast<internal::test::FakeChannelOutput&>(output_).last_packet();
819 EXPECT_EQ(packet.type(), PacketType::SERVER_ERROR);
820 EXPECT_EQ(packet.status(), Status::FailedPrecondition());
821 }
822
823 } // namespace
824 } // namespace pw::rpc
825