• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2021 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14 #pragma once
15 
16 #include <cstddef>
17 
18 #include "pw_assert/assert.h"
19 #include "pw_rpc/channel.h"
20 #include "pw_rpc/internal/fake_channel_output.h"
21 #include "pw_rpc/internal/method.h"
22 #include "pw_rpc/internal/packet.h"
23 #include "pw_rpc/server.h"
24 
25 namespace pw::rpc::internal::test {
26 
27 // Collects everything needed to invoke a particular RPC.
28 template <typename Output, typename Service, uint32_t kMethodId>
29 class InvocationContext {
30  public:
31   InvocationContext(const InvocationContext&) = delete;
32   InvocationContext(InvocationContext&&) = delete;
33 
34   InvocationContext& operator=(const InvocationContext&) = delete;
35   InvocationContext& operator=(InvocationContext&&) = delete;
36 
service()37   Service& service() { return service_; }
service()38   const Service& service() const { return service_; }
39 
40   // Sets the channel ID, which defaults to an arbitrary value.
set_channel_id(uint32_t channel_id)41   void set_channel_id(uint32_t channel_id) {
42     PW_ASSERT(channel_id != Channel::kUnassignedChannelId);
43 
44     // If using dynamic allocation, the channel objects are owned by the
45     // endpoint. The external channel is only used to initialize the endpoint's
46     // channels vector. To update that channel, remove and re-add the channel.
47     PW_ASSERT(server_.CloseChannel(context_.channel_id()).ok());
48     PW_ASSERT(server_.OpenChannel(channel_id, output_).ok());
49 
50     channel_ = Channel(channel_id, &output_);
51     context_.set_channel_id(channel_id);
52   }
53 
total_responses()54   size_t total_responses() const { return responses().size(); }
55 
max_packets()56   size_t max_packets() const { return output_.max_packets(); }
57 
58   // Returns the responses that have been recorded. The maximum number of
59   // responses is responses().max_size(). responses().back() is always the most
60   // recent response, even if total_responses() > responses().max_size().
responses()61   auto responses() const {
62     return output().payloads(
63         method_type_, channel_.id(), service().id(), kMethodId);
64   }
65 
66   // True if the RPC has completed.
done()67   bool done() const { return output_.done(); }
68 
69   // The status of the stream. Only valid if done() is true.
status()70   Status status() const {
71     PW_ASSERT(done());
72     return output_.last_status();
73   }
74 
SendClientError(Status error)75   void SendClientError(Status error) {
76     std::byte packet[kNoPayloadPacketSizeBytes];
77     PW_ASSERT(server_
78                   .ProcessPacket(Packet(PacketType::CLIENT_ERROR,
79                                         channel_.id(),
80                                         service_.id(),
81                                         kMethodId,
82                                         0,
83                                         {},
84                                         error)
85                                      .Encode(packet)
86                                      .value(),
87                                  output_)
88                   .ok());
89   }
90 
output()91   const Output& output() const { return output_; }
output()92   Output& output() { return output_; }
93 
94  protected:
95   // Constructs the invocation context. The args for the ChannelOutput type are
96   // passed in a std::tuple. The args for the Service are forwarded directly
97   // from the callsite.
98   template <typename... ServiceArgs>
InvocationContext(const Method & method,MethodType method_type,ServiceArgs &&...service_args)99   InvocationContext(const Method& method,
100                     MethodType method_type,
101                     ServiceArgs&&... service_args)
102       : method_type_(method_type),
103         channel_(123, &output_),
104         server_(std::span(static_cast<rpc::Channel*>(&channel_), 1)),
105         service_(std::forward<ServiceArgs>(service_args)...),
106         context_(server_, channel_.id(), service_, method, 0) {
107     server_.RegisterService(service_);
108   }
109 
channel_id()110   uint32_t channel_id() const { return channel_.id(); }
111 
112   template <size_t kMaxPayloadSize = 32>
SendClientStream(ConstByteSpan payload)113   void SendClientStream(ConstByteSpan payload) {
114     std::byte packet[kNoPayloadPacketSizeBytes + 3 + kMaxPayloadSize];
115     PW_ASSERT(server_
116                   .ProcessPacket(Packet(PacketType::CLIENT_STREAM,
117                                         channel_.id(),
118                                         service_.id(),
119                                         kMethodId,
120                                         0,
121                                         payload)
122                                      .Encode(packet)
123                                      .value(),
124                                  output_)
125                   .ok());
126   }
127 
SendClientStreamEnd()128   void SendClientStreamEnd() {
129     std::byte packet[kNoPayloadPacketSizeBytes];
130     PW_ASSERT(server_
131                   .ProcessPacket(Packet(PacketType::CLIENT_STREAM_END,
132                                         channel_.id(),
133                                         service_.id(),
134                                         kMethodId)
135                                      .Encode(packet)
136                                      .value(),
137                                  output_)
138                   .ok());
139   }
140 
141   // Invokes the RPC, optionally with a request argument.
142   template <auto kMethod, typename T, typename... RequestArg>
call(RequestArg &&...request)143   void call(RequestArg&&... request) {
144     static_assert(sizeof...(request) <= 1);
145     output_.clear();
146     T responder = GetResponder<T>();
147     CallMethodImplFunction<kMethod>(
148         service(), std::forward<RequestArg>(request)..., responder);
149   }
150 
151   template <typename T>
GetResponder()152   T GetResponder() {
153     return T(call_context());
154   }
155 
call_context()156   const internal::CallContext& call_context() const { return context_; }
157 
158  private:
159   static constexpr size_t kNoPayloadPacketSizeBytes =
160       2 /* type */ + 2 /* channel */ + 5 /* service */ + 5 /* method */ +
161       2 /* status */;
162 
163   const MethodType method_type_;
164   Output output_;
165   Channel channel_;
166   rpc::Server server_;
167   Service service_;
168   internal::CallContext context_;
169 };
170 
171 }  // namespace pw::rpc::internal::test
172