1 /*
2 * Copyright (C) 2023 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "mdns_manager.h"
17
18 #include <unistd.h>
19
20 #include "mdns_event_proxy.h"
21 #include "netmgr_ext_log_wrapper.h"
22
23 namespace OHOS {
24 namespace NetManagerStandard {
25
MDnsManager()26 MDnsManager::MDnsManager()
27 {
28 InitHandler();
29 }
30
RegisterService(const MDnsServiceInfo & serviceInfo,const sptr<IRegistrationCallback> & cb)31 int32_t MDnsManager::RegisterService(const MDnsServiceInfo &serviceInfo, const sptr<IRegistrationCallback> &cb)
32 {
33 if (cb == nullptr) {
34 NETMGR_EXT_LOG_E("callback is nullptr");
35 return NET_MDNS_ERR_ILLEGAL_ARGUMENT;
36 }
37 std::lock_guard<std::mutex> guard(mutex_);
38 if (std::find_if(registerMap_.begin(), registerMap_.end(), [cb](const auto &elem) {
39 return elem.first->AsObject().GetRefPtr() == cb->AsObject().GetRefPtr();
40 }) != registerMap_.end()) {
41 return NET_MDNS_ERR_CALLBACK_DUPLICATED;
42 }
43 MDnsProtocolImpl::Result result{.serviceName = serviceInfo.name,
44 .serviceType = serviceInfo.type,
45 .port = serviceInfo.port,
46 .txt = serviceInfo.txtRecord};
47 int32_t err = impl.Register(result);
48 impl.RunTaskLater([this, cb, serviceInfo, err]() { cb->HandleRegisterResult(serviceInfo, err); });
49 if (err == NETMANAGER_EXT_SUCCESS) {
50 registerMap_.emplace_back(cb, serviceInfo.name + MDNS_DOMAIN_SPLITER_STR + serviceInfo.type);
51 }
52 return err;
53 }
54
UnRegisterService(const sptr<IRegistrationCallback> & cb)55 int32_t MDnsManager::UnRegisterService(const sptr<IRegistrationCallback> &cb)
56 {
57 if (cb == nullptr) {
58 NETMGR_EXT_LOG_E("callback is nullptr");
59 return NET_MDNS_ERR_ILLEGAL_ARGUMENT;
60 }
61 std::lock_guard<std::mutex> guard(mutex_);
62 auto local = std::find_if(registerMap_.begin(), registerMap_.end(), [cb](const auto &elem) {
63 return elem.first->AsObject().GetRefPtr() == cb->AsObject().GetRefPtr();
64 });
65 if (local == registerMap_.end()) {
66 return NET_MDNS_ERR_CALLBACK_NOT_FOUND;
67 }
68 int32_t err = impl.UnRegister(local->second);
69 if (err == NETMANAGER_EXT_SUCCESS) {
70 registerMap_.erase(local);
71 }
72 return err;
73 }
74
StartDiscoverService(const std::string & serviceType,const sptr<IDiscoveryCallback> & cb)75 int32_t MDnsManager::StartDiscoverService(const std::string &serviceType, const sptr<IDiscoveryCallback> &cb)
76 {
77 if (cb == nullptr) {
78 NETMGR_EXT_LOG_E("callback is nullptr");
79 return NET_MDNS_ERR_ILLEGAL_ARGUMENT;
80 }
81 std::lock_guard<std::mutex> guard(mutex_);
82 if (std::find_if(discoveryMap_.begin(), discoveryMap_.end(), [cb](const auto &elem) {
83 return elem.first->AsObject().GetRefPtr() == cb->AsObject().GetRefPtr();
84 }) != discoveryMap_.end()) {
85 return NET_MDNS_ERR_CALLBACK_DUPLICATED;
86 }
87 int32_t err = impl.Discovery(serviceType);
88 if (err == NETMANAGER_EXT_SUCCESS) {
89 discoveryMap_.emplace_back(cb, serviceType);
90 }
91 return err;
92 }
93
StopDiscoverService(const sptr<IDiscoveryCallback> & cb)94 int32_t MDnsManager::StopDiscoverService(const sptr<IDiscoveryCallback> &cb)
95 {
96 if (cb == nullptr) {
97 NETMGR_EXT_LOG_E("callback is nullptr");
98 return NET_MDNS_ERR_ILLEGAL_ARGUMENT;
99 }
100 std::lock_guard<std::mutex> guard(mutex_);
101 auto local = std::find_if(discoveryMap_.begin(), discoveryMap_.end(), [cb](const auto &elem) {
102 return elem.first->AsObject().GetRefPtr() == cb->AsObject().GetRefPtr();
103 });
104 if (local == discoveryMap_.end()) {
105 return NET_MDNS_ERR_CALLBACK_NOT_FOUND;
106 }
107 int32_t err = impl.StopDiscovery(local->second);
108 if (err == NETMANAGER_EXT_SUCCESS) {
109 discoveryMap_.erase(local);
110 }
111 return err;
112 }
113
ResolveService(const MDnsServiceInfo & serviceInfo,const sptr<IResolveCallback> & cb)114 int32_t MDnsManager::ResolveService(const MDnsServiceInfo &serviceInfo, const sptr<IResolveCallback> &cb)
115 {
116 if (cb == nullptr) {
117 NETMGR_EXT_LOG_E("callback is nullptr");
118 return NET_MDNS_ERR_ILLEGAL_ARGUMENT;
119 }
120 std::lock_guard<std::mutex> guard(mutex_);
121 if (std::find_if(resolveMap_.begin(), resolveMap_.end(), [cb](const auto &elem) {
122 return elem.first->AsObject().GetRefPtr() == cb->AsObject().GetRefPtr();
123 }) != resolveMap_.end()) {
124 return NET_MDNS_ERR_CALLBACK_DUPLICATED;
125 }
126 std::string instance = serviceInfo.name + MDNS_DOMAIN_SPLITER_STR + serviceInfo.type;
127 int32_t err = impl.ResolveInstance(instance);
128 if (err == NETMANAGER_EXT_SUCCESS) {
129 resolveMap_.emplace_back(cb, instance);
130 }
131 return err;
132 }
133
InitHandler()134 void MDnsManager::InitHandler()
135 {
136 static auto handle = [this](const MDnsProtocolImpl::Result &result, int32_t error) {
137 ReceiveResult(result, error);
138 };
139 impl.SetHandler(handle);
140 }
141
ReceiveResult(const MDnsProtocolImpl::Result & result,int32_t error)142 void MDnsManager::ReceiveResult(const MDnsProtocolImpl::Result &result, int32_t error)
143 {
144 switch (result.type) {
145 case MDnsProtocolImpl::SERVICE_STARTED:
146 [[fallthrough]];
147 case MDnsProtocolImpl::SERVICE_STOPED:
148 return ReceiveRegister(result, error);
149 case MDnsProtocolImpl::SERVICE_FOUND:
150 [[fallthrough]];
151 case MDnsProtocolImpl::SERVICE_LOST:
152 return ReceiveDiscover(result, error);
153 case MDnsProtocolImpl::INSTANCE_RESOLVED:
154 return ReceiveInstanceResolve(result, error);
155 case MDnsProtocolImpl::DOMAIN_RESOLVED:
156 return ReceiveResolve(result, error);
157 case MDnsProtocolImpl::UNKNOWN:
158 [[fallthrough]];
159 default:
160 return;
161 }
162 }
163
ReceiveRegister(const MDnsProtocolImpl::Result & result,int32_t error)164 void MDnsManager::ReceiveRegister(const MDnsProtocolImpl::Result &result, int32_t error) {}
165
ReceiveDiscover(const MDnsProtocolImpl::Result & result,int32_t error)166 void MDnsManager::ReceiveDiscover(const MDnsProtocolImpl::Result &result, int32_t error)
167 {
168 std::lock_guard<std::mutex> guard(mutex_);
169 NETMGR_EXT_LOG_D("discoveryMap_ size: [%{public}zu]", discoveryMap_.size());
170 for (auto iter = discoveryMap_.begin(); iter != discoveryMap_.end(); ++iter) {
171 if (iter->second != result.serviceType) {
172 continue;
173 }
174 auto cb = iter->first;
175 MDnsServiceInfo info;
176 info.name = result.serviceName;
177 info.type = result.serviceType;
178 if (result.type == MDnsProtocolImpl::SERVICE_FOUND) {
179 cb->HandleServiceFound(info, error);
180 }
181 if (result.type == MDnsProtocolImpl::SERVICE_LOST) {
182 cb->HandleServiceLost(info, error);
183 }
184 }
185 }
186
ReceiveInstanceResolve(const MDnsProtocolImpl::Result & result,int32_t error)187 void MDnsManager::ReceiveInstanceResolve(const MDnsProtocolImpl::Result &result, int32_t error)
188 {
189 std::lock_guard<std::mutex> guard(mutex_);
190 if (result.addr.empty() && !result.domain.empty()) {
191 resolvResults_.emplace_back(result);
192 impl.Resolve(resolvResults_.back().domain);
193 return;
194 }
195
196 for (auto iter = resolveMap_.begin(); iter != resolveMap_.end();) {
197 if (iter->second != result.serviceName + MDNS_DOMAIN_SPLITER_STR + result.serviceType) {
198 ++iter;
199 continue;
200 }
201 auto cb = iter->first;
202 impl.StopResolveInstance(result.serviceName + MDNS_DOMAIN_SPLITER_STR + result.serviceType);
203 cb->HandleResolveResult(ConvertResultToInfo(result), error);
204 iter = resolveMap_.erase(iter);
205 }
206 }
207
ReceiveResolve(const MDnsProtocolImpl::Result & result,int32_t error)208 void MDnsManager::ReceiveResolve(const MDnsProtocolImpl::Result &result, int32_t error)
209 {
210 std::lock_guard<std::mutex> guard(mutex_);
211 auto res = resolvResults_.end();
212 res = std::find_if(resolvResults_.begin(), resolvResults_.end(),
213 [&](const auto &x) { return x.domain == result.domain; });
214 if (res == resolvResults_.end()) {
215 return;
216 }
217 res->ipv6 = result.ipv6;
218 res->addr = result.addr;
219
220 for (auto iter = resolveMap_.begin(); iter != resolveMap_.end();) {
221 if (iter->second != res->serviceName + MDNS_DOMAIN_SPLITER_STR + res->serviceType) {
222 ++iter;
223 continue;
224 }
225 auto cb = iter->first;
226 cb->HandleResolveResult(ConvertResultToInfo(*res), error);
227 iter = resolveMap_.erase(iter);
228 }
229
230 impl.StopResolve(res->domain);
231 impl.StopResolveInstance(result.serviceName + MDNS_DOMAIN_SPLITER_STR + result.serviceType);
232 resolvResults_.erase(res);
233 }
234
ConvertResultToInfo(const MDnsProtocolImpl::Result & result)235 MDnsServiceInfo MDnsManager::ConvertResultToInfo(const MDnsProtocolImpl::Result &result)
236 {
237 MDnsServiceInfo info;
238 info.name = result.serviceName;
239 info.type = result.serviceType;
240 if (!result.addr.empty()) {
241 info.family = result.ipv6 ? MDnsServiceInfo::IPV6 : MDnsServiceInfo::IPV4;
242 }
243 info.addr = result.addr;
244 info.port = result.port;
245 info.txtRecord = result.txt;
246 return info;
247 }
248
GetDumpMessage(std::string & message)249 void MDnsManager::GetDumpMessage(std::string &message)
250 {
251 message.append("mDNS Info:\n");
252 const auto &config = impl.GetConfig();
253 message.append("\tIPv6 Support: " + std::to_string(config.ipv6Support) + "\n");
254 message.append("\tAll Iface: " + std::to_string(config.configAllIface) + "\n");
255 message.append("\tTop Domain: " + config.topDomain + "\n");
256 message.append("\tHostname: " + config.hostname + "\n");
257 message.append("\tService Count: " + std::to_string(registerMap_.size()) + "\n");
258 }
259 } // namespace NetManagerStandard
260 } // namespace OHOS