1 // SPDX-License-Identifier: GPL-2.0+
2 /*
3 * Copyright (c) 2024 Huawei Device Co., Ltd.
4 *
5 * Operations on the lowpower protocol
6 * Authors: yangyanjun
7 */
8 #ifdef CONFIG_LOWPOWER_PROTOCOL
9 #include <linux/types.h>
10 #include <linux/kernel.h>
11 #include <linux/proc_fs.h>
12 #include <linux/printk.h>
13 #include <linux/list.h>
14 #include <linux/rwlock_types.h>
15 #include <linux/net_namespace.h>
16 #include <net/sock.h>
17 #include <net/ip.h>
18 #include <net/tcp.h>
19 #include <net/lowpower_protocol.h>
20
21 static atomic_t g_foreground_uid = ATOMIC_INIT(FOREGROUND_UID_INIT);
22 #define OPT_LEN 3
23 #define TO_DECIMAL 10
24 #define LIST_MAX 500
25 #define DECIMAL_CHAR_NUM 10 // u32 decimal characters (4,294,967,295)
26 static DEFINE_RWLOCK(g_dpa_rwlock);
27 static u32 g_dpa_uid_list_cnt;
28 static struct list_head g_dpa_uid_list;
29 struct dpa_node {
30 struct list_head list_node;
31 uid_t uid;
32 };
33 static ext_init g_dpa_init_fun;
34
foreground_uid_atomic_set(uid_t val)35 static void foreground_uid_atomic_set(uid_t val)
36 {
37 atomic_set(&g_foreground_uid, val);
38 }
39
foreground_uid_atomic_read(void)40 static uid_t foreground_uid_atomic_read(void)
41 {
42 return (uid_t)atomic_read(&g_foreground_uid);
43 }
44
45 // cat /proc/net/foreground_uid
foreground_uid_show(struct seq_file * seq,void * v)46 static int foreground_uid_show(struct seq_file *seq, void *v)
47 {
48 uid_t uid = foreground_uid_atomic_read();
49
50 seq_printf(seq, "%u\n", uid);
51 return 0;
52 }
53
54 // echo xx > /proc/net/foreground_uid
foreground_uid_write(struct file * file,char * buf,size_t size)55 static int foreground_uid_write(struct file *file, char *buf, size_t size)
56 {
57 char *p = buf;
58 uid_t uid = simple_strtoul(p, &p, TO_DECIMAL);
59
60 if (!p)
61 return -EINVAL;
62
63 foreground_uid_atomic_set(uid);
64 return 0;
65 }
66
67 // cat /proc/net/dpa_uid
dpa_uid_show(struct seq_file * seq,void * v)68 static int dpa_uid_show(struct seq_file *seq, void *v)
69 {
70 struct dpa_node *node = NULL;
71 struct dpa_node *tmp_node = NULL;
72
73 read_lock(&g_dpa_rwlock);
74 seq_printf(seq, "uid list num: %u\n", g_dpa_uid_list_cnt);
75 list_for_each_entry_safe(node, tmp_node, &g_dpa_uid_list, list_node)
76 seq_printf(seq, "%u\n", node->uid);
77 read_unlock(&g_dpa_rwlock);
78 return 0;
79 }
80
81 // echo "add xx yy zz" > /proc/net/dpa_uid
82 // echo "del xx yy zz" > /proc/net/dpa_uid
83 static int dpa_uid_add(uid_t uid);
84 static int dpa_uid_del(uid_t uid);
85 static int get_dpa_uids(char *buf, size_t size, u32 *uid_list,
86 u32 index_max, u32 *index);
87 static void dpa_ext_init(void);
dpa_uid_write(struct file * file,char * buf,size_t size)88 static int dpa_uid_write(struct file *file, char *buf, size_t size)
89 {
90 u32 dpa_list[LIST_MAX];
91 u32 index = 0;
92 int ret = -EINVAL;
93 int i;
94
95 if (get_dpa_uids(buf, size, dpa_list, LIST_MAX, &index) != 0) {
96 pr_err("[dpa-uid-cfg] fail to parse dpa uids\n");
97 return ret;
98 }
99
100 if (strncmp(buf, "add", OPT_LEN) == 0) {
101 dpa_ext_init();
102 for (i = 0; i < index; i++) {
103 ret = dpa_uid_add(dpa_list[i]);
104 if (ret != 0) {
105 pr_err("[dpa-uid-cfg] add fail, index=%u\n", i);
106 return ret;
107 }
108 }
109 } else if (strncmp(buf, "del", OPT_LEN) == 0) {
110 for (i = 0; i < index; i++) {
111 ret = dpa_uid_del(dpa_list[i]);
112 if (ret != 0) {
113 pr_err("[dpa-uid-cfg] del fail, index=%u\n", i);
114 return ret;
115 }
116 }
117 } else {
118 pr_err("[dpa-uid-cfg] cmd unknown\n");
119 }
120 return ret;
121 }
122
dpa_uid_add(uid_t uid)123 static int dpa_uid_add(uid_t uid)
124 {
125 bool exist = false;
126 struct dpa_node *node = NULL;
127 struct dpa_node *tmp_node = NULL;
128
129 write_lock(&g_dpa_rwlock);
130 if (g_dpa_uid_list_cnt >= LIST_MAX) {
131 write_unlock(&g_dpa_rwlock);
132 return -EFBIG;
133 }
134
135 list_for_each_entry_safe(node, tmp_node, &g_dpa_uid_list, list_node) {
136 if (node->uid == uid) {
137 exist = true;
138 break;
139 }
140 }
141
142 if (!exist) {
143 node = kzalloc(sizeof(*node), GFP_ATOMIC);
144 if (node) {
145 node->uid = uid;
146 list_add_tail(&node->list_node, &g_dpa_uid_list);
147 g_dpa_uid_list_cnt++;
148 }
149 }
150 write_unlock(&g_dpa_rwlock);
151 return 0;
152 }
153
dpa_uid_del(uid_t uid)154 static int dpa_uid_del(uid_t uid)
155 {
156 struct dpa_node *node = NULL;
157 struct dpa_node *tmp_node = NULL;
158
159 write_lock(&g_dpa_rwlock);
160 list_for_each_entry_safe(node, tmp_node, &g_dpa_uid_list, list_node) {
161 if (node->uid == uid) {
162 list_del(&node->list_node);
163 if (g_dpa_uid_list_cnt)
164 --g_dpa_uid_list_cnt;
165 break;
166 }
167 }
168 write_unlock(&g_dpa_rwlock);
169 return 0;
170 }
171
parse_single_uid(char * begin,char * end)172 static uid_t parse_single_uid(char *begin, char *end)
173 {
174 char *cur = NULL;
175 uid_t uid = 0;
176 u32 len = end - begin;
177
178 // u32 decimal characters (4,294,967,295)
179 if (len > DECIMAL_CHAR_NUM) {
180 pr_err("[dpa-uid-cfg] single uid len(%u) overflow\n", len);
181 return uid;
182 }
183
184 cur = begin;
185 while (cur < end) {
186 if (*cur < '0' || *cur > '9') {
187 pr_err("[dpa-uid-cfg] invalid character '%c'\n", *cur);
188 return uid;
189 }
190 cur++;
191 }
192
193 uid = simple_strtoul(begin, &begin, TO_DECIMAL);
194 if (!begin || !uid) {
195 pr_err("[dpa-uid-cfg] fail to change str to data");
196 return uid;
197 }
198
199 return uid;
200 }
201
parse_uids(char * args,u32 args_len,u32 * uid_list,u32 index_max,u32 * index)202 static int parse_uids(char *args, u32 args_len, u32 *uid_list,
203 u32 index_max, u32 *index)
204 {
205 char *begin = args;
206 char *end = strchr(args, ' ');
207 uid_t uid = 0;
208 u32 len = 0;
209
210 while (end) {
211 // cur decimal characters cnt + ' ' or '\n'
212 len += end - begin + 1;
213 if (len > args_len || *index > index_max) {
214 pr_err("[dpa-uid-cfg] str len(%u) or index(%u) overflow\n",
215 len, *index);
216 return -EINVAL;
217 }
218
219 uid = parse_single_uid(begin, end);
220 if (!uid)
221 return -EINVAL;
222 uid_list[(*index)++] = uid;
223 begin = ++end; // next decimal characters (skip ' ' or '\n')
224 end = strchr(begin, ' ');
225 }
226
227 // find last uid characters
228 end = strchr(begin, '\n');
229 if (!end) {
230 pr_err("[dpa-uid-cfg] last character is not '\\n'");
231 return -EINVAL;
232 }
233
234 // cur decimal characters cnt + ' ' or '\n'
235 len += end - begin + 1;
236 if (len > args_len || *index > index_max) {
237 pr_err("[dpa-uid-cfg] str len(%u) or last index(%u) overflow\n",
238 len, *index);
239 return -EINVAL;
240 }
241 uid = parse_single_uid(begin, end);
242 if (!uid)
243 return -EINVAL;
244 uid_list[(*index)++] = uid;
245 return 0;
246 }
247
get_dpa_uids(char * buf,size_t size,u32 * uid_list,u32 index_max,u32 * index)248 static int get_dpa_uids(char *buf, size_t size, u32 *uid_list,
249 u32 index_max, u32 *index)
250 {
251 char *args = NULL;
252 u32 opt_len;
253 u32 data_len;
254
255 // split into cmd and argslist
256 args = strchr(buf, ' ');
257 if (!args) {
258 pr_err("[dpa-uid-cfg] cmd fmt invalid\n");
259 return -EINVAL;
260 }
261
262 // cmd is add or del, len is 3
263 opt_len = args - buf;
264 if (opt_len != OPT_LEN) {
265 pr_err("[dpa-uid-cfg] cmd len invalid\n");
266 return -EINVAL;
267 }
268
269 data_len = size - (opt_len + 1);
270 return parse_uids(args + 1, data_len, uid_list, index_max, index);
271 }
272
dpa_uid_match(uid_t kuid)273 bool dpa_uid_match(uid_t kuid)
274 {
275 bool match = false;
276 struct dpa_node *node = NULL;
277 struct dpa_node *tmp_node = NULL;
278
279 if (kuid == 0)
280 return match;
281
282 read_lock(&g_dpa_rwlock);
283 list_for_each_entry_safe(node, tmp_node, &g_dpa_uid_list, list_node) {
284 if (node->uid == kuid) {
285 match = true;
286 break;
287 }
288 }
289 read_unlock(&g_dpa_rwlock);
290 return match;
291 }
292 EXPORT_SYMBOL(dpa_uid_match);
293
regist_dpa_init(ext_init fun)294 void regist_dpa_init(ext_init fun)
295 {
296 if (!fun)
297 return;
298 g_dpa_init_fun = fun;
299 }
300
dpa_ext_init(void)301 static void dpa_ext_init(void)
302 {
303 if (g_dpa_init_fun)
304 g_dpa_init_fun();
305 }
306
307 // call this fun in net/ipv4/af_inet.c inet_init_net()
lowpower_protocol_net_init(struct net * net)308 void __net_init lowpower_protocol_net_init(struct net *net)
309 {
310 if (!proc_create_net_single_write("foreground_uid", 0644,
311 net->proc_net,
312 foreground_uid_show,
313 foreground_uid_write,
314 NULL))
315 pr_err("fail to create /proc/net/foreground_uid");
316
317 if (!proc_create_net_single_write("dpa_uid", 0644,
318 net->proc_net,
319 dpa_uid_show,
320 dpa_uid_write,
321 NULL))
322 pr_err("fail to create /proc/net/dpa_uid");
323 }
324
foreground_uid_match(struct net * net,struct sock * sk)325 static bool foreground_uid_match(struct net *net, struct sock *sk)
326 {
327 uid_t kuid;
328 uid_t foreground_uid;
329 struct sock *fullsk;
330
331 if (!net || !sk)
332 return false;
333
334 fullsk = sk_to_full_sk(sk);
335 if (!fullsk || !sk_fullsock(fullsk))
336 return false;
337
338 kuid = sock_net_uid(net, fullsk).val;
339 foreground_uid = foreground_uid_atomic_read();
340 if (kuid != foreground_uid)
341 return false;
342
343 return true;
344 }
345
346 /*
347 * ack optimization is only enable for large data receiving tasks and
348 * there is no packet loss scenario
349 */
tcp_ack_num(struct sock * sk)350 int tcp_ack_num(struct sock *sk)
351 {
352 if (!sk)
353 return 1;
354
355 if (foreground_uid_match(sock_net(sk), sk) == false)
356 return 1;
357
358 if (tcp_sk(sk)->bytes_received >= BIG_DATA_BYTES &&
359 tcp_sk(sk)->dup_ack_counter < TCP_FASTRETRANS_THRESH)
360 return TCP_ACK_NUM;
361 return 1;
362 }
363
netfilter_bypass_enable(struct net * net,struct sk_buff * skb,int (* fun)(struct net *,struct sock *,struct sk_buff *),int * ret)364 bool netfilter_bypass_enable(struct net *net, struct sk_buff *skb,
365 int (*fun)(struct net *, struct sock *, struct sk_buff *),
366 int *ret)
367 {
368 if (!net || !skb || !ip_hdr(skb) || ip_hdr(skb)->protocol != IPPROTO_TCP)
369 return false;
370
371 if (foreground_uid_match(net, skb->sk)) {
372 *ret = fun(net, NULL, skb);
373 return true;
374 }
375 return false;
376 }
377
lowpower_register(void)378 static int __init lowpower_register(void)
379 {
380 INIT_LIST_HEAD(&g_dpa_uid_list);
381 return 0;
382 }
383
384 module_init(lowpower_register);
385 #endif /* CONFIG_LOWPOWER_PROTOCOL */
386