1 /*
2 * Copyright (c) 2024 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #ifndef NETMANAGER_BASE_EPOLLER_H
17 #define NETMANAGER_BASE_EPOLLER_H
18
19 #include <arpa/inet.h>
20 #include <fcntl.h>
21 #include <sys/epoll.h>
22 #include <sys/socket.h>
23 #include <sys/un.h>
24 #include <unistd.h>
25
26 #include <cstring>
27 #include <functional>
28 #include <memory>
29 #include <string>
30 #include <unordered_map>
31 #include <utility>
32
33 #include "securec.h"
34
35 namespace OHOS::NetManagerStandard {
36 static constexpr size_t MAX_EPOLL_EVENTS = 32;
37 typedef int FileDescriptor;
38 enum class FixedLengthReceiverState {
39 ONERROR,
40 DATA_ENOUGH,
41 CONTINUE,
42 };
43 using ReceiverRunner = std::function<FixedLengthReceiverState(FileDescriptor fd, const std::string &data)>;
44
MakeNonBlock(int sock)45 bool MakeNonBlock(int sock)
46 {
47 static constexpr uint32_t maxRetry = 30;
48 uint32_t retry = 0;
49 int flags = fcntl(sock, F_GETFL, 0);
50 while (flags == -1 && errno == EINTR && retry < maxRetry) {
51 flags = fcntl(sock, F_GETFL, 0);
52 ++retry;
53 }
54
55 if (flags == -1) {
56 return false;
57 }
58
59 retry = 0;
60 uint32_t tempFlags = static_cast<uint32_t>(flags) | O_NONBLOCK;
61 int ret = fcntl(sock, F_SETFL, tempFlags);
62 while (ret == -1 && errno == EINTR && retry < maxRetry) {
63 ret = fcntl(sock, F_SETFL, tempFlags);
64 ++retry;
65 }
66 if (ret == -1) {
67 return false;
68 }
69 return true;
70 }
71
72 struct Epoller {
EpollerEpoller73 Epoller()
74 {
75 underlying_ = epoll_create1(EPOLL_CLOEXEC);
76 }
77
~EpollerEpoller78 ~Epoller()
79 {
80 close(underlying_);
81 }
82
83 Epoller(const Epoller &) = delete;
84 Epoller(Epoller &&) = delete;
85 Epoller &operator=(const Epoller &) = delete;
86 Epoller &operator=(const Epoller &&) = delete;
87
RegisterMeEpoller88 void RegisterMe(FileDescriptor descriptor) const
89 {
90 RegisterMe(descriptor, EPOLLIN);
91 }
92
RegisterMeEpoller93 void RegisterMe(FileDescriptor descriptor, uint32_t flags) const
94 {
95 epoll_event ev{};
96 ev.events = flags;
97 ev.data.fd = descriptor;
98 epoll_ctl(underlying_, EPOLL_CTL_ADD, descriptor, &ev);
99 }
100
UnregisterMeEpoller101 void UnregisterMe(FileDescriptor descriptor) const
102 {
103 epoll_ctl(underlying_, EPOLL_CTL_DEL, descriptor, nullptr);
104 }
105
WaitEpoller106 int Wait(epoll_event *events, int maxEvents, int timeout) const
107 {
108 return epoll_wait(underlying_, events, maxEvents, timeout);
109 }
110
111 private:
112 FileDescriptor underlying_ = 0;
113 };
114
115 class FixedLengthReceiver {
116 public:
117 FixedLengthReceiver() = delete;
FixedLengthReceiver(FileDescriptor clientFd,size_t neededLength,ReceiverRunner runner)118 FixedLengthReceiver(FileDescriptor clientFd, size_t neededLength, ReceiverRunner runner)
119 : fd_(clientFd), neededLength_(neededLength), runner_(std::move(runner))
120 {
121 }
122
Run()123 FixedLengthReceiverState Run()
124 {
125 if (!runner_) {
126 return FixedLengthReceiverState::ONERROR;
127 }
128 auto res = GetData();
129 if (res == FixedLengthReceiverState::ONERROR) {
130 return res;
131 }
132 if (res == FixedLengthReceiverState::DATA_ENOUGH) {
133 return runner_(fd_, data_);
134 }
135 return FixedLengthReceiverState::CONTINUE;
136 }
137
138 private:
GetData()139 FixedLengthReceiverState GetData()
140 {
141 if (data_.size() >= neededLength_) {
142 return FixedLengthReceiverState::DATA_ENOUGH;
143 }
144 auto size = neededLength_ - data_.size();
145 auto buf = malloc(size);
146 if (buf == nullptr) {
147 return FixedLengthReceiverState::ONERROR;
148 }
149 if (memset_s(buf, size, 0, size) != EOK) {
150 free(buf);
151 return FixedLengthReceiverState::ONERROR;
152 }
153 auto recvSize = read(fd_, buf, size);
154 if (recvSize < 0) {
155 if (errno == EINTR) {
156 free(buf);
157 return FixedLengthReceiverState::CONTINUE;
158 }
159 free(buf);
160 return FixedLengthReceiverState::ONERROR;
161 }
162 if (recvSize == 0) {
163 free(buf);
164 return FixedLengthReceiverState::ONERROR;
165 }
166 data_.append(reinterpret_cast<char *>(buf), recvSize);
167 free(buf);
168 return data_.size() >= neededLength_ ? FixedLengthReceiverState::DATA_ENOUGH
169 : FixedLengthReceiverState::CONTINUE;
170 }
171
172 FileDescriptor fd_ = 0;
173 size_t neededLength_ = 0;
174 ReceiverRunner runner_;
175 std::string data_;
176 };
177
178 class EpollServer {
179 public:
EpollServer(FileDescriptor serverFd,size_t firstPackageSize,ReceiverRunner firstPackageRunner)180 EpollServer(FileDescriptor serverFd, size_t firstPackageSize, ReceiverRunner firstPackageRunner)
181 : serverFd_(serverFd), firstPackageSize_(firstPackageSize), firstPackageRunner_(std::move(firstPackageRunner))
182 {
183 epoller_ = std::make_shared<Epoller>();
184 epoller_->RegisterMe(serverFd);
185 }
186
AddReceiver(FileDescriptor clientFd,size_t neededLength,const ReceiverRunner & runner)187 void AddReceiver(FileDescriptor clientFd, size_t neededLength, const ReceiverRunner &runner)
188 {
189 auto receiver = std::make_shared<FixedLengthReceiver>(clientFd, neededLength, runner);
190 receivers_[clientFd] = receiver;
191 }
192
Run()193 void Run()
194 {
195 while (true) {
196 static constexpr int waitTimeoutMs = 5000;
197 if (!epoller_) {
198 return;
199 }
200
201 epoll_event events[MAX_EPOLL_EVENTS]{};
202 int eventsToHandle = epoller_->Wait(events, MAX_EPOLL_EVENTS, receivers_.empty() ? -1 : waitTimeoutMs);
203 if (eventsToHandle == -1) {
204 continue;
205 }
206 if (eventsToHandle == 0) {
207 for (const auto &[fd, receiver] : receivers_) {
208 epoller_->UnregisterMe(fd);
209 close(fd);
210 }
211 receivers_.clear();
212 continue;
213 }
214 RunForEvents(events, eventsToHandle);
215 }
216 }
217
218 private:
RunForFd(int fd)219 void RunForFd(int fd)
220 {
221 auto receiver = receivers_[fd];
222 if (receiver) {
223 if (receiver->Run() != FixedLengthReceiverState::CONTINUE) {
224 receivers_.erase(fd);
225 epoller_->UnregisterMe(fd);
226 close(fd);
227 }
228 } else {
229 // my fd, UnregisterMe and close
230 receivers_.erase(fd);
231 epoller_->UnregisterMe(fd);
232 close(fd);
233 }
234 }
235
RunForEvents(epoll_event events[MAX_EPOLL_EVENTS],int eventsToHandle)236 void RunForEvents(epoll_event events[MAX_EPOLL_EVENTS], int eventsToHandle)
237 {
238 for (int idx = 0; idx < eventsToHandle; ++idx) {
239 if (serverFd_ == events[idx].data.fd) {
240 sockaddr_un clientAddr{};
241 socklen_t len = sizeof(clientAddr);
242 auto clientFd = accept(serverFd_, reinterpret_cast<sockaddr *>(&clientAddr), &len);
243 if (!MakeNonBlock(clientFd)) {
244 close(clientFd);
245 continue;
246 }
247 if (clientFd > 0) {
248 epoller_->RegisterMe(clientFd);
249 AddReceiver(clientFd, firstPackageSize_, firstPackageRunner_);
250 }
251 } else if (receivers_.count(events[idx].data.fd) > 0) {
252 RunForFd(events[idx].data.fd);
253 } else {
254 // maybe not my fd, just UnregisterMe
255 // this may not happen
256 // not in receivers and not serverFd, just unregister
257 epoller_->UnregisterMe(events[idx].data.fd);
258 }
259 }
260 }
261
262 std::unordered_map<FileDescriptor, std::shared_ptr<FixedLengthReceiver>> receivers_;
263 std::shared_ptr<Epoller> epoller_;
264 FileDescriptor serverFd_ = 0;
265 size_t firstPackageSize_ = 0;
266 ReceiverRunner firstPackageRunner_;
267 };
268 } // namespace OHOS::NetManagerStandard
269 #endif // NETMANAGER_BASE_EPOLLER_H
270