1 /*
2 * Copyright (c) 2021 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "netlink/netlink_listener.h"
17
18 #include <cerrno>
19 #include <iostream>
20 #include <memory>
21 #include <sys/socket.h>
22 #include <unistd.h>
23 #include <linux/netlink.h>
24
25 #include "securec.h"
26 #include "storage_service_errno.h"
27 #include "storage_service_log.h"
28
29 constexpr int POLL_IDLE_TIME = 1000;
30 constexpr int UEVENT_MSG_LEN = 1024;
31
32 namespace OHOS {
33 namespace StorageDaemon {
UeventKernelMulticastRecv(int32_t socket,char * buffer,size_t length)34 ssize_t UeventKernelMulticastRecv(int32_t socket, char *buffer, size_t length)
35 {
36 struct iovec iov = { buffer, length };
37 struct sockaddr_nl addr;
38 char control[CMSG_SPACE(sizeof(struct ucred))];
39 struct msghdr hdr = {
40 .msg_name = &addr,
41 .msg_namelen = sizeof(addr),
42 .msg_iov = &iov,
43 .msg_iovlen = 1,
44 .msg_control = control,
45 .msg_controllen = sizeof(control),
46 .msg_flags = 0,
47 };
48 struct cmsghdr *cmsg;
49
50 ssize_t n = recvmsg(socket, &hdr, 0);
51 if (n <= 0) {
52 LOGE("Recvmsg failed, errno %{public}d", errno);
53 return n;
54 }
55
56 if (addr.nl_groups == 0 || addr.nl_pid != 0) {
57 return E_ERR;
58 }
59
60 cmsg = CMSG_FIRSTHDR(&hdr);
61 if (cmsg == nullptr || cmsg->cmsg_type != SCM_CREDENTIALS) {
62 LOGE("SCM_CREDENTIALS check failed");
63 return E_ERR;
64 }
65
66 struct ucred cred;
67 if (memcpy_s(&cred, sizeof(cred), CMSG_DATA(cmsg), sizeof(struct ucred)) != EOK || cred.uid != 0) {
68 LOGE("Uid check failed");
69 return E_ERR;
70 }
71
72 return n;
73 }
74
RecvUeventMsg()75 void NetlinkListener::RecvUeventMsg()
76 {
77 auto msg = std::make_unique<char[]>(UEVENT_MSG_LEN + 1);
78 ssize_t count;
79
80 while (1) {
81 count = UeventKernelMulticastRecv(socketFd_, msg.get(), UEVENT_MSG_LEN);
82 if (count <= 0) {
83 (void)memset_s(msg.get(), UEVENT_MSG_LEN + 1, 0, UEVENT_MSG_LEN + 1);
84 break;
85 }
86 if (count >= UEVENT_MSG_LEN) {
87 continue;
88 }
89
90 msg.get()[count] = '\0';
91 OnEvent(msg.get());
92 }
93 }
94
ReadMsg(int32_t fd_count,struct pollfd ufds[2])95 int32_t NetlinkListener::ReadMsg(int32_t fd_count, struct pollfd ufds[2])
96 {
97 int32_t i;
98 for (i = 0; i < fd_count; i++) {
99 if (ufds[i].revents == 0) {
100 continue;
101 }
102
103 if (ufds[i].fd == socketPipe_[0]) {
104 int32_t msg = 0;
105 if (read(socketPipe_[0], &msg, 1) < 0) {
106 LOGE("Read socket pipe failed");
107 return E_ERR;
108 }
109 if (msg == 0) {
110 LOGI("Stop listener");
111 return E_ERR;
112 }
113 } else if (ufds[i].fd == socketFd_) {
114 if ((static_cast<uint32_t>(ufds[i].revents) & POLLIN)) {
115 RecvUeventMsg();
116 continue;
117 }
118 if ((static_cast<uint32_t>(ufds[i].revents)) & (POLLERR | POLLHUP)) {
119 LOGE("POLLERR | POLLHUP");
120 return E_ERR;
121 }
122 }
123 }
124 return E_OK;
125 }
126
RunListener()127 void NetlinkListener::RunListener()
128 {
129 struct pollfd ufds[2];
130 int32_t idle_time = POLL_IDLE_TIME;
131
132 while (1) {
133 int32_t fd_count = 0;
134
135 ufds[fd_count].fd = socketPipe_[0];
136 ufds[fd_count].events = POLLIN;
137 ufds[fd_count].revents = 0;
138 fd_count++;
139
140 if (socketFd_ > -1) {
141 ufds[fd_count].fd = socketFd_;
142 ufds[fd_count].events = POLLIN;
143 ufds[fd_count].revents = 0;
144 fd_count++;
145 }
146
147 int32_t n = poll(ufds, fd_count, idle_time);
148 if (n < 0) {
149 if (errno == EAGAIN || errno == EINTR) {
150 continue;
151 }
152 break;
153 } else if (!n) {
154 continue;
155 }
156
157 if (ReadMsg(fd_count, ufds) != 0) {
158 return;
159 }
160 }
161 }
162
EventProcess(void * object)163 void NetlinkListener::EventProcess(void *object)
164 {
165 if (object == nullptr) {
166 return;
167 }
168
169 NetlinkListener* me = reinterpret_cast<NetlinkListener *>(object);
170 me->RunListener();
171 }
172
StartListener()173 int32_t NetlinkListener::StartListener()
174 {
175 if (socketFd_ < 0) {
176 return E_ERR;
177 }
178
179 if (pipe(socketPipe_) == -1) {
180 LOGE("Pipe error");
181 return E_ERR;
182 }
183 socketThread_ = std::make_unique<std::thread>(&NetlinkListener::EventProcess, this);
184 if (socketThread_ == nullptr) {
185 (void)close(socketPipe_[0]);
186 (void)close(socketPipe_[1]);
187 socketPipe_[0] = socketPipe_[1] = -1;
188 return E_ERR;
189 }
190
191 return E_OK;
192 }
193
StopListener()194 int32_t NetlinkListener::StopListener()
195 {
196 int32_t msg = 0;
197 write(socketPipe_[1], &msg, 1);
198
199 if (socketThread_ != nullptr && socketThread_->joinable()) {
200 socketThread_->join();
201 }
202
203 (void)close(socketPipe_[0]);
204 (void)close(socketPipe_[1]);
205 socketPipe_[0] = socketPipe_[1] = -1;
206
207 return E_OK;
208 }
209
NetlinkListener(int32_t socket)210 NetlinkListener::NetlinkListener(int32_t socket)
211 {
212 socketFd_ = socket;
213 }
214 } // StorageDaemon
215 } // OHOS
216