• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2023 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14 
15 #include "pw_stream/socket_stream.h"
16 
17 #include <thread>
18 
19 #include "gtest/gtest.h"
20 #include "pw_result/result.h"
21 #include "pw_status/status.h"
22 
23 namespace pw::stream {
24 namespace {
25 
26 // Helper function to create a ServerSocket and connect to it via loopback.
RunConnectTest(const char * hostname)27 void RunConnectTest(const char* hostname) {
28   ServerSocket server;
29   EXPECT_EQ(server.Listen(), OkStatus());
30 
31   Result<SocketStream> server_stream = Status::Unavailable();
32   auto accept_thread = std::thread{[&]() { server_stream = server.Accept(); }};
33 
34   SocketStream client;
35   EXPECT_EQ(client.Connect(hostname, server.port()), OkStatus());
36 
37   accept_thread.join();
38   EXPECT_EQ(server_stream.status(), OkStatus());
39 
40   server_stream.value().Close();
41   server.Close();
42   client.Close();
43 }
44 
TEST(SocketStreamTest,ConnectIpv4)45 TEST(SocketStreamTest, ConnectIpv4) { RunConnectTest("127.0.0.1"); }
46 
TEST(SocketStreamTest,ConnectIpv6)47 TEST(SocketStreamTest, ConnectIpv6) { RunConnectTest("::1"); }
48 
TEST(SocketStreamTest,ConnectSpecificPort)49 TEST(SocketStreamTest, ConnectSpecificPort) {
50   // We want to test the "listen on a specific port" functionality,
51   // but hard-coding a port number in a test is inherently problematic, as
52   // port numbers are a global resource.
53   //
54   // We use the automatic port assignment initially to get a port assignment,
55   // close that server, and then use that port explicitly in a new server.
56   //
57   // There's still the possibility that the port will get swiped, but it
58   // shouldn't happen by chance.
59   ServerSocket initial_server;
60   EXPECT_EQ(initial_server.Listen(), OkStatus());
61   uint16_t port = initial_server.port();
62   initial_server.Close();
63 
64   ServerSocket server;
65   EXPECT_EQ(server.Listen(port), OkStatus());
66   EXPECT_EQ(server.port(), port);
67 
68   Result<SocketStream> server_stream = Status::Unavailable();
69   auto accept_thread = std::thread{[&]() { server_stream = server.Accept(); }};
70 
71   SocketStream client;
72   EXPECT_EQ(client.Connect("localhost", server.port()), OkStatus());
73 
74   accept_thread.join();
75   EXPECT_EQ(server_stream.status(), OkStatus());
76 
77   server_stream.value().Close();
78   server.Close();
79   client.Close();
80 }
81 
82 // Helper function to test exchanging data on a pair of sockets.
ExchangeData(SocketStream & stream1,SocketStream & stream2)83 void ExchangeData(SocketStream& stream1, SocketStream& stream2) {
84   auto kPayload1 = as_bytes(span("some data"));
85   auto kPayload2 = as_bytes(span("other bytes"));
86   std::array<char, 100> read_buffer{};
87 
88   // Write data from stream1 and read it from stream2.
89   auto write_status = Status::Unavailable();
90   auto write_thread =
91       std::thread{[&]() { write_status = stream1.Write(kPayload1); }};
92   Result<ByteSpan> read_result =
93       stream2.Read(as_writable_bytes(span(read_buffer)));
94   EXPECT_EQ(read_result.status(), OkStatus());
95   EXPECT_EQ(read_result.value().size(), kPayload1.size());
96   EXPECT_TRUE(
97       std::equal(kPayload1.begin(), kPayload1.end(), read_result->begin()));
98 
99   write_thread.join();
100   EXPECT_EQ(write_status, OkStatus());
101 
102   // Read data in the client and write it from the server.
103   auto read_thread = std::thread{[&]() {
104     read_result = stream1.Read(as_writable_bytes(span(read_buffer)));
105   }};
106   EXPECT_EQ(stream2.Write(kPayload2), OkStatus());
107 
108   read_thread.join();
109   EXPECT_EQ(read_result.status(), OkStatus());
110   EXPECT_EQ(read_result.value().size(), kPayload2.size());
111   EXPECT_TRUE(
112       std::equal(kPayload2.begin(), kPayload2.end(), read_result->begin()));
113 
114   // Close stream1 and attempt to read from stream2.
115   stream1.Close();
116   read_result = stream2.Read(as_writable_bytes(span(read_buffer)));
117   EXPECT_EQ(read_result.status(), Status::OutOfRange());
118 
119   stream2.Close();
120 }
121 
TEST(SocketStreamTest,ReadWrite)122 TEST(SocketStreamTest, ReadWrite) {
123   ServerSocket server;
124   EXPECT_EQ(server.Listen(), OkStatus());
125 
126   Result<SocketStream> server_stream = Status::Unavailable();
127   auto accept_thread = std::thread{[&]() { server_stream = server.Accept(); }};
128 
129   SocketStream client;
130   EXPECT_EQ(client.Connect("localhost", server.port()), OkStatus());
131 
132   accept_thread.join();
133   EXPECT_EQ(server_stream.status(), OkStatus());
134 
135   ExchangeData(client, server_stream.value());
136   server.Close();
137 }
138 
TEST(SocketStreamTest,MultipleClients)139 TEST(SocketStreamTest, MultipleClients) {
140   ServerSocket server;
141   EXPECT_EQ(server.Listen(), OkStatus());
142 
143   Result<SocketStream> server_stream1 = Status::Unavailable();
144   Result<SocketStream> server_stream2 = Status::Unavailable();
145   Result<SocketStream> server_stream3 = Status::Unavailable();
146   auto accept_thread = std::thread{[&]() {
147     server_stream1 = server.Accept();
148     server_stream2 = server.Accept();
149     server_stream3 = server.Accept();
150   }};
151 
152   SocketStream client1;
153   SocketStream client2;
154   SocketStream client3;
155   EXPECT_EQ(client1.Connect("localhost", server.port()), OkStatus());
156   EXPECT_EQ(client2.Connect("localhost", server.port()), OkStatus());
157   EXPECT_EQ(client3.Connect("localhost", server.port()), OkStatus());
158 
159   accept_thread.join();
160   EXPECT_EQ(server_stream1.status(), OkStatus());
161   EXPECT_EQ(server_stream2.status(), OkStatus());
162   EXPECT_EQ(server_stream3.status(), OkStatus());
163 
164   ExchangeData(client1, server_stream1.value());
165   ExchangeData(client2, server_stream2.value());
166   ExchangeData(client3, server_stream3.value());
167   server.Close();
168 }
169 
TEST(SocketStreamTest,ReuseAutomaticServerPort)170 TEST(SocketStreamTest, ReuseAutomaticServerPort) {
171   uint16_t server_port = 0;
172   SocketStream client_stream;
173   ServerSocket server;
174 
175   EXPECT_EQ(server.Listen(0), OkStatus());
176   server_port = server.port();
177   EXPECT_NE(server_port, 0);
178 
179   Result<SocketStream> server_stream = Status::Unavailable();
180   auto accept_thread = std::thread{[&]() { server_stream = server.Accept(); }};
181 
182   EXPECT_EQ(client_stream.Connect(nullptr, server_port), OkStatus());
183   accept_thread.join();
184   ASSERT_EQ(server_stream.status(), OkStatus());
185 
186   server_stream->Close();
187   server.Close();
188 
189   ServerSocket server2;
190   EXPECT_EQ(server2.Listen(server_port), OkStatus());
191 }
192 
193 }  // namespace
194 }  // namespace pw::stream
195