• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2024 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include <unistd.h>
17 
18 #include "dm_transport.h"
19 #include "dm_anonymous.h"
20 #include "dm_comm_tool.h"
21 #include "dm_constants.h"
22 #include "dm_log.h"
23 #include "dm_softbus_cache.h"
24 #include "dm_transport_msg.h"
25 #include "softbus_error_code.h"
26 
27 namespace OHOS {
28 namespace DistributedHardware {
29 namespace {
30 // Dsoftbus sendBytes max message length: 4MB
31 constexpr uint32_t MAX_SEND_MSG_LENGTH = 4 * 1024 * 1024;
32 constexpr uint32_t INTERCEPT_STRING_LENGTH = 20;
33 constexpr uint32_t MAX_ROUND_SIZE = 1000;
34 const int32_t USLEEP_TIME_US_200000 = 200000;           // 200ms
35 static QosTV g_qosInfo[] = {
36     { .qos = QOS_TYPE_MIN_BW, .value = 256 * 1024},
37     { .qos = QOS_TYPE_MAX_LATENCY, .value = 8000 },
38     { .qos = QOS_TYPE_MIN_LATENCY, .value = 2000 }
39 };
40 static uint32_t g_qosTvParamIndex = static_cast<uint32_t>(sizeof(g_qosInfo) / sizeof(g_qosInfo[0]));
41 static std::weak_ptr<DMCommTool> g_dmCommToolWPtr_;
42 }
43 
DMTransport(std::shared_ptr<DMCommTool> dmCommToolPtr)44 DMTransport::DMTransport(std::shared_ptr<DMCommTool> dmCommToolPtr) : remoteDevSocketIds_({}), localServerSocket_(-1),
45     localSocketName_(""), isSocketSvrCreateFlag_(false), dmCommToolWPtr_(dmCommToolPtr)
46 {
47     LOGI("Ctor DMTransport");
48     g_dmCommToolWPtr_ = dmCommToolPtr;
49 }
50 
OnSocketOpened(int32_t socketId,const PeerSocketInfo & info)51 int32_t DMTransport::OnSocketOpened(int32_t socketId, const PeerSocketInfo &info)
52 {
53     LOGI("OnSocketOpened, socket: %{public}d, peerSocketName: %{public}s, peerNetworkId: %{public}s, "
54         "peerPkgName: %{public}s", socketId, info.name, GetAnonyString(info.networkId).c_str(), info.pkgName);
55     std::lock_guard<std::mutex> lock(rmtSocketIdMtx_);
56     if (remoteDevSocketIds_.find(info.networkId) == remoteDevSocketIds_.end()) {
57         std::set<int32_t> socketSet;
58         socketSet.insert(socketId);
59         remoteDevSocketIds_[info.networkId] = socketSet;
60         return DM_OK;
61     }
62     remoteDevSocketIds_.at(info.networkId).insert(socketId);
63     return DM_OK;
64 }
65 
OnSocketClosed(int32_t socketId,ShutdownReason reason)66 void DMTransport::OnSocketClosed(int32_t socketId, ShutdownReason reason)
67 {
68     LOGI("OnSocketClosed, socket: %{public}d, reason: %{public}d", socketId, (int32_t)reason);
69     std::lock_guard<std::mutex> lock(rmtSocketIdMtx_);
70     for (auto iter = remoteDevSocketIds_.begin(); iter != remoteDevSocketIds_.end();) {
71         iter->second.erase(socketId);
72         if (iter->second.empty()) {
73             iter = remoteDevSocketIds_.erase(iter);
74         } else {
75             ++iter;
76         }
77     }
78     sourceSocketIds_.erase(socketId);
79 }
80 
OnBytesReceived(int32_t socketId,const void * data,uint32_t dataLen)81 void DMTransport::OnBytesReceived(int32_t socketId, const void *data, uint32_t dataLen)
82 {
83     if (socketId < 0 || data == nullptr || dataLen == 0 || dataLen > MAX_SEND_MSG_LENGTH) {
84         LOGE("OnBytesReceived param check failed");
85         return;
86     }
87 
88     std::string remoteNeworkId = GetRemoteNetworkIdBySocketId(socketId);
89     if (remoteNeworkId.empty()) {
90         LOGE("Can not find the remote network id by socketId: %{public}d", socketId);
91         return;
92     }
93 
94     uint8_t *buf = reinterpret_cast<uint8_t *>(calloc(dataLen + 1, sizeof(uint8_t)));
95     if (buf == nullptr) {
96         LOGE("OnBytesReceived: malloc memory failed");
97         return;
98     }
99 
100     if (memcpy_s(buf, dataLen + 1,  reinterpret_cast<const uint8_t *>(data), dataLen) != EOK) {
101         LOGE("OnBytesReceived: memcpy memory failed");
102         free(buf);
103         return;
104     }
105 
106     std::string message(buf, buf + dataLen);
107     LOGI("Receive message size: %{public}" PRIu32, dataLen);
108     HandleReceiveMessage(socketId, message);
109     free(buf);
110     return;
111 }
112 
HandleReceiveMessage(const int32_t socketId,const std::string & payload)113 void DMTransport::HandleReceiveMessage(const int32_t socketId, const std::string &payload)
114 {
115     std::string rmtNetworkId = GetRemoteNetworkIdBySocketId(socketId);
116     if (rmtNetworkId.empty()) {
117         LOGE("Can not find networkId by socketId: %{public}d", socketId);
118         return;
119     }
120     if (payload.empty() || payload.size() > MAX_SEND_MSG_LENGTH) {
121         LOGE("payload invalid");
122         return;
123     }
124     LOGI("Receive msg: %{public}s", GetAnonyString(payload).c_str());
125     cJSON *root = cJSON_Parse(payload.c_str());
126     if (root == NULL) {
127         LOGE("the msg is not json format");
128         return;
129     }
130     std::shared_ptr<CommMsg> commMsg = std::make_shared<CommMsg>();
131     FromJson(root, *commMsg);
132     cJSON_Delete(root);
133 
134     std::shared_ptr<InnerCommMsg> innerMsg = std::make_shared<InnerCommMsg>(rmtNetworkId, commMsg, socketId);
135 
136     LOGI("Receive DM msg, code: %{public}d, msg: %{public}s", commMsg->code, GetAnonyString(commMsg->msg).c_str());
137     AppExecFwk::InnerEvent::Pointer msgEvent = AppExecFwk::InnerEvent::Get(commMsg->code, innerMsg);
138     std::shared_ptr<DMCommTool> dmCommToolSPtr = dmCommToolWPtr_.lock();
139     if (dmCommToolSPtr == nullptr) {
140         LOGE("Can not get DMCommTool ptr");
141         return;
142     }
143     if (dmCommToolSPtr->GetEventHandler() == nullptr) {
144         LOGE("Can not get eventHandler");
145         return;
146     }
147     dmCommToolSPtr->GetEventHandler()->SendEvent(msgEvent, 0, AppExecFwk::EventQueue::Priority::IMMEDIATE);
148 }
149 
150 //LCOV_EXCL_START
GetDMCommToolPtr()151 std::shared_ptr<DMCommTool> GetDMCommToolPtr()
152 {
153     if (g_dmCommToolWPtr_.expired()) {
154         LOGE("DMCommTool Weak ptr expired");
155         return nullptr;
156     }
157 
158     std::shared_ptr<DMCommTool> dmCommToolSPtr = g_dmCommToolWPtr_.lock();
159     if (dmCommToolSPtr == nullptr) {
160         LOGE("Can not get DMCommTool ptr");
161         return nullptr;
162     }
163 
164     return dmCommToolSPtr;
165 }
166 //LCOV_EXCL_STOP
167 
OnBind(int32_t socket,PeerSocketInfo info)168 void OnBind(int32_t socket, PeerSocketInfo info)
169 {
170     std::shared_ptr<DMCommTool> dmCommToolSPtr = GetDMCommToolPtr();
171     if (dmCommToolSPtr == nullptr) {
172         LOGE("Can not get DMCommTool ptr");
173         return;
174     }
175     dmCommToolSPtr->GetDMTransportPtr()->OnSocketOpened(socket, info);
176 }
177 
OnShutdown(int32_t socket,ShutdownReason reason)178 void OnShutdown(int32_t socket, ShutdownReason reason)
179 {
180     std::shared_ptr<DMCommTool> dmCommToolSPtr = GetDMCommToolPtr();
181     if (dmCommToolSPtr == nullptr) {
182         LOGE("Can not get DMCommTool ptr");
183         return;
184     }
185     dmCommToolSPtr->GetDMTransportPtr()->OnSocketClosed(socket, reason);
186 }
187 
OnBytes(int32_t socket,const void * data,uint32_t dataLen)188 void OnBytes(int32_t socket, const void *data, uint32_t dataLen)
189 {
190     std::shared_ptr<DMCommTool> dmCommToolSPtr = GetDMCommToolPtr();
191     if (dmCommToolSPtr == nullptr) {
192         LOGE("Can not get DMCommTool ptr");
193         return;
194     }
195     dmCommToolSPtr->GetDMTransportPtr()->OnBytesReceived(socket, data, dataLen);
196 }
197 
OnMessage(int32_t socket,const void * data,uint32_t dataLen)198 void OnMessage(int32_t socket, const void *data, uint32_t dataLen)
199 {
200     (void)socket;
201     (void)data;
202     (void)dataLen;
203     LOGI("socket: %{public}d, dataLen:%{public}" PRIu32, socket, dataLen);
204 }
205 
OnStream(int32_t socket,const StreamData * data,const StreamData * ext,const StreamFrameInfo * param)206 void OnStream(int32_t socket, const StreamData *data, const StreamData *ext,
207     const StreamFrameInfo *param)
208 {
209     (void)socket;
210     (void)data;
211     (void)ext;
212     (void)param;
213     LOGI("socket: %{public}d", socket);
214 }
215 
OnFile(int32_t socket,FileEvent * event)216 void OnFile(int32_t socket, FileEvent *event)
217 {
218     (void)event;
219     LOGI("socket: %{public}d", socket);
220 }
221 
OnQos(int32_t socket,QoSEvent eventId,const QosTV * qos,uint32_t qosCount)222 void OnQos(int32_t socket, QoSEvent eventId, const QosTV *qos, uint32_t qosCount)
223 {
224     if (qosCount == 0 || qosCount > MAX_ROUND_SIZE) {
225         LOGE("qosCount is invalid!");
226         return;
227     }
228     LOGI("OnQos, socket: %{public}d, QoSEvent: %{public}d, qosCount: %{public}" PRIu32,
229         socket, (int32_t)eventId, qosCount);
230     for (uint32_t idx = 0; idx < qosCount; idx++) {
231         LOGI("QosTV: type: %{public}d, value: %{public}d", (int32_t)qos[idx].qos, qos[idx].value);
232     }
233 }
234 
235 ISocketListener iSocketListener = {
236     .OnBind = OnBind,
237     .OnShutdown = OnShutdown,
238     .OnBytes = OnBytes,
239     .OnMessage = OnMessage,
240     .OnStream = OnStream,
241     .OnFile = OnFile,
242     .OnQos = OnQos
243 };
244 
CreateServerSocket()245 int32_t DMTransport::CreateServerSocket()
246 {
247     LOGI("CreateServerSocket start");
248     localSocketName_ = DM_SYNC_USERID_SESSION_NAME;
249     LOGI("CreateServerSocket , local socketName: %{public}s", localSocketName_.c_str());
250     std::string dmPkgName(DM_PKG_NAME);
251     SocketInfo info = {
252         .name = const_cast<char*>(localSocketName_.c_str()),
253         .pkgName = const_cast<char*>(dmPkgName.c_str()),
254         .dataType = DATA_TYPE_BYTES
255     };
256     int32_t socket = Socket(info);
257     LOGI("CreateServerSocket Finish, socket: %{public}d", socket);
258     return socket;
259 }
260 
CreateClientSocket(const std::string & rmtNetworkId)261 int32_t DMTransport::CreateClientSocket(const std::string &rmtNetworkId)
262 {
263     if (!IsIdLengthValid(rmtNetworkId)) {
264         return ERR_DM_INPUT_PARA_INVALID;
265     }
266     LOGI("CreateClientSocket start, peerNetworkId: %{public}s", GetAnonyString(rmtNetworkId).c_str());
267     std::string peerSocketName = DM_SYNC_USERID_SESSION_NAME;
268     std::string dmPkgName(DM_PKG_NAME);
269     SocketInfo info = {
270         .name = const_cast<char*>(localSocketName_.c_str()),
271         .peerName = const_cast<char*>(peerSocketName.c_str()),
272         .peerNetworkId = const_cast<char*>(rmtNetworkId.c_str()),
273         .pkgName = const_cast<char*>(dmPkgName.c_str()),
274         .dataType = DATA_TYPE_BYTES
275     };
276     int32_t socket = Socket(info);
277     LOGI("Bind Socket server, socket: %{public}d, localSocketName: %{public}s, peerSocketName: %{public}s",
278         socket, localSocketName_.c_str(), peerSocketName.c_str());
279     return socket;
280 }
281 
Init()282 int32_t DMTransport::Init()
283 {
284     LOGI("Init DMTransport");
285     if (isSocketSvrCreateFlag_.load()) {
286         LOGI("SocketServer already create success.");
287         return DM_OK;
288     }
289     int32_t socket = CreateServerSocket();
290     if (socket < DM_OK) {
291         LOGE("CreateSocketServer failed, ret: %{public}d", socket);
292         return ERR_DM_FAILED;
293     }
294 
295     int32_t ret = Listen(socket, g_qosInfo, g_qosTvParamIndex, &iSocketListener);
296     if (ret != DM_OK) {
297         LOGE("Socket Listen failed, error code %{public}d.", ret);
298         return ERR_DM_FAILED;
299     }
300     isSocketSvrCreateFlag_.store(true);
301     localServerSocket_ = socket;
302     LOGI("Finish Init DSoftBus Server Socket, socket: %{public}d", socket);
303     return DM_OK;
304 }
305 
UnInit()306 int32_t DMTransport::UnInit()
307 {
308     {
309         std::lock_guard<std::mutex> lock(rmtSocketIdMtx_);
310         for (auto iter = remoteDevSocketIds_.begin(); iter != remoteDevSocketIds_.end(); ++iter) {
311             for (auto iter1 = iter->second.begin(); iter1 != iter->second.end(); ++iter1) {
312                 LOGI("Shutdown client socket: %{public}d to remote dev: %{public}s", *iter1,
313                     GetAnonyString(iter->first).c_str());
314                 Shutdown(*iter1);
315             }
316         }
317         remoteDevSocketIds_.clear();
318         sourceSocketIds_.clear();
319     }
320 
321     if (!isSocketSvrCreateFlag_.load()) {
322         LOGI("DSoftBus Server Socket already remove success.");
323     } else {
324         LOGI("Shutdown DSoftBus Server Socket, socket: %{public}d", localServerSocket_.load());
325         Shutdown(localServerSocket_.load());
326         localServerSocket_ = -1;
327         isSocketSvrCreateFlag_.store(false);
328     }
329     return DM_OK;
330 }
331 
IsDeviceSessionOpened(const std::string & rmtNetworkId,int32_t & socketId)332 bool DMTransport::IsDeviceSessionOpened(const std::string &rmtNetworkId, int32_t &socketId)
333 {
334     if (!IsIdLengthValid(rmtNetworkId)) {
335         return false;
336     }
337     std::lock_guard<std::mutex> lock(rmtSocketIdMtx_);
338     auto iter = remoteDevSocketIds_.find(rmtNetworkId);
339     if (iter == remoteDevSocketIds_.end()) {
340         return false;
341     }
342     for (auto iter1 = iter->second.begin(); iter1 != iter->second.end(); ++iter1) {
343         if (sourceSocketIds_.find(*iter1) != sourceSocketIds_.end()) {
344             socketId = *iter1;
345             return true;
346         }
347     }
348     return false;
349 }
350 
GetRemoteNetworkIdBySocketId(int32_t socketId)351 std::string DMTransport::GetRemoteNetworkIdBySocketId(int32_t socketId)
352 {
353     std::lock_guard<std::mutex> lock(rmtSocketIdMtx_);
354     std::string networkId = "";
355     for (auto const &item : remoteDevSocketIds_) {
356         if (item.second.find(socketId) != item.second.end()) {
357             networkId = item.first;
358             break;
359         }
360     }
361     return networkId;
362 }
363 
ClearDeviceSocketOpened(const std::string & remoteDevId,int32_t socketId)364 void DMTransport::ClearDeviceSocketOpened(const std::string &remoteDevId, int32_t socketId)
365 {
366     if (!IsIdLengthValid(remoteDevId)) {
367         return;
368     }
369     std::lock_guard<std::mutex> lock(rmtSocketIdMtx_);
370     auto iter = remoteDevSocketIds_.find(remoteDevId);
371     if (iter == remoteDevSocketIds_.end()) {
372         return;
373     }
374     iter->second.erase(socketId);
375     if (iter->second.empty()) {
376         remoteDevSocketIds_.erase(iter);
377     }
378     sourceSocketIds_.erase(socketId);
379 }
380 
StartSocket(const std::string & rmtNetworkId,int32_t & socketId)381 int32_t DMTransport::StartSocket(const std::string &rmtNetworkId, int32_t &socketId)
382 {
383     int32_t errCode = ERR_DM_FAILED;
384     int32_t count = 0;
385     const int32_t maxCount = 10;
386 
387     do {
388         errCode = StartSocketInner(rmtNetworkId, socketId);
389         if (errCode != ERR_DM_SOCKET_IN_USED) {
390             break;
391         }
392         count++;
393         usleep(USLEEP_TIME_US_200000);
394     } while (count < maxCount);
395 
396     return errCode;
397 }
398 
StartSocketInner(const std::string & rmtNetworkId,int32_t & socketId)399 int32_t DMTransport::StartSocketInner(const std::string &rmtNetworkId, int32_t &socketId)
400 {
401     if (!IsIdLengthValid(rmtNetworkId)) {
402         return ERR_DM_INPUT_PARA_INVALID;
403     }
404     if (IsDeviceSessionOpened(rmtNetworkId, socketId)) {
405         LOGE("Softbus session has already opened, deviceId: %{public}s", GetAnonyString(rmtNetworkId).c_str());
406         return ERR_DM_SOCKET_IN_USED;
407     }
408 
409     int32_t socket = CreateClientSocket(rmtNetworkId);
410     if (socket < DM_OK) {
411         LOGE("StartSocket failed, ret: %{public}d", socket);
412         return ERR_DM_FAILED;
413     }
414 
415     int32_t ret = Bind(socket, g_qosInfo, g_qosTvParamIndex, &iSocketListener);
416     if (ret < DM_OK) {
417         if (ret == SOFTBUS_TRANS_SOCKET_IN_USE) {
418             LOGI("Softbus trans socket in use.");
419             return ERR_DM_SOCKET_IN_USED;
420         }
421         LOGE("OpenSession fail, rmtNetworkId: %{public}s, socket: %{public}d, ret: %{public}d",
422             GetAnonyString(rmtNetworkId).c_str(), socket, ret);
423         Shutdown(socket);
424         return ERR_DM_FAILED;
425     }
426 
427     LOGI("Bind Socket success, rmtNetworkId:%{public}s, socketId: %{public}d",
428         GetAnonyString(rmtNetworkId).c_str(), socket);
429     std::string peerSocketName = DM_SYNC_USERID_SESSION_NAME;
430     std::string dmPkgName(DM_PKG_NAME);
431     PeerSocketInfo peerSocketInfo = {
432         .name = const_cast<char*>(peerSocketName.c_str()),
433         .networkId = const_cast<char*>(rmtNetworkId.c_str()),
434         .pkgName = const_cast<char*>(dmPkgName.c_str()),
435         .dataType = DATA_TYPE_BYTES
436     };
437     OnSocketOpened(socket, peerSocketInfo);
438     sourceSocketIds_.insert(socket);
439     socketId = socket;
440     return DM_OK;
441 }
442 
StopSocket(const std::string & rmtNetworkId)443 int32_t DMTransport::StopSocket(const std::string &rmtNetworkId)
444 {
445     if (!IsIdLengthValid(rmtNetworkId)) {
446         return ERR_DM_INPUT_PARA_INVALID;
447     }
448     int32_t socketId = -1;
449     if (!IsDeviceSessionOpened(rmtNetworkId, socketId)) {
450         LOGI("remote dev may be not opened, rmtNetworkId: %{public}s", GetAnonyString(rmtNetworkId).c_str());
451         return ERR_DM_FAILED;
452     }
453 
454     LOGI("StopSocket rmtNetworkId: %{public}s, socketId: %{public}d",
455         GetAnonyString(rmtNetworkId).c_str(), socketId);
456     Shutdown(socketId);
457     ClearDeviceSocketOpened(rmtNetworkId, socketId);
458     return DM_OK;
459 }
460 
Send(const std::string & rmtNetworkId,const std::string & payload,int32_t socketId)461 int32_t DMTransport::Send(const std::string &rmtNetworkId, const std::string &payload, int32_t socketId)
462 {
463     if (!IsIdLengthValid(rmtNetworkId) || !IsMessageLengthValid(payload)) {
464         return ERR_DM_INPUT_PARA_INVALID;
465     }
466     if (socketId <= 0) {
467         LOGI("The session is not open, target networkId: %{public}s", GetAnonyString(rmtNetworkId).c_str());
468         return ERR_DM_FAILED;
469     }
470     uint32_t payLoadSize = payload.size();
471     LOGI("Send payload size: %{public}" PRIu32 ", target networkId: %{public}s, socketId: %{public}d",
472         static_cast<uint32_t>(payload.size()), GetAnonyString(rmtNetworkId).c_str(), socketId);
473 
474     if (payLoadSize > MAX_SEND_MSG_LENGTH) {
475         LOGE("Send error: msg size: %{public}" PRIu32 " too long", payLoadSize);
476         return ERR_DM_FAILED;
477     }
478     uint8_t *buf = reinterpret_cast<uint8_t *>(calloc((payLoadSize), sizeof(uint8_t)));
479     if (buf == nullptr) {
480         LOGE("Send: malloc memory failed");
481         return ERR_DM_FAILED;
482     }
483 
484     if (memcpy_s(buf, payLoadSize, reinterpret_cast<const uint8_t *>(payload.c_str()),
485                  payLoadSize) != EOK) {
486         LOGE("Send: memcpy memory failed");
487         free(buf);
488         return ERR_DM_FAILED;
489     }
490 
491     int32_t ret = SendBytes(socketId, buf, payLoadSize);
492     free(buf);
493     if (ret != DM_OK) {
494         LOGE("dsoftbus send error, ret: %{public}d", ret);
495         return ERR_DM_FAILED;
496     }
497     LOGI("Send payload success");
498     return DM_OK;
499 }
500 } // DistributedHardware
501 } // OHOS