• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2023 Shenzhen Kaihong Digital Industry Development Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "udp_server.h"
17 #include <algorithm>
18 #include <arpa/inet.h>
19 #include <iostream>
20 #include <securec.h>
21 #include <unistd.h>
22 #include "common/media_log.h"
23 #include "network/session/udp_session.h"
24 #include "network/socket/socket_utils.h"
25 #include "network/socket/udp_socket.h"
26 #include "utils/utils.h"
27 namespace OHOS {
28 namespace Sharing {
Start(uint16_t port,const std::string & host,bool enableReuse,uint32_t backlog)29 bool UdpServer::Start(uint16_t port, const std::string &host, bool enableReuse, uint32_t backlog)
30 {
31     SHARING_LOGD("server ip:%{public}s, Port:%{public}d, thread_id: %{public}llu.",
32         GetAnonymousIp(host).c_str(), port, GetThreadId());
33     std::unique_lock<std::shared_mutex> lk(mutex_);
34     socket_ = std::make_unique<UdpSocket>();
35     if (socket_) {
36         if (socket_->Bind(port, host, enableReuse)) {
37             SHARING_LOGD("start success, fd: %{public}d.", socket_->GetLocalFd());
38 
39             auto eventRunner = OHOS::AppExecFwk::EventRunner::Create(true);
40             eventHandler_ = std::make_shared<UdpServerEventHandler>();
41             eventHandler_->SetServer(shared_from_this());
42             eventHandler_->SetEventRunner(eventRunner);
43             eventRunner->Run();
44 
45             eventListener_ = std::make_shared<UdpServerEventListener>();
46             eventListener_->SetServer(shared_from_this());
47 
48             return eventListener_->AddFdListener(socket_->GetLocalFd(), eventListener_, eventHandler_);
49         }
50     }
51 
52     SHARING_LOGE("start failed!");
53     return false;
54 }
55 
~UdpServer()56 UdpServer::~UdpServer()
57 {
58     SHARING_LOGD("trace.");
59     Stop();
60 }
61 
UdpServer()62 UdpServer::UdpServer()
63 {
64     SHARING_LOGD("trace.");
65 }
66 
Stop()67 void UdpServer::Stop()
68 {
69     SHARING_LOGD("stop.");
70     std::unique_lock<std::shared_mutex> lk(mutex_);
71 
72     for (auto kv : sessionMap_) {
73         if (kv.second) {
74             kv.second->Shutdown();
75             kv.second.reset();
76         }
77     }
78 
79     if (socket_ != nullptr) {
80         if (eventListener_) {
81             eventListener_->RemoveFdListener(socket_->GetLocalFd());
82         }
83         SocketUtils::ShutDownSocket(socket_->GetLocalFd());
84         SocketUtils::CloseSocket(socket_->GetLocalFd());
85         socket_.reset();
86     }
87 }
88 
GetSocketInfo()89 SocketInfo::Ptr UdpServer::GetSocketInfo()
90 {
91     SHARING_LOGD("trace.");
92     return socket_;
93 }
94 
CloseClientSocket(int32_t fd)95 void UdpServer::CloseClientSocket(int32_t fd)
96 {
97     SHARING_LOGD("fd: %{public}d.", fd);
98     std::unique_lock<std::shared_mutex> lk(mutex_);
99     if (fd > 0) {
100         auto itemItr = sessionMap_.find(fd);
101         if (itemItr != sessionMap_.end()) {
102             if (itemItr->second) {
103                 itemItr->second->Shutdown();
104                 itemItr->second.reset();
105             }
106             SocketUtils::CloseSocket(fd);
107             sessionMap_.erase(itemItr);
108             SHARING_LOGD("erase fd: %{public}d.", fd);
109         }
110     }
111 }
112 
OnServerReadable(int32_t fd)113 void UdpServer::OnServerReadable(int32_t fd)
114 {
115     MEDIA_LOGD("fd: %{public}d, thread_id: %{public}llu tid:%{public}d", fd, GetThreadId(), gettid());
116 
117     std::shared_lock<std::shared_mutex> lk(mutex_);
118     if (socket_ == nullptr) {
119         SHARING_LOGE("onReadable socket null!");
120         return;
121     }
122 
123     if (fd != socket_->GetLocalFd()) {
124         SHARING_LOGE("onReadable receive msg!");
125         return;
126     }
127 
128     auto callback = callback_.lock();
129     if (callback == nullptr) {
130         SHARING_LOGE("callback null!");
131         return;
132     }
133 
134     int32_t retry = 0;
135     int32_t retCode = 0;
136     bool firstRead = true;
137     bool reading = true;
138     while (reading) {
139         DataBuffer::Ptr buf = std::make_shared<DataBuffer>(DEFAULT_READ_BUFFER_SIZE);
140         struct sockaddr_in clientAddr;
141         socklen_t len = sizeof(struct sockaddr_in);
142         retCode = ::recvfrom(fd, buf->Data(), DEFAULT_READ_BUFFER_SIZE, 0, (struct sockaddr *)&clientAddr, &len);
143         MEDIA_LOGD("recvSocket len: %{public}d,address: %{public}s,port: %{public}d,socklen: %{public}d.", retCode,
144                    inet_ntoa(clientAddr.sin_addr), clientAddr.sin_port, len);
145 
146         if (retCode < 0) {
147             if (errno != EAGAIN) {
148                 char errmsg[256] = {0};
149                 strerror_r(errno, errmsg, sizeof(errmsg));
150                 MEDIA_LOGD("on read data error %{public}d : %{public}s!", errno, errmsg);
151                 callback->OnServerException(fd);
152                 break;
153             }
154 
155             if (firstRead && retry < 5) { // 5: retry 5 times
156                 char errmsg[256] = {0};
157                 strerror_r(errno, errmsg, sizeof(errmsg));
158                 SHARING_LOGE("first read error %{public}d : %{public}s retry: %{public}d", errno, errmsg,
159                              retry);
160                 usleep(1000 * 5); // 1000 * 5: sleep 1000 * 5 millionseconds
161                 retry++;
162                 continue;
163             }
164             break;
165         }
166 
167         if (retCode > 0) {
168             firstRead = false;
169             buf->UpdateSize(retCode);
170             BaseNetworkSession::Ptr session = FindOrCreateSession(clientAddr);
171             if (session) {
172                 callback->OnServerReadData(fd, std::move(buf), session);
173             }
174         } else {
175             char errmsg[256] = {0};
176             strerror_r(errno, errmsg, sizeof(errmsg));
177             SHARING_LOGE("onReadable error: %{public}s!", errmsg);
178             break;
179         }
180     }
181 
182     MEDIA_LOGE("fd: %{public}d, thread_id: %{public}llu tid:%{public}d exit.", fd, GetThreadId(), gettid());
183 }
184 
FindOrCreateSession(const struct sockaddr_in & addr)185 std::shared_ptr<BaseNetworkSession> UdpServer::FindOrCreateSession(const struct sockaddr_in &addr)
186 {
187     MEDIA_LOGD("trace.");
188 
189     auto it = std::find_if(addrToFdMap_.begin(), addrToFdMap_.end(),
190         [&addr](std::pair<std::shared_ptr<struct sockaddr_in>, int32_t> value) {
191             return value.first->sin_addr.s_addr == addr.sin_addr.s_addr && value.first->sin_port == addr.sin_port;
192         });
193     if (it != addrToFdMap_.end()) {
194         return sessionMap_[it->second];
195     } else if (socket_ != nullptr) {
196         MEDIA_LOGD("not find, create session!");
197         int32_t peerFd = 0;
198         bool createSocketResult = SocketUtils::CreateSocket(SOCK_DGRAM, peerFd);
199         if (!createSocketResult || !BindAndConnectClinetFd(peerFd, addr)) {
200             SHARING_LOGE("create socket failed!");
201             return nullptr;
202         }
203 
204         SocketInfo::Ptr socketInfo =
205             std::make_shared<SocketInfo>(socket_->GetLocalIp(), inet_ntoa(addr.sin_addr), socket_->GetLocalFd(), peerFd,
206                                          socket_->GetLocalPort(), addr.sin_port);
207         if (socketInfo == nullptr) {
208             SHARING_LOGE("create socket info failed!");
209             return nullptr;
210         }
211         auto ret = memcpy_s(&socketInfo->udpClientAddr_, sizeof(struct sockaddr_in), &addr, sizeof(struct sockaddr_in));
212         if (ret != EOK) {
213             MEDIA_LOGE("mem copy data failed.");
214             SocketUtils::CloseSocket(peerFd);
215             return nullptr;
216         }
217         socketInfo->SetSocketType(SOCKET_TYPE_UDP);
218 
219         BaseNetworkSession::Ptr session = std::make_shared<UdpSession>(std::move(socketInfo));
220         if (session) {
221             auto peerAddr = std::make_shared<struct sockaddr_in>();
222             auto ret = memcpy_s(peerAddr.get(), sizeof(struct sockaddr_in), &addr, sizeof(struct sockaddr_in));
223             if (ret != EOK) {
224                 MEDIA_LOGE("mem copy data failed.");
225                 SocketUtils::CloseSocket(peerFd);
226                 return nullptr;
227             }
228             addrToFdMap_.insert(make_pair(peerAddr, peerFd));
229             sessionMap_.insert(make_pair(peerFd, std::move(session)));
230             auto callback = callback_.lock();
231             if (callback) {
232                 callback->OnAccept(sessionMap_[peerFd]);
233             }
234 
235             return sessionMap_[peerFd];
236         }
237     }
238 
239     return nullptr;
240 }
241 
BindAndConnectClinetFd(int32_t fd,const struct sockaddr_in & addr)242 bool UdpServer::BindAndConnectClinetFd(int32_t fd, const struct sockaddr_in &addr)
243 {
244     SHARING_LOGD("trace.");
245 
246     int32_t ret = 0;
247     SocketUtils::SetNonBlocking(fd);
248     SocketUtils::SetReusePort(fd, true);
249     SocketUtils::SetReuseAddr(fd, true);
250     SocketUtils::SetSendBuf(fd);
251     SocketUtils::SetRecvBuf(fd);
252 
253     if (!SocketUtils::BindSocket(fd, "", socket_->GetLocalPort())) {
254         SocketUtils::ShutDownSocket(fd);
255         SHARING_LOGE("bind BindSocket Failed!");
256         return false;
257     }
258 
259     SocketUtils::ConnectSocket(fd, true, inet_ntoa(addr.sin_addr), addr.sin_port, ret);
260     if (ret < 0 && (errno != EINPROGRESS)) {
261         char errmsg[256] = {0};
262         strerror_r(errno, errmsg, sizeof(errmsg));
263         SHARING_LOGE("connectSocket error: %{public}s!", errmsg);
264         SocketUtils::CloseSocket(fd);
265         return false;
266     }
267 
268     return true;
269 }
270 
271 } // namespace Sharing
272 } // namespace OHOS
273