1 /*
2 * Copyright (c) 2021 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "netlink/netlink_listener.h"
17
18 #include <cerrno>
19 #include <iostream>
20 #include <memory>
21 #include <sys/socket.h>
22 #include <unistd.h>
23 #include <linux/netlink.h>
24
25 #include "securec.h"
26 #include "storage_service_errno.h"
27 #include "storage_service_log.h"
28
29 constexpr int POLL_IDLE_TIME = 1000;
30 constexpr int UEVENT_MSG_LEN = 1024;
31
32 namespace OHOS {
33 namespace StorageDaemon {
UeventKernelMulticastRecv(int32_t socket,char * buffer,size_t length)34 ssize_t UeventKernelMulticastRecv(int32_t socket, char *buffer, size_t length)
35 {
36 struct iovec iov = { buffer, length };
37 struct sockaddr_nl addr;
38 char control[CMSG_SPACE(sizeof(struct ucred))];
39 struct msghdr hdr = {
40 .msg_name = &addr,
41 .msg_namelen = sizeof(addr),
42 .msg_iov = &iov,
43 .msg_iovlen = 1,
44 .msg_control = control,
45 .msg_controllen = sizeof(control),
46 .msg_flags = 0,
47 };
48 struct cmsghdr *cmsg;
49
50 ssize_t n = recvmsg(socket, &hdr, 0);
51 if (n <= 0) {
52 LOGE("Recvmsg failed, errno %{public}d", errno);
53 return n;
54 }
55
56 if (addr.nl_groups == 0 || addr.nl_pid != 0) {
57 LOGE("Groups or pid check failed");
58 return E_ERR;
59 }
60
61 cmsg = CMSG_FIRSTHDR(&hdr);
62 if (cmsg == nullptr || cmsg->cmsg_type != SCM_CREDENTIALS) {
63 LOGE("SCM_CREDENTIALS check failed");
64 return E_ERR;
65 }
66
67 struct ucred cred;
68 if (memcpy_s(&cred, sizeof(cred), CMSG_DATA(cmsg), sizeof(struct ucred)) != EOK || cred.uid != 0) {
69 LOGE("Uid check failed");
70 return E_ERR;
71 }
72
73 return n;
74 }
75
RecvUeventMsg()76 void NetlinkListener::RecvUeventMsg()
77 {
78 auto msg = std::make_unique<char[]>(UEVENT_MSG_LEN + 1);
79 ssize_t count;
80
81 while (1) {
82 count = UeventKernelMulticastRecv(socketFd_, msg.get(), UEVENT_MSG_LEN);
83 if (count <= 0) {
84 (void)memset_s(msg.get(), UEVENT_MSG_LEN + 1, 0, UEVENT_MSG_LEN + 1);
85 break;
86 }
87 if (count >= UEVENT_MSG_LEN) {
88 continue;
89 }
90
91 msg.get()[count] = '\0';
92 OnEvent(msg.get());
93 }
94 }
95
ReadMsg(int32_t fd_count,struct pollfd ufds[2])96 int32_t NetlinkListener::ReadMsg(int32_t fd_count, struct pollfd ufds[2])
97 {
98 int32_t i;
99 for (i = 0; i < fd_count; i++) {
100 if (ufds[i].revents == 0) {
101 continue;
102 }
103
104 if (ufds[i].fd == socketPipe_[0]) {
105 int32_t msg = 0;
106 if (read(socketPipe_[0], &msg, 1) < 0) {
107 LOGE("Read socket pipe failed");
108 return E_ERR;
109 }
110 if (msg == 0) {
111 LOGI("Stop listener");
112 return E_ERR;
113 }
114 } else if (ufds[i].fd == socketFd_) {
115 if ((static_cast<uint32_t>(ufds[i].revents) & POLLIN)) {
116 RecvUeventMsg();
117 continue;
118 }
119 if ((static_cast<uint32_t>(ufds[i].revents)) & (POLLERR | POLLHUP)) {
120 LOGE("POLLERR | POLLHUP");
121 return E_ERR;
122 }
123 }
124 }
125 return E_OK;
126 }
127
RunListener()128 void NetlinkListener::RunListener()
129 {
130 struct pollfd ufds[2];
131 int32_t idle_time = POLL_IDLE_TIME;
132
133 while (1) {
134 int32_t fd_count = 0;
135
136 ufds[fd_count].fd = socketPipe_[0];
137 ufds[fd_count].events = POLLIN;
138 ufds[fd_count].revents = 0;
139 fd_count++;
140
141 if (socketFd_ > -1) {
142 ufds[fd_count].fd = socketFd_;
143 ufds[fd_count].events = POLLIN;
144 ufds[fd_count].revents = 0;
145 fd_count++;
146 }
147
148 int32_t n = poll(ufds, fd_count, idle_time);
149 if (n < 0) {
150 if (errno == EAGAIN || errno == EINTR) {
151 continue;
152 }
153 break;
154 } else if (!n) {
155 continue;
156 }
157
158 if (ReadMsg(fd_count, ufds) != 0) {
159 return;
160 }
161 }
162 }
163
EventProcess(void * object)164 void NetlinkListener::EventProcess(void *object)
165 {
166 if (object == nullptr) {
167 return;
168 }
169
170 NetlinkListener* me = reinterpret_cast<NetlinkListener *>(object);
171 me->RunListener();
172 }
173
StartListener()174 int32_t NetlinkListener::StartListener()
175 {
176 if (socketFd_ < 0) {
177 return E_ERR;
178 }
179
180 if (pipe(socketPipe_) == -1) {
181 LOGE("Pipe error");
182 return E_ERR;
183 }
184 socketThread_ = std::make_unique<std::thread>(&NetlinkListener::EventProcess, this);
185 if (socketThread_ == nullptr) {
186 (void)close(socketPipe_[0]);
187 (void)close(socketPipe_[1]);
188 socketPipe_[0] = socketPipe_[1] = -1;
189 return E_ERR;
190 }
191
192 return E_OK;
193 }
194
StopListener()195 int32_t NetlinkListener::StopListener()
196 {
197 int32_t msg = 0;
198 write(socketPipe_[1], &msg, 1);
199
200 if (socketThread_ != nullptr && socketThread_->joinable()) {
201 socketThread_->join();
202 }
203
204 (void)close(socketPipe_[0]);
205 (void)close(socketPipe_[1]);
206 socketPipe_[0] = socketPipe_[1] = -1;
207
208 return E_OK;
209 }
210
NetlinkListener(int32_t socket)211 NetlinkListener::NetlinkListener(int32_t socket)
212 {
213 socketFd_ = socket;
214 }
215 } // StorageDaemon
216 } // OHOS
217