1 /* SPDX-License-Identifier: MIT */
2 
3 #include <stdio.h>
4 #include <unistd.h>
5 #include <errno.h>
6 #include <sys/mman.h>
7 #include <stdlib.h>
8 #include <string.h>
9 #include <netinet/udp.h>
10 #include <arpa/inet.h>
11 
12 #include "liburing.h"
13 
14 #define QD 64
15 #define BUF_SHIFT 12 /* 4k */
16 #define CQES (QD * 16)
17 #define BUFFERS CQES
18 #define CONTROLLEN 0
19 
20 struct sendmsg_ctx {
21 	struct msghdr msg;
22 	struct iovec iov;
23 };
24 
25 struct ctx {
26 	struct io_uring ring;
27 	struct io_uring_buf_ring *buf_ring;
28 	unsigned char *buffer_base;
29 	struct msghdr msg;
30 	int buf_shift;
31 	int af;
32 	bool verbose;
33 	struct sendmsg_ctx send[BUFFERS];
34 	size_t buf_ring_size;
35 };
36 
buffer_size(struct ctx * ctx)37 static size_t buffer_size(struct ctx *ctx)
38 {
39 	return 1U << ctx->buf_shift;
40 }
41 
get_buffer(struct ctx * ctx,int idx)42 static unsigned char *get_buffer(struct ctx *ctx, int idx)
43 {
44 	return ctx->buffer_base + (idx << ctx->buf_shift);
45 }
46 
setup_buffer_pool(struct ctx * ctx)47 static int setup_buffer_pool(struct ctx *ctx)
48 {
49 	int ret, i;
50 	void *mapped;
51 	struct io_uring_buf_reg reg = { .ring_addr = 0,
52 					.ring_entries = BUFFERS,
53 					.bgid = 0 };
54 
55 	ctx->buf_ring_size = (sizeof(struct io_uring_buf) + buffer_size(ctx)) * BUFFERS;
56 	mapped = mmap(NULL, ctx->buf_ring_size, PROT_READ | PROT_WRITE,
57 		      MAP_ANONYMOUS | MAP_PRIVATE, 0, 0);
58 	if (mapped == MAP_FAILED) {
59 		fprintf(stderr, "buf_ring mmap: %s\n", strerror(errno));
60 		return -1;
61 	}
62 	ctx->buf_ring = (struct io_uring_buf_ring *)mapped;
63 
64 	io_uring_buf_ring_init(ctx->buf_ring);
65 
66 	reg = (struct io_uring_buf_reg) {
67 		.ring_addr = (unsigned long)ctx->buf_ring,
68 		.ring_entries = BUFFERS,
69 		.bgid = 0
70 	};
71 	ctx->buffer_base = (unsigned char *)ctx->buf_ring +
72 			   sizeof(struct io_uring_buf) * BUFFERS;
73 
74 	ret = io_uring_register_buf_ring(&ctx->ring, ®, 0);
75 	if (ret) {
76 		fprintf(stderr, "buf_ring init failed: %s\n"
77 				"NB This requires a kernel version >= 6.0\n",
78 				strerror(-ret));
79 		return ret;
80 	}
81 
82 	for (i = 0; i < BUFFERS; i++) {
83 		io_uring_buf_ring_add(ctx->buf_ring, get_buffer(ctx, i), buffer_size(ctx), i,
84 				      io_uring_buf_ring_mask(BUFFERS), i);
85 	}
86 	io_uring_buf_ring_advance(ctx->buf_ring, BUFFERS);
87 
88 	return 0;
89 }
90 
setup_context(struct ctx * ctx)91 static int setup_context(struct ctx *ctx)
92 {
93 	struct io_uring_params params;
94 	int ret;
95 
96 	memset(¶ms, 0, sizeof(params));
97 	params.cq_entries = QD * 8;
98 	params.flags = IORING_SETUP_SUBMIT_ALL | IORING_SETUP_COOP_TASKRUN |
99 		       IORING_SETUP_CQSIZE;
100 
101 	ret = io_uring_queue_init_params(QD, &ctx->ring, ¶ms);
102 	if (ret < 0) {
103 		fprintf(stderr, "queue_init failed: %s\n"
104 				"NB: This requires a kernel version >= 6.0\n",
105 				strerror(-ret));
106 		return ret;
107 	}
108 
109 	ret = setup_buffer_pool(ctx);
110 	if (ret)
111 		io_uring_queue_exit(&ctx->ring);
112 
113 	memset(&ctx->msg, 0, sizeof(ctx->msg));
114 	ctx->msg.msg_namelen = sizeof(struct sockaddr_storage);
115 	ctx->msg.msg_controllen = CONTROLLEN;
116 	return ret;
117 }
118 
setup_sock(int af,int port)119 static int setup_sock(int af, int port)
120 {
121 	int ret;
122 	int fd;
123 	uint16_t nport = port <= 0 ? 0 : htons(port);
124 
125 	fd = socket(af, SOCK_DGRAM, 0);
126 	if (fd < 0) {
127 		fprintf(stderr, "sock_init: %s\n", strerror(errno));
128 		return -1;
129 	}
130 
131 	if (af == AF_INET6) {
132 		struct sockaddr_in6 addr6 = {
133 			.sin6_family = af,
134 			.sin6_port = nport,
135 			.sin6_addr = IN6ADDR_ANY_INIT
136 		};
137 
138 		ret = bind(fd, (struct sockaddr *) &addr6, sizeof(addr6));
139 	} else {
140 		struct sockaddr_in addr = {
141 			.sin_family = af,
142 			.sin_port = nport,
143 			.sin_addr = { INADDR_ANY }
144 		};
145 
146 		ret = bind(fd, (struct sockaddr *) &addr, sizeof(addr));
147 	}
148 
149 	if (ret) {
150 		fprintf(stderr, "sock_bind: %s\n", strerror(errno));
151 		close(fd);
152 		return -1;
153 	}
154 
155 	if (port <= 0) {
156 		int port;
157 		struct sockaddr_storage s;
158 		socklen_t sz = sizeof(s);
159 
160 		if (getsockname(fd, (struct sockaddr *)&s, &sz)) {
161 			fprintf(stderr, "getsockname failed\n");
162 			close(fd);
163 			return -1;
164 		}
165 
166 		port = ntohs(((struct sockaddr_in *)&s)->sin_port);
167 		fprintf(stderr, "port bound to %d\n", port);
168 	}
169 
170 	return fd;
171 }
172 
cleanup_context(struct ctx * ctx)173 static void cleanup_context(struct ctx *ctx)
174 {
175 	munmap(ctx->buf_ring, ctx->buf_ring_size);
176 	io_uring_queue_exit(&ctx->ring);
177 }
178 
get_sqe(struct ctx * ctx,struct io_uring_sqe ** sqe)179 static bool get_sqe(struct ctx *ctx, struct io_uring_sqe **sqe)
180 {
181 	*sqe = io_uring_get_sqe(&ctx->ring);
182 
183 	if (!*sqe) {
184 		io_uring_submit(&ctx->ring);
185 		*sqe = io_uring_get_sqe(&ctx->ring);
186 	}
187 	if (!*sqe) {
188 		fprintf(stderr, "cannot get sqe\n");
189 		return true;
190 	}
191 	return false;
192 }
193 
add_recv(struct ctx * ctx,int idx)194 static int add_recv(struct ctx *ctx, int idx)
195 {
196 	struct io_uring_sqe *sqe;
197 
198 	if (get_sqe(ctx, &sqe))
199 		return -1;
200 
201 	io_uring_prep_recvmsg_multishot(sqe, idx, &ctx->msg, MSG_TRUNC);
202 	sqe->flags |= IOSQE_FIXED_FILE;
203 
204 	sqe->flags |= IOSQE_BUFFER_SELECT;
205 	sqe->buf_group = 0;
206 	io_uring_sqe_set_data64(sqe, BUFFERS + 1);
207 	return 0;
208 }
209 
recycle_buffer(struct ctx * ctx,int idx)210 static void recycle_buffer(struct ctx *ctx, int idx)
211 {
212 	io_uring_buf_ring_add(ctx->buf_ring, get_buffer(ctx, idx), buffer_size(ctx), idx,
213 			      io_uring_buf_ring_mask(BUFFERS), 0);
214 	io_uring_buf_ring_advance(ctx->buf_ring, 1);
215 }
216 
process_cqe_send(struct ctx * ctx,struct io_uring_cqe * cqe)217 static int process_cqe_send(struct ctx *ctx, struct io_uring_cqe *cqe)
218 {
219 	int idx = cqe->user_data;
220 
221 	if (cqe->res < 0)
222 		fprintf(stderr, "bad send %s\n", strerror(-cqe->res));
223 	recycle_buffer(ctx, idx);
224 	return 0;
225 }
226 
process_cqe_recv(struct ctx * ctx,struct io_uring_cqe * cqe,int fdidx)227 static int process_cqe_recv(struct ctx *ctx, struct io_uring_cqe *cqe,
228 			    int fdidx)
229 {
230 	int ret, idx;
231 	struct io_uring_recvmsg_out *o;
232 	struct io_uring_sqe *sqe;
233 
234 	if (!(cqe->flags & IORING_CQE_F_MORE)) {
235 		ret = add_recv(ctx, fdidx);
236 		if (ret)
237 			return ret;
238 	}
239 
240 	if (cqe->res == -ENOBUFS)
241 		return 0;
242 
243 	if (!(cqe->flags & IORING_CQE_F_BUFFER) || cqe->res < 0) {
244 		fprintf(stderr, "recv cqe bad res %d\n", cqe->res);
245 		if (cqe->res == -EFAULT || cqe->res == -EINVAL)
246 			fprintf(stderr,
247 				"NB: This requires a kernel version >= 6.0\n");
248 		return -1;
249 	}
250 	idx = cqe->flags >> 16;
251 
252 	o = io_uring_recvmsg_validate(get_buffer(ctx, cqe->flags >> 16),
253 				      cqe->res, &ctx->msg);
254 	if (!o) {
255 		fprintf(stderr, "bad recvmsg\n");
256 		return -1;
257 	}
258 	if (o->namelen > ctx->msg.msg_namelen) {
259 		fprintf(stderr, "truncated name\n");
260 		recycle_buffer(ctx, idx);
261 		return 0;
262 	}
263 	if (o->flags & MSG_TRUNC) {
264 		unsigned int r;
265 
266 		r = io_uring_recvmsg_payload_length(o, cqe->res, &ctx->msg);
267 		fprintf(stderr, "truncated msg need %u received %u\n",
268 				o->payloadlen, r);
269 		recycle_buffer(ctx, idx);
270 		return 0;
271 	}
272 
273 	if (ctx->verbose) {
274 		struct sockaddr_in *addr = io_uring_recvmsg_name(o);
275 		struct sockaddr_in6 *addr6 = (void *)addr;
276 		char buff[INET6_ADDRSTRLEN + 1];
277 		const char *name;
278 		void *paddr;
279 
280 		if (ctx->af == AF_INET6)
281 			paddr = &addr6->sin6_addr;
282 		else
283 			paddr = &addr->sin_addr;
284 
285 		name = inet_ntop(ctx->af, paddr, buff, sizeof(buff));
286 		if (!name)
287 			name = "<INVALID>";
288 
289 		fprintf(stderr, "received %u bytes %d from [%s]:%d\n",
290 			io_uring_recvmsg_payload_length(o, cqe->res, &ctx->msg),
291 			o->namelen, name, (int)ntohs(addr->sin_port));
292 	}
293 
294 	if (get_sqe(ctx, &sqe))
295 		return -1;
296 
297 	ctx->send[idx].iov = (struct iovec) {
298 		.iov_base = io_uring_recvmsg_payload(o, &ctx->msg),
299 		.iov_len =
300 			io_uring_recvmsg_payload_length(o, cqe->res, &ctx->msg)
301 	};
302 	ctx->send[idx].msg = (struct msghdr) {
303 		.msg_namelen = o->namelen,
304 		.msg_name = io_uring_recvmsg_name(o),
305 		.msg_control = NULL,
306 		.msg_controllen = 0,
307 		.msg_iov = &ctx->send[idx].iov,
308 		.msg_iovlen = 1
309 	};
310 
311 	io_uring_prep_sendmsg(sqe, fdidx, &ctx->send[idx].msg, 0);
312 	io_uring_sqe_set_data64(sqe, idx);
313 	sqe->flags |= IOSQE_FIXED_FILE;
314 
315 	return 0;
316 }
process_cqe(struct ctx * ctx,struct io_uring_cqe * cqe,int fdidx)317 static int process_cqe(struct ctx *ctx, struct io_uring_cqe *cqe, int fdidx)
318 {
319 	if (cqe->user_data < BUFFERS)
320 		return process_cqe_send(ctx, cqe);
321 	else
322 		return process_cqe_recv(ctx, cqe, fdidx);
323 }
324 
main(int argc,char * argv[])325 int main(int argc, char *argv[])
326 {
327 	struct ctx ctx;
328 	int ret;
329 	int port = -1;
330 	int sockfd;
331 	int opt;
332 	struct io_uring_cqe *cqes[CQES];
333 	unsigned int count, i;
334 
335 	memset(&ctx, 0, sizeof(ctx));
336 	ctx.verbose = false;
337 	ctx.af = AF_INET;
338 	ctx.buf_shift = BUF_SHIFT;
339 
340 	while ((opt = getopt(argc, argv, "6vp:b:")) != -1) {
341 		switch (opt) {
342 		case '6':
343 			ctx.af = AF_INET6;
344 			break;
345 		case 'p':
346 			port = atoi(optarg);
347 			break;
348 		case 'b':
349 			ctx.buf_shift = atoi(optarg);
350 			break;
351 		case 'v':
352 			ctx.verbose = true;
353 			break;
354 		default:
355 			fprintf(stderr, "Usage: %s [-p port] "
356 					"[-b log2(BufferSize)] [-6] [-v]\n",
357 					argv[0]);
358 			exit(-1);
359 		}
360 	}
361 
362 	sockfd = setup_sock(ctx.af, port);
363 	if (sockfd < 0)
364 		return 1;
365 
366 	if (setup_context(&ctx)) {
367 		close(sockfd);
368 		return 1;
369 	}
370 
371 	ret = io_uring_register_files(&ctx.ring, &sockfd, 1);
372 	if (ret) {
373 		fprintf(stderr, "register files: %s\n", strerror(-ret));
374 		return -1;
375 	}
376 
377 	ret = add_recv(&ctx, 0);
378 	if (ret)
379 		return 1;
380 
381 	while (true) {
382 		ret = io_uring_submit_and_wait(&ctx.ring, 1);
383 		if (ret == -EINTR)
384 			continue;
385 		if (ret < 0) {
386 			fprintf(stderr, "submit and wait failed %d\n", ret);
387 			break;
388 		}
389 
390 		count = io_uring_peek_batch_cqe(&ctx.ring, &cqes[0], CQES);
391 		for (i = 0; i < count; i++) {
392 			ret = process_cqe(&ctx, cqes[i], 0);
393 			if (ret)
394 				goto cleanup;
395 		}
396 		io_uring_cq_advance(&ctx.ring, count);
397 	}
398 
399 cleanup:
400 	cleanup_context(&ctx);
401 	close(sockfd);
402 	return ret;
403 }
404