1 // Copyright 2020 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 // https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14
15 #include "pw_rpc/raw/internal/method.h"
16
17 #include <array>
18
19 #include "pw_bytes/array.h"
20 #include "pw_containers/algorithm.h"
21 #include "pw_protobuf/decoder.h"
22 #include "pw_protobuf/encoder.h"
23 #include "pw_rpc/internal/config.h"
24 #include "pw_rpc/internal/method_impl_tester.h"
25 #include "pw_rpc/internal/test_utils.h"
26 #include "pw_rpc/raw/internal/method_union.h"
27 #include "pw_rpc/service.h"
28 #include "pw_rpc_test_protos/test.pwpb.h"
29 #include "pw_unit_test/framework.h"
30
31 namespace pw::rpc::internal {
32 namespace {
33
34 namespace TestRequest = ::pw::rpc::test::pwpb::TestRequest;
35 namespace TestResponse = ::pw::rpc::test::pwpb::TestResponse;
36
37 // Create a fake service for use with the MethodImplTester.
38 class TestRawService final : public Service {
39 public:
40 // Unary signatures
41
Unary(ConstByteSpan,RawUnaryResponder &)42 void Unary(ConstByteSpan, RawUnaryResponder&) {}
43
StaticUnary(ConstByteSpan,RawUnaryResponder &)44 static void StaticUnary(ConstByteSpan, RawUnaryResponder&) {}
45
AsyncUnary(ConstByteSpan,RawUnaryResponder &)46 void AsyncUnary(ConstByteSpan, RawUnaryResponder&) {}
47
StaticAsyncUnary(ConstByteSpan,RawUnaryResponder &)48 static void StaticAsyncUnary(ConstByteSpan, RawUnaryResponder&) {}
49
UnaryWrongArg(ConstByteSpan,ConstByteSpan)50 void UnaryWrongArg(ConstByteSpan, ConstByteSpan) {}
51
52 // Server streaming signatures
53
ServerStreaming(ConstByteSpan,RawServerWriter &)54 void ServerStreaming(ConstByteSpan, RawServerWriter&) {}
55
StaticServerStreaming(ConstByteSpan,RawServerWriter &)56 static void StaticServerStreaming(ConstByteSpan, RawServerWriter&) {}
57
StaticUnaryVoidReturn(ConstByteSpan,ByteSpan)58 static void StaticUnaryVoidReturn(ConstByteSpan, ByteSpan) {}
59
ServerStreamingBadReturn(ConstByteSpan,RawServerWriter &)60 Status ServerStreamingBadReturn(ConstByteSpan, RawServerWriter&) {
61 return Status();
62 }
63
StaticServerStreamingMissingArg(RawServerWriter &)64 static void StaticServerStreamingMissingArg(RawServerWriter&) {}
65
66 // Client streaming signatures
67
ClientStreaming(RawServerReader &)68 void ClientStreaming(RawServerReader&) {}
69
StaticClientStreaming(RawServerReader &)70 static void StaticClientStreaming(RawServerReader&) {}
71
ClientStreamingBadReturn(RawServerReader &)72 int ClientStreamingBadReturn(RawServerReader&) { return 0; }
73
StaticClientStreamingMissingArg()74 static void StaticClientStreamingMissingArg() {}
75
76 // Bidirectional streaming signatures
77
BidirectionalStreaming(RawServerReaderWriter &)78 void BidirectionalStreaming(RawServerReaderWriter&) {}
79
StaticBidirectionalStreaming(RawServerReaderWriter &)80 static void StaticBidirectionalStreaming(RawServerReaderWriter&) {}
81
BidirectionalStreamingBadReturn(RawServerReaderWriter &)82 int BidirectionalStreamingBadReturn(RawServerReaderWriter&) { return 0; }
83
StaticBidirectionalStreamingMissingArg()84 static void StaticBidirectionalStreamingMissingArg() {}
85 };
86
87 static_assert(MethodImplTests<RawMethod, TestRawService>().Pass());
88
89 template <typename Impl>
90 class FakeServiceBase : public Service {
91 public:
FakeServiceBase(uint32_t id)92 FakeServiceBase(uint32_t id) : Service(id, kMethods) {}
93
94 static constexpr std::array<RawMethodUnion, 5> kMethods = {
95 RawMethod::AsynchronousUnary<&Impl::DoNothing>(10u),
96 RawMethod::AsynchronousUnary<&Impl::AddFive>(11u),
97 RawMethod::ServerStreaming<&Impl::StartStream>(12u),
98 RawMethod::ClientStreaming<&Impl::ClientStream>(13u),
99 RawMethod::BidirectionalStreaming<&Impl::BidirectionalStream>(14u),
100 };
101 };
102
103 class FakeService : public FakeServiceBase<FakeService> {
104 public:
FakeService(uint32_t id)105 FakeService(uint32_t id) : FakeServiceBase(id) {}
106
DoNothing(ConstByteSpan,RawUnaryResponder & responder)107 void DoNothing(ConstByteSpan, RawUnaryResponder& responder) {
108 ASSERT_EQ(OkStatus(), responder.Finish({}, Status::Unknown()));
109 }
110
AddFive(ConstByteSpan request,RawUnaryResponder & responder)111 void AddFive(ConstByteSpan request, RawUnaryResponder& responder) {
112 DecodeRawTestRequest(request);
113
114 std::array<std::byte, 32> response;
115 TestResponse::MemoryEncoder test_response(response);
116 EXPECT_EQ(OkStatus(), test_response.WriteValue(last_request.integer + 5));
117 ConstByteSpan payload(test_response);
118
119 ASSERT_EQ(OkStatus(),
120 responder.Finish(span(response).first(payload.size()),
121 Status::Unauthenticated()));
122 }
123
StartStream(ConstByteSpan request,RawServerWriter & writer)124 void StartStream(ConstByteSpan request, RawServerWriter& writer) {
125 DecodeRawTestRequest(request);
126 last_writer = std::move(writer);
127 }
128
ClientStream(RawServerReader & reader)129 void ClientStream(RawServerReader& reader) {
130 last_reader = std::move(reader);
131 }
132
BidirectionalStream(RawServerReaderWriter & reader_writer)133 void BidirectionalStream(RawServerReaderWriter& reader_writer) {
134 last_reader_writer = std::move(reader_writer);
135 }
136
DecodeRawTestRequest(ConstByteSpan request)137 void DecodeRawTestRequest(ConstByteSpan request) {
138 protobuf::Decoder decoder(request);
139
140 while (decoder.Next().ok()) {
141 TestRequest::Fields field =
142 static_cast<TestRequest::Fields>(decoder.FieldNumber());
143
144 switch (field) {
145 case TestRequest::Fields::kInteger:
146 ASSERT_EQ(OkStatus(), decoder.ReadInt64(&last_request.integer));
147 break;
148 case TestRequest::Fields::kStatusCode:
149 ASSERT_EQ(OkStatus(), decoder.ReadUint32(&last_request.status_code));
150 break;
151 }
152 }
153 }
154
155 struct {
156 int64_t integer;
157 uint32_t status_code;
158 } last_request;
159
160 RawServerWriter last_writer;
161 RawServerReader last_reader;
162 RawServerReaderWriter last_reader_writer;
163 };
164
165 constexpr const RawMethod& kAsyncUnary0 =
166 std::get<0>(FakeServiceBase<FakeService>::kMethods).raw_method();
167 constexpr const RawMethod& kAsyncUnary1 =
168 std::get<1>(FakeServiceBase<FakeService>::kMethods).raw_method();
169 constexpr const RawMethod& kServerStream =
170 std::get<2>(FakeServiceBase<FakeService>::kMethods).raw_method();
171 constexpr const RawMethod& kClientStream =
172 std::get<3>(FakeServiceBase<FakeService>::kMethods).raw_method();
173 constexpr const RawMethod& kBidirectionalStream =
174 std::get<4>(FakeServiceBase<FakeService>::kMethods).raw_method();
175
TEST(RawMethod,AsyncUnaryRpc1_SendsResponse)176 TEST(RawMethod, AsyncUnaryRpc1_SendsResponse) {
177 std::byte buffer[16];
178 stream::MemoryWriter writer(buffer);
179 TestRequest::StreamEncoder test_request(writer, ByteSpan());
180 ASSERT_EQ(OkStatus(), test_request.WriteInteger(456));
181 ASSERT_EQ(OkStatus(), test_request.WriteStatusCode(7));
182
183 ServerContextForTest<FakeService> context(kAsyncUnary1);
184 rpc_lock().lock();
185 kAsyncUnary1.Invoke(context.get(), context.request(writer.WrittenData()));
186
187 EXPECT_EQ(context.service().last_request.integer, 456);
188 EXPECT_EQ(context.service().last_request.status_code, 7u);
189
190 const Packet& response = context.output().last_packet();
191 EXPECT_EQ(response.status(), Status::Unauthenticated());
192
193 protobuf::Decoder decoder(response.payload());
194 ASSERT_TRUE(decoder.Next().ok());
195 int64_t value;
196 EXPECT_EQ(decoder.ReadInt64(&value), OkStatus());
197 EXPECT_EQ(value, 461);
198 }
199
TEST(RawMethod,AsyncUnaryRpc0_SendsResponse)200 TEST(RawMethod, AsyncUnaryRpc0_SendsResponse) {
201 ServerContextForTest<FakeService> context(kAsyncUnary0);
202
203 rpc_lock().lock();
204 kAsyncUnary0.Invoke(context.get(), context.request({}));
205
206 const Packet& packet = context.output().last_packet();
207 EXPECT_EQ(pwpb::PacketType::RESPONSE, packet.type());
208 EXPECT_EQ(Status::Unknown(), packet.status());
209 EXPECT_EQ(context.service_id(), packet.service_id());
210 EXPECT_EQ(kAsyncUnary0.id(), packet.method_id());
211 }
212
TEST(RawMethod,ServerStreamingRpc_SendsNothingWhenInitiallyCalled)213 TEST(RawMethod, ServerStreamingRpc_SendsNothingWhenInitiallyCalled) {
214 std::byte buffer[16];
215 stream::MemoryWriter writer(buffer);
216 TestRequest::StreamEncoder test_request(writer, ByteSpan());
217 ASSERT_EQ(OkStatus(), test_request.WriteInteger(777));
218 ASSERT_EQ(OkStatus(), test_request.WriteStatusCode(2));
219
220 ServerContextForTest<FakeService> context(kServerStream);
221 rpc_lock().lock();
222 kServerStream.Invoke(context.get(), context.request(writer.WrittenData()));
223
224 EXPECT_EQ(0u, context.output().total_packets());
225 EXPECT_EQ(777, context.service().last_request.integer);
226 EXPECT_EQ(2u, context.service().last_request.status_code);
227 EXPECT_TRUE(context.service().last_writer.active());
228 EXPECT_EQ(OkStatus(), context.service().last_writer.Finish());
229 }
230
TEST(RawMethod,ServerReader_HandlesRequests)231 TEST(RawMethod, ServerReader_HandlesRequests) {
232 ServerContextForTest<FakeService> context(kClientStream);
233 rpc_lock().lock();
234 kClientStream.Invoke(context.get(), context.request({}));
235
236 ConstByteSpan request;
237 context.service().last_reader.set_on_next(
238 [&request](ConstByteSpan req) { request = req; });
239
240 constexpr const char kRequestValue[] = "This is a request payload!!!";
241 std::array<std::byte, 128> encoded_request = {};
242 auto encoded = context.client_stream(as_bytes(span(kRequestValue)))
243 .Encode(encoded_request);
244 ASSERT_EQ(OkStatus(), encoded.status());
245 ASSERT_EQ(OkStatus(), context.server().ProcessPacket(*encoded));
246
247 EXPECT_STREQ(reinterpret_cast<const char*>(request.data()), kRequestValue);
248 }
249
TEST(RawMethod,ServerReaderWriter_WritesResponses)250 TEST(RawMethod, ServerReaderWriter_WritesResponses) {
251 ServerContextForTest<FakeService> context(kBidirectionalStream);
252 rpc_lock().lock();
253 kBidirectionalStream.Invoke(context.get(), context.request({}));
254
255 constexpr const char kRequestValue[] = "O_o";
256 const auto kRequestBytes = as_bytes(span(kRequestValue));
257 EXPECT_EQ(OkStatus(),
258 context.service().last_reader_writer.Write(kRequestBytes));
259
260 std::array<std::byte, 128> encoded_response = {};
261 auto encoded = context.server_stream(kRequestBytes).Encode(encoded_response);
262 ASSERT_EQ(OkStatus(), encoded.status());
263
264 ConstByteSpan sent_payload = context.output().last_packet().payload();
265 EXPECT_TRUE(pw::containers::Equal(kRequestBytes, sent_payload));
266 }
267
TEST(RawServerWriter,Write_SendsPayload)268 TEST(RawServerWriter, Write_SendsPayload) {
269 ServerContextForTest<FakeService> context(kServerStream);
270 rpc_lock().lock();
271 kServerStream.Invoke(context.get(), context.request({}));
272
273 constexpr auto data = bytes::Array<0x0d, 0x06, 0xf0, 0x0d>();
274 EXPECT_EQ(context.service().last_writer.Write(data), OkStatus());
275
276 const internal::Packet& packet = context.output().last_packet();
277 EXPECT_EQ(packet.type(), pwpb::PacketType::SERVER_STREAM);
278 EXPECT_EQ(packet.channel_id(), context.channel_id());
279 EXPECT_EQ(packet.service_id(), context.service_id());
280 EXPECT_EQ(packet.method_id(), context.get().method().id());
281 EXPECT_EQ(std::memcmp(packet.payload().data(), data.data(), data.size()), 0);
282 EXPECT_EQ(packet.status(), OkStatus());
283 }
284
TEST(RawServerWriter,Write_EmptyBuffer)285 TEST(RawServerWriter, Write_EmptyBuffer) {
286 ServerContextForTest<FakeService> context(kServerStream);
287 rpc_lock().lock();
288 kServerStream.Invoke(context.get(), context.request({}));
289
290 ASSERT_EQ(context.service().last_writer.Write({}), OkStatus());
291
292 const internal::Packet& packet = context.output().last_packet();
293 EXPECT_EQ(packet.type(), pwpb::PacketType::SERVER_STREAM);
294 EXPECT_EQ(packet.channel_id(), context.channel_id());
295 EXPECT_EQ(packet.service_id(), context.service_id());
296 EXPECT_EQ(packet.method_id(), context.get().method().id());
297 EXPECT_TRUE(packet.payload().empty());
298 EXPECT_EQ(packet.status(), OkStatus());
299 }
300
TEST(RawServerWriter,Write_Closed_ReturnsFailedPrecondition)301 TEST(RawServerWriter, Write_Closed_ReturnsFailedPrecondition) {
302 ServerContextForTest<FakeService> context(kServerStream);
303 rpc_lock().lock();
304 kServerStream.Invoke(context.get(), context.request({}));
305
306 EXPECT_EQ(OkStatus(), context.service().last_writer.Finish());
307 constexpr auto data = bytes::Array<0x0d, 0x06, 0xf0, 0x0d>();
308 EXPECT_EQ(context.service().last_writer.Write(data),
309 Status::FailedPrecondition());
310 }
311
TEST(RawServerWriter,Write_PayloadTooLargeForEncodingBuffer_ReturnsInternal)312 TEST(RawServerWriter, Write_PayloadTooLargeForEncodingBuffer_ReturnsInternal) {
313 // The payload is never too large for the encoding buffer when dynamic
314 // allocation is enabled.
315 #if PW_RPC_DYNAMIC_ALLOCATION
316 GTEST_SKIP();
317 #endif // !PW_RPC_DYNAMIC_ALLOCATION
318
319 ServerContextForTest<FakeService> context(kServerStream);
320 rpc_lock().lock();
321 kServerStream.Invoke(context.get(), context.request({}));
322
323 // A kEncodingBufferSizeBytes payload will never fit in the encoding buffer.
324 static constexpr std::array<std::byte, cfg::kEncodingBufferSizeBytes>
325 kBigData = {};
326 EXPECT_EQ(context.service().last_writer.Write(kBigData), Status::Internal());
327 }
328
329 } // namespace
330 } // namespace pw::rpc::internal
331