1 /* SPDX-License-Identifier: MIT */
2 /*
3 * Test MSG_WAITALL for recv/recvmsg and include normal sync versions just
4 * for comparison.
5 */
6 #include <assert.h>
7 #include <errno.h>
8 #include <stdio.h>
9 #include <stdlib.h>
10 #include <string.h>
11 #include <unistd.h>
12 #include <fcntl.h>
13 #include <arpa/inet.h>
14 #include <sys/types.h>
15 #include <sys/socket.h>
16 #include <pthread.h>
17
18 #include "liburing.h"
19 #include "helpers.h"
20
21 #define MAX_MSG 128
22
23 struct recv_data {
24 pthread_mutex_t mutex;
25 int use_recvmsg;
26 int use_sync;
27 __be16 port;
28 };
29
get_conn_sock(struct recv_data * rd,int * sockout)30 static int get_conn_sock(struct recv_data *rd, int *sockout)
31 {
32 struct sockaddr_in saddr;
33 int sockfd, ret, val;
34
35 memset(&saddr, 0, sizeof(saddr));
36 saddr.sin_family = AF_INET;
37 saddr.sin_addr.s_addr = htonl(INADDR_ANY);
38
39 sockfd = socket(AF_INET, SOCK_STREAM | SOCK_CLOEXEC, IPPROTO_TCP);
40 if (sockfd < 0) {
41 perror("socket");
42 goto err;
43 }
44
45 val = 1;
46 setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, &val, sizeof(val));
47 setsockopt(sockfd, SOL_SOCKET, SO_REUSEPORT, &val, sizeof(val));
48
49 if (t_bind_ephemeral_port(sockfd, &saddr)) {
50 perror("bind");
51 goto err;
52 }
53 rd->port = saddr.sin_port;
54
55 ret = listen(sockfd, 16);
56 if (ret < 0) {
57 perror("listen");
58 goto err;
59 }
60
61 pthread_mutex_unlock(&rd->mutex);
62
63 ret = accept(sockfd, NULL, NULL);
64 if (ret < 0) {
65 perror("accept");
66 return -1;
67 }
68
69 *sockout = sockfd;
70 return ret;
71 err:
72 pthread_mutex_unlock(&rd->mutex);
73 return -1;
74 }
75
recv_prep(struct io_uring * ring,struct iovec * iov,int * sock,struct recv_data * rd)76 static int recv_prep(struct io_uring *ring, struct iovec *iov, int *sock,
77 struct recv_data *rd)
78 {
79 struct io_uring_sqe *sqe;
80 struct msghdr msg = { };
81 int sockfd, sockout = -1, ret;
82
83 sockfd = get_conn_sock(rd, &sockout);
84 if (sockfd < 0)
85 goto err;
86
87 sqe = io_uring_get_sqe(ring);
88 if (!rd->use_recvmsg) {
89 io_uring_prep_recv(sqe, sockfd, iov->iov_base, iov->iov_len,
90 MSG_WAITALL);
91 } else {
92 msg.msg_namelen = sizeof(struct sockaddr_in);
93 msg.msg_iov = iov;
94 msg.msg_iovlen = 1;
95 io_uring_prep_recvmsg(sqe, sockfd, &msg, MSG_WAITALL);
96 }
97
98 sqe->user_data = 2;
99
100 ret = io_uring_submit(ring);
101 if (ret <= 0) {
102 fprintf(stderr, "submit failed: %d\n", ret);
103 goto err;
104 }
105
106 *sock = sockfd;
107 return 0;
108 err:
109 if (sockout != -1) {
110 shutdown(sockout, SHUT_RDWR);
111 close(sockout);
112 }
113 if (sockfd != -1) {
114 shutdown(sockfd, SHUT_RDWR);
115 close(sockfd);
116 }
117 return 1;
118 }
119
do_recv(struct io_uring * ring)120 static int do_recv(struct io_uring *ring)
121 {
122 struct io_uring_cqe *cqe;
123 int ret;
124
125 ret = io_uring_wait_cqe(ring, &cqe);
126 if (ret) {
127 fprintf(stdout, "wait_cqe: %d\n", ret);
128 goto err;
129 }
130 if (cqe->res == -EINVAL) {
131 fprintf(stdout, "recv not supported, skipping\n");
132 return 0;
133 }
134 if (cqe->res < 0) {
135 fprintf(stderr, "failed cqe: %d\n", cqe->res);
136 goto err;
137 }
138 if (cqe->res != MAX_MSG * sizeof(int)) {
139 fprintf(stderr, "got wrong length: %d\n", cqe->res);
140 goto err;
141 }
142
143 io_uring_cqe_seen(ring, cqe);
144 return 0;
145 err:
146 return 1;
147 }
148
recv_sync(struct recv_data * rd)149 static int recv_sync(struct recv_data *rd)
150 {
151 int buf[MAX_MSG];
152 struct iovec iov = {
153 .iov_base = buf,
154 .iov_len = sizeof(buf),
155 };
156 int i, ret, sockfd, sockout = -1;
157
158 sockfd = get_conn_sock(rd, &sockout);
159
160 if (rd->use_recvmsg) {
161 struct msghdr msg = { };
162
163 msg.msg_namelen = sizeof(struct sockaddr_in);
164 msg.msg_iov = &iov;
165 msg.msg_iovlen = 1;
166 ret = recvmsg(sockfd, &msg, MSG_WAITALL);
167 } else {
168 ret = recv(sockfd, buf, sizeof(buf), MSG_WAITALL);
169 }
170
171 if (ret < 0) {
172 perror("receive");
173 goto err;
174 }
175
176 if (ret != sizeof(buf)) {
177 ret = -1;
178 goto err;
179 }
180
181 for (i = 0; i < MAX_MSG; i++) {
182 if (buf[i] != i)
183 goto err;
184 }
185 ret = 0;
186 err:
187 shutdown(sockout, SHUT_RDWR);
188 shutdown(sockfd, SHUT_RDWR);
189 close(sockout);
190 close(sockfd);
191 return ret;
192 }
193
recv_uring(struct recv_data * rd)194 static int recv_uring(struct recv_data *rd)
195 {
196 int buf[MAX_MSG];
197 struct iovec iov = {
198 .iov_base = buf,
199 .iov_len = sizeof(buf),
200 };
201 struct io_uring_params p = { };
202 struct io_uring ring;
203 int ret, sock = -1, sockout = -1;
204
205 ret = t_create_ring_params(1, &ring, &p);
206 if (ret == T_SETUP_SKIP) {
207 pthread_mutex_unlock(&rd->mutex);
208 ret = 0;
209 goto err;
210 } else if (ret < 0) {
211 pthread_mutex_unlock(&rd->mutex);
212 goto err;
213 }
214
215 sock = recv_prep(&ring, &iov, &sockout, rd);
216 if (ret) {
217 fprintf(stderr, "recv_prep failed: %d\n", ret);
218 goto err;
219 }
220 ret = do_recv(&ring);
221 if (!ret) {
222 int i;
223
224 for (i = 0; i < MAX_MSG; i++) {
225 if (buf[i] != i) {
226 fprintf(stderr, "found %d at %d\n", buf[i], i);
227 ret = 1;
228 break;
229 }
230 }
231 }
232
233 shutdown(sockout, SHUT_RDWR);
234 shutdown(sock, SHUT_RDWR);
235 close(sock);
236 close(sockout);
237 io_uring_queue_exit(&ring);
238 err:
239 if (sock != -1) {
240 shutdown(sock, SHUT_RDWR);
241 close(sock);
242 }
243 if (sockout != -1) {
244 shutdown(sockout, SHUT_RDWR);
245 close(sockout);
246 }
247 return ret;
248 }
249
recv_fn(void * data)250 static void *recv_fn(void *data)
251 {
252 struct recv_data *rd = data;
253
254 if (rd->use_sync)
255 return (void *) (uintptr_t) recv_sync(rd);
256
257 return (void *) (uintptr_t) recv_uring(rd);
258 }
259
do_send(struct recv_data * rd)260 static int do_send(struct recv_data *rd)
261 {
262 struct sockaddr_in saddr;
263 struct io_uring ring;
264 struct io_uring_cqe *cqe;
265 struct io_uring_sqe *sqe;
266 int sockfd, ret, i;
267 struct iovec iov;
268 int *buf;
269
270 ret = io_uring_queue_init(2, &ring, 0);
271 if (ret) {
272 fprintf(stderr, "queue init failed: %d\n", ret);
273 return 1;
274 }
275
276 buf = malloc(MAX_MSG * sizeof(int));
277 for (i = 0; i < MAX_MSG; i++)
278 buf[i] = i;
279
280 sockfd = socket(AF_INET, SOCK_STREAM | SOCK_CLOEXEC, IPPROTO_TCP);
281 if (sockfd < 0) {
282 perror("socket");
283 return 1;
284 }
285
286 pthread_mutex_lock(&rd->mutex);
287 assert(rd->port != 0);
288 memset(&saddr, 0, sizeof(saddr));
289 saddr.sin_family = AF_INET;
290 saddr.sin_port = rd->port;
291 inet_pton(AF_INET, "127.0.0.1", &saddr.sin_addr);
292
293 ret = connect(sockfd, (struct sockaddr *)&saddr, sizeof(saddr));
294 if (ret < 0) {
295 perror("connect");
296 return 1;
297 }
298
299 iov.iov_base = buf;
300 iov.iov_len = MAX_MSG * sizeof(int) / 2;
301 for (i = 0; i < 2; i++) {
302 sqe = io_uring_get_sqe(&ring);
303 io_uring_prep_send(sqe, sockfd, iov.iov_base, iov.iov_len, 0);
304 sqe->user_data = 1;
305
306 ret = io_uring_submit(&ring);
307 if (ret <= 0) {
308 fprintf(stderr, "submit failed: %d\n", ret);
309 goto err;
310 }
311 usleep(10000);
312 iov.iov_base += iov.iov_len;
313 }
314
315 for (i = 0; i < 2; i++) {
316 ret = io_uring_wait_cqe(&ring, &cqe);
317 if (cqe->res == -EINVAL) {
318 fprintf(stdout, "send not supported, skipping\n");
319 close(sockfd);
320 free(buf);
321 return 0;
322 }
323 if (cqe->res != iov.iov_len) {
324 fprintf(stderr, "failed cqe: %d\n", cqe->res);
325 goto err;
326 }
327 io_uring_cqe_seen(&ring, cqe);
328 }
329
330 shutdown(sockfd, SHUT_RDWR);
331 close(sockfd);
332 free(buf);
333 return 0;
334 err:
335 shutdown(sockfd, SHUT_RDWR);
336 close(sockfd);
337 free(buf);
338 return 1;
339 }
340
test(int use_recvmsg,int use_sync)341 static int test(int use_recvmsg, int use_sync)
342 {
343 pthread_mutexattr_t attr;
344 pthread_t recv_thread;
345 struct recv_data rd;
346 int ret;
347 void *retval;
348
349 pthread_mutexattr_init(&attr);
350 pthread_mutexattr_setpshared(&attr, 1);
351 pthread_mutex_init(&rd.mutex, &attr);
352 pthread_mutex_lock(&rd.mutex);
353 rd.use_recvmsg = use_recvmsg;
354 rd.use_sync = use_sync;
355 rd.port = 0;
356
357 ret = pthread_create(&recv_thread, NULL, recv_fn, &rd);
358 if (ret) {
359 fprintf(stderr, "Thread create failed: %d\n", ret);
360 pthread_mutex_unlock(&rd.mutex);
361 return 1;
362 }
363
364 do_send(&rd);
365 pthread_join(recv_thread, &retval);
366 return (intptr_t)retval;
367 }
368
main(int argc,char * argv[])369 int main(int argc, char *argv[])
370 {
371 int ret;
372
373 if (argc > 1)
374 return 0;
375
376 ret = test(0, 0);
377 if (ret) {
378 fprintf(stderr, "test recv failed\n");
379 return ret;
380 }
381
382 ret = test(1, 0);
383 if (ret) {
384 fprintf(stderr, "test recvmsg failed\n");
385 return ret;
386 }
387
388 ret = test(0, 1);
389 if (ret) {
390 fprintf(stderr, "test sync recv failed\n");
391 return ret;
392 }
393
394 ret = test(1, 1);
395 if (ret) {
396 fprintf(stderr, "test sync recvmsg failed\n");
397 return ret;
398 }
399
400 return 0;
401 }
402