• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // SPDX-License-Identifier: GPL-2.0+
2 /*
3  * Copyright (c) 2024 Huawei Device Co., Ltd.
4  *
5  * Operations on the lowpower protocol
6  * Authors: yangyanjun
7  */
8 #ifdef CONFIG_LOWPOWER_PROTOCOL
9 #include <linux/types.h>
10 #include <linux/kernel.h>
11 #include <linux/proc_fs.h>
12 #include <linux/printk.h>
13 #include <linux/list.h>
14 #include <linux/rwlock_types.h>
15 #include <linux/net_namespace.h>
16 #include <net/sock.h>
17 #include <net/ip.h>
18 #include <net/tcp.h>
19 #include <net/lowpower_protocol.h>
20 
21 static atomic_t g_foreground_uid = ATOMIC_INIT(FOREGROUND_UID_INIT);
22 #define OPT_LEN 3
23 #define TO_DECIMAL 10
24 #define LIST_MAX 500
25 #define DECIMAL_CHAR_NUM 10 // u32 decimal characters (4,294,967,295)
26 static DEFINE_RWLOCK(g_dpa_rwlock);
27 static u32 g_dpa_uid_list_cnt;
28 static struct list_head g_dpa_uid_list;
29 struct dpa_node {
30 	struct list_head list_node;
31 	uid_t uid;
32 };
33 static ext_init g_dpa_init_fun;
34 
foreground_uid_atomic_set(uid_t val)35 static void foreground_uid_atomic_set(uid_t val)
36 {
37 	atomic_set(&g_foreground_uid, val);
38 }
39 
foreground_uid_atomic_read(void)40 static uid_t foreground_uid_atomic_read(void)
41 {
42 	return (uid_t)atomic_read(&g_foreground_uid);
43 }
44 
45 // cat /proc/net/foreground_uid
foreground_uid_show(struct seq_file * seq,void * v)46 static int foreground_uid_show(struct seq_file *seq, void *v)
47 {
48 	uid_t uid = foreground_uid_atomic_read();
49 
50 	seq_printf(seq, "%u\n", uid);
51 	return 0;
52 }
53 
54 // echo xx > /proc/net/foreground_uid
foreground_uid_write(struct file * file,char * buf,size_t size)55 static int foreground_uid_write(struct file *file, char *buf, size_t size)
56 {
57 	char *p = buf;
58 	uid_t uid = simple_strtoul(p, &p, TO_DECIMAL);
59 
60 	if (!p)
61 		return -EINVAL;
62 
63 	foreground_uid_atomic_set(uid);
64 	return 0;
65 }
66 
67 // cat /proc/net/dpa_uid
dpa_uid_show(struct seq_file * seq,void * v)68 static int dpa_uid_show(struct seq_file *seq, void *v)
69 {
70 	struct dpa_node *node = NULL;
71 	struct dpa_node *tmp_node = NULL;
72 
73 	read_lock(&g_dpa_rwlock);
74 	seq_printf(seq, "uid list num: %u\n", g_dpa_uid_list_cnt);
75 	list_for_each_entry_safe(node, tmp_node, &g_dpa_uid_list, list_node)
76 		seq_printf(seq, "%u\n", node->uid);
77 	read_unlock(&g_dpa_rwlock);
78 	return 0;
79 }
80 
81 // echo "add xx yy zz" > /proc/net/dpa_uid
82 // echo "del xx yy zz" > /proc/net/dpa_uid
83 static int dpa_uid_add(uid_t uid);
84 static int dpa_uid_del(uid_t uid);
85 static int get_dpa_uids(char *buf, size_t size, u32 *uid_list,
86 			u32 index_max, u32 *index);
87 static void dpa_ext_init(void);
dpa_uid_write(struct file * file,char * buf,size_t size)88 static int dpa_uid_write(struct file *file, char *buf, size_t size)
89 {
90 	u32 dpa_list[LIST_MAX];
91 	u32 index = 0;
92 	int ret = -EINVAL;
93 	int i;
94 
95 	if (get_dpa_uids(buf, size, dpa_list, LIST_MAX, &index) != 0) {
96 		pr_err("[dpa-uid-cfg] fail to parse dpa uids\n");
97 		return ret;
98 	}
99 
100 	if (strncmp(buf, "add", OPT_LEN) == 0) {
101 		dpa_ext_init();
102 		for (i = 0; i < index; i++) {
103 			ret = dpa_uid_add(dpa_list[i]);
104 			if (ret != 0) {
105 				pr_err("[dpa-uid-cfg] add fail, index=%u\n", i);
106 				return ret;
107 			}
108 		}
109 	} else if (strncmp(buf, "del", OPT_LEN) == 0) {
110 		for (i = 0; i < index; i++) {
111 			ret = dpa_uid_del(dpa_list[i]);
112 			if (ret != 0) {
113 				pr_err("[dpa-uid-cfg] del fail, index=%u\n", i);
114 				return ret;
115 			}
116 		}
117 	} else {
118 		pr_err("[dpa-uid-cfg] cmd unknown\n");
119 	}
120 	return ret;
121 }
122 
dpa_uid_add(uid_t uid)123 static int dpa_uid_add(uid_t uid)
124 {
125 	bool exist = false;
126 	struct dpa_node *node = NULL;
127 	struct dpa_node *tmp_node = NULL;
128 
129 	write_lock(&g_dpa_rwlock);
130 	if (g_dpa_uid_list_cnt >= LIST_MAX) {
131 		write_unlock(&g_dpa_rwlock);
132 		return -EFBIG;
133 	}
134 
135 	list_for_each_entry_safe(node, tmp_node, &g_dpa_uid_list, list_node) {
136 		if (node->uid == uid) {
137 			exist = true;
138 			break;
139 		}
140 	}
141 
142 	if (!exist) {
143 		node = kzalloc(sizeof(*node), GFP_ATOMIC);
144 		if (node) {
145 			node->uid = uid;
146 			list_add_tail(&node->list_node, &g_dpa_uid_list);
147 			g_dpa_uid_list_cnt++;
148 		}
149 	}
150 	write_unlock(&g_dpa_rwlock);
151 	return 0;
152 }
153 
dpa_uid_del(uid_t uid)154 static int dpa_uid_del(uid_t uid)
155 {
156 	struct dpa_node *node = NULL;
157 	struct dpa_node *tmp_node = NULL;
158 
159 	write_lock(&g_dpa_rwlock);
160 	list_for_each_entry_safe(node, tmp_node, &g_dpa_uid_list, list_node) {
161 		if (node->uid == uid) {
162 			list_del(&node->list_node);
163 			if (g_dpa_uid_list_cnt)
164 				--g_dpa_uid_list_cnt;
165 			break;
166 		}
167 	}
168 	write_unlock(&g_dpa_rwlock);
169 	return 0;
170 }
171 
parse_single_uid(char * begin,char * end)172 static uid_t parse_single_uid(char *begin, char *end)
173 {
174 	char *cur = NULL;
175 	uid_t uid = 0;
176 	u32 len = end - begin;
177 
178 	// u32 decimal characters (4,294,967,295)
179 	if (len > DECIMAL_CHAR_NUM) {
180 		pr_err("[dpa-uid-cfg] single uid len(%u) overflow\n", len);
181 		return uid;
182 	}
183 
184 	cur = begin;
185 	while (cur < end) {
186 		if (*cur < '0' || *cur > '9') {
187 			pr_err("[dpa-uid-cfg] invalid character '%c'\n", *cur);
188 			return uid;
189 		}
190 		cur++;
191 	}
192 
193 	uid = simple_strtoul(begin, &begin, TO_DECIMAL);
194 	if (!begin || !uid) {
195 		pr_err("[dpa-uid-cfg] fail to change str to data");
196 		return uid;
197 	}
198 
199 	return uid;
200 }
201 
parse_uids(char * args,u32 args_len,u32 * uid_list,u32 index_max,u32 * index)202 static int parse_uids(char *args, u32 args_len, u32 *uid_list,
203 		      u32 index_max, u32 *index)
204 {
205 	char *begin = args;
206 	char *end = strchr(args, ' ');
207 	uid_t uid = 0;
208 	u32 len = 0;
209 
210 	while (end) {
211 		// cur decimal characters cnt + ' ' or '\n'
212 		len += end - begin + 1;
213 		if (len > args_len || *index > index_max) {
214 			pr_err("[dpa-uid-cfg] str len(%u) or index(%u) overflow\n",
215 			       len, *index);
216 			return -EINVAL;
217 		}
218 
219 		uid = parse_single_uid(begin, end);
220 		if (!uid)
221 			return -EINVAL;
222 		uid_list[(*index)++] = uid;
223 		begin = ++end; // next decimal characters (skip ' ' or '\n')
224 		end = strchr(begin, ' ');
225 	}
226 
227 	// find last uid characters
228 	end = strchr(begin, '\n');
229 	if (!end) {
230 		pr_err("[dpa-uid-cfg] last character is not '\\n'");
231 		return -EINVAL;
232 	}
233 
234 	// cur decimal characters cnt + ' ' or '\n'
235 	len += end - begin + 1;
236 	if (len > args_len || *index > index_max) {
237 		pr_err("[dpa-uid-cfg] str len(%u) or last index(%u) overflow\n",
238 			len, *index);
239 		return -EINVAL;
240 	}
241 	uid = parse_single_uid(begin, end);
242 	if (!uid)
243 		return -EINVAL;
244 	uid_list[(*index)++] = uid;
245 	return 0;
246 }
247 
get_dpa_uids(char * buf,size_t size,u32 * uid_list,u32 index_max,u32 * index)248 static int get_dpa_uids(char *buf, size_t size, u32 *uid_list,
249 			u32 index_max, u32 *index)
250 {
251 	char *args = NULL;
252 	u32 opt_len;
253 	u32 data_len;
254 
255 	// split into cmd and argslist
256 	args = strchr(buf, ' ');
257 	if (!args) {
258 		pr_err("[dpa-uid-cfg] cmd fmt invalid\n");
259 		return -EINVAL;
260 	}
261 
262 	// cmd is add or del, len is 3
263 	opt_len = args - buf;
264 	if (opt_len != OPT_LEN) {
265 		pr_err("[dpa-uid-cfg] cmd len invalid\n");
266 		return -EINVAL;
267 	}
268 
269 	data_len = size - (opt_len + 1);
270 	return parse_uids(args + 1, data_len, uid_list, index_max, index);
271 }
272 
dpa_uid_match(uid_t kuid)273 bool dpa_uid_match(uid_t kuid)
274 {
275 	bool match = false;
276 	struct dpa_node *node = NULL;
277 	struct dpa_node *tmp_node = NULL;
278 
279 	if (kuid == 0)
280 		return match;
281 
282 	read_lock(&g_dpa_rwlock);
283 	list_for_each_entry_safe(node, tmp_node, &g_dpa_uid_list, list_node) {
284 		if (node->uid == kuid) {
285 			match = true;
286 			break;
287 		}
288 	}
289 	read_unlock(&g_dpa_rwlock);
290 	return match;
291 }
292 EXPORT_SYMBOL(dpa_uid_match);
293 
regist_dpa_init(ext_init fun)294 void regist_dpa_init(ext_init fun)
295 {
296 	if (!fun)
297 		return;
298 	g_dpa_init_fun = fun;
299 }
300 
dpa_ext_init(void)301 static void dpa_ext_init(void)
302 {
303 	if (g_dpa_init_fun)
304 		g_dpa_init_fun();
305 }
306 
307 // call this fun in net/ipv4/af_inet.c inet_init_net()
lowpower_protocol_net_init(struct net * net)308 void __net_init lowpower_protocol_net_init(struct net *net)
309 {
310 	if (!proc_create_net_single_write("foreground_uid", 0644,
311 					  net->proc_net,
312 					  foreground_uid_show,
313 					  foreground_uid_write,
314 					  NULL))
315 		pr_err("fail to create /proc/net/foreground_uid");
316 
317 	if (!proc_create_net_single_write("dpa_uid", 0644,
318 					  net->proc_net,
319 					  dpa_uid_show,
320 					  dpa_uid_write,
321 					  NULL))
322 		pr_err("fail to create /proc/net/dpa_uid");
323 }
324 
foreground_uid_match(struct net * net,struct sock * sk)325 static bool foreground_uid_match(struct net *net, struct sock *sk)
326 {
327 	uid_t kuid;
328 	uid_t foreground_uid;
329 	struct sock *fullsk;
330 
331 	if (!net || !sk)
332 		return false;
333 
334 	fullsk = sk_to_full_sk(sk);
335 	if (!fullsk || !sk_fullsock(fullsk))
336 		return false;
337 
338 	kuid = sock_net_uid(net, fullsk).val;
339 	foreground_uid = foreground_uid_atomic_read();
340 	if (kuid != foreground_uid)
341 		return false;
342 
343 	return true;
344 }
345 
346 /*
347  * ack optimization is only enable for large data receiving tasks and
348  * there is no packet loss scenario
349  */
tcp_ack_num(struct sock * sk)350 int tcp_ack_num(struct sock *sk)
351 {
352 	if (!sk)
353 		return 1;
354 
355 	if (foreground_uid_match(sock_net(sk), sk) == false)
356 		return 1;
357 
358 	if (tcp_sk(sk)->bytes_received >= BIG_DATA_BYTES &&
359 	    tcp_sk(sk)->dup_ack_counter < TCP_FASTRETRANS_THRESH)
360 		return TCP_ACK_NUM;
361 	return 1;
362 }
363 
netfilter_bypass_enable(struct net * net,struct sk_buff * skb,int (* fun)(struct net *,struct sock *,struct sk_buff *),int * ret)364 bool netfilter_bypass_enable(struct net *net, struct sk_buff *skb,
365 			     int (*fun)(struct net *, struct sock *, struct sk_buff *),
366 			     int *ret)
367 {
368 	if (!net || !skb || !ip_hdr(skb) || ip_hdr(skb)->protocol != IPPROTO_TCP)
369 		return false;
370 
371 	if (foreground_uid_match(net, skb->sk)) {
372 		*ret = fun(net, NULL, skb);
373 		return true;
374 	}
375 	return false;
376 }
377 
lowpower_register(void)378 static int __init lowpower_register(void)
379 {
380 	INIT_LIST_HEAD(&g_dpa_uid_list);
381 	return 0;
382 }
383 
384 module_init(lowpower_register);
385 #endif /* CONFIG_LOWPOWER_PROTOCOL */
386