1 // SPDX-License-Identifier: GPL-2.0
2 #define pr_fmt(fmt) KBUILD_MODNAME ": " fmt
3 #include <linux/init.h>
4 #include <linux/module.h>
5 #include <linux/umh.h>
6 #include <linux/bpfilter.h>
7 #include <linux/sched.h>
8 #include <linux/sched/signal.h>
9 #include <linux/fs.h>
10 #include <linux/file.h>
11 #include "msgfmt.h"
12
13 extern char bpfilter_umh_start;
14 extern char bpfilter_umh_end;
15
shutdown_umh(void)16 static void shutdown_umh(void)
17 {
18 struct umd_info *info = &bpfilter_ops.info;
19 struct pid *tgid = info->tgid;
20
21 if (tgid) {
22 kill_pid(tgid, SIGKILL, 1);
23 wait_event(tgid->wait_pidfd, thread_group_exited(tgid));
24 bpfilter_umh_cleanup(info);
25 }
26 }
27
__stop_umh(void)28 static void __stop_umh(void)
29 {
30 if (IS_ENABLED(CONFIG_INET))
31 shutdown_umh();
32 }
33
bpfilter_send_req(struct mbox_request * req)34 static int bpfilter_send_req(struct mbox_request *req)
35 {
36 struct mbox_reply reply;
37 loff_t pos = 0;
38 ssize_t n;
39
40 if (!bpfilter_ops.info.tgid)
41 return -EFAULT;
42 pos = 0;
43 n = kernel_write(bpfilter_ops.info.pipe_to_umh, req, sizeof(*req),
44 &pos);
45 if (n != sizeof(*req)) {
46 pr_err("write fail %zd\n", n);
47 goto stop;
48 }
49 pos = 0;
50 n = kernel_read(bpfilter_ops.info.pipe_from_umh, &reply, sizeof(reply),
51 &pos);
52 if (n != sizeof(reply)) {
53 pr_err("read fail %zd\n", n);
54 goto stop;
55 }
56 return reply.status;
57 stop:
58 __stop_umh();
59 return -EFAULT;
60 }
61
bpfilter_process_sockopt(struct sock * sk,int optname,sockptr_t optval,unsigned int optlen,bool is_set)62 static int bpfilter_process_sockopt(struct sock *sk, int optname,
63 sockptr_t optval, unsigned int optlen,
64 bool is_set)
65 {
66 struct mbox_request req = {
67 .is_set = is_set,
68 .pid = current->pid,
69 .cmd = optname,
70 .addr = (uintptr_t)optval.user,
71 .len = optlen,
72 };
73 if (sockptr_is_kernel(optval)) {
74 pr_err("kernel access not supported\n");
75 return -EFAULT;
76 }
77 return bpfilter_send_req(&req);
78 }
79
start_umh(void)80 static int start_umh(void)
81 {
82 struct mbox_request req = { .pid = current->pid };
83 int err;
84
85 /* fork usermode process */
86 err = fork_usermode_driver(&bpfilter_ops.info);
87 if (err)
88 return err;
89 pr_info("Loaded bpfilter_umh pid %d\n", pid_nr(bpfilter_ops.info.tgid));
90
91 /* health check that usermode process started correctly */
92 if (bpfilter_send_req(&req) != 0) {
93 shutdown_umh();
94 return -EFAULT;
95 }
96
97 return 0;
98 }
99
load_umh(void)100 static int __init load_umh(void)
101 {
102 int err;
103
104 err = umd_load_blob(&bpfilter_ops.info,
105 &bpfilter_umh_start,
106 &bpfilter_umh_end - &bpfilter_umh_start);
107 if (err)
108 return err;
109
110 mutex_lock(&bpfilter_ops.lock);
111 err = start_umh();
112 if (!err && IS_ENABLED(CONFIG_INET)) {
113 bpfilter_ops.sockopt = &bpfilter_process_sockopt;
114 bpfilter_ops.start = &start_umh;
115 }
116 mutex_unlock(&bpfilter_ops.lock);
117 if (err)
118 umd_unload_blob(&bpfilter_ops.info);
119 return err;
120 }
121
fini_umh(void)122 static void __exit fini_umh(void)
123 {
124 mutex_lock(&bpfilter_ops.lock);
125 if (IS_ENABLED(CONFIG_INET)) {
126 shutdown_umh();
127 bpfilter_ops.start = NULL;
128 bpfilter_ops.sockopt = NULL;
129 }
130 mutex_unlock(&bpfilter_ops.lock);
131
132 umd_unload_blob(&bpfilter_ops.info);
133 }
134 module_init(load_umh);
135 module_exit(fini_umh);
136 MODULE_LICENSE("GPL");
137