1 // Copyright 2023 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 // https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14
15 #include "pw_stream/socket_stream.h"
16
17 #include <thread>
18
19 #include "pw_result/result.h"
20 #include "pw_status/status.h"
21 #include "pw_unit_test/framework.h"
22
23 namespace pw::stream {
24 namespace {
25
26 // Helper function to create a ServerSocket and connect to it via loopback.
RunConnectTest(const char * hostname)27 void RunConnectTest(const char* hostname) {
28 ServerSocket server;
29 EXPECT_EQ(server.Listen(), OkStatus());
30
31 Result<SocketStream> server_stream = Status::Unavailable();
32 auto accept_thread = std::thread{[&]() { server_stream = server.Accept(); }};
33
34 SocketStream client;
35 EXPECT_EQ(client.Connect(hostname, server.port()), OkStatus());
36
37 accept_thread.join();
38 EXPECT_EQ(server_stream.status(), OkStatus());
39
40 server_stream.value().Close();
41 server.Close();
42 client.Close();
43 }
44
TEST(SocketStreamTest,ConnectIpv4)45 TEST(SocketStreamTest, ConnectIpv4) { RunConnectTest("127.0.0.1"); }
46
TEST(SocketStreamTest,ConnectIpv6)47 TEST(SocketStreamTest, ConnectIpv6) { RunConnectTest("::1"); }
48
TEST(SocketStreamTest,ConnectSpecificPort)49 TEST(SocketStreamTest, ConnectSpecificPort) {
50 // We want to test the "listen on a specific port" functionality,
51 // but hard-coding a port number in a test is inherently problematic, as
52 // port numbers are a global resource.
53 //
54 // We use the automatic port assignment initially to get a port assignment,
55 // close that server, and then use that port explicitly in a new server.
56 //
57 // There's still the possibility that the port will get swiped, but it
58 // shouldn't happen by chance.
59 ServerSocket initial_server;
60 EXPECT_EQ(initial_server.Listen(), OkStatus());
61 uint16_t port = initial_server.port();
62 initial_server.Close();
63
64 ServerSocket server;
65 EXPECT_EQ(server.Listen(port), OkStatus());
66 EXPECT_EQ(server.port(), port);
67
68 Result<SocketStream> server_stream = Status::Unavailable();
69 auto accept_thread = std::thread{[&]() { server_stream = server.Accept(); }};
70
71 SocketStream client;
72 EXPECT_EQ(client.Connect("localhost", server.port()), OkStatus());
73 EXPECT_TRUE(client.IsReady());
74
75 accept_thread.join();
76 EXPECT_EQ(server_stream.status(), OkStatus());
77
78 server_stream.value().Close();
79 server.Close();
80 client.Close();
81 EXPECT_FALSE(client.IsReady());
82 }
83
84 // Helper function to test exchanging data on a pair of sockets.
ExchangeData(SocketStream & stream1,SocketStream & stream2)85 void ExchangeData(SocketStream& stream1, SocketStream& stream2) {
86 auto kPayload1 = as_bytes(span("some data"));
87 auto kPayload2 = as_bytes(span("other bytes"));
88 std::array<char, 100> read_buffer{};
89
90 // Write data from stream1 and read it from stream2.
91 auto write_status = Status::Unavailable();
92 auto write_thread =
93 std::thread{[&]() { write_status = stream1.Write(kPayload1); }};
94 Result<ByteSpan> read_result =
95 stream2.Read(as_writable_bytes(span(read_buffer)));
96 EXPECT_EQ(read_result.status(), OkStatus());
97 EXPECT_EQ(read_result.value().size(), kPayload1.size());
98 EXPECT_TRUE(
99 std::equal(kPayload1.begin(), kPayload1.end(), read_result->begin()));
100
101 write_thread.join();
102 EXPECT_EQ(write_status, OkStatus());
103
104 // Read data in the client and write it from the server.
105 auto read_thread = std::thread{[&]() {
106 read_result = stream1.Read(as_writable_bytes(span(read_buffer)));
107 }};
108 EXPECT_EQ(stream2.Write(kPayload2), OkStatus());
109
110 read_thread.join();
111 EXPECT_EQ(read_result.status(), OkStatus());
112 EXPECT_EQ(read_result.value().size(), kPayload2.size());
113 EXPECT_TRUE(
114 std::equal(kPayload2.begin(), kPayload2.end(), read_result->begin()));
115
116 // Close stream1 and attempt to read from stream2.
117 stream1.Close();
118 read_result = stream2.Read(as_writable_bytes(span(read_buffer)));
119 EXPECT_EQ(read_result.status(), Status::OutOfRange());
120
121 stream2.Close();
122 }
123
TEST(SocketStreamTest,ReadWrite)124 TEST(SocketStreamTest, ReadWrite) {
125 ServerSocket server;
126 EXPECT_EQ(server.Listen(), OkStatus());
127
128 Result<SocketStream> server_stream = Status::Unavailable();
129 auto accept_thread = std::thread{[&]() { server_stream = server.Accept(); }};
130
131 SocketStream client;
132 EXPECT_EQ(client.Connect("localhost", server.port()), OkStatus());
133
134 accept_thread.join();
135 EXPECT_EQ(server_stream.status(), OkStatus());
136
137 ExchangeData(client, server_stream.value());
138 server.Close();
139 }
140
TEST(SocketStreamTest,MultipleClients)141 TEST(SocketStreamTest, MultipleClients) {
142 ServerSocket server;
143 EXPECT_EQ(server.Listen(), OkStatus());
144
145 Result<SocketStream> server_stream1 = Status::Unavailable();
146 Result<SocketStream> server_stream2 = Status::Unavailable();
147 Result<SocketStream> server_stream3 = Status::Unavailable();
148 auto accept_thread = std::thread{[&]() {
149 server_stream1 = server.Accept();
150 server_stream2 = server.Accept();
151 server_stream3 = server.Accept();
152 }};
153
154 SocketStream client1;
155 SocketStream client2;
156 SocketStream client3;
157 EXPECT_EQ(client1.Connect("localhost", server.port()), OkStatus());
158 EXPECT_EQ(client2.Connect("localhost", server.port()), OkStatus());
159 EXPECT_EQ(client3.Connect("localhost", server.port()), OkStatus());
160
161 accept_thread.join();
162 EXPECT_EQ(server_stream1.status(), OkStatus());
163 EXPECT_EQ(server_stream2.status(), OkStatus());
164 EXPECT_EQ(server_stream3.status(), OkStatus());
165
166 ExchangeData(client1, server_stream1.value());
167 ExchangeData(client2, server_stream2.value());
168 ExchangeData(client3, server_stream3.value());
169 server.Close();
170 }
171
TEST(SocketStreamTest,ReuseAutomaticServerPort)172 TEST(SocketStreamTest, ReuseAutomaticServerPort) {
173 uint16_t server_port = 0;
174 SocketStream client_stream;
175 ServerSocket server;
176
177 EXPECT_EQ(server.Listen(0), OkStatus());
178 server_port = server.port();
179 EXPECT_NE(server_port, 0);
180
181 Result<SocketStream> server_stream = Status::Unavailable();
182 auto accept_thread = std::thread{[&]() { server_stream = server.Accept(); }};
183
184 EXPECT_EQ(client_stream.Connect(nullptr, server_port), OkStatus());
185 accept_thread.join();
186 ASSERT_EQ(server_stream.status(), OkStatus());
187
188 server_stream->Close();
189 server.Close();
190
191 ServerSocket server2;
192 EXPECT_EQ(server2.Listen(server_port), OkStatus());
193 }
194
195 } // namespace
196 } // namespace pw::stream
197