1 /*
2 * Copyright (c) 2021 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "netlink/netlink_listener.h"
17
18 #include <sys/socket.h>
19 #include <linux/netlink.h>
20
21 #include "securec.h"
22 #include "storage_service_errno.h"
23 #include "storage_service_log.h"
24
25 constexpr int POLL_IDLE_TIME = 1000;
26 constexpr int UEVENT_MSG_LEN = 1024;
27
28 namespace OHOS {
29 namespace StorageDaemon {
UeventKernelMulticastRecv(int32_t socket,char * buffer,size_t length)30 ssize_t UeventKernelMulticastRecv(int32_t socket, char *buffer, size_t length)
31 {
32 struct iovec iov = { buffer, length };
33 struct sockaddr_nl addr;
34 char control[CMSG_SPACE(sizeof(struct ucred))];
35 struct msghdr hdr = {
36 .msg_name = &addr,
37 .msg_namelen = sizeof(addr),
38 .msg_iov = &iov,
39 .msg_iovlen = 1,
40 .msg_control = control,
41 .msg_controllen = sizeof(control),
42 .msg_flags = 0,
43 };
44 struct cmsghdr *cmsg;
45
46 ssize_t n = recvmsg(socket, &hdr, 0);
47 if (n <= 0) {
48 LOGE("Recvmsg failed, errno %{public}d", errno);
49 return n;
50 }
51
52 if (addr.nl_groups == 0 || addr.nl_pid != 0) {
53 return E_ERR;
54 }
55
56 cmsg = CMSG_FIRSTHDR(&hdr);
57 if (cmsg == nullptr || cmsg->cmsg_type != SCM_CREDENTIALS) {
58 LOGE("SCM_CREDENTIALS check failed");
59 return E_ERR;
60 }
61
62 struct ucred cred;
63 if (memcpy_s(&cred, sizeof(cred), CMSG_DATA(cmsg), sizeof(struct ucred)) != EOK || cred.uid != 0) {
64 LOGE("Uid check failed");
65 return E_ERR;
66 }
67
68 return n;
69 }
70
RecvUeventMsg()71 void NetlinkListener::RecvUeventMsg()
72 {
73 auto msg = std::make_unique<char[]>(UEVENT_MSG_LEN + 1);
74
75 while (1) {
76 auto count = UeventKernelMulticastRecv(socketFd_, msg.get(), UEVENT_MSG_LEN);
77 if (count <= 0) {
78 (void)memset_s(msg.get(), UEVENT_MSG_LEN + 1, 0, UEVENT_MSG_LEN + 1);
79 break;
80 }
81 if (count >= UEVENT_MSG_LEN) {
82 continue;
83 }
84
85 msg.get()[count] = '\0';
86 OnEvent(msg.get());
87 }
88 }
89
ReadMsg(int32_t fd_count,struct pollfd ufds[2])90 int32_t NetlinkListener::ReadMsg(int32_t fd_count, struct pollfd ufds[2])
91 {
92 int32_t i;
93 for (i = 0; i < fd_count; i++) {
94 if (ufds[i].revents == 0) {
95 continue;
96 }
97
98 if (ufds[i].fd == socketPipe_[0]) {
99 int32_t msg = 0;
100 if (read(socketPipe_[0], &msg, 1) < 0) {
101 LOGE("Read socket pipe failed");
102 return E_ERR;
103 }
104 if (msg == 0) {
105 LOGE("Stop listener");
106 return E_ERR;
107 }
108 } else if (ufds[i].fd == socketFd_) {
109 if ((static_cast<uint32_t>(ufds[i].revents) & POLLIN)) {
110 RecvUeventMsg();
111 continue;
112 }
113 if ((static_cast<uint32_t>(ufds[i].revents)) & (POLLERR | POLLHUP)) {
114 LOGE("POLLERR | POLLHUP");
115 return E_ERR;
116 }
117 }
118 }
119 return E_OK;
120 }
121
RunListener()122 void NetlinkListener::RunListener()
123 {
124 struct pollfd ufds[2];
125 int32_t idleTime = POLL_IDLE_TIME;
126
127 while (1) {
128 int32_t fdCount = 0;
129 ufds[fdCount].fd = socketPipe_[0];
130 ufds[fdCount].events = POLLIN;
131 ufds[fdCount].revents = 0;
132 fdCount++;
133
134 if (socketFd_ > -1) {
135 ufds[fdCount].fd = socketFd_;
136 ufds[fdCount].events = POLLIN;
137 ufds[fdCount].revents = 0;
138 fdCount++;
139 }
140
141 int32_t fdEventCount = poll(ufds, fdCount, idleTime);
142 if (fdEventCount < 0) {
143 if (errno == EAGAIN || errno == EINTR) {
144 continue;
145 }
146 break;
147 } else if (fdEventCount == 0) {
148 continue;
149 }
150
151 if (ReadMsg(fdCount, ufds) != 0) {
152 return;
153 }
154 }
155 }
156
EventProcess(void * object)157 void NetlinkListener::EventProcess(void *object)
158 {
159 if (object == nullptr) {
160 LOGE("object is NULL");
161 return;
162 }
163
164 NetlinkListener* client = reinterpret_cast<NetlinkListener *>(object);
165 client->RunListener();
166 }
167
StartListener()168 int32_t NetlinkListener::StartListener()
169 {
170 if (socketFd_ < 0) {
171 LOGE("socketFD < 0");
172 return E_ERR;
173 }
174
175 if (pipe(socketPipe_) == -1) {
176 LOGE("Pipe error");
177 return E_ERR;
178 }
179 socketThread_ = std::make_unique<std::thread>([this]() { this->EventProcess(static_cast<void *>(this)); });
180 if (socketThread_ == nullptr) {
181 (void)close(socketPipe_[0]);
182 (void)close(socketPipe_[1]);
183 socketPipe_[0] = -1;
184 socketPipe_[1] = -1;
185 return E_ERR;
186 }
187
188 return E_OK;
189 }
190
StopListener()191 int32_t NetlinkListener::StopListener()
192 {
193 int32_t msg = 0;
194 write(socketPipe_[1], &msg, 1);
195
196 if (socketThread_ != nullptr && socketThread_->joinable()) {
197 socketThread_->join();
198 }
199
200 (void)close(socketPipe_[0]);
201 (void)close(socketPipe_[1]);
202 socketPipe_[0] = -1;
203 socketPipe_[1] = -1;
204
205 return E_OK;
206 }
207
NetlinkListener(int32_t socket)208 NetlinkListener::NetlinkListener(int32_t socket)
209 {
210 socketFd_ = socket;
211 }
212 } // StorageDaemon
213 } // OHOS
214