• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2024 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #ifndef NETMANAGER_BASE_EPOLLER_H
17 #define NETMANAGER_BASE_EPOLLER_H
18 
19 #include <arpa/inet.h>
20 #include <fcntl.h>
21 #include <sys/epoll.h>
22 #include <sys/socket.h>
23 #include <sys/un.h>
24 #include <unistd.h>
25 
26 #include <cstring>
27 #include <functional>
28 #include <memory>
29 #include <string>
30 #include <unordered_map>
31 #include <utility>
32 
33 #include "securec.h"
34 
35 namespace OHOS::NetManagerStandard {
36 static constexpr size_t MAX_EPOLL_EVENTS = 32;
37 typedef int FileDescriptor;
38 enum class FixedLengthReceiverState {
39     ONERROR,
40     DATA_ENOUGH,
41     CONTINUE,
42 };
43 using ReceiverRunner = std::function<FixedLengthReceiverState(FileDescriptor fd, const std::string &data)>;
44 
MakeNonBlock(int sock)45 bool MakeNonBlock(int sock)
46 {
47     static constexpr uint32_t maxRetry = 30;
48     uint32_t retry = 0;
49     int flags = fcntl(sock, F_GETFL, 0);
50     while (flags == -1 && errno == EINTR && retry < maxRetry) {
51         flags = fcntl(sock, F_GETFL, 0);
52         ++retry;
53     }
54 
55     if (flags == -1) {
56         return false;
57     }
58 
59     retry = 0;
60     uint32_t tempFlags = static_cast<uint32_t>(flags) | O_NONBLOCK;
61     int ret = fcntl(sock, F_SETFL, tempFlags);
62     while (ret == -1 && errno == EINTR && retry < maxRetry) {
63         ret = fcntl(sock, F_SETFL, tempFlags);
64         ++retry;
65     }
66     if (ret == -1) {
67         return false;
68     }
69     return true;
70 }
71 
72 struct Epoller {
EpollerEpoller73     Epoller()
74     {
75         underlying_ = epoll_create1(EPOLL_CLOEXEC);
76     }
77 
~EpollerEpoller78     ~Epoller()
79     {
80         close(underlying_);
81     }
82 
83     Epoller(const Epoller &) = delete;
84     Epoller(Epoller &&) = delete;
85     Epoller &operator=(const Epoller &) = delete;
86     Epoller &operator=(const Epoller &&) = delete;
87 
RegisterMeEpoller88     void RegisterMe(FileDescriptor descriptor) const
89     {
90         RegisterMe(descriptor, EPOLLIN);
91     }
92 
RegisterMeEpoller93     void RegisterMe(FileDescriptor descriptor, uint32_t flags) const
94     {
95         epoll_event ev{};
96         ev.events = flags;
97         ev.data.fd = descriptor;
98         epoll_ctl(underlying_, EPOLL_CTL_ADD, descriptor, &ev);
99     }
100 
UnregisterMeEpoller101     void UnregisterMe(FileDescriptor descriptor) const
102     {
103         epoll_ctl(underlying_, EPOLL_CTL_DEL, descriptor, nullptr);
104     }
105 
WaitEpoller106     int Wait(epoll_event *events, int maxEvents, int timeout) const
107     {
108         return epoll_wait(underlying_, events, maxEvents, timeout);
109     }
110 
111 private:
112     FileDescriptor underlying_ = 0;
113 };
114 
115 class FixedLengthReceiver {
116 public:
117     FixedLengthReceiver() = delete;
FixedLengthReceiver(FileDescriptor clientFd,size_t neededLength,ReceiverRunner runner)118     FixedLengthReceiver(FileDescriptor clientFd, size_t neededLength, ReceiverRunner runner)
119         : fd_(clientFd), neededLength_(neededLength), runner_(std::move(runner))
120     {
121     }
122 
Run()123     FixedLengthReceiverState Run()
124     {
125         if (!runner_) {
126             return FixedLengthReceiverState::ONERROR;
127         }
128         auto res = GetData();
129         if (res == FixedLengthReceiverState::ONERROR) {
130             return res;
131         }
132         if (res == FixedLengthReceiverState::DATA_ENOUGH) {
133             return runner_(fd_, data_);
134         }
135         return FixedLengthReceiverState::CONTINUE;
136     }
137 
138 private:
GetData()139     FixedLengthReceiverState GetData()
140     {
141         if (data_.size() >= neededLength_) {
142             return FixedLengthReceiverState::DATA_ENOUGH;
143         }
144         auto size = neededLength_ - data_.size();
145         auto buf = malloc(size);
146         if (buf == nullptr) {
147             return FixedLengthReceiverState::ONERROR;
148         }
149         if (memset_s(buf, size, 0, size) != EOK) {
150             free(buf);
151             return FixedLengthReceiverState::ONERROR;
152         }
153         auto recvSize = read(fd_, buf, size);
154         if (recvSize < 0) {
155             if (errno == EINTR) {
156                 free(buf);
157                 return FixedLengthReceiverState::CONTINUE;
158             }
159             free(buf);
160             return FixedLengthReceiverState::ONERROR;
161         }
162         if (recvSize == 0) {
163             free(buf);
164             return FixedLengthReceiverState::ONERROR;
165         }
166         data_.append(reinterpret_cast<char *>(buf), recvSize);
167         free(buf);
168         return data_.size() >= neededLength_ ? FixedLengthReceiverState::DATA_ENOUGH
169                                              : FixedLengthReceiverState::CONTINUE;
170     }
171 
172     FileDescriptor fd_ = 0;
173     size_t neededLength_ = 0;
174     ReceiverRunner runner_;
175     std::string data_;
176 };
177 
178 class EpollServer {
179 public:
EpollServer(FileDescriptor serverFd,size_t firstPackageSize,ReceiverRunner firstPackageRunner)180     EpollServer(FileDescriptor serverFd, size_t firstPackageSize, ReceiverRunner firstPackageRunner)
181         : serverFd_(serverFd), firstPackageSize_(firstPackageSize), firstPackageRunner_(std::move(firstPackageRunner))
182     {
183         epoller_ = std::make_shared<Epoller>();
184         epoller_->RegisterMe(serverFd);
185     }
186 
AddReceiver(FileDescriptor clientFd,size_t neededLength,const ReceiverRunner & runner)187     void AddReceiver(FileDescriptor clientFd, size_t neededLength, const ReceiverRunner &runner)
188     {
189         auto receiver = std::make_shared<FixedLengthReceiver>(clientFd, neededLength, runner);
190         receivers_[clientFd] = receiver;
191     }
192 
Run()193     void Run()
194     {
195         while (true) {
196             static constexpr int waitTimeoutMs = 5000;
197             if (!epoller_) {
198                 return;
199             }
200 
201             epoll_event events[MAX_EPOLL_EVENTS]{};
202             int eventsToHandle = epoller_->Wait(events, MAX_EPOLL_EVENTS, receivers_.empty() ? -1 : waitTimeoutMs);
203             if (eventsToHandle == -1) {
204                 continue;
205             }
206             if (eventsToHandle == 0) {
207                 for (const auto &[fd, receiver] : receivers_) {
208                     epoller_->UnregisterMe(fd);
209                     close(fd);
210                 }
211                 receivers_.clear();
212                 continue;
213             }
214             RunForEvents(events, eventsToHandle);
215         }
216     }
217 
218 private:
RunForFd(int fd)219     void RunForFd(int fd)
220     {
221         auto receiver = receivers_[fd];
222         if (receiver) {
223             if (receiver->Run() != FixedLengthReceiverState::CONTINUE) {
224                 receivers_.erase(fd);
225                 epoller_->UnregisterMe(fd);
226                 close(fd);
227             }
228         } else {
229             // my fd, UnregisterMe and close
230             receivers_.erase(fd);
231             epoller_->UnregisterMe(fd);
232             close(fd);
233         }
234     }
235 
RunForEvents(epoll_event events[MAX_EPOLL_EVENTS],int eventsToHandle)236     void RunForEvents(epoll_event events[MAX_EPOLL_EVENTS], int eventsToHandle)
237     {
238         for (int idx = 0; idx < eventsToHandle; ++idx) {
239             if (serverFd_ == events[idx].data.fd) {
240                 sockaddr_un clientAddr{};
241                 socklen_t len = sizeof(clientAddr);
242                 auto clientFd = accept(serverFd_, reinterpret_cast<sockaddr *>(&clientAddr), &len);
243                 if (!MakeNonBlock(clientFd)) {
244                     close(clientFd);
245                     continue;
246                 }
247                 if (clientFd > 0) {
248                     epoller_->RegisterMe(clientFd);
249                     AddReceiver(clientFd, firstPackageSize_, firstPackageRunner_);
250                 }
251             } else if (receivers_.count(events[idx].data.fd) > 0) {
252                 RunForFd(events[idx].data.fd);
253             } else {
254                 // maybe not my fd, just UnregisterMe
255                 // this may not happen
256                 // not in receivers and not serverFd, just unregister
257                 epoller_->UnregisterMe(events[idx].data.fd);
258             }
259         }
260     }
261 
262     std::unordered_map<FileDescriptor, std::shared_ptr<FixedLengthReceiver>> receivers_;
263     std::shared_ptr<Epoller> epoller_;
264     FileDescriptor serverFd_ = 0;
265     size_t firstPackageSize_ = 0;
266     ReceiverRunner firstPackageRunner_;
267 };
268 } // namespace OHOS::NetManagerStandard
269 #endif // NETMANAGER_BASE_EPOLLER_H
270