1 // Copyright 2021 The Pigweed Authors 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 // use this file except in compliance with the License. You may obtain a copy of 5 // the License at 6 // 7 // https://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 // License for the specific language governing permissions and limitations under 13 // the License. 14 #pragma once 15 16 #include <cstddef> 17 18 #include "pw_assert/assert.h" 19 #include "pw_rpc/channel.h" 20 #include "pw_rpc/internal/call.h" 21 #include "pw_rpc/internal/fake_channel_output.h" 22 #include "pw_rpc/internal/method.h" 23 #include "pw_rpc/internal/packet.h" 24 #include "pw_rpc/server.h" 25 26 namespace pw::rpc::internal::test { 27 28 // Collects everything needed to invoke a particular RPC. 29 template <typename Output, typename Service, uint32_t kMethodId> 30 class InvocationContext { 31 public: 32 InvocationContext(const InvocationContext&) = delete; 33 InvocationContext(InvocationContext&&) = delete; 34 35 InvocationContext& operator=(const InvocationContext&) = delete; 36 InvocationContext& operator=(InvocationContext&&) = delete; 37 service()38 Service& service() { return service_; } service()39 const Service& service() const { return service_; } 40 41 // Sets the channel ID, which defaults to an arbitrary value. set_channel_id(uint32_t channel_id)42 void set_channel_id(uint32_t channel_id) { 43 PW_ASSERT(channel_id != Channel::kUnassignedChannelId); 44 45 // If using dynamic allocation, the channel objects are owned by the 46 // endpoint. The external channel is only used to initialize the endpoint's 47 // channels vector. To update that channel, remove and re-add the channel. 48 PW_ASSERT(server_.CloseChannel(context_.channel_id()).ok()); 49 PW_ASSERT(server_.OpenChannel(channel_id, output_).ok()); 50 51 channel_ = Channel(channel_id, &output_); 52 context_.set_channel_id(channel_id); 53 } 54 total_responses()55 size_t total_responses() const { return responses().size(); } 56 max_packets()57 size_t max_packets() const { return output_.max_packets(); } 58 59 // Returns the responses that have been recorded. The maximum number of 60 // responses is responses().max_size(). responses().back() is always the most 61 // recent response, even if total_responses() > responses().max_size(). responses()62 auto responses() const { 63 return output().payloads(method_type_, 64 channel_.id(), 65 UnwrapServiceId(service().service_id()), 66 kMethodId); 67 } 68 69 // True if the RPC has completed. done()70 bool done() const { return output_.done(); } 71 72 // The status of the stream. Only valid if done() is true. status()73 Status status() const { 74 PW_ASSERT(done()); 75 return output_.last_status(); 76 } 77 SendClientError(Status error)78 void SendClientError(Status error) { 79 using PacketType = ::pw::rpc::internal::pwpb::PacketType; 80 81 std::byte packet[kNoPayloadPacketSizeBytes]; 82 PW_ASSERT(server_ 83 .ProcessPacket(Packet(PacketType::CLIENT_ERROR, 84 channel_.id(), 85 UnwrapServiceId(service_.service_id()), 86 kMethodId, 87 0, 88 {}, 89 error) 90 .Encode(packet) 91 .value()) 92 .ok()); 93 } 94 output()95 const Output& output() const { return output_; } output()96 Output& output() { return output_; } 97 98 protected: 99 // Constructs the invocation context. The args for the ChannelOutput type are 100 // passed in a std::tuple. The args for the Service are forwarded directly 101 // from the callsite. 102 template <typename... ServiceArgs> InvocationContext(const Method & method,MethodType method_type,ServiceArgs &&...service_args)103 InvocationContext(const Method& method, 104 MethodType method_type, 105 ServiceArgs&&... service_args) 106 : method_type_(method_type), 107 channel_(123, &output_), 108 server_(span(static_cast<rpc::Channel*>(&channel_), 1)), 109 service_(std::forward<ServiceArgs>(service_args)...), 110 context_(server_, channel_.id(), service_, method, 0) { 111 server_.RegisterService(service_); 112 } 113 channel_id()114 uint32_t channel_id() const { return channel_.id(); } 115 116 template <size_t kMaxPayloadSize = 32> SendClientStream(ConstByteSpan payload)117 void SendClientStream(ConstByteSpan payload) { 118 using PacketType = ::pw::rpc::internal::pwpb::PacketType; 119 120 std::byte packet[kNoPayloadPacketSizeBytes + 3 + kMaxPayloadSize]; 121 PW_ASSERT(server_ 122 .ProcessPacket(Packet(PacketType::CLIENT_STREAM, 123 channel_.id(), 124 UnwrapServiceId(service_.service_id()), 125 kMethodId, 126 0, 127 payload) 128 .Encode(packet) 129 .value()) 130 .ok()); 131 } 132 SendClientStreamEnd()133 void SendClientStreamEnd() { 134 using PacketType = ::pw::rpc::internal::pwpb::PacketType; 135 136 std::byte packet[kNoPayloadPacketSizeBytes]; 137 PW_ASSERT(server_ 138 .ProcessPacket(Packet(PacketType::CLIENT_STREAM_END, 139 channel_.id(), 140 UnwrapServiceId(service_.service_id()), 141 kMethodId) 142 .Encode(packet) 143 .value()) 144 .ok()); 145 } 146 147 // Invokes the RPC, optionally with a request argument. 148 template <auto kMethod, typename T, typename... RequestArg> call(RequestArg &&...request)149 void call(RequestArg&&... request) PW_LOCKS_EXCLUDED(rpc_lock()) { 150 static_assert(sizeof...(request) <= 1); 151 output_.clear(); 152 T responder = GetResponder<T>(); 153 CallMethodImplFunction<kMethod>( 154 service(), std::forward<RequestArg>(request)..., responder); 155 } 156 157 template <typename T> GetResponder()158 T GetResponder() PW_LOCKS_EXCLUDED(rpc_lock()) { 159 RpcLockGuard lock; 160 return T(call_context().ClaimLocked()); 161 } 162 call_context()163 const internal::CallContext& call_context() const { return context_; } 164 165 private: 166 static constexpr size_t kNoPayloadPacketSizeBytes = 167 2 /* type */ + 2 /* channel */ + 5 /* service */ + 5 /* method */ + 168 2 /* status */; 169 170 const MethodType method_type_; 171 Output output_; 172 Channel channel_; 173 rpc::Server server_; 174 Service service_; 175 internal::CallContext context_; 176 }; 177 178 } // namespace pw::rpc::internal::test 179