1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */
3
4 #include <linux/bpf.h>
5 #include <linux/filter.h>
6 #include <linux/errno.h>
7 #include <linux/file.h>
8 #include <linux/net.h>
9 #include <linux/workqueue.h>
10 #include <linux/skmsg.h>
11 #include <linux/list.h>
12 #include <linux/jhash.h>
13
14 struct bpf_stab {
15 struct bpf_map map;
16 struct sock **sks;
17 struct sk_psock_progs progs;
18 raw_spinlock_t lock;
19 };
20
21 #define SOCK_CREATE_FLAG_MASK \
22 (BPF_F_NUMA_NODE | BPF_F_RDONLY | BPF_F_WRONLY)
23
sock_map_alloc(union bpf_attr * attr)24 static struct bpf_map *sock_map_alloc(union bpf_attr *attr)
25 {
26 struct bpf_stab *stab;
27 u64 cost;
28 int err;
29
30 if (!capable(CAP_NET_ADMIN))
31 return ERR_PTR(-EPERM);
32 if (attr->max_entries == 0 ||
33 attr->key_size != 4 ||
34 attr->value_size != 4 ||
35 attr->map_flags & ~SOCK_CREATE_FLAG_MASK)
36 return ERR_PTR(-EINVAL);
37
38 stab = kzalloc(sizeof(*stab), GFP_USER);
39 if (!stab)
40 return ERR_PTR(-ENOMEM);
41
42 bpf_map_init_from_attr(&stab->map, attr);
43 raw_spin_lock_init(&stab->lock);
44
45 /* Make sure page count doesn't overflow. */
46 cost = (u64) stab->map.max_entries * sizeof(struct sock *);
47 err = bpf_map_charge_init(&stab->map.memory, cost);
48 if (err)
49 goto free_stab;
50
51 stab->sks = bpf_map_area_alloc((u64) stab->map.max_entries *
52 sizeof(struct sock *),
53 stab->map.numa_node);
54 if (stab->sks)
55 return &stab->map;
56 err = -ENOMEM;
57 bpf_map_charge_finish(&stab->map.memory);
58 free_stab:
59 kfree(stab);
60 return ERR_PTR(err);
61 }
62
sock_map_get_from_fd(const union bpf_attr * attr,struct bpf_prog * prog)63 int sock_map_get_from_fd(const union bpf_attr *attr, struct bpf_prog *prog)
64 {
65 u32 ufd = attr->target_fd;
66 struct bpf_map *map;
67 struct fd f;
68 int ret;
69
70 f = fdget(ufd);
71 map = __bpf_map_get(f);
72 if (IS_ERR(map))
73 return PTR_ERR(map);
74 ret = sock_map_prog_update(map, prog, NULL, attr->attach_type);
75 fdput(f);
76 return ret;
77 }
78
sock_map_prog_detach(const union bpf_attr * attr,enum bpf_prog_type ptype)79 int sock_map_prog_detach(const union bpf_attr *attr, enum bpf_prog_type ptype)
80 {
81 u32 ufd = attr->target_fd;
82 struct bpf_prog *prog;
83 struct bpf_map *map;
84 struct fd f;
85 int ret;
86
87 if (attr->attach_flags)
88 return -EINVAL;
89
90 f = fdget(ufd);
91 map = __bpf_map_get(f);
92 if (IS_ERR(map))
93 return PTR_ERR(map);
94
95 prog = bpf_prog_get(attr->attach_bpf_fd);
96 if (IS_ERR(prog)) {
97 ret = PTR_ERR(prog);
98 goto put_map;
99 }
100
101 if (prog->type != ptype) {
102 ret = -EINVAL;
103 goto put_prog;
104 }
105
106 ret = sock_map_prog_update(map, NULL, prog, attr->attach_type);
107 put_prog:
108 bpf_prog_put(prog);
109 put_map:
110 fdput(f);
111 return ret;
112 }
113
sock_map_sk_acquire(struct sock * sk)114 static void sock_map_sk_acquire(struct sock *sk)
115 __acquires(&sk->sk_lock.slock)
116 {
117 lock_sock(sk);
118 rcu_read_lock();
119 }
120
sock_map_sk_release(struct sock * sk)121 static void sock_map_sk_release(struct sock *sk)
122 __releases(&sk->sk_lock.slock)
123 {
124 rcu_read_unlock();
125 release_sock(sk);
126 }
127
sock_map_add_link(struct sk_psock * psock,struct sk_psock_link * link,struct bpf_map * map,void * link_raw)128 static void sock_map_add_link(struct sk_psock *psock,
129 struct sk_psock_link *link,
130 struct bpf_map *map, void *link_raw)
131 {
132 link->link_raw = link_raw;
133 link->map = map;
134 spin_lock_bh(&psock->link_lock);
135 list_add_tail(&link->list, &psock->link);
136 spin_unlock_bh(&psock->link_lock);
137 }
138
sock_map_del_link(struct sock * sk,struct sk_psock * psock,void * link_raw)139 static void sock_map_del_link(struct sock *sk,
140 struct sk_psock *psock, void *link_raw)
141 {
142 struct sk_psock_link *link, *tmp;
143 bool strp_stop = false;
144
145 spin_lock_bh(&psock->link_lock);
146 list_for_each_entry_safe(link, tmp, &psock->link, list) {
147 if (link->link_raw == link_raw) {
148 struct bpf_map *map = link->map;
149 struct bpf_stab *stab = container_of(map, struct bpf_stab,
150 map);
151 if (psock->parser.enabled && stab->progs.skb_parser)
152 strp_stop = true;
153 list_del(&link->list);
154 sk_psock_free_link(link);
155 }
156 }
157 spin_unlock_bh(&psock->link_lock);
158 if (strp_stop) {
159 write_lock_bh(&sk->sk_callback_lock);
160 sk_psock_stop_strp(sk, psock);
161 write_unlock_bh(&sk->sk_callback_lock);
162 }
163 }
164
sock_map_unref(struct sock * sk,void * link_raw)165 static void sock_map_unref(struct sock *sk, void *link_raw)
166 {
167 struct sk_psock *psock = sk_psock(sk);
168
169 if (likely(psock)) {
170 sock_map_del_link(sk, psock, link_raw);
171 sk_psock_put(sk, psock);
172 }
173 }
174
sock_map_link(struct bpf_map * map,struct sk_psock_progs * progs,struct sock * sk)175 static int sock_map_link(struct bpf_map *map, struct sk_psock_progs *progs,
176 struct sock *sk)
177 {
178 struct bpf_prog *msg_parser, *skb_parser, *skb_verdict;
179 bool skb_progs, sk_psock_is_new = false;
180 struct sk_psock *psock;
181 int ret;
182
183 skb_verdict = READ_ONCE(progs->skb_verdict);
184 skb_parser = READ_ONCE(progs->skb_parser);
185 skb_progs = skb_parser && skb_verdict;
186 if (skb_progs) {
187 skb_verdict = bpf_prog_inc_not_zero(skb_verdict);
188 if (IS_ERR(skb_verdict))
189 return PTR_ERR(skb_verdict);
190 skb_parser = bpf_prog_inc_not_zero(skb_parser);
191 if (IS_ERR(skb_parser)) {
192 bpf_prog_put(skb_verdict);
193 return PTR_ERR(skb_parser);
194 }
195 }
196
197 msg_parser = READ_ONCE(progs->msg_parser);
198 if (msg_parser) {
199 msg_parser = bpf_prog_inc_not_zero(msg_parser);
200 if (IS_ERR(msg_parser)) {
201 ret = PTR_ERR(msg_parser);
202 goto out;
203 }
204 }
205
206 psock = sk_psock_get_checked(sk);
207 if (IS_ERR(psock)) {
208 ret = PTR_ERR(psock);
209 goto out_progs;
210 }
211
212 if (psock) {
213 if ((msg_parser && READ_ONCE(psock->progs.msg_parser)) ||
214 (skb_progs && READ_ONCE(psock->progs.skb_parser))) {
215 sk_psock_put(sk, psock);
216 ret = -EBUSY;
217 goto out_progs;
218 }
219 } else {
220 psock = sk_psock_init(sk, map->numa_node);
221 if (!psock) {
222 ret = -ENOMEM;
223 goto out_progs;
224 }
225 sk_psock_is_new = true;
226 }
227
228 if (msg_parser)
229 psock_set_prog(&psock->progs.msg_parser, msg_parser);
230 if (sk_psock_is_new) {
231 ret = tcp_bpf_init(sk);
232 if (ret < 0)
233 goto out_drop;
234 } else {
235 tcp_bpf_reinit(sk);
236 }
237
238 write_lock_bh(&sk->sk_callback_lock);
239 if (skb_progs && !psock->parser.enabled) {
240 ret = sk_psock_init_strp(sk, psock);
241 if (ret) {
242 write_unlock_bh(&sk->sk_callback_lock);
243 goto out_drop;
244 }
245 psock_set_prog(&psock->progs.skb_verdict, skb_verdict);
246 psock_set_prog(&psock->progs.skb_parser, skb_parser);
247 sk_psock_start_strp(sk, psock);
248 }
249 write_unlock_bh(&sk->sk_callback_lock);
250 return 0;
251 out_drop:
252 sk_psock_put(sk, psock);
253 out_progs:
254 if (msg_parser)
255 bpf_prog_put(msg_parser);
256 out:
257 if (skb_progs) {
258 bpf_prog_put(skb_verdict);
259 bpf_prog_put(skb_parser);
260 }
261 return ret;
262 }
263
sock_map_free(struct bpf_map * map)264 static void sock_map_free(struct bpf_map *map)
265 {
266 struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
267 int i;
268
269 /* After the sync no updates or deletes will be in-flight so it
270 * is safe to walk map and remove entries without risking a race
271 * in EEXIST update case.
272 */
273 synchronize_rcu();
274 for (i = 0; i < stab->map.max_entries; i++) {
275 struct sock **psk = &stab->sks[i];
276 struct sock *sk;
277
278 sk = xchg(psk, NULL);
279 if (sk) {
280 sock_hold(sk);
281 lock_sock(sk);
282 rcu_read_lock();
283 sock_map_unref(sk, psk);
284 rcu_read_unlock();
285 release_sock(sk);
286 sock_put(sk);
287 }
288 }
289
290 /* wait for psock readers accessing its map link */
291 synchronize_rcu();
292
293 bpf_map_area_free(stab->sks);
294 kfree(stab);
295 }
296
sock_map_release_progs(struct bpf_map * map)297 static void sock_map_release_progs(struct bpf_map *map)
298 {
299 psock_progs_drop(&container_of(map, struct bpf_stab, map)->progs);
300 }
301
__sock_map_lookup_elem(struct bpf_map * map,u32 key)302 static struct sock *__sock_map_lookup_elem(struct bpf_map *map, u32 key)
303 {
304 struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
305
306 WARN_ON_ONCE(!rcu_read_lock_held());
307
308 if (unlikely(key >= map->max_entries))
309 return NULL;
310 return READ_ONCE(stab->sks[key]);
311 }
312
sock_map_lookup(struct bpf_map * map,void * key)313 static void *sock_map_lookup(struct bpf_map *map, void *key)
314 {
315 return ERR_PTR(-EOPNOTSUPP);
316 }
317
__sock_map_delete(struct bpf_stab * stab,struct sock * sk_test,struct sock ** psk)318 static int __sock_map_delete(struct bpf_stab *stab, struct sock *sk_test,
319 struct sock **psk)
320 {
321 struct sock *sk;
322 int err = 0;
323
324 if (irqs_disabled())
325 return -EOPNOTSUPP; /* locks here are hardirq-unsafe */
326
327 raw_spin_lock_bh(&stab->lock);
328 sk = *psk;
329 if (!sk_test || sk_test == sk)
330 sk = xchg(psk, NULL);
331
332 if (likely(sk))
333 sock_map_unref(sk, psk);
334 else
335 err = -EINVAL;
336
337 raw_spin_unlock_bh(&stab->lock);
338 return err;
339 }
340
sock_map_delete_from_link(struct bpf_map * map,struct sock * sk,void * link_raw)341 static void sock_map_delete_from_link(struct bpf_map *map, struct sock *sk,
342 void *link_raw)
343 {
344 struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
345
346 __sock_map_delete(stab, sk, link_raw);
347 }
348
sock_map_delete_elem(struct bpf_map * map,void * key)349 static int sock_map_delete_elem(struct bpf_map *map, void *key)
350 {
351 struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
352 u32 i = *(u32 *)key;
353 struct sock **psk;
354
355 if (unlikely(i >= map->max_entries))
356 return -EINVAL;
357
358 psk = &stab->sks[i];
359 return __sock_map_delete(stab, NULL, psk);
360 }
361
sock_map_get_next_key(struct bpf_map * map,void * key,void * next)362 static int sock_map_get_next_key(struct bpf_map *map, void *key, void *next)
363 {
364 struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
365 u32 i = key ? *(u32 *)key : U32_MAX;
366 u32 *key_next = next;
367
368 if (i == stab->map.max_entries - 1)
369 return -ENOENT;
370 if (i >= stab->map.max_entries)
371 *key_next = 0;
372 else
373 *key_next = i + 1;
374 return 0;
375 }
376
sock_map_update_common(struct bpf_map * map,u32 idx,struct sock * sk,u64 flags)377 static int sock_map_update_common(struct bpf_map *map, u32 idx,
378 struct sock *sk, u64 flags)
379 {
380 struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
381 struct inet_connection_sock *icsk = inet_csk(sk);
382 struct sk_psock_link *link;
383 struct sk_psock *psock;
384 struct sock *osk;
385 int ret;
386
387 WARN_ON_ONCE(!rcu_read_lock_held());
388 if (unlikely(flags > BPF_EXIST))
389 return -EINVAL;
390 if (unlikely(idx >= map->max_entries))
391 return -E2BIG;
392 if (unlikely(rcu_access_pointer(icsk->icsk_ulp_data)))
393 return -EINVAL;
394
395 link = sk_psock_init_link();
396 if (!link)
397 return -ENOMEM;
398
399 ret = sock_map_link(map, &stab->progs, sk);
400 if (ret < 0)
401 goto out_free;
402
403 psock = sk_psock(sk);
404 WARN_ON_ONCE(!psock);
405
406 raw_spin_lock_bh(&stab->lock);
407 osk = stab->sks[idx];
408 if (osk && flags == BPF_NOEXIST) {
409 ret = -EEXIST;
410 goto out_unlock;
411 } else if (!osk && flags == BPF_EXIST) {
412 ret = -ENOENT;
413 goto out_unlock;
414 }
415
416 sock_map_add_link(psock, link, map, &stab->sks[idx]);
417 stab->sks[idx] = sk;
418 if (osk)
419 sock_map_unref(osk, &stab->sks[idx]);
420 raw_spin_unlock_bh(&stab->lock);
421 return 0;
422 out_unlock:
423 raw_spin_unlock_bh(&stab->lock);
424 if (psock)
425 sk_psock_put(sk, psock);
426 out_free:
427 sk_psock_free_link(link);
428 return ret;
429 }
430
sock_map_op_okay(const struct bpf_sock_ops_kern * ops)431 static bool sock_map_op_okay(const struct bpf_sock_ops_kern *ops)
432 {
433 return ops->op == BPF_SOCK_OPS_PASSIVE_ESTABLISHED_CB ||
434 ops->op == BPF_SOCK_OPS_ACTIVE_ESTABLISHED_CB;
435 }
436
sock_map_sk_is_suitable(const struct sock * sk)437 static bool sock_map_sk_is_suitable(const struct sock *sk)
438 {
439 return sk->sk_type == SOCK_STREAM &&
440 sk->sk_protocol == IPPROTO_TCP;
441 }
442
sock_map_update_elem(struct bpf_map * map,void * key,void * value,u64 flags)443 static int sock_map_update_elem(struct bpf_map *map, void *key,
444 void *value, u64 flags)
445 {
446 u32 ufd = *(u32 *)value;
447 u32 idx = *(u32 *)key;
448 struct socket *sock;
449 struct sock *sk;
450 int ret;
451
452 sock = sockfd_lookup(ufd, &ret);
453 if (!sock)
454 return ret;
455 sk = sock->sk;
456 if (!sk) {
457 ret = -EINVAL;
458 goto out;
459 }
460 if (!sock_map_sk_is_suitable(sk)) {
461 ret = -EOPNOTSUPP;
462 goto out;
463 }
464
465 sock_map_sk_acquire(sk);
466 if (sk->sk_state != TCP_ESTABLISHED)
467 ret = -EOPNOTSUPP;
468 else
469 ret = sock_map_update_common(map, idx, sk, flags);
470 sock_map_sk_release(sk);
471 out:
472 fput(sock->file);
473 return ret;
474 }
475
BPF_CALL_4(bpf_sock_map_update,struct bpf_sock_ops_kern *,sops,struct bpf_map *,map,void *,key,u64,flags)476 BPF_CALL_4(bpf_sock_map_update, struct bpf_sock_ops_kern *, sops,
477 struct bpf_map *, map, void *, key, u64, flags)
478 {
479 WARN_ON_ONCE(!rcu_read_lock_held());
480
481 if (likely(sock_map_sk_is_suitable(sops->sk) &&
482 sock_map_op_okay(sops)))
483 return sock_map_update_common(map, *(u32 *)key, sops->sk,
484 flags);
485 return -EOPNOTSUPP;
486 }
487
488 const struct bpf_func_proto bpf_sock_map_update_proto = {
489 .func = bpf_sock_map_update,
490 .gpl_only = false,
491 .pkt_access = true,
492 .ret_type = RET_INTEGER,
493 .arg1_type = ARG_PTR_TO_CTX,
494 .arg2_type = ARG_CONST_MAP_PTR,
495 .arg3_type = ARG_PTR_TO_MAP_KEY,
496 .arg4_type = ARG_ANYTHING,
497 };
498
BPF_CALL_4(bpf_sk_redirect_map,struct sk_buff *,skb,struct bpf_map *,map,u32,key,u64,flags)499 BPF_CALL_4(bpf_sk_redirect_map, struct sk_buff *, skb,
500 struct bpf_map *, map, u32, key, u64, flags)
501 {
502 struct tcp_skb_cb *tcb = TCP_SKB_CB(skb);
503
504 if (unlikely(flags & ~(BPF_F_INGRESS)))
505 return SK_DROP;
506 tcb->bpf.flags = flags;
507 tcb->bpf.sk_redir = __sock_map_lookup_elem(map, key);
508 if (!tcb->bpf.sk_redir)
509 return SK_DROP;
510 return SK_PASS;
511 }
512
513 const struct bpf_func_proto bpf_sk_redirect_map_proto = {
514 .func = bpf_sk_redirect_map,
515 .gpl_only = false,
516 .ret_type = RET_INTEGER,
517 .arg1_type = ARG_PTR_TO_CTX,
518 .arg2_type = ARG_CONST_MAP_PTR,
519 .arg3_type = ARG_ANYTHING,
520 .arg4_type = ARG_ANYTHING,
521 };
522
BPF_CALL_4(bpf_msg_redirect_map,struct sk_msg *,msg,struct bpf_map *,map,u32,key,u64,flags)523 BPF_CALL_4(bpf_msg_redirect_map, struct sk_msg *, msg,
524 struct bpf_map *, map, u32, key, u64, flags)
525 {
526 if (unlikely(flags & ~(BPF_F_INGRESS)))
527 return SK_DROP;
528 msg->flags = flags;
529 msg->sk_redir = __sock_map_lookup_elem(map, key);
530 if (!msg->sk_redir)
531 return SK_DROP;
532 return SK_PASS;
533 }
534
535 const struct bpf_func_proto bpf_msg_redirect_map_proto = {
536 .func = bpf_msg_redirect_map,
537 .gpl_only = false,
538 .ret_type = RET_INTEGER,
539 .arg1_type = ARG_PTR_TO_CTX,
540 .arg2_type = ARG_CONST_MAP_PTR,
541 .arg3_type = ARG_ANYTHING,
542 .arg4_type = ARG_ANYTHING,
543 };
544
545 const struct bpf_map_ops sock_map_ops = {
546 .map_alloc = sock_map_alloc,
547 .map_free = sock_map_free,
548 .map_get_next_key = sock_map_get_next_key,
549 .map_update_elem = sock_map_update_elem,
550 .map_delete_elem = sock_map_delete_elem,
551 .map_lookup_elem = sock_map_lookup,
552 .map_release_uref = sock_map_release_progs,
553 .map_check_btf = map_check_no_btf,
554 };
555
556 struct bpf_htab_elem {
557 struct rcu_head rcu;
558 u32 hash;
559 struct sock *sk;
560 struct hlist_node node;
561 u8 key[0];
562 };
563
564 struct bpf_htab_bucket {
565 struct hlist_head head;
566 raw_spinlock_t lock;
567 };
568
569 struct bpf_htab {
570 struct bpf_map map;
571 struct bpf_htab_bucket *buckets;
572 u32 buckets_num;
573 u32 elem_size;
574 struct sk_psock_progs progs;
575 atomic_t count;
576 };
577
sock_hash_bucket_hash(const void * key,u32 len)578 static inline u32 sock_hash_bucket_hash(const void *key, u32 len)
579 {
580 return jhash(key, len, 0);
581 }
582
sock_hash_select_bucket(struct bpf_htab * htab,u32 hash)583 static struct bpf_htab_bucket *sock_hash_select_bucket(struct bpf_htab *htab,
584 u32 hash)
585 {
586 return &htab->buckets[hash & (htab->buckets_num - 1)];
587 }
588
589 static struct bpf_htab_elem *
sock_hash_lookup_elem_raw(struct hlist_head * head,u32 hash,void * key,u32 key_size)590 sock_hash_lookup_elem_raw(struct hlist_head *head, u32 hash, void *key,
591 u32 key_size)
592 {
593 struct bpf_htab_elem *elem;
594
595 hlist_for_each_entry_rcu(elem, head, node) {
596 if (elem->hash == hash &&
597 !memcmp(&elem->key, key, key_size))
598 return elem;
599 }
600
601 return NULL;
602 }
603
__sock_hash_lookup_elem(struct bpf_map * map,void * key)604 static struct sock *__sock_hash_lookup_elem(struct bpf_map *map, void *key)
605 {
606 struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
607 u32 key_size = map->key_size, hash;
608 struct bpf_htab_bucket *bucket;
609 struct bpf_htab_elem *elem;
610
611 WARN_ON_ONCE(!rcu_read_lock_held());
612
613 hash = sock_hash_bucket_hash(key, key_size);
614 bucket = sock_hash_select_bucket(htab, hash);
615 elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size);
616
617 return elem ? elem->sk : NULL;
618 }
619
sock_hash_free_elem(struct bpf_htab * htab,struct bpf_htab_elem * elem)620 static void sock_hash_free_elem(struct bpf_htab *htab,
621 struct bpf_htab_elem *elem)
622 {
623 atomic_dec(&htab->count);
624 kfree_rcu(elem, rcu);
625 }
626
sock_hash_delete_from_link(struct bpf_map * map,struct sock * sk,void * link_raw)627 static void sock_hash_delete_from_link(struct bpf_map *map, struct sock *sk,
628 void *link_raw)
629 {
630 struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
631 struct bpf_htab_elem *elem_probe, *elem = link_raw;
632 struct bpf_htab_bucket *bucket;
633
634 WARN_ON_ONCE(!rcu_read_lock_held());
635 bucket = sock_hash_select_bucket(htab, elem->hash);
636
637 /* elem may be deleted in parallel from the map, but access here
638 * is okay since it's going away only after RCU grace period.
639 * However, we need to check whether it's still present.
640 */
641 raw_spin_lock_bh(&bucket->lock);
642 elem_probe = sock_hash_lookup_elem_raw(&bucket->head, elem->hash,
643 elem->key, map->key_size);
644 if (elem_probe && elem_probe == elem) {
645 hlist_del_rcu(&elem->node);
646 sock_map_unref(elem->sk, elem);
647 sock_hash_free_elem(htab, elem);
648 }
649 raw_spin_unlock_bh(&bucket->lock);
650 }
651
sock_hash_delete_elem(struct bpf_map * map,void * key)652 static int sock_hash_delete_elem(struct bpf_map *map, void *key)
653 {
654 struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
655 u32 hash, key_size = map->key_size;
656 struct bpf_htab_bucket *bucket;
657 struct bpf_htab_elem *elem;
658 int ret = -ENOENT;
659
660 if (irqs_disabled())
661 return -EOPNOTSUPP; /* locks here are hardirq-unsafe */
662
663 hash = sock_hash_bucket_hash(key, key_size);
664 bucket = sock_hash_select_bucket(htab, hash);
665
666 raw_spin_lock_bh(&bucket->lock);
667 elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size);
668 if (elem) {
669 hlist_del_rcu(&elem->node);
670 sock_map_unref(elem->sk, elem);
671 sock_hash_free_elem(htab, elem);
672 ret = 0;
673 }
674 raw_spin_unlock_bh(&bucket->lock);
675 return ret;
676 }
677
sock_hash_alloc_elem(struct bpf_htab * htab,void * key,u32 key_size,u32 hash,struct sock * sk,struct bpf_htab_elem * old)678 static struct bpf_htab_elem *sock_hash_alloc_elem(struct bpf_htab *htab,
679 void *key, u32 key_size,
680 u32 hash, struct sock *sk,
681 struct bpf_htab_elem *old)
682 {
683 struct bpf_htab_elem *new;
684
685 if (atomic_inc_return(&htab->count) > htab->map.max_entries) {
686 if (!old) {
687 atomic_dec(&htab->count);
688 return ERR_PTR(-E2BIG);
689 }
690 }
691
692 new = kmalloc_node(htab->elem_size, GFP_ATOMIC | __GFP_NOWARN,
693 htab->map.numa_node);
694 if (!new) {
695 atomic_dec(&htab->count);
696 return ERR_PTR(-ENOMEM);
697 }
698 memcpy(new->key, key, key_size);
699 new->sk = sk;
700 new->hash = hash;
701 return new;
702 }
703
sock_hash_update_common(struct bpf_map * map,void * key,struct sock * sk,u64 flags)704 static int sock_hash_update_common(struct bpf_map *map, void *key,
705 struct sock *sk, u64 flags)
706 {
707 struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
708 struct inet_connection_sock *icsk = inet_csk(sk);
709 u32 key_size = map->key_size, hash;
710 struct bpf_htab_elem *elem, *elem_new;
711 struct bpf_htab_bucket *bucket;
712 struct sk_psock_link *link;
713 struct sk_psock *psock;
714 int ret;
715
716 WARN_ON_ONCE(!rcu_read_lock_held());
717 if (unlikely(flags > BPF_EXIST))
718 return -EINVAL;
719 if (unlikely(icsk->icsk_ulp_data))
720 return -EINVAL;
721
722 link = sk_psock_init_link();
723 if (!link)
724 return -ENOMEM;
725
726 ret = sock_map_link(map, &htab->progs, sk);
727 if (ret < 0)
728 goto out_free;
729
730 psock = sk_psock(sk);
731 WARN_ON_ONCE(!psock);
732
733 hash = sock_hash_bucket_hash(key, key_size);
734 bucket = sock_hash_select_bucket(htab, hash);
735
736 raw_spin_lock_bh(&bucket->lock);
737 elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size);
738 if (elem && flags == BPF_NOEXIST) {
739 ret = -EEXIST;
740 goto out_unlock;
741 } else if (!elem && flags == BPF_EXIST) {
742 ret = -ENOENT;
743 goto out_unlock;
744 }
745
746 elem_new = sock_hash_alloc_elem(htab, key, key_size, hash, sk, elem);
747 if (IS_ERR(elem_new)) {
748 ret = PTR_ERR(elem_new);
749 goto out_unlock;
750 }
751
752 sock_map_add_link(psock, link, map, elem_new);
753 /* Add new element to the head of the list, so that
754 * concurrent search will find it before old elem.
755 */
756 hlist_add_head_rcu(&elem_new->node, &bucket->head);
757 if (elem) {
758 hlist_del_rcu(&elem->node);
759 sock_map_unref(elem->sk, elem);
760 sock_hash_free_elem(htab, elem);
761 }
762 raw_spin_unlock_bh(&bucket->lock);
763 return 0;
764 out_unlock:
765 raw_spin_unlock_bh(&bucket->lock);
766 sk_psock_put(sk, psock);
767 out_free:
768 sk_psock_free_link(link);
769 return ret;
770 }
771
sock_hash_update_elem(struct bpf_map * map,void * key,void * value,u64 flags)772 static int sock_hash_update_elem(struct bpf_map *map, void *key,
773 void *value, u64 flags)
774 {
775 u32 ufd = *(u32 *)value;
776 struct socket *sock;
777 struct sock *sk;
778 int ret;
779
780 sock = sockfd_lookup(ufd, &ret);
781 if (!sock)
782 return ret;
783 sk = sock->sk;
784 if (!sk) {
785 ret = -EINVAL;
786 goto out;
787 }
788 if (!sock_map_sk_is_suitable(sk)) {
789 ret = -EOPNOTSUPP;
790 goto out;
791 }
792
793 sock_map_sk_acquire(sk);
794 if (sk->sk_state != TCP_ESTABLISHED)
795 ret = -EOPNOTSUPP;
796 else
797 ret = sock_hash_update_common(map, key, sk, flags);
798 sock_map_sk_release(sk);
799 out:
800 fput(sock->file);
801 return ret;
802 }
803
sock_hash_get_next_key(struct bpf_map * map,void * key,void * key_next)804 static int sock_hash_get_next_key(struct bpf_map *map, void *key,
805 void *key_next)
806 {
807 struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
808 struct bpf_htab_elem *elem, *elem_next;
809 u32 hash, key_size = map->key_size;
810 struct hlist_head *head;
811 int i = 0;
812
813 if (!key)
814 goto find_first_elem;
815 hash = sock_hash_bucket_hash(key, key_size);
816 head = &sock_hash_select_bucket(htab, hash)->head;
817 elem = sock_hash_lookup_elem_raw(head, hash, key, key_size);
818 if (!elem)
819 goto find_first_elem;
820
821 elem_next = hlist_entry_safe(rcu_dereference_raw(hlist_next_rcu(&elem->node)),
822 struct bpf_htab_elem, node);
823 if (elem_next) {
824 memcpy(key_next, elem_next->key, key_size);
825 return 0;
826 }
827
828 i = hash & (htab->buckets_num - 1);
829 i++;
830 find_first_elem:
831 for (; i < htab->buckets_num; i++) {
832 head = &sock_hash_select_bucket(htab, i)->head;
833 elem_next = hlist_entry_safe(rcu_dereference_raw(hlist_first_rcu(head)),
834 struct bpf_htab_elem, node);
835 if (elem_next) {
836 memcpy(key_next, elem_next->key, key_size);
837 return 0;
838 }
839 }
840
841 return -ENOENT;
842 }
843
sock_hash_alloc(union bpf_attr * attr)844 static struct bpf_map *sock_hash_alloc(union bpf_attr *attr)
845 {
846 struct bpf_htab *htab;
847 int i, err;
848 u64 cost;
849
850 if (!capable(CAP_NET_ADMIN))
851 return ERR_PTR(-EPERM);
852 if (attr->max_entries == 0 ||
853 attr->key_size == 0 ||
854 attr->value_size != 4 ||
855 attr->map_flags & ~SOCK_CREATE_FLAG_MASK)
856 return ERR_PTR(-EINVAL);
857 if (attr->key_size > MAX_BPF_STACK)
858 return ERR_PTR(-E2BIG);
859
860 htab = kzalloc(sizeof(*htab), GFP_USER);
861 if (!htab)
862 return ERR_PTR(-ENOMEM);
863
864 bpf_map_init_from_attr(&htab->map, attr);
865
866 htab->buckets_num = roundup_pow_of_two(htab->map.max_entries);
867 htab->elem_size = sizeof(struct bpf_htab_elem) +
868 round_up(htab->map.key_size, 8);
869 if (htab->buckets_num == 0 ||
870 htab->buckets_num > U32_MAX / sizeof(struct bpf_htab_bucket)) {
871 err = -EINVAL;
872 goto free_htab;
873 }
874
875 cost = (u64) htab->buckets_num * sizeof(struct bpf_htab_bucket) +
876 (u64) htab->elem_size * htab->map.max_entries;
877 if (cost >= U32_MAX - PAGE_SIZE) {
878 err = -EINVAL;
879 goto free_htab;
880 }
881 err = bpf_map_charge_init(&htab->map.memory, cost);
882 if (err)
883 goto free_htab;
884
885 htab->buckets = bpf_map_area_alloc(htab->buckets_num *
886 sizeof(struct bpf_htab_bucket),
887 htab->map.numa_node);
888 if (!htab->buckets) {
889 bpf_map_charge_finish(&htab->map.memory);
890 err = -ENOMEM;
891 goto free_htab;
892 }
893
894 for (i = 0; i < htab->buckets_num; i++) {
895 INIT_HLIST_HEAD(&htab->buckets[i].head);
896 raw_spin_lock_init(&htab->buckets[i].lock);
897 }
898
899 return &htab->map;
900 free_htab:
901 kfree(htab);
902 return ERR_PTR(err);
903 }
904
sock_hash_free(struct bpf_map * map)905 static void sock_hash_free(struct bpf_map *map)
906 {
907 struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
908 struct bpf_htab_bucket *bucket;
909 struct hlist_head unlink_list;
910 struct bpf_htab_elem *elem;
911 struct hlist_node *node;
912 int i;
913
914 /* After the sync no updates or deletes will be in-flight so it
915 * is safe to walk map and remove entries without risking a race
916 * in EEXIST update case.
917 */
918 synchronize_rcu();
919 for (i = 0; i < htab->buckets_num; i++) {
920 bucket = sock_hash_select_bucket(htab, i);
921
922 /* We are racing with sock_hash_delete_from_link to
923 * enter the spin-lock critical section. Every socket on
924 * the list is still linked to sockhash. Since link
925 * exists, psock exists and holds a ref to socket. That
926 * lets us to grab a socket ref too.
927 */
928 raw_spin_lock_bh(&bucket->lock);
929 hlist_for_each_entry(elem, &bucket->head, node)
930 sock_hold(elem->sk);
931 hlist_move_list(&bucket->head, &unlink_list);
932 raw_spin_unlock_bh(&bucket->lock);
933
934 /* Process removed entries out of atomic context to
935 * block for socket lock before deleting the psock's
936 * link to sockhash.
937 */
938 hlist_for_each_entry_safe(elem, node, &unlink_list, node) {
939 hlist_del(&elem->node);
940 lock_sock(elem->sk);
941 rcu_read_lock();
942 sock_map_unref(elem->sk, elem);
943 rcu_read_unlock();
944 release_sock(elem->sk);
945 sock_put(elem->sk);
946 sock_hash_free_elem(htab, elem);
947 }
948 }
949
950 /* wait for psock readers accessing its map link */
951 synchronize_rcu();
952
953 /* wait for psock readers accessing its map link */
954 synchronize_rcu();
955
956 bpf_map_area_free(htab->buckets);
957 kfree(htab);
958 }
959
sock_hash_release_progs(struct bpf_map * map)960 static void sock_hash_release_progs(struct bpf_map *map)
961 {
962 psock_progs_drop(&container_of(map, struct bpf_htab, map)->progs);
963 }
964
BPF_CALL_4(bpf_sock_hash_update,struct bpf_sock_ops_kern *,sops,struct bpf_map *,map,void *,key,u64,flags)965 BPF_CALL_4(bpf_sock_hash_update, struct bpf_sock_ops_kern *, sops,
966 struct bpf_map *, map, void *, key, u64, flags)
967 {
968 WARN_ON_ONCE(!rcu_read_lock_held());
969
970 if (likely(sock_map_sk_is_suitable(sops->sk) &&
971 sock_map_op_okay(sops)))
972 return sock_hash_update_common(map, key, sops->sk, flags);
973 return -EOPNOTSUPP;
974 }
975
976 const struct bpf_func_proto bpf_sock_hash_update_proto = {
977 .func = bpf_sock_hash_update,
978 .gpl_only = false,
979 .pkt_access = true,
980 .ret_type = RET_INTEGER,
981 .arg1_type = ARG_PTR_TO_CTX,
982 .arg2_type = ARG_CONST_MAP_PTR,
983 .arg3_type = ARG_PTR_TO_MAP_KEY,
984 .arg4_type = ARG_ANYTHING,
985 };
986
BPF_CALL_4(bpf_sk_redirect_hash,struct sk_buff *,skb,struct bpf_map *,map,void *,key,u64,flags)987 BPF_CALL_4(bpf_sk_redirect_hash, struct sk_buff *, skb,
988 struct bpf_map *, map, void *, key, u64, flags)
989 {
990 struct tcp_skb_cb *tcb = TCP_SKB_CB(skb);
991
992 if (unlikely(flags & ~(BPF_F_INGRESS)))
993 return SK_DROP;
994 tcb->bpf.flags = flags;
995 tcb->bpf.sk_redir = __sock_hash_lookup_elem(map, key);
996 if (!tcb->bpf.sk_redir)
997 return SK_DROP;
998 return SK_PASS;
999 }
1000
1001 const struct bpf_func_proto bpf_sk_redirect_hash_proto = {
1002 .func = bpf_sk_redirect_hash,
1003 .gpl_only = false,
1004 .ret_type = RET_INTEGER,
1005 .arg1_type = ARG_PTR_TO_CTX,
1006 .arg2_type = ARG_CONST_MAP_PTR,
1007 .arg3_type = ARG_PTR_TO_MAP_KEY,
1008 .arg4_type = ARG_ANYTHING,
1009 };
1010
BPF_CALL_4(bpf_msg_redirect_hash,struct sk_msg *,msg,struct bpf_map *,map,void *,key,u64,flags)1011 BPF_CALL_4(bpf_msg_redirect_hash, struct sk_msg *, msg,
1012 struct bpf_map *, map, void *, key, u64, flags)
1013 {
1014 if (unlikely(flags & ~(BPF_F_INGRESS)))
1015 return SK_DROP;
1016 msg->flags = flags;
1017 msg->sk_redir = __sock_hash_lookup_elem(map, key);
1018 if (!msg->sk_redir)
1019 return SK_DROP;
1020 return SK_PASS;
1021 }
1022
1023 const struct bpf_func_proto bpf_msg_redirect_hash_proto = {
1024 .func = bpf_msg_redirect_hash,
1025 .gpl_only = false,
1026 .ret_type = RET_INTEGER,
1027 .arg1_type = ARG_PTR_TO_CTX,
1028 .arg2_type = ARG_CONST_MAP_PTR,
1029 .arg3_type = ARG_PTR_TO_MAP_KEY,
1030 .arg4_type = ARG_ANYTHING,
1031 };
1032
1033 const struct bpf_map_ops sock_hash_ops = {
1034 .map_alloc = sock_hash_alloc,
1035 .map_free = sock_hash_free,
1036 .map_get_next_key = sock_hash_get_next_key,
1037 .map_update_elem = sock_hash_update_elem,
1038 .map_delete_elem = sock_hash_delete_elem,
1039 .map_lookup_elem = sock_map_lookup,
1040 .map_release_uref = sock_hash_release_progs,
1041 .map_check_btf = map_check_no_btf,
1042 };
1043
sock_map_progs(struct bpf_map * map)1044 static struct sk_psock_progs *sock_map_progs(struct bpf_map *map)
1045 {
1046 switch (map->map_type) {
1047 case BPF_MAP_TYPE_SOCKMAP:
1048 return &container_of(map, struct bpf_stab, map)->progs;
1049 case BPF_MAP_TYPE_SOCKHASH:
1050 return &container_of(map, struct bpf_htab, map)->progs;
1051 default:
1052 break;
1053 }
1054
1055 return NULL;
1056 }
1057
sock_map_prog_update(struct bpf_map * map,struct bpf_prog * prog,struct bpf_prog * old,u32 which)1058 int sock_map_prog_update(struct bpf_map *map, struct bpf_prog *prog,
1059 struct bpf_prog *old, u32 which)
1060 {
1061 struct sk_psock_progs *progs = sock_map_progs(map);
1062 struct bpf_prog **pprog;
1063
1064 if (!progs)
1065 return -EOPNOTSUPP;
1066
1067 switch (which) {
1068 case BPF_SK_MSG_VERDICT:
1069 pprog = &progs->msg_parser;
1070 break;
1071 case BPF_SK_SKB_STREAM_PARSER:
1072 pprog = &progs->skb_parser;
1073 break;
1074 case BPF_SK_SKB_STREAM_VERDICT:
1075 pprog = &progs->skb_verdict;
1076 break;
1077 default:
1078 return -EOPNOTSUPP;
1079 }
1080
1081 if (old)
1082 return psock_replace_prog(pprog, prog, old);
1083
1084 psock_set_prog(pprog, prog);
1085 return 0;
1086 }
1087
sk_psock_unlink(struct sock * sk,struct sk_psock_link * link)1088 void sk_psock_unlink(struct sock *sk, struct sk_psock_link *link)
1089 {
1090 switch (link->map->map_type) {
1091 case BPF_MAP_TYPE_SOCKMAP:
1092 return sock_map_delete_from_link(link->map, sk,
1093 link->link_raw);
1094 case BPF_MAP_TYPE_SOCKHASH:
1095 return sock_hash_delete_from_link(link->map, sk,
1096 link->link_raw);
1097 default:
1098 break;
1099 }
1100 }
1101