• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2021 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14 #pragma once
15 
16 #include <cstddef>
17 
18 #include "pw_assert/assert.h"
19 #include "pw_rpc/channel.h"
20 #include "pw_rpc/internal/call.h"
21 #include "pw_rpc/internal/fake_channel_output.h"
22 #include "pw_rpc/internal/method.h"
23 #include "pw_rpc/internal/packet.h"
24 #include "pw_rpc/server.h"
25 
26 namespace pw::rpc::internal::test {
27 
28 // Collects everything needed to invoke a particular RPC.
29 template <typename Output, typename Service, uint32_t kMethodId>
30 class InvocationContext {
31  public:
32   InvocationContext(const InvocationContext&) = delete;
33   InvocationContext(InvocationContext&&) = delete;
34 
35   InvocationContext& operator=(const InvocationContext&) = delete;
36   InvocationContext& operator=(InvocationContext&&) = delete;
37 
service()38   Service& service() { return service_; }
service()39   const Service& service() const { return service_; }
40 
41   // Sets the channel ID, which defaults to an arbitrary value.
set_channel_id(uint32_t channel_id)42   void set_channel_id(uint32_t channel_id) {
43     PW_ASSERT(channel_id != Channel::kUnassignedChannelId);
44 
45     // If using dynamic allocation, the channel objects are owned by the
46     // endpoint. The external channel is only used to initialize the endpoint's
47     // channels vector. To update that channel, remove and re-add the channel.
48     PW_ASSERT(server_.CloseChannel(context_.channel_id()).ok());
49     PW_ASSERT(server_.OpenChannel(channel_id, output_).ok());
50 
51     channel_ = Channel(channel_id, &output_);
52     context_.set_channel_id(channel_id);
53   }
54 
total_responses()55   size_t total_responses() const { return responses().size(); }
56 
max_packets()57   size_t max_packets() const { return output_.max_packets(); }
58 
59   // Returns the responses that have been recorded. The maximum number of
60   // responses is responses().max_size(). responses().back() is always the most
61   // recent response, even if total_responses() > responses().max_size().
responses()62   auto responses() const {
63     return output().payloads(method_type_,
64                              channel_.id(),
65                              UnwrapServiceId(service().service_id()),
66                              kMethodId);
67   }
68 
69   // True if the RPC has completed.
done()70   bool done() const { return output_.done(); }
71 
72   // The status of the stream. Only valid if done() is true.
status()73   Status status() const {
74     PW_ASSERT(done());
75     return output_.last_status();
76   }
77 
SendClientError(Status error)78   void SendClientError(Status error) {
79     using PacketType = ::pw::rpc::internal::pwpb::PacketType;
80 
81     std::byte packet[kNoPayloadPacketSizeBytes];
82     PW_ASSERT(server_
83                   .ProcessPacket(Packet(PacketType::CLIENT_ERROR,
84                                         channel_.id(),
85                                         UnwrapServiceId(service_.service_id()),
86                                         kMethodId,
87                                         0,
88                                         {},
89                                         error)
90                                      .Encode(packet)
91                                      .value())
92                   .ok());
93   }
94 
output()95   const Output& output() const { return output_; }
output()96   Output& output() { return output_; }
97 
98  protected:
99   // Constructs the invocation context. The args for the ChannelOutput type are
100   // passed in a std::tuple. The args for the Service are forwarded directly
101   // from the callsite.
102   template <typename... ServiceArgs>
InvocationContext(const Method & method,MethodType method_type,ServiceArgs &&...service_args)103   InvocationContext(const Method& method,
104                     MethodType method_type,
105                     ServiceArgs&&... service_args)
106       : method_type_(method_type),
107         channel_(123, &output_),
108         server_(span(static_cast<rpc::Channel*>(&channel_), 1)),
109         service_(std::forward<ServiceArgs>(service_args)...),
110         context_(server_, channel_.id(), service_, method, 0) {
111     server_.RegisterService(service_);
112   }
113 
channel_id()114   uint32_t channel_id() const { return channel_.id(); }
115 
116   template <size_t kMaxPayloadSize = 32>
SendClientStream(ConstByteSpan payload)117   void SendClientStream(ConstByteSpan payload) {
118     using PacketType = ::pw::rpc::internal::pwpb::PacketType;
119 
120     std::byte packet[kNoPayloadPacketSizeBytes + 3 + kMaxPayloadSize];
121     PW_ASSERT(server_
122                   .ProcessPacket(Packet(PacketType::CLIENT_STREAM,
123                                         channel_.id(),
124                                         UnwrapServiceId(service_.service_id()),
125                                         kMethodId,
126                                         0,
127                                         payload)
128                                      .Encode(packet)
129                                      .value())
130                   .ok());
131   }
132 
SendClientStreamEnd()133   void SendClientStreamEnd() {
134     using PacketType = ::pw::rpc::internal::pwpb::PacketType;
135 
136     std::byte packet[kNoPayloadPacketSizeBytes];
137     PW_ASSERT(server_
138                   .ProcessPacket(Packet(PacketType::CLIENT_STREAM_END,
139                                         channel_.id(),
140                                         UnwrapServiceId(service_.service_id()),
141                                         kMethodId)
142                                      .Encode(packet)
143                                      .value())
144                   .ok());
145   }
146 
147   // Invokes the RPC, optionally with a request argument.
148   template <auto kMethod, typename T, typename... RequestArg>
call(RequestArg &&...request)149   void call(RequestArg&&... request) PW_LOCKS_EXCLUDED(rpc_lock()) {
150     static_assert(sizeof...(request) <= 1);
151     output_.clear();
152     T responder = GetResponder<T>();
153     CallMethodImplFunction<kMethod>(
154         service(), std::forward<RequestArg>(request)..., responder);
155   }
156 
157   template <typename T>
GetResponder()158   T GetResponder() PW_LOCKS_EXCLUDED(rpc_lock()) {
159     RpcLockGuard lock;
160     return T(call_context().ClaimLocked());
161   }
162 
call_context()163   const internal::CallContext& call_context() const { return context_; }
164 
165  private:
166   static constexpr size_t kNoPayloadPacketSizeBytes =
167       2 /* type */ + 2 /* channel */ + 5 /* service */ + 5 /* method */ +
168       2 /* status */;
169 
170   const MethodType method_type_;
171   Output output_;
172   Channel channel_;
173   rpc::Server server_;
174   Service service_;
175   internal::CallContext context_;
176 };
177 
178 }  // namespace pw::rpc::internal::test
179