1 /*
2 * Copyright (c) 2024 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #ifndef NETMANAGER_BASE_EPOLLER_RECVMSG_H
17 #define NETMANAGER_BASE_EPOLLER_RECVMSG_H
18
19 #include <arpa/inet.h>
20 #include <fcntl.h>
21 #include <sys/epoll.h>
22 #include <sys/socket.h>
23 #include <sys/un.h>
24 #include <unistd.h>
25
26 #include <cstring>
27 #include <functional>
28 #include <memory>
29 #include <string>
30 #include <unordered_set>
31 #include <utility>
32
33 #include "securec.h"
34
35 namespace OHOS::NetManagerStandard::FwmarkTool {
36 typedef int FileDescriptor;
37 using RecvMsgRunner = std::function<void(FileDescriptor fd)>;
38 static constexpr size_t MAX_EPOLL_EVENTS = 32;
39
MakeNonBlock(int sock)40 bool MakeNonBlock(int sock)
41 {
42 static constexpr uint32_t maxRetry = 30;
43 uint32_t retry = 0;
44 int flags = fcntl(sock, F_GETFL, 0);
45 while (flags == -1 && errno == EINTR && retry < maxRetry) {
46 flags = fcntl(sock, F_GETFL, 0);
47 ++retry;
48 }
49
50 if (flags == -1) {
51 return false;
52 }
53
54 retry = 0;
55 uint32_t tempFlags = static_cast<uint32_t>(flags) | O_NONBLOCK;
56 int ret = fcntl(sock, F_SETFL, tempFlags);
57 while (ret == -1 && errno == EINTR && retry < maxRetry) {
58 ret = fcntl(sock, F_SETFL, tempFlags);
59 ++retry;
60 }
61 if (ret == -1) {
62 return false;
63 }
64 return true;
65 }
66
67 struct Epoller {
EpollerEpoller68 Epoller()
69 {
70 underlying_ = epoll_create1(EPOLL_CLOEXEC);
71 }
72
~EpollerEpoller73 ~Epoller()
74 {
75 close(underlying_);
76 }
77
78 Epoller(const Epoller &) = delete;
79 Epoller(Epoller &&) = delete;
80 Epoller &operator=(const Epoller &) = delete;
81 Epoller &operator=(const Epoller &&) = delete;
82
RegisterMeEpoller83 void RegisterMe(FileDescriptor descriptor) const
84 {
85 RegisterMe(descriptor, EPOLLIN);
86 }
87
RegisterMeEpoller88 void RegisterMe(FileDescriptor descriptor, uint32_t flags) const
89 {
90 epoll_event ev{};
91 ev.events = flags;
92 ev.data.fd = descriptor;
93 epoll_ctl(underlying_, EPOLL_CTL_ADD, descriptor, &ev);
94 }
95
UnregisterMeEpoller96 void UnregisterMe(FileDescriptor descriptor) const
97 {
98 epoll_ctl(underlying_, EPOLL_CTL_DEL, descriptor, nullptr);
99 }
100
WaitEpoller101 int Wait(epoll_event *events, int maxEvents, int timeout) const
102 {
103 return epoll_wait(underlying_, events, maxEvents, timeout);
104 }
105
106 private:
107 FileDescriptor underlying_ = 0;
108 };
109
110 class FwmarkEpollServer {
111 public:
FwmarkEpollServer(FileDescriptor serverFd,RecvMsgRunner runner)112 FwmarkEpollServer(FileDescriptor serverFd, RecvMsgRunner runner) : serverFd_(serverFd), runner_(std::move(runner))
113 {
114 epoller_ = std::make_shared<Epoller>();
115 epoller_->RegisterMe(serverFd);
116 }
117
Run()118 void Run()
119 {
120 while (true) {
121 static constexpr int waitTimeoutMs = 5000;
122 if (!epoller_) {
123 return;
124 }
125
126 epoll_event events[MAX_EPOLL_EVENTS]{};
127 int eventsToHandle = epoller_->Wait(events, MAX_EPOLL_EVENTS, receivers_.empty() ? -1 : waitTimeoutMs);
128 if (eventsToHandle == -1) {
129 continue;
130 }
131 if (eventsToHandle == 0) {
132 for (const auto fd : receivers_) {
133 epoller_->UnregisterMe(fd);
134 close(fd);
135 }
136 receivers_.clear();
137 continue;
138 }
139 RunForReceivers(events, eventsToHandle);
140 }
141 }
142
143 private:
RunForReceivers(epoll_event events[MAX_EPOLL_EVENTS],int eventsToHandle)144 void RunForReceivers(epoll_event events[MAX_EPOLL_EVENTS], int eventsToHandle)
145 {
146 for (int idx = 0; idx < eventsToHandle; ++idx) {
147 if (serverFd_ == events[idx].data.fd) {
148 sockaddr_un clientAddr{};
149 socklen_t len = sizeof(clientAddr);
150 auto clientFd = accept(serverFd_, reinterpret_cast<sockaddr *>(&clientAddr), &len);
151 if (!MakeNonBlock(clientFd)) {
152 close(clientFd);
153 continue;
154 }
155 if (clientFd > 0) {
156 epoller_->RegisterMe(clientFd);
157 receivers_.insert(clientFd);
158 }
159 } else if (receivers_.count(events[idx].data.fd) > 0) {
160 epoller_->UnregisterMe(events[idx].data.fd);
161 receivers_.erase(events[idx].data.fd);
162 if (runner_) {
163 runner_(events[idx].data.fd);
164 } else {
165 close(events[idx].data.fd);
166 }
167 } else {
168 // maybe not my fd, just UnregisterMe
169 // this may not happen
170 // not in receivers and not serverFd, just unregister
171 epoller_->UnregisterMe(events[idx].data.fd);
172 }
173 }
174 }
175
176 std::shared_ptr<Epoller> epoller_;
177 FileDescriptor serverFd_ = 0;
178 RecvMsgRunner runner_;
179 std::unordered_set<FileDescriptor> receivers_;
180 };
181 } // namespace OHOS::NetManagerStandard::FwmarkTool
182 #endif // NETMANAGER_BASE_EPOLLER_RECVMSG_H
183