1 // Copyright 2021 The Pigweed Authors 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 // use this file except in compliance with the License. You may obtain a copy of 5 // the License at 6 // 7 // https://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 // License for the specific language governing permissions and limitations under 13 // the License. 14 #pragma once 15 16 #include <cstddef> 17 18 #include "pw_assert/assert.h" 19 #include "pw_rpc/channel.h" 20 #include "pw_rpc/internal/fake_channel_output.h" 21 #include "pw_rpc/internal/method.h" 22 #include "pw_rpc/internal/packet.h" 23 #include "pw_rpc/server.h" 24 25 namespace pw::rpc::internal::test { 26 27 // Collects everything needed to invoke a particular RPC. 28 template <typename Output, typename Service, uint32_t kMethodId> 29 class InvocationContext { 30 public: 31 InvocationContext(const InvocationContext&) = delete; 32 InvocationContext(InvocationContext&&) = delete; 33 34 InvocationContext& operator=(const InvocationContext&) = delete; 35 InvocationContext& operator=(InvocationContext&&) = delete; 36 service()37 Service& service() { return service_; } service()38 const Service& service() const { return service_; } 39 40 // Sets the channel ID, which defaults to an arbitrary value. set_channel_id(uint32_t channel_id)41 void set_channel_id(uint32_t channel_id) { 42 PW_ASSERT(channel_id != Channel::kUnassignedChannelId); 43 44 // If using dynamic allocation, the channel objects are owned by the 45 // endpoint. The external channel is only used to initialize the endpoint's 46 // channels vector. To update that channel, remove and re-add the channel. 47 PW_ASSERT(server_.CloseChannel(context_.channel_id()).ok()); 48 PW_ASSERT(server_.OpenChannel(channel_id, output_).ok()); 49 50 channel_ = Channel(channel_id, &output_); 51 context_.set_channel_id(channel_id); 52 } 53 total_responses()54 size_t total_responses() const { return responses().size(); } 55 max_packets()56 size_t max_packets() const { return output_.max_packets(); } 57 58 // Returns the responses that have been recorded. The maximum number of 59 // responses is responses().max_size(). responses().back() is always the most 60 // recent response, even if total_responses() > responses().max_size(). responses()61 auto responses() const { 62 return output().payloads( 63 method_type_, channel_.id(), service().id(), kMethodId); 64 } 65 66 // True if the RPC has completed. done()67 bool done() const { return output_.done(); } 68 69 // The status of the stream. Only valid if done() is true. status()70 Status status() const { 71 PW_ASSERT(done()); 72 return output_.last_status(); 73 } 74 SendClientError(Status error)75 void SendClientError(Status error) { 76 std::byte packet[kNoPayloadPacketSizeBytes]; 77 PW_ASSERT(server_ 78 .ProcessPacket(Packet(PacketType::CLIENT_ERROR, 79 channel_.id(), 80 service_.id(), 81 kMethodId, 82 0, 83 {}, 84 error) 85 .Encode(packet) 86 .value(), 87 output_) 88 .ok()); 89 } 90 output()91 const Output& output() const { return output_; } output()92 Output& output() { return output_; } 93 94 protected: 95 // Constructs the invocation context. The args for the ChannelOutput type are 96 // passed in a std::tuple. The args for the Service are forwarded directly 97 // from the callsite. 98 template <typename... ServiceArgs> InvocationContext(const Method & method,MethodType method_type,ServiceArgs &&...service_args)99 InvocationContext(const Method& method, 100 MethodType method_type, 101 ServiceArgs&&... service_args) 102 : method_type_(method_type), 103 channel_(123, &output_), 104 server_(std::span(static_cast<rpc::Channel*>(&channel_), 1)), 105 service_(std::forward<ServiceArgs>(service_args)...), 106 context_(server_, channel_.id(), service_, method, 0) { 107 server_.RegisterService(service_); 108 } 109 channel_id()110 uint32_t channel_id() const { return channel_.id(); } 111 112 template <size_t kMaxPayloadSize = 32> SendClientStream(ConstByteSpan payload)113 void SendClientStream(ConstByteSpan payload) { 114 std::byte packet[kNoPayloadPacketSizeBytes + 3 + kMaxPayloadSize]; 115 PW_ASSERT(server_ 116 .ProcessPacket(Packet(PacketType::CLIENT_STREAM, 117 channel_.id(), 118 service_.id(), 119 kMethodId, 120 0, 121 payload) 122 .Encode(packet) 123 .value(), 124 output_) 125 .ok()); 126 } 127 SendClientStreamEnd()128 void SendClientStreamEnd() { 129 std::byte packet[kNoPayloadPacketSizeBytes]; 130 PW_ASSERT(server_ 131 .ProcessPacket(Packet(PacketType::CLIENT_STREAM_END, 132 channel_.id(), 133 service_.id(), 134 kMethodId) 135 .Encode(packet) 136 .value(), 137 output_) 138 .ok()); 139 } 140 141 // Invokes the RPC, optionally with a request argument. 142 template <auto kMethod, typename T, typename... RequestArg> call(RequestArg &&...request)143 void call(RequestArg&&... request) { 144 static_assert(sizeof...(request) <= 1); 145 output_.clear(); 146 T responder = GetResponder<T>(); 147 CallMethodImplFunction<kMethod>( 148 service(), std::forward<RequestArg>(request)..., responder); 149 } 150 151 template <typename T> GetResponder()152 T GetResponder() { 153 return T(call_context()); 154 } 155 call_context()156 const internal::CallContext& call_context() const { return context_; } 157 158 private: 159 static constexpr size_t kNoPayloadPacketSizeBytes = 160 2 /* type */ + 2 /* channel */ + 5 /* service */ + 5 /* method */ + 161 2 /* status */; 162 163 const MethodType method_type_; 164 Output output_; 165 Channel channel_; 166 rpc::Server server_; 167 Service service_; 168 internal::CallContext context_; 169 }; 170 171 } // namespace pw::rpc::internal::test 172