1 // Copyright 2020 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 // https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14 #pragma once
15
16 #include <tuple>
17 #include <utility>
18
19 #include "pw_assert/light.h"
20 #include "pw_containers/vector.h"
21 #include "pw_preprocessor/arguments.h"
22 #include "pw_rpc/channel.h"
23 #include "pw_rpc/internal/hash.h"
24 #include "pw_rpc/internal/method_lookup.h"
25 #include "pw_rpc/internal/nanopb_method.h"
26 #include "pw_rpc/internal/packet.h"
27 #include "pw_rpc/internal/server.h"
28
29 namespace pw::rpc {
30
31 // Declares a context object that may be used to invoke an RPC. The context is
32 // declared with the name of the implemented service and the method to invoke.
33 // The RPC can then be invoked with the call method.
34 //
35 // For a unary RPC, context.call(request) returns the status, and the response
36 // struct can be accessed via context.response().
37 //
38 // PW_NANOPB_TEST_METHOD_CONTEXT(my::CoolService, TheMethod) context;
39 // EXPECT_EQ(OkStatus(), context.call({.some_arg = 123}));
40 // EXPECT_EQ(500, context.response().some_response_value);
41 //
42 // For a server streaming RPC, context.call(request) invokes the method. As in a
43 // normal RPC, the method completes when the ServerWriter's Finish method is
44 // called (or it goes out of scope).
45 //
46 // PW_NANOPB_TEST_METHOD_CONTEXT(my::CoolService, TheStreamingMethod) context;
47 // context.call({.some_arg = 123});
48 //
49 // EXPECT_TRUE(context.done()); // Check that the RPC completed
50 // EXPECT_EQ(OkStatus(), context.status()); // Check the status
51 //
52 // EXPECT_EQ(3u, context.responses().size());
53 // EXPECT_EQ(123, context.responses()[0].value); // check individual responses
54 //
55 // for (const MyResponse& response : context.responses()) {
56 // // iterate over the responses
57 // }
58 //
59 // PW_NANOPB_TEST_METHOD_CONTEXT forwards its constructor arguments to the
60 // underlying serivce. For example:
61 //
62 // PW_NANOPB_TEST_METHOD_CONTEXT(MyService, Go) context(service, args);
63 //
64 // PW_NANOPB_TEST_METHOD_CONTEXT takes two optional arguments:
65 //
66 // size_t kMaxResponse: maximum responses to store; ignored unless streaming
67 // size_t kOutputSizeBytes: buffer size; must be large enough for a packet
68 //
69 // Example:
70 //
71 // PW_NANOPB_TEST_METHOD_CONTEXT(MyService, BestMethod, 3, 256) context;
72 // ASSERT_EQ(3u, context.responses().max_size());
73 //
74 #define PW_NANOPB_TEST_METHOD_CONTEXT(service, method, ...) \
75 ::pw::rpc::NanopbTestMethodContext<service, \
76 &service::method, \
77 ::pw::rpc::internal::Hash(#method), \
78 ##__VA_ARGS__>
79 template <typename Service,
80 auto method,
81 uint32_t kMethodId,
82 size_t kMaxResponse = 4,
83 size_t kOutputSizeBytes = 128>
84 class NanopbTestMethodContext;
85
86 // Internal classes that implement NanopbTestMethodContext.
87 namespace internal::test::nanopb {
88
89 // A ChannelOutput implementation that stores the outgoing payloads and status.
90 template <typename Response>
91 class MessageOutput final : public ChannelOutput {
92 public:
MessageOutput(const internal::NanopbMethod & method,Vector<Response> & responses,std::span<std::byte> buffer)93 MessageOutput(const internal::NanopbMethod& method,
94 Vector<Response>& responses,
95 std::span<std::byte> buffer)
96 : ChannelOutput("internal::test::nanopb::MessageOutput"),
97 method_(method),
98 responses_(responses),
99 buffer_(buffer) {
100 clear();
101 }
102
last_status()103 Status last_status() const { return last_status_; }
set_last_status(Status status)104 void set_last_status(Status status) { last_status_ = status; }
105
total_responses()106 size_t total_responses() const { return total_responses_; }
107
stream_ended()108 bool stream_ended() const { return stream_ended_; }
109
110 void clear();
111
112 private:
AcquireBuffer()113 std::span<std::byte> AcquireBuffer() override { return buffer_; }
114
115 Status SendAndReleaseBuffer(std::span<const std::byte> buffer) override;
116
117 const internal::NanopbMethod& method_;
118 Vector<Response>& responses_;
119 std::span<std::byte> buffer_;
120 size_t total_responses_;
121 bool stream_ended_;
122 Status last_status_;
123 };
124
125 // Collects everything needed to invoke a particular RPC.
126 template <typename Service,
127 auto method,
128 uint32_t kMethodId,
129 size_t kMaxResponse,
130 size_t kOutputSize>
131 struct InvocationContext {
132 using Request = internal::Request<method>;
133 using Response = internal::Response<method>;
134
135 template <typename... Args>
InvocationContextInvocationContext136 InvocationContext(Args&&... args)
137 : output(MethodLookup::GetNanopbMethod<Service, kMethodId>(),
138 responses,
139 buffer),
140 channel(Channel::Create<123>(&output)),
141 server(std::span(&channel, 1)),
142 service(std::forward<Args>(args)...),
143 call(static_cast<internal::Server&>(server),
144 static_cast<internal::Channel&>(channel),
145 service,
146 MethodLookup::GetNanopbMethod<Service, kMethodId>()) {}
147
148 MessageOutput<Response> output;
149
150 rpc::Channel channel;
151 rpc::Server server;
152 Service service;
153 Vector<Response, kMaxResponse> responses;
154 std::array<std::byte, kOutputSize> buffer = {};
155
156 internal::ServerCall call;
157 };
158
159 // Method invocation context for a unary RPC. Returns the status in call() and
160 // provides the response through the response() method.
161 template <typename Service, auto method, uint32_t kMethodId, size_t kOutputSize>
162 class UnaryContext {
163 private:
164 InvocationContext<Service, method, kMethodId, 1, kOutputSize> ctx_;
165
166 public:
167 using Request = typename decltype(ctx_)::Request;
168 using Response = typename decltype(ctx_)::Response;
169
170 template <typename... Args>
UnaryContext(Args &&...args)171 UnaryContext(Args&&... args) : ctx_(std::forward<Args>(args)...) {}
172
service()173 Service& service() { return ctx_.service; }
174
175 // Invokes the RPC with the provided request. Returns the status.
call(const Request & request)176 Status call(const Request& request) {
177 ctx_.output.clear();
178 ctx_.responses.emplace_back();
179 ctx_.responses.back() = {};
180 return CallMethodImplFunction<method>(
181 ctx_.call, request, ctx_.responses.back());
182 }
183
184 // Gives access to the RPC's response.
response()185 const Response& response() const {
186 PW_ASSERT(ctx_.responses.size() > 0u);
187 return ctx_.responses.back();
188 }
189 };
190
191 // Method invocation context for a server streaming RPC.
192 template <typename Service,
193 auto method,
194 uint32_t kMethodId,
195 size_t kMaxResponse,
196 size_t kOutputSize>
197 class ServerStreamingContext {
198 private:
199 InvocationContext<Service, method, kMethodId, kMaxResponse, kOutputSize> ctx_;
200
201 public:
202 using Request = typename decltype(ctx_)::Request;
203 using Response = typename decltype(ctx_)::Response;
204
205 template <typename... Args>
ServerStreamingContext(Args &&...args)206 ServerStreamingContext(Args&&... args) : ctx_(std::forward<Args>(args)...) {}
207
service()208 Service& service() { return ctx_.service; }
209
210 // Invokes the RPC with the provided request.
call(const Request & request)211 void call(const Request& request) {
212 ctx_.output.clear();
213 internal::BaseServerWriter server_writer(ctx_.call);
214 return CallMethodImplFunction<method>(
215 ctx_.call,
216 request,
217 static_cast<ServerWriter<Response>&>(server_writer));
218 }
219
220 // Returns a server writer which writes responses into the context's buffer.
221 // This should not be called alongside call(); use one or the other.
writer()222 ServerWriter<Response> writer() {
223 ctx_.output.clear();
224 internal::BaseServerWriter server_writer(ctx_.call);
225 return std::move(static_cast<ServerWriter<Response>&>(server_writer));
226 }
227
228 // Returns the responses that have been recorded. The maximum number of
229 // responses is responses().max_size(). responses().back() is always the most
230 // recent response, even if total_responses() > responses().max_size().
responses()231 const Vector<Response>& responses() const { return ctx_.responses; }
232
233 // The total number of responses sent, which may be larger than
234 // responses.max_size().
total_responses()235 size_t total_responses() const { return ctx_.output.total_responses(); }
236
237 // True if the stream has terminated.
done()238 bool done() const { return ctx_.output.stream_ended(); }
239
240 // The status of the stream. Only valid if done() is true.
status()241 Status status() const {
242 PW_ASSERT(done());
243 return ctx_.output.last_status();
244 }
245 };
246
247 // Alias to select the type of the context object to use based on which type of
248 // RPC it is for.
249 template <typename Service,
250 auto method,
251 uint32_t kMethodId,
252 size_t kMaxResponse,
253 size_t kOutputSize>
254 using Context = std::tuple_element_t<
255 static_cast<size_t>(internal::MethodTraits<decltype(method)>::kType),
256 std::tuple<UnaryContext<Service, method, kMethodId, kOutputSize>,
257 ServerStreamingContext<Service,
258 method,
259 kMethodId,
260 kMaxResponse,
261 kOutputSize>
262 // TODO(hepler): Support client and bidi streaming
263 >>;
264
265 template <typename Response>
clear()266 void MessageOutput<Response>::clear() {
267 responses_.clear();
268 total_responses_ = 0;
269 stream_ended_ = false;
270 last_status_ = Status::Unknown();
271 }
272
273 template <typename Response>
SendAndReleaseBuffer(std::span<const std::byte> buffer)274 Status MessageOutput<Response>::SendAndReleaseBuffer(
275 std::span<const std::byte> buffer) {
276 PW_ASSERT(!stream_ended_);
277 PW_ASSERT(buffer.data() == buffer_.data());
278
279 if (buffer.empty()) {
280 return OkStatus();
281 }
282
283 Result<internal::Packet> result = internal::Packet::FromBuffer(buffer);
284 PW_ASSERT(result.ok());
285
286 last_status_ = result.value().status();
287
288 switch (result.value().type()) {
289 case internal::PacketType::RESPONSE:
290 // If we run out of space, the back message is always the most recent.
291 responses_.emplace_back();
292 responses_.back() = {};
293 PW_ASSERT(
294 method_.DecodeResponse(result.value().payload(), &responses_.back()));
295 total_responses_ += 1;
296 break;
297 case internal::PacketType::SERVER_STREAM_END:
298 stream_ended_ = true;
299 break;
300 default:
301 PW_CRASH("Unhandled PacketType");
302 }
303 return OkStatus();
304 }
305
306 } // namespace internal::test::nanopb
307
308 template <typename Service,
309 auto method,
310 uint32_t kMethodId,
311 size_t kMaxResponse,
312 size_t kOutputSizeBytes>
313 class NanopbTestMethodContext
314 : public internal::test::nanopb::
315 Context<Service, method, kMethodId, kMaxResponse, kOutputSizeBytes> {
316 public:
317 // Forwards constructor arguments to the service class.
318 template <typename... ServiceArgs>
NanopbTestMethodContext(ServiceArgs &&...service_args)319 NanopbTestMethodContext(ServiceArgs&&... service_args)
320 : internal::test::nanopb::
321 Context<Service, method, kMethodId, kMaxResponse, kOutputSizeBytes>(
322 std::forward<ServiceArgs>(service_args)...) {}
323 };
324
325 } // namespace pw::rpc
326