• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* SPDX-License-Identifier: MIT */
2 /*
3  * Test MSG_WAITALL for recv/recvmsg and include normal sync versions just
4  * for comparison.
5  */
6 #include <assert.h>
7 #include <errno.h>
8 #include <stdio.h>
9 #include <stdlib.h>
10 #include <string.h>
11 #include <unistd.h>
12 #include <fcntl.h>
13 #include <arpa/inet.h>
14 #include <sys/types.h>
15 #include <sys/socket.h>
16 #include <pthread.h>
17 
18 #include "liburing.h"
19 #include "helpers.h"
20 
21 #define MAX_MSG	128
22 
23 struct recv_data {
24 	pthread_mutex_t mutex;
25 	int use_recvmsg;
26 	int use_sync;
27 	__be16 port;
28 };
29 
get_conn_sock(struct recv_data * rd,int * sockout)30 static int get_conn_sock(struct recv_data *rd, int *sockout)
31 {
32 	struct sockaddr_in saddr;
33 	int sockfd, ret, val;
34 
35 	memset(&saddr, 0, sizeof(saddr));
36 	saddr.sin_family = AF_INET;
37 	saddr.sin_addr.s_addr = htonl(INADDR_ANY);
38 
39 	sockfd = socket(AF_INET, SOCK_STREAM | SOCK_CLOEXEC, IPPROTO_TCP);
40 	if (sockfd < 0) {
41 		perror("socket");
42 		goto err;
43 	}
44 
45 	val = 1;
46 	setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, &val, sizeof(val));
47 	setsockopt(sockfd, SOL_SOCKET, SO_REUSEPORT, &val, sizeof(val));
48 
49 	if (t_bind_ephemeral_port(sockfd, &saddr)) {
50 		perror("bind");
51 		goto err;
52 	}
53 	rd->port = saddr.sin_port;
54 
55 	ret = listen(sockfd, 16);
56 	if (ret < 0) {
57 		perror("listen");
58 		goto err;
59 	}
60 
61 	pthread_mutex_unlock(&rd->mutex);
62 
63 	ret = accept(sockfd, NULL, NULL);
64 	if (ret < 0) {
65 		perror("accept");
66 		return -1;
67 	}
68 
69 	*sockout = sockfd;
70 	return ret;
71 err:
72 	pthread_mutex_unlock(&rd->mutex);
73 	return -1;
74 }
75 
recv_prep(struct io_uring * ring,struct iovec * iov,int * sock,struct recv_data * rd)76 static int recv_prep(struct io_uring *ring, struct iovec *iov, int *sock,
77 		     struct recv_data *rd)
78 {
79 	struct io_uring_sqe *sqe;
80 	struct msghdr msg = { };
81 	int sockfd, sockout = -1, ret;
82 
83 	sockfd = get_conn_sock(rd, &sockout);
84 	if (sockfd < 0)
85 		goto err;
86 
87 	sqe = io_uring_get_sqe(ring);
88 	if (!rd->use_recvmsg) {
89 		io_uring_prep_recv(sqe, sockfd, iov->iov_base, iov->iov_len,
90 					MSG_WAITALL);
91 	} else {
92 		msg.msg_namelen = sizeof(struct sockaddr_in);
93 		msg.msg_iov = iov;
94 		msg.msg_iovlen = 1;
95 		io_uring_prep_recvmsg(sqe, sockfd, &msg, MSG_WAITALL);
96 	}
97 
98 	sqe->user_data = 2;
99 
100 	ret = io_uring_submit(ring);
101 	if (ret <= 0) {
102 		fprintf(stderr, "submit failed: %d\n", ret);
103 		goto err;
104 	}
105 
106 	*sock = sockfd;
107 	return 0;
108 err:
109 	if (sockout != -1) {
110 		shutdown(sockout, SHUT_RDWR);
111 		close(sockout);
112 	}
113 	if (sockfd != -1) {
114 		shutdown(sockfd, SHUT_RDWR);
115 		close(sockfd);
116 	}
117 	return 1;
118 }
119 
do_recv(struct io_uring * ring)120 static int do_recv(struct io_uring *ring)
121 {
122 	struct io_uring_cqe *cqe;
123 	int ret;
124 
125 	ret = io_uring_wait_cqe(ring, &cqe);
126 	if (ret) {
127 		fprintf(stdout, "wait_cqe: %d\n", ret);
128 		goto err;
129 	}
130 	if (cqe->res == -EINVAL) {
131 		fprintf(stdout, "recv not supported, skipping\n");
132 		return 0;
133 	}
134 	if (cqe->res < 0) {
135 		fprintf(stderr, "failed cqe: %d\n", cqe->res);
136 		goto err;
137 	}
138 	if (cqe->res != MAX_MSG * sizeof(int)) {
139 		fprintf(stderr, "got wrong length: %d\n", cqe->res);
140 		goto err;
141 	}
142 
143 	io_uring_cqe_seen(ring, cqe);
144 	return 0;
145 err:
146 	return 1;
147 }
148 
recv_sync(struct recv_data * rd)149 static int recv_sync(struct recv_data *rd)
150 {
151 	int buf[MAX_MSG];
152 	struct iovec iov = {
153 		.iov_base = buf,
154 		.iov_len = sizeof(buf),
155 	};
156 	int i, ret, sockfd, sockout = -1;
157 
158 	sockfd = get_conn_sock(rd, &sockout);
159 
160 	if (rd->use_recvmsg) {
161 		struct msghdr msg = { };
162 
163 		msg.msg_namelen = sizeof(struct sockaddr_in);
164 		msg.msg_iov = &iov;
165 		msg.msg_iovlen = 1;
166 		ret = recvmsg(sockfd, &msg, MSG_WAITALL);
167 	} else {
168 		ret = recv(sockfd, buf, sizeof(buf), MSG_WAITALL);
169 	}
170 
171 	if (ret < 0) {
172 		perror("receive");
173 		goto err;
174 	}
175 
176 	if (ret != sizeof(buf)) {
177 		ret = -1;
178 		goto err;
179 	}
180 
181 	for (i = 0; i < MAX_MSG; i++) {
182 		if (buf[i] != i)
183 			goto err;
184 	}
185 	ret = 0;
186 err:
187 	shutdown(sockout, SHUT_RDWR);
188 	shutdown(sockfd, SHUT_RDWR);
189 	close(sockout);
190 	close(sockfd);
191 	return ret;
192 }
193 
recv_uring(struct recv_data * rd)194 static int recv_uring(struct recv_data *rd)
195 {
196 	int buf[MAX_MSG];
197 	struct iovec iov = {
198 		.iov_base = buf,
199 		.iov_len = sizeof(buf),
200 	};
201 	struct io_uring_params p = { };
202 	struct io_uring ring;
203 	int ret, sock = -1, sockout = -1;
204 
205 	ret = t_create_ring_params(1, &ring, &p);
206 	if (ret == T_SETUP_SKIP) {
207 		pthread_mutex_unlock(&rd->mutex);
208 		ret = 0;
209 		goto err;
210 	} else if (ret < 0) {
211 		pthread_mutex_unlock(&rd->mutex);
212 		goto err;
213 	}
214 
215 	sock = recv_prep(&ring, &iov, &sockout, rd);
216 	if (ret) {
217 		fprintf(stderr, "recv_prep failed: %d\n", ret);
218 		goto err;
219 	}
220 	ret = do_recv(&ring);
221 	if (!ret) {
222 		int i;
223 
224 		for (i = 0; i < MAX_MSG; i++) {
225 			if (buf[i] != i) {
226 				fprintf(stderr, "found %d at %d\n", buf[i], i);
227 				ret = 1;
228 				break;
229 			}
230 		}
231 	}
232 
233 	shutdown(sockout, SHUT_RDWR);
234 	shutdown(sock, SHUT_RDWR);
235 	close(sock);
236 	close(sockout);
237 	io_uring_queue_exit(&ring);
238 err:
239 	if (sock != -1) {
240 		shutdown(sock, SHUT_RDWR);
241 		close(sock);
242 	}
243 	if (sockout != -1) {
244 		shutdown(sockout, SHUT_RDWR);
245 		close(sockout);
246 	}
247 	return ret;
248 }
249 
recv_fn(void * data)250 static void *recv_fn(void *data)
251 {
252 	struct recv_data *rd = data;
253 
254 	if (rd->use_sync)
255 		return (void *) (uintptr_t) recv_sync(rd);
256 
257 	return (void *) (uintptr_t) recv_uring(rd);
258 }
259 
do_send(struct recv_data * rd)260 static int do_send(struct recv_data *rd)
261 {
262 	struct sockaddr_in saddr;
263 	struct io_uring ring;
264 	struct io_uring_cqe *cqe;
265 	struct io_uring_sqe *sqe;
266 	int sockfd, ret, i;
267 	struct iovec iov;
268 	int *buf;
269 
270 	ret = io_uring_queue_init(2, &ring, 0);
271 	if (ret) {
272 		fprintf(stderr, "queue init failed: %d\n", ret);
273 		return 1;
274 	}
275 
276 	buf = malloc(MAX_MSG * sizeof(int));
277 	for (i = 0; i < MAX_MSG; i++)
278 		buf[i] = i;
279 
280 	sockfd = socket(AF_INET, SOCK_STREAM | SOCK_CLOEXEC, IPPROTO_TCP);
281 	if (sockfd < 0) {
282 		perror("socket");
283 		return 1;
284 	}
285 
286 	pthread_mutex_lock(&rd->mutex);
287 	assert(rd->port != 0);
288 	memset(&saddr, 0, sizeof(saddr));
289 	saddr.sin_family = AF_INET;
290 	saddr.sin_port = rd->port;
291 	inet_pton(AF_INET, "127.0.0.1", &saddr.sin_addr);
292 
293 	ret = connect(sockfd, (struct sockaddr *)&saddr, sizeof(saddr));
294 	if (ret < 0) {
295 		perror("connect");
296 		return 1;
297 	}
298 
299 	iov.iov_base = buf;
300 	iov.iov_len = MAX_MSG * sizeof(int) / 2;
301 	for (i = 0; i < 2; i++) {
302 		sqe = io_uring_get_sqe(&ring);
303 		io_uring_prep_send(sqe, sockfd, iov.iov_base, iov.iov_len, 0);
304 		sqe->user_data = 1;
305 
306 		ret = io_uring_submit(&ring);
307 		if (ret <= 0) {
308 			fprintf(stderr, "submit failed: %d\n", ret);
309 			goto err;
310 		}
311 		usleep(10000);
312 		iov.iov_base += iov.iov_len;
313 	}
314 
315 	for (i = 0; i < 2; i++) {
316 		ret = io_uring_wait_cqe(&ring, &cqe);
317 		if (cqe->res == -EINVAL) {
318 			fprintf(stdout, "send not supported, skipping\n");
319 			close(sockfd);
320 			free(buf);
321 			return 0;
322 		}
323 		if (cqe->res != iov.iov_len) {
324 			fprintf(stderr, "failed cqe: %d\n", cqe->res);
325 			goto err;
326 		}
327 		io_uring_cqe_seen(&ring, cqe);
328 	}
329 
330 	shutdown(sockfd, SHUT_RDWR);
331 	close(sockfd);
332 	free(buf);
333 	return 0;
334 err:
335 	shutdown(sockfd, SHUT_RDWR);
336 	close(sockfd);
337 	free(buf);
338 	return 1;
339 }
340 
test(int use_recvmsg,int use_sync)341 static int test(int use_recvmsg, int use_sync)
342 {
343 	pthread_mutexattr_t attr;
344 	pthread_t recv_thread;
345 	struct recv_data rd;
346 	int ret;
347 	void *retval;
348 
349 	pthread_mutexattr_init(&attr);
350 	pthread_mutexattr_setpshared(&attr, 1);
351 	pthread_mutex_init(&rd.mutex, &attr);
352 	pthread_mutex_lock(&rd.mutex);
353 	rd.use_recvmsg = use_recvmsg;
354 	rd.use_sync = use_sync;
355 	rd.port = 0;
356 
357 	ret = pthread_create(&recv_thread, NULL, recv_fn, &rd);
358 	if (ret) {
359 		fprintf(stderr, "Thread create failed: %d\n", ret);
360 		pthread_mutex_unlock(&rd.mutex);
361 		return 1;
362 	}
363 
364 	do_send(&rd);
365 	pthread_join(recv_thread, &retval);
366 	return (intptr_t)retval;
367 }
368 
main(int argc,char * argv[])369 int main(int argc, char *argv[])
370 {
371 	int ret;
372 
373 	if (argc > 1)
374 		return 0;
375 
376 	ret = test(0, 0);
377 	if (ret) {
378 		fprintf(stderr, "test recv failed\n");
379 		return ret;
380 	}
381 
382 	ret = test(1, 0);
383 	if (ret) {
384 		fprintf(stderr, "test recvmsg failed\n");
385 		return ret;
386 	}
387 
388 	ret = test(0, 1);
389 	if (ret) {
390 		fprintf(stderr, "test sync recv failed\n");
391 		return ret;
392 	}
393 
394 	ret = test(1, 1);
395 	if (ret) {
396 		fprintf(stderr, "test sync recvmsg failed\n");
397 		return ret;
398 	}
399 
400 	return 0;
401 }
402