1 /* SPDX-License-Identifier: MIT */
2 /*
3 * Description: check that multiple receives on the same socket don't get
4 * stalled if multiple wakers race with the socket readiness.
5 */
6 #include <stdio.h>
7 #include <stdlib.h>
8 #include <unistd.h>
9 #include <pthread.h>
10 #include <sys/socket.h>
11
12 #include "liburing.h"
13 #include "helpers.h"
14
15 #define NREQS 64
16
17 struct data {
18 pthread_barrier_t barrier;
19 int fd;
20 };
21
thread(void * data)22 static void *thread(void *data)
23 {
24 struct data *d = data;
25 char buf[64];
26 int ret, i;
27
28 pthread_barrier_wait(&d->barrier);
29 for (i = 0; i < NREQS; i++) {
30 ret = write(d->fd, buf, sizeof(buf));
31 if (ret != 64)
32 fprintf(stderr, "wrote short %d\n", ret);
33 }
34 return NULL;
35 }
36
test(struct io_uring * ring,struct data * d)37 static int test(struct io_uring *ring, struct data *d)
38 {
39 struct io_uring_sqe *sqe;
40 struct io_uring_cqe *cqe;
41 int fd[2], ret, i;
42 char buf[64];
43 pthread_t t;
44 void *ret2;
45
46 if (socketpair(PF_LOCAL, SOCK_STREAM, 0, fd) < 0) {
47 perror("socketpair");
48 return T_EXIT_FAIL;
49 }
50
51 d->fd = fd[1];
52
53 pthread_create(&t, NULL, thread, d);
54
55 for (i = 0; i < NREQS; i++) {
56 sqe = io_uring_get_sqe(ring);
57 io_uring_prep_recv(sqe, fd[0], buf, sizeof(buf), 0);
58 }
59
60 pthread_barrier_wait(&d->barrier);
61
62 ret = io_uring_submit(ring);
63 if (ret != NREQS) {
64 fprintf(stderr, "submit %d\n", ret);
65 return T_EXIT_FAIL;
66 }
67
68 for (i = 0; i < NREQS; i++) {
69 ret = io_uring_wait_cqe(ring, &cqe);
70 if (ret) {
71 fprintf(stderr, "cqe wait %d\n", ret);
72 return T_EXIT_FAIL;
73 }
74 io_uring_cqe_seen(ring, cqe);
75 }
76
77 close(fd[0]);
78 close(fd[1]);
79 pthread_join(t, &ret2);
80 return T_EXIT_PASS;
81 }
82
main(int argc,char * argv[])83 int main(int argc, char *argv[])
84 {
85 struct io_uring ring;
86 struct data d;
87 int i, ret;
88
89 if (argc > 1)
90 return T_EXIT_SKIP;
91
92 pthread_barrier_init(&d.barrier, NULL, 2);
93
94 io_uring_queue_init(NREQS, &ring, 0);
95
96 for (i = 0; i < 1000; i++) {
97 ret = test(&ring, &d);
98 if (ret != T_EXIT_PASS) {
99 fprintf(stderr, "Test failed\n");
100 return T_EXIT_FAIL;
101 }
102 }
103
104 return T_EXIT_PASS;
105 }
106