1 /*
2 * inet_diag.c Module for monitoring INET transport protocols sockets.
3 *
4 * Authors: Alexey Kuznetsov, <kuznet@ms2.inr.ac.ru>
5 *
6 * This program is free software; you can redistribute it and/or
7 * modify it under the terms of the GNU General Public License
8 * as published by the Free Software Foundation; either version
9 * 2 of the License, or (at your option) any later version.
10 */
11
12 #include <linux/kernel.h>
13 #include <linux/module.h>
14 #include <linux/types.h>
15 #include <linux/fcntl.h>
16 #include <linux/random.h>
17 #include <linux/slab.h>
18 #include <linux/cache.h>
19 #include <linux/init.h>
20 #include <linux/time.h>
21
22 #include <net/icmp.h>
23 #include <net/tcp.h>
24 #include <net/ipv6.h>
25 #include <net/inet_common.h>
26 #include <net/inet_connection_sock.h>
27 #include <net/inet_hashtables.h>
28 #include <net/inet_timewait_sock.h>
29 #include <net/inet6_hashtables.h>
30 #include <net/netlink.h>
31
32 #include <linux/inet.h>
33 #include <linux/stddef.h>
34
35 #include <linux/inet_diag.h>
36 #include <linux/sock_diag.h>
37
38 static const struct inet_diag_handler **inet_diag_table;
39
40 struct inet_diag_entry {
41 __be32 *saddr;
42 __be32 *daddr;
43 u16 sport;
44 u16 dport;
45 u16 family;
46 u16 userlocks;
47 #if IS_ENABLED(CONFIG_IPV6)
48 struct in6_addr saddr_storage; /* for IPv4-mapped-IPv6 addresses */
49 struct in6_addr daddr_storage; /* for IPv4-mapped-IPv6 addresses */
50 #endif
51 u32 ifindex;
52 u32 mark;
53 };
54
55 static DEFINE_MUTEX(inet_diag_table_mutex);
56
inet_diag_lock_handler(int proto)57 static const struct inet_diag_handler *inet_diag_lock_handler(int proto)
58 {
59 if (!inet_diag_table[proto])
60 request_module("net-pf-%d-proto-%d-type-%d-%d", PF_NETLINK,
61 NETLINK_SOCK_DIAG, AF_INET, proto);
62
63 mutex_lock(&inet_diag_table_mutex);
64 if (!inet_diag_table[proto])
65 return ERR_PTR(-ENOENT);
66
67 return inet_diag_table[proto];
68 }
69
inet_diag_unlock_handler(const struct inet_diag_handler * handler)70 static inline void inet_diag_unlock_handler(
71 const struct inet_diag_handler *handler)
72 {
73 mutex_unlock(&inet_diag_table_mutex);
74 }
75
inet_sk_diag_fill(struct sock * sk,struct inet_connection_sock * icsk,struct sk_buff * skb,struct inet_diag_req_v2 * req,struct user_namespace * user_ns,u32 portid,u32 seq,u16 nlmsg_flags,const struct nlmsghdr * unlh,bool net_admin)76 int inet_sk_diag_fill(struct sock *sk, struct inet_connection_sock *icsk,
77 struct sk_buff *skb, struct inet_diag_req_v2 *req,
78 struct user_namespace *user_ns,
79 u32 portid, u32 seq, u16 nlmsg_flags,
80 const struct nlmsghdr *unlh, bool net_admin)
81 {
82 const struct inet_sock *inet = inet_sk(sk);
83 struct inet_diag_msg *r;
84 struct nlmsghdr *nlh;
85 struct nlattr *attr;
86 void *info = NULL;
87 const struct inet_diag_handler *handler;
88 int ext = req->idiag_ext;
89
90 handler = inet_diag_table[req->sdiag_protocol];
91 BUG_ON(handler == NULL);
92
93 nlh = nlmsg_put(skb, portid, seq, unlh->nlmsg_type, sizeof(*r),
94 nlmsg_flags);
95 if (!nlh)
96 return -EMSGSIZE;
97
98 r = nlmsg_data(nlh);
99 BUG_ON(sk->sk_state == TCP_TIME_WAIT);
100
101 r->idiag_family = sk->sk_family;
102 r->idiag_state = sk->sk_state;
103 r->idiag_timer = 0;
104 r->idiag_retrans = 0;
105
106 r->id.idiag_if = sk->sk_bound_dev_if;
107 sock_diag_save_cookie(sk, r->id.idiag_cookie);
108
109 r->id.idiag_sport = inet->inet_sport;
110 r->id.idiag_dport = inet->inet_dport;
111
112 memset(&r->id.idiag_src, 0, sizeof(r->id.idiag_src));
113 memset(&r->id.idiag_dst, 0, sizeof(r->id.idiag_dst));
114
115 r->id.idiag_src[0] = inet->inet_rcv_saddr;
116 r->id.idiag_dst[0] = inet->inet_daddr;
117
118 if (nla_put_u8(skb, INET_DIAG_SHUTDOWN, sk->sk_shutdown))
119 goto errout;
120
121 /* IPv6 dual-stack sockets use inet->tos for IPv4 connections,
122 * hence this needs to be included regardless of socket family.
123 */
124 if (ext & (1 << (INET_DIAG_TOS - 1)))
125 if (nla_put_u8(skb, INET_DIAG_TOS, inet->tos) < 0)
126 goto errout;
127
128 #if IS_ENABLED(CONFIG_IPV6)
129 if (r->idiag_family == AF_INET6) {
130 const struct ipv6_pinfo *np = inet6_sk(sk);
131
132 *(struct in6_addr *)r->id.idiag_src = np->rcv_saddr;
133 *(struct in6_addr *)r->id.idiag_dst = np->daddr;
134
135 if (ext & (1 << (INET_DIAG_TCLASS - 1)))
136 if (nla_put_u8(skb, INET_DIAG_TCLASS, np->tclass) < 0)
137 goto errout;
138 }
139 #endif
140
141 if (net_admin && nla_put_u32(skb, INET_DIAG_MARK, sk->sk_mark))
142 goto errout;
143
144 r->idiag_uid = from_kuid_munged(user_ns, sock_i_uid(sk));
145 r->idiag_inode = sock_i_ino(sk);
146
147 if (ext & (1 << (INET_DIAG_MEMINFO - 1))) {
148 struct inet_diag_meminfo minfo = {
149 .idiag_rmem = sk_rmem_alloc_get(sk),
150 .idiag_wmem = sk->sk_wmem_queued,
151 .idiag_fmem = sk->sk_forward_alloc,
152 .idiag_tmem = sk_wmem_alloc_get(sk),
153 };
154
155 if (nla_put(skb, INET_DIAG_MEMINFO, sizeof(minfo), &minfo) < 0)
156 goto errout;
157 }
158
159 if (ext & (1 << (INET_DIAG_SKMEMINFO - 1)))
160 if (sock_diag_put_meminfo(sk, skb, INET_DIAG_SKMEMINFO))
161 goto errout;
162
163 if (icsk == NULL) {
164 handler->idiag_get_info(sk, r, NULL);
165 goto out;
166 }
167
168 #define EXPIRES_IN_MS(tmo) DIV_ROUND_UP((tmo - jiffies) * 1000, HZ)
169
170 if (icsk->icsk_pending == ICSK_TIME_RETRANS ||
171 icsk->icsk_pending == ICSK_TIME_EARLY_RETRANS ||
172 icsk->icsk_pending == ICSK_TIME_LOSS_PROBE) {
173 r->idiag_timer = 1;
174 r->idiag_retrans = icsk->icsk_retransmits;
175 r->idiag_expires = EXPIRES_IN_MS(icsk->icsk_timeout);
176 } else if (icsk->icsk_pending == ICSK_TIME_PROBE0) {
177 r->idiag_timer = 4;
178 r->idiag_retrans = icsk->icsk_probes_out;
179 r->idiag_expires = EXPIRES_IN_MS(icsk->icsk_timeout);
180 } else if (timer_pending(&sk->sk_timer)) {
181 r->idiag_timer = 2;
182 r->idiag_retrans = icsk->icsk_probes_out;
183 r->idiag_expires = EXPIRES_IN_MS(sk->sk_timer.expires);
184 } else {
185 r->idiag_timer = 0;
186 r->idiag_expires = 0;
187 }
188 #undef EXPIRES_IN_MS
189
190 if (ext & (1 << (INET_DIAG_INFO - 1))) {
191 attr = nla_reserve(skb, INET_DIAG_INFO,
192 sizeof(struct tcp_info));
193 if (!attr)
194 goto errout;
195
196 info = nla_data(attr);
197 }
198
199 if ((ext & (1 << (INET_DIAG_CONG - 1))) && icsk->icsk_ca_ops)
200 if (nla_put_string(skb, INET_DIAG_CONG,
201 icsk->icsk_ca_ops->name) < 0)
202 goto errout;
203
204 handler->idiag_get_info(sk, r, info);
205
206 if (sk->sk_state < TCP_TIME_WAIT &&
207 icsk->icsk_ca_ops && icsk->icsk_ca_ops->get_info)
208 icsk->icsk_ca_ops->get_info(sk, ext, skb);
209
210 out:
211 return nlmsg_end(skb, nlh);
212
213 errout:
214 nlmsg_cancel(skb, nlh);
215 return -EMSGSIZE;
216 }
217 EXPORT_SYMBOL_GPL(inet_sk_diag_fill);
218
inet_csk_diag_fill(struct sock * sk,struct sk_buff * skb,struct inet_diag_req_v2 * req,struct user_namespace * user_ns,u32 portid,u32 seq,u16 nlmsg_flags,const struct nlmsghdr * unlh,bool net_admin)219 static int inet_csk_diag_fill(struct sock *sk,
220 struct sk_buff *skb, struct inet_diag_req_v2 *req,
221 struct user_namespace *user_ns,
222 u32 portid, u32 seq, u16 nlmsg_flags,
223 const struct nlmsghdr *unlh,
224 bool net_admin)
225 {
226 return inet_sk_diag_fill(sk, inet_csk(sk),
227 skb, req, user_ns, portid, seq, nlmsg_flags, unlh, net_admin);
228 }
229
inet_twsk_diag_fill(struct inet_timewait_sock * tw,struct sk_buff * skb,struct inet_diag_req_v2 * req,u32 portid,u32 seq,u16 nlmsg_flags,const struct nlmsghdr * unlh)230 static int inet_twsk_diag_fill(struct inet_timewait_sock *tw,
231 struct sk_buff *skb, struct inet_diag_req_v2 *req,
232 u32 portid, u32 seq, u16 nlmsg_flags,
233 const struct nlmsghdr *unlh)
234 {
235 long tmo;
236 struct inet_diag_msg *r;
237 struct nlmsghdr *nlh;
238
239 nlh = nlmsg_put(skb, portid, seq, unlh->nlmsg_type, sizeof(*r),
240 nlmsg_flags);
241 if (!nlh)
242 return -EMSGSIZE;
243
244 r = nlmsg_data(nlh);
245 BUG_ON(tw->tw_state != TCP_TIME_WAIT);
246
247 tmo = tw->tw_ttd - jiffies;
248 if (tmo < 0)
249 tmo = 0;
250
251 r->idiag_family = tw->tw_family;
252 r->idiag_retrans = 0;
253
254 r->id.idiag_if = tw->tw_bound_dev_if;
255 sock_diag_save_cookie(tw, r->id.idiag_cookie);
256
257 r->id.idiag_sport = tw->tw_sport;
258 r->id.idiag_dport = tw->tw_dport;
259
260 memset(&r->id.idiag_src, 0, sizeof(r->id.idiag_src));
261 memset(&r->id.idiag_dst, 0, sizeof(r->id.idiag_dst));
262
263 r->id.idiag_src[0] = tw->tw_rcv_saddr;
264 r->id.idiag_dst[0] = tw->tw_daddr;
265
266 r->idiag_state = tw->tw_substate;
267 r->idiag_timer = 3;
268 r->idiag_expires = DIV_ROUND_UP(tmo * 1000, HZ);
269 r->idiag_rqueue = 0;
270 r->idiag_wqueue = 0;
271 r->idiag_uid = 0;
272 r->idiag_inode = 0;
273 #if IS_ENABLED(CONFIG_IPV6)
274 if (tw->tw_family == AF_INET6) {
275 const struct inet6_timewait_sock *tw6 =
276 inet6_twsk((struct sock *)tw);
277
278 *(struct in6_addr *)r->id.idiag_src = tw6->tw_v6_rcv_saddr;
279 *(struct in6_addr *)r->id.idiag_dst = tw6->tw_v6_daddr;
280 }
281 #endif
282
283 return nlmsg_end(skb, nlh);
284 }
285
sk_diag_fill(struct sock * sk,struct sk_buff * skb,struct inet_diag_req_v2 * r,struct user_namespace * user_ns,u32 portid,u32 seq,u16 nlmsg_flags,const struct nlmsghdr * unlh,bool net_admin)286 static int sk_diag_fill(struct sock *sk, struct sk_buff *skb,
287 struct inet_diag_req_v2 *r,
288 struct user_namespace *user_ns,
289 u32 portid, u32 seq, u16 nlmsg_flags,
290 const struct nlmsghdr *unlh, bool net_admin)
291 {
292 if (sk->sk_state == TCP_TIME_WAIT)
293 return inet_twsk_diag_fill((struct inet_timewait_sock *)sk,
294 skb, r, portid, seq, nlmsg_flags,
295 unlh);
296 return inet_csk_diag_fill(sk, skb, r, user_ns, portid, seq, nlmsg_flags,
297 unlh, net_admin);
298 }
299
inet_diag_find_one_icsk(struct net * net,struct inet_hashinfo * hashinfo,struct inet_diag_req_v2 * req)300 struct sock *inet_diag_find_one_icsk(struct net *net,
301 struct inet_hashinfo *hashinfo,
302 struct inet_diag_req_v2 *req)
303 {
304 struct sock *sk;
305
306 if (req->sdiag_family == AF_INET) {
307 sk = inet_lookup(net, hashinfo, req->id.idiag_dst[0],
308 req->id.idiag_dport, req->id.idiag_src[0],
309 req->id.idiag_sport, req->id.idiag_if);
310 }
311 #if IS_ENABLED(CONFIG_IPV6)
312 else if (req->sdiag_family == AF_INET6) {
313 if (ipv6_addr_v4mapped((struct in6_addr *)req->id.idiag_dst) &&
314 ipv6_addr_v4mapped((struct in6_addr *)req->id.idiag_src))
315 sk = inet_lookup(net, hashinfo, req->id.idiag_dst[3],
316 req->id.idiag_dport, req->id.idiag_src[3],
317 req->id.idiag_sport, req->id.idiag_if);
318 else
319 sk = inet6_lookup(net, hashinfo,
320 (struct in6_addr *)req->id.idiag_dst,
321 req->id.idiag_dport,
322 (struct in6_addr *)req->id.idiag_src,
323 req->id.idiag_sport,
324 req->id.idiag_if);
325 }
326 #endif
327 else {
328 return ERR_PTR(-EINVAL);
329 }
330
331 if (!sk)
332 return ERR_PTR(-ENOENT);
333
334 if (sock_diag_check_cookie(sk, req->id.idiag_cookie)) {
335 /* NOTE: forward-ports should use sock_gen_put(sk) instead. */
336 if (sk->sk_state == TCP_TIME_WAIT)
337 inet_twsk_put((struct inet_timewait_sock *)sk);
338 else
339 sock_put(sk);
340 return ERR_PTR(-ENOENT);
341 }
342
343 return sk;
344 }
345 EXPORT_SYMBOL_GPL(inet_diag_find_one_icsk);
346
inet_diag_dump_one_icsk(struct inet_hashinfo * hashinfo,struct sk_buff * in_skb,const struct nlmsghdr * nlh,struct inet_diag_req_v2 * req)347 int inet_diag_dump_one_icsk(struct inet_hashinfo *hashinfo,
348 struct sk_buff *in_skb,
349 const struct nlmsghdr *nlh,
350 struct inet_diag_req_v2 *req)
351 {
352 struct net *net = sock_net(in_skb->sk);
353 struct sk_buff *rep;
354 struct sock *sk;
355 int err;
356
357 sk = inet_diag_find_one_icsk(net, hashinfo, req);
358 if (IS_ERR(sk))
359 return PTR_ERR(sk);
360
361 rep = nlmsg_new(sizeof(struct inet_diag_msg) +
362 sizeof(struct inet_diag_meminfo) +
363 sizeof(struct tcp_info) + 64, GFP_KERNEL);
364 if (!rep) {
365 err = -ENOMEM;
366 goto out;
367 }
368
369 err = sk_diag_fill(sk, rep, req,
370 sk_user_ns(NETLINK_CB(in_skb).sk),
371 NETLINK_CB(in_skb).portid,
372 nlh->nlmsg_seq, 0, nlh,
373 ns_capable(sock_net(in_skb->sk)->user_ns,
374 CAP_NET_ADMIN));
375 if (err < 0) {
376 WARN_ON(err == -EMSGSIZE);
377 nlmsg_free(rep);
378 goto out;
379 }
380 err = netlink_unicast(net->diag_nlsk, rep, NETLINK_CB(in_skb).portid,
381 MSG_DONTWAIT);
382 if (err > 0)
383 err = 0;
384
385 out:
386 if (sk) {
387 if (sk->sk_state == TCP_TIME_WAIT)
388 inet_twsk_put((struct inet_timewait_sock *)sk);
389 else
390 sock_put(sk);
391 }
392 return err;
393 }
394 EXPORT_SYMBOL_GPL(inet_diag_dump_one_icsk);
395
inet_diag_cmd_exact(int cmd,struct sk_buff * in_skb,const struct nlmsghdr * nlh,struct inet_diag_req_v2 * req)396 static int inet_diag_cmd_exact(int cmd, struct sk_buff *in_skb,
397 const struct nlmsghdr *nlh,
398 struct inet_diag_req_v2 *req)
399 {
400 const struct inet_diag_handler *handler;
401 int err;
402
403 handler = inet_diag_lock_handler(req->sdiag_protocol);
404 if (IS_ERR(handler))
405 err = PTR_ERR(handler);
406 else if (cmd == SOCK_DIAG_BY_FAMILY)
407 err = handler->dump_one(in_skb, nlh, req);
408 else if (cmd == SOCK_DESTROY_BACKPORT && handler->destroy)
409 err = handler->destroy(in_skb, req);
410 else
411 err = -EOPNOTSUPP;
412 inet_diag_unlock_handler(handler);
413
414 return err;
415 }
416
bitstring_match(const __be32 * a1,const __be32 * a2,int bits)417 static int bitstring_match(const __be32 *a1, const __be32 *a2, int bits)
418 {
419 int words = bits >> 5;
420
421 bits &= 0x1f;
422
423 if (words) {
424 if (memcmp(a1, a2, words << 2))
425 return 0;
426 }
427 if (bits) {
428 __be32 w1, w2;
429 __be32 mask;
430
431 w1 = a1[words];
432 w2 = a2[words];
433
434 mask = htonl((0xffffffff) << (32 - bits));
435
436 if ((w1 ^ w2) & mask)
437 return 0;
438 }
439
440 return 1;
441 }
442
443
inet_diag_bc_run(const struct nlattr * _bc,const struct inet_diag_entry * entry)444 static int inet_diag_bc_run(const struct nlattr *_bc,
445 const struct inet_diag_entry *entry)
446 {
447 const void *bc = nla_data(_bc);
448 int len = nla_len(_bc);
449
450 while (len > 0) {
451 int yes = 1;
452 const struct inet_diag_bc_op *op = bc;
453
454 switch (op->code) {
455 case INET_DIAG_BC_NOP:
456 break;
457 case INET_DIAG_BC_JMP:
458 yes = 0;
459 break;
460 case INET_DIAG_BC_S_GE:
461 yes = entry->sport >= op[1].no;
462 break;
463 case INET_DIAG_BC_S_LE:
464 yes = entry->sport <= op[1].no;
465 break;
466 case INET_DIAG_BC_D_GE:
467 yes = entry->dport >= op[1].no;
468 break;
469 case INET_DIAG_BC_D_LE:
470 yes = entry->dport <= op[1].no;
471 break;
472 case INET_DIAG_BC_AUTO:
473 yes = !(entry->userlocks & SOCK_BINDPORT_LOCK);
474 break;
475 case INET_DIAG_BC_S_COND:
476 case INET_DIAG_BC_D_COND: {
477 struct inet_diag_hostcond *cond;
478 __be32 *addr;
479
480 cond = (struct inet_diag_hostcond *)(op + 1);
481 if (cond->port != -1 &&
482 cond->port != (op->code == INET_DIAG_BC_S_COND ?
483 entry->sport : entry->dport)) {
484 yes = 0;
485 break;
486 }
487
488 if (op->code == INET_DIAG_BC_S_COND)
489 addr = entry->saddr;
490 else
491 addr = entry->daddr;
492
493 if (cond->family != AF_UNSPEC &&
494 cond->family != entry->family) {
495 if (entry->family == AF_INET6 &&
496 cond->family == AF_INET) {
497 if (addr[0] == 0 && addr[1] == 0 &&
498 addr[2] == htonl(0xffff) &&
499 bitstring_match(addr + 3,
500 cond->addr,
501 cond->prefix_len))
502 break;
503 }
504 yes = 0;
505 break;
506 }
507
508 if (cond->prefix_len == 0)
509 break;
510 if (bitstring_match(addr, cond->addr,
511 cond->prefix_len))
512 break;
513 yes = 0;
514 break;
515 }
516 case INET_DIAG_BC_DEV_COND: {
517 u32 ifindex;
518
519 ifindex = *((const u32 *)(op + 1));
520 if (ifindex != entry->ifindex)
521 yes = 0;
522 break;
523 }
524 case INET_DIAG_BC_MARK_COND: {
525 struct inet_diag_markcond *cond;
526
527 cond = (struct inet_diag_markcond *)(op + 1);
528 if ((entry->mark & cond->mask) != cond->mark)
529 yes = 0;
530 break;
531 }
532 }
533
534 if (yes) {
535 len -= op->yes;
536 bc += op->yes;
537 } else {
538 len -= op->no;
539 bc += op->no;
540 }
541 }
542 return len == 0;
543 }
544
inet_diag_bc_sk(const struct nlattr * bc,struct sock * sk)545 int inet_diag_bc_sk(const struct nlattr *bc, struct sock *sk)
546 {
547 struct inet_diag_entry entry;
548 struct inet_sock *inet = inet_sk(sk);
549
550 if (bc == NULL)
551 return 1;
552
553 entry.family = sk->sk_family;
554 #if IS_ENABLED(CONFIG_IPV6)
555 if (entry.family == AF_INET6) {
556 struct ipv6_pinfo *np = inet6_sk(sk);
557
558 entry.saddr = np->rcv_saddr.s6_addr32;
559 entry.daddr = np->daddr.s6_addr32;
560 } else
561 #endif
562 {
563 entry.saddr = &inet->inet_rcv_saddr;
564 entry.daddr = &inet->inet_daddr;
565 }
566 entry.sport = inet->inet_num;
567 entry.dport = ntohs(inet->inet_dport);
568 entry.ifindex = sk->sk_bound_dev_if;
569 entry.userlocks = sk->sk_userlocks;
570 entry.mark = sk->sk_mark;
571
572 return inet_diag_bc_run(bc, &entry);
573 }
574 EXPORT_SYMBOL_GPL(inet_diag_bc_sk);
575
valid_cc(const void * bc,int len,int cc)576 static int valid_cc(const void *bc, int len, int cc)
577 {
578 while (len >= 0) {
579 const struct inet_diag_bc_op *op = bc;
580
581 if (cc > len)
582 return 0;
583 if (cc == len)
584 return 1;
585 if (op->yes < 4 || op->yes & 3)
586 return 0;
587 len -= op->yes;
588 bc += op->yes;
589 }
590 return 0;
591 }
592
593 /* data is u32 ifindex */
valid_devcond(const struct inet_diag_bc_op * op,int len,int * min_len)594 static bool valid_devcond(const struct inet_diag_bc_op *op, int len,
595 int *min_len)
596 {
597 /* Check ifindex space. */
598 *min_len += sizeof(u32);
599 if (len < *min_len)
600 return false;
601
602 return true;
603 }
604 /* Validate an inet_diag_hostcond. */
valid_hostcond(const struct inet_diag_bc_op * op,int len,int * min_len)605 static bool valid_hostcond(const struct inet_diag_bc_op *op, int len,
606 int *min_len)
607 {
608 int addr_len;
609 struct inet_diag_hostcond *cond;
610
611 /* Check hostcond space. */
612 *min_len += sizeof(struct inet_diag_hostcond);
613 if (len < *min_len)
614 return false;
615 cond = (struct inet_diag_hostcond *)(op + 1);
616
617 /* Check address family and address length. */
618 switch (cond->family) {
619 case AF_UNSPEC:
620 addr_len = 0;
621 break;
622 case AF_INET:
623 addr_len = sizeof(struct in_addr);
624 break;
625 case AF_INET6:
626 addr_len = sizeof(struct in6_addr);
627 break;
628 default:
629 return false;
630 }
631 *min_len += addr_len;
632 if (len < *min_len)
633 return false;
634
635 /* Check prefix length (in bits) vs address length (in bytes). */
636 if (cond->prefix_len > 8 * addr_len)
637 return false;
638
639 return true;
640 }
641
642 /* Validate a port comparison operator. */
valid_port_comparison(const struct inet_diag_bc_op * op,int len,int * min_len)643 static inline bool valid_port_comparison(const struct inet_diag_bc_op *op,
644 int len, int *min_len)
645 {
646 /* Port comparisons put the port in a follow-on inet_diag_bc_op. */
647 *min_len += sizeof(struct inet_diag_bc_op);
648 if (len < *min_len)
649 return false;
650 return true;
651 }
652
valid_markcond(const struct inet_diag_bc_op * op,int len,int * min_len)653 static bool valid_markcond(const struct inet_diag_bc_op *op, int len,
654 int *min_len)
655 {
656 *min_len += sizeof(struct inet_diag_markcond);
657 return len >= *min_len;
658 }
659
inet_diag_bc_audit(const struct nlattr * attr,const struct sk_buff * skb)660 static int inet_diag_bc_audit(const struct nlattr *attr,
661 const struct sk_buff *skb)
662 {
663 bool net_admin = ns_capable(sock_net(skb->sk)->user_ns, CAP_NET_ADMIN);
664 const void *bytecode, *bc;
665 int bytecode_len, len;
666
667 if (!attr || nla_len(attr) < sizeof(struct inet_diag_bc_op))
668 return -EINVAL;
669
670 bytecode = bc = nla_data(attr);
671 len = bytecode_len = nla_len(attr);
672
673 while (len > 0) {
674 const struct inet_diag_bc_op *op = bc;
675 int min_len = sizeof(struct inet_diag_bc_op);
676
677 //printk("BC: %d %d %d {%d} / %d\n", op->code, op->yes, op->no, op[1].no, len);
678 switch (op->code) {
679 case INET_DIAG_BC_S_COND:
680 case INET_DIAG_BC_D_COND:
681 if (!valid_hostcond(bc, len, &min_len))
682 return -EINVAL;
683 break;
684 case INET_DIAG_BC_DEV_COND:
685 if (!valid_devcond(bc, len, &min_len))
686 return -EINVAL;
687 break;
688 case INET_DIAG_BC_S_GE:
689 case INET_DIAG_BC_S_LE:
690 case INET_DIAG_BC_D_GE:
691 case INET_DIAG_BC_D_LE:
692 if (!valid_port_comparison(bc, len, &min_len))
693 return -EINVAL;
694 break;
695 case INET_DIAG_BC_MARK_COND:
696 if (!net_admin)
697 return -EPERM;
698 if (!valid_markcond(bc, len, &min_len))
699 return -EINVAL;
700 break;
701 case INET_DIAG_BC_AUTO:
702 case INET_DIAG_BC_JMP:
703 case INET_DIAG_BC_NOP:
704 break;
705 default:
706 return -EINVAL;
707 }
708
709 if (op->code != INET_DIAG_BC_NOP) {
710 if (op->no < min_len || op->no > len + 4 || op->no & 3)
711 return -EINVAL;
712 if (op->no < len &&
713 !valid_cc(bytecode, bytecode_len, len - op->no))
714 return -EINVAL;
715 }
716
717 if (op->yes < min_len || op->yes > len + 4 || op->yes & 3)
718 return -EINVAL;
719 bc += op->yes;
720 len -= op->yes;
721 }
722 return len == 0 ? 0 : -EINVAL;
723 }
724
inet_csk_diag_dump(struct sock * sk,struct sk_buff * skb,struct netlink_callback * cb,struct inet_diag_req_v2 * r,const struct nlattr * bc,bool net_admin)725 static int inet_csk_diag_dump(struct sock *sk,
726 struct sk_buff *skb,
727 struct netlink_callback *cb,
728 struct inet_diag_req_v2 *r,
729 const struct nlattr *bc,
730 bool net_admin)
731 {
732 if (!inet_diag_bc_sk(bc, sk))
733 return 0;
734
735 return inet_csk_diag_fill(sk, skb, r,
736 sk_user_ns(NETLINK_CB(cb->skb).sk),
737 NETLINK_CB(cb->skb).portid,
738 cb->nlh->nlmsg_seq, NLM_F_MULTI, cb->nlh,
739 net_admin);
740 }
741
inet_twsk_diag_dump(struct inet_timewait_sock * tw,struct sk_buff * skb,struct netlink_callback * cb,struct inet_diag_req_v2 * r,const struct nlattr * bc)742 static int inet_twsk_diag_dump(struct inet_timewait_sock *tw,
743 struct sk_buff *skb,
744 struct netlink_callback *cb,
745 struct inet_diag_req_v2 *r,
746 const struct nlattr *bc)
747 {
748 if (bc != NULL) {
749 struct inet_diag_entry entry;
750
751 entry.family = tw->tw_family;
752 #if IS_ENABLED(CONFIG_IPV6)
753 if (tw->tw_family == AF_INET6) {
754 struct inet6_timewait_sock *tw6 =
755 inet6_twsk((struct sock *)tw);
756 entry.saddr = tw6->tw_v6_rcv_saddr.s6_addr32;
757 entry.daddr = tw6->tw_v6_daddr.s6_addr32;
758 } else
759 #endif
760 {
761 entry.saddr = &tw->tw_rcv_saddr;
762 entry.daddr = &tw->tw_daddr;
763 }
764 entry.sport = tw->tw_num;
765 entry.dport = ntohs(tw->tw_dport);
766 entry.userlocks = 0;
767 entry.mark = 0;
768
769 if (!inet_diag_bc_run(bc, &entry))
770 return 0;
771 }
772
773 return inet_twsk_diag_fill(tw, skb, r,
774 NETLINK_CB(cb->skb).portid,
775 cb->nlh->nlmsg_seq, NLM_F_MULTI, cb->nlh);
776 }
777
778 /* Get the IPv4, IPv6, or IPv4-mapped-IPv6 local and remote addresses
779 * from a request_sock. For IPv4-mapped-IPv6 we must map IPv4 to IPv6.
780 */
inet_diag_req_addrs(const struct sock * sk,const struct request_sock * req,struct inet_diag_entry * entry)781 static inline void inet_diag_req_addrs(const struct sock *sk,
782 const struct request_sock *req,
783 struct inet_diag_entry *entry)
784 {
785 struct inet_request_sock *ireq = inet_rsk(req);
786
787 #if IS_ENABLED(CONFIG_IPV6)
788 if (sk->sk_family == AF_INET6) {
789 if (req->rsk_ops->family == AF_INET6) {
790 entry->saddr = inet6_rsk(req)->loc_addr.s6_addr32;
791 entry->daddr = inet6_rsk(req)->rmt_addr.s6_addr32;
792 } else if (req->rsk_ops->family == AF_INET) {
793 ipv6_addr_set_v4mapped(ireq->loc_addr,
794 &entry->saddr_storage);
795 ipv6_addr_set_v4mapped(ireq->rmt_addr,
796 &entry->daddr_storage);
797 entry->saddr = entry->saddr_storage.s6_addr32;
798 entry->daddr = entry->daddr_storage.s6_addr32;
799 }
800 } else
801 #endif
802 {
803 entry->saddr = &ireq->loc_addr;
804 entry->daddr = &ireq->rmt_addr;
805 }
806 }
807
inet_diag_fill_req(struct sk_buff * skb,struct sock * sk,struct request_sock * req,struct user_namespace * user_ns,u32 portid,u32 seq,const struct nlmsghdr * unlh)808 static int inet_diag_fill_req(struct sk_buff *skb, struct sock *sk,
809 struct request_sock *req,
810 struct user_namespace *user_ns,
811 u32 portid, u32 seq,
812 const struct nlmsghdr *unlh)
813 {
814 const struct inet_request_sock *ireq = inet_rsk(req);
815 struct inet_sock *inet = inet_sk(sk);
816 struct inet_diag_msg *r;
817 struct nlmsghdr *nlh;
818 long tmo;
819
820 nlh = nlmsg_put(skb, portid, seq, unlh->nlmsg_type, sizeof(*r),
821 NLM_F_MULTI);
822 if (!nlh)
823 return -EMSGSIZE;
824
825 r = nlmsg_data(nlh);
826 r->idiag_family = sk->sk_family;
827 r->idiag_state = TCP_SYN_RECV;
828 r->idiag_timer = 1;
829 r->idiag_retrans = req->num_retrans;
830
831 r->id.idiag_if = sk->sk_bound_dev_if;
832 sock_diag_save_cookie(req, r->id.idiag_cookie);
833
834 tmo = req->expires - jiffies;
835 if (tmo < 0)
836 tmo = 0;
837
838 r->id.idiag_sport = inet->inet_sport;
839 r->id.idiag_dport = ireq->rmt_port;
840
841 memset(&r->id.idiag_src, 0, sizeof(r->id.idiag_src));
842 memset(&r->id.idiag_dst, 0, sizeof(r->id.idiag_dst));
843
844 r->id.idiag_src[0] = ireq->loc_addr;
845 r->id.idiag_dst[0] = ireq->rmt_addr;
846 r->idiag_expires = jiffies_to_msecs(tmo);
847 r->idiag_rqueue = 0;
848 r->idiag_wqueue = 0;
849 r->idiag_uid = from_kuid_munged(user_ns, sock_i_uid(sk));
850 r->idiag_inode = 0;
851 #if IS_ENABLED(CONFIG_IPV6)
852 if (r->idiag_family == AF_INET6) {
853 struct inet_diag_entry entry;
854 inet_diag_req_addrs(sk, req, &entry);
855 memcpy(r->id.idiag_src, entry.saddr, sizeof(struct in6_addr));
856 memcpy(r->id.idiag_dst, entry.daddr, sizeof(struct in6_addr));
857 }
858 #endif
859
860 return nlmsg_end(skb, nlh);
861 }
862
inet_diag_dump_reqs(struct sk_buff * skb,struct sock * sk,struct netlink_callback * cb,struct inet_diag_req_v2 * r,const struct nlattr * bc)863 static int inet_diag_dump_reqs(struct sk_buff *skb, struct sock *sk,
864 struct netlink_callback *cb,
865 struct inet_diag_req_v2 *r,
866 const struct nlattr *bc)
867 {
868 struct inet_diag_entry entry;
869 struct inet_connection_sock *icsk = inet_csk(sk);
870 struct listen_sock *lopt;
871 struct inet_sock *inet = inet_sk(sk);
872 int j, s_j;
873 int reqnum, s_reqnum;
874 int err = 0;
875
876 s_j = cb->args[3];
877 s_reqnum = cb->args[4];
878
879 if (s_j > 0)
880 s_j--;
881
882 entry.family = sk->sk_family;
883
884 read_lock_bh(&icsk->icsk_accept_queue.syn_wait_lock);
885
886 lopt = icsk->icsk_accept_queue.listen_opt;
887 if (!lopt || !lopt->qlen)
888 goto out;
889
890 if (bc != NULL) {
891 entry.sport = inet->inet_num;
892 entry.userlocks = sk->sk_userlocks;
893 entry.mark = sk->sk_mark;
894 }
895
896 for (j = s_j; j < lopt->nr_table_entries; j++) {
897 struct request_sock *req, *head = lopt->syn_table[j];
898
899 reqnum = 0;
900 for (req = head; req; reqnum++, req = req->dl_next) {
901 struct inet_request_sock *ireq = inet_rsk(req);
902
903 if (reqnum < s_reqnum)
904 continue;
905 if (r->id.idiag_dport != ireq->rmt_port &&
906 r->id.idiag_dport)
907 continue;
908
909 if (bc) {
910 inet_diag_req_addrs(sk, req, &entry);
911 entry.dport = ntohs(ireq->rmt_port);
912
913 if (!inet_diag_bc_run(bc, &entry))
914 continue;
915 }
916
917 err = inet_diag_fill_req(skb, sk, req,
918 sk_user_ns(NETLINK_CB(cb->skb).sk),
919 NETLINK_CB(cb->skb).portid,
920 cb->nlh->nlmsg_seq, cb->nlh);
921 if (err < 0) {
922 cb->args[3] = j + 1;
923 cb->args[4] = reqnum;
924 goto out;
925 }
926 }
927
928 s_reqnum = 0;
929 }
930
931 out:
932 read_unlock_bh(&icsk->icsk_accept_queue.syn_wait_lock);
933
934 return err;
935 }
936
inet_diag_dump_icsk(struct inet_hashinfo * hashinfo,struct sk_buff * skb,struct netlink_callback * cb,struct inet_diag_req_v2 * r,struct nlattr * bc)937 void inet_diag_dump_icsk(struct inet_hashinfo *hashinfo, struct sk_buff *skb,
938 struct netlink_callback *cb, struct inet_diag_req_v2 *r, struct nlattr *bc)
939 {
940 int i, num;
941 int s_i, s_num;
942 struct net *net = sock_net(skb->sk);
943 bool net_admin = ns_capable(net->user_ns, CAP_NET_ADMIN);
944
945 s_i = cb->args[1];
946 s_num = num = cb->args[2];
947
948 if (cb->args[0] == 0) {
949 if (!(r->idiag_states & (TCPF_LISTEN | TCPF_SYN_RECV)))
950 goto skip_listen_ht;
951
952 for (i = s_i; i < INET_LHTABLE_SIZE; i++) {
953 struct sock *sk;
954 struct hlist_nulls_node *node;
955 struct inet_listen_hashbucket *ilb;
956
957 num = 0;
958 ilb = &hashinfo->listening_hash[i];
959 spin_lock_bh(&ilb->lock);
960 sk_nulls_for_each(sk, node, &ilb->head) {
961 struct inet_sock *inet = inet_sk(sk);
962
963 if (!net_eq(sock_net(sk), net))
964 continue;
965
966 if (num < s_num) {
967 num++;
968 continue;
969 }
970
971 if (r->sdiag_family != AF_UNSPEC &&
972 sk->sk_family != r->sdiag_family)
973 goto next_listen;
974
975 if (r->id.idiag_sport != inet->inet_sport &&
976 r->id.idiag_sport)
977 goto next_listen;
978
979 if (!(r->idiag_states & TCPF_LISTEN) ||
980 r->id.idiag_dport ||
981 cb->args[3] > 0)
982 goto syn_recv;
983
984 if (inet_csk_diag_dump(sk, skb, cb, r,
985 bc, net_admin) < 0) {
986 spin_unlock_bh(&ilb->lock);
987 goto done;
988 }
989
990 syn_recv:
991 if (!(r->idiag_states & TCPF_SYN_RECV))
992 goto next_listen;
993
994 if (inet_diag_dump_reqs(skb, sk, cb, r, bc) < 0) {
995 spin_unlock_bh(&ilb->lock);
996 goto done;
997 }
998
999 next_listen:
1000 cb->args[3] = 0;
1001 cb->args[4] = 0;
1002 ++num;
1003 }
1004 spin_unlock_bh(&ilb->lock);
1005
1006 s_num = 0;
1007 cb->args[3] = 0;
1008 cb->args[4] = 0;
1009 }
1010 skip_listen_ht:
1011 cb->args[0] = 1;
1012 s_i = num = s_num = 0;
1013 }
1014
1015 if (!(r->idiag_states & ~(TCPF_LISTEN | TCPF_SYN_RECV)))
1016 goto out;
1017
1018 for (i = s_i; i <= hashinfo->ehash_mask; i++) {
1019 struct inet_ehash_bucket *head = &hashinfo->ehash[i];
1020 spinlock_t *lock = inet_ehash_lockp(hashinfo, i);
1021 struct sock *sk;
1022 struct hlist_nulls_node *node;
1023
1024 num = 0;
1025
1026 if (hlist_nulls_empty(&head->chain) &&
1027 hlist_nulls_empty(&head->twchain))
1028 continue;
1029
1030 if (i > s_i)
1031 s_num = 0;
1032
1033 spin_lock_bh(lock);
1034 sk_nulls_for_each(sk, node, &head->chain) {
1035 struct inet_sock *inet = inet_sk(sk);
1036
1037 if (!net_eq(sock_net(sk), net))
1038 continue;
1039 if (num < s_num)
1040 goto next_normal;
1041 if (!(r->idiag_states & (1 << sk->sk_state)))
1042 goto next_normal;
1043 if (r->sdiag_family != AF_UNSPEC &&
1044 sk->sk_family != r->sdiag_family)
1045 goto next_normal;
1046 if (r->id.idiag_sport != inet->inet_sport &&
1047 r->id.idiag_sport)
1048 goto next_normal;
1049 if (r->id.idiag_dport != inet->inet_dport &&
1050 r->id.idiag_dport)
1051 goto next_normal;
1052 if (inet_csk_diag_dump(sk, skb, cb, r,
1053 bc, net_admin) < 0) {
1054 spin_unlock_bh(lock);
1055 goto done;
1056 }
1057 next_normal:
1058 ++num;
1059 }
1060
1061 if (r->idiag_states & TCPF_TIME_WAIT) {
1062 struct inet_timewait_sock *tw;
1063
1064 inet_twsk_for_each(tw, node,
1065 &head->twchain) {
1066 if (!net_eq(twsk_net(tw), net))
1067 continue;
1068
1069 if (num < s_num)
1070 goto next_dying;
1071 if (r->sdiag_family != AF_UNSPEC &&
1072 tw->tw_family != r->sdiag_family)
1073 goto next_dying;
1074 if (r->id.idiag_sport != tw->tw_sport &&
1075 r->id.idiag_sport)
1076 goto next_dying;
1077 if (r->id.idiag_dport != tw->tw_dport &&
1078 r->id.idiag_dport)
1079 goto next_dying;
1080 if (inet_twsk_diag_dump(tw, skb, cb, r, bc) < 0) {
1081 spin_unlock_bh(lock);
1082 goto done;
1083 }
1084 next_dying:
1085 ++num;
1086 }
1087 }
1088 spin_unlock_bh(lock);
1089 }
1090
1091 done:
1092 cb->args[1] = i;
1093 cb->args[2] = num;
1094 out:
1095 ;
1096 }
1097 EXPORT_SYMBOL_GPL(inet_diag_dump_icsk);
1098
__inet_diag_dump(struct sk_buff * skb,struct netlink_callback * cb,struct inet_diag_req_v2 * r,struct nlattr * bc)1099 static int __inet_diag_dump(struct sk_buff *skb, struct netlink_callback *cb,
1100 struct inet_diag_req_v2 *r, struct nlattr *bc)
1101 {
1102 const struct inet_diag_handler *handler;
1103 int err = 0;
1104
1105 handler = inet_diag_lock_handler(r->sdiag_protocol);
1106 if (!IS_ERR(handler))
1107 handler->dump(skb, cb, r, bc);
1108 else
1109 err = PTR_ERR(handler);
1110 inet_diag_unlock_handler(handler);
1111
1112 return err ? : skb->len;
1113 }
1114
inet_diag_dump(struct sk_buff * skb,struct netlink_callback * cb)1115 static int inet_diag_dump(struct sk_buff *skb, struct netlink_callback *cb)
1116 {
1117 struct nlattr *bc = NULL;
1118 int hdrlen = sizeof(struct inet_diag_req_v2);
1119
1120 if (nlmsg_attrlen(cb->nlh, hdrlen))
1121 bc = nlmsg_find_attr(cb->nlh, hdrlen, INET_DIAG_REQ_BYTECODE);
1122
1123 return __inet_diag_dump(skb, cb, nlmsg_data(cb->nlh), bc);
1124 }
1125
inet_diag_type2proto(int type)1126 static inline int inet_diag_type2proto(int type)
1127 {
1128 switch (type) {
1129 case TCPDIAG_GETSOCK:
1130 return IPPROTO_TCP;
1131 case DCCPDIAG_GETSOCK:
1132 return IPPROTO_DCCP;
1133 default:
1134 return 0;
1135 }
1136 }
1137
inet_diag_dump_compat(struct sk_buff * skb,struct netlink_callback * cb)1138 static int inet_diag_dump_compat(struct sk_buff *skb, struct netlink_callback *cb)
1139 {
1140 struct inet_diag_req *rc = nlmsg_data(cb->nlh);
1141 struct inet_diag_req_v2 req;
1142 struct nlattr *bc = NULL;
1143 int hdrlen = sizeof(struct inet_diag_req);
1144
1145 req.sdiag_family = AF_UNSPEC; /* compatibility */
1146 req.sdiag_protocol = inet_diag_type2proto(cb->nlh->nlmsg_type);
1147 req.idiag_ext = rc->idiag_ext;
1148 req.idiag_states = rc->idiag_states;
1149 req.id = rc->id;
1150
1151 if (nlmsg_attrlen(cb->nlh, hdrlen))
1152 bc = nlmsg_find_attr(cb->nlh, hdrlen, INET_DIAG_REQ_BYTECODE);
1153
1154 return __inet_diag_dump(skb, cb, &req, bc);
1155 }
1156
inet_diag_get_exact_compat(struct sk_buff * in_skb,const struct nlmsghdr * nlh)1157 static int inet_diag_get_exact_compat(struct sk_buff *in_skb,
1158 const struct nlmsghdr *nlh)
1159 {
1160 struct inet_diag_req *rc = nlmsg_data(nlh);
1161 struct inet_diag_req_v2 req;
1162
1163 req.sdiag_family = rc->idiag_family;
1164 req.sdiag_protocol = inet_diag_type2proto(nlh->nlmsg_type);
1165 req.idiag_ext = rc->idiag_ext;
1166 req.idiag_states = rc->idiag_states;
1167 req.id = rc->id;
1168
1169 return inet_diag_cmd_exact(SOCK_DIAG_BY_FAMILY, in_skb, nlh, &req);
1170 }
1171
inet_diag_rcv_msg_compat(struct sk_buff * skb,struct nlmsghdr * nlh)1172 static int inet_diag_rcv_msg_compat(struct sk_buff *skb, struct nlmsghdr *nlh)
1173 {
1174 int hdrlen = sizeof(struct inet_diag_req);
1175 struct net *net = sock_net(skb->sk);
1176
1177 if (nlh->nlmsg_type >= INET_DIAG_GETSOCK_MAX ||
1178 nlmsg_len(nlh) < hdrlen)
1179 return -EINVAL;
1180
1181 if (nlh->nlmsg_flags & NLM_F_DUMP) {
1182 if (nlmsg_attrlen(nlh, hdrlen)) {
1183 struct nlattr *attr;
1184 int err;
1185
1186 attr = nlmsg_find_attr(nlh, hdrlen,
1187 INET_DIAG_REQ_BYTECODE);
1188 err = inet_diag_bc_audit(attr, skb);
1189 if (err)
1190 return err;
1191 }
1192 {
1193 struct netlink_dump_control c = {
1194 .dump = inet_diag_dump_compat,
1195 };
1196 return netlink_dump_start(net->diag_nlsk, skb, nlh, &c);
1197 }
1198 }
1199
1200 return inet_diag_get_exact_compat(skb, nlh);
1201 }
1202
inet_diag_handler_cmd(struct sk_buff * skb,struct nlmsghdr * h)1203 static int inet_diag_handler_cmd(struct sk_buff *skb, struct nlmsghdr *h)
1204 {
1205 int hdrlen = sizeof(struct inet_diag_req_v2);
1206 struct net *net = sock_net(skb->sk);
1207
1208 if (nlmsg_len(h) < hdrlen)
1209 return -EINVAL;
1210
1211 if (h->nlmsg_type == SOCK_DIAG_BY_FAMILY &&
1212 h->nlmsg_flags & NLM_F_DUMP) {
1213 if (nlmsg_attrlen(h, hdrlen)) {
1214 struct nlattr *attr;
1215 int err;
1216
1217 attr = nlmsg_find_attr(h, hdrlen,
1218 INET_DIAG_REQ_BYTECODE);
1219 err = inet_diag_bc_audit(attr, skb);
1220 if (err)
1221 return err;
1222 }
1223 {
1224 struct netlink_dump_control c = {
1225 .dump = inet_diag_dump,
1226 };
1227 return netlink_dump_start(net->diag_nlsk, skb, h, &c);
1228 }
1229 }
1230
1231 return inet_diag_cmd_exact(h->nlmsg_type, skb, h, nlmsg_data(h));
1232 }
1233
1234 static const struct sock_diag_handler inet_diag_handler = {
1235 .family = AF_INET,
1236 .dump = inet_diag_handler_cmd,
1237 .destroy = inet_diag_handler_cmd,
1238 };
1239
1240 static const struct sock_diag_handler inet6_diag_handler = {
1241 .family = AF_INET6,
1242 .dump = inet_diag_handler_cmd,
1243 .destroy = inet_diag_handler_cmd,
1244 };
1245
inet_diag_register(const struct inet_diag_handler * h)1246 int inet_diag_register(const struct inet_diag_handler *h)
1247 {
1248 const __u16 type = h->idiag_type;
1249 int err = -EINVAL;
1250
1251 if (type >= IPPROTO_MAX)
1252 goto out;
1253
1254 mutex_lock(&inet_diag_table_mutex);
1255 err = -EEXIST;
1256 if (inet_diag_table[type] == NULL) {
1257 inet_diag_table[type] = h;
1258 err = 0;
1259 }
1260 mutex_unlock(&inet_diag_table_mutex);
1261 out:
1262 return err;
1263 }
1264 EXPORT_SYMBOL_GPL(inet_diag_register);
1265
inet_diag_unregister(const struct inet_diag_handler * h)1266 void inet_diag_unregister(const struct inet_diag_handler *h)
1267 {
1268 const __u16 type = h->idiag_type;
1269
1270 if (type >= IPPROTO_MAX)
1271 return;
1272
1273 mutex_lock(&inet_diag_table_mutex);
1274 inet_diag_table[type] = NULL;
1275 mutex_unlock(&inet_diag_table_mutex);
1276 }
1277 EXPORT_SYMBOL_GPL(inet_diag_unregister);
1278
inet_diag_init(void)1279 static int __init inet_diag_init(void)
1280 {
1281 const int inet_diag_table_size = (IPPROTO_MAX *
1282 sizeof(struct inet_diag_handler *));
1283 int err = -ENOMEM;
1284
1285 inet_diag_table = kzalloc(inet_diag_table_size, GFP_KERNEL);
1286 if (!inet_diag_table)
1287 goto out;
1288
1289 err = sock_diag_register(&inet_diag_handler);
1290 if (err)
1291 goto out_free_nl;
1292
1293 err = sock_diag_register(&inet6_diag_handler);
1294 if (err)
1295 goto out_free_inet;
1296
1297 sock_diag_register_inet_compat(inet_diag_rcv_msg_compat);
1298 out:
1299 return err;
1300
1301 out_free_inet:
1302 sock_diag_unregister(&inet_diag_handler);
1303 out_free_nl:
1304 kfree(inet_diag_table);
1305 goto out;
1306 }
1307
inet_diag_exit(void)1308 static void __exit inet_diag_exit(void)
1309 {
1310 sock_diag_unregister(&inet6_diag_handler);
1311 sock_diag_unregister(&inet_diag_handler);
1312 sock_diag_unregister_inet_compat(inet_diag_rcv_msg_compat);
1313 kfree(inet_diag_table);
1314 }
1315
1316 module_init(inet_diag_init);
1317 module_exit(inet_diag_exit);
1318 MODULE_LICENSE("GPL");
1319 MODULE_ALIAS_NET_PF_PROTO_TYPE(PF_NETLINK, NETLINK_SOCK_DIAG, 2 /* AF_INET */);
1320 MODULE_ALIAS_NET_PF_PROTO_TYPE(PF_NETLINK, NETLINK_SOCK_DIAG, 10 /* AF_INET6 */);
1321