• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2020 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14 
15 #include "pw_rpc/server.h"
16 
17 #include <array>
18 #include <cstdint>
19 
20 #include "pw_assert/check.h"
21 #include "pw_rpc/internal/call.h"
22 #include "pw_rpc/internal/method.h"
23 #include "pw_rpc/internal/packet.h"
24 #include "pw_rpc/internal/test_utils.h"
25 #include "pw_rpc/service.h"
26 #include "pw_rpc_private/fake_server_reader_writer.h"
27 #include "pw_rpc_private/test_method.h"
28 #include "pw_unit_test/framework.h"
29 
30 namespace pw::rpc {
31 
32 class ServerTestHelper {
33  public:
FindMethod(Server & server,uint32_t service_id,uint32_t method_id)34   static std::tuple<Service*, const internal::Method*> FindMethod(
35       Server& server, uint32_t service_id, uint32_t method_id) {
36     return server.FindMethod(service_id, method_id);
37   }
38 };
39 
40 namespace {
41 
42 using std::byte;
43 
44 using internal::Packet;
45 using internal::TestMethod;
46 using internal::TestMethodUnion;
47 using internal::pwpb::PacketType;
48 
49 class TestService : public Service {
50  public:
TestService(uint32_t service_id)51   TestService(uint32_t service_id)
52       : Service(service_id, methods_),
53         methods_{
54             TestMethod(100, MethodType::kBidirectionalStreaming),
55             TestMethod(200),
56         } {}
57 
method(uint32_t id)58   const TestMethod& method(uint32_t id) {
59     for (TestMethodUnion& method : methods_) {
60       if (method.method().id() == id) {
61         return method.test_method();
62       }
63     }
64 
65     PW_CRASH("Invalid method ID %u", static_cast<unsigned>(id));
66   }
67 
68  private:
69   std::array<TestMethodUnion, 2> methods_;
70 };
71 
72 class EmptyService : public Service {
73  public:
EmptyService()74   constexpr EmptyService() : Service(200, methods_) {}
75 
76  private:
77   static constexpr std::array<TestMethodUnion, 0> methods_ = {};
78 };
79 
80 uint32_t kDefaultCallId = 24601;
81 
82 class BasicServer : public ::testing::Test {
83  protected:
84   static constexpr byte kDefaultPayload[] = {
85       byte(0x82), byte(0x02), byte(0xff), byte(0xff)};
86 
BasicServer()87   BasicServer()
88       : channels_{
89             Channel::Create<1>(&output_),
90             Channel::Create<2>(&output_),
91             Channel(),  // available for assignment
92         },
93         server_(channels_),
94         service_1_(1),
95         service_42_(42) {
96     server_.RegisterService(service_1_, service_42_, empty_service_);
97   }
98 
EncodePacket(PacketType type,uint32_t channel_id,uint32_t service_id,uint32_t method_id,uint32_t call_id=kDefaultCallId)99   span<const byte> EncodePacket(PacketType type,
100                                 uint32_t channel_id,
101                                 uint32_t service_id,
102                                 uint32_t method_id,
103                                 uint32_t call_id = kDefaultCallId) {
104     return EncodePacketWithBody(type,
105                                 channel_id,
106                                 service_id,
107                                 method_id,
108                                 call_id,
109                                 kDefaultPayload,
110                                 OkStatus());
111   }
112 
EncodeCancel(uint32_t channel_id=1,uint32_t service_id=42,uint32_t method_id=100,uint32_t call_id=kDefaultCallId)113   span<const byte> EncodeCancel(uint32_t channel_id = 1,
114                                 uint32_t service_id = 42,
115                                 uint32_t method_id = 100,
116                                 uint32_t call_id = kDefaultCallId) {
117     return EncodePacketWithBody(PacketType::CLIENT_ERROR,
118                                 channel_id,
119                                 service_id,
120                                 method_id,
121                                 call_id,
122                                 {},
123                                 Status::Cancelled());
124   }
125 
126   template <typename T = ConstByteSpan>
PacketForRpc(PacketType type,Status status=OkStatus (),T && payload={},uint32_t call_id=kDefaultCallId)127   ConstByteSpan PacketForRpc(PacketType type,
128                              Status status = OkStatus(),
129                              T&& payload = {},
130                              uint32_t call_id = kDefaultCallId) {
131     return EncodePacketWithBody(
132         type, 1, 42, 100, call_id, as_bytes(span(payload)), status);
133   }
134 
135   RawFakeChannelOutput<2> output_;
136   std::array<Channel, 3> channels_;
137   Server server_;
138   TestService service_1_;
139   TestService service_42_;
140   EmptyService empty_service_;
141 
142  private:
143   byte request_buffer_[64];
144 
EncodePacketWithBody(PacketType type,uint32_t channel_id,uint32_t service_id,uint32_t method_id,uint32_t call_id,span<const byte> payload,Status status)145   span<const byte> EncodePacketWithBody(PacketType type,
146                                         uint32_t channel_id,
147                                         uint32_t service_id,
148                                         uint32_t method_id,
149                                         uint32_t call_id,
150                                         span<const byte> payload,
151                                         Status status) {
152     auto result =
153         Packet(
154             type, channel_id, service_id, method_id, call_id, payload, status)
155             .Encode(request_buffer_);
156     EXPECT_EQ(OkStatus(), result.status());
157     return result.value_or(ConstByteSpan());
158   }
159 };
160 
TEST_F(BasicServer,IsServiceRegistered)161 TEST_F(BasicServer, IsServiceRegistered) {
162   TestService unregisteredService(0);
163   EXPECT_FALSE(server_.IsServiceRegistered(unregisteredService));
164   EXPECT_TRUE(server_.IsServiceRegistered(service_1_));
165 }
166 
TEST_F(BasicServer,ProcessPacket_ValidMethodInService1_InvokesMethod)167 TEST_F(BasicServer, ProcessPacket_ValidMethodInService1_InvokesMethod) {
168   EXPECT_EQ(
169       OkStatus(),
170       server_.ProcessPacket(EncodePacket(PacketType::REQUEST, 1, 1, 100)));
171 
172   const TestMethod& method = service_1_.method(100);
173   EXPECT_EQ(1u, method.last_channel_id());
174   ASSERT_EQ(sizeof(kDefaultPayload), method.last_request().payload().size());
175   EXPECT_EQ(std::memcmp(kDefaultPayload,
176                         method.last_request().payload().data(),
177                         method.last_request().payload().size()),
178             0);
179 }
180 
TEST_F(BasicServer,ProcessPacket_ValidMethodInService42_InvokesMethod)181 TEST_F(BasicServer, ProcessPacket_ValidMethodInService42_InvokesMethod) {
182   EXPECT_EQ(
183       OkStatus(),
184       server_.ProcessPacket(EncodePacket(PacketType::REQUEST, 1, 42, 200)));
185 
186   const TestMethod& method = service_42_.method(200);
187   EXPECT_EQ(1u, method.last_channel_id());
188   ASSERT_EQ(sizeof(kDefaultPayload), method.last_request().payload().size());
189   EXPECT_EQ(std::memcmp(kDefaultPayload,
190                         method.last_request().payload().data(),
191                         method.last_request().payload().size()),
192             0);
193 }
194 
TEST_F(BasicServer,UnregisterService_CannotCallMethod)195 TEST_F(BasicServer, UnregisterService_CannotCallMethod) {
196   const uint32_t kCallId = 8675309;
197   server_.UnregisterService(service_1_, service_42_);
198 
199   EXPECT_EQ(OkStatus(),
200             server_.ProcessPacket(
201                 EncodePacket(PacketType::REQUEST, 1, 1, 100, kCallId)));
202 
203   const Packet& packet =
204       static_cast<internal::test::FakeChannelOutput&>(output_).last_packet();
205   EXPECT_EQ(packet.type(), PacketType::SERVER_ERROR);
206   EXPECT_EQ(packet.channel_id(), 1u);
207   EXPECT_EQ(packet.service_id(), 1u);
208   EXPECT_EQ(packet.method_id(), 100u);
209   EXPECT_EQ(packet.call_id(), kCallId);
210   EXPECT_EQ(packet.status(), Status::NotFound());
211 }
212 
TEST_F(BasicServer,UnregisterService_AlreadyUnregistered_DoesNothing)213 TEST_F(BasicServer, UnregisterService_AlreadyUnregistered_DoesNothing) {
214   server_.UnregisterService(service_42_, service_42_, service_42_);
215   server_.UnregisterService(service_42_);
216 
217   EXPECT_EQ(
218       OkStatus(),
219       server_.ProcessPacket(EncodePacket(PacketType::REQUEST, 1, 1, 100)));
220 
221   const TestMethod& method = service_1_.method(100);
222   EXPECT_EQ(1u, method.last_channel_id());
223   ASSERT_EQ(sizeof(kDefaultPayload), method.last_request().payload().size());
224   EXPECT_EQ(std::memcmp(kDefaultPayload,
225                         method.last_request().payload().data(),
226                         method.last_request().payload().size()),
227             0);
228 }
229 
TEST_F(BasicServer,ProcessPacket_IncompletePacket_NothingIsInvoked)230 TEST_F(BasicServer, ProcessPacket_IncompletePacket_NothingIsInvoked) {
231   EXPECT_EQ(
232       Status::DataLoss(),
233       server_.ProcessPacket(EncodePacket(PacketType::REQUEST, 0, 42, 101)));
234   EXPECT_EQ(
235       Status::DataLoss(),
236       server_.ProcessPacket(EncodePacket(PacketType::REQUEST, 1, 0, 101)));
237   EXPECT_EQ(Status::DataLoss(),
238             server_.ProcessPacket(EncodePacket(PacketType::REQUEST, 1, 42, 0)));
239 
240   EXPECT_EQ(0u, service_42_.method(100).last_channel_id());
241   EXPECT_EQ(0u, service_42_.method(200).last_channel_id());
242 }
243 
TEST_F(BasicServer,ProcessPacket_NoChannel_SendsNothing)244 TEST_F(BasicServer, ProcessPacket_NoChannel_SendsNothing) {
245   EXPECT_EQ(
246       Status::DataLoss(),
247       server_.ProcessPacket(EncodePacket(PacketType::REQUEST, 0, 42, 101)));
248 
249   EXPECT_EQ(output_.total_packets(), 0u);
250 }
251 
TEST_F(BasicServer,ProcessPacket_NoService_SendsNothing)252 TEST_F(BasicServer, ProcessPacket_NoService_SendsNothing) {
253   EXPECT_EQ(
254       Status::DataLoss(),
255       server_.ProcessPacket(EncodePacket(PacketType::REQUEST, 1, 0, 101)));
256 
257   EXPECT_EQ(output_.total_packets(), 0u);
258 }
259 
TEST_F(BasicServer,ProcessPacket_NoMethod_SendsNothing)260 TEST_F(BasicServer, ProcessPacket_NoMethod_SendsNothing) {
261   EXPECT_EQ(Status::DataLoss(),
262             server_.ProcessPacket(EncodePacket(PacketType::REQUEST, 1, 42, 0)));
263 
264   EXPECT_EQ(output_.total_packets(), 0u);
265 }
266 
TEST_F(BasicServer,ProcessPacket_InvalidMethod_NothingIsInvoked)267 TEST_F(BasicServer, ProcessPacket_InvalidMethod_NothingIsInvoked) {
268   EXPECT_EQ(
269       OkStatus(),
270       server_.ProcessPacket(EncodePacket(PacketType::REQUEST, 1, 42, 101)));
271 
272   EXPECT_EQ(0u, service_42_.method(100).last_channel_id());
273   EXPECT_EQ(0u, service_42_.method(200).last_channel_id());
274 }
275 
TEST_F(BasicServer,ProcessPacket_ClientErrorWithInvalidMethod_NoResponse)276 TEST_F(BasicServer, ProcessPacket_ClientErrorWithInvalidMethod_NoResponse) {
277   EXPECT_EQ(OkStatus(),
278             server_.ProcessPacket(
279                 EncodePacket(PacketType::CLIENT_ERROR, 1, 42, 101)));
280 
281   EXPECT_EQ(0u, output_.total_packets());
282 }
283 
TEST_F(BasicServer,ProcessPacket_InvalidMethod_SendsError)284 TEST_F(BasicServer, ProcessPacket_InvalidMethod_SendsError) {
285   EXPECT_EQ(
286       OkStatus(),
287       server_.ProcessPacket(EncodePacket(PacketType::REQUEST, 1, 42, 27)));
288 
289   const Packet& packet =
290       static_cast<internal::test::FakeChannelOutput&>(output_).last_packet();
291   EXPECT_EQ(packet.type(), PacketType::SERVER_ERROR);
292   EXPECT_EQ(packet.channel_id(), 1u);
293   EXPECT_EQ(packet.service_id(), 42u);
294   EXPECT_EQ(packet.method_id(), 27u);  // No method ID 27
295   EXPECT_EQ(packet.status(), Status::NotFound());
296 }
297 
TEST_F(BasicServer,ProcessPacket_InvalidService_SendsError)298 TEST_F(BasicServer, ProcessPacket_InvalidService_SendsError) {
299   EXPECT_EQ(
300       OkStatus(),
301       server_.ProcessPacket(EncodePacket(PacketType::REQUEST, 1, 43, 27)));
302 
303   const Packet& packet =
304       static_cast<internal::test::FakeChannelOutput&>(output_).last_packet();
305   EXPECT_EQ(packet.type(), PacketType::SERVER_ERROR);
306   EXPECT_EQ(packet.channel_id(), 1u);
307   EXPECT_EQ(packet.service_id(), 43u);  // No service ID 43
308   EXPECT_EQ(packet.method_id(), 27u);
309   EXPECT_EQ(packet.status(), Status::NotFound());
310 }
311 
TEST_F(BasicServer,ProcessPacket_UnassignedChannel)312 TEST_F(BasicServer, ProcessPacket_UnassignedChannel) {
313   EXPECT_EQ(Status::Unavailable(),
314             server_.ProcessPacket(
315                 EncodePacket(PacketType::REQUEST, /*channel_id=*/99, 42, 27)));
316 }
317 
TEST_F(BasicServer,ProcessPacket_ClientErrorOnUnassignedChannel_NoResponse)318 TEST_F(BasicServer, ProcessPacket_ClientErrorOnUnassignedChannel_NoResponse) {
319   channels_[2] = Channel::Create<3>(&output_);  // Occupy only available channel
320 
321   EXPECT_EQ(Status::Unavailable(),
322             server_.ProcessPacket(EncodePacket(
323                 PacketType::CLIENT_ERROR, /*channel_id=*/99, 42, 27)));
324 
325   EXPECT_EQ(0u, output_.total_packets());
326 }
327 
TEST_F(BasicServer,ProcessPacket_Cancel_MethodNotActive_SendsNothing)328 TEST_F(BasicServer, ProcessPacket_Cancel_MethodNotActive_SendsNothing) {
329   // Set up a fake ServerWriter representing an ongoing RPC.
330   EXPECT_EQ(OkStatus(), server_.ProcessPacket(EncodeCancel(1, 42, 100)));
331 
332   EXPECT_EQ(output_.total_packets(), 0u);
333 }
334 
GetChannel(internal::Endpoint & endpoint,uint32_t id)335 const Channel* GetChannel(internal::Endpoint& endpoint, uint32_t id) {
336   internal::RpcLockGuard lock;
337   return endpoint.GetInternalChannel(id);
338 }
339 
TEST_F(BasicServer,CloseChannel_NoCalls)340 TEST_F(BasicServer, CloseChannel_NoCalls) {
341   EXPECT_NE(nullptr, GetChannel(server_, 2));
342   EXPECT_EQ(OkStatus(), server_.CloseChannel(2));
343   EXPECT_EQ(nullptr, GetChannel(server_, 2));
344   ASSERT_EQ(output_.total_packets(), 0u);
345 }
346 
TEST_F(BasicServer,CloseChannel_UnknownChannel)347 TEST_F(BasicServer, CloseChannel_UnknownChannel) {
348   ASSERT_EQ(nullptr, GetChannel(server_, 13579));
349   EXPECT_EQ(Status::NotFound(), server_.CloseChannel(13579));
350 }
351 
TEST_F(BasicServer,CloseChannel_PendingCall)352 TEST_F(BasicServer, CloseChannel_PendingCall) {
353   EXPECT_NE(nullptr, GetChannel(server_, 1));
354   EXPECT_EQ(static_cast<internal::Endpoint&>(server_).active_call_count(), 0u);
355 
356   internal::test::FakeServerReaderWriter call;
357   service_42_.method(100).keep_call_active(call);
358 
359   EXPECT_EQ(
360       OkStatus(),
361       server_.ProcessPacket(EncodePacket(PacketType::REQUEST, 1, 42, 100)));
362 
363   Status on_error_status;
364   call.set_on_error(
365       [&on_error_status](Status error) { on_error_status = error; });
366 
367   ASSERT_TRUE(call.active());
368   EXPECT_EQ(static_cast<internal::Endpoint&>(server_).active_call_count(), 1u);
369 
370   EXPECT_EQ(OkStatus(), server_.CloseChannel(1));
371   EXPECT_EQ(nullptr, GetChannel(server_, 1));
372 
373   EXPECT_EQ(static_cast<internal::Endpoint&>(server_).active_call_count(), 0u);
374 
375   // Should call on_error, but not send a packet since the channel is closed.
376   EXPECT_EQ(Status::Aborted(), on_error_status);
377   ASSERT_EQ(output_.total_packets(), 0u);
378 }
379 
TEST_F(BasicServer,OpenChannel_UnusedSlot)380 TEST_F(BasicServer, OpenChannel_UnusedSlot) {
381   const span request = EncodePacket(PacketType::REQUEST, 9, 42, 100);
382   EXPECT_EQ(Status::Unavailable(), server_.ProcessPacket(request));
383 
384   EXPECT_EQ(OkStatus(), server_.OpenChannel(9, output_));
385   EXPECT_EQ(OkStatus(), server_.ProcessPacket(request));
386 
387   const Packet& packet =
388       static_cast<internal::test::FakeChannelOutput&>(output_).last_packet();
389   EXPECT_EQ(packet.type(), PacketType::RESPONSE);
390   EXPECT_EQ(packet.channel_id(), 9u);
391   EXPECT_EQ(packet.service_id(), 42u);
392   EXPECT_EQ(packet.method_id(), 100u);
393 }
394 
TEST_F(BasicServer,OpenChannel_AlreadyExists)395 TEST_F(BasicServer, OpenChannel_AlreadyExists) {
396   ASSERT_NE(nullptr, GetChannel(server_, 1));
397   EXPECT_EQ(Status::AlreadyExists(), server_.OpenChannel(1, output_));
398 }
399 
TEST_F(BasicServer,OpenChannel_AdditionalSlot)400 TEST_F(BasicServer, OpenChannel_AdditionalSlot) {
401   EXPECT_EQ(OkStatus(), server_.OpenChannel(3, output_));
402 
403   constexpr Status kExpected =
404       PW_RPC_DYNAMIC_ALLOCATION == 0 ? Status::ResourceExhausted() : OkStatus();
405   EXPECT_EQ(kExpected, server_.OpenChannel(19823, output_));
406 }
407 
TEST_F(BasicServer,FindMethod_FoundOkOptionallyCheckType)408 TEST_F(BasicServer, FindMethod_FoundOkOptionallyCheckType) {
409   const auto [service, method] = ServerTestHelper::FindMethod(server_, 1, 100);
410   ASSERT_TRUE(service != nullptr);
411   ASSERT_TRUE(method != nullptr);
412 #if PW_RPC_METHOD_STORES_TYPE
413   EXPECT_EQ(MethodType::kBidirectionalStreaming, method->type());
414 #endif
415 }
416 
TEST_F(BasicServer,FindMethod_NotFound)417 TEST_F(BasicServer, FindMethod_NotFound) {
418   {
419     const auto [service, method] =
420         ServerTestHelper::FindMethod(server_, 2, 100);
421     ASSERT_TRUE(service == nullptr);
422     ASSERT_TRUE(method == nullptr);
423   }
424 
425   {
426     const auto [service, method] =
427         ServerTestHelper::FindMethod(server_, 1, 101);
428     ASSERT_TRUE(service != nullptr);
429     ASSERT_TRUE(method == nullptr);
430   }
431 }
432 
433 class BidiMethod : public BasicServer {
434  protected:
BidiMethod()435   BidiMethod() {
436     internal::rpc_lock().lock();
437     internal::CallContext context(server_,
438                                   channels_[0].id(),
439                                   service_42_,
440                                   service_42_.method(100),
441                                   kDefaultCallId);
442     // A local temporary is required since the constructor requires a lock,
443     // but the *move* constructor takes out the lock.
444     internal::test::FakeServerReaderWriter responder_temp(
445         context.ClaimLocked());
446     internal::rpc_lock().unlock();
447     responder_ = std::move(responder_temp);
448     PW_CHECK(responder_.active());
449   }
450 
451   internal::test::FakeServerReaderWriter responder_;
452 };
453 
TEST_F(BidiMethod,DuplicateCallId_CancelsExistingThenCallsAgain)454 TEST_F(BidiMethod, DuplicateCallId_CancelsExistingThenCallsAgain) {
455   int cancelled = 0;
456   responder_.set_on_error([&cancelled](Status error) {
457     if (error.IsCancelled()) {
458       cancelled += 1;
459     }
460   });
461 
462   const TestMethod& method = service_42_.method(100);
463   ASSERT_EQ(method.invocations(), 0u);
464 
465   EXPECT_EQ(OkStatus(),
466             server_.ProcessPacket(PacketForRpc(PacketType::REQUEST)));
467 
468   EXPECT_EQ(cancelled, 1);
469   EXPECT_EQ(method.invocations(), 1u);
470 }
471 
TEST_F(BidiMethod,DuplicateMethodDifferentCallId_NotCancelled)472 TEST_F(BidiMethod, DuplicateMethodDifferentCallId_NotCancelled) {
473   int cancelled = 0;
474   responder_.set_on_error([&cancelled](Status error) {
475     if (error.IsCancelled()) {
476       cancelled += 1;
477     }
478   });
479 
480   const uint32_t kSecondCallId = 1625;
481   EXPECT_EQ(OkStatus(),
482             server_.ProcessPacket(PacketForRpc(
483                 PacketType::REQUEST, OkStatus(), {}, kSecondCallId)));
484 
485   EXPECT_EQ(cancelled, 0);
486 }
487 
span_as_cstr(ConstByteSpan span)488 const char* span_as_cstr(ConstByteSpan span) {
489   return reinterpret_cast<const char*>(span.data());
490 }
491 
TEST_F(BidiMethod,DuplicateMethodDifferentCallIdEachCallGetsSeparateResponse)492 TEST_F(BidiMethod, DuplicateMethodDifferentCallIdEachCallGetsSeparateResponse) {
493   const uint32_t kSecondCallId = 1625;
494 
495   internal::rpc_lock().lock();
496   internal::test::FakeServerReaderWriter responder_2(
497       internal::CallContext(server_,
498                             channels_[0].id(),
499                             service_42_,
500                             service_42_.method(100),
501                             kSecondCallId)
502           .ClaimLocked());
503   internal::rpc_lock().unlock();
504 
505   ConstByteSpan data_1 = as_bytes(span("data_1_unset"));
506   responder_.set_on_next(
507       [&data_1](ConstByteSpan payload) { data_1 = payload; });
508 
509   ConstByteSpan data_2 = as_bytes(span("data_2_unset"));
510   responder_2.set_on_next(
511       [&data_2](ConstByteSpan payload) { data_2 = payload; });
512 
513   const char* kMessage1 = "hello_1";
514   const char* kMessage2 = "hello_2";
515 
516   EXPECT_EQ(
517       OkStatus(),
518       server_.ProcessPacket(PacketForRpc(
519           PacketType::CLIENT_STREAM, OkStatus(), "hello_2", kSecondCallId)));
520 
521   EXPECT_STREQ(span_as_cstr(data_2), kMessage2);
522 
523   EXPECT_EQ(
524       OkStatus(),
525       server_.ProcessPacket(PacketForRpc(
526           PacketType::CLIENT_STREAM, OkStatus(), "hello_1", kDefaultCallId)));
527 
528   EXPECT_STREQ(span_as_cstr(data_1), kMessage1);
529 }
530 
TEST_F(BidiMethod,Cancel_ClosesServerWriter)531 TEST_F(BidiMethod, Cancel_ClosesServerWriter) {
532   EXPECT_EQ(OkStatus(), server_.ProcessPacket(EncodeCancel()));
533 
534   EXPECT_FALSE(responder_.active());
535 }
536 
TEST_F(BidiMethod,Cancel_SendsNoResponse)537 TEST_F(BidiMethod, Cancel_SendsNoResponse) {
538   EXPECT_EQ(OkStatus(), server_.ProcessPacket(EncodeCancel()));
539 
540   EXPECT_EQ(output_.total_packets(), 0u);
541 }
542 
TEST_F(BidiMethod,ClientError_ClosesServerWriterWithoutResponse)543 TEST_F(BidiMethod, ClientError_ClosesServerWriterWithoutResponse) {
544   ASSERT_EQ(OkStatus(),
545             server_.ProcessPacket(PacketForRpc(PacketType::CLIENT_ERROR)));
546 
547   EXPECT_FALSE(responder_.active());
548   EXPECT_EQ(output_.total_packets(), 0u);
549 }
550 
TEST_F(BidiMethod,ClientError_CallsOnErrorCallback)551 TEST_F(BidiMethod, ClientError_CallsOnErrorCallback) {
552   Status status = Status::Unknown();
553   responder_.set_on_error([&status](Status error) { status = error; });
554 
555   ASSERT_EQ(OkStatus(),
556             server_.ProcessPacket(PacketForRpc(PacketType::CLIENT_ERROR,
557                                                Status::Unauthenticated())));
558 
559   EXPECT_EQ(status, Status::Unauthenticated());
560 }
561 
TEST_F(BidiMethod,Cancel_CallsOnErrorCallback)562 TEST_F(BidiMethod, Cancel_CallsOnErrorCallback) {
563   Status status = Status::Unknown();
564   responder_.set_on_error([&status](Status error) { status = error; });
565 
566   ASSERT_EQ(OkStatus(), server_.ProcessPacket(EncodeCancel()));
567   EXPECT_EQ(status, Status::Cancelled());
568 }
569 
TEST_F(BidiMethod,Cancel_IncorrectChannel_SendsNothing)570 TEST_F(BidiMethod, Cancel_IncorrectChannel_SendsNothing) {
571   EXPECT_EQ(OkStatus(), server_.ProcessPacket(EncodeCancel(2, 42, 100)));
572 
573   EXPECT_EQ(output_.total_packets(), 0u);
574   EXPECT_TRUE(responder_.active());
575 }
576 
TEST_F(BidiMethod,Cancel_IncorrectService_SendsNothing)577 TEST_F(BidiMethod, Cancel_IncorrectService_SendsNothing) {
578   EXPECT_EQ(OkStatus(), server_.ProcessPacket(EncodeCancel(1, 43, 100)));
579   EXPECT_EQ(output_.total_packets(), 0u);
580   EXPECT_TRUE(responder_.active());
581 }
582 
TEST_F(BidiMethod,Cancel_IncorrectMethod_SendsNothing)583 TEST_F(BidiMethod, Cancel_IncorrectMethod_SendsNothing) {
584   EXPECT_EQ(OkStatus(), server_.ProcessPacket(EncodeCancel(1, 42, 101)));
585   EXPECT_EQ(output_.total_packets(), 0u);
586   EXPECT_TRUE(responder_.active());
587 }
588 
TEST_F(BidiMethod,ClientStream_CallsCallback)589 TEST_F(BidiMethod, ClientStream_CallsCallback) {
590   ConstByteSpan data = as_bytes(span("?"));
591   responder_.set_on_next([&data](ConstByteSpan payload) { data = payload; });
592 
593   ASSERT_EQ(OkStatus(),
594             server_.ProcessPacket(
595                 PacketForRpc(PacketType::CLIENT_STREAM, {}, "hello")));
596 
597   EXPECT_EQ(output_.total_packets(), 0u);
598   EXPECT_STREQ(span_as_cstr(data), "hello");
599 }
600 
TEST_F(BidiMethod,ClientStream_CallsCallbackOnCallWithOpenId)601 TEST_F(BidiMethod, ClientStream_CallsCallbackOnCallWithOpenId) {
602   ConstByteSpan data = as_bytes(span("?"));
603   responder_.set_on_next([&data](ConstByteSpan payload) { data = payload; });
604 
605   ASSERT_EQ(
606       OkStatus(),
607       server_.ProcessPacket(PacketForRpc(
608           PacketType::CLIENT_STREAM, {}, "hello", internal::kOpenCallId)));
609 
610   EXPECT_EQ(output_.total_packets(), 0u);
611   EXPECT_STREQ(span_as_cstr(data), "hello");
612 }
613 
TEST_F(BidiMethod,ClientStream_CallsCallbackOnCallWithLegacyOpenId)614 TEST_F(BidiMethod, ClientStream_CallsCallbackOnCallWithLegacyOpenId) {
615   ConstByteSpan data = as_bytes(span("?"));
616   responder_.set_on_next([&data](ConstByteSpan payload) { data = payload; });
617 
618   ASSERT_EQ(OkStatus(),
619             server_.ProcessPacket(PacketForRpc(PacketType::CLIENT_STREAM,
620                                                {},
621                                                "hello",
622                                                internal::kLegacyOpenCallId)));
623 
624   EXPECT_EQ(output_.total_packets(), 0u);
625   EXPECT_STREQ(span_as_cstr(data), "hello");
626 }
627 
TEST_F(BidiMethod,ClientStream_CallsOpenIdOnCallWithDifferentId)628 TEST_F(BidiMethod, ClientStream_CallsOpenIdOnCallWithDifferentId) {
629   const uint32_t kSecondCallId = 1625;
630   internal::CallContext context(server_,
631                                 channels_[0].id(),
632                                 service_42_,
633                                 service_42_.method(100),
634                                 internal::kOpenCallId);
635   internal::rpc_lock().lock();
636   auto temp_responder =
637       internal::test::FakeServerReaderWriter(context.ClaimLocked());
638   internal::rpc_lock().unlock();
639   responder_ = std::move(temp_responder);
640 
641   ConstByteSpan data = as_bytes(span("?"));
642   responder_.set_on_next([&data](ConstByteSpan payload) { data = payload; });
643 
644   ASSERT_EQ(OkStatus(),
645             server_.ProcessPacket(PacketForRpc(
646                 PacketType::CLIENT_STREAM, {}, "hello", kSecondCallId)));
647 
648   EXPECT_EQ(output_.total_packets(), 0u);
649   EXPECT_STREQ(span_as_cstr(data), "hello");
650 
651   internal::RpcLockGuard lock;
652   EXPECT_EQ(responder_.as_server_call().id(), kSecondCallId);
653 }
654 
TEST_F(BidiMethod,ClientStream_CallsLegacyOpenIdOnCallWithDifferentId)655 TEST_F(BidiMethod, ClientStream_CallsLegacyOpenIdOnCallWithDifferentId) {
656   const uint32_t kSecondCallId = 1625;
657   internal::CallContext context(server_,
658                                 channels_[0].id(),
659                                 service_42_,
660                                 service_42_.method(100),
661                                 internal::kLegacyOpenCallId);
662   internal::rpc_lock().lock();
663   auto temp_responder =
664       internal::test::FakeServerReaderWriter(context.ClaimLocked());
665   internal::rpc_lock().unlock();
666   responder_ = std::move(temp_responder);
667 
668   ConstByteSpan data = as_bytes(span("?"));
669   responder_.set_on_next([&data](ConstByteSpan payload) { data = payload; });
670 
671   ASSERT_EQ(OkStatus(),
672             server_.ProcessPacket(PacketForRpc(
673                 PacketType::CLIENT_STREAM, {}, "hello", kSecondCallId)));
674 
675   EXPECT_EQ(output_.total_packets(), 0u);
676   EXPECT_STREQ(span_as_cstr(data), "hello");
677 
678   internal::RpcLockGuard lock;
679   EXPECT_EQ(responder_.as_server_call().id(), kSecondCallId);
680 }
681 
TEST_F(BidiMethod,UnregsiterService_AbortsActiveCalls)682 TEST_F(BidiMethod, UnregsiterService_AbortsActiveCalls) {
683   ASSERT_TRUE(responder_.active());
684 
685   Status on_error_status = OkStatus();
686   responder_.set_on_error(
687       [&on_error_status](Status status) { on_error_status = status; });
688 
689   server_.UnregisterService(service_42_);
690 
691   EXPECT_FALSE(responder_.active());
692   EXPECT_EQ(Status::Aborted(), on_error_status);
693 }
694 
TEST_F(BidiMethod,ClientRequestedCompletion_CallsCallback)695 TEST_F(BidiMethod, ClientRequestedCompletion_CallsCallback) {
696   bool called = false;
697 #if PW_RPC_COMPLETION_REQUEST_CALLBACK
698   responder_.set_on_completion_requested([&called]() { called = true; });
699 #endif
700   ASSERT_EQ(OkStatus(),
701             server_.ProcessPacket(
702                 PacketForRpc(PacketType::CLIENT_REQUEST_COMPLETION)));
703 
704   EXPECT_EQ(output_.total_packets(), 0u);
705   EXPECT_EQ(called, PW_RPC_COMPLETION_REQUEST_CALLBACK);
706 }
707 
TEST_F(BidiMethod,ClientRequestedCompletion_CallsCallbackIfEnabled)708 TEST_F(BidiMethod, ClientRequestedCompletion_CallsCallbackIfEnabled) {
709   bool called = false;
710   responder_.set_on_completion_requested_if_enabled(
711       [&called]() { called = true; });
712 
713   ASSERT_EQ(OkStatus(),
714             server_.ProcessPacket(
715                 PacketForRpc(PacketType::CLIENT_REQUEST_COMPLETION)));
716 
717   EXPECT_EQ(output_.total_packets(), 0u);
718   EXPECT_EQ(called, PW_RPC_COMPLETION_REQUEST_CALLBACK);
719 }
720 
TEST_F(BidiMethod,ClientRequestedCompletion_ErrorWhenClosed)721 TEST_F(BidiMethod, ClientRequestedCompletion_ErrorWhenClosed) {
722   const auto end = PacketForRpc(PacketType::CLIENT_REQUEST_COMPLETION);
723   ASSERT_EQ(OkStatus(), server_.ProcessPacket(end));
724   ASSERT_EQ(OkStatus(), server_.ProcessPacket(end));
725 
726   ASSERT_EQ(output_.total_packets(), 0u);
727 }
728 
TEST_F(BidiMethod,ClientRequestedCompletion_ErrorWhenAlreadyClosed)729 TEST_F(BidiMethod, ClientRequestedCompletion_ErrorWhenAlreadyClosed) {
730   ASSERT_EQ(OkStatus(), server_.ProcessPacket(EncodeCancel()));
731   EXPECT_FALSE(responder_.active());
732 
733   const auto end = PacketForRpc(PacketType::CLIENT_REQUEST_COMPLETION);
734   ASSERT_EQ(OkStatus(), server_.ProcessPacket(end));
735 
736   ASSERT_EQ(output_.total_packets(), 1u);
737   const Packet& packet =
738       static_cast<internal::test::FakeChannelOutput&>(output_).last_packet();
739   EXPECT_EQ(packet.type(), PacketType::SERVER_ERROR);
740   EXPECT_EQ(packet.status(), Status::FailedPrecondition());
741 }
742 
743 class ServerStreamingMethod : public BasicServer {
744  protected:
ServerStreamingMethod()745   ServerStreamingMethod() {
746     internal::CallContext context(server_,
747                                   channels_[0].id(),
748                                   service_42_,
749                                   service_42_.method(100),
750                                   kDefaultCallId);
751     internal::rpc_lock().lock();
752     internal::test::FakeServerWriter responder_temp(context.ClaimLocked());
753     internal::rpc_lock().unlock();
754     responder_ = std::move(responder_temp);
755     PW_CHECK(responder_.active());
756   }
757 
758   internal::test::FakeServerWriter responder_;
759 };
760 
TEST_F(ServerStreamingMethod,ClientStream_InvalidArgumentError)761 TEST_F(ServerStreamingMethod, ClientStream_InvalidArgumentError) {
762   ASSERT_EQ(OkStatus(),
763             server_.ProcessPacket(PacketForRpc(PacketType::CLIENT_STREAM)));
764 
765   ASSERT_EQ(output_.total_packets(), 1u);
766   const Packet& packet =
767       static_cast<internal::test::FakeChannelOutput&>(output_).last_packet();
768   EXPECT_EQ(packet.type(), PacketType::SERVER_ERROR);
769   EXPECT_EQ(packet.status(), Status::InvalidArgument());
770 }
771 
TEST_F(ServerStreamingMethod,ClientRequestedCompletion_CallsCallback)772 TEST_F(ServerStreamingMethod, ClientRequestedCompletion_CallsCallback) {
773   bool called = false;
774 #if PW_RPC_COMPLETION_REQUEST_CALLBACK
775   responder_.set_on_completion_requested([&called]() { called = true; });
776 #endif
777 
778   ASSERT_EQ(OkStatus(),
779             server_.ProcessPacket(
780                 PacketForRpc(PacketType::CLIENT_REQUEST_COMPLETION)));
781 
782   EXPECT_EQ(output_.total_packets(), 0u);
783   EXPECT_EQ(called, PW_RPC_COMPLETION_REQUEST_CALLBACK);
784 }
785 
TEST_F(ServerStreamingMethod,ClientRequestedCompletion_CallsCallbackIfEnabled)786 TEST_F(ServerStreamingMethod,
787        ClientRequestedCompletion_CallsCallbackIfEnabled) {
788   bool called = false;
789   responder_.set_on_completion_requested_if_enabled(
790       [&called]() { called = true; });
791 
792   ASSERT_EQ(OkStatus(),
793             server_.ProcessPacket(
794                 PacketForRpc(PacketType::CLIENT_REQUEST_COMPLETION)));
795 
796   EXPECT_EQ(output_.total_packets(), 0u);
797   EXPECT_EQ(called, PW_RPC_COMPLETION_REQUEST_CALLBACK);
798 }
799 
TEST_F(ServerStreamingMethod,ClientRequestedCompletion_ErrorWhenClosed)800 TEST_F(ServerStreamingMethod, ClientRequestedCompletion_ErrorWhenClosed) {
801   const auto end = PacketForRpc(PacketType::CLIENT_REQUEST_COMPLETION);
802   ASSERT_EQ(OkStatus(), server_.ProcessPacket(end));
803   ASSERT_EQ(OkStatus(), server_.ProcessPacket(end));
804 
805   ASSERT_EQ(output_.total_packets(), 0u);
806 }
807 
TEST_F(ServerStreamingMethod,ClientRequestedCompletion_ErrorWhenAlreadyClosed)808 TEST_F(ServerStreamingMethod,
809        ClientRequestedCompletion_ErrorWhenAlreadyClosed) {
810   ASSERT_EQ(OkStatus(), server_.ProcessPacket(EncodeCancel()));
811   EXPECT_FALSE(responder_.active());
812 
813   const auto end = PacketForRpc(PacketType::CLIENT_REQUEST_COMPLETION);
814   ASSERT_EQ(OkStatus(), server_.ProcessPacket(end));
815 
816   ASSERT_EQ(output_.total_packets(), 1u);
817   const Packet& packet =
818       static_cast<internal::test::FakeChannelOutput&>(output_).last_packet();
819   EXPECT_EQ(packet.type(), PacketType::SERVER_ERROR);
820   EXPECT_EQ(packet.status(), Status::FailedPrecondition());
821 }
822 
823 }  // namespace
824 }  // namespace pw::rpc
825