• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2021 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14 #pragma once
15 
16 #include <atomic>
17 #include <cstdint>
18 #include <optional>
19 #include <thread>
20 
21 #include "pw_hdlc/encoded_size.h"
22 #include "pw_hdlc/rpc_channel.h"
23 #include "pw_hdlc/rpc_packets.h"
24 #include "pw_rpc/integration_testing.h"
25 #include "pw_span/span.h"
26 #include "pw_status/try.h"
27 #include "pw_stream/socket_stream.h"
28 
29 namespace pw::rpc::integration_test {
30 
31 // Wraps an RPC client with a socket stream and a channel configured to use it.
32 // Useful for integration tests that run across a socket.
33 template <size_t kMaxTransmissionUnit>
34 class SocketClientContext {
35  public:
SocketClientContext()36   constexpr SocketClientContext()
37       : rpc_dispatch_thread_handle_(std::nullopt),
38         channel_output_(stream_, hdlc::kDefaultRpcAddress, "socket"),
39         channel_output_with_manipulator_(channel_output_),
40         channel_(
41             Channel::Create<kChannelId>(&channel_output_with_manipulator_)),
42         client_(span(&channel_, 1)) {}
43 
client()44   Client& client() { return client_; }
45 
46   // Connects to the specified host:port and starts a background thread to read
47   // packets from the socket.
Start(const char * host,uint16_t port)48   Status Start(const char* host, uint16_t port) {
49     PW_TRY(stream_.Connect(host, port));
50     rpc_dispatch_thread_handle_.emplace(&SocketClientContext::ProcessPackets,
51                                         this);
52     return OkStatus();
53   }
54 
55   // Terminates the client, joining the RPC dispatch thread.
56   //
57   // WARNING: This may block forever if the socket is configured to block
58   // indefinitely on reads. Configuring the client socket's `SO_RCVTIMEO` to a
59   // nonzero timeout will allow the dispatch thread to always return.
Terminate()60   void Terminate() {
61     PW_ASSERT(rpc_dispatch_thread_handle_.has_value());
62     should_terminate_.test_and_set();
63     rpc_dispatch_thread_handle_->join();
64   }
65 
GetSocketFd()66   int GetSocketFd() { return stream_.connection_fd(); }
67 
SetEgressChannelManipulator(ChannelManipulator * new_channel_manipulator)68   void SetEgressChannelManipulator(
69       ChannelManipulator* new_channel_manipulator) {
70     channel_output_with_manipulator_.set_channel_manipulator(
71         new_channel_manipulator);
72   }
73 
SetIngressChannelManipulator(ChannelManipulator * new_channel_manipulator)74   void SetIngressChannelManipulator(
75       ChannelManipulator* new_channel_manipulator) {
76     if (new_channel_manipulator != nullptr) {
77       new_channel_manipulator->set_send_packet([&](ConstByteSpan payload) {
78         return client_.ProcessPacket(payload);
79       });
80     }
81     ingress_channel_manipulator_ = new_channel_manipulator;
82   }
83 
84   // Calls Start for localhost.
Start(uint16_t port)85   Status Start(uint16_t port) { return Start("localhost", port); }
86 
87  private:
88   void ProcessPackets();
89 
90   class ChannelOutputWithManipulator : public ChannelOutput {
91    public:
ChannelOutputWithManipulator(ChannelOutput & actual_output)92     ChannelOutputWithManipulator(ChannelOutput& actual_output)
93         : ChannelOutput(actual_output.name()),
94           actual_output_(actual_output),
95           channel_manipulator_(nullptr) {}
96 
set_channel_manipulator(ChannelManipulator * new_channel_manipulator)97     void set_channel_manipulator(ChannelManipulator* new_channel_manipulator) {
98       if (new_channel_manipulator != nullptr) {
99         new_channel_manipulator->set_send_packet(
100             ChannelManipulator::SendCallback([&](
101                 ConstByteSpan
102                     payload) __attribute__((no_thread_safety_analysis)) {
103               return actual_output_.Send(payload);
104             }));
105       }
106       channel_manipulator_ = new_channel_manipulator;
107     }
108 
MaximumTransmissionUnit()109     size_t MaximumTransmissionUnit() override {
110       return actual_output_.MaximumTransmissionUnit();
111     }
Send(span<const std::byte> buffer)112     Status Send(span<const std::byte> buffer) override
113         PW_EXCLUSIVE_LOCKS_REQUIRED(internal::rpc_lock()) {
114       if (channel_manipulator_ != nullptr) {
115         return channel_manipulator_->ProcessAndSend(buffer);
116       }
117 
118       return actual_output_.Send(buffer);
119     }
120 
121    private:
122     ChannelOutput& actual_output_;
123     ChannelManipulator* channel_manipulator_;
124   };
125 
126   std::atomic_flag should_terminate_ = ATOMIC_FLAG_INIT;
127   std::optional<std::thread> rpc_dispatch_thread_handle_;
128   stream::SocketStream stream_;
129   hdlc::FixedMtuChannelOutput<kMaxTransmissionUnit> channel_output_;
130   ChannelOutputWithManipulator channel_output_with_manipulator_;
131   ChannelManipulator* ingress_channel_manipulator_;
132   Channel channel_;
133   Client client_;
134 };
135 
136 template <size_t kMaxTransmissionUnit>
ProcessPackets()137 void SocketClientContext<kMaxTransmissionUnit>::ProcessPackets() {
138   constexpr size_t kDecoderBufferSize =
139       hdlc::Decoder::RequiredBufferSizeForFrameSize(kMaxTransmissionUnit);
140   std::array<std::byte, kDecoderBufferSize> decode_buffer;
141   hdlc::Decoder decoder(decode_buffer);
142 
143   while (true) {
144     std::byte byte[1];
145     Result<ByteSpan> read = stream_.Read(byte);
146 
147     if (should_terminate_.test()) {
148       return;
149     }
150 
151     if (!read.ok() || read->empty()) {
152       continue;
153     }
154 
155     if (auto result = decoder.Process(*byte); result.ok()) {
156       hdlc::Frame& frame = result.value();
157       if (frame.address() == hdlc::kDefaultRpcAddress) {
158         if (ingress_channel_manipulator_ != nullptr) {
159           PW_ASSERT(
160               ingress_channel_manipulator_->ProcessAndSend(frame.data()).ok());
161         } else {
162           PW_ASSERT(client_.ProcessPacket(frame.data()).ok());
163         }
164       }
165     }
166   }
167 }
168 
169 }  // namespace pw::rpc::integration_test
170