1 /* SPDX-License-Identifier: MIT */
2
3 #include <stdio.h>
4 #include <unistd.h>
5 #include <errno.h>
6 #include <sys/mman.h>
7 #include <stdlib.h>
8 #include <string.h>
9 #include <netinet/udp.h>
10 #include <arpa/inet.h>
11
12 #include "liburing.h"
13
14 #define QD 64
15 #define BUF_SHIFT 12 /* 4k */
16 #define CQES (QD * 16)
17 #define BUFFERS CQES
18 #define CONTROLLEN 0
19
20 struct sendmsg_ctx {
21 struct msghdr msg;
22 struct iovec iov;
23 };
24
25 struct ctx {
26 struct io_uring ring;
27 struct io_uring_buf_ring *buf_ring;
28 unsigned char *buffer_base;
29 struct msghdr msg;
30 int buf_shift;
31 int af;
32 bool verbose;
33 struct sendmsg_ctx send[BUFFERS];
34 size_t buf_ring_size;
35 };
36
buffer_size(struct ctx * ctx)37 static size_t buffer_size(struct ctx *ctx)
38 {
39 return 1U << ctx->buf_shift;
40 }
41
get_buffer(struct ctx * ctx,int idx)42 static unsigned char *get_buffer(struct ctx *ctx, int idx)
43 {
44 return ctx->buffer_base + (idx << ctx->buf_shift);
45 }
46
setup_buffer_pool(struct ctx * ctx)47 static int setup_buffer_pool(struct ctx *ctx)
48 {
49 int ret, i;
50 void *mapped;
51 struct io_uring_buf_reg reg = { .ring_addr = 0,
52 .ring_entries = BUFFERS,
53 .bgid = 0 };
54
55 ctx->buf_ring_size = (sizeof(struct io_uring_buf) + buffer_size(ctx)) * BUFFERS;
56 mapped = mmap(NULL, ctx->buf_ring_size, PROT_READ | PROT_WRITE,
57 MAP_ANONYMOUS | MAP_PRIVATE, 0, 0);
58 if (mapped == MAP_FAILED) {
59 fprintf(stderr, "buf_ring mmap: %s\n", strerror(errno));
60 return -1;
61 }
62 ctx->buf_ring = (struct io_uring_buf_ring *)mapped;
63
64 io_uring_buf_ring_init(ctx->buf_ring);
65
66 reg = (struct io_uring_buf_reg) {
67 .ring_addr = (unsigned long)ctx->buf_ring,
68 .ring_entries = BUFFERS,
69 .bgid = 0
70 };
71 ctx->buffer_base = (unsigned char *)ctx->buf_ring +
72 sizeof(struct io_uring_buf) * BUFFERS;
73
74 ret = io_uring_register_buf_ring(&ctx->ring, ®, 0);
75 if (ret) {
76 fprintf(stderr, "buf_ring init failed: %s\n"
77 "NB This requires a kernel version >= 6.0\n",
78 strerror(-ret));
79 return ret;
80 }
81
82 for (i = 0; i < BUFFERS; i++) {
83 io_uring_buf_ring_add(ctx->buf_ring, get_buffer(ctx, i), buffer_size(ctx), i,
84 io_uring_buf_ring_mask(BUFFERS), i);
85 }
86 io_uring_buf_ring_advance(ctx->buf_ring, BUFFERS);
87
88 return 0;
89 }
90
setup_context(struct ctx * ctx)91 static int setup_context(struct ctx *ctx)
92 {
93 struct io_uring_params params;
94 int ret;
95
96 memset(¶ms, 0, sizeof(params));
97 params.cq_entries = QD * 8;
98 params.flags = IORING_SETUP_SUBMIT_ALL | IORING_SETUP_COOP_TASKRUN |
99 IORING_SETUP_CQSIZE;
100
101 ret = io_uring_queue_init_params(QD, &ctx->ring, ¶ms);
102 if (ret < 0) {
103 fprintf(stderr, "queue_init failed: %s\n"
104 "NB: This requires a kernel version >= 6.0\n",
105 strerror(-ret));
106 return ret;
107 }
108
109 ret = setup_buffer_pool(ctx);
110 if (ret)
111 io_uring_queue_exit(&ctx->ring);
112
113 memset(&ctx->msg, 0, sizeof(ctx->msg));
114 ctx->msg.msg_namelen = sizeof(struct sockaddr_storage);
115 ctx->msg.msg_controllen = CONTROLLEN;
116 return ret;
117 }
118
setup_sock(int af,int port)119 static int setup_sock(int af, int port)
120 {
121 int ret;
122 int fd;
123 uint16_t nport = port <= 0 ? 0 : htons(port);
124
125 fd = socket(af, SOCK_DGRAM, 0);
126 if (fd < 0) {
127 fprintf(stderr, "sock_init: %s\n", strerror(errno));
128 return -1;
129 }
130
131 if (af == AF_INET6) {
132 struct sockaddr_in6 addr6 = {
133 .sin6_family = af,
134 .sin6_port = nport,
135 .sin6_addr = IN6ADDR_ANY_INIT
136 };
137
138 ret = bind(fd, (struct sockaddr *) &addr6, sizeof(addr6));
139 } else {
140 struct sockaddr_in addr = {
141 .sin_family = af,
142 .sin_port = nport,
143 .sin_addr = { INADDR_ANY }
144 };
145
146 ret = bind(fd, (struct sockaddr *) &addr, sizeof(addr));
147 }
148
149 if (ret) {
150 fprintf(stderr, "sock_bind: %s\n", strerror(errno));
151 close(fd);
152 return -1;
153 }
154
155 if (port <= 0) {
156 int port;
157 struct sockaddr_storage s;
158 socklen_t sz = sizeof(s);
159
160 if (getsockname(fd, (struct sockaddr *)&s, &sz)) {
161 fprintf(stderr, "getsockname failed\n");
162 close(fd);
163 return -1;
164 }
165
166 port = ntohs(((struct sockaddr_in *)&s)->sin_port);
167 fprintf(stderr, "port bound to %d\n", port);
168 }
169
170 return fd;
171 }
172
cleanup_context(struct ctx * ctx)173 static void cleanup_context(struct ctx *ctx)
174 {
175 munmap(ctx->buf_ring, ctx->buf_ring_size);
176 io_uring_queue_exit(&ctx->ring);
177 }
178
get_sqe(struct ctx * ctx,struct io_uring_sqe ** sqe)179 static bool get_sqe(struct ctx *ctx, struct io_uring_sqe **sqe)
180 {
181 *sqe = io_uring_get_sqe(&ctx->ring);
182
183 if (!*sqe) {
184 io_uring_submit(&ctx->ring);
185 *sqe = io_uring_get_sqe(&ctx->ring);
186 }
187 if (!*sqe) {
188 fprintf(stderr, "cannot get sqe\n");
189 return true;
190 }
191 return false;
192 }
193
add_recv(struct ctx * ctx,int idx)194 static int add_recv(struct ctx *ctx, int idx)
195 {
196 struct io_uring_sqe *sqe;
197
198 if (get_sqe(ctx, &sqe))
199 return -1;
200
201 io_uring_prep_recvmsg_multishot(sqe, idx, &ctx->msg, MSG_TRUNC);
202 sqe->flags |= IOSQE_FIXED_FILE;
203
204 sqe->flags |= IOSQE_BUFFER_SELECT;
205 sqe->buf_group = 0;
206 io_uring_sqe_set_data64(sqe, BUFFERS + 1);
207 return 0;
208 }
209
recycle_buffer(struct ctx * ctx,int idx)210 static void recycle_buffer(struct ctx *ctx, int idx)
211 {
212 io_uring_buf_ring_add(ctx->buf_ring, get_buffer(ctx, idx), buffer_size(ctx), idx,
213 io_uring_buf_ring_mask(BUFFERS), 0);
214 io_uring_buf_ring_advance(ctx->buf_ring, 1);
215 }
216
process_cqe_send(struct ctx * ctx,struct io_uring_cqe * cqe)217 static int process_cqe_send(struct ctx *ctx, struct io_uring_cqe *cqe)
218 {
219 int idx = cqe->user_data;
220
221 if (cqe->res < 0)
222 fprintf(stderr, "bad send %s\n", strerror(-cqe->res));
223 recycle_buffer(ctx, idx);
224 return 0;
225 }
226
process_cqe_recv(struct ctx * ctx,struct io_uring_cqe * cqe,int fdidx)227 static int process_cqe_recv(struct ctx *ctx, struct io_uring_cqe *cqe,
228 int fdidx)
229 {
230 int ret, idx;
231 struct io_uring_recvmsg_out *o;
232 struct io_uring_sqe *sqe;
233
234 if (!(cqe->flags & IORING_CQE_F_MORE)) {
235 ret = add_recv(ctx, fdidx);
236 if (ret)
237 return ret;
238 }
239
240 if (cqe->res == -ENOBUFS)
241 return 0;
242
243 if (!(cqe->flags & IORING_CQE_F_BUFFER) || cqe->res < 0) {
244 fprintf(stderr, "recv cqe bad res %d\n", cqe->res);
245 if (cqe->res == -EFAULT || cqe->res == -EINVAL)
246 fprintf(stderr,
247 "NB: This requires a kernel version >= 6.0\n");
248 return -1;
249 }
250 idx = cqe->flags >> 16;
251
252 o = io_uring_recvmsg_validate(get_buffer(ctx, cqe->flags >> 16),
253 cqe->res, &ctx->msg);
254 if (!o) {
255 fprintf(stderr, "bad recvmsg\n");
256 return -1;
257 }
258 if (o->namelen > ctx->msg.msg_namelen) {
259 fprintf(stderr, "truncated name\n");
260 recycle_buffer(ctx, idx);
261 return 0;
262 }
263 if (o->flags & MSG_TRUNC) {
264 unsigned int r;
265
266 r = io_uring_recvmsg_payload_length(o, cqe->res, &ctx->msg);
267 fprintf(stderr, "truncated msg need %u received %u\n",
268 o->payloadlen, r);
269 recycle_buffer(ctx, idx);
270 return 0;
271 }
272
273 if (ctx->verbose) {
274 struct sockaddr_in *addr = io_uring_recvmsg_name(o);
275 struct sockaddr_in6 *addr6 = (void *)addr;
276 char buff[INET6_ADDRSTRLEN + 1];
277 const char *name;
278 void *paddr;
279
280 if (ctx->af == AF_INET6)
281 paddr = &addr6->sin6_addr;
282 else
283 paddr = &addr->sin_addr;
284
285 name = inet_ntop(ctx->af, paddr, buff, sizeof(buff));
286 if (!name)
287 name = "<INVALID>";
288
289 fprintf(stderr, "received %u bytes %d from [%s]:%d\n",
290 io_uring_recvmsg_payload_length(o, cqe->res, &ctx->msg),
291 o->namelen, name, (int)ntohs(addr->sin_port));
292 }
293
294 if (get_sqe(ctx, &sqe))
295 return -1;
296
297 ctx->send[idx].iov = (struct iovec) {
298 .iov_base = io_uring_recvmsg_payload(o, &ctx->msg),
299 .iov_len =
300 io_uring_recvmsg_payload_length(o, cqe->res, &ctx->msg)
301 };
302 ctx->send[idx].msg = (struct msghdr) {
303 .msg_namelen = o->namelen,
304 .msg_name = io_uring_recvmsg_name(o),
305 .msg_control = NULL,
306 .msg_controllen = 0,
307 .msg_iov = &ctx->send[idx].iov,
308 .msg_iovlen = 1
309 };
310
311 io_uring_prep_sendmsg(sqe, fdidx, &ctx->send[idx].msg, 0);
312 io_uring_sqe_set_data64(sqe, idx);
313 sqe->flags |= IOSQE_FIXED_FILE;
314
315 return 0;
316 }
process_cqe(struct ctx * ctx,struct io_uring_cqe * cqe,int fdidx)317 static int process_cqe(struct ctx *ctx, struct io_uring_cqe *cqe, int fdidx)
318 {
319 if (cqe->user_data < BUFFERS)
320 return process_cqe_send(ctx, cqe);
321 else
322 return process_cqe_recv(ctx, cqe, fdidx);
323 }
324
main(int argc,char * argv[])325 int main(int argc, char *argv[])
326 {
327 struct ctx ctx;
328 int ret;
329 int port = -1;
330 int sockfd;
331 int opt;
332 struct io_uring_cqe *cqes[CQES];
333 unsigned int count, i;
334
335 memset(&ctx, 0, sizeof(ctx));
336 ctx.verbose = false;
337 ctx.af = AF_INET;
338 ctx.buf_shift = BUF_SHIFT;
339
340 while ((opt = getopt(argc, argv, "6vp:b:")) != -1) {
341 switch (opt) {
342 case '6':
343 ctx.af = AF_INET6;
344 break;
345 case 'p':
346 port = atoi(optarg);
347 break;
348 case 'b':
349 ctx.buf_shift = atoi(optarg);
350 break;
351 case 'v':
352 ctx.verbose = true;
353 break;
354 default:
355 fprintf(stderr, "Usage: %s [-p port] "
356 "[-b log2(BufferSize)] [-6] [-v]\n",
357 argv[0]);
358 exit(-1);
359 }
360 }
361
362 sockfd = setup_sock(ctx.af, port);
363 if (sockfd < 0)
364 return 1;
365
366 if (setup_context(&ctx)) {
367 close(sockfd);
368 return 1;
369 }
370
371 ret = io_uring_register_files(&ctx.ring, &sockfd, 1);
372 if (ret) {
373 fprintf(stderr, "register files: %s\n", strerror(-ret));
374 return -1;
375 }
376
377 ret = add_recv(&ctx, 0);
378 if (ret)
379 return 1;
380
381 while (true) {
382 ret = io_uring_submit_and_wait(&ctx.ring, 1);
383 if (ret == -EINTR)
384 continue;
385 if (ret < 0) {
386 fprintf(stderr, "submit and wait failed %d\n", ret);
387 break;
388 }
389
390 count = io_uring_peek_batch_cqe(&ctx.ring, &cqes[0], CQES);
391 for (i = 0; i < count; i++) {
392 ret = process_cqe(&ctx, cqes[i], 0);
393 if (ret)
394 goto cleanup;
395 }
396 io_uring_cq_advance(&ctx.ring, count);
397 }
398
399 cleanup:
400 cleanup_context(&ctx);
401 close(sockfd);
402 return ret;
403 }
404