1 // SPDX-License-Identifier: GPL-2.0-only
2 // Copyright (C) 2019-2020 Arm Ltd.
3
4 #include <linux/compiler.h>
5 #include <linux/kasan-checks.h>
6 #include <linux/kernel.h>
7
8 #include <net/checksum.h>
9
accumulate(u64 sum,u64 data)10 static u64 accumulate(u64 sum, u64 data)
11 {
12 sum += data;
13 if (sum < data)
14 sum += 1;
15 return sum;
16 }
17
18 /*
19 * We over-read the buffer and this makes KASAN unhappy. Instead, disable
20 * instrumentation and call kasan explicitly.
21 */
do_csum(const unsigned char * buff,int len)22 unsigned int __no_sanitize_address do_csum(const unsigned char *buff, int len)
23 {
24 unsigned int offset, shift, sum;
25 const u64 *ptr;
26 u64 data, sum64 = 0;
27
28 if (unlikely(len == 0))
29 return 0;
30
31 offset = (unsigned long)buff & 7;
32 /*
33 * This is to all intents and purposes safe, since rounding down cannot
34 * result in a different page or cache line being accessed, and @buff
35 * should absolutely not be pointing to anything read-sensitive. We do,
36 * however, have to be careful not to piss off KASAN, which means using
37 * unchecked reads to accommodate the head and tail, for which we'll
38 * compensate with an explicit check up-front.
39 */
40 kasan_check_read(buff, len);
41 ptr = (u64 *)(buff - offset);
42 len = len + offset - 8;
43
44 /*
45 * Head: zero out any excess leading bytes. Shifting back by the same
46 * amount should be at least as fast as any other way of handling the
47 * odd/even alignment, and means we can ignore it until the very end.
48 */
49 shift = offset * 8;
50 data = *ptr++;
51 data = (data >> shift) << shift;
52
53 /*
54 * Body: straightforward aligned loads from here on (the paired loads
55 * underlying the quadword type still only need dword alignment). The
56 * main loop strictly excludes the tail, so the second loop will always
57 * run at least once.
58 */
59 while (unlikely(len > 64)) {
60 __uint128_t tmp1, tmp2, tmp3, tmp4;
61
62 tmp1 = *(__uint128_t *)ptr;
63 tmp2 = *(__uint128_t *)(ptr + 2);
64 tmp3 = *(__uint128_t *)(ptr + 4);
65 tmp4 = *(__uint128_t *)(ptr + 6);
66
67 len -= 64;
68 ptr += 8;
69
70 /* This is the "don't dump the carry flag into a GPR" idiom */
71 tmp1 += (tmp1 >> 64) | (tmp1 << 64);
72 tmp2 += (tmp2 >> 64) | (tmp2 << 64);
73 tmp3 += (tmp3 >> 64) | (tmp3 << 64);
74 tmp4 += (tmp4 >> 64) | (tmp4 << 64);
75 tmp1 = ((tmp1 >> 64) << 64) | (tmp2 >> 64);
76 tmp1 += (tmp1 >> 64) | (tmp1 << 64);
77 tmp3 = ((tmp3 >> 64) << 64) | (tmp4 >> 64);
78 tmp3 += (tmp3 >> 64) | (tmp3 << 64);
79 tmp1 = ((tmp1 >> 64) << 64) | (tmp3 >> 64);
80 tmp1 += (tmp1 >> 64) | (tmp1 << 64);
81 tmp1 = ((tmp1 >> 64) << 64) | sum64;
82 tmp1 += (tmp1 >> 64) | (tmp1 << 64);
83 sum64 = tmp1 >> 64;
84 }
85 while (len > 8) {
86 __uint128_t tmp;
87
88 sum64 = accumulate(sum64, data);
89 tmp = *(__uint128_t *)ptr;
90
91 len -= 16;
92 ptr += 2;
93
94 data = tmp >> 64;
95 sum64 = accumulate(sum64, tmp);
96 }
97 if (len > 0) {
98 sum64 = accumulate(sum64, data);
99 data = *ptr;
100 len -= 8;
101 }
102 /*
103 * Tail: zero any over-read bytes similarly to the head, again
104 * preserving odd/even alignment.
105 */
106 shift = len * -8;
107 data = (data << shift) >> shift;
108 sum64 = accumulate(sum64, data);
109
110 /* Finally, folding */
111 sum64 += (sum64 >> 32) | (sum64 << 32);
112 sum = sum64 >> 32;
113 sum += (sum >> 16) | (sum << 16);
114 if (offset & 1)
115 return (u16)swab32(sum);
116
117 return sum >> 16;
118 }
119
csum_ipv6_magic(const struct in6_addr * saddr,const struct in6_addr * daddr,__u32 len,__u8 proto,__wsum csum)120 __sum16 csum_ipv6_magic(const struct in6_addr *saddr,
121 const struct in6_addr *daddr,
122 __u32 len, __u8 proto, __wsum csum)
123 {
124 __uint128_t src, dst;
125 u64 sum = (__force u64)csum;
126
127 src = *(const __uint128_t *)saddr->s6_addr;
128 dst = *(const __uint128_t *)daddr->s6_addr;
129
130 sum += (__force u32)htonl(len);
131 sum += (u32)proto << 24;
132 src += (src >> 64) | (src << 64);
133 dst += (dst >> 64) | (dst << 64);
134
135 sum = accumulate(sum, src >> 64);
136 sum = accumulate(sum, dst >> 64);
137
138 sum += ((sum >> 32) | (sum << 32));
139 return csum_fold((__force __wsum)(sum >> 32));
140 }
141 EXPORT_SYMBOL(csum_ipv6_magic);
142