• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2022 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14 #pragma once
15 
16 #include <cinttypes>
17 
18 #include "pw_rpc/channel.h"
19 #include "pw_rpc/client_server.h"
20 #include "pw_rpc/internal/client_server_testing.h"
21 #include "pw_span/span.h"
22 #include "pw_status/status.h"
23 #include "pw_sync/binary_semaphore.h"
24 #include "pw_sync/mutex.h"
25 #include "pw_thread/thread.h"
26 
27 namespace pw::rpc {
28 namespace internal {
29 
30 // Expands on a Forwarding Channel Output implementation to allow for
31 // observation of packets.
32 template <typename FakeChannelOutputImpl,
33           size_t kOutputSize,
34           size_t kMaxPackets,
35           size_t kPayloadsBufferSizeBytes>
36 class WatchableChannelOutput
37     : public ForwardingChannelOutput<FakeChannelOutputImpl,
38                                      kOutputSize,
39                                      kMaxPackets,
40                                      kPayloadsBufferSizeBytes> {
41  private:
42   using Base = ForwardingChannelOutput<FakeChannelOutputImpl,
43                                        kOutputSize,
44                                        kMaxPackets,
45                                        kPayloadsBufferSizeBytes>;
46 
47  public:
MaximumTransmissionUnit()48   size_t MaximumTransmissionUnit() PW_LOCKS_EXCLUDED(mutex_) override {
49     std::lock_guard lock(mutex_);
50     return Base::MaximumTransmissionUnit();
51   }
52 
Send(span<const std::byte> buffer)53   Status Send(span<const std::byte> buffer) PW_LOCKS_EXCLUDED(mutex_) override {
54     Status status;
55     mutex_.lock();
56     status = Base::Send(buffer);
57     mutex_.unlock();
58     output_semaphore_.release();
59     return status;
60   }
61 
62   // Returns true if should continue waiting for additional output
WaitForOutput()63   bool WaitForOutput() PW_LOCKS_EXCLUDED(mutex_) {
64     output_semaphore_.acquire();
65     std::lock_guard lock(mutex_);
66     return should_wait_;
67   }
68 
StopWaitingForOutput()69   void StopWaitingForOutput() PW_LOCKS_EXCLUDED(mutex_) {
70     std::lock_guard lock(mutex_);
71     should_wait_ = false;
72     output_semaphore_.release();
73   }
74 
75  protected:
76   constexpr WatchableChannelOutput() = default;
77 
PacketCount()78   size_t PacketCount() const PW_EXCLUSIVE_LOCKS_REQUIRED(mutex_) override {
79     return Base::PacketCount();
80   }
81 
82   sync::Mutex mutex_;
83 
84  private:
EncodeNextUnsentPacket(std::array<std::byte,kPayloadsBufferSizeBytes> & packet_buffer)85   Result<ConstByteSpan> EncodeNextUnsentPacket(
86       std::array<std::byte, kPayloadsBufferSizeBytes>& packet_buffer)
87       PW_LOCKS_EXCLUDED(mutex_) override {
88     std::lock_guard lock(mutex_);
89     return Base::EncodeNextUnsentPacket(packet_buffer);
90   }
91   sync::BinarySemaphore output_semaphore_;
92   bool should_wait_ PW_GUARDED_BY(mutex_) = true;
93 };
94 
95 // Provides a testing context with a real client and server
96 template <typename WatchableChannelOutputImpl,
97           size_t kOutputSize = 128,
98           size_t kMaxPackets = 16,
99           size_t kPayloadsBufferSizeBytes = 128>
100 class ClientServerTestContextThreaded
101     : public ClientServerTestContext<WatchableChannelOutputImpl,
102                                      kOutputSize,
103                                      kMaxPackets,
104                                      kPayloadsBufferSizeBytes> {
105  private:
106   using Instance = ClientServerTestContextThreaded<WatchableChannelOutputImpl,
107                                                    kOutputSize,
108                                                    kMaxPackets,
109                                                    kPayloadsBufferSizeBytes>;
110   using Base = ClientServerTestContext<WatchableChannelOutputImpl,
111                                        kOutputSize,
112                                        kMaxPackets,
113                                        kPayloadsBufferSizeBytes>;
114 
115  public:
~ClientServerTestContextThreaded()116   ~ClientServerTestContextThreaded() {
117     Base::channel_output_.StopWaitingForOutput();
118     thread_.join();
119   }
120 
121  protected:
ClientServerTestContextThreaded(const thread::Options & options)122   explicit ClientServerTestContextThreaded(const thread::Options& options)
123       : thread_(options, Instance::Run, this) {}
124 
125  private:
126   using Base::ForwardNewPackets;
Run(void * arg)127   static void Run(void* arg) {
128     auto& ctx = *static_cast<Instance*>(arg);
129     while (ctx.channel_output_.WaitForOutput()) {
130       ctx.ForwardNewPackets();
131     }
132   }
133   thread::Thread thread_;
134 };
135 
136 }  // namespace internal
137 }  // namespace pw::rpc
138