• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */
3 
4 #include <linux/bpf.h>
5 #include <linux/filter.h>
6 #include <linux/errno.h>
7 #include <linux/file.h>
8 #include <linux/net.h>
9 #include <linux/workqueue.h>
10 #include <linux/skmsg.h>
11 #include <linux/list.h>
12 #include <linux/jhash.h>
13 
14 struct bpf_stab {
15 	struct bpf_map map;
16 	struct sock **sks;
17 	struct sk_psock_progs progs;
18 	raw_spinlock_t lock;
19 };
20 
21 #define SOCK_CREATE_FLAG_MASK				\
22 	(BPF_F_NUMA_NODE | BPF_F_RDONLY | BPF_F_WRONLY)
23 
sock_map_alloc(union bpf_attr * attr)24 static struct bpf_map *sock_map_alloc(union bpf_attr *attr)
25 {
26 	struct bpf_stab *stab;
27 	u64 cost;
28 	int err;
29 
30 	if (!capable(CAP_NET_ADMIN))
31 		return ERR_PTR(-EPERM);
32 	if (attr->max_entries == 0 ||
33 	    attr->key_size    != 4 ||
34 	    attr->value_size  != 4 ||
35 	    attr->map_flags & ~SOCK_CREATE_FLAG_MASK)
36 		return ERR_PTR(-EINVAL);
37 
38 	stab = kzalloc(sizeof(*stab), GFP_USER);
39 	if (!stab)
40 		return ERR_PTR(-ENOMEM);
41 
42 	bpf_map_init_from_attr(&stab->map, attr);
43 	raw_spin_lock_init(&stab->lock);
44 
45 	/* Make sure page count doesn't overflow. */
46 	cost = (u64) stab->map.max_entries * sizeof(struct sock *);
47 	err = bpf_map_charge_init(&stab->map.memory, cost);
48 	if (err)
49 		goto free_stab;
50 
51 	stab->sks = bpf_map_area_alloc((u64) stab->map.max_entries *
52 				       sizeof(struct sock *),
53 				       stab->map.numa_node);
54 	if (stab->sks)
55 		return &stab->map;
56 	err = -ENOMEM;
57 	bpf_map_charge_finish(&stab->map.memory);
58 free_stab:
59 	kfree(stab);
60 	return ERR_PTR(err);
61 }
62 
sock_map_get_from_fd(const union bpf_attr * attr,struct bpf_prog * prog)63 int sock_map_get_from_fd(const union bpf_attr *attr, struct bpf_prog *prog)
64 {
65 	u32 ufd = attr->target_fd;
66 	struct bpf_map *map;
67 	struct fd f;
68 	int ret;
69 
70 	f = fdget(ufd);
71 	map = __bpf_map_get(f);
72 	if (IS_ERR(map))
73 		return PTR_ERR(map);
74 	ret = sock_map_prog_update(map, prog, NULL, attr->attach_type);
75 	fdput(f);
76 	return ret;
77 }
78 
sock_map_prog_detach(const union bpf_attr * attr,enum bpf_prog_type ptype)79 int sock_map_prog_detach(const union bpf_attr *attr, enum bpf_prog_type ptype)
80 {
81 	u32 ufd = attr->target_fd;
82 	struct bpf_prog *prog;
83 	struct bpf_map *map;
84 	struct fd f;
85 	int ret;
86 
87 	if (attr->attach_flags)
88 		return -EINVAL;
89 
90 	f = fdget(ufd);
91 	map = __bpf_map_get(f);
92 	if (IS_ERR(map))
93 		return PTR_ERR(map);
94 
95 	prog = bpf_prog_get(attr->attach_bpf_fd);
96 	if (IS_ERR(prog)) {
97 		ret = PTR_ERR(prog);
98 		goto put_map;
99 	}
100 
101 	if (prog->type != ptype) {
102 		ret = -EINVAL;
103 		goto put_prog;
104 	}
105 
106 	ret = sock_map_prog_update(map, NULL, prog, attr->attach_type);
107 put_prog:
108 	bpf_prog_put(prog);
109 put_map:
110 	fdput(f);
111 	return ret;
112 }
113 
sock_map_sk_acquire(struct sock * sk)114 static void sock_map_sk_acquire(struct sock *sk)
115 	__acquires(&sk->sk_lock.slock)
116 {
117 	lock_sock(sk);
118 	rcu_read_lock();
119 }
120 
sock_map_sk_release(struct sock * sk)121 static void sock_map_sk_release(struct sock *sk)
122 	__releases(&sk->sk_lock.slock)
123 {
124 	rcu_read_unlock();
125 	release_sock(sk);
126 }
127 
sock_map_add_link(struct sk_psock * psock,struct sk_psock_link * link,struct bpf_map * map,void * link_raw)128 static void sock_map_add_link(struct sk_psock *psock,
129 			      struct sk_psock_link *link,
130 			      struct bpf_map *map, void *link_raw)
131 {
132 	link->link_raw = link_raw;
133 	link->map = map;
134 	spin_lock_bh(&psock->link_lock);
135 	list_add_tail(&link->list, &psock->link);
136 	spin_unlock_bh(&psock->link_lock);
137 }
138 
sock_map_del_link(struct sock * sk,struct sk_psock * psock,void * link_raw)139 static void sock_map_del_link(struct sock *sk,
140 			      struct sk_psock *psock, void *link_raw)
141 {
142 	struct sk_psock_link *link, *tmp;
143 	bool strp_stop = false;
144 
145 	spin_lock_bh(&psock->link_lock);
146 	list_for_each_entry_safe(link, tmp, &psock->link, list) {
147 		if (link->link_raw == link_raw) {
148 			struct bpf_map *map = link->map;
149 			struct bpf_stab *stab = container_of(map, struct bpf_stab,
150 							     map);
151 			if (psock->parser.enabled && stab->progs.skb_parser)
152 				strp_stop = true;
153 			list_del(&link->list);
154 			sk_psock_free_link(link);
155 		}
156 	}
157 	spin_unlock_bh(&psock->link_lock);
158 	if (strp_stop) {
159 		write_lock_bh(&sk->sk_callback_lock);
160 		sk_psock_stop_strp(sk, psock);
161 		write_unlock_bh(&sk->sk_callback_lock);
162 	}
163 }
164 
sock_map_unref(struct sock * sk,void * link_raw)165 static void sock_map_unref(struct sock *sk, void *link_raw)
166 {
167 	struct sk_psock *psock = sk_psock(sk);
168 
169 	if (likely(psock)) {
170 		sock_map_del_link(sk, psock, link_raw);
171 		sk_psock_put(sk, psock);
172 	}
173 }
174 
sock_map_link(struct bpf_map * map,struct sk_psock_progs * progs,struct sock * sk)175 static int sock_map_link(struct bpf_map *map, struct sk_psock_progs *progs,
176 			 struct sock *sk)
177 {
178 	struct bpf_prog *msg_parser, *skb_parser, *skb_verdict;
179 	bool skb_progs, sk_psock_is_new = false;
180 	struct sk_psock *psock;
181 	int ret;
182 
183 	skb_verdict = READ_ONCE(progs->skb_verdict);
184 	skb_parser = READ_ONCE(progs->skb_parser);
185 	skb_progs = skb_parser && skb_verdict;
186 	if (skb_progs) {
187 		skb_verdict = bpf_prog_inc_not_zero(skb_verdict);
188 		if (IS_ERR(skb_verdict))
189 			return PTR_ERR(skb_verdict);
190 		skb_parser = bpf_prog_inc_not_zero(skb_parser);
191 		if (IS_ERR(skb_parser)) {
192 			bpf_prog_put(skb_verdict);
193 			return PTR_ERR(skb_parser);
194 		}
195 	}
196 
197 	msg_parser = READ_ONCE(progs->msg_parser);
198 	if (msg_parser) {
199 		msg_parser = bpf_prog_inc_not_zero(msg_parser);
200 		if (IS_ERR(msg_parser)) {
201 			ret = PTR_ERR(msg_parser);
202 			goto out;
203 		}
204 	}
205 
206 	psock = sk_psock_get_checked(sk);
207 	if (IS_ERR(psock)) {
208 		ret = PTR_ERR(psock);
209 		goto out_progs;
210 	}
211 
212 	if (psock) {
213 		if ((msg_parser && READ_ONCE(psock->progs.msg_parser)) ||
214 		    (skb_progs  && READ_ONCE(psock->progs.skb_parser))) {
215 			sk_psock_put(sk, psock);
216 			ret = -EBUSY;
217 			goto out_progs;
218 		}
219 	} else {
220 		psock = sk_psock_init(sk, map->numa_node);
221 		if (!psock) {
222 			ret = -ENOMEM;
223 			goto out_progs;
224 		}
225 		sk_psock_is_new = true;
226 	}
227 
228 	if (msg_parser)
229 		psock_set_prog(&psock->progs.msg_parser, msg_parser);
230 	if (sk_psock_is_new) {
231 		ret = tcp_bpf_init(sk);
232 		if (ret < 0)
233 			goto out_drop;
234 	} else {
235 		tcp_bpf_reinit(sk);
236 	}
237 
238 	write_lock_bh(&sk->sk_callback_lock);
239 	if (skb_progs && !psock->parser.enabled) {
240 		ret = sk_psock_init_strp(sk, psock);
241 		if (ret) {
242 			write_unlock_bh(&sk->sk_callback_lock);
243 			goto out_drop;
244 		}
245 		psock_set_prog(&psock->progs.skb_verdict, skb_verdict);
246 		psock_set_prog(&psock->progs.skb_parser, skb_parser);
247 		sk_psock_start_strp(sk, psock);
248 	}
249 	write_unlock_bh(&sk->sk_callback_lock);
250 	return 0;
251 out_drop:
252 	sk_psock_put(sk, psock);
253 out_progs:
254 	if (msg_parser)
255 		bpf_prog_put(msg_parser);
256 out:
257 	if (skb_progs) {
258 		bpf_prog_put(skb_verdict);
259 		bpf_prog_put(skb_parser);
260 	}
261 	return ret;
262 }
263 
sock_map_free(struct bpf_map * map)264 static void sock_map_free(struct bpf_map *map)
265 {
266 	struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
267 	int i;
268 
269 	/* After the sync no updates or deletes will be in-flight so it
270 	 * is safe to walk map and remove entries without risking a race
271 	 * in EEXIST update case.
272 	 */
273 	synchronize_rcu();
274 	for (i = 0; i < stab->map.max_entries; i++) {
275 		struct sock **psk = &stab->sks[i];
276 		struct sock *sk;
277 
278 		sk = xchg(psk, NULL);
279 		if (sk) {
280 			sock_hold(sk);
281 			lock_sock(sk);
282 			rcu_read_lock();
283 			sock_map_unref(sk, psk);
284 			rcu_read_unlock();
285 			release_sock(sk);
286 			sock_put(sk);
287 		}
288 	}
289 
290 	/* wait for psock readers accessing its map link */
291 	synchronize_rcu();
292 
293 	bpf_map_area_free(stab->sks);
294 	kfree(stab);
295 }
296 
sock_map_release_progs(struct bpf_map * map)297 static void sock_map_release_progs(struct bpf_map *map)
298 {
299 	psock_progs_drop(&container_of(map, struct bpf_stab, map)->progs);
300 }
301 
__sock_map_lookup_elem(struct bpf_map * map,u32 key)302 static struct sock *__sock_map_lookup_elem(struct bpf_map *map, u32 key)
303 {
304 	struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
305 
306 	WARN_ON_ONCE(!rcu_read_lock_held());
307 
308 	if (unlikely(key >= map->max_entries))
309 		return NULL;
310 	return READ_ONCE(stab->sks[key]);
311 }
312 
sock_map_lookup(struct bpf_map * map,void * key)313 static void *sock_map_lookup(struct bpf_map *map, void *key)
314 {
315 	return ERR_PTR(-EOPNOTSUPP);
316 }
317 
__sock_map_delete(struct bpf_stab * stab,struct sock * sk_test,struct sock ** psk)318 static int __sock_map_delete(struct bpf_stab *stab, struct sock *sk_test,
319 			     struct sock **psk)
320 {
321 	struct sock *sk;
322 	int err = 0;
323 
324 	if (irqs_disabled())
325 		return -EOPNOTSUPP; /* locks here are hardirq-unsafe */
326 
327 	raw_spin_lock_bh(&stab->lock);
328 	sk = *psk;
329 	if (!sk_test || sk_test == sk)
330 		sk = xchg(psk, NULL);
331 
332 	if (likely(sk))
333 		sock_map_unref(sk, psk);
334 	else
335 		err = -EINVAL;
336 
337 	raw_spin_unlock_bh(&stab->lock);
338 	return err;
339 }
340 
sock_map_delete_from_link(struct bpf_map * map,struct sock * sk,void * link_raw)341 static void sock_map_delete_from_link(struct bpf_map *map, struct sock *sk,
342 				      void *link_raw)
343 {
344 	struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
345 
346 	__sock_map_delete(stab, sk, link_raw);
347 }
348 
sock_map_delete_elem(struct bpf_map * map,void * key)349 static int sock_map_delete_elem(struct bpf_map *map, void *key)
350 {
351 	struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
352 	u32 i = *(u32 *)key;
353 	struct sock **psk;
354 
355 	if (unlikely(i >= map->max_entries))
356 		return -EINVAL;
357 
358 	psk = &stab->sks[i];
359 	return __sock_map_delete(stab, NULL, psk);
360 }
361 
sock_map_get_next_key(struct bpf_map * map,void * key,void * next)362 static int sock_map_get_next_key(struct bpf_map *map, void *key, void *next)
363 {
364 	struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
365 	u32 i = key ? *(u32 *)key : U32_MAX;
366 	u32 *key_next = next;
367 
368 	if (i == stab->map.max_entries - 1)
369 		return -ENOENT;
370 	if (i >= stab->map.max_entries)
371 		*key_next = 0;
372 	else
373 		*key_next = i + 1;
374 	return 0;
375 }
376 
sock_map_update_common(struct bpf_map * map,u32 idx,struct sock * sk,u64 flags)377 static int sock_map_update_common(struct bpf_map *map, u32 idx,
378 				  struct sock *sk, u64 flags)
379 {
380 	struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
381 	struct inet_connection_sock *icsk = inet_csk(sk);
382 	struct sk_psock_link *link;
383 	struct sk_psock *psock;
384 	struct sock *osk;
385 	int ret;
386 
387 	WARN_ON_ONCE(!rcu_read_lock_held());
388 	if (unlikely(flags > BPF_EXIST))
389 		return -EINVAL;
390 	if (unlikely(idx >= map->max_entries))
391 		return -E2BIG;
392 	if (unlikely(rcu_access_pointer(icsk->icsk_ulp_data)))
393 		return -EINVAL;
394 
395 	link = sk_psock_init_link();
396 	if (!link)
397 		return -ENOMEM;
398 
399 	ret = sock_map_link(map, &stab->progs, sk);
400 	if (ret < 0)
401 		goto out_free;
402 
403 	psock = sk_psock(sk);
404 	WARN_ON_ONCE(!psock);
405 
406 	raw_spin_lock_bh(&stab->lock);
407 	osk = stab->sks[idx];
408 	if (osk && flags == BPF_NOEXIST) {
409 		ret = -EEXIST;
410 		goto out_unlock;
411 	} else if (!osk && flags == BPF_EXIST) {
412 		ret = -ENOENT;
413 		goto out_unlock;
414 	}
415 
416 	sock_map_add_link(psock, link, map, &stab->sks[idx]);
417 	stab->sks[idx] = sk;
418 	if (osk)
419 		sock_map_unref(osk, &stab->sks[idx]);
420 	raw_spin_unlock_bh(&stab->lock);
421 	return 0;
422 out_unlock:
423 	raw_spin_unlock_bh(&stab->lock);
424 	if (psock)
425 		sk_psock_put(sk, psock);
426 out_free:
427 	sk_psock_free_link(link);
428 	return ret;
429 }
430 
sock_map_op_okay(const struct bpf_sock_ops_kern * ops)431 static bool sock_map_op_okay(const struct bpf_sock_ops_kern *ops)
432 {
433 	return ops->op == BPF_SOCK_OPS_PASSIVE_ESTABLISHED_CB ||
434 	       ops->op == BPF_SOCK_OPS_ACTIVE_ESTABLISHED_CB;
435 }
436 
sock_map_sk_is_suitable(const struct sock * sk)437 static bool sock_map_sk_is_suitable(const struct sock *sk)
438 {
439 	return sk->sk_type == SOCK_STREAM &&
440 	       sk->sk_protocol == IPPROTO_TCP;
441 }
442 
sock_map_update_elem(struct bpf_map * map,void * key,void * value,u64 flags)443 static int sock_map_update_elem(struct bpf_map *map, void *key,
444 				void *value, u64 flags)
445 {
446 	u32 ufd = *(u32 *)value;
447 	u32 idx = *(u32 *)key;
448 	struct socket *sock;
449 	struct sock *sk;
450 	int ret;
451 
452 	sock = sockfd_lookup(ufd, &ret);
453 	if (!sock)
454 		return ret;
455 	sk = sock->sk;
456 	if (!sk) {
457 		ret = -EINVAL;
458 		goto out;
459 	}
460 	if (!sock_map_sk_is_suitable(sk)) {
461 		ret = -EOPNOTSUPP;
462 		goto out;
463 	}
464 
465 	sock_map_sk_acquire(sk);
466 	if (sk->sk_state != TCP_ESTABLISHED)
467 		ret = -EOPNOTSUPP;
468 	else
469 		ret = sock_map_update_common(map, idx, sk, flags);
470 	sock_map_sk_release(sk);
471 out:
472 	fput(sock->file);
473 	return ret;
474 }
475 
BPF_CALL_4(bpf_sock_map_update,struct bpf_sock_ops_kern *,sops,struct bpf_map *,map,void *,key,u64,flags)476 BPF_CALL_4(bpf_sock_map_update, struct bpf_sock_ops_kern *, sops,
477 	   struct bpf_map *, map, void *, key, u64, flags)
478 {
479 	WARN_ON_ONCE(!rcu_read_lock_held());
480 
481 	if (likely(sock_map_sk_is_suitable(sops->sk) &&
482 		   sock_map_op_okay(sops)))
483 		return sock_map_update_common(map, *(u32 *)key, sops->sk,
484 					      flags);
485 	return -EOPNOTSUPP;
486 }
487 
488 const struct bpf_func_proto bpf_sock_map_update_proto = {
489 	.func		= bpf_sock_map_update,
490 	.gpl_only	= false,
491 	.pkt_access	= true,
492 	.ret_type	= RET_INTEGER,
493 	.arg1_type	= ARG_PTR_TO_CTX,
494 	.arg2_type	= ARG_CONST_MAP_PTR,
495 	.arg3_type	= ARG_PTR_TO_MAP_KEY,
496 	.arg4_type	= ARG_ANYTHING,
497 };
498 
BPF_CALL_4(bpf_sk_redirect_map,struct sk_buff *,skb,struct bpf_map *,map,u32,key,u64,flags)499 BPF_CALL_4(bpf_sk_redirect_map, struct sk_buff *, skb,
500 	   struct bpf_map *, map, u32, key, u64, flags)
501 {
502 	struct tcp_skb_cb *tcb = TCP_SKB_CB(skb);
503 
504 	if (unlikely(flags & ~(BPF_F_INGRESS)))
505 		return SK_DROP;
506 	tcb->bpf.flags = flags;
507 	tcb->bpf.sk_redir = __sock_map_lookup_elem(map, key);
508 	if (!tcb->bpf.sk_redir)
509 		return SK_DROP;
510 	return SK_PASS;
511 }
512 
513 const struct bpf_func_proto bpf_sk_redirect_map_proto = {
514 	.func           = bpf_sk_redirect_map,
515 	.gpl_only       = false,
516 	.ret_type       = RET_INTEGER,
517 	.arg1_type	= ARG_PTR_TO_CTX,
518 	.arg2_type      = ARG_CONST_MAP_PTR,
519 	.arg3_type      = ARG_ANYTHING,
520 	.arg4_type      = ARG_ANYTHING,
521 };
522 
BPF_CALL_4(bpf_msg_redirect_map,struct sk_msg *,msg,struct bpf_map *,map,u32,key,u64,flags)523 BPF_CALL_4(bpf_msg_redirect_map, struct sk_msg *, msg,
524 	   struct bpf_map *, map, u32, key, u64, flags)
525 {
526 	if (unlikely(flags & ~(BPF_F_INGRESS)))
527 		return SK_DROP;
528 	msg->flags = flags;
529 	msg->sk_redir = __sock_map_lookup_elem(map, key);
530 	if (!msg->sk_redir)
531 		return SK_DROP;
532 	return SK_PASS;
533 }
534 
535 const struct bpf_func_proto bpf_msg_redirect_map_proto = {
536 	.func           = bpf_msg_redirect_map,
537 	.gpl_only       = false,
538 	.ret_type       = RET_INTEGER,
539 	.arg1_type	= ARG_PTR_TO_CTX,
540 	.arg2_type      = ARG_CONST_MAP_PTR,
541 	.arg3_type      = ARG_ANYTHING,
542 	.arg4_type      = ARG_ANYTHING,
543 };
544 
545 const struct bpf_map_ops sock_map_ops = {
546 	.map_alloc		= sock_map_alloc,
547 	.map_free		= sock_map_free,
548 	.map_get_next_key	= sock_map_get_next_key,
549 	.map_update_elem	= sock_map_update_elem,
550 	.map_delete_elem	= sock_map_delete_elem,
551 	.map_lookup_elem	= sock_map_lookup,
552 	.map_release_uref	= sock_map_release_progs,
553 	.map_check_btf		= map_check_no_btf,
554 };
555 
556 struct bpf_htab_elem {
557 	struct rcu_head rcu;
558 	u32 hash;
559 	struct sock *sk;
560 	struct hlist_node node;
561 	u8 key[0];
562 };
563 
564 struct bpf_htab_bucket {
565 	struct hlist_head head;
566 	raw_spinlock_t lock;
567 };
568 
569 struct bpf_htab {
570 	struct bpf_map map;
571 	struct bpf_htab_bucket *buckets;
572 	u32 buckets_num;
573 	u32 elem_size;
574 	struct sk_psock_progs progs;
575 	atomic_t count;
576 };
577 
sock_hash_bucket_hash(const void * key,u32 len)578 static inline u32 sock_hash_bucket_hash(const void *key, u32 len)
579 {
580 	return jhash(key, len, 0);
581 }
582 
sock_hash_select_bucket(struct bpf_htab * htab,u32 hash)583 static struct bpf_htab_bucket *sock_hash_select_bucket(struct bpf_htab *htab,
584 						       u32 hash)
585 {
586 	return &htab->buckets[hash & (htab->buckets_num - 1)];
587 }
588 
589 static struct bpf_htab_elem *
sock_hash_lookup_elem_raw(struct hlist_head * head,u32 hash,void * key,u32 key_size)590 sock_hash_lookup_elem_raw(struct hlist_head *head, u32 hash, void *key,
591 			  u32 key_size)
592 {
593 	struct bpf_htab_elem *elem;
594 
595 	hlist_for_each_entry_rcu(elem, head, node) {
596 		if (elem->hash == hash &&
597 		    !memcmp(&elem->key, key, key_size))
598 			return elem;
599 	}
600 
601 	return NULL;
602 }
603 
__sock_hash_lookup_elem(struct bpf_map * map,void * key)604 static struct sock *__sock_hash_lookup_elem(struct bpf_map *map, void *key)
605 {
606 	struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
607 	u32 key_size = map->key_size, hash;
608 	struct bpf_htab_bucket *bucket;
609 	struct bpf_htab_elem *elem;
610 
611 	WARN_ON_ONCE(!rcu_read_lock_held());
612 
613 	hash = sock_hash_bucket_hash(key, key_size);
614 	bucket = sock_hash_select_bucket(htab, hash);
615 	elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size);
616 
617 	return elem ? elem->sk : NULL;
618 }
619 
sock_hash_free_elem(struct bpf_htab * htab,struct bpf_htab_elem * elem)620 static void sock_hash_free_elem(struct bpf_htab *htab,
621 				struct bpf_htab_elem *elem)
622 {
623 	atomic_dec(&htab->count);
624 	kfree_rcu(elem, rcu);
625 }
626 
sock_hash_delete_from_link(struct bpf_map * map,struct sock * sk,void * link_raw)627 static void sock_hash_delete_from_link(struct bpf_map *map, struct sock *sk,
628 				       void *link_raw)
629 {
630 	struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
631 	struct bpf_htab_elem *elem_probe, *elem = link_raw;
632 	struct bpf_htab_bucket *bucket;
633 
634 	WARN_ON_ONCE(!rcu_read_lock_held());
635 	bucket = sock_hash_select_bucket(htab, elem->hash);
636 
637 	/* elem may be deleted in parallel from the map, but access here
638 	 * is okay since it's going away only after RCU grace period.
639 	 * However, we need to check whether it's still present.
640 	 */
641 	raw_spin_lock_bh(&bucket->lock);
642 	elem_probe = sock_hash_lookup_elem_raw(&bucket->head, elem->hash,
643 					       elem->key, map->key_size);
644 	if (elem_probe && elem_probe == elem) {
645 		hlist_del_rcu(&elem->node);
646 		sock_map_unref(elem->sk, elem);
647 		sock_hash_free_elem(htab, elem);
648 	}
649 	raw_spin_unlock_bh(&bucket->lock);
650 }
651 
sock_hash_delete_elem(struct bpf_map * map,void * key)652 static int sock_hash_delete_elem(struct bpf_map *map, void *key)
653 {
654 	struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
655 	u32 hash, key_size = map->key_size;
656 	struct bpf_htab_bucket *bucket;
657 	struct bpf_htab_elem *elem;
658 	int ret = -ENOENT;
659 
660 	if (irqs_disabled())
661 		return -EOPNOTSUPP; /* locks here are hardirq-unsafe */
662 
663 	hash = sock_hash_bucket_hash(key, key_size);
664 	bucket = sock_hash_select_bucket(htab, hash);
665 
666 	raw_spin_lock_bh(&bucket->lock);
667 	elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size);
668 	if (elem) {
669 		hlist_del_rcu(&elem->node);
670 		sock_map_unref(elem->sk, elem);
671 		sock_hash_free_elem(htab, elem);
672 		ret = 0;
673 	}
674 	raw_spin_unlock_bh(&bucket->lock);
675 	return ret;
676 }
677 
sock_hash_alloc_elem(struct bpf_htab * htab,void * key,u32 key_size,u32 hash,struct sock * sk,struct bpf_htab_elem * old)678 static struct bpf_htab_elem *sock_hash_alloc_elem(struct bpf_htab *htab,
679 						  void *key, u32 key_size,
680 						  u32 hash, struct sock *sk,
681 						  struct bpf_htab_elem *old)
682 {
683 	struct bpf_htab_elem *new;
684 
685 	if (atomic_inc_return(&htab->count) > htab->map.max_entries) {
686 		if (!old) {
687 			atomic_dec(&htab->count);
688 			return ERR_PTR(-E2BIG);
689 		}
690 	}
691 
692 	new = kmalloc_node(htab->elem_size, GFP_ATOMIC | __GFP_NOWARN,
693 			   htab->map.numa_node);
694 	if (!new) {
695 		atomic_dec(&htab->count);
696 		return ERR_PTR(-ENOMEM);
697 	}
698 	memcpy(new->key, key, key_size);
699 	new->sk = sk;
700 	new->hash = hash;
701 	return new;
702 }
703 
sock_hash_update_common(struct bpf_map * map,void * key,struct sock * sk,u64 flags)704 static int sock_hash_update_common(struct bpf_map *map, void *key,
705 				   struct sock *sk, u64 flags)
706 {
707 	struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
708 	struct inet_connection_sock *icsk = inet_csk(sk);
709 	u32 key_size = map->key_size, hash;
710 	struct bpf_htab_elem *elem, *elem_new;
711 	struct bpf_htab_bucket *bucket;
712 	struct sk_psock_link *link;
713 	struct sk_psock *psock;
714 	int ret;
715 
716 	WARN_ON_ONCE(!rcu_read_lock_held());
717 	if (unlikely(flags > BPF_EXIST))
718 		return -EINVAL;
719 	if (unlikely(icsk->icsk_ulp_data))
720 		return -EINVAL;
721 
722 	link = sk_psock_init_link();
723 	if (!link)
724 		return -ENOMEM;
725 
726 	ret = sock_map_link(map, &htab->progs, sk);
727 	if (ret < 0)
728 		goto out_free;
729 
730 	psock = sk_psock(sk);
731 	WARN_ON_ONCE(!psock);
732 
733 	hash = sock_hash_bucket_hash(key, key_size);
734 	bucket = sock_hash_select_bucket(htab, hash);
735 
736 	raw_spin_lock_bh(&bucket->lock);
737 	elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size);
738 	if (elem && flags == BPF_NOEXIST) {
739 		ret = -EEXIST;
740 		goto out_unlock;
741 	} else if (!elem && flags == BPF_EXIST) {
742 		ret = -ENOENT;
743 		goto out_unlock;
744 	}
745 
746 	elem_new = sock_hash_alloc_elem(htab, key, key_size, hash, sk, elem);
747 	if (IS_ERR(elem_new)) {
748 		ret = PTR_ERR(elem_new);
749 		goto out_unlock;
750 	}
751 
752 	sock_map_add_link(psock, link, map, elem_new);
753 	/* Add new element to the head of the list, so that
754 	 * concurrent search will find it before old elem.
755 	 */
756 	hlist_add_head_rcu(&elem_new->node, &bucket->head);
757 	if (elem) {
758 		hlist_del_rcu(&elem->node);
759 		sock_map_unref(elem->sk, elem);
760 		sock_hash_free_elem(htab, elem);
761 	}
762 	raw_spin_unlock_bh(&bucket->lock);
763 	return 0;
764 out_unlock:
765 	raw_spin_unlock_bh(&bucket->lock);
766 	sk_psock_put(sk, psock);
767 out_free:
768 	sk_psock_free_link(link);
769 	return ret;
770 }
771 
sock_hash_update_elem(struct bpf_map * map,void * key,void * value,u64 flags)772 static int sock_hash_update_elem(struct bpf_map *map, void *key,
773 				 void *value, u64 flags)
774 {
775 	u32 ufd = *(u32 *)value;
776 	struct socket *sock;
777 	struct sock *sk;
778 	int ret;
779 
780 	sock = sockfd_lookup(ufd, &ret);
781 	if (!sock)
782 		return ret;
783 	sk = sock->sk;
784 	if (!sk) {
785 		ret = -EINVAL;
786 		goto out;
787 	}
788 	if (!sock_map_sk_is_suitable(sk)) {
789 		ret = -EOPNOTSUPP;
790 		goto out;
791 	}
792 
793 	sock_map_sk_acquire(sk);
794 	if (sk->sk_state != TCP_ESTABLISHED)
795 		ret = -EOPNOTSUPP;
796 	else
797 		ret = sock_hash_update_common(map, key, sk, flags);
798 	sock_map_sk_release(sk);
799 out:
800 	fput(sock->file);
801 	return ret;
802 }
803 
sock_hash_get_next_key(struct bpf_map * map,void * key,void * key_next)804 static int sock_hash_get_next_key(struct bpf_map *map, void *key,
805 				  void *key_next)
806 {
807 	struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
808 	struct bpf_htab_elem *elem, *elem_next;
809 	u32 hash, key_size = map->key_size;
810 	struct hlist_head *head;
811 	int i = 0;
812 
813 	if (!key)
814 		goto find_first_elem;
815 	hash = sock_hash_bucket_hash(key, key_size);
816 	head = &sock_hash_select_bucket(htab, hash)->head;
817 	elem = sock_hash_lookup_elem_raw(head, hash, key, key_size);
818 	if (!elem)
819 		goto find_first_elem;
820 
821 	elem_next = hlist_entry_safe(rcu_dereference_raw(hlist_next_rcu(&elem->node)),
822 				     struct bpf_htab_elem, node);
823 	if (elem_next) {
824 		memcpy(key_next, elem_next->key, key_size);
825 		return 0;
826 	}
827 
828 	i = hash & (htab->buckets_num - 1);
829 	i++;
830 find_first_elem:
831 	for (; i < htab->buckets_num; i++) {
832 		head = &sock_hash_select_bucket(htab, i)->head;
833 		elem_next = hlist_entry_safe(rcu_dereference_raw(hlist_first_rcu(head)),
834 					     struct bpf_htab_elem, node);
835 		if (elem_next) {
836 			memcpy(key_next, elem_next->key, key_size);
837 			return 0;
838 		}
839 	}
840 
841 	return -ENOENT;
842 }
843 
sock_hash_alloc(union bpf_attr * attr)844 static struct bpf_map *sock_hash_alloc(union bpf_attr *attr)
845 {
846 	struct bpf_htab *htab;
847 	int i, err;
848 	u64 cost;
849 
850 	if (!capable(CAP_NET_ADMIN))
851 		return ERR_PTR(-EPERM);
852 	if (attr->max_entries == 0 ||
853 	    attr->key_size    == 0 ||
854 	    attr->value_size  != 4 ||
855 	    attr->map_flags & ~SOCK_CREATE_FLAG_MASK)
856 		return ERR_PTR(-EINVAL);
857 	if (attr->key_size > MAX_BPF_STACK)
858 		return ERR_PTR(-E2BIG);
859 
860 	htab = kzalloc(sizeof(*htab), GFP_USER);
861 	if (!htab)
862 		return ERR_PTR(-ENOMEM);
863 
864 	bpf_map_init_from_attr(&htab->map, attr);
865 
866 	htab->buckets_num = roundup_pow_of_two(htab->map.max_entries);
867 	htab->elem_size = sizeof(struct bpf_htab_elem) +
868 			  round_up(htab->map.key_size, 8);
869 	if (htab->buckets_num == 0 ||
870 	    htab->buckets_num > U32_MAX / sizeof(struct bpf_htab_bucket)) {
871 		err = -EINVAL;
872 		goto free_htab;
873 	}
874 
875 	cost = (u64) htab->buckets_num * sizeof(struct bpf_htab_bucket) +
876 	       (u64) htab->elem_size * htab->map.max_entries;
877 	if (cost >= U32_MAX - PAGE_SIZE) {
878 		err = -EINVAL;
879 		goto free_htab;
880 	}
881 	err = bpf_map_charge_init(&htab->map.memory, cost);
882 	if (err)
883 		goto free_htab;
884 
885 	htab->buckets = bpf_map_area_alloc(htab->buckets_num *
886 					   sizeof(struct bpf_htab_bucket),
887 					   htab->map.numa_node);
888 	if (!htab->buckets) {
889 		bpf_map_charge_finish(&htab->map.memory);
890 		err = -ENOMEM;
891 		goto free_htab;
892 	}
893 
894 	for (i = 0; i < htab->buckets_num; i++) {
895 		INIT_HLIST_HEAD(&htab->buckets[i].head);
896 		raw_spin_lock_init(&htab->buckets[i].lock);
897 	}
898 
899 	return &htab->map;
900 free_htab:
901 	kfree(htab);
902 	return ERR_PTR(err);
903 }
904 
sock_hash_free(struct bpf_map * map)905 static void sock_hash_free(struct bpf_map *map)
906 {
907 	struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
908 	struct bpf_htab_bucket *bucket;
909 	struct hlist_head unlink_list;
910 	struct bpf_htab_elem *elem;
911 	struct hlist_node *node;
912 	int i;
913 
914 	/* After the sync no updates or deletes will be in-flight so it
915 	 * is safe to walk map and remove entries without risking a race
916 	 * in EEXIST update case.
917 	 */
918 	synchronize_rcu();
919 	for (i = 0; i < htab->buckets_num; i++) {
920 		bucket = sock_hash_select_bucket(htab, i);
921 
922 		/* We are racing with sock_hash_delete_from_link to
923 		 * enter the spin-lock critical section. Every socket on
924 		 * the list is still linked to sockhash. Since link
925 		 * exists, psock exists and holds a ref to socket. That
926 		 * lets us to grab a socket ref too.
927 		 */
928 		raw_spin_lock_bh(&bucket->lock);
929 		hlist_for_each_entry(elem, &bucket->head, node)
930 			sock_hold(elem->sk);
931 		hlist_move_list(&bucket->head, &unlink_list);
932 		raw_spin_unlock_bh(&bucket->lock);
933 
934 		/* Process removed entries out of atomic context to
935 		 * block for socket lock before deleting the psock's
936 		 * link to sockhash.
937 		 */
938 		hlist_for_each_entry_safe(elem, node, &unlink_list, node) {
939 			hlist_del(&elem->node);
940 			lock_sock(elem->sk);
941 			rcu_read_lock();
942 			sock_map_unref(elem->sk, elem);
943 			rcu_read_unlock();
944 			release_sock(elem->sk);
945 			sock_put(elem->sk);
946 			sock_hash_free_elem(htab, elem);
947 		}
948 	}
949 
950 	/* wait for psock readers accessing its map link */
951 	synchronize_rcu();
952 
953 	/* wait for psock readers accessing its map link */
954 	synchronize_rcu();
955 
956 	bpf_map_area_free(htab->buckets);
957 	kfree(htab);
958 }
959 
sock_hash_release_progs(struct bpf_map * map)960 static void sock_hash_release_progs(struct bpf_map *map)
961 {
962 	psock_progs_drop(&container_of(map, struct bpf_htab, map)->progs);
963 }
964 
BPF_CALL_4(bpf_sock_hash_update,struct bpf_sock_ops_kern *,sops,struct bpf_map *,map,void *,key,u64,flags)965 BPF_CALL_4(bpf_sock_hash_update, struct bpf_sock_ops_kern *, sops,
966 	   struct bpf_map *, map, void *, key, u64, flags)
967 {
968 	WARN_ON_ONCE(!rcu_read_lock_held());
969 
970 	if (likely(sock_map_sk_is_suitable(sops->sk) &&
971 		   sock_map_op_okay(sops)))
972 		return sock_hash_update_common(map, key, sops->sk, flags);
973 	return -EOPNOTSUPP;
974 }
975 
976 const struct bpf_func_proto bpf_sock_hash_update_proto = {
977 	.func		= bpf_sock_hash_update,
978 	.gpl_only	= false,
979 	.pkt_access	= true,
980 	.ret_type	= RET_INTEGER,
981 	.arg1_type	= ARG_PTR_TO_CTX,
982 	.arg2_type	= ARG_CONST_MAP_PTR,
983 	.arg3_type	= ARG_PTR_TO_MAP_KEY,
984 	.arg4_type	= ARG_ANYTHING,
985 };
986 
BPF_CALL_4(bpf_sk_redirect_hash,struct sk_buff *,skb,struct bpf_map *,map,void *,key,u64,flags)987 BPF_CALL_4(bpf_sk_redirect_hash, struct sk_buff *, skb,
988 	   struct bpf_map *, map, void *, key, u64, flags)
989 {
990 	struct tcp_skb_cb *tcb = TCP_SKB_CB(skb);
991 
992 	if (unlikely(flags & ~(BPF_F_INGRESS)))
993 		return SK_DROP;
994 	tcb->bpf.flags = flags;
995 	tcb->bpf.sk_redir = __sock_hash_lookup_elem(map, key);
996 	if (!tcb->bpf.sk_redir)
997 		return SK_DROP;
998 	return SK_PASS;
999 }
1000 
1001 const struct bpf_func_proto bpf_sk_redirect_hash_proto = {
1002 	.func           = bpf_sk_redirect_hash,
1003 	.gpl_only       = false,
1004 	.ret_type       = RET_INTEGER,
1005 	.arg1_type	= ARG_PTR_TO_CTX,
1006 	.arg2_type      = ARG_CONST_MAP_PTR,
1007 	.arg3_type      = ARG_PTR_TO_MAP_KEY,
1008 	.arg4_type      = ARG_ANYTHING,
1009 };
1010 
BPF_CALL_4(bpf_msg_redirect_hash,struct sk_msg *,msg,struct bpf_map *,map,void *,key,u64,flags)1011 BPF_CALL_4(bpf_msg_redirect_hash, struct sk_msg *, msg,
1012 	   struct bpf_map *, map, void *, key, u64, flags)
1013 {
1014 	if (unlikely(flags & ~(BPF_F_INGRESS)))
1015 		return SK_DROP;
1016 	msg->flags = flags;
1017 	msg->sk_redir = __sock_hash_lookup_elem(map, key);
1018 	if (!msg->sk_redir)
1019 		return SK_DROP;
1020 	return SK_PASS;
1021 }
1022 
1023 const struct bpf_func_proto bpf_msg_redirect_hash_proto = {
1024 	.func           = bpf_msg_redirect_hash,
1025 	.gpl_only       = false,
1026 	.ret_type       = RET_INTEGER,
1027 	.arg1_type	= ARG_PTR_TO_CTX,
1028 	.arg2_type      = ARG_CONST_MAP_PTR,
1029 	.arg3_type      = ARG_PTR_TO_MAP_KEY,
1030 	.arg4_type      = ARG_ANYTHING,
1031 };
1032 
1033 const struct bpf_map_ops sock_hash_ops = {
1034 	.map_alloc		= sock_hash_alloc,
1035 	.map_free		= sock_hash_free,
1036 	.map_get_next_key	= sock_hash_get_next_key,
1037 	.map_update_elem	= sock_hash_update_elem,
1038 	.map_delete_elem	= sock_hash_delete_elem,
1039 	.map_lookup_elem	= sock_map_lookup,
1040 	.map_release_uref	= sock_hash_release_progs,
1041 	.map_check_btf		= map_check_no_btf,
1042 };
1043 
sock_map_progs(struct bpf_map * map)1044 static struct sk_psock_progs *sock_map_progs(struct bpf_map *map)
1045 {
1046 	switch (map->map_type) {
1047 	case BPF_MAP_TYPE_SOCKMAP:
1048 		return &container_of(map, struct bpf_stab, map)->progs;
1049 	case BPF_MAP_TYPE_SOCKHASH:
1050 		return &container_of(map, struct bpf_htab, map)->progs;
1051 	default:
1052 		break;
1053 	}
1054 
1055 	return NULL;
1056 }
1057 
sock_map_prog_update(struct bpf_map * map,struct bpf_prog * prog,struct bpf_prog * old,u32 which)1058 int sock_map_prog_update(struct bpf_map *map, struct bpf_prog *prog,
1059 			 struct bpf_prog *old, u32 which)
1060 {
1061 	struct sk_psock_progs *progs = sock_map_progs(map);
1062 	struct bpf_prog **pprog;
1063 
1064 	if (!progs)
1065 		return -EOPNOTSUPP;
1066 
1067 	switch (which) {
1068 	case BPF_SK_MSG_VERDICT:
1069 		pprog = &progs->msg_parser;
1070 		break;
1071 	case BPF_SK_SKB_STREAM_PARSER:
1072 		pprog = &progs->skb_parser;
1073 		break;
1074 	case BPF_SK_SKB_STREAM_VERDICT:
1075 		pprog = &progs->skb_verdict;
1076 		break;
1077 	default:
1078 		return -EOPNOTSUPP;
1079 	}
1080 
1081 	if (old)
1082 		return psock_replace_prog(pprog, prog, old);
1083 
1084 	psock_set_prog(pprog, prog);
1085 	return 0;
1086 }
1087 
sk_psock_unlink(struct sock * sk,struct sk_psock_link * link)1088 void sk_psock_unlink(struct sock *sk, struct sk_psock_link *link)
1089 {
1090 	switch (link->map->map_type) {
1091 	case BPF_MAP_TYPE_SOCKMAP:
1092 		return sock_map_delete_from_link(link->map, sk,
1093 						 link->link_raw);
1094 	case BPF_MAP_TYPE_SOCKHASH:
1095 		return sock_hash_delete_from_link(link->map, sk,
1096 						  link->link_raw);
1097 	default:
1098 		break;
1099 	}
1100 }
1101