1 // SPDX-License-Identifier: GPL-2.0-or-later
2 /*
3 * Based on net/ipv4/inet_hashtables.c
4 * Authors: Lotsa people, from code originally in tcp
5 *
6 * Based on net/ipv6/inet6_hashtables.c
7 * Authors: Lotsa people, from code originally in tcp, generalised here
8 * by Arnaldo Carvalho de Melo <acme@mandriva.com>
9 *
10 * Based on include/net/ip.h
11 * Authors: Ross Biro
12 * Fred N. van Kempen, <waltje@uWalt.NL.Mugnet.ORG>
13 * Alan Cox, <gw4pts@gw4pts.ampr.org>
14 *
15 * Changes:
16 * Mike McLagan : Routing by source
17 *
18 * Based on include/net/ipv6.h
19 * Authors:
20 * Pedro Roque <roque@di.fc.ul.pt>
21 *
22 * Based on net/core/secure_seq.c
23 * Copyright (C) 2016 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
24 *
25 * NewIP INET
26 * An implementation of the TCP/IP protocol suite for the LINUX
27 * operating system. NewIP INET is implemented using the BSD Socket
28 * interface as the means of communication with the user level.
29 *
30 * Generic NewIP INET transport hashtables
31 */
32 #define pr_fmt(fmt) KBUILD_MODNAME ": [%s:%d] " fmt, __func__, __LINE__
33
34 #include <linux/module.h>
35 #include <linux/random.h>
36
37 #include <net/nip_addrconf.h>
38 #include <net/inet_connection_sock.h>
39 #include <net/inet_hashtables.h>
40 #include <net/ninet_hashtables.h>
41 #include <net/secure_seq.h>
42 #include "tcp_nip_parameter.h"
43
44 #define TCP_SEQ_SCALE_SHIFT 6
45
46 static siphash_key_t net_secret __read_mostly;
47
net_secret_init(void)48 static __always_inline void net_secret_init(void)
49 {
50 net_get_random_once(&net_secret, sizeof(net_secret));
51 }
52
53 #ifdef CONFIG_INET
seq_scale(u32 seq)54 static u32 seq_scale(u32 seq)
55 {
56 /* As close as possible to RFC 793, which
57 * suggests using a 250 kHz clock.
58 * Further reading shows this assumes 2 Mb/s networks.
59 * or 10 Mb/s Ethernet, a 1 MHz clock is appropriate.
60 * For 10 Gb/s Ethernet, a 1 GHz clock should be ok, but
61 * we also need to limit the resolution so that the u32 seq
62 * overlaps less than one time per MSL (2 minutes).
63 * Choosing a clock of 64 ns period is OK. (period of 274 s)
64 */
65 return seq + (ktime_get_real_ns() >> TCP_SEQ_SCALE_SHIFT);
66 }
67 #endif
68
secure_tcp_nip_sequence_number(const __be32 * saddr,const __be32 * daddr,__be16 sport,__be16 dport)69 __u32 secure_tcp_nip_sequence_number(const __be32 *saddr, const __be32 *daddr,
70 __be16 sport, __be16 dport)
71 {
72 const struct {
73 struct nip_addr saddr;
74 struct nip_addr daddr;
75 __be16 sport;
76 __be16 dport;
77 } __aligned(SIPHASH_ALIGNMENT) combined = {
78 .saddr = *(struct nip_addr *)saddr,
79 .daddr = *(struct nip_addr *)daddr,
80 .sport = sport,
81 .dport = dport,
82 };
83 u32 hash;
84
85 net_secret_init();
86 hash = siphash(&combined, offsetofend(typeof(combined), dport),
87 &net_secret);
88 return seq_scale(hash);
89 }
90 EXPORT_SYMBOL_GPL(secure_tcp_nip_sequence_number);
91
secure_newip_port_ephemeral(const __be32 * saddr,const __be32 * daddr,__be16 dport)92 u64 secure_newip_port_ephemeral(const __be32 *saddr, const __be32 *daddr,
93 __be16 dport)
94 {
95 const struct {
96 struct nip_addr saddr;
97 struct nip_addr daddr;
98 __be16 dport;
99 } __aligned(SIPHASH_ALIGNMENT) combined = {
100 .saddr = *(struct nip_addr *)saddr,
101 .daddr = *(struct nip_addr *)daddr,
102 .dport = dport,
103 };
104 net_secret_init();
105 return siphash(&combined, offsetofend(typeof(combined), dport),
106 &net_secret);
107 }
108 EXPORT_SYMBOL_GPL(secure_newip_port_ephemeral);
109
nip_portaddr_hash(const struct net * net,const struct nip_addr * saddr,unsigned int port)110 static inline u32 nip_portaddr_hash(const struct net *net,
111 const struct nip_addr *saddr,
112 unsigned int port)
113 {
114 u32 v = (__force u32)saddr->NIP_ADDR_FIELD32[0] ^ (__force u32)saddr->NIP_ADDR_FIELD32[1];
115
116 return jhash_1word(v, net_hash_mix(net)) ^ port;
117 }
118
__nip_addr_jhash(const struct nip_addr * a,const u32 initval)119 static u32 __nip_addr_jhash(const struct nip_addr *a, const u32 initval)
120 {
121 u32 v = (__force u32)a->NIP_ADDR_FIELD32[0] ^ (__force u32)a->NIP_ADDR_FIELD32[1];
122
123 return jhash_3words(v,
124 (__force u32)a->NIP_ADDR_FIELD32[0],
125 (__force u32)a->NIP_ADDR_FIELD32[1],
126 initval);
127 }
128
129 static struct inet_listen_hashbucket *
ninet_lhash2_bucket_sk(struct inet_hashinfo * h,struct sock * sk)130 ninet_lhash2_bucket_sk(struct inet_hashinfo *h, struct sock *sk)
131 {
132 u32 hash = nip_portaddr_hash(sock_net(sk),
133 &sk->SK_NIP_RCV_SADDR,
134 inet_sk(sk)->inet_num);
135 return inet_lhash2_bucket(h, hash);
136 }
137
ninet_hash2(struct inet_hashinfo * h,struct sock * sk)138 static void ninet_hash2(struct inet_hashinfo *h, struct sock *sk)
139 {
140 struct inet_listen_hashbucket *ilb2;
141
142 if (!h->lhash2)
143 return;
144
145 ilb2 = ninet_lhash2_bucket_sk(h, sk);
146
147 spin_lock(&ilb2->lock);
148 hlist_add_head_rcu(&inet_csk(sk)->icsk_listen_portaddr_node, &ilb2->head);
149
150 ilb2->count++;
151 spin_unlock(&ilb2->lock);
152 }
153
154 /* Function
155 * Returns the hash value based on the passed argument
156 * Parameter
157 * net: The namespace
158 * laddr: The destination address
159 * lport: Destination port
160 * faddr: Source address
161 * fport: Source port
162 */
ninet_ehashfn(const struct net * net,const struct nip_addr * laddr,const u16 lport,const struct nip_addr * faddr,const __be16 fport)163 u32 ninet_ehashfn(const struct net *net,
164 const struct nip_addr *laddr, const u16 lport,
165 const struct nip_addr *faddr, const __be16 fport)
166 {
167 static u32 ninet_ehash_secret __read_mostly;
168 static u32 ninet_hash_secret __read_mostly;
169
170 u32 lhash, fhash;
171
172 net_get_random_once(&ninet_ehash_secret, sizeof(ninet_ehash_secret));
173 net_get_random_once(&ninet_hash_secret, sizeof(ninet_hash_secret));
174
175 /* Ipv6 uses S6_ADdr32 [3], the last 32bits of the address */
176 lhash = (__force u32)laddr->NIP_ADDR_FIELD32[0];
177 fhash = __nip_addr_jhash(faddr, ninet_hash_secret);
178
179 return __ninet_ehashfn(lhash, lport, fhash, fport,
180 ninet_ehash_secret + net_hash_mix(net));
181 }
182
183 /* Function
184 * The socket is put into the Listen hash in case the server finds
185 * the socket in the second handshake
186 * Parameter
187 * sk: Transmission control block
188 * osk: old socket
189 */
__ninet_hash(struct sock * sk,struct sock * osk)190 int __ninet_hash(struct sock *sk, struct sock *osk)
191 {
192 struct inet_hashinfo *hashinfo = sk->sk_prot->h.hashinfo;
193 struct inet_listen_hashbucket *ilb;
194 int err = 0;
195
196 if (sk->sk_state != TCP_LISTEN) {
197 local_bh_disable();
198 inet_ehash_nolisten(sk, osk, NULL);
199 local_bh_enable();
200 return 0;
201 }
202 WARN_ON(!sk_unhashed(sk));
203 ilb = &hashinfo->listening_hash[inet_sk_listen_hashfn(sk)];
204
205 spin_lock(&ilb->lock);
206
207 __sk_nulls_add_node_rcu(sk, &ilb->nulls_head);
208
209 ninet_hash2(hashinfo, sk);
210 ilb->count++;
211 sock_set_flag(sk, SOCK_RCU_FREE);
212 sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1);
213
214 spin_unlock(&ilb->lock);
215
216 return err;
217 }
218
ninet_hash(struct sock * sk)219 int ninet_hash(struct sock *sk)
220 {
221 int err = 0;
222
223 if (sk->sk_state != TCP_CLOSE) {
224 local_bh_disable();
225 err = __ninet_hash(sk, NULL);
226 local_bh_enable();
227 }
228
229 return err;
230 }
231
ninet_unhash2(struct inet_hashinfo * h,struct sock * sk)232 static void ninet_unhash2(struct inet_hashinfo *h, struct sock *sk)
233 {
234 struct inet_listen_hashbucket *ilb2;
235
236 if (!h->lhash2 ||
237 WARN_ON_ONCE(hlist_unhashed(&inet_csk(sk)->icsk_listen_portaddr_node)))
238 return;
239
240 ilb2 = ninet_lhash2_bucket_sk(h, sk);
241
242 spin_lock(&ilb2->lock);
243 hlist_del_init_rcu(&inet_csk(sk)->icsk_listen_portaddr_node);
244 if (ilb2->count)
245 ilb2->count--;
246 spin_unlock(&ilb2->lock);
247 }
248
__ninet_unhash(struct sock * sk,struct inet_listen_hashbucket * ilb)249 static void __ninet_unhash(struct sock *sk, struct inet_listen_hashbucket *ilb)
250 {
251 if (sk_unhashed(sk))
252 return;
253
254 if (ilb) {
255 struct inet_hashinfo *hashinfo = sk->sk_prot->h.hashinfo;
256
257 ninet_unhash2(hashinfo, sk);
258 if (ilb->count)
259 ilb->count--;
260 }
261 __sk_nulls_del_node_init_rcu(sk);
262 sock_prot_inuse_add(sock_net(sk), sk->sk_prot, -1);
263 }
264
ninet_unhash(struct sock * sk)265 void ninet_unhash(struct sock *sk)
266 {
267 struct inet_hashinfo *hashinfo = sk->sk_prot->h.hashinfo;
268
269 if (sk_unhashed(sk))
270 return;
271
272 if (sk->sk_state == TCP_LISTEN) {
273 struct inet_listen_hashbucket *ilb;
274
275 ilb = &hashinfo->listening_hash[inet_sk_listen_hashfn(sk)];
276 /* Don't disable bottom halves while acquiring the lock to
277 * avoid circular locking dependency on PREEMPT_RT.
278 */
279 spin_lock(&ilb->lock);
280 __ninet_unhash(sk, ilb);
281 spin_unlock(&ilb->lock);
282 } else {
283 spinlock_t *lock = inet_ehash_lockp(hashinfo, sk->sk_hash);
284
285 spin_lock_bh(lock);
286 __ninet_unhash(sk, NULL);
287 spin_unlock_bh(lock);
288 }
289 }
290
291 /* Function
292 * Find transport control blocks based on address and port in the ehash table.
293 * If found, three handshakes have been made and a connection has been established,
294 * and normal communication can proceed.
295 * Parameter
296 * net: The namespace
297 * hashinfo: A global scalar of type tcp_hashinfo that stores tcp_SOCK(including ESTABLISHED,
298 * listen, and bind) for various states of the current system.
299 * saddr: Source address
300 * sport: Source port
301 * daddr: The destination address
302 * hnum: Destination port
303 */
__ninet_lookup_established(struct net * net,struct inet_hashinfo * hashinfo,const struct nip_addr * saddr,const __be16 sport,const struct nip_addr * daddr,const u16 hnum,const int dif)304 struct sock *__ninet_lookup_established(struct net *net,
305 struct inet_hashinfo *hashinfo,
306 const struct nip_addr *saddr,
307 const __be16 sport,
308 const struct nip_addr *daddr,
309 const u16 hnum,
310 const int dif)
311 {
312 struct sock *sk;
313 const struct hlist_nulls_node *node;
314
315 const __portpair ports = INET_COMBINED_PORTS(sport, hnum);
316 /* mask ensures that the hash index is valid without memory overruns */
317 unsigned int hash = ninet_ehashfn(net, daddr, hnum, saddr, sport);
318 unsigned int slot = hash & hashinfo->ehash_mask;
319 struct inet_ehash_bucket *head = &hashinfo->ehash[slot];
320
321 begin:
322 sk_nulls_for_each_rcu(sk, node, &head->chain) {
323 if (sk->sk_hash != hash)
324 continue;
325 if (!ninet_match(sk, net, saddr, daddr, ports, dif))
326 continue;
327 if (unlikely(!refcount_inc_not_zero(&sk->sk_refcnt))) {
328 nip_dbg("sk->sk_refcnt == 0");
329 goto out;
330 }
331
332 if (unlikely(!ninet_match(sk, net, saddr, daddr, ports, dif))) {
333 sock_gen_put(sk);
334 goto begin;
335 }
336 goto found;
337 }
338 if (get_nulls_value(node) != slot)
339 goto begin;
340 out:
341 sk = NULL;
342 found:
343 return sk;
344 }
345
nip_tcp_compute_score(struct sock * sk,struct net * net,const unsigned short hnum,const struct nip_addr * daddr,const int dif,int sdif)346 static int nip_tcp_compute_score(struct sock *sk, struct net *net,
347 const unsigned short hnum,
348 const struct nip_addr *daddr,
349 const int dif, int sdif)
350 {
351 int score = -1;
352
353 if (inet_sk(sk)->inet_num == hnum && sk->sk_family == PF_NINET &&
354 net_eq(sock_net(sk), net)) {
355 score = 1;
356 if (!nip_addr_eq(&sk->SK_NIP_RCV_SADDR, &nip_any_addr)) {
357 if (!nip_addr_eq(&sk->SK_NIP_RCV_SADDR, daddr))
358 return -1;
359 score++;
360 }
361 if (!inet_sk_bound_dev_eq(net, sk->sk_bound_dev_if, dif, sdif))
362 return -1;
363 score++;
364 if (READ_ONCE(sk->sk_incoming_cpu) == raw_smp_processor_id())
365 score++;
366 }
367
368 return score;
369 }
370
371 /* nip reuseport */
ninet_lhash2_lookup(struct net * net,struct inet_listen_hashbucket * ilb2,struct sk_buff * skb,int doff,const struct nip_addr * saddr,__be16 sport,const struct nip_addr * daddr,const unsigned short hnum,const int dif,const int sdif)372 static struct sock *ninet_lhash2_lookup(struct net *net,
373 struct inet_listen_hashbucket *ilb2,
374 struct sk_buff *skb, int doff,
375 const struct nip_addr *saddr, __be16 sport,
376 const struct nip_addr *daddr, const unsigned short hnum,
377 const int dif, const int sdif)
378 {
379 struct inet_connection_sock *icsk;
380 struct sock *sk;
381 struct sock *result = NULL;
382 int hiscore = 0;
383 int matches = 0;
384 int reuseport = 0;
385 u32 phash = 0;
386
387 inet_lhash2_for_each_icsk_rcu(icsk, &ilb2->head) {
388 int score;
389
390 sk = (struct sock *)icsk;
391 score = nip_tcp_compute_score(sk, net, hnum, daddr, dif, sdif);
392 if (score > hiscore) {
393 nip_dbg("find sock in lhash table");
394 result = sk;
395 hiscore = score;
396 reuseport = sk->sk_reuseport;
397 if (reuseport) {
398 nip_dbg("find reuseport sock in lhash table");
399 phash = ninet_ehashfn(net, daddr, hnum, saddr, sport);
400 matches = 1;
401 }
402 } else if (score == hiscore && reuseport) {
403 matches++;
404 if (reciprocal_scale(phash, matches) == 0)
405 result = sk;
406 phash = next_pseudo_random32(phash);
407 }
408 }
409 return result;
410 }
411
ninet_lookup_listener(struct net * net,struct inet_hashinfo * hashinfo,struct sk_buff * skb,int doff,const struct nip_addr * saddr,const __be16 sport,const struct nip_addr * daddr,const unsigned short hnum,const int dif,const int sdif)412 struct sock *ninet_lookup_listener(struct net *net,
413 struct inet_hashinfo *hashinfo,
414 struct sk_buff *skb, int doff,
415 const struct nip_addr *saddr,
416 const __be16 sport, const struct nip_addr *daddr,
417 const unsigned short hnum, const int dif, const int sdif)
418 {
419 struct inet_listen_hashbucket *ilb2;
420 struct sock *result = NULL;
421 unsigned int hash2 = nip_portaddr_hash(net, daddr, hnum);
422
423 ilb2 = inet_lhash2_bucket(hashinfo, hash2);
424
425 result = ninet_lhash2_lookup(net, ilb2, skb, doff,
426 saddr, sport, daddr, hnum,
427 dif, sdif);
428 if (result)
429 goto done;
430
431 hash2 = nip_portaddr_hash(net, &nip_any_addr, hnum);
432 ilb2 = inet_lhash2_bucket(hashinfo, hash2);
433
434 result = ninet_lhash2_lookup(net, ilb2, skb, doff,
435 saddr, sport, &nip_any_addr, hnum,
436 dif, sdif);
437 done:
438 if (IS_ERR_OR_NULL(result))
439 return NULL;
440 return result;
441 }
442
443 /* Check whether the quad information in sock is bound by ehash. If not,
444 * the SK is inserted into the ehash and 0 is returned
445 */
__ninet_check_established(struct inet_timewait_death_row * death_row,struct sock * sk,const __u16 lport,struct inet_timewait_sock ** twp)446 static int __ninet_check_established(struct inet_timewait_death_row *death_row,
447 struct sock *sk, const __u16 lport,
448 struct inet_timewait_sock **twp)
449 {
450 struct inet_hashinfo *hinfo = death_row->hashinfo;
451 struct inet_sock *inet = inet_sk(sk);
452 struct nip_addr *daddr = &sk->SK_NIP_RCV_SADDR;
453 struct nip_addr *saddr = &sk->SK_NIP_DADDR;
454 struct net *net = sock_net(sk);
455 const __portpair ports = INET_COMBINED_PORTS(inet->inet_dport, lport);
456 unsigned int hash = ninet_ehashfn(net, daddr, lport,
457 saddr, inet->inet_dport);
458 struct inet_ehash_bucket *head = inet_ehash_bucket(hinfo, hash);
459 spinlock_t *lock = inet_ehash_lockp(hinfo, hash);
460 struct sock *sk2;
461 const struct hlist_nulls_node *node;
462
463 spin_lock(lock);
464
465 sk_nulls_for_each(sk2, node, &head->chain) {
466 if (sk2->sk_hash != hash)
467 continue;
468
469 if (likely(ninet_match(sk2, net,
470 saddr, daddr, ports, sk->sk_bound_dev_if))) {
471 nip_dbg("found same sk in ehash");
472 goto not_unique;
473 }
474 }
475
476 /* Must record num and sport now. Otherwise we will see
477 * in hash table socket with a funny identity.
478 */
479 nip_dbg("add tcp sock into ehash table. sport=%u", lport);
480 inet->inet_num = lport;
481 inet->inet_sport = htons(lport);
482 sk->sk_hash = hash;
483 WARN_ON(!sk_unhashed(sk));
484 __sk_nulls_add_node_rcu(sk, &head->chain);
485
486 spin_unlock(lock);
487 sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1);
488 return 0;
489
490 not_unique:
491 spin_unlock(lock);
492 return -EADDRNOTAVAIL;
493 }
494
ninet_sk_port_offset(const struct sock * sk)495 static u64 ninet_sk_port_offset(const struct sock *sk)
496 {
497 const struct inet_sock *inet = inet_sk(sk);
498
499 return secure_newip_port_ephemeral(sk->SK_NIP_RCV_SADDR.NIP_ADDR_FIELD32,
500 sk->SK_NIP_DADDR.NIP_ADDR_FIELD32,
501 inet->inet_dport);
502 }
503
504 /* Bind local ports randomly */
ninet_hash_connect(struct inet_timewait_death_row * death_row,struct sock * sk)505 int ninet_hash_connect(struct inet_timewait_death_row *death_row,
506 struct sock *sk)
507 {
508 u64 port_offset = 0;
509
510 if (!inet_sk(sk)->inet_num)
511 port_offset = ninet_sk_port_offset(sk);
512
513 return __inet_hash_connect(death_row, sk, port_offset,
514 __ninet_check_established);
515 }
516
517