1 /*
2 * Copyright (c) 2021 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "netlink/netlink_listener.h"
17 #include <memory>
18 #include <sys/socket.h>
19 #include <unistd.h>
20 #include <linux/netlink.h>
21
22 #include "securec.h"
23 #include "storage_service_errno.h"
24 #include "storage_service_log.h"
25
26 constexpr int POLL_IDLE_TIME = 1000;
27 constexpr int UEVENT_MSG_LEN = 1024;
28
29 namespace OHOS {
30 namespace StorageDaemon {
UeventKernelMulticastRecv(int32_t socket,char * buffer,size_t length)31 ssize_t UeventKernelMulticastRecv(int32_t socket, char *buffer, size_t length)
32 {
33 struct iovec iov = { buffer, length };
34 struct sockaddr_nl addr;
35 char control[CMSG_SPACE(sizeof(struct ucred))];
36 struct msghdr hdr = {
37 .msg_name = &addr,
38 .msg_namelen = sizeof(addr),
39 .msg_iov = &iov,
40 .msg_iovlen = 1,
41 .msg_control = control,
42 .msg_controllen = sizeof(control),
43 .msg_flags = 0,
44 };
45 struct cmsghdr *cmsg;
46
47 ssize_t n = recvmsg(socket, &hdr, 0);
48 if (n <= 0) {
49 LOGE("Recvmsg failed, errno %{public}d", errno);
50 return n;
51 }
52
53 if (addr.nl_groups == 0 || addr.nl_pid != 0) {
54 return E_ERR;
55 }
56
57 cmsg = CMSG_FIRSTHDR(&hdr);
58 if (cmsg == nullptr || cmsg->cmsg_type != SCM_CREDENTIALS) {
59 LOGE("SCM_CREDENTIALS check failed");
60 return E_ERR;
61 }
62
63 struct ucred cred;
64 if (memcpy_s(&cred, sizeof(cred), CMSG_DATA(cmsg), sizeof(struct ucred)) != EOK || cred.uid != 0) {
65 LOGE("Uid check failed");
66 return E_ERR;
67 }
68
69 return n;
70 }
71
RecvUeventMsg()72 void NetlinkListener::RecvUeventMsg()
73 {
74 auto msg = std::make_unique<char[]>(UEVENT_MSG_LEN + 1);
75
76 while (1) {
77 auto count = UeventKernelMulticastRecv(socketFd_, msg.get(), UEVENT_MSG_LEN);
78 if (count <= 0) {
79 (void)memset_s(msg.get(), UEVENT_MSG_LEN + 1, 0, UEVENT_MSG_LEN + 1);
80 break;
81 }
82 if (count >= UEVENT_MSG_LEN) {
83 continue;
84 }
85
86 msg.get()[count] = '\0';
87 OnEvent(msg.get());
88 }
89 }
90
ReadMsg(int32_t fd_count,struct pollfd ufds[2])91 int32_t NetlinkListener::ReadMsg(int32_t fd_count, struct pollfd ufds[2])
92 {
93 int32_t i;
94 for (i = 0; i < fd_count; i++) {
95 if (ufds[i].revents == 0) {
96 continue;
97 }
98
99 if (ufds[i].fd == socketPipe_[0]) {
100 int32_t msg = 0;
101 if (read(socketPipe_[0], &msg, 1) < 0) {
102 LOGE("Read socket pipe failed");
103 return E_ERR;
104 }
105 if (msg == 0) {
106 LOGI("Stop listener");
107 return E_ERR;
108 }
109 } else if (ufds[i].fd == socketFd_) {
110 if ((static_cast<uint32_t>(ufds[i].revents) & POLLIN)) {
111 RecvUeventMsg();
112 continue;
113 }
114 if ((static_cast<uint32_t>(ufds[i].revents)) & (POLLERR | POLLHUP)) {
115 LOGE("POLLERR | POLLHUP");
116 return E_ERR;
117 }
118 }
119 }
120 return E_OK;
121 }
122
RunListener()123 void NetlinkListener::RunListener()
124 {
125 struct pollfd ufds[2];
126 int32_t idleTime = POLL_IDLE_TIME;
127
128 while (1) {
129 int32_t fdCount = 0;
130 ufds[fdCount].fd = socketPipe_[0];
131 ufds[fdCount].events = POLLIN;
132 ufds[fdCount].revents = 0;
133 fdCount++;
134
135 if (socketFd_ > -1) {
136 ufds[fdCount].fd = socketFd_;
137 ufds[fdCount].events = POLLIN;
138 ufds[fdCount].revents = 0;
139 fdCount++;
140 }
141
142 int32_t fdEventCount = poll(ufds, fdCount, idleTime);
143 if (fdEventCount < 0) {
144 if (errno == EAGAIN || errno == EINTR) {
145 continue;
146 }
147 break;
148 } else if (fdEventCount == 0) {
149 continue;
150 }
151
152 if (ReadMsg(fdCount, ufds) != 0) {
153 return;
154 }
155 }
156 }
157
EventProcess(void * object)158 void NetlinkListener::EventProcess(void *object)
159 {
160 if (object == nullptr) {
161 LOGE("object is NULL");
162 return;
163 }
164
165 NetlinkListener* client = reinterpret_cast<NetlinkListener *>(object);
166 client->RunListener();
167 }
168
StartListener()169 int32_t NetlinkListener::StartListener()
170 {
171 if (socketFd_ < 0) {
172 LOGE("socketFD < 0");
173 return E_ERR;
174 }
175
176 if (pipe(socketPipe_) == -1) {
177 LOGE("Pipe error");
178 return E_ERR;
179 }
180 socketThread_ = std::make_unique<std::thread>([this]() { this->EventProcess(static_cast<void *>(this)); });
181 if (socketThread_ == nullptr) {
182 (void)close(socketPipe_[0]);
183 (void)close(socketPipe_[1]);
184 socketPipe_[0] = -1;
185 socketPipe_[1] = -1;
186 return E_ERR;
187 }
188
189 return E_OK;
190 }
191
StopListener()192 int32_t NetlinkListener::StopListener()
193 {
194 int32_t msg = 0;
195 write(socketPipe_[1], &msg, 1);
196
197 if (socketThread_ != nullptr && socketThread_->joinable()) {
198 socketThread_->join();
199 }
200
201 (void)close(socketPipe_[0]);
202 (void)close(socketPipe_[1]);
203 socketPipe_[0] = -1;
204 socketPipe_[1] = -1;
205
206 return E_OK;
207 }
208
NetlinkListener(int32_t socket)209 NetlinkListener::NetlinkListener(int32_t socket)
210 {
211 socketFd_ = socket;
212 }
213 } // StorageDaemon
214 } // OHOS
215