• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2020 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14 
15 #include "pw_rpc/nanopb/internal/method.h"
16 
17 #include <array>
18 
19 #include "gtest/gtest.h"
20 #include "pw_containers/algorithm.h"
21 #include "pw_rpc/internal/lock.h"
22 #include "pw_rpc/internal/method_impl_tester.h"
23 #include "pw_rpc/internal/test_utils.h"
24 #include "pw_rpc/nanopb/internal/method_union.h"
25 #include "pw_rpc/service.h"
26 #include "pw_rpc_nanopb_private/internal_test_utils.h"
27 #include "pw_rpc_test_protos/test.pb.h"
28 
29 PW_MODIFY_DIAGNOSTICS_PUSH();
30 PW_MODIFY_DIAGNOSTIC(ignored, "-Wmissing-field-initializers");
31 
32 namespace pw::rpc::internal {
33 namespace {
34 
35 using std::byte;
36 
37 struct FakePb {};
38 
39 // Create a fake service for use with the MethodImplTester.
40 class TestNanopbService final : public Service {
41  public:
42   // Unary signatures
43 
Unary(const FakePb &,FakePb &)44   Status Unary(const FakePb&, FakePb&) { return Status(); }
45 
StaticUnary(const FakePb &,FakePb &)46   static Status StaticUnary(const FakePb&, FakePb&) { return Status(); }
47 
AsyncUnary(const FakePb &,NanopbUnaryResponder<FakePb> &)48   void AsyncUnary(const FakePb&, NanopbUnaryResponder<FakePb>&) {}
49 
StaticAsyncUnary(const FakePb &,NanopbUnaryResponder<FakePb> &)50   static void StaticAsyncUnary(const FakePb&, NanopbUnaryResponder<FakePb>&) {}
51 
UnaryWrongArg(FakePb &,FakePb &)52   Status UnaryWrongArg(FakePb&, FakePb&) { return Status(); }
53 
StaticUnaryVoidReturn(const FakePb &,FakePb &)54   static void StaticUnaryVoidReturn(const FakePb&, FakePb&) {}
55 
56   // Server streaming signatures
57 
ServerStreaming(const FakePb &,NanopbServerWriter<FakePb> &)58   void ServerStreaming(const FakePb&, NanopbServerWriter<FakePb>&) {}
59 
StaticServerStreaming(const FakePb &,NanopbServerWriter<FakePb> &)60   static void StaticServerStreaming(const FakePb&,
61                                     NanopbServerWriter<FakePb>&) {}
62 
ServerStreamingBadReturn(const FakePb &,NanopbServerWriter<FakePb> &)63   int ServerStreamingBadReturn(const FakePb&, NanopbServerWriter<FakePb>&) {
64     return 5;
65   }
66 
StaticServerStreamingMissingArg(NanopbServerWriter<FakePb> &)67   static void StaticServerStreamingMissingArg(NanopbServerWriter<FakePb>&) {}
68 
69   // Client streaming signatures
70 
ClientStreaming(NanopbServerReader<FakePb,FakePb> &)71   void ClientStreaming(NanopbServerReader<FakePb, FakePb>&) {}
72 
StaticClientStreaming(NanopbServerReader<FakePb,FakePb> &)73   static void StaticClientStreaming(NanopbServerReader<FakePb, FakePb>&) {}
74 
ClientStreamingBadReturn(NanopbServerReader<FakePb,FakePb> &)75   int ClientStreamingBadReturn(NanopbServerReader<FakePb, FakePb>&) {
76     return 0;
77   }
78 
StaticClientStreamingMissingArg()79   static void StaticClientStreamingMissingArg() {}
80 
81   // Bidirectional streaming signatures
82 
BidirectionalStreaming(NanopbServerReaderWriter<FakePb,FakePb> &)83   void BidirectionalStreaming(NanopbServerReaderWriter<FakePb, FakePb>&) {}
84 
StaticBidirectionalStreaming(NanopbServerReaderWriter<FakePb,FakePb> &)85   static void StaticBidirectionalStreaming(
86       NanopbServerReaderWriter<FakePb, FakePb>&) {}
87 
BidirectionalStreamingBadReturn(NanopbServerReaderWriter<FakePb,FakePb> &)88   int BidirectionalStreamingBadReturn(
89       NanopbServerReaderWriter<FakePb, FakePb>&) {
90     return 0;
91   }
92 
StaticBidirectionalStreamingMissingArg()93   static void StaticBidirectionalStreamingMissingArg() {}
94 };
95 
96 struct WrongPb;
97 
98 // Test matches() rejects incorrect request/response types.
99 // clang-format off
100 static_assert(!NanopbMethod::template matches<&TestNanopbService::Unary, WrongPb, FakePb>());
101 static_assert(!NanopbMethod::template matches<&TestNanopbService::Unary, FakePb, WrongPb>());
102 static_assert(!NanopbMethod::template matches<&TestNanopbService::Unary, WrongPb, WrongPb>());
103 static_assert(!NanopbMethod::template matches<&TestNanopbService::StaticUnary, FakePb, WrongPb>());
104 
105 static_assert(!NanopbMethod::template matches<&TestNanopbService::ServerStreaming, WrongPb, FakePb>());
106 static_assert(!NanopbMethod::template matches<&TestNanopbService::StaticServerStreaming, FakePb, WrongPb>());
107 
108 static_assert(!NanopbMethod::template matches<&TestNanopbService::ClientStreaming, WrongPb, FakePb>());
109 static_assert(!NanopbMethod::template matches<&TestNanopbService::StaticClientStreaming, FakePb, WrongPb>());
110 
111 static_assert(!NanopbMethod::template matches<&TestNanopbService::BidirectionalStreaming, WrongPb, FakePb>());
112 static_assert(!NanopbMethod::template matches<&TestNanopbService::StaticBidirectionalStreaming, FakePb, WrongPb>());
113 // clang-format on
114 
115 static_assert(MethodImplTests<NanopbMethod, TestNanopbService>().Pass(
116     MatchesTypes<FakePb, FakePb>(),
117     std::tuple<const NanopbMethodSerde&>(
118         kNanopbMethodSerde<nullptr, nullptr>)));
119 
120 template <typename Impl>
121 class FakeServiceBase : public Service {
122  public:
FakeServiceBase(uint32_t id)123   FakeServiceBase(uint32_t id) : Service(id, kMethods) {}
124 
125   static constexpr std::array<NanopbMethodUnion, 5> kMethods = {
126       NanopbMethod::SynchronousUnary<&Impl::DoNothing>(
127           10u,
128           kNanopbMethodSerde<pw_rpc_test_Empty_fields,
129                              pw_rpc_test_Empty_fields>),
130       NanopbMethod::AsynchronousUnary<&Impl::AddFive>(
131           11u,
132           kNanopbMethodSerde<pw_rpc_test_TestRequest_fields,
133                              pw_rpc_test_TestResponse_fields>),
134       NanopbMethod::ServerStreaming<&Impl::StartStream>(
135           12u,
136           kNanopbMethodSerde<pw_rpc_test_TestRequest_fields,
137                              pw_rpc_test_TestResponse_fields>),
138       NanopbMethod::ClientStreaming<&Impl::ClientStream>(
139           13u,
140           kNanopbMethodSerde<pw_rpc_test_TestRequest_fields,
141                              pw_rpc_test_TestResponse_fields>),
142       NanopbMethod::BidirectionalStreaming<&Impl::BidirectionalStream>(
143           14u,
144           kNanopbMethodSerde<pw_rpc_test_TestRequest_fields,
145                              pw_rpc_test_TestResponse_fields>)};
146 };
147 
148 class FakeService : public FakeServiceBase<FakeService> {
149  public:
FakeService(uint32_t id)150   FakeService(uint32_t id) : FakeServiceBase(id) {}
151 
DoNothing(const pw_rpc_test_Empty &,pw_rpc_test_Empty &)152   Status DoNothing(const pw_rpc_test_Empty&, pw_rpc_test_Empty&) {
153     return Status::Unknown();
154   }
155 
AddFive(const pw_rpc_test_TestRequest & request,NanopbUnaryResponder<pw_rpc_test_TestResponse> & responder)156   void AddFive(const pw_rpc_test_TestRequest& request,
157                NanopbUnaryResponder<pw_rpc_test_TestResponse>& responder) {
158     last_request = request;
159 
160     if (fail_to_encode_async_unary_response) {
161       pw_rpc_test_TestResponse response = pw_rpc_test_TestResponse_init_default;
162       response.repeated_field.funcs.encode =
163           [](pb_ostream_t*, const pb_field_t*, void* const*) { return false; };
164       ASSERT_EQ(OkStatus(), responder.Finish(response, Status::NotFound()));
165     } else {
166       ASSERT_EQ(
167           OkStatus(),
168           responder.Finish({.value = static_cast<int32_t>(request.integer + 5)},
169                            Status::Unauthenticated()));
170     }
171   }
172 
StartStream(const pw_rpc_test_TestRequest & request,NanopbServerWriter<pw_rpc_test_TestResponse> & writer)173   void StartStream(const pw_rpc_test_TestRequest& request,
174                    NanopbServerWriter<pw_rpc_test_TestResponse>& writer) {
175     last_request = request;
176     last_writer = std::move(writer);
177   }
178 
ClientStream(NanopbServerReader<pw_rpc_test_TestRequest,pw_rpc_test_TestResponse> & reader)179   void ClientStream(NanopbServerReader<pw_rpc_test_TestRequest,
180                                        pw_rpc_test_TestResponse>& reader) {
181     last_reader = std::move(reader);
182   }
183 
BidirectionalStream(NanopbServerReaderWriter<pw_rpc_test_TestRequest,pw_rpc_test_TestResponse> & reader_writer)184   void BidirectionalStream(
185       NanopbServerReaderWriter<pw_rpc_test_TestRequest,
186                                pw_rpc_test_TestResponse>& reader_writer) {
187     last_reader_writer = std::move(reader_writer);
188   }
189 
190   bool fail_to_encode_async_unary_response = false;
191 
192   pw_rpc_test_TestRequest last_request;
193   NanopbServerWriter<pw_rpc_test_TestResponse> last_writer;
194   NanopbServerReader<pw_rpc_test_TestRequest, pw_rpc_test_TestResponse>
195       last_reader;
196   NanopbServerReaderWriter<pw_rpc_test_TestRequest, pw_rpc_test_TestResponse>
197       last_reader_writer;
198 };
199 
200 constexpr const NanopbMethod& kSyncUnary =
201     std::get<0>(FakeServiceBase<FakeService>::kMethods).nanopb_method();
202 constexpr const NanopbMethod& kAsyncUnary =
203     std::get<1>(FakeServiceBase<FakeService>::kMethods).nanopb_method();
204 constexpr const NanopbMethod& kServerStream =
205     std::get<2>(FakeServiceBase<FakeService>::kMethods).nanopb_method();
206 constexpr const NanopbMethod& kClientStream =
207     std::get<3>(FakeServiceBase<FakeService>::kMethods).nanopb_method();
208 constexpr const NanopbMethod& kBidirectionalStream =
209     std::get<4>(FakeServiceBase<FakeService>::kMethods).nanopb_method();
210 
TEST(NanopbMethod,AsyncUnaryRpc_SendsResponse)211 TEST(NanopbMethod, AsyncUnaryRpc_SendsResponse) {
212   PW_ENCODE_PB(
213       pw_rpc_test_TestRequest, request, .integer = 123, .status_code = 0);
214 
215   ServerContextForTest<FakeService> context(kAsyncUnary);
216   rpc_lock().lock();
217   kAsyncUnary.Invoke(context.get(), context.request(request));
218 
219   const Packet& response = context.output().last_packet();
220   EXPECT_EQ(response.status(), Status::Unauthenticated());
221 
222   // Field 1 (encoded as 1 << 3) with 128 as the value.
223   constexpr std::byte expected[]{
224       std::byte{0x08}, std::byte{0x80}, std::byte{0x01}};
225 
226   EXPECT_EQ(sizeof(expected), response.payload().size());
227   EXPECT_EQ(0,
228             std::memcmp(expected, response.payload().data(), sizeof(expected)));
229 
230   EXPECT_EQ(123, context.service().last_request.integer);
231 }
232 
TEST(NanopbMethod,SyncUnaryRpc_InvalidPayload_SendsError)233 TEST(NanopbMethod, SyncUnaryRpc_InvalidPayload_SendsError) {
234   std::array<byte, 8> bad_payload{byte{0xFF}, byte{0xAA}, byte{0xDD}};
235 
236   ServerContextForTest<FakeService> context(kSyncUnary);
237   rpc_lock().lock();
238   kSyncUnary.Invoke(context.get(), context.request(bad_payload));
239 
240   const Packet& packet = context.output().last_packet();
241   EXPECT_EQ(pwpb::PacketType::SERVER_ERROR, packet.type());
242   EXPECT_EQ(Status::DataLoss(), packet.status());
243   EXPECT_EQ(context.service_id(), packet.service_id());
244   EXPECT_EQ(kSyncUnary.id(), packet.method_id());
245 }
246 
TEST(NanopbMethod,AsyncUnaryRpc_ResponseEncodingFails_SendsInternalError)247 TEST(NanopbMethod, AsyncUnaryRpc_ResponseEncodingFails_SendsInternalError) {
248   constexpr int64_t value = 0x7FFFFFFF'FFFFFF00ll;
249   PW_ENCODE_PB(
250       pw_rpc_test_TestRequest, request, .integer = value, .status_code = 0);
251 
252   ServerContextForTest<FakeService> context(kAsyncUnary);
253   context.service().fail_to_encode_async_unary_response = true;
254 
255   rpc_lock().lock();
256   kAsyncUnary.Invoke(context.get(), context.request(request));
257 
258   const Packet& packet = context.output().last_packet();
259   EXPECT_EQ(pwpb::PacketType::SERVER_ERROR, packet.type());
260   EXPECT_EQ(Status::Internal(), packet.status());
261   EXPECT_EQ(context.service_id(), packet.service_id());
262   EXPECT_EQ(kAsyncUnary.id(), packet.method_id());
263 
264   EXPECT_EQ(value, context.service().last_request.integer);
265 }
266 
TEST(NanopbMethod,ServerStreamingRpc_SendsNothingWhenInitiallyCalled)267 TEST(NanopbMethod, ServerStreamingRpc_SendsNothingWhenInitiallyCalled) {
268   PW_ENCODE_PB(
269       pw_rpc_test_TestRequest, request, .integer = 555, .status_code = 0);
270 
271   ServerContextForTest<FakeService> context(kServerStream);
272 
273   rpc_lock().lock();
274   kServerStream.Invoke(context.get(), context.request(request));
275 
276   EXPECT_EQ(0u, context.output().total_packets());
277   EXPECT_EQ(555, context.service().last_request.integer);
278 }
279 
TEST(NanopbMethod,ServerWriter_SendsResponse)280 TEST(NanopbMethod, ServerWriter_SendsResponse) {
281   ServerContextForTest<FakeService> context(kServerStream);
282 
283   rpc_lock().lock();
284   kServerStream.Invoke(context.get(), context.request({}));
285 
286   EXPECT_EQ(OkStatus(), context.service().last_writer.Write({.value = 100}));
287 
288   PW_ENCODE_PB(pw_rpc_test_TestResponse, payload, .value = 100);
289   std::array<byte, 128> encoded_response = {};
290   auto encoded = context.server_stream(payload).Encode(encoded_response);
291   ASSERT_EQ(OkStatus(), encoded.status());
292 
293   ConstByteSpan sent_payload = context.output().last_packet().payload();
294   EXPECT_TRUE(pw::containers::Equal(payload, sent_payload));
295 }
296 
TEST(NanopbMethod,ServerWriter_WriteWhenClosed_ReturnsFailedPrecondition)297 TEST(NanopbMethod, ServerWriter_WriteWhenClosed_ReturnsFailedPrecondition) {
298   ServerContextForTest<FakeService> context(kServerStream);
299 
300   rpc_lock().lock();
301   kServerStream.Invoke(context.get(), context.request({}));
302 
303   EXPECT_EQ(OkStatus(), context.service().last_writer.Finish());
304   EXPECT_TRUE(context.service()
305                   .last_writer.Write({.value = 100})
306                   .IsFailedPrecondition());
307 }
308 
TEST(NanopbMethod,ServerWriter_WriteAfterMoved_ReturnsFailedPrecondition)309 TEST(NanopbMethod, ServerWriter_WriteAfterMoved_ReturnsFailedPrecondition) {
310   ServerContextForTest<FakeService> context(kServerStream);
311 
312   rpc_lock().lock();
313   kServerStream.Invoke(context.get(), context.request({}));
314   NanopbServerWriter<pw_rpc_test_TestResponse> new_writer =
315       std::move(context.service().last_writer);
316 
317   EXPECT_EQ(OkStatus(), new_writer.Write({.value = 100}));
318 
319   EXPECT_EQ(Status::FailedPrecondition(),
320             context.service().last_writer.Write({.value = 100}));
321   EXPECT_EQ(Status::FailedPrecondition(),
322             context.service().last_writer.Finish());
323 
324   EXPECT_EQ(OkStatus(), new_writer.Finish());
325 }
326 
TEST(NanopbMethod,ServerStreamingRpc_ResponseEncodingFails_InternalError)327 TEST(NanopbMethod, ServerStreamingRpc_ResponseEncodingFails_InternalError) {
328   ServerContextForTest<FakeService> context(kServerStream);
329 
330   rpc_lock().lock();
331   kServerStream.Invoke(context.get(), context.request({}));
332 
333   EXPECT_EQ(OkStatus(), context.service().last_writer.Write({}));
334 
335   pw_rpc_test_TestResponse response = pw_rpc_test_TestResponse_init_default;
336   response.repeated_field.funcs.encode =
337       [](pb_ostream_t*, const pb_field_t*, void* const*) { return false; };
338   EXPECT_EQ(Status::Internal(), context.service().last_writer.Write(response));
339 }
340 
TEST(NanopbMethod,ServerReader_HandlesRequests)341 TEST(NanopbMethod, ServerReader_HandlesRequests) {
342   ServerContextForTest<FakeService> context(kClientStream);
343 
344   rpc_lock().lock();
345   kClientStream.Invoke(context.get(), context.request({}));
346 
347   pw_rpc_test_TestRequest request_struct{};
348   context.service().last_reader.set_on_next(
349       [&request_struct](const pw_rpc_test_TestRequest& req) {
350         request_struct = req;
351       });
352 
353   PW_ENCODE_PB(
354       pw_rpc_test_TestRequest, request, .integer = 1 << 30, .status_code = 9);
355   std::array<byte, 128> encoded_request = {};
356   auto encoded = context.client_stream(request).Encode(encoded_request);
357   ASSERT_EQ(OkStatus(), encoded.status());
358   ASSERT_EQ(OkStatus(), context.server().ProcessPacket(*encoded));
359 
360   EXPECT_EQ(request_struct.integer, 1 << 30);
361   EXPECT_EQ(request_struct.status_code, 9u);
362 }
363 
TEST(NanopbMethod,ServerReaderWriter_WritesResponses)364 TEST(NanopbMethod, ServerReaderWriter_WritesResponses) {
365   ServerContextForTest<FakeService> context(kBidirectionalStream);
366 
367   rpc_lock().lock();
368   kBidirectionalStream.Invoke(context.get(), context.request({}));
369 
370   EXPECT_EQ(OkStatus(),
371             context.service().last_reader_writer.Write({.value = 100}));
372 
373   PW_ENCODE_PB(pw_rpc_test_TestResponse, payload, .value = 100);
374   std::array<byte, 128> encoded_response = {};
375   auto encoded = context.server_stream(payload).Encode(encoded_response);
376   ASSERT_EQ(OkStatus(), encoded.status());
377 
378   ConstByteSpan sent_payload = context.output().last_packet().payload();
379   EXPECT_TRUE(pw::containers::Equal(payload, sent_payload));
380 }
381 
TEST(NanopbMethod,ServerReaderWriter_HandlesRequests)382 TEST(NanopbMethod, ServerReaderWriter_HandlesRequests) {
383   ServerContextForTest<FakeService> context(kBidirectionalStream);
384 
385   rpc_lock().lock();
386   kBidirectionalStream.Invoke(context.get(), context.request({}));
387 
388   pw_rpc_test_TestRequest request_struct{};
389   context.service().last_reader_writer.set_on_next(
390       [&request_struct](const pw_rpc_test_TestRequest& req) {
391         request_struct = req;
392       });
393 
394   PW_ENCODE_PB(
395       pw_rpc_test_TestRequest, request, .integer = 1 << 29, .status_code = 8);
396   std::array<byte, 128> encoded_request = {};
397   auto encoded = context.client_stream(request).Encode(encoded_request);
398   ASSERT_EQ(OkStatus(), encoded.status());
399   ASSERT_EQ(OkStatus(), context.server().ProcessPacket(*encoded));
400 
401   EXPECT_EQ(request_struct.integer, 1 << 29);
402   EXPECT_EQ(request_struct.status_code, 8u);
403 }
404 
405 }  // namespace
406 }  // namespace pw::rpc::internal
407 
408 PW_MODIFY_DIAGNOSTICS_POP();
409