• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // SPDX-License-Identifier: GPL-2.0+
2 /*
3  * Copyright (c) 2024 Huawei Device Co., Ltd.
4  *
5  * Operations on the lowpower protocol
6  * Authors: yangyanjun
7  */
8 #ifdef CONFIG_LOWPOWER_PROTOCOL
9 #include <linux/types.h>
10 #include <linux/kernel.h>
11 #include <linux/proc_fs.h>
12 #include <linux/printk.h>
13 #include <linux/list.h>
14 #include <linux/rwlock_types.h>
15 #include <linux/net_namespace.h>
16 #include <net/sock.h>
17 #include <net/ip.h>
18 #include <net/tcp.h>
19 #include <net/lowpower_protocol.h>
20 
21 static atomic_t g_foreground_uid = ATOMIC_INIT(FOREGROUND_UID_INIT);
22 #define OPT_LEN 3
23 #define TO_DECIMAL 10
24 #define LIST_MAX 500
25 #define DECIMAL_CHAR_NUM 10 // u32 decimal characters (4,294,967,295)
26 static DEFINE_RWLOCK(g_dpa_rwlock);
27 static u32 g_dpa_uid_list_cnt;
28 static struct list_head g_dpa_uid_list;
29 struct dpa_node {
30 	struct list_head list_node;
31 	uid_t uid;
32 };
33 static ext_init g_dpa_init_fun;
34 
foreground_uid_atomic_set(uid_t val)35 static void foreground_uid_atomic_set(uid_t val)
36 {
37 	atomic_set(&g_foreground_uid, val);
38 }
39 
foreground_uid_atomic_read(void)40 static uid_t foreground_uid_atomic_read(void)
41 {
42 	return (uid_t)atomic_read(&g_foreground_uid);
43 }
44 
45 // cat /proc/net/foreground_uid
foreground_uid_show(struct seq_file * seq,void * v)46 static int foreground_uid_show(struct seq_file *seq, void *v)
47 {
48 	uid_t uid = foreground_uid_atomic_read();
49 
50 	seq_printf(seq, "%u\n", uid);
51 	return 0;
52 }
53 
54 // echo xx > /proc/net/foreground_uid
foreground_uid_write(struct file * file,char * buf,size_t size)55 static int foreground_uid_write(struct file *file, char *buf, size_t size)
56 {
57 	char *p = buf;
58 	uid_t uid = simple_strtoul(p, &p, TO_DECIMAL);
59 
60 	if (!p)
61 		return -EINVAL;
62 
63 	foreground_uid_atomic_set(uid);
64 	return 0;
65 }
66 
67 // cat /proc/net/dpa_uid
dpa_uid_show(struct seq_file * seq,void * v)68 static int dpa_uid_show(struct seq_file *seq, void *v)
69 {
70 	struct dpa_node *node = NULL;
71 	struct dpa_node *tmp_node = NULL;
72 
73 	read_lock(&g_dpa_rwlock);
74 	seq_printf(seq, "uid list num: %u\n", g_dpa_uid_list_cnt);
75 	list_for_each_entry_safe(node, tmp_node, &g_dpa_uid_list, list_node)
76 		seq_printf(seq, "%u\n", node->uid);
77 	read_unlock(&g_dpa_rwlock);
78 	return 0;
79 }
80 
81 // echo "add xx yy zz" > /proc/net/dpa_uid
82 // echo "del xx yy zz" > /proc/net/dpa_uid
83 static int dpa_uid_add(uid_t uid);
84 static int dpa_uid_del(uid_t uid);
85 static int get_dpa_uids(char *buf, size_t size, u32 *uid_list,
86 			u32 index_max, u32 *index);
87 static void dpa_ext_init(void);
dpa_uid_write(struct file * file,char * buf,size_t size)88 static int dpa_uid_write(struct file *file, char *buf, size_t size)
89 {
90 	u32 *dpa_list = (u32 *)kmalloc(LIST_MAX * sizeof(u32), GFP_KERNEL);
91 	u32 index = 0;
92 	int ret = -EINVAL;
93 	int i;
94 
95 	if (!dpa_list)
96 		return ret;
97 
98 	if (get_dpa_uids(buf, size, dpa_list, LIST_MAX, &index) != 0) {
99 		kfree(dpa_list);
100 		pr_err("[dpa-uid-cfg] fail to parse dpa uids\n");
101 		return ret;
102 	}
103 
104 	if (strncmp(buf, "add", OPT_LEN) == 0) {
105 		dpa_ext_init();
106 		for (i = 0; i < index; i++) {
107 			ret = dpa_uid_add(dpa_list[i]);
108 			if (ret != 0) {
109 				kfree(dpa_list);
110 				return ret;
111 			}
112 		}
113 	} else if (strncmp(buf, "del", OPT_LEN) == 0) {
114 		for (i = 0; i < index; i++) {
115 			ret = dpa_uid_del(dpa_list[i]);
116 			if (ret != 0) {
117 				kfree(dpa_list);
118 				return ret;
119 			}
120 		}
121 	} else {
122 		pr_err("[dpa-uid-cfg] cmd unknown\n");
123 	}
124 	kfree(dpa_list);
125 	return ret;
126 }
127 
dpa_uid_add(uid_t uid)128 static int dpa_uid_add(uid_t uid)
129 {
130 	bool exist = false;
131 	struct dpa_node *node = NULL;
132 	struct dpa_node *tmp_node = NULL;
133 
134 	write_lock(&g_dpa_rwlock);
135 	if (g_dpa_uid_list_cnt >= LIST_MAX) {
136 		write_unlock(&g_dpa_rwlock);
137 		return -EFBIG;
138 	}
139 
140 	list_for_each_entry_safe(node, tmp_node, &g_dpa_uid_list, list_node) {
141 		if (node->uid == uid) {
142 			exist = true;
143 			break;
144 		}
145 	}
146 
147 	if (!exist) {
148 		node = kzalloc(sizeof(*node), GFP_ATOMIC);
149 		if (node) {
150 			node->uid = uid;
151 			list_add_tail(&node->list_node, &g_dpa_uid_list);
152 			g_dpa_uid_list_cnt++;
153 		}
154 	}
155 	write_unlock(&g_dpa_rwlock);
156 	return 0;
157 }
158 
dpa_uid_del(uid_t uid)159 static int dpa_uid_del(uid_t uid)
160 {
161 	struct dpa_node *node = NULL;
162 	struct dpa_node *tmp_node = NULL;
163 
164 	write_lock(&g_dpa_rwlock);
165 	list_for_each_entry_safe(node, tmp_node, &g_dpa_uid_list, list_node) {
166 		if (node->uid == uid) {
167 			list_del(&node->list_node);
168 			if (g_dpa_uid_list_cnt)
169 				--g_dpa_uid_list_cnt;
170 			break;
171 		}
172 	}
173 	write_unlock(&g_dpa_rwlock);
174 	return 0;
175 }
176 
parse_single_uid(char * begin,char * end)177 static uid_t parse_single_uid(char *begin, char *end)
178 {
179 	char *cur = NULL;
180 	uid_t uid = 0;
181 	u32 len = end - begin;
182 
183 	// u32 decimal characters (4,294,967,295)
184 	if (len > DECIMAL_CHAR_NUM) {
185 		pr_err("[dpa-uid-cfg] single uid len(%u) overflow\n", len);
186 		return uid;
187 	}
188 
189 	cur = begin;
190 	while (cur < end) {
191 		if (*cur < '0' || *cur > '9') {
192 			pr_err("[dpa-uid-cfg] invalid character '%c'\n", *cur);
193 			return uid;
194 		}
195 		cur++;
196 	}
197 
198 	uid = simple_strtoul(begin, &begin, TO_DECIMAL);
199 	if (!begin || !uid) {
200 		pr_err("[dpa-uid-cfg] fail to change str to data");
201 		return uid;
202 	}
203 
204 	return uid;
205 }
206 
parse_uids(char * args,u32 args_len,u32 * uid_list,u32 index_max,u32 * index)207 static int parse_uids(char *args, u32 args_len, u32 *uid_list,
208 		      u32 index_max, u32 *index)
209 {
210 	char *begin = args;
211 	char *end = strchr(args, ' ');
212 	uid_t uid = 0;
213 	u32 len = 0;
214 
215 	while (end) {
216 		// cur decimal characters cnt + ' ' or '\n'
217 		len += end - begin + 1;
218 		if (len > args_len || *index > index_max) {
219 			pr_err("[dpa-uid-cfg] str len(%u) or index(%u) overflow\n",
220 			       len, *index);
221 			return -EINVAL;
222 		}
223 
224 		uid = parse_single_uid(begin, end);
225 		if (!uid)
226 			return -EINVAL;
227 		uid_list[(*index)++] = uid;
228 		begin = ++end; // next decimal characters (skip ' ' or '\n')
229 		end = strchr(begin, ' ');
230 	}
231 
232 	// find last uid characters
233 	end = strchr(begin, '\n');
234 	if (!end) {
235 		pr_err("[dpa-uid-cfg] last character is not '\\n'");
236 		return -EINVAL;
237 	}
238 
239 	// cur decimal characters cnt + ' ' or '\n'
240 	len += end - begin + 1;
241 	if (len > args_len || *index > index_max) {
242 		pr_err("[dpa-uid-cfg] str len(%u) or last index(%u) overflow\n",
243 			len, *index);
244 		return -EINVAL;
245 	}
246 	uid = parse_single_uid(begin, end);
247 	if (!uid)
248 		return -EINVAL;
249 	uid_list[(*index)++] = uid;
250 	return 0;
251 }
252 
get_dpa_uids(char * buf,size_t size,u32 * uid_list,u32 index_max,u32 * index)253 static int get_dpa_uids(char *buf, size_t size, u32 *uid_list,
254 			u32 index_max, u32 *index)
255 {
256 	char *args = NULL;
257 	u32 opt_len;
258 	u32 data_len;
259 
260 	// split into cmd and argslist
261 	args = strchr(buf, ' ');
262 	if (!args) {
263 		pr_err("[dpa-uid-cfg] cmd fmt invalid\n");
264 		return -EINVAL;
265 	}
266 
267 	// cmd is add or del, len is 3
268 	opt_len = args - buf;
269 	if (opt_len != OPT_LEN) {
270 		pr_err("[dpa-uid-cfg] cmd len invalid\n");
271 		return -EINVAL;
272 	}
273 
274 	data_len = size - (opt_len + 1);
275 	return parse_uids(args + 1, data_len, uid_list, index_max, index);
276 }
277 
dpa_uid_match(uid_t kuid)278 bool dpa_uid_match(uid_t kuid)
279 {
280 	bool match = false;
281 	struct dpa_node *node = NULL;
282 	struct dpa_node *tmp_node = NULL;
283 
284 	if (kuid == 0)
285 		return match;
286 
287 	read_lock(&g_dpa_rwlock);
288 	list_for_each_entry_safe(node, tmp_node, &g_dpa_uid_list, list_node) {
289 		if (node->uid == kuid) {
290 			match = true;
291 			break;
292 		}
293 	}
294 	read_unlock(&g_dpa_rwlock);
295 	return match;
296 }
297 EXPORT_SYMBOL(dpa_uid_match);
298 
regist_dpa_init(ext_init fun)299 void regist_dpa_init(ext_init fun)
300 {
301 	if (!fun)
302 		return;
303 	g_dpa_init_fun = fun;
304 }
305 
dpa_ext_init(void)306 static void dpa_ext_init(void)
307 {
308 	if (g_dpa_init_fun)
309 		g_dpa_init_fun();
310 }
311 
lowpower_protocol_net_exit(struct net * net)312 void __net_exit lowpower_protocol_net_exit(struct net *net)
313 {
314 	remove_proc_entry("foreground_uid", net->proc_net);
315 	remove_proc_entry("dpa_uid", net->proc_net);
316 }
317 
318 // call this fun in net/ipv4/af_inet.c inet_init_net()
lowpower_protocol_net_init(struct net * net)319 void __net_init lowpower_protocol_net_init(struct net *net)
320 {
321 	if (!proc_create_net_single_write("foreground_uid", 0644,
322 					  net->proc_net,
323 					  foreground_uid_show,
324 					  foreground_uid_write,
325 					  NULL))
326 		pr_err("fail to create /proc/net/foreground_uid");
327 
328 	if (!proc_create_net_single_write("dpa_uid", 0644,
329 					  net->proc_net,
330 					  dpa_uid_show,
331 					  dpa_uid_write,
332 					  NULL))
333 		pr_err("fail to create /proc/net/dpa_uid");
334 }
335 
foreground_uid_match(struct net * net,struct sock * sk)336 static bool foreground_uid_match(struct net *net, struct sock *sk)
337 {
338 	uid_t kuid;
339 	uid_t foreground_uid;
340 	struct sock *fullsk;
341 
342 	if (!net || !sk)
343 		return false;
344 
345 	fullsk = sk_to_full_sk(sk);
346 	if (!fullsk || !sk_fullsock(fullsk))
347 		return false;
348 
349 	kuid = sock_net_uid(net, fullsk).val;
350 	foreground_uid = foreground_uid_atomic_read();
351 	if (kuid != foreground_uid)
352 		return false;
353 
354 	return true;
355 }
356 
357 /*
358  * ack optimization is only enable for large data receiving tasks and
359  * there is no packet loss scenario
360  */
tcp_ack_num(struct sock * sk)361 int tcp_ack_num(struct sock *sk)
362 {
363 	if (!sk)
364 		return 1;
365 
366 	if (foreground_uid_match(sock_net(sk), sk) == false)
367 		return 1;
368 
369 	if (tcp_sk(sk)->bytes_received >= BIG_DATA_BYTES &&
370 	    tcp_sk(sk)->dup_ack_counter < TCP_FASTRETRANS_THRESH)
371 		return TCP_ACK_NUM;
372 	return 1;
373 }
374 
netfilter_bypass_enable(struct net * net,struct sk_buff * skb,int (* fun)(struct net *,struct sock *,struct sk_buff *),int * ret)375 bool netfilter_bypass_enable(struct net *net, struct sk_buff *skb,
376 			     int (*fun)(struct net *, struct sock *, struct sk_buff *),
377 			     int *ret)
378 {
379 	if (!net || !skb || !ip_hdr(skb) || ip_hdr(skb)->protocol != IPPROTO_TCP)
380 		return false;
381 
382 	if (foreground_uid_match(net, skb->sk)) {
383 		*ret = fun(net, NULL, skb);
384 		return true;
385 	}
386 	return false;
387 }
388 
lowpower_register(void)389 static int __init lowpower_register(void)
390 {
391 	INIT_LIST_HEAD(&g_dpa_uid_list);
392 	return 0;
393 }
394 
395 module_init(lowpower_register);
396 #endif /* CONFIG_LOWPOWER_PROTOCOL */
397