• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // SPDX-License-Identifier: GPL-2.0-or-later
2 /*
3  * Based on net/ipv4/inet_hashtables.c
4  * Authors:	Lotsa people, from code originally in tcp
5  *
6  * Based on net/ipv6/inet6_hashtables.c
7  * Authors:	Lotsa people, from code originally in tcp, generalised here
8  *		by Arnaldo Carvalho de Melo <acme@mandriva.com>
9  *
10  * Based on include/net/ip.h
11  * Authors:	Ross Biro
12  *		Fred N. van Kempen, <waltje@uWalt.NL.Mugnet.ORG>
13  *		Alan Cox, <gw4pts@gw4pts.ampr.org>
14  *
15  * Changes:
16  *		Mike McLagan    :       Routing by source
17  *
18  * Based on include/net/ipv6.h
19  *	Authors:
20  *	Pedro Roque		<roque@di.fc.ul.pt>
21  *
22  * Based on net/core/secure_seq.c
23  * Copyright (C) 2016 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
24  *
25  * NewIP INET
26  * An implementation of the TCP/IP protocol suite for the LINUX
27  * operating system. NewIP INET is implemented using the BSD Socket
28  * interface as the means of communication with the user level.
29  *
30  * Generic NewIP INET transport hashtables
31  */
32 #define pr_fmt(fmt) KBUILD_MODNAME ": [%s:%d] " fmt, __func__, __LINE__
33 
34 #include <linux/module.h>
35 #include <linux/random.h>
36 
37 #include <net/nip_addrconf.h>
38 #include <net/inet_connection_sock.h>
39 #include <net/inet_hashtables.h>
40 #include <net/ninet_hashtables.h>
41 #include <net/secure_seq.h>
42 #include "tcp_nip_parameter.h"
43 
44 #define TCP_SEQ_SCALE_SHIFT 6
45 
46 static siphash_key_t net_secret __read_mostly;
47 
net_secret_init(void)48 static __always_inline void net_secret_init(void)
49 {
50 	net_get_random_once(&net_secret, sizeof(net_secret));
51 }
52 
53 #ifdef CONFIG_INET
seq_scale(u32 seq)54 static u32 seq_scale(u32 seq)
55 {
56 	/* As close as possible to RFC 793, which
57 	 * suggests using a 250 kHz clock.
58 	 * Further reading shows this assumes 2 Mb/s networks.
59 	 * or 10 Mb/s Ethernet, a 1 MHz clock is appropriate.
60 	 * For 10 Gb/s Ethernet, a 1 GHz clock should be ok, but
61 	 * we also need to limit the resolution so that the u32 seq
62 	 * overlaps less than one time per MSL (2 minutes).
63 	 * Choosing a clock of 64 ns period is OK. (period of 274 s)
64 	 */
65 	return seq + (ktime_get_real_ns() >> TCP_SEQ_SCALE_SHIFT);
66 }
67 #endif
68 
secure_tcp_nip_sequence_number(const __be32 * saddr,const __be32 * daddr,__be16 sport,__be16 dport)69 __u32 secure_tcp_nip_sequence_number(const __be32 *saddr, const __be32 *daddr,
70 				     __be16 sport, __be16 dport)
71 {
72 	const struct {
73 		struct nip_addr saddr;
74 		struct nip_addr daddr;
75 		__be16 sport;
76 		__be16 dport;
77 	} __aligned(SIPHASH_ALIGNMENT) combined = {
78 		.saddr = *(struct nip_addr *)saddr,
79 		.daddr = *(struct nip_addr *)daddr,
80 		.sport = sport,
81 		.dport = dport,
82 	};
83 	u32 hash;
84 
85 	net_secret_init();
86 	hash = siphash(&combined, offsetofend(typeof(combined), dport),
87 		       &net_secret);
88 	return seq_scale(hash);
89 }
90 EXPORT_SYMBOL_GPL(secure_tcp_nip_sequence_number);
91 
secure_newip_port_ephemeral(const __be32 * saddr,const __be32 * daddr,__be16 dport)92 u64 secure_newip_port_ephemeral(const __be32 *saddr, const __be32 *daddr,
93 				__be16 dport)
94 {
95 	const struct {
96 		struct nip_addr saddr;
97 		struct nip_addr daddr;
98 		__be16 dport;
99 	} __aligned(SIPHASH_ALIGNMENT) combined = {
100 		.saddr = *(struct nip_addr *)saddr,
101 		.daddr = *(struct nip_addr *)daddr,
102 		.dport = dport,
103 	};
104 	net_secret_init();
105 	return siphash(&combined, offsetofend(typeof(combined), dport),
106 		       &net_secret);
107 }
108 EXPORT_SYMBOL_GPL(secure_newip_port_ephemeral);
109 
nip_portaddr_hash(const struct net * net,const struct nip_addr * saddr,unsigned int port)110 static inline u32 nip_portaddr_hash(const struct net *net,
111 				    const struct nip_addr *saddr,
112 				    unsigned int port)
113 {
114 	u32 v = (__force u32)saddr->NIP_ADDR_FIELD32[0] ^ (__force u32)saddr->NIP_ADDR_FIELD32[1];
115 
116 	return jhash_1word(v, net_hash_mix(net)) ^ port;
117 }
118 
__nip_addr_jhash(const struct nip_addr * a,const u32 initval)119 static u32 __nip_addr_jhash(const struct nip_addr *a, const u32 initval)
120 {
121 	u32 v = (__force u32)a->NIP_ADDR_FIELD32[0] ^ (__force u32)a->NIP_ADDR_FIELD32[1];
122 
123 	return jhash_3words(v,
124 			    (__force u32)a->NIP_ADDR_FIELD32[0],
125 			    (__force u32)a->NIP_ADDR_FIELD32[1],
126 			    initval);
127 }
128 
129 static struct inet_listen_hashbucket *
ninet_lhash2_bucket_sk(struct inet_hashinfo * h,struct sock * sk)130 ninet_lhash2_bucket_sk(struct inet_hashinfo *h, struct sock *sk)
131 {
132 	u32 hash = nip_portaddr_hash(sock_net(sk),
133 					  &sk->SK_NIP_RCV_SADDR,
134 					  inet_sk(sk)->inet_num);
135 	return inet_lhash2_bucket(h, hash);
136 }
137 
ninet_hash2(struct inet_hashinfo * h,struct sock * sk)138 static void ninet_hash2(struct inet_hashinfo *h, struct sock *sk)
139 {
140 	struct inet_listen_hashbucket *ilb2;
141 
142 	if (!h->lhash2)
143 		return;
144 
145 	ilb2 = ninet_lhash2_bucket_sk(h, sk);
146 
147 	spin_lock(&ilb2->lock);
148 	hlist_add_head_rcu(&inet_csk(sk)->icsk_listen_portaddr_node, &ilb2->head);
149 
150 	ilb2->count++;
151 	spin_unlock(&ilb2->lock);
152 }
153 
154 /* Function
155  *    Returns the hash value based on the passed argument
156  * Parameter
157  *    net: The namespace
158  *    laddr: The destination address
159  *    lport: Destination port
160  *    faddr: Source address
161  *    fport: Source port
162  */
ninet_ehashfn(const struct net * net,const struct nip_addr * laddr,const u16 lport,const struct nip_addr * faddr,const __be16 fport)163 u32 ninet_ehashfn(const struct net *net,
164 		  const struct nip_addr *laddr, const u16 lport,
165 		  const struct nip_addr *faddr, const __be16 fport)
166 {
167 	static u32 ninet_ehash_secret __read_mostly;
168 	static u32 ninet_hash_secret __read_mostly;
169 
170 	u32 lhash, fhash;
171 
172 	net_get_random_once(&ninet_ehash_secret, sizeof(ninet_ehash_secret));
173 	net_get_random_once(&ninet_hash_secret, sizeof(ninet_hash_secret));
174 
175 	/* Ipv6 uses S6_ADdr32 [3], the last 32bits of the address */
176 	lhash = (__force u32)laddr->NIP_ADDR_FIELD32[0];
177 	fhash = __nip_addr_jhash(faddr, ninet_hash_secret);
178 
179 	return __ninet_ehashfn(lhash, lport, fhash, fport,
180 			       ninet_ehash_secret + net_hash_mix(net));
181 }
182 
183 /* Function
184  *    The socket is put into the Listen hash in case the server finds
185  *    the socket in the second handshake
186  * Parameter
187  *    sk: Transmission control block
188  *    osk: old socket
189  */
__ninet_hash(struct sock * sk,struct sock * osk)190 int __ninet_hash(struct sock *sk, struct sock *osk)
191 {
192 	struct inet_hashinfo *hashinfo = sk->sk_prot->h.hashinfo;
193 	struct inet_listen_hashbucket *ilb;
194 	int err = 0;
195 
196 	if (sk->sk_state != TCP_LISTEN) {
197 		local_bh_disable();
198 		inet_ehash_nolisten(sk, osk, NULL);
199 		local_bh_enable();
200 		return 0;
201 	}
202 	WARN_ON(!sk_unhashed(sk));
203 	ilb = &hashinfo->listening_hash[inet_sk_listen_hashfn(sk)];
204 
205 	spin_lock(&ilb->lock);
206 
207 	__sk_nulls_add_node_rcu(sk, &ilb->nulls_head);
208 
209 	ninet_hash2(hashinfo, sk);
210 	ilb->count++;
211 	sock_set_flag(sk, SOCK_RCU_FREE);
212 	sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1);
213 
214 	spin_unlock(&ilb->lock);
215 
216 	return err;
217 }
218 
ninet_hash(struct sock * sk)219 int ninet_hash(struct sock *sk)
220 {
221 	int err = 0;
222 
223 	if (sk->sk_state != TCP_CLOSE) {
224 		local_bh_disable();
225 		err = __ninet_hash(sk, NULL);
226 		local_bh_enable();
227 	}
228 
229 	return err;
230 }
231 
ninet_unhash2(struct inet_hashinfo * h,struct sock * sk)232 static void ninet_unhash2(struct inet_hashinfo *h, struct sock *sk)
233 {
234 	struct inet_listen_hashbucket *ilb2;
235 
236 	if (!h->lhash2 ||
237 	    WARN_ON_ONCE(hlist_unhashed(&inet_csk(sk)->icsk_listen_portaddr_node)))
238 		return;
239 
240 	ilb2 = ninet_lhash2_bucket_sk(h, sk);
241 
242 	spin_lock(&ilb2->lock);
243 	hlist_del_init_rcu(&inet_csk(sk)->icsk_listen_portaddr_node);
244 	if (ilb2->count)
245 		ilb2->count--;
246 	spin_unlock(&ilb2->lock);
247 }
248 
__ninet_unhash(struct sock * sk,struct inet_listen_hashbucket * ilb)249 static void __ninet_unhash(struct sock *sk, struct inet_listen_hashbucket *ilb)
250 {
251 	if (sk_unhashed(sk))
252 		return;
253 
254 	if (ilb) {
255 		struct inet_hashinfo *hashinfo = sk->sk_prot->h.hashinfo;
256 
257 		ninet_unhash2(hashinfo, sk);
258 		if (ilb->count)
259 			ilb->count--;
260 	}
261 	__sk_nulls_del_node_init_rcu(sk);
262 	sock_prot_inuse_add(sock_net(sk), sk->sk_prot, -1);
263 }
264 
ninet_unhash(struct sock * sk)265 void ninet_unhash(struct sock *sk)
266 {
267 	struct inet_hashinfo *hashinfo = sk->sk_prot->h.hashinfo;
268 
269 	if (sk_unhashed(sk))
270 		return;
271 
272 	if (sk->sk_state == TCP_LISTEN) {
273 		struct inet_listen_hashbucket *ilb;
274 
275 		ilb = &hashinfo->listening_hash[inet_sk_listen_hashfn(sk)];
276 		/* Don't disable bottom halves while acquiring the lock to
277 		 * avoid circular locking dependency on PREEMPT_RT.
278 		 */
279 		spin_lock(&ilb->lock);
280 		__ninet_unhash(sk, ilb);
281 		spin_unlock(&ilb->lock);
282 	} else {
283 		spinlock_t *lock = inet_ehash_lockp(hashinfo, sk->sk_hash);
284 
285 		spin_lock_bh(lock);
286 		__ninet_unhash(sk, NULL);
287 		spin_unlock_bh(lock);
288 	}
289 }
290 
291 /* Function
292  *    Find transport control blocks based on address and port in the ehash table.
293  *    If found, three handshakes have been made and a connection has been established,
294  *    and normal communication can proceed.
295  * Parameter
296  *    net: The namespace
297  *    hashinfo: A global scalar of type tcp_hashinfo that stores tcp_SOCK(including ESTABLISHED,
298  *              listen, and bind) for various states of the current system.
299  *    saddr: Source address
300  *    sport: Source port
301  *    daddr: The destination address
302  *    hnum: Destination port
303  */
__ninet_lookup_established(struct net * net,struct inet_hashinfo * hashinfo,const struct nip_addr * saddr,const __be16 sport,const struct nip_addr * daddr,const u16 hnum,const int dif)304 struct sock *__ninet_lookup_established(struct net *net,
305 					struct inet_hashinfo *hashinfo,
306 					   const struct nip_addr *saddr,
307 					   const __be16 sport,
308 					   const struct nip_addr *daddr,
309 					   const u16 hnum,
310 					   const int dif)
311 {
312 	struct sock *sk;
313 	const struct hlist_nulls_node *node;
314 
315 	const __portpair ports = INET_COMBINED_PORTS(sport, hnum);
316 	/* mask ensures that the hash index is valid without memory overruns */
317 	unsigned int hash = ninet_ehashfn(net, daddr, hnum, saddr, sport);
318 	unsigned int slot = hash & hashinfo->ehash_mask;
319 	struct inet_ehash_bucket *head = &hashinfo->ehash[slot];
320 
321 begin:
322 	sk_nulls_for_each_rcu(sk, node, &head->chain) {
323 		if (sk->sk_hash != hash)
324 			continue;
325 		if (!ninet_match(sk, net, saddr, daddr, ports, dif))
326 			continue;
327 		if (unlikely(!refcount_inc_not_zero(&sk->sk_refcnt))) {
328 			nip_dbg("sk->sk_refcnt == 0");
329 			goto out;
330 		}
331 
332 		if (unlikely(!ninet_match(sk, net, saddr, daddr, ports, dif))) {
333 			sock_gen_put(sk);
334 			goto begin;
335 		}
336 		goto found;
337 	}
338 	if (get_nulls_value(node) != slot)
339 		goto begin;
340 out:
341 	sk = NULL;
342 found:
343 	return sk;
344 }
345 
nip_tcp_compute_score(struct sock * sk,struct net * net,const unsigned short hnum,const struct nip_addr * daddr,const int dif,int sdif)346 static int nip_tcp_compute_score(struct sock *sk, struct net *net,
347 					const unsigned short hnum,
348 					const struct nip_addr *daddr,
349 					const int dif, int sdif)
350 {
351 	int score = -1;
352 
353 	if (inet_sk(sk)->inet_num == hnum && sk->sk_family == PF_NINET &&
354 	    net_eq(sock_net(sk), net)) {
355 		score = 1;
356 		if (!nip_addr_eq(&sk->SK_NIP_RCV_SADDR, &nip_any_addr)) {
357 			if (!nip_addr_eq(&sk->SK_NIP_RCV_SADDR, daddr))
358 				return -1;
359 			score++;
360 		}
361 		if (!inet_sk_bound_dev_eq(net, sk->sk_bound_dev_if, dif, sdif))
362 			return -1;
363 		score++;
364 		if (READ_ONCE(sk->sk_incoming_cpu) == raw_smp_processor_id())
365 			score++;
366 	}
367 
368 	return score;
369 }
370 
371 /* nip reuseport */
ninet_lhash2_lookup(struct net * net,struct inet_listen_hashbucket * ilb2,struct sk_buff * skb,int doff,const struct nip_addr * saddr,__be16 sport,const struct nip_addr * daddr,const unsigned short hnum,const int dif,const int sdif)372 static struct sock *ninet_lhash2_lookup(struct net *net,
373 					struct inet_listen_hashbucket *ilb2,
374 					struct sk_buff *skb, int doff,
375 					const struct nip_addr *saddr, __be16 sport,
376 					const struct nip_addr *daddr, const unsigned short hnum,
377 					const int dif, const int sdif)
378 {
379 	struct inet_connection_sock *icsk;
380 	struct sock *sk;
381 	struct sock *result = NULL;
382 	int hiscore = 0;
383 	int matches = 0;
384 	int reuseport = 0;
385 	u32 phash = 0;
386 
387 	inet_lhash2_for_each_icsk_rcu(icsk, &ilb2->head) {
388 		int score;
389 
390 		sk = (struct sock *)icsk;
391 		score = nip_tcp_compute_score(sk, net, hnum, daddr, dif, sdif);
392 		if (score > hiscore) {
393 			nip_dbg("find sock in lhash table");
394 			result = sk;
395 			hiscore = score;
396 			reuseport = sk->sk_reuseport;
397 			if (reuseport) {
398 				nip_dbg("find reuseport sock in lhash table");
399 				phash = ninet_ehashfn(net, daddr, hnum, saddr, sport);
400 				matches = 1;
401 			}
402 		} else if (score == hiscore && reuseport) {
403 			matches++;
404 			if (reciprocal_scale(phash, matches) == 0)
405 				result = sk;
406 			phash = next_pseudo_random32(phash);
407 		}
408 	}
409 	return result;
410 }
411 
ninet_lookup_listener(struct net * net,struct inet_hashinfo * hashinfo,struct sk_buff * skb,int doff,const struct nip_addr * saddr,const __be16 sport,const struct nip_addr * daddr,const unsigned short hnum,const int dif,const int sdif)412 struct sock *ninet_lookup_listener(struct net *net,
413 				   struct inet_hashinfo *hashinfo,
414 				   struct sk_buff *skb, int doff,
415 				   const struct nip_addr *saddr,
416 				   const __be16 sport, const struct nip_addr *daddr,
417 				   const unsigned short hnum, const int dif, const int sdif)
418 {
419 	struct inet_listen_hashbucket *ilb2;
420 	struct sock *result = NULL;
421 	unsigned int hash2 = nip_portaddr_hash(net, daddr, hnum);
422 
423 	ilb2 = inet_lhash2_bucket(hashinfo, hash2);
424 
425 	result = ninet_lhash2_lookup(net, ilb2, skb, doff,
426 				     saddr, sport, daddr, hnum,
427 				     dif, sdif);
428 	if (result)
429 		goto done;
430 
431 	hash2 = nip_portaddr_hash(net, &nip_any_addr, hnum);
432 	ilb2 = inet_lhash2_bucket(hashinfo, hash2);
433 
434 	result = ninet_lhash2_lookup(net, ilb2, skb, doff,
435 				     saddr, sport, &nip_any_addr, hnum,
436 				     dif, sdif);
437 done:
438 	if (IS_ERR_OR_NULL(result))
439 		return NULL;
440 	return result;
441 }
442 
443 /* Check whether the quad information in sock is bound by ehash. If not,
444  * the SK is inserted into the ehash and 0 is returned
445  */
__ninet_check_established(struct inet_timewait_death_row * death_row,struct sock * sk,const __u16 lport,struct inet_timewait_sock ** twp)446 static int __ninet_check_established(struct inet_timewait_death_row *death_row,
447 				     struct sock *sk, const __u16 lport,
448 				     struct inet_timewait_sock **twp)
449 {
450 	struct inet_hashinfo *hinfo = death_row->hashinfo;
451 	struct inet_sock *inet = inet_sk(sk);
452 	struct nip_addr *daddr = &sk->SK_NIP_RCV_SADDR;
453 	struct nip_addr *saddr = &sk->SK_NIP_DADDR;
454 	struct net *net = sock_net(sk);
455 	const __portpair ports = INET_COMBINED_PORTS(inet->inet_dport, lport);
456 	unsigned int hash = ninet_ehashfn(net, daddr, lport,
457 					 saddr, inet->inet_dport);
458 	struct inet_ehash_bucket *head = inet_ehash_bucket(hinfo, hash);
459 	spinlock_t *lock = inet_ehash_lockp(hinfo, hash);
460 	struct sock *sk2;
461 	const struct hlist_nulls_node *node;
462 
463 	spin_lock(lock);
464 
465 	sk_nulls_for_each(sk2, node, &head->chain) {
466 		if (sk2->sk_hash != hash)
467 			continue;
468 
469 		if (likely(ninet_match(sk2, net,
470 				       saddr, daddr, ports, sk->sk_bound_dev_if))) {
471 			nip_dbg("found same sk in ehash");
472 			goto not_unique;
473 		}
474 	}
475 
476 	/* Must record num and sport now. Otherwise we will see
477 	 * in hash table socket with a funny identity.
478 	 */
479 	nip_dbg("add tcp sock into ehash table. sport=%u", lport);
480 	inet->inet_num = lport;
481 	inet->inet_sport = htons(lport);
482 	sk->sk_hash = hash;
483 	WARN_ON(!sk_unhashed(sk));
484 	__sk_nulls_add_node_rcu(sk, &head->chain);
485 
486 	spin_unlock(lock);
487 	sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1);
488 	return 0;
489 
490 not_unique:
491 	spin_unlock(lock);
492 	return -EADDRNOTAVAIL;
493 }
494 
ninet_sk_port_offset(const struct sock * sk)495 static u64 ninet_sk_port_offset(const struct sock *sk)
496 {
497 	const struct inet_sock *inet = inet_sk(sk);
498 
499 	return secure_newip_port_ephemeral(sk->SK_NIP_RCV_SADDR.NIP_ADDR_FIELD32,
500 					  sk->SK_NIP_DADDR.NIP_ADDR_FIELD32,
501 					  inet->inet_dport);
502 }
503 
504 /* Bind local ports randomly */
ninet_hash_connect(struct inet_timewait_death_row * death_row,struct sock * sk)505 int ninet_hash_connect(struct inet_timewait_death_row *death_row,
506 		       struct sock *sk)
507 {
508 	u64 port_offset = 0;
509 
510 	if (!inet_sk(sk)->inet_num)
511 		port_offset = ninet_sk_port_offset(sk);
512 
513 	return __inet_hash_connect(death_row, sk, port_offset,
514 	__ninet_check_established);
515 }
516 
517