1 // SPDX-License-Identifier: GPL-2.0+
2 /*
3 * Copyright (c) 2024 Huawei Device Co., Ltd.
4 *
5 * Operations on the lowpower protocol
6 * Authors: yangyanjun
7 */
8 #ifdef CONFIG_LOWPOWER_PROTOCOL
9 #include <linux/types.h>
10 #include <linux/kernel.h>
11 #include <linux/proc_fs.h>
12 #include <linux/printk.h>
13 #include <linux/list.h>
14 #include <linux/rwlock_types.h>
15 #include <linux/net_namespace.h>
16 #include <net/sock.h>
17 #include <net/ip.h>
18 #include <net/tcp.h>
19 #include <net/lowpower_protocol.h>
20
21 static atomic_t g_foreground_uid = ATOMIC_INIT(FOREGROUND_UID_INIT);
22 #define OPT_LEN 3
23 #define TO_DECIMAL 10
24 #define LIST_MAX 500
25 #define DECIMAL_CHAR_NUM 10 // u32 decimal characters (4,294,967,295)
26 static DEFINE_RWLOCK(g_dpa_rwlock);
27 static u32 g_dpa_uid_list_cnt;
28 static struct list_head g_dpa_uid_list;
29 struct dpa_node {
30 struct list_head list_node;
31 uid_t uid;
32 };
33 static ext_init g_dpa_init_fun;
34
foreground_uid_atomic_set(uid_t val)35 static void foreground_uid_atomic_set(uid_t val)
36 {
37 atomic_set(&g_foreground_uid, val);
38 }
39
foreground_uid_atomic_read(void)40 static uid_t foreground_uid_atomic_read(void)
41 {
42 return (uid_t)atomic_read(&g_foreground_uid);
43 }
44
45 // cat /proc/net/foreground_uid
foreground_uid_show(struct seq_file * seq,void * v)46 static int foreground_uid_show(struct seq_file *seq, void *v)
47 {
48 uid_t uid = foreground_uid_atomic_read();
49
50 seq_printf(seq, "%u\n", uid);
51 return 0;
52 }
53
54 // echo xx > /proc/net/foreground_uid
foreground_uid_write(struct file * file,char * buf,size_t size)55 static int foreground_uid_write(struct file *file, char *buf, size_t size)
56 {
57 char *p = buf;
58 uid_t uid = simple_strtoul(p, &p, TO_DECIMAL);
59
60 if (!p)
61 return -EINVAL;
62
63 foreground_uid_atomic_set(uid);
64 return 0;
65 }
66
67 // cat /proc/net/dpa_uid
dpa_uid_show(struct seq_file * seq,void * v)68 static int dpa_uid_show(struct seq_file *seq, void *v)
69 {
70 struct dpa_node *node = NULL;
71 struct dpa_node *tmp_node = NULL;
72
73 read_lock(&g_dpa_rwlock);
74 seq_printf(seq, "uid list num: %u\n", g_dpa_uid_list_cnt);
75 list_for_each_entry_safe(node, tmp_node, &g_dpa_uid_list, list_node)
76 seq_printf(seq, "%u\n", node->uid);
77 read_unlock(&g_dpa_rwlock);
78 return 0;
79 }
80
81 // echo "add xx yy zz" > /proc/net/dpa_uid
82 // echo "del xx yy zz" > /proc/net/dpa_uid
83 static int dpa_uid_add(uid_t uid);
84 static int dpa_uid_del(uid_t uid);
85 static int get_dpa_uids(char *buf, size_t size, u32 *uid_list,
86 u32 index_max, u32 *index);
87 static void dpa_ext_init(void);
dpa_uid_write(struct file * file,char * buf,size_t size)88 static int dpa_uid_write(struct file *file, char *buf, size_t size)
89 {
90 u32 *dpa_list = (u32 *)kmalloc(LIST_MAX * sizeof(u32), GFP_KERNEL);
91 u32 index = 0;
92 int ret = -EINVAL;
93 int i;
94
95 if (!dpa_list)
96 return ret;
97
98 if (get_dpa_uids(buf, size, dpa_list, LIST_MAX, &index) != 0) {
99 kfree(dpa_list);
100 pr_err("[dpa-uid-cfg] fail to parse dpa uids\n");
101 return ret;
102 }
103
104 if (strncmp(buf, "add", OPT_LEN) == 0) {
105 dpa_ext_init();
106 for (i = 0; i < index; i++) {
107 ret = dpa_uid_add(dpa_list[i]);
108 if (ret != 0) {
109 kfree(dpa_list);
110 return ret;
111 }
112 }
113 } else if (strncmp(buf, "del", OPT_LEN) == 0) {
114 for (i = 0; i < index; i++) {
115 ret = dpa_uid_del(dpa_list[i]);
116 if (ret != 0) {
117 kfree(dpa_list);
118 return ret;
119 }
120 }
121 } else {
122 pr_err("[dpa-uid-cfg] cmd unknown\n");
123 }
124 kfree(dpa_list);
125 return ret;
126 }
127
dpa_uid_add(uid_t uid)128 static int dpa_uid_add(uid_t uid)
129 {
130 bool exist = false;
131 struct dpa_node *node = NULL;
132 struct dpa_node *tmp_node = NULL;
133
134 write_lock(&g_dpa_rwlock);
135 if (g_dpa_uid_list_cnt >= LIST_MAX) {
136 write_unlock(&g_dpa_rwlock);
137 return -EFBIG;
138 }
139
140 list_for_each_entry_safe(node, tmp_node, &g_dpa_uid_list, list_node) {
141 if (node->uid == uid) {
142 exist = true;
143 break;
144 }
145 }
146
147 if (!exist) {
148 node = kzalloc(sizeof(*node), GFP_ATOMIC);
149 if (node) {
150 node->uid = uid;
151 list_add_tail(&node->list_node, &g_dpa_uid_list);
152 g_dpa_uid_list_cnt++;
153 }
154 }
155 write_unlock(&g_dpa_rwlock);
156 return 0;
157 }
158
dpa_uid_del(uid_t uid)159 static int dpa_uid_del(uid_t uid)
160 {
161 struct dpa_node *node = NULL;
162 struct dpa_node *tmp_node = NULL;
163
164 write_lock(&g_dpa_rwlock);
165 list_for_each_entry_safe(node, tmp_node, &g_dpa_uid_list, list_node) {
166 if (node->uid == uid) {
167 list_del(&node->list_node);
168 if (g_dpa_uid_list_cnt)
169 --g_dpa_uid_list_cnt;
170 break;
171 }
172 }
173 write_unlock(&g_dpa_rwlock);
174 return 0;
175 }
176
parse_single_uid(char * begin,char * end)177 static uid_t parse_single_uid(char *begin, char *end)
178 {
179 char *cur = NULL;
180 uid_t uid = 0;
181 u32 len = end - begin;
182
183 // u32 decimal characters (4,294,967,295)
184 if (len > DECIMAL_CHAR_NUM) {
185 pr_err("[dpa-uid-cfg] single uid len(%u) overflow\n", len);
186 return uid;
187 }
188
189 cur = begin;
190 while (cur < end) {
191 if (*cur < '0' || *cur > '9') {
192 pr_err("[dpa-uid-cfg] invalid character '%c'\n", *cur);
193 return uid;
194 }
195 cur++;
196 }
197
198 uid = simple_strtoul(begin, &begin, TO_DECIMAL);
199 if (!begin || !uid) {
200 pr_err("[dpa-uid-cfg] fail to change str to data");
201 return uid;
202 }
203
204 return uid;
205 }
206
parse_uids(char * args,u32 args_len,u32 * uid_list,u32 index_max,u32 * index)207 static int parse_uids(char *args, u32 args_len, u32 *uid_list,
208 u32 index_max, u32 *index)
209 {
210 char *begin = args;
211 char *end = strchr(args, ' ');
212 uid_t uid = 0;
213 u32 len = 0;
214
215 while (end) {
216 // cur decimal characters cnt + ' ' or '\n'
217 len += end - begin + 1;
218 if (len > args_len || *index > index_max) {
219 pr_err("[dpa-uid-cfg] str len(%u) or index(%u) overflow\n",
220 len, *index);
221 return -EINVAL;
222 }
223
224 uid = parse_single_uid(begin, end);
225 if (!uid)
226 return -EINVAL;
227 uid_list[(*index)++] = uid;
228 begin = ++end; // next decimal characters (skip ' ' or '\n')
229 end = strchr(begin, ' ');
230 }
231
232 // find last uid characters
233 end = strchr(begin, '\n');
234 if (!end) {
235 pr_err("[dpa-uid-cfg] last character is not '\\n'");
236 return -EINVAL;
237 }
238
239 // cur decimal characters cnt + ' ' or '\n'
240 len += end - begin + 1;
241 if (len > args_len || *index > index_max) {
242 pr_err("[dpa-uid-cfg] str len(%u) or last index(%u) overflow\n",
243 len, *index);
244 return -EINVAL;
245 }
246 uid = parse_single_uid(begin, end);
247 if (!uid)
248 return -EINVAL;
249 uid_list[(*index)++] = uid;
250 return 0;
251 }
252
get_dpa_uids(char * buf,size_t size,u32 * uid_list,u32 index_max,u32 * index)253 static int get_dpa_uids(char *buf, size_t size, u32 *uid_list,
254 u32 index_max, u32 *index)
255 {
256 char *args = NULL;
257 u32 opt_len;
258 u32 data_len;
259
260 // split into cmd and argslist
261 args = strchr(buf, ' ');
262 if (!args) {
263 pr_err("[dpa-uid-cfg] cmd fmt invalid\n");
264 return -EINVAL;
265 }
266
267 // cmd is add or del, len is 3
268 opt_len = args - buf;
269 if (opt_len != OPT_LEN) {
270 pr_err("[dpa-uid-cfg] cmd len invalid\n");
271 return -EINVAL;
272 }
273
274 data_len = size - (opt_len + 1);
275 return parse_uids(args + 1, data_len, uid_list, index_max, index);
276 }
277
dpa_uid_match(uid_t kuid)278 bool dpa_uid_match(uid_t kuid)
279 {
280 bool match = false;
281 struct dpa_node *node = NULL;
282 struct dpa_node *tmp_node = NULL;
283
284 if (kuid == 0)
285 return match;
286
287 read_lock(&g_dpa_rwlock);
288 list_for_each_entry_safe(node, tmp_node, &g_dpa_uid_list, list_node) {
289 if (node->uid == kuid) {
290 match = true;
291 break;
292 }
293 }
294 read_unlock(&g_dpa_rwlock);
295 return match;
296 }
297 EXPORT_SYMBOL(dpa_uid_match);
298
regist_dpa_init(ext_init fun)299 void regist_dpa_init(ext_init fun)
300 {
301 if (!fun)
302 return;
303 g_dpa_init_fun = fun;
304 }
305
dpa_ext_init(void)306 static void dpa_ext_init(void)
307 {
308 if (g_dpa_init_fun)
309 g_dpa_init_fun();
310 }
311
lowpower_protocol_net_exit(struct net * net)312 void __net_exit lowpower_protocol_net_exit(struct net *net)
313 {
314 remove_proc_entry("foreground_uid", net->proc_net);
315 remove_proc_entry("dpa_uid", net->proc_net);
316 }
317
318 // call this fun in net/ipv4/af_inet.c inet_init_net()
lowpower_protocol_net_init(struct net * net)319 void __net_init lowpower_protocol_net_init(struct net *net)
320 {
321 if (!proc_create_net_single_write("foreground_uid", 0644,
322 net->proc_net,
323 foreground_uid_show,
324 foreground_uid_write,
325 NULL))
326 pr_err("fail to create /proc/net/foreground_uid");
327
328 if (!proc_create_net_single_write("dpa_uid", 0644,
329 net->proc_net,
330 dpa_uid_show,
331 dpa_uid_write,
332 NULL))
333 pr_err("fail to create /proc/net/dpa_uid");
334 }
335
foreground_uid_match(struct net * net,struct sock * sk)336 static bool foreground_uid_match(struct net *net, struct sock *sk)
337 {
338 uid_t kuid;
339 uid_t foreground_uid;
340 struct sock *fullsk;
341
342 if (!net || !sk)
343 return false;
344
345 fullsk = sk_to_full_sk(sk);
346 if (!fullsk || !sk_fullsock(fullsk))
347 return false;
348
349 kuid = sock_net_uid(net, fullsk).val;
350 foreground_uid = foreground_uid_atomic_read();
351 if (kuid != foreground_uid)
352 return false;
353
354 return true;
355 }
356
357 /*
358 * ack optimization is only enable for large data receiving tasks and
359 * there is no packet loss scenario
360 */
tcp_ack_num(struct sock * sk)361 int tcp_ack_num(struct sock *sk)
362 {
363 if (!sk)
364 return 1;
365
366 if (foreground_uid_match(sock_net(sk), sk) == false)
367 return 1;
368
369 if (tcp_sk(sk)->bytes_received >= BIG_DATA_BYTES &&
370 tcp_sk(sk)->dup_ack_counter < TCP_FASTRETRANS_THRESH)
371 return TCP_ACK_NUM;
372 return 1;
373 }
374
netfilter_bypass_enable(struct net * net,struct sk_buff * skb,int (* fun)(struct net *,struct sock *,struct sk_buff *),int * ret)375 bool netfilter_bypass_enable(struct net *net, struct sk_buff *skb,
376 int (*fun)(struct net *, struct sock *, struct sk_buff *),
377 int *ret)
378 {
379 if (!net || !skb || !ip_hdr(skb) || ip_hdr(skb)->protocol != IPPROTO_TCP)
380 return false;
381
382 if (foreground_uid_match(net, skb->sk)) {
383 *ret = fun(net, NULL, skb);
384 return true;
385 }
386 return false;
387 }
388
lowpower_register(void)389 static int __init lowpower_register(void)
390 {
391 INIT_LIST_HEAD(&g_dpa_uid_list);
392 return 0;
393 }
394
395 module_init(lowpower_register);
396 #endif /* CONFIG_LOWPOWER_PROTOCOL */
397