1 // SPDX-License-Identifier: GPL-2.0-or-later
2 /*
3 * Copyright (c) 2021 Linux Test Project
4 */
5
6 #include <stdlib.h>
7 #include <limits.h>
8 #include <asm/types.h>
9 #include <linux/netlink.h>
10 #include <linux/rtnetlink.h>
11 #include <sys/types.h>
12 #include <sys/socket.h>
13 #include <sys/poll.h>
14 #define TST_NO_DEFAULT_MAIN
15 #include "tst_test.h"
16 #include "tst_rtnetlink.h"
17
18 struct tst_rtnl_context {
19 int socket;
20 pid_t pid;
21 uint32_t seq;
22 size_t bufsize, datalen;
23 char *buffer;
24 struct nlmsghdr *curmsg;
25 };
26
tst_rtnl_grow_buffer(const char * file,const int lineno,struct tst_rtnl_context * ctx,size_t size)27 static int tst_rtnl_grow_buffer(const char *file, const int lineno,
28 struct tst_rtnl_context *ctx, size_t size)
29 {
30 size_t needed, offset, curlen = NLMSG_ALIGN(ctx->datalen);
31 char *buf;
32
33 if (ctx->bufsize - curlen >= size)
34 return 1;
35
36 needed = size - (ctx->bufsize - curlen);
37 size = ctx->bufsize + (ctx->bufsize > needed ? ctx->bufsize : needed);
38 size = NLMSG_ALIGN(size);
39 buf = safe_realloc(file, lineno, ctx->buffer, size);
40
41 if (!buf)
42 return 0;
43
44 memset(buf + ctx->bufsize, 0, size - ctx->bufsize);
45 offset = ((char *)ctx->curmsg) - ctx->buffer;
46 ctx->buffer = buf;
47 ctx->curmsg = (struct nlmsghdr *)(buf + offset);
48 ctx->bufsize = size;
49
50 return 1;
51 }
52
tst_rtnl_destroy_context(const char * file,const int lineno,struct tst_rtnl_context * ctx)53 void tst_rtnl_destroy_context(const char *file, const int lineno,
54 struct tst_rtnl_context *ctx)
55 {
56 safe_close(file, lineno, NULL, ctx->socket);
57 free(ctx->buffer);
58 free(ctx);
59 }
60
tst_rtnl_create_context(const char * file,const int lineno)61 struct tst_rtnl_context *tst_rtnl_create_context(const char *file,
62 const int lineno)
63 {
64 struct tst_rtnl_context *ctx;
65 struct sockaddr_nl addr = { .nl_family = AF_NETLINK };
66
67 ctx = safe_malloc(file, lineno, NULL, sizeof(struct tst_rtnl_context));
68
69 if (!ctx)
70 return NULL;
71
72 ctx->pid = 0;
73 ctx->seq = 0;
74 ctx->buffer = NULL;
75 ctx->bufsize = 1024;
76 ctx->datalen = 0;
77 ctx->curmsg = NULL;
78 ctx->socket = safe_socket(file, lineno, NULL, AF_NETLINK,
79 SOCK_DGRAM | SOCK_CLOEXEC, NETLINK_ROUTE);
80
81 if (ctx->socket < 0) {
82 free(ctx);
83 return NULL;
84 }
85
86 if (safe_bind(file, lineno, NULL, ctx->socket, (struct sockaddr *)&addr,
87 sizeof(addr))) {
88 tst_rtnl_destroy_context(file, lineno, ctx);
89 return NULL;
90 }
91
92 ctx->buffer = safe_malloc(file, lineno, NULL, ctx->bufsize);
93
94 if (!ctx->buffer) {
95 tst_rtnl_destroy_context(file, lineno, ctx);
96 return NULL;
97 }
98
99 memset(ctx->buffer, 0, ctx->bufsize);
100
101 return ctx;
102 }
103
tst_rtnl_free_message(struct tst_rtnl_message * msg)104 void tst_rtnl_free_message(struct tst_rtnl_message *msg)
105 {
106 if (!msg)
107 return;
108
109 // all ptr->header and ptr->info pointers point to the same buffer
110 // msg->header is the start of the buffer
111 free(msg->header);
112 free(msg);
113 }
114
tst_rtnl_send(const char * file,const int lineno,struct tst_rtnl_context * ctx)115 int tst_rtnl_send(const char *file, const int lineno,
116 struct tst_rtnl_context *ctx)
117 {
118 int ret;
119 struct sockaddr_nl addr = { .nl_family = AF_NETLINK };
120 struct iovec iov;
121 struct msghdr msg = {
122 .msg_name = &addr,
123 .msg_namelen = sizeof(addr),
124 .msg_iov = &iov,
125 .msg_iovlen = 1
126 };
127
128 if (!ctx->curmsg) {
129 tst_brk_(file, lineno, TBROK, "%s(): No message to send",
130 __func__);
131 return 0;
132 }
133
134 if (ctx->curmsg->nlmsg_flags & NLM_F_MULTI) {
135 struct nlmsghdr eom = { .nlmsg_type = NLMSG_DONE };
136
137 if (!tst_rtnl_add_message(file, lineno, ctx, &eom, NULL, 0))
138 return 0;
139
140 /* NLMSG_DONE message must not have NLM_F_MULTI flag */
141 ctx->curmsg->nlmsg_flags = 0;
142 }
143
144 iov.iov_base = ctx->buffer;
145 iov.iov_len = ctx->datalen;
146 ret = safe_sendmsg(file, lineno, ctx->datalen, ctx->socket, &msg, 0);
147
148 if (ret > 0)
149 ctx->curmsg = NULL;
150
151 return ret;
152 }
153
tst_rtnl_wait(struct tst_rtnl_context * ctx)154 int tst_rtnl_wait(struct tst_rtnl_context *ctx)
155 {
156 struct pollfd fdinfo = {
157 .fd = ctx->socket,
158 .events = POLLIN
159 };
160
161 return poll(&fdinfo, 1, 1000);
162 }
163
tst_rtnl_recv(const char * file,const int lineno,struct tst_rtnl_context * ctx)164 struct tst_rtnl_message *tst_rtnl_recv(const char *file, const int lineno,
165 struct tst_rtnl_context *ctx)
166 {
167 char tmp, *tmpbuf, *buffer = NULL;
168 struct tst_rtnl_message *ret;
169 struct nlmsghdr *ptr;
170 size_t retsize, bufsize = 0;
171 ssize_t size;
172 int i, size_left, msgcount;
173
174 /* Each recv() call returns one message, read all pending messages */
175 while (1) {
176 errno = 0;
177 size = recv(ctx->socket, &tmp, 1,
178 MSG_DONTWAIT | MSG_PEEK | MSG_TRUNC);
179
180 if (size < 0) {
181 if (errno != EAGAIN) {
182 tst_brk_(file, lineno, TBROK | TERRNO,
183 "recv() failed");
184 }
185
186 break;
187 }
188
189 tmpbuf = safe_realloc(file, lineno, buffer, bufsize + size);
190
191 if (!tmpbuf)
192 break;
193
194 buffer = tmpbuf;
195 size = safe_recv(file, lineno, size, ctx->socket,
196 buffer + bufsize, size, 0);
197
198 if (size < 0)
199 break;
200
201 bufsize += size;
202 }
203
204 if (!bufsize) {
205 free(buffer);
206 return NULL;
207 }
208
209 ptr = (struct nlmsghdr *)buffer;
210 size_left = bufsize;
211 msgcount = 0;
212
213 for (; size_left > 0 && NLMSG_OK(ptr, size_left); msgcount++)
214 ptr = NLMSG_NEXT(ptr, size_left);
215
216 retsize = (msgcount + 1) * sizeof(struct tst_rtnl_message);
217 ret = safe_malloc(file, lineno, NULL, retsize);
218
219 if (!ret) {
220 free(buffer);
221 return NULL;
222 }
223
224 memset(ret, 0, retsize);
225 ptr = (struct nlmsghdr *)buffer;
226 size_left = bufsize;
227
228 for (i = 0; i < msgcount; i++, ptr = NLMSG_NEXT(ptr, size_left)) {
229 ret[i].header = ptr;
230 ret[i].payload = NLMSG_DATA(ptr);
231 ret[i].payload_size = NLMSG_PAYLOAD(ptr, 0);
232
233 if (ptr->nlmsg_type == NLMSG_ERROR)
234 ret[i].err = NLMSG_DATA(ptr);
235 }
236
237 return ret;
238 }
239
tst_rtnl_add_message(const char * file,const int lineno,struct tst_rtnl_context * ctx,const struct nlmsghdr * header,const void * payload,size_t payload_size)240 int tst_rtnl_add_message(const char *file, const int lineno,
241 struct tst_rtnl_context *ctx, const struct nlmsghdr *header,
242 const void *payload, size_t payload_size)
243 {
244 size_t size;
245 unsigned int extra_flags = 0;
246
247 if (!tst_rtnl_grow_buffer(file, lineno, ctx, NLMSG_SPACE(payload_size)))
248 return 0;
249
250 if (!ctx->curmsg) {
251 /*
252 * datalen may hold the size of last sent message for ACK
253 * checking, reset it back to 0 here
254 */
255 ctx->datalen = 0;
256 ctx->curmsg = (struct nlmsghdr *)ctx->buffer;
257 } else {
258 size = NLMSG_ALIGN(ctx->curmsg->nlmsg_len);
259
260 extra_flags = NLM_F_MULTI;
261 ctx->curmsg->nlmsg_flags |= extra_flags;
262 ctx->curmsg = NLMSG_NEXT(ctx->curmsg, size);
263 ctx->datalen = NLMSG_ALIGN(ctx->datalen);
264 }
265
266 *ctx->curmsg = *header;
267 ctx->curmsg->nlmsg_len = NLMSG_LENGTH(payload_size);
268 ctx->curmsg->nlmsg_flags |= extra_flags;
269 ctx->curmsg->nlmsg_seq = ctx->seq++;
270 ctx->curmsg->nlmsg_pid = ctx->pid;
271
272 if (payload_size)
273 memcpy(NLMSG_DATA(ctx->curmsg), payload, payload_size);
274
275 ctx->datalen += ctx->curmsg->nlmsg_len;
276
277 return 1;
278 }
279
tst_rtnl_add_attr(const char * file,const int lineno,struct tst_rtnl_context * ctx,unsigned short type,const void * data,unsigned short len)280 int tst_rtnl_add_attr(const char *file, const int lineno,
281 struct tst_rtnl_context *ctx, unsigned short type,
282 const void *data, unsigned short len)
283 {
284 size_t size;
285 struct rtattr *attr;
286
287 if (!ctx->curmsg) {
288 tst_brk_(file, lineno, TBROK,
289 "%s(): No message to add attributes to", __func__);
290 return 0;
291 }
292
293 if (!tst_rtnl_grow_buffer(file, lineno, ctx, RTA_SPACE(len)))
294 return 0;
295
296 size = NLMSG_ALIGN(ctx->curmsg->nlmsg_len);
297 attr = (struct rtattr *)(((char *)ctx->curmsg) + size);
298 attr->rta_type = type;
299 attr->rta_len = RTA_LENGTH(len);
300 memcpy(RTA_DATA(attr), data, len);
301 ctx->curmsg->nlmsg_len = size + attr->rta_len;
302 ctx->datalen = NLMSG_ALIGN(ctx->datalen) + attr->rta_len;
303
304 return 1;
305 }
306
tst_rtnl_add_attr_string(const char * file,const int lineno,struct tst_rtnl_context * ctx,unsigned short type,const char * data)307 int tst_rtnl_add_attr_string(const char *file, const int lineno,
308 struct tst_rtnl_context *ctx, unsigned short type,
309 const char *data)
310 {
311 return tst_rtnl_add_attr(file, lineno, ctx, type, data,
312 strlen(data) + 1);
313 }
314
tst_rtnl_add_attr_list(const char * file,const int lineno,struct tst_rtnl_context * ctx,const struct tst_rtnl_attr_list * list)315 int tst_rtnl_add_attr_list(const char *file, const int lineno,
316 struct tst_rtnl_context *ctx,
317 const struct tst_rtnl_attr_list *list)
318 {
319 int i, ret;
320 size_t offset;
321
322 for (i = 0; list[i].len >= 0; i++) {
323 if (list[i].len > USHRT_MAX) {
324 tst_brk_(file, lineno, TBROK,
325 "%s(): Attribute value too long", __func__);
326 return -1;
327 }
328
329 offset = NLMSG_ALIGN(ctx->datalen);
330 ret = tst_rtnl_add_attr(file, lineno, ctx, list[i].type,
331 list[i].data, list[i].len);
332
333 if (!ret)
334 return -1;
335
336 if (list[i].sublist) {
337 struct rtattr *attr;
338
339 ret = tst_rtnl_add_attr_list(file, lineno, ctx,
340 list[i].sublist);
341
342 if (ret < 0)
343 return ret;
344
345 attr = (struct rtattr *)(ctx->buffer + offset);
346
347 if (ctx->datalen - offset > USHRT_MAX) {
348 tst_brk_(file, lineno, TBROK,
349 "%s(): Sublist too long", __func__);
350 return -1;
351 }
352
353 attr->rta_len = ctx->datalen - offset;
354 }
355 }
356
357 return i;
358 }
359
tst_rtnl_check_acks(const char * file,const int lineno,struct tst_rtnl_context * ctx,struct tst_rtnl_message * res)360 int tst_rtnl_check_acks(const char *file, const int lineno,
361 struct tst_rtnl_context *ctx, struct tst_rtnl_message *res)
362 {
363 struct nlmsghdr *msg = (struct nlmsghdr *)ctx->buffer;
364 int size_left = ctx->datalen;
365
366 for (; size_left > 0 && NLMSG_OK(msg, size_left);
367 msg = NLMSG_NEXT(msg, size_left)) {
368
369 if (!(msg->nlmsg_flags & NLM_F_ACK))
370 continue;
371
372 while (res->header && res->header->nlmsg_seq != msg->nlmsg_seq)
373 res++;
374
375 if (!res->err || res->header->nlmsg_seq != msg->nlmsg_seq) {
376 tst_brk_(file, lineno, TBROK,
377 "No ACK found for Netlink message %u",
378 msg->nlmsg_seq);
379 return 0;
380 }
381
382 if (res->err->error) {
383 TST_ERR = -res->err->error;
384 return 0;
385 }
386 }
387
388 return 1;
389 }
390
tst_rtnl_send_validate(const char * file,const int lineno,struct tst_rtnl_context * ctx)391 int tst_rtnl_send_validate(const char *file, const int lineno,
392 struct tst_rtnl_context *ctx)
393 {
394 struct tst_rtnl_message *response;
395 int ret;
396
397 TST_ERR = 0;
398
399 if (tst_rtnl_send(file, lineno, ctx) <= 0)
400 return 0;
401
402 tst_rtnl_wait(ctx);
403 response = tst_rtnl_recv(file, lineno, ctx);
404
405 if (!response)
406 return 0;
407
408 ret = tst_rtnl_check_acks(file, lineno, ctx, response);
409 tst_rtnl_free_message(response);
410
411 return ret;
412 }
413