1 // Copyright 2022 The Pigweed Authors 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 // use this file except in compliance with the License. You may obtain a copy of 5 // the License at 6 // 7 // https://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 // License for the specific language governing permissions and limitations under 13 // the License. 14 #pragma once 15 16 #include <cinttypes> 17 18 #include "pw_rpc/channel.h" 19 #include "pw_rpc/client_server.h" 20 #include "pw_rpc/internal/client_server_testing.h" 21 #include "pw_span/span.h" 22 #include "pw_status/status.h" 23 #include "pw_sync/binary_semaphore.h" 24 #include "pw_sync/mutex.h" 25 #include "pw_thread/thread.h" 26 27 namespace pw::rpc { 28 namespace internal { 29 30 // Expands on a Forwarding Channel Output implementation to allow for 31 // observation of packets. 32 template <typename FakeChannelOutputImpl, 33 size_t kOutputSize, 34 size_t kMaxPackets, 35 size_t kPayloadsBufferSizeBytes> 36 class WatchableChannelOutput 37 : public ForwardingChannelOutput<FakeChannelOutputImpl, 38 kOutputSize, 39 kMaxPackets, 40 kPayloadsBufferSizeBytes> { 41 private: 42 using Base = ForwardingChannelOutput<FakeChannelOutputImpl, 43 kOutputSize, 44 kMaxPackets, 45 kPayloadsBufferSizeBytes>; 46 47 public: MaximumTransmissionUnit()48 size_t MaximumTransmissionUnit() PW_LOCKS_EXCLUDED(mutex_) override { 49 std::lock_guard lock(mutex_); 50 return Base::MaximumTransmissionUnit(); 51 } 52 Send(span<const std::byte> buffer)53 Status Send(span<const std::byte> buffer) PW_LOCKS_EXCLUDED(mutex_) override { 54 Status status; 55 mutex_.lock(); 56 status = Base::Send(buffer); 57 mutex_.unlock(); 58 output_semaphore_.release(); 59 return status; 60 } 61 62 // Returns true if should continue waiting for additional output WaitForOutput()63 bool WaitForOutput() PW_LOCKS_EXCLUDED(mutex_) { 64 output_semaphore_.acquire(); 65 std::lock_guard lock(mutex_); 66 return should_wait_; 67 } 68 StopWaitingForOutput()69 void StopWaitingForOutput() PW_LOCKS_EXCLUDED(mutex_) { 70 std::lock_guard lock(mutex_); 71 should_wait_ = false; 72 output_semaphore_.release(); 73 } 74 75 protected: 76 constexpr WatchableChannelOutput() = default; 77 PacketCount()78 size_t PacketCount() const PW_EXCLUSIVE_LOCKS_REQUIRED(mutex_) override { 79 return Base::PacketCount(); 80 } 81 82 sync::Mutex mutex_; 83 84 private: EncodeNextUnsentPacket(std::array<std::byte,kPayloadsBufferSizeBytes> & packet_buffer)85 Result<ConstByteSpan> EncodeNextUnsentPacket( 86 std::array<std::byte, kPayloadsBufferSizeBytes>& packet_buffer) 87 PW_LOCKS_EXCLUDED(mutex_) override { 88 std::lock_guard lock(mutex_); 89 return Base::EncodeNextUnsentPacket(packet_buffer); 90 } 91 sync::BinarySemaphore output_semaphore_; 92 bool should_wait_ PW_GUARDED_BY(mutex_) = true; 93 }; 94 95 // Provides a testing context with a real client and server 96 template <typename WatchableChannelOutputImpl, 97 size_t kOutputSize = 128, 98 size_t kMaxPackets = 16, 99 size_t kPayloadsBufferSizeBytes = 128> 100 class ClientServerTestContextThreaded 101 : public ClientServerTestContext<WatchableChannelOutputImpl, 102 kOutputSize, 103 kMaxPackets, 104 kPayloadsBufferSizeBytes> { 105 private: 106 using Instance = ClientServerTestContextThreaded<WatchableChannelOutputImpl, 107 kOutputSize, 108 kMaxPackets, 109 kPayloadsBufferSizeBytes>; 110 using Base = ClientServerTestContext<WatchableChannelOutputImpl, 111 kOutputSize, 112 kMaxPackets, 113 kPayloadsBufferSizeBytes>; 114 115 public: ~ClientServerTestContextThreaded()116 ~ClientServerTestContextThreaded() { 117 Base::channel_output_.StopWaitingForOutput(); 118 thread_.join(); 119 } 120 121 protected: ClientServerTestContextThreaded(const thread::Options & options)122 explicit ClientServerTestContextThreaded(const thread::Options& options) 123 : thread_(options, Instance::Run, this) {} 124 125 private: 126 using Base::ForwardNewPackets; Run(void * arg)127 static void Run(void* arg) { 128 auto& ctx = *static_cast<Instance*>(arg); 129 while (ctx.channel_output_.WaitForOutput()) { 130 ctx.ForwardNewPackets(); 131 } 132 } 133 thread::Thread thread_; 134 }; 135 136 } // namespace internal 137 } // namespace pw::rpc 138